# Install

In [None]:
!pip install einops datasets jaxtyping better_abc fancy_einsum wandb netcal

# Setup

In [None]:
import sys
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
path_to_root = '/content/drive/My Drive/Colab Notebooks/BatuEl_Dissertation'
sys.path.append(path_to_root)
print("Drive mounted.")

data_path = path_to_root + '/data'

In [None]:
import torch
import tqdm
from reprshift.learning.algorithms import ERM
from reprshift.models.hparams import hparams_f
from reprshift.dataset.datasets import MultiNLI, CivilComments
from reprshift.dataset.dataloaders import InfiniteDataLoader, FastDataLoader

from reprshift.models.model_param_maps import ERM_to_HookedEncoder, load_focal, load_groupdro, load_jtt, load_lff
from reprshift.models.HookedEncoderConfig import bert_config

from transformer_lens2 import HookedEncoder, HookedTransformerConfig

# Define Metrics

In [None]:
import torch
import numpy as np
import pandas as pd
from sklearn.metrics import (accuracy_score, confusion_matrix, roc_auc_score, average_precision_score,
                             balanced_accuracy_score, recall_score, brier_score_loss, log_loss, classification_report)
# import netcal.metrics


## From Subpopbench Repository ##


def predict_on_set(algorithm, loader, device):
    num_labels = loader.dataset.num_labels

    ys, atts, gs, ps = [], [], [], []

    algorithm.eval()
    with torch.no_grad():
        for _, x, y, a in loader:
            p = algorithm.predict(x.to(device))
            if p.squeeze().ndim == 1:
                p = torch.sigmoid(p).detach().cpu().numpy()
            else:
                p = torch.softmax(p, dim=-1).detach().cpu().numpy()
                if num_labels == 2:
                    p = p[:, 1]

            ps.append(p)
            ys.append(y)
            atts.append(a)
            gs.append([f'y={yi},a={gi}' for c, (yi, gi) in enumerate(zip(y, a))])

    return np.concatenate(ys, axis=0), np.concatenate(atts, axis=0), np.concatenate(ps, axis=0), np.concatenate(gs)

def eval_metrics(algorithm, loader, device, thres=0.5):
    targets, attributes, preds, gs = predict_on_set(algorithm, loader, device)
    preds_rounded = preds >= thres if preds.squeeze().ndim == 1 else preds.argmax(1)
    label_set = np.unique(targets)

    res = {}
    res['overall'] = {
        **binary_metrics(targets, preds_rounded, label_set),
        # **prob_metrics(targets, preds, label_set)
    }
    res['per_attribute'] = {}
    res['per_class'] = {}
    res['per_group'] = {}

    for a in np.unique(attributes):
        mask = attributes == a
        res['per_attribute'][int(a)] = {
            **binary_metrics(targets[mask], preds_rounded[mask], label_set),
            # **prob_metrics(targets[mask], preds[mask], label_set)
        }

    classes_report = classification_report(targets, preds_rounded, output_dict=True, zero_division=0.)
    res['overall']['macro_avg'] = classes_report['macro avg']
    res['overall']['weighted_avg'] = classes_report['weighted avg']
    for y in np.unique(targets):
        res['per_class'][int(y)] = classes_report[str(y)]

    for g in np.unique(gs):
        mask = gs == g
        res['per_group'][g] = {
            **binary_metrics(targets[mask], preds_rounded[mask], label_set)
        }

    res['adjusted_accuracy'] = sum([res['per_group'][g]['accuracy'] for g in np.unique(gs)]) / len(np.unique(gs))
    res['min_attr'] = pd.DataFrame(res['per_attribute']).min(axis=1).to_dict()
    res['max_attr'] = pd.DataFrame(res['per_attribute']).max(axis=1).to_dict()
    res['min_group'] = pd.DataFrame(res['per_group']).min(axis=1).to_dict()
    res['max_group'] = pd.DataFrame(res['per_group']).max(axis=1).to_dict()
    return res

