In [1]:
import os, json, numpy
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.patches import Rectangle
from matplotlib.offsetbox import VPacker, TextArea, DrawingArea, AnchoredOffsetbox, HPacker
from sklearn.metrics import precision_score, recall_score, f1_score, classification_report
from matplotlib.colors import LinearSegmentedColormap

Load the log data from the files

In [2]:
%%capture 

results_dir = '/Users/ispiero2/Documents/Research/Study 3 - Comparison of LLM for LRTI Symptom Extraction/Scripts/logsonly-11juli2025/'

log_files = []

# Loop over the folders with the log results
for folder in os.listdir(results_dir):
    folder_path = os.path.join(results_dir, folder)
    if os.path.isdir(folder_path):
        folds_dict = {}

        # Loop over the folds within the folder
        for fold_folder in os.listdir(folder_path):
            fold_path = os.path.join(folder_path, fold_folder)

            if os.path.isdir(fold_path) and fold_folder.startswith('fold_'):
                log_path = os.path.join(fold_path, 'log_history.json')
                with open(log_path) as f:
                    log_file = json.load(f)

                # Add the log results of the current fold
                folds_dict[fold_folder] = log_file 

        # Add the log results of all folds to the respective modeling setting
        if folds_dict:
            log_files.append({folder: folds_dict})

log_files

Convert the log data into a raw DataFrame

In [3]:
%%capture 

rows = []

for folder_entry in log_files:
    for folder, folds in folder_entry.items():
        for fold_name, log_data in folds.items():
            
            # Extract last evaluation log
            eval_logs = log_data.get('Eval Logs', [])
            if not eval_logs:
                continue  
            last_eval = eval_logs[-1]

            # Split the folder key
            parts = folder.split('_')
            classifier = parts[0]

            # Get name of the model which starts with either 'models', 'robbert', or 'medroberta'
            model_start_idx = next((i for i, part in enumerate(parts) if part.startswith(('models', 'robbert', 'medroberta'))), None)
            if model_start_idx is None:
                continue 
            model = parts[model_start_idx]

            # Get the name of the extracted symptom
            symptom = '_'.join(parts[1:model_start_idx])

            # Get the size of the training sample
            sample_size = parts[model_start_idx + 1] if len(parts) > model_start_idx + 1 else ''

            # Add row to the DataFrame with the respective results
            rows.append({
                'Classifier': classifier,
                'Model': model,
                'Symptom': symptom,
                'Number of samples': sample_size,
                'Fold': fold_name,
                'Confusion matrix': last_eval.get('eval_confusion_matrix')
            })

# Create DataFrame
df_raw = pd.DataFrame(rows)

df_raw.head()

Clean the dataframe into the desired format

In [4]:
%%capture 

# Clean the value names
df_cleaned = df_raw.copy()

df_cleaned['Classifier'] = df_cleaned['Classifier'].replace({'run': 'Direct', 
                                                             'pbrun': 'Prompt-based'})
df_cleaned['Symptom'] = df_cleaned['Symptom'].replace({'Pijn_Borst': 'Chest pain', 
                                                       'Zieke_Indruk': 'Ill appearance',
                                                       'Auscultatie': 'Crackles upon auscultation',
                                                       'Hoesten': 'Cough',
                                                       'Dyspnoe': 'Shortness of breath',
                                                       'Rillingen': 'Chills',
                                                       'Sputum': 'Sputum',
                                                       'Verwardheid': 'Confusion',
                                                       'Crepitaties': 'Crackles upon auscultation',
                                                       'Koorts': 'Fever'})
df_cleaned['Model'] = df_cleaned['Model'].replace({'models--pdelobelle--robbert-v2-dutch-base': 'RobBERT', 
                                                   'models--CLTL--MedRoBERTa.nl': 'MedRoBERTa.nl',
                                                   'robbert-prompt': 'RobBERT',
                                                   'medroberta-prompt': 'MedRoBERTa.nl'})
df_cleaned['Number of samples'] = df_cleaned['Number of samples'].astype(str)
df_cleaned['Number of samples'] = df_cleaned['Number of samples'].str.replace('-samples', '', regex=False)
df_cleaned['Fold'] = df_cleaned['Fold'].str.replace('fold_', '', regex=False)

# Convert number of samples to numeric
df_cleaned['Number of samples'] = pd.to_numeric(df_cleaned['Number of samples'], errors='coerce')

# Remove the sample size of 25 that is not used in the analysis (too low)
df_cleaned = df_cleaned[df_cleaned['Number of samples'] != 25]

df_cleaned.head()

In [5]:
%%capture 

# Check the rows in which the confusion matrix is 2x2 (instead of 3x3)
mask_2x2 = df_cleaned['Confusion matrix'].apply(
    lambda x: isinstance(x, list) and len(x) == 2 and all(isinstance(row, list) and len(row) == 2 for row in x)
)
df_2x2 = df_cleaned[mask_2x2]

# Save the row indices
indices_2x2 = df_cleaned[mask_2x2].index

df_2x2

In [6]:
%%capture 

# Create a dataframe with the counts for class within each symptom across the folds:
columns = ['Fold', 'Class', 'Fever', 'Cough', 'Shortness of breath', 'Sputum', 'Confusion', 'Chest pain',
          'Chills', 'Ill appearence', 'Crackles upon auscultation']
data = [
    [0, 0, 147, 6, 58, 14, 9, 15, 2, 81, 259],
    [0, 1, 98, 309, 165, 94, 6, 48, 20, 46, 76],
    [0, 2, 147, 77, 169, 284, 377, 329, 370, 256, 57],
    [1, 0, 127,4,37,5,4,16,1,75,222],
    [1, 1, 94,299,159,87,8,37,19,42,107],
    [1, 2, 171,89,196,300,380,339,372,275,63],
    [2, 0, 113,7,41,11,7,12,3,90,257],
    [2, 1, 112,289,150,87,1,51,18,38,89],
    [2, 2, 166,95,200,293,383,328,370,263,45],
    [3, 0, 139,4,56,13,7,22,1,83,238],
    [3, 1, 91,299,152,105,5,37,17,41,90],
    [3, 2, 161,88,183,273,379,332,373,267,63],
    [4, 0, 125,6,46,12,10,12,0,57,232],
    [4, 1, 99,294,147,91,2,37,16,44,97],
    [4, 2, 167,91,198,288,379,342,375,290,62]
]
df_counts = pd.DataFrame(data, columns=columns)

df_counts

In [7]:
%%capture 

# Expand the 2x2 matrices to 3x3 by adding zeros in case a class happened to not occur in the data sample

def expand_confusion_matrix(cm, true_counts, num_classes=3):
    
    present_classes = [cls for cls, count in enumerate(true_counts) if count > 0]
    full_cm = np.zeros((num_classes, num_classes), dtype=int)

    # Map present classes to cm indices
    row_map = {i: cls for i, cls in enumerate(present_classes)}
    col_map = row_map  

    for i, row_class in row_map.items():
        for j, col_class in col_map.items():
            full_cm[row_class, col_class] = cm[i, j]
    
    return full_cm


