# Imports

In [None]:
import packages
from configs import settings, const, components

from configs.settings import logger
import asyncio, os, time, yaml, json, datetime, copy
from typing import Any, AsyncGenerator, Generator, Callable, Literal, Optional, TypeAlias, Union
from tqdm import tqdm
from pprint import pprint
from rich import print as rprint

from toolkit.llm.llama_index import (
	agents, cores, deploys as dpls, evaluation, messages, models, 
	observability, types, utils as utils_llama_index, workflows as wfs
)
from toolkit.llm.llama_index.data import loading, querying, storing

from toolkit.utils import utils, typer as t
from toolkit.utils.llm import measure_performance, main as utils_llm


# User Query Categorization

In [None]:
import json
from typing import Dict, List, Tuple, Set, TypeAlias, Literal, Optional, Callable
from collections import defaultdict
import yaml
from tqdm import tqdm
from rich.console import Console
from rich.table import Table
from rich import print as rprint

def predict_categories(
		input_file: str,
		output_file: str,
		predict_fn: Callable[[str], str],
		valid_categories: Set[str],
		query_field: str = "user_query",
		true_category_field: str = "category",
		pred_category_field: str = "pred_category",
		default_category: Optional[str] = None
) -> None:
		"""
		Generic function to predict categories for items in a JSON file.
		
		Args:
				input_file: Path to input JSON file with items to classify
				output_file: Path to output JSON file for results
				predict_fn: Function that takes a query string and returns a predicted category
				valid_categories: Set of valid category names
				query_field: Name of the field containing the query in input JSON
				true_category_field: Name of the field containing the true category
				pred_category_field: Name of the field to store predicted category
				default_category: Default category to use when prediction fails (if None, will raise error)
				
		Raises:
				FileNotFoundError: If input file doesn't exist
				json.JSONDecodeError: If input file is not valid JSON
				ValueError: If input data is invalid
				IOError: If writing to output file fails
		"""
		# 1. Input Validation
		if not input_file or not output_file:
				raise ValueError("Input and output file paths must be provided")
		
		if not valid_categories:
				raise ValueError("Valid categories set must not be empty")

		# 2. Load and Validate Input Data
		try:
				with open(input_file, 'r') as f:
						items = json.load(f)
		except FileNotFoundError:
				raise FileNotFoundError(f"Input file not found: {input_file}")
		except json.JSONDecodeError:
				raise ValueError(f"Invalid JSON in input file: {input_file}")
		
		if not items or not isinstance(items, list):
				raise ValueError("Input file must contain a non-empty list of items")
		
		# Validate required fields in first item
		required_fields = {query_field, true_category_field}
		if not items[0] or not all(field in items[0] for field in required_fields):
				raise ValueError(f"Missing required fields in input data: {required_fields}")
		
		# 3. Process Items
		processed_items = []
		for item in tqdm(items, desc="Processing items"):
				try:
						query = item[query_field]
						
						# Get prediction
						pred_category = predict_fn(query)
						
						# Validate prediction
						if pred_category not in valid_categories:
								if default_category is None:
										raise ValueError(f"Invalid prediction '{pred_category}' and no default category provided")
								print(f"Warning: Invalid prediction '{pred_category}' for query: {query}")
								pred_category = default_category
						
						# Add to processed items
						item[pred_category_field] = pred_category
						processed_items.append(item)
						
				except Exception as e:
						if default_category is None:
								raise Exception(f"Error processing query '{query}': {str(e)}")
						print(f"Error processing query '{query}': {str(e)}")
						item[pred_category_field] = default_category
						processed_items.append(item)
		
		# 4. Write Results
		try:
				with open(output_file, 'w') as f:
						json.dump(processed_items, f, indent=2)
		except Exception as e:
				raise IOError(f"Error writing to output file {output_file}: {str(e)}")

