# Evaluation of reasons for misclassifications

**Possible categories**:
- **Lack of context** (insufficient emphasis or indication of the context of the query, such as domain-specific background information about psychological assessments)
- **Lack of examples** (absence of representative few-shot examples showing correct classifications for similar psychological cases before posing the actual question)
- **Lack of feedback** (absence of contrasting cases with false psychological diagnoses or classifications that would improve the model's robustness against wrong decisions)
- **Lack of counterfactual demonstrations** (missing iterative refinements and dialogues where the user gives feedback and interactively refines the prompt)
- **Lack of opinion-based information** (missing subjective clinical judgement and contextual interpretations, e.g., reframing the data as a narrator’s statement and opinions rather than relying solely on quantitative symptoms)
- **Knowledge conflicts** (outdated memorized facts or contradictions in the model's training data between psychological diagnostic criteria or between different theoretical frameworks)
- **Prediction with Abstention** (model uncertainty regarding the binary classification or insufficient confidence in distinguishing between presence and absence of mental disorder development)

## 0 Imports

In [2]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import holoviews as hv
import plotly.graph_objects as go
from holoviews import opts
from matplotlib.colors import ListedColormap
from sklearn.metrics import confusion_matrix, recall_score, matthews_corrcoef, accuracy_score
from sklearn.model_selection import train_test_split

## 1 Misclassifications

In [35]:
main_reasons_GPT_4_simple_df = pd.read_csv("03_Reasons_Misclassifications/reasons/GPT_4/main_reasons_GPT_4_simple.csv", sep = ",", index_col = 0)
main_reasons_GPT_4_class_def_df = pd.read_csv("03_Reasons_Misclassifications/reasons/GPT_4/main_reasons_GPT_4_class_definitions.csv", sep = ",", index_col = 0)
main_reasons_GPT_4_profiled_simple_df = pd.read_csv("03_Reasons_Misclassifications/reasons/GPT_4/main_reasons_GPT_4_profiled_simple.csv", sep = ",", index_col = 0)
main_reasons_GPT_4_few_shot_df = pd.read_csv("03_Reasons_Misclassifications/reasons/GPT_4/main_reasons_GPT_4_few_shot.csv", sep = ",", index_col = 0)
main_reasons_GPT_4_vignette_df = pd.read_csv("03_Reasons_Misclassifications/reasons/GPT_4/main_reasons_GPT_4_vignette.csv", sep = ",", index_col = 0)
main_reasons_GPT_4_cot_df = pd.read_csv("03_Reasons_Misclassifications/reasons/GPT_4/main_reasons_GPT_4_cot.csv", sep = ",", index_col = 0)

In [36]:
main_reasons_GPT_o3_simple_df = pd.read_csv("03_Reasons_Misclassifications/reasons/GPT_o3/main_reasons_GPT_o3_simple.csv", sep = ",", index_col = 0)
main_reasons_GPT_o3_class_def_df = pd.read_csv("03_Reasons_Misclassifications/reasons/GPT_o3/main_reasons_GPT_o3_class_definitions.csv", sep = ",", index_col = 0)
main_reasons_GPT_o3_profiled_simple_df = pd.read_csv("03_Reasons_Misclassifications/reasons/GPT_o3/main_reasons_GPT_o3_profiled_simple.csv", sep = ",", index_col = 0)
main_reasons_GPT_o3_few_shot_df = pd.read_csv("03_Reasons_Misclassifications/reasons/GPT_o3/main_reasons_GPT_o3_few_shot.csv", sep = ",", index_col = 0)
main_reasons_GPT_o3_vignette_df = pd.read_csv("03_Reasons_Misclassifications/reasons/GPT_o3/main_reasons_GPT_o3_vignette.csv", sep = ",", index_col = 0)
main_reasons_GPT_o3_cot_df = pd.read_csv("03_Reasons_Misclassifications/reasons/GPT_o3/main_reasons_GPT_o3_cot.csv", sep = ",", index_col = 0)

In [37]:
main_reasons_Gemini_simple_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Gemini/main_reasons_Gemini_simple.csv", sep = ",", index_col = 0)
main_reasons_Gemini_class_def_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Gemini/main_reasons_Gemini_class_definitions.csv", sep = ",", index_col = 0)
main_reasons_Gemini_profiled_simple_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Gemini/main_reasons_Gemini_profiled_simple.csv", sep = ",", index_col = 0)
main_reasons_Gemini_few_shot_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Gemini/main_reasons_Gemini_few_shot.csv", sep = ",", index_col = 0)
main_reasons_Gemini_vignette_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Gemini/main_reasons_Gemini_vignette.csv", sep = ",", index_col = 0)
main_reasons_Gemini_cot_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Gemini/main_reasons_Gemini_cot.csv", sep = ",", index_col = 0)

In [38]:
main_reasons_Gemma_simple_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Gemma/main_reasons_Gemma_simple.csv", sep = ",", index_col = 0)
main_reasons_Gemma_class_def_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Gemma/main_reasons_Gemma_class_definitions.csv", sep = ",", index_col = 0)
main_reasons_Gemma_profiled_simple_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Gemma/main_reasons_Gemma_profiled_simple.csv", sep = ",", index_col = 0)
main_reasons_Gemma_few_shot_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Gemma/main_reasons_Gemma_few_shot.csv", sep = ",", index_col = 0)
main_reasons_Gemma_vignette_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Gemma/main_reasons_Gemma_vignette.csv", sep = ",", index_col = 0)
main_reasons_Gemma_cot_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Gemma/main_reasons_Gemma_cot.csv", sep = ",", index_col = 0)

In [39]:
main_reasons_Claude_simple_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Claude/main_reasons_Claude_simple.csv", sep = ",", index_col = 0)
main_reasons_Claude_class_def_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Claude/main_reasons_Claude_class_definitions.csv", sep = ",", index_col = 0)
main_reasons_Claude_profiled_simple_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Claude/main_reasons_Claude_profiled_simple.csv", sep = ",", index_col = 0)
main_reasons_Claude_few_shot_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Claude/main_reasons_Claude_few_shot.csv", sep = ",", index_col = 0)
main_reasons_Claude_vignette_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Claude/main_reasons_Claude_vignette.csv", sep = ",", index_col = 0)
main_reasons_Claude_cot_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Claude/main_reasons_Claude_cot.csv", sep = ",", index_col = 0)

In [40]:
main_reasons_DeepSeek_simple_df = pd.read_csv("03_Reasons_Misclassifications/reasons/DeepSeek/main_reasons_DeepSeek_simple.csv", sep = ",", index_col = 0)
main_reasons_DeepSeek_class_def_df = pd.read_csv("03_Reasons_Misclassifications/reasons/DeepSeek/main_reasons_DeepSeek_class_definitions.csv", sep = ",", index_col = 0)
main_reasons_DeepSeek_profiled_simple_df = pd.read_csv("03_Reasons_Misclassifications/reasons/DeepSeek/main_reasons_DeepSeek_profiled_simple.csv", sep = ",", index_col = 0)
main_reasons_DeepSeek_few_shot_df = pd.read_csv("03_Reasons_Misclassifications/reasons/DeepSeek/main_reasons_DeepSeek_few_shot.csv", sep = ",", index_col = 0)
main_reasons_DeepSeek_vignette_df = pd.read_csv("03_Reasons_Misclassifications/reasons/DeepSeek/main_reasons_DeepSeek_vignette.csv", sep = ",", index_col = 0)
main_reasons_DeepSeek_cot_df = pd.read_csv("03_Reasons_Misclassifications/reasons/DeepSeek/main_reasons_DeepSeek_cot.csv", sep = ",", index_col = 0)

In [41]:
main_reasons_Grok_simple_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Grok/main_reasons_Grok_simple.csv", sep = ",", index_col = 0)
main_reasons_Grok_class_def_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Grok/main_reasons_Grok_class_definitions.csv", sep = ",", index_col = 0)
main_reasons_Grok_profiled_simple_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Grok/main_reasons_Grok_profiled_simple.csv", sep = ",", index_col = 0)
main_reasons_Grok_few_shot_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Grok/main_reasons_Grok_few_shot.csv", sep = ",", index_col = 0)
main_reasons_Grok_vignette_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Grok/main_reasons_Grok_vignette.csv", sep = ",", index_col = 0)
main_reasons_Grok_cot_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Grok/main_reasons_Grok_cot.csv", sep = ",", index_col = 0)

In [42]:
# sort dfs by count
main_reasons_GPT_4_simple_df = main_reasons_GPT_4_simple_df.sort_values(by = "count", ascending = False)
main_reasons_GPT_4_class_def_df = main_reasons_GPT_4_class_def_df.sort_values(by = "count", ascending = False)
main_reasons_GPT_4_profiled_simple_df = main_reasons_GPT_4_profiled_simple_df.sort_values(by = "count", ascending = False)
main_reasons_GPT_4_few_shot_df = main_reasons_GPT_4_few_shot_df.sort_values(by = "count", ascending = False)
main_reasons_GPT_4_vignette_df = main_reasons_GPT_4_vignette_df.sort_values(by = "count", ascending = False)
main_reasons_GPT_4_cot_df = main_reasons_GPT_4_cot_df.sort_values(by = "count", ascending = False)

In [43]:
main_reasons_GPT_o3_simple_df = main_reasons_GPT_o3_simple_df.sort_values(by = "count", ascending = False)
main_reasons_GPT_o3_class_def_df = main_reasons_GPT_o3_class_def_df.sort_values(by = "count", ascending = False)
main_reasons_GPT_o3_profiled_simple_df = main_reasons_GPT_o3_profiled_simple_df.sort_values(by = "count", ascending = False)
main_reasons_GPT_o3_few_shot_df = main_reasons_GPT_o3_few_shot_df.sort_values(by = "count", ascending = False)
main_reasons_GPT_o3_vignette_df = main_reasons_GPT_o3_vignette_df.sort_values(by = "count", ascending = False)
main_reasons_GPT_o3_cot_df = main_reasons_GPT_o3_cot_df.sort_values(by = "count", ascending = False)

In [44]:
main_reasons_Gemini_simple_df = main_reasons_Gemini_simple_df.sort_values(by = "count", ascending = False)
main_reasons_Gemini_class_def_df = main_reasons_Gemini_class_def_df.sort_values(by = "count", ascending = False)
main_reasons_Gemini_profiled_simple_df = main_reasons_Gemini_profiled_simple_df.sort_values(by = "count", ascending = False)
main_reasons_Gemini_few_shot_df = main_reasons_Gemini_few_shot_df.sort_values(by = "count", ascending = False)
main_reasons_Gemini_vignette_df = main_reasons_Gemini_vignette_df.sort_values(by = "count", ascending = False)
main_reasons_Gemini_cot_df = main_reasons_Gemini_cot_df.sort_values(by = "count", ascending = False)

In [45]:
main_reasons_Gemma_simple_df = main_reasons_Gemma_simple_df.sort_values(by = "count", ascending = False)
main_reasons_Gemma_class_def_df = main_reasons_Gemma_class_def_df.sort_values(by = "count", ascending = False)
main_reasons_Gemma_profiled_simple_df = main_reasons_Gemma_profiled_simple_df.sort_values(by = "count", ascending = False)
main_reasons_Gemma_few_shot_df = main_reasons_Gemma_few_shot_df.sort_values(by = "count", ascending = False)
main_reasons_Gemma_vignette_df = main_reasons_Gemma_vignette_df.sort_values(by = "count", ascending = False)
main_reasons_Gemma_cot_df = main_reasons_Gemma_cot_df.sort_values(by = "count", ascending = False)

In [46]:
main_reasons_Claude_simple_df = main_reasons_Claude_simple_df.sort_values(by = "count", ascending = False)
main_reasons_Claude_class_def_df = main_reasons_Claude_class_def_df.sort_values(by = "count", ascending = False)
main_reasons_Claude_profiled_simple_df = main_reasons_Claude_profiled_simple_df.sort_values(by = "count", ascending = False)
main_reasons_Claude_few_shot_df = main_reasons_Claude_few_shot_df.sort_values(by = "count", ascending = False)
main_reasons_Claude_vignette_df = main_reasons_Claude_vignette_df.sort_values(by = "count", ascending = False)
main_reasons_Claude_cot_df = main_reasons_Claude_cot_df.sort_values(by = "count", ascending = False)

In [47]:
main_reasons_DeepSeek_simple_df = main_reasons_DeepSeek_simple_df.sort_values(by = "count", ascending = False)
main_reasons_DeepSeek_class_def_df = main_reasons_DeepSeek_class_def_df.sort_values(by = "count", ascending = False)
main_reasons_DeepSeek_profiled_simple_df = main_reasons_DeepSeek_profiled_simple_df.sort_values(by = "count", ascending = False)
main_reasons_DeepSeek_few_shot_df = main_reasons_DeepSeek_few_shot_df.sort_values(by = "count", ascending = False)
main_reasons_DeepSeek_vignette_df = main_reasons_DeepSeek_vignette_df.sort_values(by = "count", ascending = False)
main_reasons_DeepSeek_cot_df = main_reasons_DeepSeek_cot_df.sort_values(by = "count", ascending = False)

In [48]:
main_reasons_Grok_simple_df = main_reasons_Grok_simple_df.sort_values(by = "count", ascending = False)
main_reasons_Grok_class_def_df = main_reasons_Grok_class_def_df.sort_values(by = "count", ascending = False)
main_reasons_Grok_profiled_simple_df = main_reasons_Grok_profiled_simple_df.sort_values(by = "count", ascending = False)
main_reasons_Grok_few_shot_df = main_reasons_Grok_few_shot_df.sort_values(by = "count", ascending = False)
main_reasons_Grok_vignette_df = main_reasons_Grok_vignette_df.sort_values(by = "count", ascending = False)
main_reasons_Grok_cot_df = main_reasons_Grok_cot_df.sort_values(by = "count", ascending = False)

In [771]:
# # # delete index second row from main_reasons_DeepSeek_simple_df
# main_reasons_Gemini_cot_df = main_reasons_Gemini_cot_df[main_reasons_Gemini_cot_df.index != "Lack of counterfactual demonstrationsKnowledge conflictsLack of context"]
# main_reasons_Gemini_cot_df = main_reasons_Gemini_cot_df[main_reasons_Gemini_cot_df.index != "Knowledge conflictsLack of examplesLack of context"]
# main_reasons_Gemini_cot_df = main_reasons_Gemini_cot_df[main_reasons_Gemini_cot_df.index != "Knowledge conflictsLack of counterfactual demonstrationsLack of context"]
# main_reasons_Gemini_cot_df = main_reasons_Gemini_cot_df[main_reasons_Gemini_cot_df.index != "Lack of counterfactual demonstrationsKnowledge conflictsLack of examples"]
# main_reasons_Gemini_cot_df = main_reasons_Gemini_cot_df[main_reasons_Gemini_cot_df.index != "Lack of counterfactual demonstrationsKnowledge conflicts"]
# main_reasons_Gemini_cot_df = main_reasons_Gemini_cot_df[main_reasons_Gemini_cot_df.index != "Knowledge conflictsLack of examplesLack of counterfactual demonstrations"]
# main_reasons_Gemini_cot_df = main_reasons_Gemini_cot_df[main_reasons_Gemini_cot_df.index != "Knowledge conflictsLack of opinionbased informationLack of contextLack of examples"]
# main_reasons_Gemini_cot_df = main_reasons_Gemini_cot_df[main_reasons_Gemini_cot_df.index != "Knowledge conflictsLack of contextLack of counterfactual demonstrations"]
#
# main_reasons_Gemini_cot_df.loc["Lack of counterfactual demonstrations", "count"] = 15
# main_reasons_Gemini_cot_df.loc["Knowledge conflicts", "count"] = 40

In [686]:
# main_reasons_DeepSeek_simple_df.to_csv("03_Reasons_Misclassifications/reasons/DeepSeek/main_reasons_DeepSeek_simple.csv", sep = ",")

In [49]:
cases_GPT_4_simple_df = pd.read_csv("03_Reasons_Misclassifications/reasons/GPT_4/cases_GPT_4_simple.csv", sep = ",", index_col = 0)
cases_GPT_4_class_def_df = pd.read_csv("03_Reasons_Misclassifications/reasons/GPT_4/cases_GPT_4_class_definitions.csv", sep = ",", index_col = 0)
cases_GPT_4_profiled_simple_df = pd.read_csv("03_Reasons_Misclassifications/reasons/GPT_4/cases_GPT_4_profiled_simple.csv", sep = ",", index_col = 0)
cases_GPT_4_few_shot_df = pd.read_csv("03_Reasons_Misclassifications/reasons/GPT_4/cases_GPT_4_few_shot.csv", sep = ",", index_col = 0)
cases_GPT_4_vignette_df = pd.read_csv("03_Reasons_Misclassifications/reasons/GPT_4/cases_GPT_4_vignette.csv", sep = ",", index_col = 0)
cases_GPT_4_cot_df = pd.read_csv("03_Reasons_Misclassifications/reasons/GPT_4/cases_GPT_4_cot.csv", sep = ",", index_col = 0)

In [50]:
cases_GPT_o3_simple_df = pd.read_csv("03_Reasons_Misclassifications/reasons/GPT_o3/cases_GPT_o3_simple.csv", sep = ",", index_col = 0)
cases_GPT_o3_class_def_df = pd.read_csv("03_Reasons_Misclassifications/reasons/GPT_o3/cases_GPT_o3_class_definitions.csv", sep = ",", index_col = 0)
cases_GPT_o3_profiled_simple_df = pd.read_csv("03_Reasons_Misclassifications/reasons/GPT_o3/cases_GPT_o3_profiled_simple.csv", sep = ",", index_col = 0)
cases_GPT_o3_few_shot_df = pd.read_csv("03_Reasons_Misclassifications/reasons/GPT_o3/cases_GPT_o3_few_shot.csv", sep = ",", index_col = 0)
cases_GPT_o3_vignette_df = pd.read_csv("03_Reasons_Misclassifications/reasons/GPT_o3/cases_GPT_o3_vignette.csv", sep = ",", index_col = 0)
cases_GPT_o3_cot_df = pd.read_csv("03_Reasons_Misclassifications/reasons/GPT_o3/cases_GPT_o3_cot.csv", sep = ",", index_col = 0)

In [51]:
cases_Gemini_simple_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Gemini/cases_Gemini_simple.csv", sep = ",", index_col = 0)
cases_Gemini_class_def_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Gemini/cases_Gemini_class_definitions.csv", sep = ",", index_col = 0)
cases_Gemini_profiled_simple_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Gemini/cases_Gemini_profiled_simple.csv", sep = ",", index_col = 0)
cases_Gemini_few_shot_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Gemini/cases_Gemini_few_shot.csv", sep = ",", index_col = 0)
cases_Gemini_vignette_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Gemini/cases_Gemini_vignette.csv", sep = ",", index_col = 0)
cases_Gemini_cot_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Gemini/cases_Gemini_cot.csv", sep = ",", index_col = 0)

In [52]:
cases_Gemma_simple_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Gemma/cases_Gemma_simple.csv", sep = ",", index_col = 0)
cases_Gemma_class_def_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Gemma/cases_Gemma_class_definitions.csv", sep = ",", index_col = 0)
cases_Gemma_profiled_simple_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Gemma/cases_Gemma_profiled_simple.csv", sep = ",", index_col = 0)
cases_Gemma_few_shot_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Gemma/cases_Gemma_few_shot.csv", sep = ",", index_col = 0)
cases_Gemma_vignette_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Gemma/cases_Gemma_vignette.csv", sep = ",", index_col = 0)
cases_Gemma_cot_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Gemma/cases_Gemma_cot.csv", sep = ",", index_col = 0)

In [53]:
cases_Claude_simple_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Claude/cases_Claude_simple.csv", sep = ",", index_col = 0)
cases_Claude_class_def_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Claude/cases_Claude_class_definitions.csv", sep = ",", index_col = 0)
cases_Claude_profiled_simple_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Claude/cases_Claude_profiled_simple.csv", sep = ",", index_col = 0)
cases_Claude_few_shot_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Claude/cases_Claude_few_shot.csv", sep = ",", index_col = 0)
cases_Claude_vignette_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Claude/cases_Claude_vignette.csv", sep = ",", index_col = 0)
cases_Claude_cot_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Claude/cases_Claude_cot.csv", sep = ",", index_col = 0)

In [54]:
cases_DeepSeek_simple_df = pd.read_csv("03_Reasons_Misclassifications/reasons/DeepSeek/cases_DeepSeek_simple.csv", sep = ",", index_col = 0)
cases_DeepSeek_class_def_df = pd.read_csv("03_Reasons_Misclassifications/reasons/DeepSeek/cases_DeepSeek_class_definitions.csv", sep = ",", index_col = 0)
cases_DeepSeek_profiled_simple_df = pd.read_csv("03_Reasons_Misclassifications/reasons/DeepSeek/cases_DeepSeek_profiled_simple.csv", sep = ",", index_col = 0)
cases_DeepSeek_few_shot_df = pd.read_csv("03_Reasons_Misclassifications/reasons/DeepSeek/cases_DeepSeek_few_shot.csv", sep = ",", index_col = 0)
cases_DeepSeek_vignette_df = pd.read_csv("03_Reasons_Misclassifications/reasons/DeepSeek/cases_DeepSeek_vignette.csv", sep = ",", index_col = 0)
cases_DeepSeek_cot_df = pd.read_csv("03_Reasons_Misclassifications/reasons/DeepSeek/cases_DeepSeek_cot.csv", sep = ",", index_col = 0)

In [55]:
cases_Grok_simple_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Grok/cases_Grok_simple.csv", sep = ",", index_col = 0)
cases_Grok_class_def_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Grok/cases_Grok_class_definitions.csv", sep = ",", index_col = 0)
cases_Grok_profiled_simple_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Grok/cases_Grok_profiled_simple.csv", sep = ",", index_col = 0)
cases_Grok_few_shot_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Grok/cases_Grok_few_shot.csv", sep = ",", index_col = 0)
cases_Grok_vignette_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Grok/cases_Grok_vignette.csv", sep = ",", index_col = 0)
cases_Grok_cot_df = pd.read_csv("03_Reasons_Misclassifications/reasons/Grok/cases_Grok_cot.csv", sep = ",", index_col = 0)

In [None]:
def lighten_color(color, amount=0.8):
    """Lighten an RGB or hex color by blending with white."""
    import matplotlib.colors as mc

    try:
        c = mc.cnames[color]
    except:
        c = color
    c = mc.to_rgb(c)
    lightened = tuple(1 - amount * (1 - x) for x in c)
    return f"rgb({int(lightened[0]*255)}, {int(lightened[1]*255)}, {int(lightened[2]*255)})"

In [56]:
def plot_sankey(name, cases_df, main_reasons_df):
    total_cases = int(cases_df['total'][0])
    correct_cases = int(cases_df['correct'][0])
    misclassified_cases = int(cases_df['missclassified'][0])

    main_reason_labels = [str(reason).strip() for reason in main_reasons_df.index]

    # Labels now include counts
    labels = [
        f"Total Cases ({total_cases})",
        f"Correct ({correct_cases})",
        f"Misclassified ({misclassified_cases})"
    ] + [
        f"{reason} ({main_reasons_df.loc[reason, 'count']})" for reason in main_reason_labels
    ]

    idx = {label: i for i, label in enumerate(labels)}

    sources = [
        idx[f"Total Cases ({total_cases})"], idx[f"Total Cases ({total_cases})"]
    ] + [idx[f"Misclassified ({misclassified_cases})"]] * len(main_reason_labels)

    targets = [
        idx[f"Correct ({correct_cases})"], idx[f"Misclassified ({misclassified_cases})"]
    ] + [idx[f"{reason} ({main_reasons_df.loc[reason, 'count']})"] for reason in main_reason_labels]

    values = [
        correct_cases, misclassified_cases
    ] + [int(main_reasons_df.loc[reason, "count"]) for reason in main_reason_labels]

    base_colors = {
        "Total Cases": "#042940",      # muted baby blue (cool, calm)
        "Correct": "#728C14",          # baby blue (lighter, friendly)
        "Misclassified": "#BF2C53",    # warm pastel orange (soft but warm)
    }


    final_reason_palette = [
        "#03588C",
        "#89CFF0",
        # "#731A32",
        "#730220",
        "#BF9BB9",
        "#F27405",
        "#F29F05",
        "#F26680",
    ]

    reason_colors = {}
    for i, reason in enumerate(main_reason_labels):
        reason_colors[reason] = final_reason_palette[i % len(final_reason_palette)]

    # Node colors
    node_colors = [
        base_colors["Total Cases"],
        base_colors["Correct"],
        base_colors["Misclassified"]
    ] + [reason_colors[reason] for reason in main_reason_labels]

    # Link colors: lighter variants for flows
    link_colors = [
        lighten_color(base_colors["Correct"], 0.4),
        lighten_color(base_colors["Misclassified"], 0.4),
    ] + [lighten_color(reason_colors[reason], 0.4) for reason in main_reason_labels]

    # Node positions to avoid overlap
    node_x = [0.0, 0.4, 0.4] + [0.9] * len(main_reason_labels)
    node_y = [0.5, -0.8, 0.8] + [
        (0.9 - (0.1 * len(main_reason_labels)))  + (0.2 * len(main_reason_labels)) * i / max(len(main_reason_labels) - 1, 1)
        for i in range(len(main_reason_labels))
    ]

    fig = go.Figure(data=[go.Sankey(
        arrangement="snap",
        node=dict(
            pad=15,
            thickness=20,
            line=dict(color="black", width=0.5),
            label=labels,
            color=node_colors,
            x=node_x,
            y=node_y
        ),
        link=dict(
            source=sources,
            target=targets,
            value=values,
            color=link_colors
        )
    )])

    fig.update_layout(
        title_text=f"Reasons for misclassifications for {name}",
        font=dict(size=14, color='black', family='Times New Roman', weight=750),
        margin=dict(t=50, l=10, r=10, b=20 * len(main_reason_labels)),
        height=300,
        width=900
    )
    fig.show()


In [782]:
plot_sankey(name = "GPT 4 Simple Prompt",
            cases_df = cases_GPT_4_simple_df,
            main_reasons_df = main_reasons_GPT_4_simple_df)

In [783]:
plot_sankey(name = "GPT o3 Simple Prompt",
            cases_df = cases_GPT_o3_simple_df,
            main_reasons_df = main_reasons_GPT_o3_simple_df)

In [784]:
plot_sankey(name = "Gemini Simple Prompt",
            cases_df = cases_Gemini_simple_df,
            main_reasons_df = main_reasons_Gemini_simple_df)

In [785]:
plot_sankey(name = "Gemma Simple Prompt",
            cases_df = cases_Gemma_simple_df,
            main_reasons_df = main_reasons_Gemma_simple_df)

In [786]:
plot_sankey(name = "Claude 4 Simple Prompt",
            cases_df = cases_Claude_simple_df,
            main_reasons_df = main_reasons_Claude_simple_df)

In [787]:
plot_sankey(name = "DeepSeek Simple Prompt",
            cases_df = cases_DeepSeek_simple_df,
            main_reasons_df = main_reasons_DeepSeek_simple_df)

In [788]:
plot_sankey(name = "Grok Simple Prompt",
            cases_df = cases_Grok_simple_df,
            main_reasons_df = main_reasons_Grok_simple_df)

In [789]:
plot_sankey(name = "GPT 4 CoT Prompt",
            cases_df = cases_GPT_4_cot_df,
            main_reasons_df = main_reasons_GPT_4_cot_df)

In [790]:
plot_sankey(name = "GPT o3 CoT Prompt",
            cases_df = cases_GPT_o3_cot_df,
            main_reasons_df = main_reasons_GPT_o3_cot_df)

In [791]:
plot_sankey(name = "Gemini CoT Prompt",
            cases_df = cases_Gemini_cot_df,
            main_reasons_df = main_reasons_Gemini_cot_df)

In [792]:
# calculate for all simple prompts the percentage of misclassifications that are due to each category

# concatenate all main reasons dataframes
all_main_reasons_simple_df = pd.concat([
    main_reasons_GPT_4_simple_df, main_reasons_GPT_o3_simple_df,
    main_reasons_Gemini_simple_df, main_reasons_Gemma_simple_df,
    main_reasons_Claude_simple_df, main_reasons_DeepSeek_simple_df,
    main_reasons_Grok_simple_df
])

all_cases_simple_df = pd.concat([
    cases_GPT_4_simple_df, cases_GPT_o3_simple_df,
    cases_Gemini_simple_df, cases_Gemma_simple_df,
    cases_Claude_simple_df, cases_DeepSeek_simple_df,
    cases_Grok_simple_df
])

def cal_percentage(all_main_df, all_cases_df):
    # group by index and sum the counts
    all_main_reasons_simple_df = all_main_df.groupby(all_main_df.index).sum()
    # calculate the total count of misclassifications
    total_misclassifications_simple = all_main_reasons_simple_df["count"].sum()
    # calculate the percentage of misclassifications that are due to each category
    all_main_reasons_simple_df["percentage"] = round(all_main_reasons_simple_df["count"] / total_misclassifications_simple * 100, 1)
    # delete count column
    all_main_reasons_simple_df = all_main_reasons_simple_df.drop(columns=["count"])
    # sort df
    all_main_reasons_simple_df = all_main_reasons_simple_df.sort_values(by="percentage", ascending=False)

    # multiply the percentage by percentage of misclassifications
    all_main_reasons_simple_df["percentage"] = round(all_main_reasons_simple_df["percentage"] * all_cases_df["missclassified"].sum() / 100, 2)

    return all_main_reasons_simple_df

def cal_percentage_cases(all_cases_df):
    # calc mean per column
    all_cases_simple_df = all_cases_df.mean()

    # devide by the total number of cases
    total_cases_simple = all_cases_simple_df["total"]
    all_cases_simple_df["correct"] = round(all_cases_simple_df["correct"] / total_cases_simple * 100, 1)
    all_cases_simple_df["missclassified"] = round(all_cases_simple_df["missclassified"] / total_cases_simple * 100, 1)
    all_cases_simple_df["total"] = round(all_cases_simple_df["total"] / total_cases_simple * 100, 1)

    return all_cases_simple_df

all_cases_simple_df = cal_percentage_cases(all_cases_simple_df)
all_main_reasons_simple_df = cal_percentage(all_main_reasons_simple_df, all_cases_simple_df)

In [793]:
def plot_sankey_percentage(name, cases_df, main_reasons_df):
    total_cases = int(cases_df['total'])
    correct_cases = int(cases_df['correct'])
    misclassified_cases = int(cases_df['missclassified'])

    main_reason_labels = [str(reason).strip() for reason in main_reasons_df.index]

    # Labels now include counts
    labels = [
        f"Total Cases ({total_cases}%)",
        f"Correct ({correct_cases}%)",
        f"Misclassified ({misclassified_cases}%)"
    ] + [
        f"{reason} ({main_reasons_df.loc[reason, 'percentage']}%)" for reason in main_reason_labels
    ]
    # ] + [
    #         f"{reason} ({round(main_reasons_df.loc[reason, 'percentage'] * misclassified_cases / 100, 2)}%)" for reason in main_reason_labels
    #     ]

    idx = {label: i for i, label in enumerate(labels)}

    sources = [
        idx[f"Total Cases ({total_cases}%)"], idx[f"Total Cases ({total_cases}%)"]
    ] + [idx[f"Misclassified ({misclassified_cases}%)"]] * len(main_reason_labels)

    targets = [
        idx[f"Correct ({correct_cases}%)"], idx[f"Misclassified ({misclassified_cases}%)"]
    ] + [idx[f"{reason} ({main_reasons_df.loc[reason, 'percentage']}%)"] for reason in main_reason_labels]
    # targets = [
    #     idx[f"Correct ({correct_cases}%)"], idx[f"Misclassified ({misclassified_cases}%)"]
    # ] + [idx[f"{reason} ({round(main_reasons_df.loc[reason, 'percentage'] * misclassified_cases / 100, 2)}%)"] for reason in main_reason_labels]

    values = [
        correct_cases, misclassified_cases
    ] + [int(main_reasons_df.loc[reason, "percentage"]) for reason in main_reason_labels]

    base_colors = {
        "Total Cases": "#042940",      # muted baby blue (cool, calm)
        "Correct": "#728C14",          # baby blue (lighter, friendly)
        "Misclassified": "#BF2C53",    # warm pastel orange (soft but warm)
    }

    final_reason_palette = [
        "#03588C",
        "#89CFF0",
        # "#731A32",
        "#730220",
        "#BF9BB9",
        "#F27405",
        "#F29F05",
        "#F26680",
    ]

    reason_colors = {}
    for i, reason in enumerate(main_reason_labels):
        reason_colors[reason] = final_reason_palette[i % len(final_reason_palette)]

    # Node colors
    node_colors = [
        base_colors["Total Cases"],
        base_colors["Correct"],
        base_colors["Misclassified"]
    ] + [reason_colors[reason] for reason in main_reason_labels]

    # Link colors: lighter variants
    link_colors = [
        lighten_color(base_colors["Correct"], 0.4),
        lighten_color(base_colors["Misclassified"], 0.4),
    ] + [lighten_color(reason_colors[reason], 0.4) for reason in main_reason_labels]

    # Node positions to avoid overlap
    node_x = [0.0, 0.4, 0.4] + [0.9] * len(main_reason_labels)
    node_y = [0.5, -0.8, 0.8] + [
        (0.9 - (0.1 * len(main_reason_labels)))  + (0.2 * len(main_reason_labels)) * i / max(len(main_reason_labels) - 1, 1)
        for i in range(len(main_reason_labels))
    ]

    fig = go.Figure(data=[go.Sankey(
        arrangement="snap",
        node=dict(
            pad=15,
            thickness=20,
            line=dict(color="black", width=0.5),
            label=labels,
            color=node_colors,
            x=node_x,
            y=node_y
        ),
        link=dict(
            source=sources,
            target=targets,
            value=values,
            color=link_colors
        )
    )])

    fig.update_layout(
        title_text=f"Reasons for misclassifications for {name}",
        font=dict(size=14, color='black', family='Times New Roman', weight=750),
        margin=dict(t=50, l=10, r=10, b=20 * len(main_reason_labels)),
        height=300,
        width=900
    )

    fig.show()

In [794]:
plot_sankey_percentage(name = "Simple Prompts",
                        cases_df = all_cases_simple_df,
                        main_reasons_df = all_main_reasons_simple_df)

In [795]:
all_main_reasons_df = pd.concat([
    all_main_reasons_simple_df
])

all_cases_df = pd.concat([
    all_cases_simple_df
])