def fix_confusion_matrices(df, df_counts):

    df['Fold'] = df['Fold'].astype(int)
    df_counts['Fold'] = df_counts['Fold'].astype(int)
    
    fixed_matrices = []
    
    for idx, row in df.iterrows():
        cm = np.array(row['Confusion matrix'])
        symptom = row['Symptom']
        fold = row['Fold']
        
        if cm.shape == (3, 3):
            fixed_matrices.append(cm)
            continue  

        # Get true class counts from df_counts
        counts_row = df_counts[df_counts['Fold'] == fold]
        if counts_row.empty:
            raise ValueError(f"No matching fold={fold} found in df_counts.")
        
        counts = counts_row[[symptom]].reset_index(drop=True)
        if counts.shape[0] != 3:
            raise ValueError(f"Expected 3 class rows for fold={fold} in df_counts.")

        true_counts = counts[symptom].tolist()
        expanded_cm = expand_confusion_matrix(cm, true_counts)
        fixed_matrices.append(expanded_cm)

    df['Confusion matrix'] = fixed_matrices
    return df

df_correct = fix_confusion_matrices(df_cleaned, df_counts)
df_correct.head()

In [8]:
%%capture 

# Check how many 2x2 matrices there are in the DataFrame (should be zero)
len(df_correct[df_correct['Confusion matrix'].apply(
    lambda x: isinstance(x, list) and len(x) == 2 and all(isinstance(row, list) and len(row) == 2 for row in x)
)])

Compute the (micro/macro/per-class) averages of recall, precision, and F1-score

In [9]:
%%capture 

def compute_metrics(conf_matrix):
    
    conf_matrix = np.array(conf_matrix)
    y_true = []
    y_pred = []
    num_classes = 3

    for i in range(num_classes):      
        for j in range(num_classes):   
            y_true += [i] * conf_matrix[i, j]
            y_pred += [j] * conf_matrix[i, j]

    labels = [0, 1, 2]

    # Global metrics
    micro_precision = precision_score(y_true, y_pred, average='micro', labels=labels, zero_division=0)
    micro_recall = recall_score(y_true, y_pred, average='micro', labels=labels, zero_division=0)
    micro_f1 = f1_score(y_true, y_pred, average='micro', labels=labels, zero_division=0)

    macro_precision = precision_score(y_true, y_pred, average='macro', labels=labels, zero_division=0)
    macro_recall = recall_score(y_true, y_pred, average='macro', labels=labels, zero_division=0)
    macro_f1 = f1_score(y_true, y_pred, average='macro', labels=labels, zero_division=0)

    # Per-class metrics
    report = classification_report(
        y_true, y_pred, labels=labels,
        output_dict=True, zero_division=0
    )

    return {
        'Micro-precision': micro_precision,
        'Micro-recall': micro_recall,
        'Micro-F1': micro_f1,
        'Macro-precision': macro_precision,
        'Macro-recall': macro_recall,
        'Macro-F1': macro_f1,
        "Precision 'present'": report['1']['precision'],
        "Recall 'present'": report['1']['recall'],
        "F1 'present'": report['1']['f1-score'],
        "Precision 'absent'": report['0']['precision'],
        "Recall 'absent'": report['0']['recall'],
        "F1 'absent'": report['0']['f1-score'],
        "Precision 'not reported'": report['2']['precision'],
        "Recall 'not reported'": report['2']['recall'],
        "F1 'not reported'": report['2']['f1-score'],
    }

metrics_df = df_correct['Confusion matrix'].apply(compute_metrics).apply(pd.Series)
df_final = pd.concat([df_correct, metrics_df], axis=1)

df_final.head()

In [10]:
%%capture 

# Compute the averages of the metrics across folds
averaged_df = df_final.groupby(['Classifier', 'Model',  'Symptom', 'Number of samples'])[['Micro-precision',
                                                                                          'Micro-recall',
                                                                                          'Micro-F1',
                                                                                          'Macro-precision',
                                                                                          'Macro-recall',
                                                                                          'Macro-F1',
                                                                                          "Precision 'present'",
                                                                                          "Recall 'present'",
                                                                                          "F1 'present'",
                                                                                          "Precision 'absent'",
                                                                                          "Recall 'absent'",
                                                                                          "F1 'absent'",
                                                                                          "Precision 'not reported'",
                                                                                          "Recall 'not reported'",
                                                                                          "F1 'not reported'"]].mean().reset_index()

averaged_df['Number of samples'] = averaged_df['Number of samples'].astype(str)
averaged_df_styled = averaged_df.style.background_gradient(
    cmap=LinearSegmentedColormap.from_list(
    'soft_rgy', ['#ffcccc', '#fff2b2', '#ccffcc']
),  
    axis=None       
)

averaged_df_styled.to_excel("Table_all.xlsx", engine='openpyxl')

averaged_df_styled

Create a subset of the largest number of training data used

In [11]:
%%capture 

averaged_df_copy = averaged_df.copy()
averaged_df_copy['Number of samples'] = averaged_df_copy['Number of samples'].astype(str)
averaged_df_copy = averaged_df_copy[averaged_df_copy['Number of samples'] == '1600'].style.background_gradient(
    cmap=LinearSegmentedColormap.from_list(
    'soft_rgy', ['#ffcccc', '#fff2b2', '#ccffcc']
),  
    axis=None       
)
averaged_df_copy.to_excel("Table_1600.xlsx", engine='openpyxl')
averaged_df_copy

Derive the min, max and mean values of the metrics for in the results section

In [12]:
%%capture 

# Select only the results for the largest number of samples
averaged_df['Number of samples'] = pd.to_numeric(averaged_df['Number of samples'], errors='coerce')
#filtered_df = averaged_df[averaged_df['Number of samples'].isin([1600, 3])]
filtered_df = averaged_df
grouped_results = filtered_df.groupby(['Classifier', 'Model'])[['Micro-precision',
            'Micro-recall',
            'Micro-F1',
            'Macro-precision',
            'Macro-recall',
            'Macro-F1',
            "Precision 'present'",
            "Recall 'present'",
            "F1 'present'",
            "Precision 'absent'",
            "Recall 'absent'",
            "F1 'absent'",
            "Precision 'not reported'",
            "Recall 'not reported'",
            "F1 'not reported'"]].agg(['min', 'max', 'mean','median']).reset_index()

grouped_results.round(2).T

In [13]:
%%capture 

# Select only the results for the largest number of samples used
filtered_df = averaged_df[averaged_df['Number of samples'].isin([200, 1])]

grouped_results = filtered_df.groupby(['Classifier', 'Model'])[['Micro-precision',
            'Micro-recall',
            'Micro-F1',
            'Macro-precision',
            'Macro-recall',
            'Macro-F1',
            "Precision 'present'",
            "Recall 'present'",
            "F1 'present'",
            "Precision 'absent'",
            "Recall 'absent'",
            "F1 'absent'",
            "Precision 'not reported'",
            "Recall 'not reported'",
            "F1 'not reported'"]].agg(['min', 'max', 'mean']).reset_index()

grouped_results.round(2).T

Create a figure for the results of direct classification

In [14]:
%%capture 

averaged_df['Number of samples'] = pd.to_numeric(averaged_df['Number of samples'], errors='coerce')

# Select rows where Classifier is 'Direct'
direct_classifier_df = averaged_df[averaged_df['Classifier'] == 'Direct']

# Convert wide to long format
df_long = direct_classifier_df.melt(
    id_vars=['Model', 'Symptom', 'Number of samples'],
    value_vars=['Micro-precision',
            'Micro-recall',
            'Micro-F1',
            'Macro-precision',
            'Macro-recall',
            'Macro-F1',
            "Precision 'present'",
            "Recall 'present'",
            "F1 'present'",
            "Precision 'absent'",
            "Recall 'absent'",
            "F1 'absent'",
            "Precision 'not reported'",
            "Recall 'not reported'",
            "F1 'not reported'"],
    var_name='metric',
    value_name='value'
)

