# Setup

In [None]:
import sys
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
path_to_root = '/content/drive/My Drive/Colab Notebooks/BatuEl_Dissertation'
sys.path.append(path_to_root)
print("Drive mounted.")

data_path = path_to_root + '/data'

In [None]:
import torch
import tqdm
from reprshift.learning.algorithms import ERM
from reprshift.models.hparams import hparams_f
from reprshift.dataset.datasets import MultiNLI, CivilComments
from reprshift.dataset.dataloaders import InfiniteDataLoader, FastDataLoader

from reprshift.models.model_param_maps import ERM_to_HookedEncoder, load_focal, load_groupdro, load_jtt, load_lff
from reprshift.models.HookedEncoderConfig import bert_config

# from transformer_lens2 import HookedEncoder, HookedTransformerConfig
import pandas as pd

# Val Metrics

In [None]:
SEED = 0
DATASET =  'CivilComments' # 'CivilComments' , 'MultiNLI'
valmetrics_PATH = path_to_root + f'/results/ValidationMetrics/{DATASET}_valmetrics_seed{SEED}.pth'
val_metrics = torch.load(valmetrics_PATH)

In [None]:
valmetrics = {}

for DATASET in ['CivilComments' , 'MultiNLI']:
    valmetrics[DATASET] = {}
    for SEED in [0,1,2]:
          valmetrics_PATH =  path_to_root + f'/results/ValidationMetrics/{DATASET}_valmetrics_seed{SEED}.pth'
          val_metrics = torch.load(valmetrics_PATH)
          ### Compute Statistics ###
          group_keys = val_metrics['erm'][1]['per_group'].keys()
          d = {algo_key: {epoch : {group_key: val_metrics[algo_key][epoch]['per_group'][group_key]['accuracy'].round(4) * 100 for group_key in group_keys} for epoch in range(1,31)} for algo_key in ['erm', 'groupdro', 'jtt','lff', 'focal']}
          df = pd.concat({algo_key:pd.DataFrame(d[algo_key]) for algo_key in d.keys()} , axis=0)
          selected_models = {algorithm: df.loc[algorithm].min().argmax() + 1 for algorithm in ['erm', 'groupdro', 'focal', 'jtt', 'lff']}
          d2 = {algo_key: {epoch : val_metrics[algo_key][epoch]['overall']['accuracy'].round(4) * 100 for epoch in range(1,31)} for algo_key in ['erm', 'groupdro', 'jtt','lff', 'focal']}
          df2 = pd.DataFrame(d2)
          selected_models2 = {algorithm: df2[algorithm].argmax() + 1 for algorithm in ['erm', 'groupdro', 'focal', 'jtt', 'lff']}
          df2 = pd.concat({"overall": df2.T})
          df2 = df2.swaplevel()
          df3 = {}
          for i in df.index:
              df3[i[0]] = pd.concat([df.loc[i[0]], df2.loc[i[0]]])
          df3 = pd.concat(df3)
          ### Val Table ###
          selected_model_stats = pd.DataFrame({key:df.loc[key][selected_models[key]] for key in selected_models.keys()})
          selected_model_stats.loc['Epoch'] = selected_models
          overall = pd.DataFrame(index=['Overall'], data={algorithm_name: val_metrics[algorithm_name][selected_models[algorithm_name]]['overall']['accuracy'].round(4) * 100 for algorithm_name in selected_models.keys()})
          ValTable = pd.concat([overall, selected_model_stats])
          ### Save ###
          valmetrics[DATASET][SEED] = {}
          valmetrics[DATASET][SEED]['WGA Selection'] = selected_models
          valmetrics[DATASET][SEED]['A Selection'] = selected_models2
          valmetrics[DATASET][SEED]['df'] = df3
          valmetrics[DATASET][SEED]['Validation Table'] = ValTable

torch.save(valmetrics, path_to_root + f'/results/ValidationMetrics/clean_val_results.pth')

