# Setup

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import json
import random
from pathlib import Path

from tqdm import tqdm

from chat_checker.breakdown_detection.breakdown_detector import BreakdownIdentifier, OurBreakdownIdentifier
from models.benchmark_dialogues import DBDCErrorClassificationDialogue
from models.configs import BreakdownDetectionConfig
from dbdc_eval.breakdown_detection_analyzer import analyze_error_category_classification_dataset
from dbdc_eval.breakdown_classification_evaluator import compute_dbdc_detection_scores, compute_dbdc_error_classification_scores
from breakdown_dataset_loader import load_error_classification_dataset, load_tested_error_classification_dialogues

In [3]:
random.seed(42)

# Load Dataset
Data source: downloaded data from https://chateval.org/dbdc5 and extracted under ../datasets/dialogue_breakdowns/

In [4]:
eval_base_dir = Path("./data/dbdc5_error_classification_ja_dev_subset/")
tested_subset_dir = eval_base_dir / "annotated_dialogues"

In [None]:
dataset = load_error_classification_dataset()
len(dataset)

In [None]:
print(json.dumps(dataset[0].model_dump(), indent=4, ensure_ascii=False))

In [None]:
dataset[0].to_chat_checker_dialogue()

In [None]:
analyze_error_category_classification_dataset(dataset)

In [9]:
# Specify whether to load existing samples and recompute existing annotations
load_existing_samples = True
recompute_existing_annotations = False
# Specify the number of new samples to sample and the maximum number of samples for evaluation
n_new_samples = 0
max_samples = 200

In [None]:
tested_samples: list[DBDCErrorClassificationDialogue] = []
if load_existing_samples:
    tested_samples = load_tested_error_classification_dialogues()
len(tested_samples)

In [None]:
new_samples = min(n_new_samples, len(dataset) - len(tested_samples))
new_samples

In [12]:
# Sample new samples from the dataset excluding the already tested samples
tested_ids = {dialogue.dialogue_id for dialogue in tested_samples}
remaining_samples = [dialogue for dialogue in dataset if dialogue.dialogue_id not in tested_ids]
new_samples = random.sample(remaining_samples, n_new_samples)

In [None]:
# Shuffle the tested samples before combining with new samples
random.shuffle(tested_samples)

subset_for_testing = new_samples + tested_samples
subset_for_testing = subset_for_testing[:max_samples]
len(subset_for_testing)

In [None]:
print([dialogue.dialogue_id for dialogue in subset_for_testing])

In [None]:
print("First dialogue from subset for testing:")
print(json.dumps(subset_for_testing[0].model_dump(), indent=2, ensure_ascii=False))

# Build Evaluation Variants

In [14]:
models = {
    # 'gpt-3.5': 'gpt-3.5-turbo-0125',
    "gpt-4o": "gpt-4o-2024-08-06",
    # "o3-mini": "o3-mini-2025-01-31",
    # "gpt-4-turbo": "gpt-4-turbo-2024-04-09",
    # "gemini-2.5-pro": "gemini/gemini-2.5-pro-preview-03-25"
    # "gemini-2.0-flash": "gemini/gemini-2.0-flash-001"
}

breakdown_identifiers: dict[str, BreakdownIdentifier] = {
    "ours": OurBreakdownIdentifier(),
}

te_inclusion = {
    "no-tes": False,
    # "with-tes": True,
}

In [None]:
eval_configs: list[BreakdownDetectionConfig] = []
for model_name, model_version in models.items():
    for breakdown_identifier_name, breakdown_identifier in breakdown_identifiers.items():
        for te_variant, include_te in te_inclusion.items():
            config = BreakdownDetectionConfig(
                key=f"{model_name}_{breakdown_identifier_name}_{te_variant}",
                model=model_version,
                breakdown_identifier=breakdown_identifier,
                include_task_oriented_errors=include_te,
            )
            eval_configs.append(config)

print(f"Total number of eval configs: {len(eval_configs)}")
print(f"Config keys:\n{[config.key for config in eval_configs]}")

# Generate the breakdown annotations with each config

In [None]:
from litellm import completion_cost

from chat_checker.models.dialogue import SpeakerRole
from chat_checker.utils.misc_utils import write_prompt_to_txt_file