# Create the mirrored value column to plot the models in a vertical mirrored barplot
df_long['mirrored_value'] = df_long.apply(
    lambda row: row['value'] if row['Model'] == 'RobBERT' else -row['value'],
    axis=1
)
df_long

In [15]:
%%capture 

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.patches import Rectangle
from matplotlib.offsetbox import AnchoredOffsetbox, VPacker, HPacker, TextArea, DrawingArea

# Color palette for MedRoBERTa.nl (blue) and RobBERT (orange)
subcategory_colors_medroberta = {
    'Recall': {
        'Macro-recall': '#bdbdbd'    
    },
    'Precision': {
        'Macro-precision': '#bdbdbd'
    },
    'F1': {
        'Macro-F1': '#bdbdbd'
    }
}

subcategory_colors_roberta = {
    'Recall': {
        'Macro-recall': '#bdbdbd'       
    },
    'Precision': {
        'Macro-precision': '#bdbdbd'
    },
    'F1': {
        'Macro-F1': '#bdbdbd'
    }
}

metrics = ['Recall', 'Precision', 'F1']
subcategories_per_metric = {
    'Recall': ['Macro-recall'],
    'Precision': ['Macro-precision'],
    'F1': ['Macro-F1']
}

# Grid setup: Each column corresponds to a metric (Recall, Precision, F1) and each row to a symptom
num_symptoms = len(df_long['Symptom'].unique())
fig, axes = plt.subplots(3, num_symptoms, figsize=(40, 20), sharex=True, sharey=True)  # Transposed grid

bar_width = 0.45
ytick_spacing = 0.75
font_params = {'fontsize': 18}  # Increased font size

# Loop over metrics (columns)
for col_idx, metric in enumerate(metrics):
    subcats = subcategories_per_metric[metric]

    # Loop over symptoms (rows)
    for row_idx, symptom in enumerate(df_long['Symptom'].unique()):
        ax = axes[col_idx, row_idx] if axes.ndim == 2 else axes[row_idx]
        symptom_data = df_long[df_long['Symptom'] == symptom]
        ytick_count = len(symptom_data['Number of samples'].unique())
        y_positions = np.arange(ytick_count) * ytick_spacing

        # Loop over subcategories
        for k, subcat in enumerate(subcats):
            for model in ['MedRoBERTa.nl', 'RobBERT']:
                subset = symptom_data[
                    (symptom_data['metric'] == subcat) &
                    (symptom_data['Model'] == model)
                ]
                training_sizes = sorted(subset['Number of samples'].unique())
                values = [
                    subset[subset['Number of samples'] == ts]['mirrored_value'].values[0]
                    if not subset[subset['Number of samples'] == ts].empty else np.nan
                    for ts in training_sizes
                ]
                y_pos = y_positions + k * bar_width - bar_width  # stack bars within the same metric

                # Apply subcategory-specific color
                if model == 'MedRoBERTa.nl':
                    color = subcategory_colors_medroberta[metric].get(subcat, '#cccccc')
                else:
                    color = subcategory_colors_roberta[metric].get(subcat, '#cccccc')

                # Set zorder for bars to be on top of the grid
                ax.barh(
                    y=y_pos,
                    width=values,
                    height=bar_width,
                    color=color,
                    edgecolor='none',
                    zorder=2  # Bars on top
                )

        # Y-axis labels for the first row only
        if row_idx == 0:
            ax.set_yticks(y_positions - 0.45)
            ax.set_yticklabels([200, 400, 600, 800, 1000, 1200, 1400, 1600], fontsize=font_params['fontsize'])
            ax.set_ylabel("Number of samples", fontsize=font_params['fontsize'])
        else:
            ax.set_yticklabels([200, 400, 600, 800, 1000, 1200, 1400, 1600])

        ax.set_xlim(-1, 1)
        ax.axvline(0, color='black', linewidth=0.5, zorder=3)  # Ensure vertical line is on top

        # Move grid call after bars and set zorder for grid
        ax.grid(True, axis='x', linestyle='--', linewidth=0.5, color='black', zorder=1)  # Grid behind the bars

        # Remove x-tick labels for non-last rows
        if row_idx != num_symptoms - 1:
            ax.set_xticklabels([])

        # Set custom x-tick labels
        xticks = np.linspace(-1, 1, 9)
        xticklabels = ['1', '0.75', '0.5', '0.25', '0', '0.25', '0.50', '0.75', '1']
        ax.set_xticks(xticks)
        ax.set_xticklabels(xticklabels, fontsize=font_params['fontsize']-5)

        ax.text(0.442, 1.005, 'MedRoBERTa.nl vs RobBERT', transform=ax.transAxes,
                fontsize=font_params['fontsize']-1, ha='center', va='bottom')

# Add "Value" label for each column in the last row
for col_idx in range(num_symptoms):
    ax = axes[len(metrics) - 1, col_idx]  # Accessing the last row's columns individually
    ax.set_xlabel("Value", fontsize=font_params['fontsize'])  # Add "Value" for each column

# Subtitle
# Add column headers
for idx, metric in enumerate(df_long['Symptom'].unique()):
    fig.text(
        x=(0.065 + idx * 0.102),  # Adjust these to match your layout spacing
        y=0.98,
        s=metric,
        ha='center',
        va='bottom',
        fontsize=font_params['fontsize'] + 2,  # Adjust font size
        fontweight='bold'
    )

# Add row labels (Recall, Precision, F1) from top to bottom
row_labels = ['Recall', 'Precision', 'F1-score']
for idx, label in enumerate(row_labels):
    fig.text(
        x=-0.01,  # x-position (left side of the figure)
        y=0.81 - idx * 0.31,  # y-position tuned per row
        s=label,
        ha='left',
        va='center',
        rotation='vertical',
        fontsize=font_params['fontsize'] + 2,
        fontweight='bold'
    )

fig.suptitle("Direct classification", fontsize=20, fontweight='bold', y=1.02, x=0.01, ha='left')

plt.tight_layout(rect=[0, 0, 0.93, 1])
plt.savefig('Direct_classifiers_macro.png', format='png')
plt.show()


In [16]:
%%capture 

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.patches import Rectangle
from matplotlib.offsetbox import AnchoredOffsetbox, VPacker, HPacker, TextArea, DrawingArea

subcategory_colors_models = {
    'Recall': {
        "Recall 'present'": '#66c2a5',        # Teal blue
        "Recall 'absent'": '#fc8d62',         # Warm amber
        "Recall 'not reported'": '#bdbdbd'    # Neutral grey
    },
    'Precision': {
        "Precision 'present'": '#66c2a5',
        "Precision 'absent'": '#fc8d62',
        "Precision 'not_reported'": '#bdbdbd'
    },
    'F1': {
        "F1 'present'": '#66c2a5',
        "F1 'absent'": '#fc8d62',
        "F1 'not reported'": '#bdbdbd'
    }
}

metrics = ['Recall', 'Precision', 'F1']
subcategories_per_metric = {
    'Recall': ["Recall 'not reported'", "Recall 'absent'", "Recall 'present'"],
    'Precision': ["Precision 'not reported'", "Precision 'absent'", "Precision 'present'"],
    'F1': ["F1 'not reported'", "F1 'absent'", "F1 'present'"]
}