# Means and Stds

In [None]:
valmetrics = torch.load(path_to_root + f'/results/ValidationMetrics/clean_val_results.pth')

In [None]:
DATASET = 'CivilComments'

### df ###
dfs = [valmetrics[DATASET][i]['df'] for i in range(2)]
stacked_dfs = np.stack(dfs)
df_mean_values = np.mean(stacked_dfs, axis=0)
df_std_values = np.std(stacked_dfs, axis=0)
df_mean = pd.DataFrame(df_mean_values, columns=dfs[0].columns, index=dfs[0].index)
df_std = pd.DataFrame(df_std_values, columns=dfs[0].columns, index=dfs[0].index)
valmetrics[DATASET]['df Mean'] = df_mean.round(1)
valmetrics[DATASET]['df Std'] = df_std.round(1)

### ValTable ###
dfs = [valmetrics[DATASET][i]['Validation Table'] for i in range(2)]
# dfs = [pd.DataFrame(valmetrics[DATASET][i]['Validation Table'].drop('Overall').mean()) for i in range(2)]
stacked_dfs = np.stack(dfs)
df_mean_values = np.mean(stacked_dfs, axis=0)
df_std_values = np.std(stacked_dfs, axis=0)
df_mean = pd.DataFrame(df_mean_values, columns=dfs[0].columns, index=dfs[0].index)
df_std = pd.DataFrame(df_std_values, columns=dfs[0].columns, index=dfs[0].index)
valmetrics[DATASET]['Validation Table Mean'] = df_mean.round(1)
valmetrics[DATASET]['Validation Table Std'] = df_std.round(1)

In [None]:
df = valmetrics[DATASET]['df Mean']
df_std = valmetrics[DATASET]['df Std']

import numpy as np
for group in df.loc['jtt'][15].index:
    df.loc['jtt'][15][group] = np.nan
    # df_std.loc['jtt'][15][group] = np.nan

# Group Accuracy Across Epochs

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

df = valmetrics[DATASET]['df Mean']
df_std = valmetrics[DATASET]['df Std']

fig, axes = plt.subplots(nrows=3, ncols=2, figsize=(10, 10))  
axes = axes.flatten()

colorblind_colors = ["#88CCEE", "#CC6677", "#DDCC77", "#117733", "#332288", "#AA4499", "#44AA99", "#999933", "#882255", "#661100", "#6699CC", "#888888"]
Colors = sns.color_palette(colorblind_colors)
Colors = sns.color_palette("tab20", n_colors=17)

Titles = {'erm': 'ERM', 'groupdro': 'GroupDRO', 'focal': 'Focal', 'jtt': 'JTT', 'lff': 'LFF'}

for i, algorithm in enumerate(['erm', 'groupdro', 'focal', 'jtt', 'lff']):
    mean_values = df.loc[algorithm].T
    std_values = df_std.loc[algorithm].T

    for j, col in enumerate(mean_values.columns):
        axes[i].plot(mean_values.index, mean_values[col], label=col, color=Colors[j])
        axes[i].fill_between(mean_values.index, mean_values[col] - std_values[col], mean_values[col] + std_values[col], alpha=0.2, color=Colors[j])

    axes[i].set_title(Titles[algorithm])
    axes[i].set_xlabel('Model Checkpoints')
    axes[i].set_ylabel('Accuracy')
    axes[i].set_ylim(0, 100)  

if len(['erm', 'groupdro', 'focal', 'jtt', 'lff']) < len(axes):
    for j in range(len(['erm', 'groupdro', 'focal', 'jtt', 'lff']), len(axes)):
        fig.delaxes(axes[j])

handles, labels = axes[0].get_legend_handles_labels()
fig.legend(handles, labels, loc='upper center', ncol=(len(labels)/3)+1,bbox_to_anchor=(0.5, 1.05))

plt.tight_layout(rect=[0, 0, 1, 0.95])  
plt.show()