def evaluate_predictions(
		prediction_file: str,
		true_category_field: str = "category",
		pred_category_field: str = "pred_category",
		query_field: str = "user_query",
		valid_categories: Optional[Set[str]] = None
) -> Tuple[float, Dict]:
		"""
		Generic function to evaluate classification predictions.
		
		Args:
				prediction_file: Path to JSON file containing true and predicted categories
				true_category_field: Name of the field containing true category
				pred_category_field: Name of the field containing predicted category
				query_field: Name of the field containing the query text
				valid_categories: Optional set of valid categories for validation
				
		Returns:
				Tuple containing:
				- overall accuracy (float)
				- detailed metrics dictionary with confusion matrix and per-category metrics
				
		Raises:
				FileNotFoundError: If prediction file doesn't exist
				ValueError: If prediction file is empty or missing required fields
		"""
		console = Console()
		
		# 1. Load and Validate Input Data
		try:
				with open(prediction_file, 'r') as f:
						predictions = json.load(f)
		except FileNotFoundError:
				raise FileNotFoundError(f"Prediction file not found: {prediction_file}")
		except json.JSONDecodeError:
				raise ValueError(f"Invalid JSON in prediction file: {prediction_file}")
		
		if not predictions:
				raise ValueError("Empty prediction file")
		
		# Validate required fields
		required_fields = {query_field, true_category_field, pred_category_field}
		if not all(field in predictions[0] for field in required_fields):
				raise ValueError(f"Missing required fields: {required_fields}")
		
		# 2. Initialize Data Structures
		total = len(predictions)
		correct = 0
		confusion_matrix = defaultdict(lambda: defaultdict(int))
		category_metrics = defaultdict(lambda: {
				'correct': 0,
				'total': 0,
				'incorrect_examples': [],
				'precision': 0.0,
				'recall': 0.0,
				'f1': 0.0
		})
		
		# Get all unique categories
		all_categories: Set[str] = set()
		for item in predictions:
				all_categories.add(item[true_category_field])
				all_categories.add(item[pred_category_field])
		
		# Validate categories if provided
		if valid_categories:
				invalid_categories = all_categories - valid_categories
				if invalid_categories:
						print(f"Warning: Found invalid categories: {invalid_categories}")
		
		# 3. Calculate Basic Metrics
		for item in predictions:
				true_category = item[true_category_field]
				pred_category = item[pred_category_field]
				query = item[query_field]
				
				category_metrics[true_category]['total'] += 1
				confusion_matrix[true_category][pred_category] += 1
				
				if true_category == pred_category:
						correct += 1
						category_metrics[true_category]['correct'] += 1
				else:
						category_metrics[true_category]['incorrect_examples'].append({
								'query': query,
								'true': true_category,
								'predicted': pred_category
						})
		
		# 4. Calculate Advanced Metrics
		accuracy = correct / total if total > 0 else 0
		
		for category in category_metrics:
				metrics = category_metrics[category]
				total_cat = metrics['total']
				metrics['accuracy'] = metrics['correct'] / total_cat if total_cat > 0 else 0
				
				true_positives = confusion_matrix[category][category]
				false_positives = sum(confusion_matrix[other_cat][category] 
														for other_cat in all_categories if other_cat != category)
				false_negatives = sum(confusion_matrix[category][other_cat] 
														for other_cat in all_categories if other_cat != category)
				
				metrics['precision'] = (true_positives / (true_positives + false_positives) 
															if (true_positives + false_positives) > 0 else 0)
				metrics['recall'] = (true_positives / (true_positives + false_negatives)
													if (true_positives + false_negatives) > 0 else 0)
				
				if metrics['precision'] + metrics['recall'] > 0:
						metrics['f1'] = (2 * metrics['precision'] * metrics['recall'] / 
													(metrics['precision'] + metrics['recall']))
				else:
						metrics['f1'] = 0.0
		
		# 5. Prepare Output Metrics
		detailed_metrics = {
				'total_samples': total,
				'correct_predictions': correct,
				'overall_accuracy': accuracy,
				'per_category_metrics': dict(category_metrics),
				'confusion_matrix': dict(confusion_matrix)
		}
		
		# 6. Print Results
		# Overall metrics table
		console.print("\n[bold cyan]Evaluation Results:[/bold cyan]")
		overall_table = Table(show_header=True, header_style="bold magenta")
		overall_table.add_column("Metric", style="bright_blue")
		overall_table.add_column("Value", justify="right", style="bright_yellow")
		
		overall_table.add_row("Total samples", str(total))
		overall_table.add_row("Correct predictions", str(correct))
		overall_table.add_row("Overall accuracy", f"{accuracy:.2%}")
		
		console.print(overall_table)
		
		# Per-category performance table
		console.print("\n[bold cyan]Per-category Performance:[/bold cyan]")
		perf_table = Table(show_header=True, header_style="bold magenta")
		perf_table.add_column("Category", style="cyan")
		perf_table.add_column("Accuracy", justify="right")
		perf_table.add_column("Precision", justify="right")
		perf_table.add_column("Recall", justify="right")
		perf_table.add_column("F1 Score", justify="right")
		perf_table.add_column("Correct/Total", justify="right")
		
		for category in sorted(category_metrics.keys()):
				metrics = category_metrics[category]
				perf_table.add_row(
						category,
						f"{metrics['accuracy']:.2%}",
						f"{metrics['precision']:.2%}",
						f"{metrics['recall']:.2%}",
						f"{metrics['f1']:.2%}",
						f"{metrics['correct']}/{metrics['total']}"
				)
		
		console.print(perf_table)
		
		# Confusion matrix
		console.print("\n[bold cyan]Confusion Matrix:[/bold cyan]")
		matrix_table = Table(show_header=True, header_style="bold magenta")
		matrix_table.add_column("True ↓ Pred →", style="cyan")
		for category in sorted(all_categories):
				matrix_table.add_column(category, justify="right")
		
		for true_cat in sorted(all_categories):
				row = [true_cat]
				for pred_cat in sorted(all_categories):
						row.append(str(confusion_matrix[true_cat][pred_cat]))
				matrix_table.add_row(*row)
		
		console.print(matrix_table)
		
		return accuracy, detailed_metrics