num_symptoms = len(df_long['Symptom'].unique())
fig, axes = plt.subplots(3, num_symptoms, figsize=(40, 20), sharex=True, sharey=True)

bar_width = 0.15
ytick_spacing = 0.75
font_params = {'fontsize': 18}

for col_idx, metric in enumerate(metrics):
    subcats = subcategories_per_metric[metric]

    for row_idx, symptom in enumerate(df_long['Symptom'].unique()):
        ax = axes[col_idx, row_idx] if axes.ndim == 2 else axes[row_idx]
        symptom_data = df_long[df_long['Symptom'] == symptom]
        ytick_count = len(symptom_data['Number of samples'].unique())
        y_positions = np.arange(ytick_count) * ytick_spacing

        for k, subcat in enumerate(subcats):
            for model in ['MedRoBERTa.nl', 'RobBERT']:
                subset = symptom_data[
                    (symptom_data['metric'] == subcat) &
                    (symptom_data['Model'] == model)
                ]
                training_sizes = sorted(subset['Number of samples'].unique())
                values = [
                    subset[subset['Number of samples'] == ts]['mirrored_value'].values[0]
                    if not subset[subset['Number of samples'] == ts].empty else np.nan
                    for ts in training_sizes
                ]
                y_pos = y_positions + k * bar_width - bar_width

                # Unified color for subcategories
                color = subcategory_colors_models[metric].get(subcat, '#cccccc')

                # Set zorder for bars to be on top of the grid
                ax.barh(
                    y=y_pos,
                    width=values,
                    height=bar_width,
                    color=color,
                    edgecolor='none',
                    zorder=2  # Bars on top
                )

        if row_idx == 0:
            ax.set_yticks(y_positions)
            ax.set_yticklabels([200, 400, 600, 800, 1000, 1200, 1400, 1600], fontsize=font_params['fontsize'])
            ax.set_ylabel("Number of samples", fontsize=font_params['fontsize'])
        else:
            ax.set_yticklabels([200, 400, 600, 800, 1000, 1200, 1400, 1600])

        ax.set_xlim(-1, 1)
        ax.axvline(0, color='black', linewidth=0.5, zorder=3)  # Ensure vertical line is on top

        # Move grid call after bars and set zorder for grid
        ax.grid(True, axis='x', linestyle='--', linewidth=0.5, color='black', zorder=1)  # Grid behind the bars

        if row_idx != num_symptoms - 1:
            ax.set_xticklabels([])

        xticks = np.linspace(-1, 1, 9)
        xticklabels = ['1', '0.75', '0.5', '0.25', '0', '0.25', '0.50', '0.75', '1']
        ax.set_xticks(xticks)
        ax.set_xticklabels(xticklabels, fontsize=font_params['fontsize']-5)

        ax.text(0.442, 1.005, 'MedRoBERTa.nl vs RobBERT', transform=ax.transAxes,
                fontsize=font_params['fontsize']-1, ha='center', va='bottom')

for col_idx in range(num_symptoms):
    ax = axes[len(metrics) - 1, col_idx]
    ax.set_xlabel("Value", fontsize=font_params['fontsize'])

# Legend (subcategories only, no model distinction)
legend_submetrics = ["'present'", "'absent'", "'not reported'"]
label_mapping = {
    "'present'": "Present",
    "'absent'": "Absent",
    "'not reported'": "Not reported"
}
color_keys = ["Recall 'present'", "Recall 'absent'", "Recall 'not reported'"]

legend_items = []
for submetric_label, color_key in zip(legend_submetrics, color_keys):
    patch = DrawingArea(20, 10, 0, 0)
    rect = Rectangle((0, 0), 20, 10, fc=subcategory_colors_models['Recall'][color_key], edgecolor='none')
    patch.add_artist(rect)
    label = TextArea(label_mapping[submetric_label], textprops=dict(fontsize=font_params['fontsize'], ha='left'))
    legend_items.append(HPacker(children=[patch, label], align="left", pad=0, sep=6))

final_legend = VPacker(children=legend_items, align="left", pad=0, sep=10)

anchored_box = AnchoredOffsetbox(
    loc='center left',
    child=final_legend,
    pad=0.,
    frameon=False,
    bbox_to_anchor=(0.94, 0.5),
    bbox_transform=fig.transFigure,
    borderpad=0.
)
fig.add_artist(anchored_box)

# Add column headers (symptom names)
for idx, symptom in enumerate(df_long['Symptom'].unique()):
    fig.text(
        x=(0.065 + idx * 0.102),
        y=0.98,
        s=symptom,
        ha='center',
        va='bottom',
        fontsize=font_params['fontsize'] + 2,
        fontweight='bold'
    )

# Row labels (Recall, Precision, F1)
row_labels = ['Recall', 'Precision', 'F1-score']
for idx, label in enumerate(row_labels):
    fig.text(
        x=-0.01,
        y=0.81 - idx * 0.31,
        s=label,
        ha='left',
        va='center',
        rotation='vertical',
        fontsize=font_params['fontsize'] + 2,
        fontweight='bold'
    )

fig.suptitle("Direct classification", fontsize=20, fontweight='bold', y=1.02, x=0.01, ha='left')

plt.tight_layout(rect=[0, 0, 0.93, 1])
plt.savefig('Direct_classifiers_perclass.png', format='png', bbox_inches='tight')
plt.show()


In [17]:
%%capture 

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.patches import Rectangle
from matplotlib.offsetbox import AnchoredOffsetbox, VPacker, HPacker, TextArea, DrawingArea

# Color palette for MedRoBERTa.nl (blue) and RobBERT (orange)
subcategory_colors_medroberta = {
    'Recall': {
        'Macro-recall': '#2B4C7E',         # dark blue
        "Recall 'present'": '#5178A6',     # medium blue
        "Recall 'absent'": '#7FA2C9',      # light-medium blue
        "Recall 'not reported'": '#A9C3DE' # light blue
    },
    'Precision': {
        'Macro-precision': '#2B4C7E',
        "Precision 'present'": '#5178A6',
        "Precision 'absent'": '#7FA2C9',
        "Precision 'not reported'": '#A9C3DE'
    },
    'F1': {
        'Macro-F1': '#2B4C7E',
        "F1 'present'": '#5178A6',
        "F1 'absent'": '#7FA2C9',
        "F1 'not reported'": '#A9C3DE'
    }
}

subcategory_colors_roberta = {
    'Recall': {
        'Macro-recall': '#B85716',         # dark orange
        "Recall 'present'": '#DA7A34',     # medium orange
        "Recall 'absent'": '#F1A469',      # light-medium orange
        "Recall 'not reported'": '#F6C7A3' # light orange

    },
    'Precision': {
        'Macro-precision': '#B85716',
        "Precision 'present'": '#DA7A34',
        "Precision 'absent'": '#F1A469',
        "Precision 'not_reported'": '#F6C7A3'
    },
    'F1': {
        'Macro-F1': '#B85716',
        "F1 'present'": '#DA7A34',
        "F1 'absent'": '#F1A469',
        "F1 'not reported'": '#F6C7A3'
    }
}

metrics = ['Recall', 'Precision', 'F1']
subcategories_per_metric = {
    'Recall': ["Recall 'not reported'", "Recall 'absent'", "Recall 'present'", 'Macro-recall'],
    'Precision': ["Precision 'not reported'", "Precision 'absent'", "Precision 'present'", 'Macro-precision'],
    'F1': ["F1 'not reported'", "F1 'absent'", "F1 'present'", 'Macro-F1']
}

