In [6]:
import json
from pathlib import Path
import sys
from textwrap import wrap

from matplotlib.colors import LogNorm
from numpy import NaN


import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os

def to_title_case(input_string):
    # Special handling for strings that are likely to be acronyms or single words
    input_string = input_string.lower()
    
    # Split the string by underscores, then capitalize the first letter of each part
    words = input_string.split("_")
    title_case_words = [word.capitalize() for word in words]
    
    # Join the capitalized words with spaces
    result = " ".join(title_case_words)
    result = result.replace("Tsh", "TSH")
    result = result.replace("tsh", "TSH")
    result = result.replace("Fpr", "FPR")
    result = result.replace("Fnr", "FNR")
    
    return result

def note(description):
    plt.plot([], [], ' ', label="\n" + '\n'.join(wrap(description, 35, replace_whitespace=False)))


sns.set_style("whitegrid")
metrics = ["sensitivity", "specificity", "f10", "f1", "fnr", "fpr", "fold", "precision", "recall"]



file_path = '../analysis.csv'  # Adjust this path as necessary
data = pd.read_csv(file_path)

data = data[data['resampler'] == 'none']

data['fold'] = data['fold'].astype(int)
data['precision'] = data['true_positives'] / (data['true_positives'] + data['false_positives'])
data['recall'] = data['true_positives'] / (data['true_positives'] + data['false_negatives'])

data['f10'] = (1+(10**2))*((data['precision'] * data['recall'])/(((10**2) * data['precision']) + data['recall']))


data = data[data['tag'].isin([ "baseline_ge_14.0","baseline_ge_15.2","baseline_ge_15.3","baseline_ge_16.0"])]

 # Generate and save each plot individually

mean_dir = Path(f'mean')
os.makedirs(mean_dir, exist_ok=True)
for metric in metrics:
    # Calculate results
    df = data.copy()
    columns = ["model", "tag" , metric]
    df = df.drop(columns=[col for col in df.columns if col not in columns])
    var = df.copy()
    df = df.groupby(["model", "tag"]).mean().reset_index()
    plt.figure(figsize=(13, 9))
    # sns.barplot(x=metric, y="model", hue="resampler",  data=data, errorbar=("ci", 95))
    sns.heatmap(data=df.pivot(index="model", columns="tag", values=metric), 
                annot=True,
                fmt=".3f", 
                cmap="YlGnBu", 
                cbar_kws={'label': to_title_case(metric)},
                vmax=0.5,
                vmin=0,
    )
    plt.title(f'Performance by {to_title_case(metric)}')
    plt.xlabel("Resampler")
    plt.ylabel("Learner")
    
    note("These models are trained using F1 optimzation.")
    note("The heatmap shows the mean performance of each model and resampler combination.")
    note("The performance is calculated using 5-fold cross-validation.")
    
    plt.legend(title='Notes', bbox_to_anchor=(1.2, 1), loc='upper left')
    
    # Hide legend 
    plt.legend().set_visible(False)
    
    # Define the file path
    plt.tight_layout()
    plt.savefig(mean_dir / f'{metric}_performance.png', dpi=400)
    # plt.show()
    plt.close()  # Close the plot to free memory