def binary_metrics(targets, preds, label_set=[0, 1], return_arrays=False):
    if len(targets) == 0:
        return {}
    res = {
        'accuracy': accuracy_score(targets, preds),
        'n_samples': len(targets)
    }

    if len(label_set) == 2:
        CM = confusion_matrix(targets, preds, labels=label_set)

        res['TN'] = CM[0][0].item()
        res['FN'] = CM[1][0].item()
        res['TP'] = CM[1][1].item()
        res['FP'] = CM[0][1].item()
        res['error'] = res['FN'] + res['FP']

        if res['TP'] + res['FN'] == 0:
            res['TPR'] = 0
            res['FNR'] = 1
        else:
            res['TPR'] = res['TP']/(res['TP']+res['FN'])
            res['FNR'] = res['FN']/(res['TP']+res['FN'])

        if res['FP'] + res['TN'] == 0:
            res['FPR'] = 1
            res['TNR'] = 0
        else:
            res['FPR'] = res['FP']/(res['FP']+res['TN'])
            res['TNR'] = res['TN']/(res['FP']+res['TN'])

        res['pred_prevalence'] = (res['TP'] + res['FP']) / res['n_samples']
        res['prevalence'] = (res['TP'] + res['FN']) / res['n_samples']
    else:
        CM = confusion_matrix(targets, preds, labels=label_set)
        res['TPR'] = recall_score(targets, preds, labels=label_set, average='macro', zero_division=0.)

    if len(np.unique(targets)) > 1:
        res['balanced_acc'] = balanced_accuracy_score(targets, preds)
    if return_arrays:
        res['targets'] = targets
        res['preds'] = preds
    return res

# def prob_metrics(targets, preds, label_set, return_arrays=False):
#     if len(targets) == 0:
#         return {}

#     res = {
#         'AUROC_ovo': roc_auc_score(targets, preds, multi_class='ovo', labels=label_set),
#         'BCE': log_loss(targets, preds, eps=1e-6, labels=label_set),
#         'ECE': netcal.metrics.ECE().measure(preds, targets)
#     }

#     # happens when you predict a class, but there are no samples with that class in the dataset
#     try:
#         res['AUROC'] = roc_auc_score(targets, preds, multi_class='ovr', labels=label_set)
#     except:
#         res['AUROC'] = roc_auc_score(targets, preds, multi_class='ovo', labels=label_set)

#     if len(set(targets)) == 2:
#         res['AUPRC'] = average_precision_score(targets, preds, average='macro')
#         res['brier'] = brier_score_loss(targets, preds)

#     if return_arrays:
#         res['targets'] = targets
#         res['preds'] = preds

#     return res

# Dataset

In [None]:
DATASET = 'MultiNLI'  # 'CivilComments' , 'MultiNLI'

if DATASET == 'MultiNLI':
    NUM_CLASSES = 3
    NUM_ATTRIBUTES = 2
    # train_dataset = MultiNLI(data_path, 'tr', hparams)
    # val_dataset = MultiNLI(data_path, 'va', hparams=hparams_f('ERM'))
    te_dataset = MultiNLI(data_path, 'te', hparams=hparams_f('ERM'))
    models_path = path_to_root + '/models/models_mnli'
    representations_path = path_to_root + '/representations/representations_mnli'
    print(DATASET)
elif DATASET  == 'CivilComments':
    NUM_CLASSES = 2
    NUM_ATTRIBUTES = 8
    # train_dataset = CivilComments(data_path, 'tr', hparams, granularity="fine")
    # val_dataset = CivilComments(data_path, 'va', hparams=hparams_f('ERM'))
    te_dataset = CivilComments(data_path, 'te', hparams=hparams_f('ERM'))
    models_path = path_to_root + '/models/models_civilcomments'
    representations_path = path_to_root + '/representations/representations_civilcomments'
    print(DATASET)
else:
    print('Dataset Not Implemented')

# Load Model

In [None]:
LogitShiftEdits = {0:1.6, 1:2.1, 2:1.7}