# Example usage for the car query classification
if __name__ == "__main__":
		try:
				prompts_agent_car = settings.prompts_agent_car
		
				# Define prediction function for car queries
				def predict_car_query_category(query: str) -> str:
						examples = components.retriever_user_query_category.retrieve(query)
						examples = str(await utils_llama_index.extract_retriever_results(examples))
						
						pred_category = await utils_llama_index.interact_model(
								prompt=prompts_agent_car["CategorizeQuestion"]["dev"],
								mode="achat",
								user_query=query,
								examples=examples,
						)
						
						return await utils_llm.post_process_llm_output(
								pred_category,
								mode=["remove_quotes", "remove_brackets"],
						)
				
				# Set up parameters
				VALID_CATEGORIES = {"car_control", "car_manual"}
				input_file = f"{packages.APP_PATH}/data/QnAs/user_query_category.json"
				output_file = f"{packages.APP_PATH}/data/QnAs/pred-user_query_category.json"
				
				# Predict categories
				predict_categories(
						input_file=input_file,
						output_file=output_file,
						predict_fn=predict_car_query_category,
						valid_categories=VALID_CATEGORIES,
						query_field="user_query",
						true_category_field="user_query_category",
						pred_category_field="pred_user_query_category",
						default_category="car_manual"
				)
				
				# Evaluate predictions
				accuracy, metrics = evaluate_predictions(
						prediction_file=output_file,
						true_category_field="user_query_category",
						pred_category_field="pred_user_query_category",
						query_field="user_query",
						valid_categories=VALID_CATEGORIES
				)
				
		except Exception as e:
				rprint(f"[bold red]Error during execution:[/bold red] {str(e)}")