# Evaluation of reasons for misclassifications

**Possible categories**:
- **Lack of context** (insufficient emphasis or indication of the context of the query, such as domain-specific background information about psychological assessments)
- **Lack of examples** (absence of representative few-shot examples showing correct classifications for similar psychological cases before posing the actual question)
- **Lack of feedback** (absence of contrasting cases with false psychological diagnoses or classifications that would improve the model's robustness against wrong decisions)
- **Lack of counterfactual demonstrations** (missing iterative refinements and dialogues where the user gives feedback and interactively refines the prompt)
- **Lack of opinion-based information** (missing subjective clinical judgement and contextual interpretations, e.g., reframing the data as a narrator’s statement and opinions rather than relying solely on quantitative symptoms)
- **Knowledge conflicts** (outdated memorized facts or contradictions in the model's training data between psychological diagnostic criteria or between different theoretical frameworks)
- **Prediction with Abstention** (model uncertainty regarding the binary classification or insufficient confidence in distinguishing between presence and absence of mental disorder development)

## 0 Imports

In [1]:
import pandas as pd
import numpy as np
import plotly.graph_objects as go

## 1 Reasons for Misclassifications

In [None]:
main_reasons_GPT_o3_simple_df = pd.read_csv(
    "04_Reasons_Misclassifications/reasons/GPT/main_reasons_GPT_o3_simple.csv", sep =",", index_col = 0)
main_reasons_GPT_o3_class_def_df = pd.read_csv(
    "04_Reasons_Misclassifications/reasons/GPT/main_reasons_GPT_o3_class_definitions.csv", sep =",", index_col = 0)
main_reasons_GPT_o3_profiled_simple_df = pd.read_csv(
    "04_Reasons_Misclassifications/reasons/GPT/main_reasons_GPT_o3_profiled_simple.csv", sep =",", index_col = 0)
main_reasons_GPT_o3_few_shot_df = pd.read_csv(
    "04_Reasons_Misclassifications/reasons/GPT/main_reasons_GPT_o3_few_shot.csv", sep =",", index_col = 0)
main_reasons_GPT_o3_vignette_df = pd.read_csv(
    "04_Reasons_Misclassifications/reasons/GPT/main_reasons_GPT_o3_vignette.csv", sep =",", index_col = 0)
main_reasons_GPT_o3_cot_df = pd.read_csv("04_Reasons_Misclassifications/reasons/GPT/main_reasons_GPT_o3_cot.csv", sep =",", index_col = 0)

In [None]:
main_reasons_Gemini_simple_df = pd.read_csv(
    "04_Reasons_Misclassifications/reasons/Gemini/main_reasons_Gemini_simple.csv", sep =",", index_col = 0)
main_reasons_Gemini_class_def_df = pd.read_csv(
    "04_Reasons_Misclassifications/reasons/Gemini/main_reasons_Gemini_class_definitions.csv", sep =",", index_col = 0)
main_reasons_Gemini_profiled_simple_df = pd.read_csv(
    "04_Reasons_Misclassifications/reasons/Gemini/main_reasons_Gemini_profiled_simple.csv", sep =",", index_col = 0)
main_reasons_Gemini_few_shot_df = pd.read_csv(
    "04_Reasons_Misclassifications/reasons/Gemini/main_reasons_Gemini_few_shot.csv", sep =",", index_col = 0)
main_reasons_Gemini_vignette_df = pd.read_csv(
    "04_Reasons_Misclassifications/reasons/Gemini/main_reasons_Gemini_vignette.csv", sep =",", index_col = 0)
main_reasons_Gemini_cot_df = pd.read_csv("04_Reasons_Misclassifications/reasons/Gemini/main_reasons_Gemini_cot.csv", sep =",", index_col = 0)

In [5]:
main_reasons_Gemma_simple_df = pd.read_csv("04_Reasons_Misclassifications/reasons/Gemma/main_reasons_Gemma_simple.csv", sep =",", index_col = 0)
main_reasons_Gemma_class_def_df = pd.read_csv(
    "04_Reasons_Misclassifications/reasons/Gemma/main_reasons_Gemma_class_definitions.csv", sep =",", index_col = 0)
main_reasons_Gemma_profiled_simple_df = pd.read_csv(
    "04_Reasons_Misclassifications/reasons/Gemma/main_reasons_Gemma_profiled_simple.csv", sep =",", index_col = 0)
main_reasons_Gemma_few_shot_df = pd.read_csv(
    "04_Reasons_Misclassifications/reasons/Gemma/main_reasons_Gemma_few_shot.csv", sep =",", index_col = 0)
main_reasons_Gemma_vignette_df = pd.read_csv(
    "04_Reasons_Misclassifications/reasons/Gemma/main_reasons_Gemma_vignette.csv", sep =",", index_col = 0)
main_reasons_Gemma_cot_df = pd.read_csv("04_Reasons_Misclassifications/reasons/Gemma/main_reasons_Gemma_cot.csv", sep =",", index_col = 0)

In [6]:
main_reasons_Claude_simple_df = pd.read_csv(
    "04_Reasons_Misclassifications/reasons/Claude/main_reasons_Claude_simple.csv", sep =",", index_col = 0)
main_reasons_Claude_class_def_df = pd.read_csv(
    "04_Reasons_Misclassifications/reasons/Claude/main_reasons_Claude_class_definitions.csv", sep =",", index_col = 0)
main_reasons_Claude_profiled_simple_df = pd.read_csv(
    "04_Reasons_Misclassifications/reasons/Claude/main_reasons_Claude_profiled_simple.csv", sep =",", index_col = 0)
main_reasons_Claude_few_shot_df = pd.read_csv(
    "04_Reasons_Misclassifications/reasons/Claude/main_reasons_Claude_few_shot.csv", sep =",", index_col = 0)
main_reasons_Claude_vignette_df = pd.read_csv(
    "04_Reasons_Misclassifications/reasons/Claude/main_reasons_Claude_vignette.csv", sep =",", index_col = 0)
main_reasons_Claude_cot_df = pd.read_csv("04_Reasons_Misclassifications/reasons/Claude/main_reasons_Claude_cot.csv", sep =",", index_col = 0)

In [7]:
main_reasons_DeepSeek_simple_df = pd.read_csv(
    "04_Reasons_Misclassifications/reasons/DeepSeek/main_reasons_DeepSeek_simple.csv", sep =",", index_col = 0)
main_reasons_DeepSeek_class_def_df = pd.read_csv(
    "04_Reasons_Misclassifications/reasons/DeepSeek/main_reasons_DeepSeek_class_definitions.csv", sep =",", index_col = 0)
main_reasons_DeepSeek_profiled_simple_df = pd.read_csv(
    "04_Reasons_Misclassifications/reasons/DeepSeek/main_reasons_DeepSeek_profiled_simple.csv", sep =",", index_col = 0)
main_reasons_DeepSeek_few_shot_df = pd.read_csv(
    "04_Reasons_Misclassifications/reasons/DeepSeek/main_reasons_DeepSeek_few_shot.csv", sep =",", index_col = 0)
main_reasons_DeepSeek_vignette_df = pd.read_csv(
    "04_Reasons_Misclassifications/reasons/DeepSeek/main_reasons_DeepSeek_vignette.csv", sep =",", index_col = 0)
main_reasons_DeepSeek_cot_df = pd.read_csv(
    "04_Reasons_Misclassifications/reasons/DeepSeek/main_reasons_DeepSeek_cot.csv", sep =",", index_col = 0)

In [8]:
main_reasons_Grok_simple_df = pd.read_csv("04_Reasons_Misclassifications/reasons/Grok/main_reasons_Grok_simple.csv", sep =",", index_col = 0)
main_reasons_Grok_class_def_df = pd.read_csv(
    "04_Reasons_Misclassifications/reasons/Grok/main_reasons_Grok_class_definitions.csv", sep =",", index_col = 0)
main_reasons_Grok_profiled_simple_df = pd.read_csv(
    "04_Reasons_Misclassifications/reasons/Grok/main_reasons_Grok_profiled_simple.csv", sep =",", index_col = 0)
main_reasons_Grok_few_shot_df = pd.read_csv("04_Reasons_Misclassifications/reasons/Grok/main_reasons_Grok_few_shot.csv", sep =",", index_col = 0)
main_reasons_Grok_vignette_df = pd.read_csv("04_Reasons_Misclassifications/reasons/Grok/main_reasons_Grok_vignette.csv", sep =",", index_col = 0)
main_reasons_Grok_cot_df = pd.read_csv("04_Reasons_Misclassifications/reasons/Grok/main_reasons_Grok_cot.csv", sep =",", index_col = 0)

In [10]:
main_reasons_GPT_o3_simple_df = main_reasons_GPT_o3_simple_df.sort_values(by = "count", ascending = False)
main_reasons_GPT_o3_class_def_df = main_reasons_GPT_o3_class_def_df.sort_values(by = "count", ascending = False)
main_reasons_GPT_o3_profiled_simple_df = main_reasons_GPT_o3_profiled_simple_df.sort_values(by = "count", ascending = False)
main_reasons_GPT_o3_few_shot_df = main_reasons_GPT_o3_few_shot_df.sort_values(by = "count", ascending = False)
main_reasons_GPT_o3_vignette_df = main_reasons_GPT_o3_vignette_df.sort_values(by = "count", ascending = False)
main_reasons_GPT_o3_cot_df = main_reasons_GPT_o3_cot_df.sort_values(by = "count", ascending = False)

In [11]:
main_reasons_Gemini_simple_df = main_reasons_Gemini_simple_df.sort_values(by = "count", ascending = False)
main_reasons_Gemini_class_def_df = main_reasons_Gemini_class_def_df.sort_values(by = "count", ascending = False)
main_reasons_Gemini_profiled_simple_df = main_reasons_Gemini_profiled_simple_df.sort_values(by = "count", ascending = False)
main_reasons_Gemini_few_shot_df = main_reasons_Gemini_few_shot_df.sort_values(by = "count", ascending = False)
main_reasons_Gemini_vignette_df = main_reasons_Gemini_vignette_df.sort_values(by = "count", ascending = False)
main_reasons_Gemini_cot_df = main_reasons_Gemini_cot_df.sort_values(by = "count", ascending = False)

In [12]:
main_reasons_Gemma_simple_df = main_reasons_Gemma_simple_df.sort_values(by = "count", ascending = False)
main_reasons_Gemma_class_def_df = main_reasons_Gemma_class_def_df.sort_values(by = "count", ascending = False)
main_reasons_Gemma_profiled_simple_df = main_reasons_Gemma_profiled_simple_df.sort_values(by = "count", ascending = False)
main_reasons_Gemma_few_shot_df = main_reasons_Gemma_few_shot_df.sort_values(by = "count", ascending = False)
main_reasons_Gemma_vignette_df = main_reasons_Gemma_vignette_df.sort_values(by = "count", ascending = False)
main_reasons_Gemma_cot_df = main_reasons_Gemma_cot_df.sort_values(by = "count", ascending = False)

In [13]:
main_reasons_Claude_simple_df = main_reasons_Claude_simple_df.sort_values(by = "count", ascending = False)
main_reasons_Claude_class_def_df = main_reasons_Claude_class_def_df.sort_values(by = "count", ascending = False)
main_reasons_Claude_profiled_simple_df = main_reasons_Claude_profiled_simple_df.sort_values(by = "count", ascending = False)
main_reasons_Claude_few_shot_df = main_reasons_Claude_few_shot_df.sort_values(by = "count", ascending = False)
main_reasons_Claude_vignette_df = main_reasons_Claude_vignette_df.sort_values(by = "count", ascending = False)
main_reasons_Claude_cot_df = main_reasons_Claude_cot_df.sort_values(by = "count", ascending = False)

In [14]:
main_reasons_DeepSeek_simple_df = main_reasons_DeepSeek_simple_df.sort_values(by = "count", ascending = False)
main_reasons_DeepSeek_class_def_df = main_reasons_DeepSeek_class_def_df.sort_values(by = "count", ascending = False)
main_reasons_DeepSeek_profiled_simple_df = main_reasons_DeepSeek_profiled_simple_df.sort_values(by = "count", ascending = False)
main_reasons_DeepSeek_few_shot_df = main_reasons_DeepSeek_few_shot_df.sort_values(by = "count", ascending = False)
main_reasons_DeepSeek_vignette_df = main_reasons_DeepSeek_vignette_df.sort_values(by = "count", ascending = False)
main_reasons_DeepSeek_cot_df = main_reasons_DeepSeek_cot_df.sort_values(by = "count", ascending = False)

In [15]:
main_reasons_Grok_simple_df = main_reasons_Grok_simple_df.sort_values(by = "count", ascending = False)
main_reasons_Grok_class_def_df = main_reasons_Grok_class_def_df.sort_values(by = "count", ascending = False)
main_reasons_Grok_profiled_simple_df = main_reasons_Grok_profiled_simple_df.sort_values(by = "count", ascending = False)
main_reasons_Grok_few_shot_df = main_reasons_Grok_few_shot_df.sort_values(by = "count", ascending = False)
main_reasons_Grok_vignette_df = main_reasons_Grok_vignette_df.sort_values(by = "count", ascending = False)
main_reasons_Grok_cot_df = main_reasons_Grok_cot_df.sort_values(by = "count", ascending = False)

In [None]:
cases_GPT_o3_simple_df = pd.read_csv("04_Reasons_Misclassifications/reasons/GPT/cases_GPT_o3_simple.csv", sep =",", index_col = 0)
cases_GPT_o3_class_def_df = pd.read_csv(
    "04_Reasons_Misclassifications/reasons/GPT/cases_GPT_o3_class_definitions.csv", sep =",", index_col = 0)
cases_GPT_o3_profiled_simple_df = pd.read_csv(
    "04_Reasons_Misclassifications/reasons/GPT/cases_GPT_o3_profiled_simple.csv", sep =",", index_col = 0)
cases_GPT_o3_few_shot_df = pd.read_csv("04_Reasons_Misclassifications/reasons/GPT/cases_GPT_o3_few_shot.csv", sep =",", index_col = 0)
cases_GPT_o3_vignette_df = pd.read_csv("04_Reasons_Misclassifications/reasons/GPT/cases_GPT_o3_vignette.csv", sep =",", index_col = 0)
cases_GPT_o3_cot_df = pd.read_csv("04_Reasons_Misclassifications/reasons/GPT/cases_GPT_o3_cot.csv", sep =",", index_col = 0)

In [None]:
cases_Gemini_simple_df = pd.read_csv("04_Reasons_Misclassifications/reasons/Gemini/cases_Gemini_simple.csv", sep =",", index_col = 0)
cases_Gemini_class_def_df = pd.read_csv(
    "04_Reasons_Misclassifications/reasons/Gemini/cases_Gemini_class_definitions.csv", sep =",", index_col = 0)
cases_Gemini_profiled_simple_df = pd.read_csv(
    "04_Reasons_Misclassifications/reasons/Gemini/cases_Gemini_profiled_simple.csv", sep =",", index_col = 0)
cases_Gemini_few_shot_df = pd.read_csv("04_Reasons_Misclassifications/reasons/Gemini/cases_Gemini_few_shot.csv", sep =",", index_col = 0)
cases_Gemini_vignette_df = pd.read_csv("04_Reasons_Misclassifications/reasons/Gemini/cases_Gemini_vignette.csv", sep =",", index_col = 0)
cases_Gemini_cot_df = pd.read_csv("04_Reasons_Misclassifications/reasons/Gemini/cases_Gemini_cot.csv", sep =",", index_col = 0)

In [21]:
cases_Gemma_simple_df = pd.read_csv("04_Reasons_Misclassifications/reasons/Gemma/cases_Gemma_simple.csv", sep =",", index_col = 0)
cases_Gemma_class_def_df = pd.read_csv("04_Reasons_Misclassifications/reasons/Gemma/cases_Gemma_class_definitions.csv", sep =",", index_col = 0)
cases_Gemma_profiled_simple_df = pd.read_csv(
    "04_Reasons_Misclassifications/reasons/Gemma/cases_Gemma_profiled_simple.csv", sep =",", index_col = 0)
cases_Gemma_few_shot_df = pd.read_csv("04_Reasons_Misclassifications/reasons/Gemma/cases_Gemma_few_shot.csv", sep =",", index_col = 0)
cases_Gemma_vignette_df = pd.read_csv("04_Reasons_Misclassifications/reasons/Gemma/cases_Gemma_vignette.csv", sep =",", index_col = 0)
cases_Gemma_cot_df = pd.read_csv("04_Reasons_Misclassifications/reasons/Gemma/cases_Gemma_cot.csv", sep =",", index_col = 0)

In [22]:
cases_Claude_simple_df = pd.read_csv("04_Reasons_Misclassifications/reasons/Claude/cases_Claude_simple.csv", sep =",", index_col = 0)
cases_Claude_class_def_df = pd.read_csv(
    "04_Reasons_Misclassifications/reasons/Claude/cases_Claude_class_definitions.csv", sep =",", index_col = 0)
cases_Claude_profiled_simple_df = pd.read_csv(
    "04_Reasons_Misclassifications/reasons/Claude/cases_Claude_profiled_simple.csv", sep =",", index_col = 0)
cases_Claude_few_shot_df = pd.read_csv("04_Reasons_Misclassifications/reasons/Claude/cases_Claude_few_shot.csv", sep =",", index_col = 0)
cases_Claude_vignette_df = pd.read_csv("04_Reasons_Misclassifications/reasons/Claude/cases_Claude_vignette.csv", sep =",", index_col = 0)
cases_Claude_cot_df = pd.read_csv("04_Reasons_Misclassifications/reasons/Claude/cases_Claude_cot.csv", sep =",", index_col = 0)

In [23]:
cases_DeepSeek_simple_df = pd.read_csv("04_Reasons_Misclassifications/reasons/DeepSeek/cases_DeepSeek_simple.csv", sep =",", index_col = 0)
cases_DeepSeek_class_def_df = pd.read_csv(
    "04_Reasons_Misclassifications/reasons/DeepSeek/cases_DeepSeek_class_definitions.csv", sep =",", index_col = 0)
cases_DeepSeek_profiled_simple_df = pd.read_csv(
    "04_Reasons_Misclassifications/reasons/DeepSeek/cases_DeepSeek_profiled_simple.csv", sep =",", index_col = 0)
cases_DeepSeek_few_shot_df = pd.read_csv("04_Reasons_Misclassifications/reasons/DeepSeek/cases_DeepSeek_few_shot.csv", sep =",", index_col = 0)
cases_DeepSeek_vignette_df = pd.read_csv("04_Reasons_Misclassifications/reasons/DeepSeek/cases_DeepSeek_vignette.csv", sep =",", index_col = 0)
cases_DeepSeek_cot_df = pd.read_csv("04_Reasons_Misclassifications/reasons/DeepSeek/cases_DeepSeek_cot.csv", sep =",", index_col = 0)

In [24]:
cases_Grok_simple_df = pd.read_csv("04_Reasons_Misclassifications/reasons/Grok/cases_Grok_simple.csv", sep =",", index_col = 0)
cases_Grok_class_def_df = pd.read_csv("04_Reasons_Misclassifications/reasons/Grok/cases_Grok_class_definitions.csv", sep =",", index_col = 0)
cases_Grok_profiled_simple_df = pd.read_csv("04_Reasons_Misclassifications/reasons/Grok/cases_Grok_profiled_simple.csv", sep =",", index_col = 0)
cases_Grok_few_shot_df = pd.read_csv("04_Reasons_Misclassifications/reasons/Grok/cases_Grok_few_shot.csv", sep =",", index_col = 0)
cases_Grok_vignette_df = pd.read_csv("04_Reasons_Misclassifications/reasons/Grok/cases_Grok_vignette.csv", sep =",", index_col = 0)
cases_Grok_cot_df = pd.read_csv("04_Reasons_Misclassifications/reasons/Grok/cases_Grok_cot.csv", sep =",", index_col = 0)

## 2 Sankey Diagram

In [25]:
def lighten_color(color, amount=0.8):
    """Lighten an RGB or hex color by blending with white."""
    import matplotlib.colors as mc

    try:
        c = mc.cnames[color]
    except:
        c = color
    c = mc.to_rgb(c)
    lightened = tuple(1 - amount * (1 - x) for x in c)
    return f"rgb({int(lightened[0]*255)}, {int(lightened[1]*255)}, {int(lightened[2]*255)})"

In [29]:
def cal_percentage(all_main_df, all_cases_df):
    # group by index and sum the counts
    all_main_df = all_main_df.groupby(all_main_df.index).sum()
    # calculate the total count of misclassifications
    total_misclassifications_simple = all_main_df["count"].sum()
    # calculate the percentage of misclassifications that are due to each category
    all_main_df["percentage"] = round(all_main_df["count"] / total_misclassifications_simple * 100, 1)
    # delete count column
    all_main_df = all_main_df.drop(columns=["count"])
    # sort df
    all_main_df = all_main_df.sort_values(by="percentage", ascending=False)

    # multiply the percentage by percentage of misclassifications
    all_main_df["percentage"] = round(all_main_df["percentage"] * all_cases_df["missclassified"].sum() / 100, 2)

    return all_main_df

In [30]:
def cal_percentage_cases(all_cases_df):
    # calc mean per column
    all_cases_df = all_cases_df.mean()

    # devide by the total number of cases
    total_cases = all_cases_df["total"]
    all_cases_df["correct"] = round(all_cases_df["correct"] / total_cases * 100, 1)
    all_cases_df["missclassified"] = round(all_cases_df["missclassified"] / total_cases * 100, 1)
    all_cases_df["total"] = round(all_cases_df["total"] / total_cases * 100, 1)

    return all_cases_df

In [31]:
def plot_global_misclassification_sankey(name, prompt_cases_dict, prompt_reasons_dict):
    labels = []
    sources = []
    targets = []
    values = []

    label_to_index = {}
    idx_counter = 0

    base_colors = {
        "Total misclassifications": "#042940",
        "Simple prompt": "#03588C",
        "Class def. prompt": "#BF2C53",
        "Profiled s. prompt": "#6F8F00",
        "Few-shot prompt": "#EB7801",
        "Vignette prompt": "#CCB900",
        "CoT prompt": "#8C0327"
    }

    reasons_colors = {
        "Lack of context": "#7EB5D6",
        "Lack of examples": "#AAC72A",
        "Lack of counterfactual demonstrations": "#BA5D77",
        "Lack of opinionbased information": "#DACE65",
        "Knowledge conflicts": "#EBAA64",
        "Prediction with Abstention": "#ED77A0"
    }

    label_aliases = {
        "Lack of context": "Lack of context",
        "Lack of examples": "Lack of examples",
        "Lack of counterfactual demonstrations": "Lack of countf. ex.",
        "Lack of opinionbased information": "Lack of opin. info.",
        "Knowledge conflicts": "Knowledge conflicts",
        "Prediction with Abstention": "Pred. with abstention"
    }

    node_colors = []
    link_colors = []
    reason_color_map = {}
    dummy_links = []

    total_misscl = 0

    # 1. Total Node
    total_cases = sum(int(df['total'].mean()) for df in prompt_cases_dict.values())
    labels.append("Total misclassifications")
    label_to_index["Total misclassifications"] = idx_counter
    node_colors.append(base_colors["Total misclassifications"])
    idx_counter += 1

    # 2. Prompt Nodes
    for prompt_name, cases_df in prompt_cases_dict.items():
        prompt_label = f"{prompt_name} <span style='color:#555; font-weight:normal;'>({int(cases_df['missclassified'].mean())})</span>"
        total_misscl += int(cases_df["missclassified"].mean())
        labels.append(prompt_label)
        label_to_index[prompt_name] = idx_counter
        node_colors.append(base_colors.get(prompt_name, "#888888"))

        sources.append(label_to_index["Total misclassifications"])
        targets.append(label_to_index[prompt_name])
        values.append(int(cases_df["missclassified"].mean()))
        link_colors.append(lighten_color(base_colors.get(prompt_name, "#888888"), 0.4))

        idx_counter += 1

    # 3. Reason Nodes + Dummy Nodes
    reason_totals = {}

    # Aggregate total reason counts across all prompts
    for reasons_df in prompt_reasons_dict.values():
        reasons_df.index = reasons_df.index.astype(str).str.strip()
        grouped = reasons_df.groupby(reasons_df.index).sum()
        for reason, row in grouped.iterrows():
            cleaned_reason = reason.strip()
            reason_totals[cleaned_reason] = reason_totals.get(cleaned_reason, 0) + int(row["count"])

    # Sort reasons by total count descending
    sorted_reasons = sorted(reason_totals.items(), key=lambda x: x[1], reverse=True)

    # Add reason and dummy nodes in sorted order
    for cleaned_reason, _ in sorted_reasons:
        if cleaned_reason not in label_to_index:

            reason_percent = round(reason_totals[cleaned_reason] / total_misscl * 100, 1)
            reason_text = label_aliases.get(cleaned_reason, cleaned_reason)
            short_label = f"{reason_text} <span style='color:#555; font-weight:normal;'>({reason_percent}%)</span>"
            labels.append(short_label)

            label_to_index[cleaned_reason] = idx_counter

            color = reasons_colors.get(cleaned_reason, "#AAAAAA")
            reason_color_map[cleaned_reason] = color
            node_colors.append(color)
            idx_counter += 1

            # Dummy node
            dummy_label = f"_{cleaned_reason}_dummy"
            labels.append("")  # blank label
            label_to_index[dummy_label] = idx_counter
            node_colors.append("rgba(0,0,0,0)")
            idx_counter += 1

            sources.append(label_to_index[cleaned_reason])
            targets.append(label_to_index[dummy_label])
            values.append(1e-6)
            link_colors.append("rgba(255,255,255,0)")

    # Now build links from prompt → reasons
    for prompt_name, reasons_df in prompt_reasons_dict.items():
        reasons_df.index = reasons_df.index.astype(str).str.strip()
        grouped = reasons_df.groupby(reasons_df.index).sum()

        for reason, row in grouped.iterrows():
            cleaned_reason = reason.strip()
            sources.append(label_to_index[prompt_name])
            targets.append(label_to_index[cleaned_reason])
            values.append(int(row["count"]))
            link_colors.append(lighten_color(reason_color_map[cleaned_reason], 0.4))


    num_total_nodes = len(labels)
    num_prompt_nodes = len(prompt_cases_dict)
    num_reason_nodes = len(reason_color_map)

    # Compute y positions for prompts and reasons
    prompt_ys = np.linspace(0.1, 0.9, num_prompt_nodes)
    spread_factor = 0.19 * num_reason_nodes if num_reason_nodes <= 5 else 0.17 * num_reason_nodes
    reason_ys = np.linspace(0.05, spread_factor, num_reason_nodes)

    node_x = [None] * num_total_nodes
    node_y = [None] * num_total_nodes

    # Assign Total Misclassification node
    node_x[label_to_index["Total misclassifications"]] = 0.0
    node_y[label_to_index["Total misclassifications"]] = 0.5

    # Assign prompt positions
    for i, (prompt_name, _) in enumerate(prompt_cases_dict.items()):
        idx = label_to_index[prompt_name]
        node_x[idx] = 0.3
        node_y[idx] = prompt_ys[i]

    # Assign reason positions
    for i, reason in enumerate(reason_color_map.keys()):
        idx = label_to_index[reason]
        node_x[idx] = 0.65
        node_y[idx] = reason_ys[i]

    # Add dummy nodes for spacing labels right of reason nodes
    dummy_label_indices = []
    for i, reason in enumerate(reason_color_map.keys()):

        labels.append("\u200B")

        # Make node color fully transparent
        node_colors.append("rgba(0,0,0,0)")
        dummy_idx = len(labels) - 1

        # Invisible link (tiny value)
        sources.append(label_to_index[reason])
        targets.append(dummy_idx)
        values.append(1e-10)
        link_colors.append("rgba(0,0,0,0)")  # invisible

        # Position it at the absolute right
        node_x.append(1.05)
        node_y.append(reason_ys[i])


    fig = go.Figure(data=[go.Sankey(
        arrangement="snap",
        node=dict(
            pad=15,
            thickness=20,
            line=dict(color="black", width=0.5),
            label=labels,
            color=node_colors,
            x=node_x,
            y=node_y
        ),
        link=dict(
            source=sources,
            target=targets,
            value=values,
            color=link_colors
        )
    )])

    fig.update_layout(
        font=dict(size=22, color='black', family='Times New Roman', weight=549),
        margin=dict(t=50, l=10, r=300, b=50),
        height=350,
        width=1300
    )

    fig.show()

In [32]:
prompt_cases_dict_GPT = {
    "Simple prompt": cases_GPT_o3_simple_df,
    "Class def. prompt": cases_GPT_o3_class_def_df,
    "Profiled s. prompt": cases_GPT_o3_profiled_simple_df,
    "Few-shot prompt": cases_GPT_o3_few_shot_df,
    "Vignette prompt": cases_GPT_o3_vignette_df,
    "CoT prompt": cases_GPT_o3_cot_df
}

prompt_reasons_dict_GPT = {
    "Simple prompt": main_reasons_GPT_o3_simple_df,
    "Class def. prompt": main_reasons_GPT_o3_class_def_df,
    "Profiled s. prompt": main_reasons_GPT_o3_profiled_simple_df,
    "Few-shot prompt": main_reasons_GPT_o3_few_shot_df,
    "Vignette prompt": main_reasons_GPT_o3_vignette_df,
    "CoT prompt": main_reasons_GPT_o3_cot_df
}

plot_global_misclassification_sankey("GPT", prompt_cases_dict_GPT, prompt_reasons_dict_GPT)

In [43]:
prompt_cases_dict_Gemini = {
    "Simple prompt": cases_Gemini_simple_df,
    "Class def. prompt": cases_Gemini_class_def_df,
    "Profiled s. prompt": cases_Gemini_profiled_simple_df,
    "Few-shot prompt": cases_Gemini_few_shot_df,
    "Vignette prompt": cases_Gemini_vignette_df,
    "CoT prompt": cases_Gemini_cot_df
}

prompt_reasons_dict_Gemini = {
    "Simple prompt": main_reasons_Gemini_simple_df,
    "Class def. prompt": main_reasons_Gemini_class_def_df,
    "Profiled s. prompt": main_reasons_Gemini_profiled_simple_df,
    "Few-shot prompt": main_reasons_Gemini_few_shot_df,
    "Vignette prompt": main_reasons_Gemini_vignette_df,
    "CoT prompt": main_reasons_Gemini_cot_df
}

plot_global_misclassification_sankey("Gemini", prompt_cases_dict_Gemini, prompt_reasons_dict_Gemini)

In [44]:
prompt_cases_dict_Gemma = {
    "Simple prompt": cases_Gemma_simple_df,
    "Class def. prompt": cases_Gemma_class_def_df,
    "Profiled s. prompt": cases_Gemma_profiled_simple_df,
    "Few-shot prompt": cases_Gemma_few_shot_df,
    "Vignette prompt": cases_Gemma_vignette_df,
    "CoT prompt": cases_Gemma_cot_df
}

prompt_reasons_dict_Gemma = {
    "Simple prompt": main_reasons_Gemma_simple_df,
    "Class def. prompt": main_reasons_Gemma_class_def_df,
    "Profiled s. prompt": main_reasons_Gemma_profiled_simple_df,
    "Few-shot prompt": main_reasons_Gemma_few_shot_df,
    "Vignette prompt": main_reasons_Gemma_vignette_df,
    "CoT prompt": main_reasons_Gemma_cot_df
}

plot_global_misclassification_sankey("Gemma", prompt_cases_dict_Gemma, prompt_reasons_dict_Gemma)

In [45]:
prompt_cases_dict_Claude = {
    "Simple prompt": cases_Claude_simple_df,
    "Class def. prompt": cases_Claude_class_def_df,
    "Profiled s. prompt": cases_Claude_profiled_simple_df,
    "Few-shot prompt": cases_Claude_few_shot_df,
    "Vignette prompt": cases_Claude_vignette_df,
    "CoT prompt": cases_Claude_cot_df
}

prompt_reasons_dict_Claude = {
    "Simple prompt": main_reasons_Claude_simple_df,
    "Class def. prompt": main_reasons_Claude_class_def_df,
    "Profiled s. prompt": main_reasons_Claude_profiled_simple_df,
    "Few-shot prompt": main_reasons_Claude_few_shot_df,
    "Vignette prompt": main_reasons_Claude_vignette_df,
    "CoT prompt": main_reasons_Claude_cot_df
}

plot_global_misclassification_sankey("Claude", prompt_cases_dict_Claude, prompt_reasons_dict_Claude)

In [46]:
prompt_cases_dict_DeepSeek = {
    "Simple prompt": cases_DeepSeek_simple_df,
    "Class def. prompt": cases_DeepSeek_class_def_df,
    "Profiled s. prompt": cases_DeepSeek_profiled_simple_df,
    "Few-shot prompt": cases_DeepSeek_few_shot_df,
    "Vignette prompt": cases_DeepSeek_vignette_df,
    "CoT prompt": cases_DeepSeek_cot_df
}

prompt_reasons_dict_DeepSeek = {
    "Simple prompt": main_reasons_DeepSeek_simple_df,
    "Class def. prompt": main_reasons_DeepSeek_class_def_df,
    "Profiled s. prompt": main_reasons_DeepSeek_profiled_simple_df,
    "Few-shot prompt": main_reasons_DeepSeek_few_shot_df,
    "Vignette prompt": main_reasons_DeepSeek_vignette_df,
    "CoT prompt": main_reasons_DeepSeek_cot_df
}

plot_global_misclassification_sankey("DeepSeek", prompt_cases_dict_DeepSeek, prompt_reasons_dict_DeepSeek)

In [47]:
prompt_cases_dict_Grok = {
    "Simple prompt": cases_Grok_simple_df,
    "Class def. prompt": cases_Grok_class_def_df,
    "Profiled s. prompt": cases_Grok_profiled_simple_df,
    "Few-shot prompt": cases_Grok_few_shot_df,
    "Vignette prompt": cases_Grok_vignette_df,
    "CoT prompt": cases_Grok_cot_df
}

prompt_reasons_dict_Grok = {
    "Simple prompt": main_reasons_Grok_simple_df,
    "Class def. prompt": main_reasons_Grok_class_def_df,
    "Profiled s. prompt": main_reasons_Grok_profiled_simple_df,
    "Few-shot prompt": main_reasons_Grok_few_shot_df,
    "Vignette prompt": main_reasons_Grok_vignette_df,
    "CoT prompt": main_reasons_Grok_cot_df
}

plot_global_misclassification_sankey("Grok", prompt_cases_dict_Grok, prompt_reasons_dict_Grok)

## 3 Percentage of misclassifications per category

In [39]:
# for all LLMs and all prompts, calculate the percentage of misclassifications that are due to each category
all_main_reasons_df = pd.concat([

    main_reasons_GPT_o3_simple_df,
    main_reasons_GPT_o3_class_def_df,
    main_reasons_GPT_o3_profiled_simple_df,
    main_reasons_GPT_o3_few_shot_df,
    main_reasons_GPT_o3_vignette_df,
    main_reasons_GPT_o3_cot_df,

    main_reasons_Gemini_simple_df,
    main_reasons_Gemini_class_def_df,
    main_reasons_Gemini_profiled_simple_df,
    main_reasons_Gemini_few_shot_df,
    main_reasons_Gemini_vignette_df,
    main_reasons_Gemini_cot_df,

    main_reasons_Gemma_simple_df,
    main_reasons_Gemma_class_def_df,
    main_reasons_Gemma_profiled_simple_df,
    main_reasons_Gemma_few_shot_df,
    main_reasons_Gemma_vignette_df,
    main_reasons_Gemma_cot_df,

    main_reasons_Claude_simple_df,
    main_reasons_Claude_class_def_df,
    main_reasons_Claude_profiled_simple_df,
    main_reasons_Claude_few_shot_df,
    main_reasons_Claude_vignette_df,
    main_reasons_Claude_cot_df,

    main_reasons_DeepSeek_simple_df,
    main_reasons_DeepSeek_class_def_df,
    main_reasons_DeepSeek_profiled_simple_df,
    main_reasons_DeepSeek_few_shot_df,
    main_reasons_DeepSeek_vignette_df,
    main_reasons_DeepSeek_cot_df,

    main_reasons_Grok_simple_df,
    main_reasons_Grok_class_def_df,
    main_reasons_Grok_profiled_simple_df,
    main_reasons_Grok_few_shot_df,
    main_reasons_Grok_vignette_df,
    main_reasons_Grok_cot_df
], axis=0)

all_cases_df = pd.concat([
    cases_GPT_o3_simple_df,
    cases_GPT_o3_class_def_df,
    cases_GPT_o3_profiled_simple_df,
    cases_GPT_o3_few_shot_df,
    cases_GPT_o3_vignette_df,
    cases_GPT_o3_cot_df,

    cases_Gemini_simple_df,
    cases_Gemini_class_def_df,
    cases_Gemini_profiled_simple_df,
    cases_Gemini_few_shot_df,
    cases_Gemini_vignette_df,
    cases_Gemini_cot_df,

    cases_Gemma_simple_df,
    cases_Gemma_class_def_df,
    cases_Gemma_profiled_simple_df,
    cases_Gemma_few_shot_df,
    cases_Gemma_vignette_df,
    cases_Gemma_cot_df,

    cases_Claude_simple_df,
    cases_Claude_class_def_df,
    cases_Claude_profiled_simple_df,
    cases_Claude_few_shot_df,
    cases_Claude_vignette_df,
    cases_Claude_cot_df,

    cases_DeepSeek_simple_df,
    cases_DeepSeek_class_def_df,
    cases_DeepSeek_profiled_simple_df,
    cases_DeepSeek_few_shot_df,
    cases_DeepSeek_vignette_df,
    cases_DeepSeek_cot_df,

    cases_Grok_simple_df,
    cases_Grok_class_def_df,
    cases_Grok_profiled_simple_df,
    cases_Grok_few_shot_df,
    cases_Grok_vignette_df,
    cases_Grok_cot_df
], axis=0)

In [40]:
# calc percentage of misclassifications that are due to each category
all_main_reasons_df = all_main_reasons_df.copy()
all_main_reasons_df = all_main_reasons_df.groupby(all_main_reasons_df.index).mean()
all_main_reasons_df["count"] = round(all_main_reasons_df["count"], 2)
all_main_reasons_df = all_main_reasons_df.sort_values(by="count", ascending=False)

In [41]:
all_main_reasons_df

Unnamed: 0,count
Lack of context,58.34
Lack of opinionbased information,28.47
Lack of counterfactual demonstrations,22.05
Lack of examples,16.12
Knowledge conflicts,12.72
Prediction with Abstention,3.07