for SEED in [0,1,2]:
    print('SEED: ', SEED)
    valmetrics = torch.load(path_to_root + f'/results/ValidationMetrics/clean_val_results.pth')
    CURRENT_BEST_EPOCH = valmetrics[DATASET][SEED]['A Selection']

    ### MODELS ###
    MODELS  = {'pretrained': {'path': models_path + '/00_randominit/',  'load_f': lambda x: x, 'epoch':0,},
              'erm': {'path': models_path + '/01_erm/',  'load_f': lambda x: x, 'epoch':CURRENT_BEST_EPOCH['erm'],},
              'groupdro': {'path': models_path + '/03_groupdro/',  'load_f': load_groupdro, 'epoch':CURRENT_BEST_EPOCH['groupdro'],},
              'jtt': {'path': models_path + '/06_jtt/',  'load_f': load_jtt, 'epoch':CURRENT_BEST_EPOCH['jtt'],},
              'lff': {'path': models_path + '/07_lff/',  'load_f': load_lff, 'epoch':CURRENT_BEST_EPOCH['lff'],},
              'focal': {'path': models_path + '/15_focal/',  'load_f': lambda x: x, 'epoch':CURRENT_BEST_EPOCH['focal'],},}

    ### Load Statedict ###
    algorithm_name = 'erm'
    state_dict_PATH = MODELS[algorithm_name]['path']
    load_f =  MODELS[algorithm_name]['load_f']
    epoch = MODELS[algorithm_name]['epoch']
    algorithm_state_dict_PATH = state_dict_PATH + f'seed{SEED}/sd_epoch{epoch}.pth'
    sd = load_f(torch.load(algorithm_state_dict_PATH))

    ### Edit State Dict ###
    print(sd['network.1.classifier.bias'])
    sd['network.1.classifier.bias'] = sd['network.1.classifier.bias'] - torch.tensor([LogitShiftEdits[SEED], 0, 0]).cuda()
    print(sd['network.1.classifier.bias'])

    ### Initialize ERM Model ###
    hparams = hparams_f('ERM')
    algorithm = ERM(num_classes=NUM_CLASSES, num_attributes=NUM_ATTRIBUTES, hparams=hparams)
    algorithm.load_state_dict(sd)

    # from reprshift.utils import eval_helper
    te_metrics = {}
    for algorithm_name in tqdm.tqdm(['erm']):
        print(algorithm_name)
        te_loader = FastDataLoader(  dataset=te_dataset,
                                      batch_size=128,
                                      num_workers=1)
        te_metrics[algorithm_name] =  eval_metrics(algorithm.cuda(), te_loader, 'cuda')

    temetrics_PATH = path_to_root + f'/results/ModelEdits/{DATASET}_logitshift_seed{SEED}.pth'
    torch.save(te_metrics, temetrics_PATH)

# Load Test Metrics

In [None]:
import torch

DATASET = 'MultiNLI'
TestTables = []

for SEED in [0,1,2]:
    te_metrics_load = torch.load(path_to_root + f'/results/ModelEdits/{DATASET}_logitshift_seed{SEED}.pth')
    group_keys = te_metrics_load['erm']['per_group'].keys()
    d = {algo_key: {epoch : {group_key: te_metrics_load[algo_key]['per_group'][group_key]['accuracy'].round(4) * 100 for group_key in group_keys}} for algo_key in ['erm']}
    df = pd.concat({algo_key:pd.DataFrame(d[algo_key]) for algo_key in d.keys()} , axis=1)
    df.columns = df.columns.droplevel(level=1)
    overall = pd.DataFrame(index=['Overall'], data={algo_key: te_metrics_load[algo_key]['overall']['accuracy'].round(4) * 100 for algo_key in ['erm']})
    df = pd.concat([overall, df])
    TestTable = df[['erm']]
    TestTable.to_csv(path_to_root + f'/results/ModelEdits/{DATASET}_logitshift_seed{SEED}_TestTable.csv')
    TestTables.append(TestTable)

In [None]:
# TestTables[0]

In [None]:
### Test Table ###
dfs = TestTables
dfs = [pd.DataFrame(TestTables[i].drop('Overall').mean()) for i in range(2)] # To Calculate average accuracy
stacked_dfs = np.stack(dfs)
df_mean_values = np.mean(stacked_dfs, axis=0)
df_std_values = np.std(stacked_dfs, axis=0)
df_mean = pd.DataFrame(df_mean_values, columns=dfs[0].columns, index=dfs[0].index)
df_std = pd.DataFrame(df_std_values, columns=dfs[0].columns, index=dfs[0].index)
df_std.round(1)

In [None]:
# temetrics_PATH = path_to_root + f'/results/{DATASET}_temetrics_seed{SEED}.pth'
# print(temetrics_PATH)
# te_metrics_load = torch.load(temetrics_PATH)