# Grid setup: Each column corresponds to a metric (Recall, Precision, F1) and each row to a symptom
num_symptoms = len(df_long['Symptom'].unique())
fig, axes = plt.subplots(3, num_symptoms, figsize=(40, 20), sharex=True, sharey=True)  # Transposed grid

bar_width = 0.15
ytick_spacing = 1
font_params = {'fontsize': 18}  # Increased font size

# Loop over metrics (columns)
for col_idx, metric in enumerate(metrics):
    subcats = subcategories_per_metric[metric]

    # Loop over symptoms (rows)
    for row_idx, symptom in enumerate(df_long['Symptom'].unique()):
        ax = axes[col_idx, row_idx] if axes.ndim == 2 else axes[row_idx]
        symptom_data = df_long[df_long['Symptom'] == symptom]
        ytick_count = len(symptom_data['Number of samples'].unique())
        y_positions = np.arange(ytick_count) * ytick_spacing

        # Loop over subcategories
        for k, subcat in enumerate(subcats):
            for model in ['MedRoBERTa.nl', 'RobBERT']:
                subset = symptom_data[
                    (symptom_data['metric'] == subcat) &
                    (symptom_data['Model'] == model)
                ]
                training_sizes = sorted(subset['Number of samples'].unique())
                values = [
                    subset[subset['Number of samples'] == ts]['mirrored_value'].values[0]
                    if not subset[subset['Number of samples'] == ts].empty else np.nan
                    for ts in training_sizes
                ]
                y_pos = y_positions + k * bar_width - bar_width  # stack bars within the same metric

                # Apply subcategory-specific color
                if model == 'MedRoBERTa.nl':
                    color = subcategory_colors_medroberta[metric].get(subcat, '#cccccc')
                else:
                    color = subcategory_colors_roberta[metric].get(subcat, '#cccccc')

                ax.barh(
                    y=y_pos,
                    width=values,
                    height=bar_width,
                    color=color,
                    edgecolor='none'
                )

        # Y-axis labels for the first row only
        if row_idx == 0:
            ax.set_yticks(y_positions + 0.15)
            ax.set_yticklabels([200, 400, 600, 800, 1000, 1200, 1400, 1600], fontsize=font_params['fontsize'])
            ax.set_ylabel("Number of samples", fontsize=font_params['fontsize'])
        else:
            ax.set_yticklabels([200, 400, 600, 800, 1000, 1200, 1400, 1600])

        ax.set_xlim(-1, 1)
        ax.axvline(0, color='black', linewidth=0.5)
        ax.grid(True, axis='x', linestyle='--', linewidth=0.5, color='black')

        # Remove x-tick labels for non-last rows
        if row_idx != num_symptoms - 1:
            ax.set_xticklabels([])

        # Set custom x-tick labels
        xticks = np.linspace(-1, 1, 9)
        xticklabels = ['1', '0.75', '0.5', '0.25', '0', '0.25', '0.50', '0.75', '1']
        ax.set_xticks(xticks)
        ax.set_xticklabels(xticklabels, fontsize=font_params['fontsize']-5)

        ax.text(0.442, 1.005, 'MedRoBERTa.nl vs RobBERT', transform=ax.transAxes,
                fontsize=font_params['fontsize']-1, ha='center', va='bottom')

# Add "Value" label for each column in the last row
for col_idx in range(num_symptoms):
    ax = axes[len(metrics) - 1, col_idx]  # Accessing the last row's columns individually
    ax.set_xlabel("Value", fontsize=font_params['fontsize'])  # Add "Value" for each column

# Build and add legend
# Define unified submetric order and labels
legend_submetrics = ['Macro', "'present'", "'absent'", "'not reported'"]
label_mapping = {
    'Macro': 'Macro',
    "'present'": "Present",
    "'absent'": "Absent",
    "'not reported'": "Not reported"
}

# Use Recall mappings as representative for colors (they are consistent across metrics)
color_keys_med = ['Macro-recall', "Recall 'present'", "Recall 'absent'", "Recall 'not reported'"]
color_keys_rob = ['Macro-recall', "Recall 'present'", "Recall 'absent'", "Recall 'not reported'"]

legend_items_medroberta = [TextArea("MedRoBERTa.nl", textprops=dict(fontsize=font_params['fontsize'], ha='center'))]
legend_items_roberta = [TextArea("RobBERT", textprops=dict(fontsize=font_params['fontsize'], ha='center'))]

# Add submetric entries just once
for submetric_label, color_key_med, color_key_rob in zip(legend_submetrics, color_keys_med, color_keys_rob):
    # MedRoBERTa.nl
    patch_blue = DrawingArea(20, 10, 0, 0)
    rect_blue = Rectangle((0, 0), 20, 10, fc=subcategory_colors_medroberta['Recall'][color_key_med], edgecolor='none')
    patch_blue.add_artist(rect_blue)
    label_blue = TextArea(label_mapping[submetric_label], textprops=dict(fontsize=font_params['fontsize'], ha='left'))
    legend_items_medroberta.append(HPacker(children=[patch_blue, label_blue], align="left", pad=0, sep=6))

    # RobBERT
    patch_orange = DrawingArea(20, 10, 0, 0)
    rect_orange = Rectangle((0, 0), 20, 10, fc=subcategory_colors_roberta['Recall'][color_key_rob], edgecolor='none')
    patch_orange.add_artist(rect_orange)
    label_orange = TextArea(label_mapping[submetric_label], textprops=dict(fontsize=font_params['fontsize'], ha='left'))
    legend_items_roberta.append(HPacker(children=[patch_orange, label_orange], align="left", pad=0, sep=6))

# Combine and place the legend
final_legend = VPacker(
    children=legend_items_medroberta + [TextArea("", textprops=dict(fontsize=font_params['fontsize']))] + legend_items_roberta,
    align="left", pad=0, sep=18
)

anchored_box = AnchoredOffsetbox(
    loc='center left',
    child=final_legend,
    pad=0.,
    frameon=False,
    bbox_to_anchor=(0.94, 0.5),
    bbox_transform=fig.transFigure,
    borderpad=0.
)
fig.add_artist(anchored_box)


# Subtitle
# Add column headers
for idx, metric in enumerate(df_long['Symptom'].unique()):
    fig.text(
        x=(0.065 + idx * 0.102),  # Adjust these to match your layout spacing
        y=0.98,
        s=metric,
        ha='center',
        va='bottom',
        fontsize=font_params['fontsize'] + 2,  # Adjust font size
        fontweight='bold'
    )

# Add row labels (Recall, Precision, F1) from top to bottom
row_labels = ['Recall', 'Precision', 'F1-score']
for idx, label in enumerate(row_labels):
    fig.text(
        x=-0.01,  # x-position (left side of the figure)
        y=0.81 - idx * 0.31,  # y-position tuned per row
        s=label,
        ha='left',
        va='center',
        rotation='vertical',
        fontsize=font_params['fontsize'] + 2,
        fontweight='bold'
    )

fig.suptitle("Direct classification", fontsize=20, fontweight='bold', y=1.02, x=0.01, ha='left')

plt.tight_layout(rect=[0, 0, 0.93, 1])
plt.savefig('Direct_classifiers.png', format='png')
plt.show()


Create a figure for the results of prompt-based classification

In [18]:
%%capture 

# Select rows where Classifier is 'Prompt-based'
prompt_based_df = averaged_df[averaged_df['Classifier'] == 'Prompt-based']

