# Evaluation of reasons for misclassifications

**Possible categories**:
- **Lack of context** (insufficient emphasis or indication of the context of the query, such as domain-specific background information about psychological assessments)
- **Lack of examples** (absence of representative few-shot examples showing correct classifications for similar psychological cases before posing the actual question)
- **Lack of feedback** (absence of contrasting cases with false psychological diagnoses or classifications that would improve the model's robustness against wrong decisions)
- **Lack of counterfactual demonstrations** (missing iterative refinements and dialogues where the user gives feedback and interactively refines the prompt)
- **Lack of opinion-based information** (missing subjective clinical judgement and contextual interpretations, e.g., reframing the data as a narrator’s statement and opinions rather than relying solely on quantitative symptoms)
- **Knowledge conflicts** (outdated memorized facts or contradictions in the model's training data between psychological diagnostic criteria or between different theoretical frameworks)
- **Prediction with Abstention** (model uncertainty regarding the binary classification or insufficient confidence in distinguishing between presence and absence of mental disorder development)

## 0 Imports

In [9]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import holoviews as hv
import plotly.graph_objects as go
from holoviews import opts
from matplotlib.colors import ListedColormap
from sklearn.metrics import confusion_matrix, recall_score, matthews_corrcoef, accuracy_score
from sklearn.model_selection import train_test_split

## 1 Misclassifications

In [10]:
main_reasons_GPT_4_simple_df = pd.read_csv("03_Reasons_Misclassifications/reasons/GPT_4/main_reasons_GPT_4_simple.csv", sep = ",", index_col = 0)
main_reasons_GPT_4_class_def_df = pd.read_csv("03_Reasons_Misclassifications/reasons/GPT_4/main_reasons_GPT_4_class_definitions.csv", sep = ",", index_col = 0)
main_reasons_GPT_4_profiled_simple_df = pd.read_csv("03_Reasons_Misclassifications/reasons/GPT_4/main_reasons_GPT_4_profiled_simple.csv", sep = ",", index_col = 0)
main_reasons_GPT_4_few_shot_df = pd.read_csv("03_Reasons_Misclassifications/reasons/GPT_4/main_reasons_GPT_4_few_shot.csv", sep = ",", index_col = 0)
main_reasons_GPT_4_vignette_df = pd.read_csv("03_Reasons_Misclassifications/reasons/GPT_4/main_reasons_GPT_4_vignette.csv", sep = ",", index_col = 0)
main_reasons_GPT_4_cot_df = pd.read_csv("03_Reasons_Misclassifications/reasons/GPT_4/main_reasons_GPT_4_cot.csv", sep = ",", index_col = 0)

In [11]:
main_reasons_GPT_o3_simple_df = pd.read_csv("03_Reasons_Misclassifications/reasons/GPT_o3/main_reasons_GPT_o3_simple.csv", sep = ",", index_col = 0)
main_reasons_GPT_o3_class_def_df = pd.read_csv("03_Reasons_Misclassifications/reasons/GPT_o3/main_reasons_GPT_o3_class_definitions.csv", sep = ",", index_col = 0)
main_reasons_GPT_o3_profiled_simple_df = pd.read_csv("03_Reasons_Misclassifications/reasons/GPT_o3/main_reasons_GPT_o3_profiled_simple.csv", sep = ",", index_col = 0)
main_reasons_GPT_o3_few_shot_df = pd.read_csv("03_Reasons_Misclassifications/reasons/GPT_o3/main_reasons_GPT_o3_few_shot.csv", sep = ",", index_col = 0)
main_reasons_GPT_o3_vignette_df = pd.read_csv("03_Reasons_Misclassifications/reasons/GPT_o3/main_reasons_GPT_o3_vignette.csv", sep = ",", index_col = 0)
main_reasons_GPT_o3_cot_df = pd.read_csv("03_Reasons_Misclassifications/reasons/GPT_o3/main_reasons_GPT_o3_cot.csv", sep = ",", index_col = 0)

In [12]:
main_reasons_Gemini_simple_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Gemini/main_reasons_Gemini_simple.csv", sep = ",", index_col = 0)
main_reasons_Gemini_class_def_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Gemini/main_reasons_Gemini_class_definitions.csv", sep = ",", index_col = 0)
main_reasons_Gemini_profiled_simple_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Gemini/main_reasons_Gemini_profiled_simple.csv", sep = ",", index_col = 0)
main_reasons_Gemini_few_shot_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Gemini/main_reasons_Gemini_few_shot.csv", sep = ",", index_col = 0)
main_reasons_Gemini_vignette_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Gemini/main_reasons_Gemini_vignette.csv", sep = ",", index_col = 0)
main_reasons_Gemini_cot_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Gemini/main_reasons_Gemini_cot.csv", sep = ",", index_col = 0)

In [13]:
main_reasons_Gemma_simple_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Gemma/main_reasons_Gemma_simple.csv", sep = ",", index_col = 0)
main_reasons_Gemma_class_def_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Gemma/main_reasons_Gemma_class_definitions.csv", sep = ",", index_col = 0)
main_reasons_Gemma_profiled_simple_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Gemma/main_reasons_Gemma_profiled_simple.csv", sep = ",", index_col = 0)
main_reasons_Gemma_few_shot_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Gemma/main_reasons_Gemma_few_shot.csv", sep = ",", index_col = 0)
main_reasons_Gemma_vignette_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Gemma/main_reasons_Gemma_vignette.csv", sep = ",", index_col = 0)
main_reasons_Gemma_cot_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Gemma/main_reasons_Gemma_cot.csv", sep = ",", index_col = 0)

In [14]:
main_reasons_Claude_simple_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Claude/main_reasons_Claude_simple.csv", sep = ",", index_col = 0)
main_reasons_Claude_class_def_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Claude/main_reasons_Claude_class_definitions.csv", sep = ",", index_col = 0)
main_reasons_Claude_profiled_simple_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Claude/main_reasons_Claude_profiled_simple.csv", sep = ",", index_col = 0)
main_reasons_Claude_few_shot_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Claude/main_reasons_Claude_few_shot.csv", sep = ",", index_col = 0)
main_reasons_Claude_vignette_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Claude/main_reasons_Claude_vignette.csv", sep = ",", index_col = 0)
main_reasons_Claude_cot_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Claude/main_reasons_Claude_cot.csv", sep = ",", index_col = 0)

In [15]:
main_reasons_DeepSeek_simple_df = pd.read_csv("03_Reasons_Misclassifications/reasons/DeepSeek/main_reasons_DeepSeek_simple.csv", sep = ",", index_col = 0)
main_reasons_DeepSeek_class_def_df = pd.read_csv("03_Reasons_Misclassifications/reasons/DeepSeek/main_reasons_DeepSeek_class_definitions.csv", sep = ",", index_col = 0)
main_reasons_DeepSeek_profiled_simple_df = pd.read_csv("03_Reasons_Misclassifications/reasons/DeepSeek/main_reasons_DeepSeek_profiled_simple.csv", sep = ",", index_col = 0)
main_reasons_DeepSeek_few_shot_df = pd.read_csv("03_Reasons_Misclassifications/reasons/DeepSeek/main_reasons_DeepSeek_few_shot.csv", sep = ",", index_col = 0)
main_reasons_DeepSeek_vignette_df = pd.read_csv("03_Reasons_Misclassifications/reasons/DeepSeek/main_reasons_DeepSeek_vignette.csv", sep = ",", index_col = 0)
main_reasons_DeepSeek_cot_df = pd.read_csv("03_Reasons_Misclassifications/reasons/DeepSeek/main_reasons_DeepSeek_cot.csv", sep = ",", index_col = 0)

In [16]:
main_reasons_Grok_simple_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Grok/main_reasons_Grok_simple.csv", sep = ",", index_col = 0)
main_reasons_Grok_class_def_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Grok/main_reasons_Grok_class_definitions.csv", sep = ",", index_col = 0)
main_reasons_Grok_profiled_simple_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Grok/main_reasons_Grok_profiled_simple.csv", sep = ",", index_col = 0)
main_reasons_Grok_few_shot_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Grok/main_reasons_Grok_few_shot.csv", sep = ",", index_col = 0)
main_reasons_Grok_vignette_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Grok/main_reasons_Grok_vignette.csv", sep = ",", index_col = 0)
main_reasons_Grok_cot_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Grok/main_reasons_Grok_cot.csv", sep = ",", index_col = 0)

In [17]:
# sort dfs by count
main_reasons_GPT_4_simple_df = main_reasons_GPT_4_simple_df.sort_values(by = "count", ascending = False)
main_reasons_GPT_4_class_def_df = main_reasons_GPT_4_class_def_df.sort_values(by = "count", ascending = False)
main_reasons_GPT_4_profiled_simple_df = main_reasons_GPT_4_profiled_simple_df.sort_values(by = "count", ascending = False)
main_reasons_GPT_4_few_shot_df = main_reasons_GPT_4_few_shot_df.sort_values(by = "count", ascending = False)
main_reasons_GPT_4_vignette_df = main_reasons_GPT_4_vignette_df.sort_values(by = "count", ascending = False)
main_reasons_GPT_4_cot_df = main_reasons_GPT_4_cot_df.sort_values(by = "count", ascending = False)

In [18]:
main_reasons_GPT_o3_simple_df = main_reasons_GPT_o3_simple_df.sort_values(by = "count", ascending = False)
main_reasons_GPT_o3_class_def_df = main_reasons_GPT_o3_class_def_df.sort_values(by = "count", ascending = False)
main_reasons_GPT_o3_profiled_simple_df = main_reasons_GPT_o3_profiled_simple_df.sort_values(by = "count", ascending = False)
main_reasons_GPT_o3_few_shot_df = main_reasons_GPT_o3_few_shot_df.sort_values(by = "count", ascending = False)
main_reasons_GPT_o3_vignette_df = main_reasons_GPT_o3_vignette_df.sort_values(by = "count", ascending = False)
main_reasons_GPT_o3_cot_df = main_reasons_GPT_o3_cot_df.sort_values(by = "count", ascending = False)

In [19]:
main_reasons_Gemini_simple_df = main_reasons_Gemini_simple_df.sort_values(by = "count", ascending = False)
main_reasons_Gemini_class_def_df = main_reasons_Gemini_class_def_df.sort_values(by = "count", ascending = False)
main_reasons_Gemini_profiled_simple_df = main_reasons_Gemini_profiled_simple_df.sort_values(by = "count", ascending = False)
main_reasons_Gemini_few_shot_df = main_reasons_Gemini_few_shot_df.sort_values(by = "count", ascending = False)
main_reasons_Gemini_vignette_df = main_reasons_Gemini_vignette_df.sort_values(by = "count", ascending = False)
main_reasons_Gemini_cot_df = main_reasons_Gemini_cot_df.sort_values(by = "count", ascending = False)

In [20]:
main_reasons_Gemma_simple_df = main_reasons_Gemma_simple_df.sort_values(by = "count", ascending = False)
main_reasons_Gemma_class_def_df = main_reasons_Gemma_class_def_df.sort_values(by = "count", ascending = False)
main_reasons_Gemma_profiled_simple_df = main_reasons_Gemma_profiled_simple_df.sort_values(by = "count", ascending = False)
main_reasons_Gemma_few_shot_df = main_reasons_Gemma_few_shot_df.sort_values(by = "count", ascending = False)
main_reasons_Gemma_vignette_df = main_reasons_Gemma_vignette_df.sort_values(by = "count", ascending = False)
main_reasons_Gemma_cot_df = main_reasons_Gemma_cot_df.sort_values(by = "count", ascending = False)

In [21]:
main_reasons_Claude_simple_df = main_reasons_Claude_simple_df.sort_values(by = "count", ascending = False)
main_reasons_Claude_class_def_df = main_reasons_Claude_class_def_df.sort_values(by = "count", ascending = False)
main_reasons_Claude_profiled_simple_df = main_reasons_Claude_profiled_simple_df.sort_values(by = "count", ascending = False)
main_reasons_Claude_few_shot_df = main_reasons_Claude_few_shot_df.sort_values(by = "count", ascending = False)
main_reasons_Claude_vignette_df = main_reasons_Claude_vignette_df.sort_values(by = "count", ascending = False)
main_reasons_Claude_cot_df = main_reasons_Claude_cot_df.sort_values(by = "count", ascending = False)

In [22]:
main_reasons_DeepSeek_simple_df = main_reasons_DeepSeek_simple_df.sort_values(by = "count", ascending = False)
main_reasons_DeepSeek_class_def_df = main_reasons_DeepSeek_class_def_df.sort_values(by = "count", ascending = False)
main_reasons_DeepSeek_profiled_simple_df = main_reasons_DeepSeek_profiled_simple_df.sort_values(by = "count", ascending = False)
main_reasons_DeepSeek_few_shot_df = main_reasons_DeepSeek_few_shot_df.sort_values(by = "count", ascending = False)
main_reasons_DeepSeek_vignette_df = main_reasons_DeepSeek_vignette_df.sort_values(by = "count", ascending = False)
main_reasons_DeepSeek_cot_df = main_reasons_DeepSeek_cot_df.sort_values(by = "count", ascending = False)

In [23]:
main_reasons_Grok_simple_df = main_reasons_Grok_simple_df.sort_values(by = "count", ascending = False)
main_reasons_Grok_class_def_df = main_reasons_Grok_class_def_df.sort_values(by = "count", ascending = False)
main_reasons_Grok_profiled_simple_df = main_reasons_Grok_profiled_simple_df.sort_values(by = "count", ascending = False)
main_reasons_Grok_few_shot_df = main_reasons_Grok_few_shot_df.sort_values(by = "count", ascending = False)
main_reasons_Grok_vignette_df = main_reasons_Grok_vignette_df.sort_values(by = "count", ascending = False)
main_reasons_Grok_cot_df = main_reasons_Grok_cot_df.sort_values(by = "count", ascending = False)

In [16]:
# # # delete index second row from main_reasons_DeepSeek_simple_df
# main_reasons_Gemini_cot_df = main_reasons_Gemini_cot_df[main_reasons_Gemini_cot_df.index != "Lack of counterfactual demonstrationsKnowledge conflictsLack of context"]
# main_reasons_Gemini_cot_df = main_reasons_Gemini_cot_df[main_reasons_Gemini_cot_df.index != "Knowledge conflictsLack of examplesLack of context"]
# main_reasons_Gemini_cot_df = main_reasons_Gemini_cot_df[main_reasons_Gemini_cot_df.index != "Knowledge conflictsLack of counterfactual demonstrationsLack of context"]
# main_reasons_Gemini_cot_df = main_reasons_Gemini_cot_df[main_reasons_Gemini_cot_df.index != "Lack of counterfactual demonstrationsKnowledge conflictsLack of examples"]
# main_reasons_Gemini_cot_df = main_reasons_Gemini_cot_df[main_reasons_Gemini_cot_df.index != "Lack of counterfactual demonstrationsKnowledge conflicts"]
# main_reasons_Gemini_cot_df = main_reasons_Gemini_cot_df[main_reasons_Gemini_cot_df.index != "Knowledge conflictsLack of examplesLack of counterfactual demonstrations"]
# main_reasons_Gemini_cot_df = main_reasons_Gemini_cot_df[main_reasons_Gemini_cot_df.index != "Knowledge conflictsLack of opinionbased informationLack of contextLack of examples"]
# main_reasons_Gemini_cot_df = main_reasons_Gemini_cot_df[main_reasons_Gemini_cot_df.index != "Knowledge conflictsLack of contextLack of counterfactual demonstrations"]
#
# main_reasons_Gemini_cot_df.loc["Lack of counterfactual demonstrations", "count"] = 15
# main_reasons_Gemini_cot_df.loc["Knowledge conflicts", "count"] = 40

In [17]:
# main_reasons_DeepSeek_simple_df.to_csv("03_Reasons_Misclassifications/reasons/DeepSeek/main_reasons_DeepSeek_simple.csv", sep = ",")

In [24]:
cases_GPT_4_simple_df = pd.read_csv("03_Reasons_Misclassifications/reasons/GPT_4/cases_GPT_4_simple.csv", sep = ",", index_col = 0)
cases_GPT_4_class_def_df = pd.read_csv("03_Reasons_Misclassifications/reasons/GPT_4/cases_GPT_4_class_definitions.csv", sep = ",", index_col = 0)
cases_GPT_4_profiled_simple_df = pd.read_csv("03_Reasons_Misclassifications/reasons/GPT_4/cases_GPT_4_profiled_simple.csv", sep = ",", index_col = 0)
cases_GPT_4_few_shot_df = pd.read_csv("03_Reasons_Misclassifications/reasons/GPT_4/cases_GPT_4_few_shot.csv", sep = ",", index_col = 0)
cases_GPT_4_vignette_df = pd.read_csv("03_Reasons_Misclassifications/reasons/GPT_4/cases_GPT_4_vignette.csv", sep = ",", index_col = 0)
cases_GPT_4_cot_df = pd.read_csv("03_Reasons_Misclassifications/reasons/GPT_4/cases_GPT_4_cot.csv", sep = ",", index_col = 0)

In [25]:
cases_GPT_o3_simple_df = pd.read_csv("03_Reasons_Misclassifications/reasons/GPT_o3/cases_GPT_o3_simple.csv", sep = ",", index_col = 0)
cases_GPT_o3_class_def_df = pd.read_csv("03_Reasons_Misclassifications/reasons/GPT_o3/cases_GPT_o3_class_definitions.csv", sep = ",", index_col = 0)
cases_GPT_o3_profiled_simple_df = pd.read_csv("03_Reasons_Misclassifications/reasons/GPT_o3/cases_GPT_o3_profiled_simple.csv", sep = ",", index_col = 0)
cases_GPT_o3_few_shot_df = pd.read_csv("03_Reasons_Misclassifications/reasons/GPT_o3/cases_GPT_o3_few_shot.csv", sep = ",", index_col = 0)
cases_GPT_o3_vignette_df = pd.read_csv("03_Reasons_Misclassifications/reasons/GPT_o3/cases_GPT_o3_vignette.csv", sep = ",", index_col = 0)
cases_GPT_o3_cot_df = pd.read_csv("03_Reasons_Misclassifications/reasons/GPT_o3/cases_GPT_o3_cot.csv", sep = ",", index_col = 0)

In [26]:
cases_Gemini_simple_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Gemini/cases_Gemini_simple.csv", sep = ",", index_col = 0)
cases_Gemini_class_def_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Gemini/cases_Gemini_class_definitions.csv", sep = ",", index_col = 0)
cases_Gemini_profiled_simple_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Gemini/cases_Gemini_profiled_simple.csv", sep = ",", index_col = 0)
cases_Gemini_few_shot_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Gemini/cases_Gemini_few_shot.csv", sep = ",", index_col = 0)
cases_Gemini_vignette_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Gemini/cases_Gemini_vignette.csv", sep = ",", index_col = 0)
cases_Gemini_cot_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Gemini/cases_Gemini_cot.csv", sep = ",", index_col = 0)

In [27]:
cases_Gemma_simple_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Gemma/cases_Gemma_simple.csv", sep = ",", index_col = 0)
cases_Gemma_class_def_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Gemma/cases_Gemma_class_definitions.csv", sep = ",", index_col = 0)
cases_Gemma_profiled_simple_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Gemma/cases_Gemma_profiled_simple.csv", sep = ",", index_col = 0)
cases_Gemma_few_shot_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Gemma/cases_Gemma_few_shot.csv", sep = ",", index_col = 0)
cases_Gemma_vignette_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Gemma/cases_Gemma_vignette.csv", sep = ",", index_col = 0)
cases_Gemma_cot_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Gemma/cases_Gemma_cot.csv", sep = ",", index_col = 0)

In [28]:
cases_Claude_simple_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Claude/cases_Claude_simple.csv", sep = ",", index_col = 0)
cases_Claude_class_def_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Claude/cases_Claude_class_definitions.csv", sep = ",", index_col = 0)
cases_Claude_profiled_simple_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Claude/cases_Claude_profiled_simple.csv", sep = ",", index_col = 0)
cases_Claude_few_shot_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Claude/cases_Claude_few_shot.csv", sep = ",", index_col = 0)
cases_Claude_vignette_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Claude/cases_Claude_vignette.csv", sep = ",", index_col = 0)
cases_Claude_cot_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Claude/cases_Claude_cot.csv", sep = ",", index_col = 0)

In [29]:
cases_DeepSeek_simple_df = pd.read_csv("03_Reasons_Misclassifications/reasons/DeepSeek/cases_DeepSeek_simple.csv", sep = ",", index_col = 0)
cases_DeepSeek_class_def_df = pd.read_csv("03_Reasons_Misclassifications/reasons/DeepSeek/cases_DeepSeek_class_definitions.csv", sep = ",", index_col = 0)
cases_DeepSeek_profiled_simple_df = pd.read_csv("03_Reasons_Misclassifications/reasons/DeepSeek/cases_DeepSeek_profiled_simple.csv", sep = ",", index_col = 0)
cases_DeepSeek_few_shot_df = pd.read_csv("03_Reasons_Misclassifications/reasons/DeepSeek/cases_DeepSeek_few_shot.csv", sep = ",", index_col = 0)
cases_DeepSeek_vignette_df = pd.read_csv("03_Reasons_Misclassifications/reasons/DeepSeek/cases_DeepSeek_vignette.csv", sep = ",", index_col = 0)
cases_DeepSeek_cot_df = pd.read_csv("03_Reasons_Misclassifications/reasons/DeepSeek/cases_DeepSeek_cot.csv", sep = ",", index_col = 0)

In [30]:
cases_Grok_simple_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Grok/cases_Grok_simple.csv", sep = ",", index_col = 0)
cases_Grok_class_def_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Grok/cases_Grok_class_definitions.csv", sep = ",", index_col = 0)
cases_Grok_profiled_simple_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Grok/cases_Grok_profiled_simple.csv", sep = ",", index_col = 0)
cases_Grok_few_shot_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Grok/cases_Grok_few_shot.csv", sep = ",", index_col = 0)
cases_Grok_vignette_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Grok/cases_Grok_vignette.csv", sep = ",", index_col = 0)
cases_Grok_cot_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Grok/cases_Grok_cot.csv", sep = ",", index_col = 0)

In [31]:
def lighten_color(color, amount=0.8):
    """Lighten an RGB or hex color by blending with white."""
    import matplotlib.colors as mc

    try:
        c = mc.cnames[color]
    except:
        c = color
    c = mc.to_rgb(c)
    lightened = tuple(1 - amount * (1 - x) for x in c)
    return f"rgb({int(lightened[0]*255)}, {int(lightened[1]*255)}, {int(lightened[2]*255)})"

In [32]:
def plot_sankey(name, cases_df, main_reasons_df):
    total_cases = int(cases_df['total'][0])
    correct_cases = int(cases_df['correct'][0])
    misclassified_cases = int(cases_df['missclassified'][0])

    main_reason_labels = [str(reason).strip() for reason in main_reasons_df.index]

    # Labels now include counts
    labels = [
        f"Total Cases ({total_cases})",
        f"Correct ({correct_cases})",
        f"Misclassified ({misclassified_cases})"
    ] + [
        f"{reason} ({main_reasons_df.loc[reason, 'count']})" for reason in main_reason_labels
    ]

    idx = {label: i for i, label in enumerate(labels)}

    sources = [
        idx[f"Total Cases ({total_cases})"], idx[f"Total Cases ({total_cases})"]
    ] + [idx[f"Misclassified ({misclassified_cases})"]] * len(main_reason_labels)

    targets = [
        idx[f"Correct ({correct_cases})"], idx[f"Misclassified ({misclassified_cases})"]
    ] + [idx[f"{reason} ({main_reasons_df.loc[reason, 'count']})"] for reason in main_reason_labels]

    values = [
        correct_cases, misclassified_cases
    ] + [int(main_reasons_df.loc[reason, "count"]) for reason in main_reason_labels]

    base_colors = {
        "Total Cases": "#042940",      # muted baby blue (cool, calm)
        "Correct": "#728C14",          # baby blue (lighter, friendly)
        "Misclassified": "#BF2C53",    # warm pastel orange (soft but warm)
    }


    final_reason_palette = [
        "#03588C",
        "#89CFF0",
        # "#731A32",
        "#730220",
        "#BF9BB9",
        "#F27405",
        "#F29F05",
        "#F26680",
    ]

    reason_colors = {}
    for i, reason in enumerate(main_reason_labels):
        reason_colors[reason] = final_reason_palette[i % len(final_reason_palette)]

    # Node colors
    node_colors = [
        base_colors["Total Cases"],
        base_colors["Correct"],
        base_colors["Misclassified"]
    ] + [reason_colors[reason] for reason in main_reason_labels]

    # Link colors: lighter variants for flows
    link_colors = [
        lighten_color(base_colors["Correct"], 0.4),
        lighten_color(base_colors["Misclassified"], 0.4),
    ] + [lighten_color(reason_colors[reason], 0.4) for reason in main_reason_labels]

    # Node positions to avoid overlap
    node_x = [0.0, 0.4, 0.4] + [0.9] * len(main_reason_labels)
    node_y = [0.5, -0.8, 0.8] + [
        (0.9 - (0.1 * len(main_reason_labels)))  + (0.2 * len(main_reason_labels)) * i / max(len(main_reason_labels) - 1, 1)
        for i in range(len(main_reason_labels))
    ]

    fig = go.Figure(data=[go.Sankey(
        arrangement="snap",
        node=dict(
            pad=15,
            thickness=20,
            line=dict(color="black", width=0.5),
            label=labels,
            color=node_colors,
            x=node_x,
            y=node_y
        ),
        link=dict(
            source=sources,
            target=targets,
            value=values,
            color=link_colors
        )
    )])

    fig.update_layout(
        title_text=f"Reasons for misclassifications for {name}",
        font=dict(size=14, color='black', family='Times New Roman', weight=750),
        margin=dict(t=50, l=10, r=10, b=20 * len(main_reason_labels)),
        height=300,
        width=900
    )
    fig.show()


In [33]:
plot_sankey(name = "GPT 4 Simple Prompt",
            cases_df = cases_GPT_4_simple_df,
            main_reasons_df = main_reasons_GPT_4_simple_df)
plot_sankey(name = "GPT o3 Simple Prompt",
            cases_df = cases_GPT_o3_simple_df,
            main_reasons_df = main_reasons_GPT_o3_simple_df)
plot_sankey(name = "Gemini Simple Prompt",
            cases_df = cases_Gemini_simple_df,
            main_reasons_df = main_reasons_Gemini_simple_df)
plot_sankey(name = "Gemma Simple Prompt",
            cases_df = cases_Gemma_simple_df,
            main_reasons_df = main_reasons_Gemma_simple_df)
plot_sankey(name = "Claude Simple Prompt",
            cases_df = cases_Claude_simple_df,
            main_reasons_df = main_reasons_Claude_simple_df)
plot_sankey(name = "DeepSeek Simple Prompt",
            cases_df = cases_DeepSeek_simple_df,
            main_reasons_df = main_reasons_DeepSeek_simple_df)
plot_sankey(name = "Grok Simple Prompt",
            cases_df = cases_Grok_simple_df,
            main_reasons_df = main_reasons_Grok_simple_df)





This means that static image generation (e.g. `fig.write_image()`) will not work.

Please upgrade Plotly to version 6.1.1 or greater, or downgrade Kaleido to version 0.2.1.




In [34]:
plot_sankey(name = "GPT 4 CoT Prompt",
            cases_df = cases_GPT_4_cot_df,
            main_reasons_df = main_reasons_GPT_4_cot_df)
plot_sankey(name = "GPT o3 CoT Prompt",
            cases_df = cases_GPT_o3_cot_df,
            main_reasons_df = main_reasons_GPT_o3_cot_df)
plot_sankey(name = "Gemini CoT Prompt",
            cases_df = cases_Gemini_cot_df,
            main_reasons_df = main_reasons_Gemini_cot_df)
plot_sankey(name = "Gemma CoT Prompt",
            cases_df = cases_Gemma_cot_df,
            main_reasons_df = main_reasons_Gemma_cot_df)
plot_sankey(name = "Claude CoT Prompt",
            cases_df = cases_Claude_cot_df,
            main_reasons_df = main_reasons_Claude_cot_df)
plot_sankey(name = "DeepSeek CoT Prompt",
            cases_df = cases_DeepSeek_cot_df,
            main_reasons_df = main_reasons_DeepSeek_cot_df)
plot_sankey(name = "Grok CoT Prompt",
            cases_df = cases_Grok_cot_df,
            main_reasons_df = main_reasons_Grok_cot_df)

### Calculate for all prompts the percentage of misclassifications that are due to each category

In [35]:
def cal_percentage(all_main_df, all_cases_df):
    # group by index and sum the counts
    all_main_df = all_main_df.groupby(all_main_df.index).sum()
    # calculate the total count of misclassifications
    total_misclassifications_simple = all_main_df["count"].sum()
    # calculate the percentage of misclassifications that are due to each category
    all_main_df["percentage"] = round(all_main_df["count"] / total_misclassifications_simple * 100, 1)
    # delete count column
    all_main_df = all_main_df.drop(columns=["count"])
    # sort df
    all_main_df = all_main_df.sort_values(by="percentage", ascending=False)

    # multiply the percentage by percentage of misclassifications
    all_main_df["percentage"] = round(all_main_df["percentage"] * all_cases_df["missclassified"].sum() / 100, 2)

    return all_main_df

In [36]:
def cal_percentage_cases(all_cases_df):
    # calc mean per column
    all_cases_df = all_cases_df.mean()

    # devide by the total number of cases
    total_cases = all_cases_df["total"]
    all_cases_df["correct"] = round(all_cases_df["correct"] / total_cases * 100, 1)
    all_cases_df["missclassified"] = round(all_cases_df["missclassified"] / total_cases * 100, 1)
    all_cases_df["total"] = round(all_cases_df["total"] / total_cases * 100, 1)

    return all_cases_df

In [37]:
def plot_global_misclassification_sankey(name, prompt_cases_dict, prompt_reasons_dict):
    labels = []
    sources = []
    targets = []
    values = []

    label_to_index = {}
    idx_counter = 0

    # Farben wie bei dir
    base_colors = {
        "Total Misclassifications": "#042940",
        "Simple": "#03588C",
        "Class Definitions": "#BF2C53",
        "Profiled Simple": "#6F8F00",
        "Few Shot": "#EB7801",
        "Vignette": "#CCB900",
        "CoT": "#8C0327"
    }

    reasons_colors = {
        "Lack of context": "#7EB5D6",
        "Lack of examples": "#ED77A0",
        "Lack of counterfactual demonstrations": "#AAC72A",
        "Lack of opinionbased information": "#DACE65",
        "Knowledge conflicts": "#EBAA64",
        "Prediction with Abstention": "#BA5D77"
    }

    label_aliases = {
        "Lack of context": "Lack of Context",
        "Lack of examples": "Lack of Examples",
        "Lack of counterfactual demonstrations": "Lack of Countf. Ex.",
        "Lack of opinionbased information": "Lack of Opinionb. Info.",
        "Knowledge conflicts": "Knowledge Conflicts",
        "Prediction with Abstention": "Pred. with Abstention"
    }

    node_colors = []
    link_colors = []
    reason_color_map = {}
    dummy_links = []

    # 1. Total Node
    total_cases = sum(int(df['total'].mean()) for df in prompt_cases_dict.values())
    labels.append("Total Misclassifications")
    label_to_index["Total Misclassifications"] = idx_counter
    node_colors.append(base_colors["Total Misclassifications"])
    idx_counter += 1

    # 2. Prompt Nodes
    for prompt_name, cases_df in prompt_cases_dict.items():
        prompt_label = f"{prompt_name} ({int(cases_df['missclassified'].mean())})"
        labels.append(prompt_label)
        label_to_index[prompt_name] = idx_counter
        node_colors.append(base_colors.get(prompt_name, "#888888"))

        sources.append(label_to_index["Total Misclassifications"])
        targets.append(label_to_index[prompt_name])
        values.append(int(cases_df["missclassified"].mean()))
        link_colors.append(lighten_color(base_colors.get(prompt_name, "#888888"), 0.4))

        idx_counter += 1

    # 3. Reason Nodes + Dummy Nodes
    reason_totals = {}

    # Aggregate total reason counts across all prompts
    for reasons_df in prompt_reasons_dict.values():
        reasons_df.index = reasons_df.index.astype(str).str.strip()
        grouped = reasons_df.groupby(reasons_df.index).sum()
        for reason, row in grouped.iterrows():
            cleaned_reason = reason.strip()
            reason_totals[cleaned_reason] = reason_totals.get(cleaned_reason, 0) + int(row["count"])

    # Sort reasons by total count descending
    sorted_reasons = sorted(reason_totals.items(), key=lambda x: x[1], reverse=True)

    # Add reason and dummy nodes in sorted order
    for cleaned_reason, _ in sorted_reasons:
        if cleaned_reason not in label_to_index:
            short_label = label_aliases.get(cleaned_reason, cleaned_reason)
            labels.append(short_label)
            label_to_index[cleaned_reason] = idx_counter

            color = reasons_colors.get(cleaned_reason, "#AAAAAA")
            reason_color_map[cleaned_reason] = color
            node_colors.append(color)
            idx_counter += 1

            # Dummy node
            dummy_label = f"_{cleaned_reason}_dummy"
            labels.append("")  # blank label
            label_to_index[dummy_label] = idx_counter
            node_colors.append("rgba(0,0,0,0)")
            idx_counter += 1

            sources.append(label_to_index[cleaned_reason])
            targets.append(label_to_index[dummy_label])
            values.append(1e-6)
            link_colors.append("rgba(255,255,255,0)")

    # Now build links from prompt → reasons
    for prompt_name, reasons_df in prompt_reasons_dict.items():
        reasons_df.index = reasons_df.index.astype(str).str.strip()
        grouped = reasons_df.groupby(reasons_df.index).sum()

        for reason, row in grouped.iterrows():
            cleaned_reason = reason.strip()
            sources.append(label_to_index[prompt_name])
            targets.append(label_to_index[cleaned_reason])
            values.append(int(row["count"]))
            link_colors.append(lighten_color(reason_color_map[cleaned_reason], 0.4))


        # --------- Layout Fix: aligned spacing for all nodes ---------
    num_total_nodes = len(labels)
    num_prompt_nodes = len(prompt_cases_dict)
    num_reason_nodes = len(reason_color_map)

    # Compute y positions for prompts and reasons
    prompt_ys = np.linspace(0.1, 0.9, num_prompt_nodes)
    spread_factor = 0.19 * num_reason_nodes if num_reason_nodes <= 5 else 0.17 * num_reason_nodes
    reason_ys = np.linspace(0.05, spread_factor, num_reason_nodes)

    node_x = [None] * num_total_nodes
    node_y = [None] * num_total_nodes

    # Assign Total Misclassification node
    node_x[label_to_index["Total Misclassifications"]] = 0.0
    node_y[label_to_index["Total Misclassifications"]] = 0.5

    # Assign prompt positions
    for i, (prompt_name, _) in enumerate(prompt_cases_dict.items()):
        idx = label_to_index[prompt_name]
        node_x[idx] = 0.3
        node_y[idx] = prompt_ys[i]

    # Assign reason positions
    for i, reason in enumerate(reason_color_map.keys()):
        idx = label_to_index[reason]
        node_x[idx] = 0.65
        node_y[idx] = reason_ys[i]

    # Add dummy nodes for spacing labels right of reason nodes
    dummy_label_indices = []
    for i, reason in enumerate(reason_color_map.keys()):
        labels.append("")  # blank label
        node_colors.append("rgba(0,0,0,0)")  # transparent node
        dummy_idx = len(labels) - 1
        dummy_label_indices.append(dummy_idx)

        # Add dummy flow (tiny)
        sources.append(label_to_index[reason])
        targets.append(dummy_idx)
        values.append(1e-10)
        link_colors.append("white")

        node_x.append(0.95)
        node_y.append(reason_ys[i])

    # --------- Final Figure ---------
    fig = go.Figure(data=[go.Sankey(
        arrangement="snap",
        node=dict(
            pad=15,
            thickness=20,
            line=dict(color="black", width=0.5),
            label=labels,
            color=node_colors,
            x=node_x,
            y=node_y
        ),
        link=dict(
            source=sources,
            target=targets,
            value=values,
            color=link_colors
        )
    )])

    fig.update_layout(
        font=dict(size=22, color='black', family='Times New Roman', weight=549),
        margin=dict(t=100, l=10, r=320, b=100),
        height=450,
        width=1400
    )

    # fig.update_layout(
    #     font=dict(size=22, color='black', family='Times New Roman'),
    #     margin=dict(t=50, l=10, r=400, b=80),
    #     height=600,
    #     width=1300
    # )

    fig.show()

In [38]:
# prompt_cases_dict = {
#     "Simple": cases_GPT_4_simple_df,
#     "Class Definitions": cases_GPT_4_class_def_df,
#     "Profiled Simple": cases_GPT_4_profiled_simple_df,
#     "Few Shot": cases_GPT_4_few_shot_df,
#     "Vignette": cases_GPT_4_vignette_df,
#     "CoT": cases_GPT_4_cot_df
# }
#
# prompt_reasons_dict = {
#     "Simple": main_reasons_GPT_4_simple_df,
#     "Class Definitions": main_reasons_GPT_4_class_def_df,
#     "Profiled Simple": main_reasons_GPT_4_profiled_simple_df,
#     "Few Shot": main_reasons_GPT_4_few_shot_df,
#     "Vignette": main_reasons_GPT_4_vignette_df,
#     "CoT": main_reasons_GPT_4_cot_df
# }
#
# plot_global_misclassification_sankey("GPT", prompt_cases_dict, prompt_reasons_dict)

In [39]:
prompt_cases_dict = {
    "Simple": cases_GPT_o3_simple_df,
    "Class Definitions": cases_GPT_o3_class_def_df,
    "Profiled Simple": cases_GPT_o3_profiled_simple_df,
    "Few Shot": cases_GPT_o3_few_shot_df,
    "Vignette": cases_GPT_o3_vignette_df,
    "CoT": cases_GPT_o3_cot_df
}

prompt_reasons_dict = {
    "Simple": main_reasons_GPT_o3_simple_df,
    "Class Definitions": main_reasons_GPT_o3_class_def_df,
    "Profiled Simple": main_reasons_GPT_o3_profiled_simple_df,
    "Few Shot": main_reasons_GPT_o3_few_shot_df,
    "Vignette": main_reasons_GPT_o3_vignette_df,
    "CoT": main_reasons_GPT_o3_cot_df
}

plot_global_misclassification_sankey("GPT", prompt_cases_dict, prompt_reasons_dict)

In [40]:
prompt_cases_dict = {
    "Simple": cases_Gemini_simple_df,
    "Class Definitions": cases_Gemini_class_def_df,
    "Profiled Simple": cases_Gemini_profiled_simple_df,
    "Few Shot": cases_Gemini_few_shot_df,
    "Vignette": cases_Gemini_vignette_df,
    "CoT": cases_Gemini_cot_df
}

prompt_reasons_dict = {
    "Simple": main_reasons_Gemini_simple_df,
    "Class Definitions": main_reasons_Gemini_class_def_df,
    "Profiled Simple": main_reasons_Gemini_profiled_simple_df,
    "Few Shot": main_reasons_Gemini_few_shot_df,
    "Vignette": main_reasons_Gemini_vignette_df,
    "CoT": main_reasons_Gemini_cot_df
}

plot_global_misclassification_sankey("Gemini", prompt_cases_dict, prompt_reasons_dict)

In [41]:
prompt_cases_dict = {
    "Simple": cases_Gemma_simple_df,
    "Class Definitions": cases_Gemma_class_def_df,
    "Profiled Simple": cases_Gemma_profiled_simple_df,
    "Few Shot": cases_Gemma_few_shot_df,
    "Vignette": cases_Gemma_vignette_df,
    "CoT": cases_Gemma_cot_df
}

prompt_reasons_dict = {
    "Simple": main_reasons_Gemma_simple_df,
    "Class Definitions": main_reasons_Gemma_class_def_df,
    "Profiled Simple": main_reasons_Gemma_profiled_simple_df,
    "Few Shot": main_reasons_Gemma_few_shot_df,
    "Vignette": main_reasons_Gemma_vignette_df,
    "CoT": main_reasons_Gemma_cot_df
}

plot_global_misclassification_sankey("Gemma", prompt_cases_dict, prompt_reasons_dict)

In [42]:
prompt_cases_dict = {
    "Simple": cases_Claude_simple_df,
    "Class Definitions": cases_Claude_class_def_df,
    "Profiled Simple": cases_Claude_profiled_simple_df,
    "Few Shot": cases_Claude_few_shot_df,
    "Vignette": cases_Claude_vignette_df,
    "CoT": cases_Claude_cot_df
}

prompt_reasons_dict = {
    "Simple": main_reasons_Claude_simple_df,
    "Class Definitions": main_reasons_Claude_class_def_df,
    "Profiled Simple": main_reasons_Claude_profiled_simple_df,
    "Few Shot": main_reasons_Claude_few_shot_df,
    "Vignette": main_reasons_Claude_vignette_df,
    "CoT": main_reasons_Claude_cot_df
}

plot_global_misclassification_sankey("Claude", prompt_cases_dict, prompt_reasons_dict)

In [43]:
prompt_cases_dict = {
    "Simple": cases_DeepSeek_simple_df,
    "Class Definitions": cases_DeepSeek_class_def_df,
    "Profiled Simple": cases_DeepSeek_profiled_simple_df,
    "Few Shot": cases_DeepSeek_few_shot_df,
    "Vignette": cases_DeepSeek_vignette_df,
    "CoT": cases_DeepSeek_cot_df
}

prompt_reasons_dict = {
    "Simple": main_reasons_DeepSeek_simple_df,
    "Class Definitions": main_reasons_DeepSeek_class_def_df,
    "Profiled Simple": main_reasons_DeepSeek_profiled_simple_df,
    "Few Shot": main_reasons_DeepSeek_few_shot_df,
    "Vignette": main_reasons_DeepSeek_vignette_df,
    "CoT": main_reasons_DeepSeek_cot_df
}

plot_global_misclassification_sankey("DeepSeek", prompt_cases_dict, prompt_reasons_dict)

In [44]:
prompt_cases_dict = {
    "Simple": cases_Grok_simple_df,
    "Class Definitions": cases_Grok_class_def_df,
    "Profiled Simple": cases_Grok_profiled_simple_df,
    "Few Shot": cases_Grok_few_shot_df,
    "Vignette": cases_Grok_vignette_df,
    "CoT": cases_Grok_cot_df
}

prompt_reasons_dict = {
    "Simple": main_reasons_Grok_simple_df,
    "Class Definitions": main_reasons_Grok_class_def_df,
    "Profiled Simple": main_reasons_Grok_profiled_simple_df,
    "Few Shot": main_reasons_Grok_few_shot_df,
    "Vignette": main_reasons_Grok_vignette_df,
    "CoT": main_reasons_Grok_cot_df
}

plot_global_misclassification_sankey("Grok", prompt_cases_dict, prompt_reasons_dict)