for config in eval_configs:
    config_dir = tested_subset_dir / config.key
    config_dir.mkdir(parents=True, exist_ok=True)
    print(config_dir)
    first_debug_stored = False
    for i, dialogue in tqdm(enumerate(subset_for_testing)):
        chat_checker_dialogue = dialogue.to_chat_checker_dialogue()
        for k, turn in enumerate(chat_checker_dialogue.chat_history):
            if turn.role != SpeakerRole.DIALOGUE_SYSTEM:
                continue
            conversation_history = chat_checker_dialogue.chat_history[:k]
            last_bot_utterance = turn.content
            has_llm_label = dialogue.turns[k].llm_breakdown_annotations and dialogue.turns[k].llm_breakdown_annotations.get(config.key) is not None
            if has_llm_label and not recompute_existing_annotations:
                continue
            try:
                breakdown_info, prompt, model_response = config.breakdown_identifier.identify_breakdowns(
                    chat_history=conversation_history,
                    last_bot_utterance=last_bot_utterance,
                    is_task_oriented=config.include_task_oriented_errors,
                    llm_name=config.model,
                )
            except Exception as e:
                print(f"Error processing dialogue {dialogue.dialogue_id} at turn {k} with config {config}: {e}")
                # We simply skip this turn and continue to the next one (sometimes OpenAI refuses to answer {'refusal': "I'm sorry, I can't assist with that request."})
                continue
            if not dialogue.turns[k].llm_breakdown_annotations:
                dialogue.turns[k].llm_breakdown_annotations = {}
            dialogue.turns[k].llm_breakdown_annotations[config.key] = breakdown_info
            if k > 0 and not first_debug_stored:
                first_debug_stored = True
                write_prompt_to_txt_file(prompt, config_dir / "sample_0_prompt.txt")
                with open(
                    config_dir / "sample_0_model_response.json", "w", encoding="utf-8"
                ) as f:
                    json.dump(model_response.model_dump(), f, ensure_ascii=False, indent=2)
                cost = completion_cost(model_response)
                with open(
                    config_dir / "sample_0_response_cost.txt", "w", encoding="utf-8"
                ) as f:
                    f.write(f"Model response cost: {cost:.8f} USD\n")
            

        with open(tested_subset_dir / f"{dialogue.dialogue_id}.log.json", "w", encoding="utf-8") as f:
            json.dump(dialogue.model_dump(by_alias=True), f, indent=2, ensure_ascii=False)

# Evaluate the breakdown annotations against the ground truth

## Breakdown Detection Scores

In [None]:
for config in eval_configs:
    scores = compute_dbdc_detection_scores(
        dialogues=subset_for_testing,
        config_key=config.key,
        threshold=0.0,
    )
    scores.print_results()


## Error category classification scores on agreed breakdowns

In [None]:
config_category_counts = {}
config_mismatch_metrics = {}
for config in eval_configs:
    scores, category_counts, mismatch_metrics = compute_dbdc_error_classification_scores(
        dialogues=subset_for_testing,
        config_key=config.key,
        mode="agreed_breakdowns",
    )
    print(scores)
    config_category_counts[config.key] = category_counts
    config_mismatch_metrics[config.key] = mismatch_metrics


In [None]:
for config in eval_configs:
    print(config_category_counts[config.key])
    config_category_counts[config.key].plot_counts(save_dir=tested_subset_dir / config.key)

In [None]:
for config in eval_configs:
    print(config_mismatch_metrics[config.key])


## Error category classification scores on ground truth breakdowns

In [None]:
config_category_counts = {}
config_mismatch_metrics = {}
for config in eval_configs:
    scores, category_counts, mismatch_metrics = compute_dbdc_error_classification_scores(
        dialogues=subset_for_testing,
        config_key=config.key,
        mode="true_breakdowns",
    )
    print(scores)
    config_category_counts[config.key] = category_counts
    config_mismatch_metrics[config.key] = mismatch_metrics


In [None]:
for config in eval_configs:
    print(config_category_counts[config.key])
    config_category_counts[config.key].plot_counts()

In [None]:
for config in eval_configs:
    print(config_mismatch_metrics[config.key])