# Convert wide to long format
df_long = prompt_based_df.melt(
    id_vars=['Model', 'Symptom', 'Number of samples'],
    value_vars=['Micro-precision',
            'Micro-recall',
            'Micro-F1',
            'Macro-precision',
            'Macro-recall',
            'Macro-F1',
            "Precision 'present'",
            "Recall 'present'",
            "F1 'present'",
            "Precision 'absent'",
            "Recall 'absent'",
            "F1 'absent'",
            "Precision 'not reported'",
            "Recall 'not reported'",
            "F1 'not reported'"],
    var_name='metric',
    value_name='value'
)

# Create the mirrored value column to plot the models in a vertical mirrored barplot
df_long['mirrored_value'] = df_long.apply(
    lambda row: row['value'] if row['Model'] == 'RobBERT' else -row['value'],
    axis=1
)
df_long

In [19]:
%%capture 

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.patches import Rectangle
from matplotlib.offsetbox import AnchoredOffsetbox, VPacker, HPacker, TextArea, DrawingArea

# Color palette for MedRoBERTa.nl (blue) and RobBERT (orange)
subcategory_colors_medroberta = {
    'Recall': {
        'Macro-recall': '#bdbdbd'    
    },
    'Precision': {
        'Macro-precision': '#bdbdbd'
    },
    'F1': {
        'Macro-F1': '#bdbdbd'
    }
}

subcategory_colors_roberta = {
    'Recall': {
        'Macro-recall': '#bdbdbd'       
    },
    'Precision': {
        'Macro-precision': '#bdbdbd'
    },
    'F1': {
        'Macro-F1': '#bdbdbd'
    }
}

metrics = ['Recall', 'Precision', 'F1']
subcategories_per_metric = {
    'Recall': ['Macro-recall'],
    'Precision': ['Macro-precision'],
    'F1': ['Macro-F1']
}

# Grid setup: Each column corresponds to a metric (Recall, Precision, F1) and each row to a symptom
num_symptoms = len(df_long['Symptom'].unique())
fig, axes = plt.subplots(3, num_symptoms, figsize=(40, 20), sharex=True, sharey=True)  # Transposed grid

bar_width = 0.45
ytick_spacing = 0.75
font_params = {'fontsize': 18}  # Increased font size

# Loop over metrics (columns)
for col_idx, metric in enumerate(metrics):
    subcats = subcategories_per_metric[metric]

    # Loop over symptoms (rows)
    for row_idx, symptom in enumerate(df_long['Symptom'].unique()):
        ax = axes[col_idx, row_idx] if axes.ndim == 2 else axes[row_idx]
        symptom_data = df_long[df_long['Symptom'] == symptom]
        ytick_count = len(symptom_data['Number of samples'].unique())
        y_positions = np.arange(ytick_count) * ytick_spacing

        # Loop over subcategories
        for k, subcat in enumerate(subcats):
            for model in ['MedRoBERTa.nl', 'RobBERT']:
                subset = symptom_data[
                    (symptom_data['metric'] == subcat) &
                    (symptom_data['Model'] == model)
                ]
                training_sizes = sorted(subset['Number of samples'].unique())
                values = [
                    subset[subset['Number of samples'] == ts]['mirrored_value'].values[0]
                    if not subset[subset['Number of samples'] == ts].empty else np.nan
                    for ts in training_sizes
                ]
                y_pos = y_positions + k * bar_width - bar_width  # stack bars within the same metric

                # Apply subcategory-specific color
                if model == 'MedRoBERTa.nl':
                    color = subcategory_colors_medroberta[metric].get(subcat, '#cccccc')
                else:
                    color = subcategory_colors_roberta[metric].get(subcat, '#cccccc')

                # Set zorder for bars to be on top of the grid
                ax.barh(
                    y=y_pos,
                    width=values,
                    height=bar_width,
                    color=color,
                    edgecolor='none',
                    zorder=2  # Bars on top
                )

        # Y-axis labels for the first row only
        if row_idx == 0:
            ax.set_yticks(y_positions - 0.45)
            ax.set_yticklabels([1, 2, 3], fontsize=font_params['fontsize'])
            ax.set_ylabel("Number of samples", fontsize=font_params['fontsize'])
        else:
            ax.set_yticklabels([1, 2, 3])

        ax.set_xlim(-1, 1)
        ax.axvline(0, color='black', linewidth=0.5, zorder=3)  # Ensure vertical line is on top

        # Move grid call after bars and set zorder for grid
        ax.grid(True, axis='x', linestyle='--', linewidth=0.5, color='black', zorder=1)  # Grid behind the bars

        # Remove x-tick labels for non-last rows
        if row_idx != num_symptoms - 1:
            ax.set_xticklabels([])

        # Set custom x-tick labels
        xticks = np.linspace(-1, 1, 9)
        xticklabels = ['1', '0.75', '0.5', '0.25', '0', '0.25', '0.50', '0.75', '1']
        ax.set_xticks(xticks)
        ax.set_xticklabels(xticklabels, fontsize=font_params['fontsize']-5)

        ax.text(0.442, 1.005, 'MedRoBERTa.nl vs RobBERT', transform=ax.transAxes,
                fontsize=font_params['fontsize']-1, ha='center', va='bottom')

# Add "Value" label for each column in the last row
for col_idx in range(num_symptoms):
    ax = axes[len(metrics) - 1, col_idx]  # Accessing the last row's columns individually
    ax.set_xlabel("Value", fontsize=font_params['fontsize'])  # Add "Value" for each column

# Subtitle
# Add column headers
for idx, metric in enumerate(df_long['Symptom'].unique()):
    fig.text(
        x=(0.065 + idx * 0.102),  # Adjust these to match your layout spacing
        y=0.98,
        s=metric,
        ha='center',
        va='bottom',
        fontsize=font_params['fontsize'] + 2,  # Adjust font size
        fontweight='bold'
    )

# Add row labels (Recall, Precision, F1) from top to bottom
row_labels = ['Recall', 'Precision', 'F1-score']
for idx, label in enumerate(row_labels):
    fig.text(
        x=-0.01,  # x-position (left side of the figure)
        y=0.81 - idx * 0.31,  # y-position tuned per row
        s=label,
        ha='left',
        va='center',
        rotation='vertical',
        fontsize=font_params['fontsize'] + 2,
        fontweight='bold'
    )

fig.suptitle("Prompt-based classification", fontsize=20, fontweight='bold', y=1.02, x=0.01, ha='left')

plt.tight_layout(rect=[0, 0, 0.93, 1])
plt.savefig('Prompt-based_classifiers_macro.png', format='png')
plt.show()

In [20]:
%%capture 

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.patches import Rectangle
from matplotlib.offsetbox import AnchoredOffsetbox, VPacker, HPacker, TextArea, DrawingArea

subcategory_colors_models = {
    'Recall': {
        "Recall 'present'": '#66c2a5',        # Teal blue
        "Recall 'absent'": '#fc8d62',         # Warm amber
        "Recall 'not reported'": '#bdbdbd'    # Neutral grey
    },
    'Precision': {
        "Precision 'present'": '#66c2a5',
        "Precision 'absent'": '#fc8d62',
        "Precision 'not reported'": '#bdbdbd'
    },
    'F1': {
        "F1 'present'": '#66c2a5',
        "F1 'absent'": '#fc8d62',
        "F1 'not reported'": '#bdbdbd'
    }
}

metrics = ['Recall', 'Precision', 'F1']
subcategories_per_metric = {
    'Recall': ["Recall 'not reported'", "Recall 'absent'", "Recall 'present'"],
    'Precision': ["Precision 'not reported'", "Precision 'absent'", "Precision 'present'"],
    'F1': ["F1 'not reported'", "F1 'absent'", "F1 'present'"]
}

num_symptoms = len(df_long['Symptom'].unique())
fig, axes = plt.subplots(3, num_symptoms, figsize=(40, 20), sharex=True, sharey=True)

bar_width = 0.15
ytick_spacing = 0.75
font_params = {'fontsize': 18}

for col_idx, metric in enumerate(metrics):
    subcats = subcategories_per_metric[metric]

    for row_idx, symptom in enumerate(df_long['Symptom'].unique()):
        ax = axes[col_idx, row_idx] if axes.ndim == 2 else axes[row_idx]
        symptom_data = df_long[df_long['Symptom'] == symptom]
        ytick_count = len(symptom_data['Number of samples'].unique())
        y_positions = np.arange(ytick_count) * ytick_spacing

        for k, subcat in enumerate(subcats):
            for model in ['MedRoBERTa.nl', 'RobBERT']:
                subset = symptom_data[
                    (symptom_data['metric'] == subcat) &
                    (symptom_data['Model'] == model)
                ]
                training_sizes = sorted(subset['Number of samples'].unique())
                values = [
                    subset[subset['Number of samples'] == ts]['mirrored_value'].values[0]
                    if not subset[subset['Number of samples'] == ts].empty else np.nan
                    for ts in training_sizes
                ]
                y_pos = y_positions + k * bar_width - bar_width

                # Unified color for subcategories
                color = subcategory_colors_models[metric].get(subcat, '#cccccc')

                # Set zorder for bars to be on top of the grid
                ax.barh(
                    y=y_pos,
                    width=values,
                    height=bar_width,
                    color=color,
                    edgecolor='none',
                    zorder=2  # Bars on top
                )

        if row_idx == 0:
            ax.set_yticks(y_positions)
            ax.set_yticklabels([1, 2, 3], fontsize=font_params['fontsize'])
            ax.set_ylabel("Number of samples", fontsize=font_params['fontsize'])
        else:
            ax.set_yticklabels([1, 2, 3])

        ax.set_xlim(-1, 1)
        ax.axvline(0, color='black', linewidth=0.5, zorder=3)  # Ensure vertical line is on top

        # Move grid call after bars and set zorder for grid
        ax.grid(True, axis='x', linestyle='--', linewidth=0.5, color='black', zorder=1)  # Grid behind the bars

        if row_idx != num_symptoms - 1:
            ax.set_xticklabels([])

        xticks = np.linspace(-1, 1, 9)
        xticklabels = ['1', '0.75', '0.5', '0.25', '0', '0.25', '0.50', '0.75', '1']
        ax.set_xticks(xticks)
        ax.set_xticklabels(xticklabels, fontsize=font_params['fontsize']-5)

        ax.text(0.442, 1.005, 'MedRoBERTa.nl vs RobBERT', transform=ax.transAxes,
                fontsize=font_params['fontsize']-1, ha='center', va='bottom')

for col_idx in range(num_symptoms):
    ax = axes[len(metrics) - 1, col_idx]
    ax.set_xlabel("Value", fontsize=font_params['fontsize'])

# Legend (subcategories only, no model distinction)
legend_submetrics = ["'present'", "'absent'", "'not reported'"]
label_mapping = {
    "'present'": "Present",
    "'absent'": "Absent",
    "'not reported'": "Not reported"
}
color_keys = ["Recall 'present'", "Recall 'absent'", "Recall 'not reported'"]

legend_items = []
for submetric_label, color_key in zip(legend_submetrics, color_keys):
    patch = DrawingArea(20, 10, 0, 0)
    rect = Rectangle((0, 0), 20, 10, fc=subcategory_colors_models['Recall'][color_key], edgecolor='none')
    patch.add_artist(rect)
    label = TextArea(label_mapping[submetric_label], textprops=dict(fontsize=font_params['fontsize'], ha='left'))
    legend_items.append(HPacker(children=[patch, label], align="left", pad=0, sep=6))

final_legend = VPacker(children=legend_items, align="left", pad=0, sep=10)

anchored_box = AnchoredOffsetbox(
    loc='center left',
    child=final_legend,
    pad=0.,
    frameon=False,
    bbox_to_anchor=(0.94, 0.5),
    bbox_transform=fig.transFigure,
    borderpad=0.
)
fig.add_artist(anchored_box)

# Add column headers (symptom names)
for idx, symptom in enumerate(df_long['Symptom'].unique()):
    fig.text(
        x=(0.065 + idx * 0.102),
        y=0.98,
        s=symptom,
        ha='center',
        va='bottom',
        fontsize=font_params['fontsize'] + 2,
        fontweight='bold'
    )

# Row labels (Recall, Precision, F1)
row_labels = ['Recall', 'Precision', 'F1-score']
for idx, label in enumerate(row_labels):
    fig.text(
        x=-0.01,
        y=0.81 - idx * 0.31,
        s=label,
        ha='left',
        va='center',
        rotation='vertical',
        fontsize=font_params['fontsize'] + 2,
        fontweight='bold'
    )

fig.suptitle("Prompt-based classification", fontsize=20, fontweight='bold', y=1.02, x=0.01, ha='left')

plt.tight_layout(rect=[0, 0, 0.93, 1])
plt.savefig('Prompt-based_classifiers_perclass.png', format='png', bbox_inches='tight')
plt.show()


In [21]:
%%capture 

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.patches import Rectangle
from matplotlib.offsetbox import AnchoredOffsetbox, VPacker, HPacker, TextArea, DrawingArea

# Color palette for MedRoBERTa.nl (blue) and RobBERT (orange)
subcategory_colors_medroberta = {
    'Recall': {
        'Macro-recall': '#2B4C7E',         # dark blue
        "Recall 'present'": '#5178A6',     # medium blue
        "Recall 'absent'": '#7FA2C9',      # light-medium blue
        "Recall 'not reported'": '#A9C3DE' # light blue
    },
    'Precision': {
        'Macro-precision': '#2B4C7E',
        "Precision 'present'": '#5178A6',
        "Precision 'absent'": '#7FA2C9',
        "Precision 'not reported'": '#A9C3DE'
    },
    'F1': {
        'Macro-F1': '#2B4C7E',
        "F1 'present'": '#5178A6',
        "F1 'absent'": '#7FA2C9',
        "F1 'not reported'": '#A9C3DE'
    }
}

subcategory_colors_roberta = {
    'Recall': {
        'Macro-recall': '#B85716',         # dark orange
        "Recall 'present'": '#DA7A34',     # medium orange
        "Recall 'absent'": '#F1A469',      # light-medium orange
        "Recall 'not reported'": '#F6C7A3' # light orange

    },
    'Precision': {
        'Macro-precision': '#B85716',
        "Precision 'present'": '#DA7A34',
        "Precision 'absent'": '#F1A469',
        "Precision 'not_reported'": '#F6C7A3'
    },
    'F1': {
        'Macro-F1': '#B85716',
        "F1 'present'": '#DA7A34',
        "F1 'absent'": '#F1A469',
        "F1 'not reported'": '#F6C7A3'
    }
}

metrics = ['Recall', 'Precision', 'F1']
subcategories_per_metric = {
    'Recall': ["Recall 'not reported'", "Recall 'absent'", "Recall 'present'", 'Macro-recall'],
    'Precision': ["Precision 'not reported'", "Precision 'absent'", "Precision 'present'", 'Macro-precision'],
    'F1': ["F1 'not reported'", "F1 'absent'", "F1 'present'", 'Macro-F1']
}

# Grid setup: Each column corresponds to a metric (Recall, Precision, F1) and each row to a symptom
num_symptoms = len(df_long['Symptom'].unique())
fig, axes = plt.subplots(3, num_symptoms, figsize=(40, 20), sharex=True, sharey=True)  # Transposed grid

bar_width = 0.15
ytick_spacing = 1
font_params = {'fontsize': 18}  # Increased font size

# Loop over metrics (columns)
for col_idx, metric in enumerate(metrics):
    subcats = subcategories_per_metric[metric]

    # Loop over symptoms (rows)
    for row_idx, symptom in enumerate(df_long['Symptom'].unique()):
        ax = axes[col_idx, row_idx] if axes.ndim == 2 else axes[row_idx]
        symptom_data = df_long[df_long['Symptom'] == symptom]
        ytick_count = len(symptom_data['Number of samples'].unique())
        y_positions = np.arange(ytick_count) * ytick_spacing

        # Loop over subcategories
        for k, subcat in enumerate(subcats):
            for model in ['MedRoBERTa.nl', 'RobBERT']:
                subset = symptom_data[
                    (symptom_data['metric'] == subcat) &
                    (symptom_data['Model'] == model)
                ]
                training_sizes = sorted(subset['Number of samples'].unique())
                values = [
                    subset[subset['Number of samples'] == ts]['mirrored_value'].values[0]
                    if not subset[subset['Number of samples'] == ts].empty else np.nan
                    for ts in training_sizes
                ]
                y_pos = y_positions + k * bar_width - bar_width  # stack bars within the same metric

                # Apply subcategory-specific color
                if model == 'MedRoBERTa.nl':
                    color = subcategory_colors_medroberta[metric].get(subcat, '#cccccc')
                else:
                    color = subcategory_colors_roberta[metric].get(subcat, '#cccccc')

                ax.barh(
                    y=y_pos,
                    width=values,
                    height=bar_width,
                    color=color,
                    edgecolor='none'
                )

        # Y-axis labels for the first row only
        if row_idx == 0:
            ax.set_yticks(y_positions + 0.15)
            ax.set_yticklabels([1, 2, 3], fontsize=font_params['fontsize'])
            ax.set_ylabel("Number of samples", fontsize=font_params['fontsize'])
        else:
            ax.set_yticklabels([1, 2, 3])

        ax.set_xlim(-1, 1)
        ax.axvline(0, color='black', linewidth=0.5)
        ax.grid(True, axis='x', linestyle='--', linewidth=0.5, color='black')

        # Remove x-tick labels for non-last rows
        if row_idx != num_symptoms - 1:
            ax.set_xticklabels([])

        # Set custom x-tick labels
        xticks = np.linspace(-1, 1, 9)
        xticklabels = ['1', '0.75', '0.5', '0.25', '0', '0.25', '0.50', '0.75', '1']
        ax.set_xticks(xticks)
        ax.set_xticklabels(xticklabels, fontsize=font_params['fontsize']-5)

        ax.text(0.442, 1.005, 'MedRoBERTa.nl vs RobBERT', transform=ax.transAxes,
                fontsize=font_params['fontsize']-1, ha='center', va='bottom')

# Add "Value" label for each column in the last row
for col_idx in range(num_symptoms):
    ax = axes[len(metrics) - 1, col_idx]  # Accessing the last row's columns individually
    ax.set_xlabel("Value", fontsize=font_params['fontsize'])  # Add "Value" for each column

# Build and add legend
# Define unified submetric order and labels
legend_submetrics = ['Macro', "'present'", "'absent'", "'not reported'"]
label_mapping = {
    'Macro': 'Macro',
    "'present'": "Present",
    "'absent'": "Absent",
    "'not reported'": "Not reported"
}

# Use Recall mappings as representative for colors (they are consistent across metrics)
color_keys_med = ['Macro-recall', "Recall 'present'", "Recall 'absent'", "Recall 'not reported'"]
color_keys_rob = ['Macro-recall', "Recall 'present'", "Recall 'absent'", "Recall 'not reported'"]

legend_items_medroberta = [TextArea("MedRoBERTa.nl", textprops=dict(fontsize=font_params['fontsize'], ha='center'))]
legend_items_roberta = [TextArea("RobBERT", textprops=dict(fontsize=font_params['fontsize'], ha='center'))]

# Add submetric entries just once
for submetric_label, color_key_med, color_key_rob in zip(legend_submetrics, color_keys_med, color_keys_rob):
    # MedRoBERTa.nl
    patch_blue = DrawingArea(20, 10, 0, 0)
    rect_blue = Rectangle((0, 0), 20, 10, fc=subcategory_colors_medroberta['Recall'][color_key_med], edgecolor='none')
    patch_blue.add_artist(rect_blue)
    label_blue = TextArea(label_mapping[submetric_label], textprops=dict(fontsize=font_params['fontsize'], ha='left'))
    legend_items_medroberta.append(HPacker(children=[patch_blue, label_blue], align="left", pad=0, sep=6))

    # RobBERT
    patch_orange = DrawingArea(20, 10, 0, 0)
    rect_orange = Rectangle((0, 0), 20, 10, fc=subcategory_colors_roberta['Recall'][color_key_rob], edgecolor='none')
    patch_orange.add_artist(rect_orange)
    label_orange = TextArea(label_mapping[submetric_label], textprops=dict(fontsize=font_params['fontsize'], ha='left'))
    legend_items_roberta.append(HPacker(children=[patch_orange, label_orange], align="left", pad=0, sep=6))

# Combine and place the legend
final_legend = VPacker(
    children=legend_items_medroberta + [TextArea("", textprops=dict(fontsize=font_params['fontsize']))] + legend_items_roberta,
    align="left", pad=0, sep=18
)

anchored_box = AnchoredOffsetbox(
    loc='center left',
    child=final_legend,
    pad=0.,
    frameon=False,
    bbox_to_anchor=(0.94, 0.5),
    bbox_transform=fig.transFigure,
    borderpad=0.
)
fig.add_artist(anchored_box)


# Subtitle
# Add column headers
for idx, metric in enumerate(df_long['Symptom'].unique()):
    fig.text(
        x=(0.065 + idx * 0.102),  # Adjust these to match your layout spacing
        y=0.98,
        s=metric,
        ha='center',
        va='bottom',
        fontsize=font_params['fontsize'] + 2,  # Adjust font size
        fontweight='bold'
    )

# Add row labels (Recall, Precision, F1) from top to bottom
row_labels = ['Recall', 'Precision', 'F1-score']
for idx, label in enumerate(row_labels):
    fig.text(
        x=-0.01,  # x-position (left side of the figure)
        y=0.81 - idx * 0.31,  # y-position tuned per row
        s=label,
        ha='left',
        va='center',
        rotation='vertical',
        fontsize=font_params['fontsize'] + 2,
        fontweight='bold'
    )

fig.suptitle("Prompt-based classification", fontsize=20, fontweight='bold', y=1.02, x=0.01, ha='left')

plt.tight_layout(rect=[0, 0, 0.93, 1])
plt.savefig('Prompt-based_classifiers.png', format='png')
plt.show()
