# Install

In [None]:
!pip install einops datasets jaxtyping better_abc fancy_einsum wandb netcal

# Setup

In [None]:
import sys
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
path_to_root = '/content/drive/My Drive/Colab Notebooks/BatuEl_Dissertation'
sys.path.append(path_to_root)
print("Drive mounted.")

data_path = path_to_root + '/data'

In [None]:
import torch
import tqdm
from reprshift.learning.algorithms import ERM
from reprshift.models.hparams import hparams_f
from reprshift.dataset.datasets import MultiNLI, CivilComments
from reprshift.dataset.dataloaders import InfiniteDataLoader, FastDataLoader

from reprshift.models.model_param_maps import ERM_to_HookedEncoder, load_focal, load_groupdro, load_jtt, load_lff
from reprshift.models.HookedEncoderConfig import bert_config

from transformer_lens2 import HookedEncoder, HookedTransformerConfig
import numpy as np
import pandas as pd

# Dataset

In [None]:
DATASET = 'CivilComments'  # 'CivilComments' , 'MultiNLI'
SEED = 2

if DATASET == 'MultiNLI':
    NUM_CLASSES = 3
    NUM_ATTRIBUTES = 2
    # train_dataset = MultiNLI(data_path, 'tr', hparams)
    # val_dataset = MultiNLI(data_path, 'va', hparams=hparams_f('ERM'))
    # te_dataset = MultiNLI(data_path, 'te', hparams=hparams_f('ERM'))
    models_path = path_to_root + '/models/models_mnli'
    representations_path = path_to_root + '/representations/representations_mnli'
    print(DATASET)
elif DATASET  == 'CivilComments':
    NUM_CLASSES = 2
    NUM_ATTRIBUTES = 8
    # train_dataset = CivilComments(data_path, 'tr', hparams, granularity="fine")
    # val_dataset = CivilComments(data_path, 'va', hparams=hparams_f('ERM'))
    # te_dataset = CivilComments(data_path, 'te', hparams=hparams_f('ERM'))
    models_path = path_to_root + '/models/models_civilcomments'
    representations_path = path_to_root + '/representations/representations_civilcomments'
    print(DATASET)
else:
    print('Dataset Not Implemented')

# Representations

In [None]:
algorithm_names =  ['random', 'randominit', 'pretrained', 'erm', 'groupdro', 'focal', 'jtt', 'lff', ]
REPRS = torch.load(f'{representations_path}/seed{SEED}'+'_reprs')
PER_GROUP_REPR = int(3600 / (NUM_CLASSES * NUM_ATTRIBUTES))

REPRS.keys()

# LogitLens CLS

In [None]:
### MODELS ###

valmetrics = torch.load(path_to_root + f'/results/ValidationMetrics/clean_val_results.pth')
CURRENT_BEST_EPOCH = valmetrics[DATASET][SEED]['WGA Selection']

MODELS  = {'pretrained': {'path': models_path + '/00_randominit/',  'load_f': lambda x: x, 'epoch':0,},
          'erm': {'path': models_path + '/01_erm/',  'load_f': lambda x: x, 'epoch':CURRENT_BEST_EPOCH['erm'],},
          'groupdro': {'path': models_path + '/03_groupdro/',  'load_f': load_groupdro, 'epoch':CURRENT_BEST_EPOCH['groupdro'],},
          'jtt': {'path': models_path + '/06_jtt/',  'load_f': load_jtt, 'epoch':CURRENT_BEST_EPOCH['jtt'],},
          'lff': {'path': models_path + '/07_lff/',  'load_f': load_lff, 'epoch':CURRENT_BEST_EPOCH['lff'],},
          'focal': {'path': models_path + '/15_focal/',  'load_f': lambda x: x, 'epoch':CURRENT_BEST_EPOCH['focal'],},}

def load_and_init(algorithm_name, seed=SEED):
    state_dict_PATH = MODELS[algorithm_name]['path']
    load_f =  MODELS[algorithm_name]['load_f']
    epoch = MODELS[algorithm_name]['epoch']
    algorithm_state_dict_PATH = state_dict_PATH + f'seed{seed}/sd_epoch{epoch}.pth'
    sd = load_f(torch.load(algorithm_state_dict_PATH))
    ### Initialize ERM Model ###
    # hparams = hparams_f('ERM')
    # algorithm = ERM(num_classes=NUM_CLASSES, num_attributes=NUM_ATTRIBUTES, hparams=hparams)
    # algorithm.load_state_dict(sd)
    ### Load ERM Model ###
    bert = HookedEncoder(HookedTransformerConfig(**bert_config(NUM_CLASSES)))
    bert.load_state_dict(ERM_to_HookedEncoder(sd, bert.state_dict()))
    return bert

def LogitLens(bert, res):
    res = bert.mlm_head(res)
    res = bert.unembed(res)
    return res

algorithm_names = ['erm', 'groupdro', 'focal', 'jtt', 'lff']
algorithms = {algorithm_name: load_and_init(algorithm_name).cuda().eval() for algorithm_name in algorithm_names}
logit_lenses = {algorithm_name: lambda res: LogitLens(algorithms[algorithm_name],res) for algorithm_name in algorithm_names}

from sklearn.metrics import accuracy_score

def evaluate_model(model, X_test, Y_test):
    with torch.no_grad():
        test_outputs = model(X_test)
        _, predicted = torch.max(test_outputs.data, 2)
        test_acc = accuracy_score(Y_test, predicted.detach().cpu())
        return test_acc

# Per Group Logit Lens Accuracy

In [None]:
LABEL_Y = {}
for y_idx, y_key in enumerate(REPRS['erm']['layer0'].keys()):
    LABEL_Y[y_key] = {}
    for a_idx, a_key in enumerate(REPRS['erm']['layer0']['y0'].keys()):
        LABEL_Y[y_key][a_key] = torch.tensor([y_idx] * PER_GROUP_REPR)

In [None]:
Y_PROBE_LOGITLENS = {}

## add rest
for algorithm_key in ['erm', 'groupdro', 'focal', 'jtt', 'lff']:
    Y_PROBE_LOGITLENS[algorithm_key] = {}
    for layer_key in tqdm.tqdm(REPRS[algorithm_key].keys()):
        Y_PROBE_LOGITLENS[algorithm_key][layer_key] = {}
        for y_key in REPRS[algorithm_key][layer_key].keys():
            for a_key in REPRS[algorithm_key][layer_key][y_key].keys():
                Y_PROBE_LOGITLENS[algorithm_key][layer_key][f'{y_key}-{a_key}'] = evaluate_model(logit_lenses[algorithm_key],REPRS[algorithm_key][layer_key][y_key][a_key].unsqueeze(dim=1).cuda(), LABEL_Y[y_key][a_key])

In [None]:
Y_PROBE_LOGITLENS_df = pd.concat({algorithm_key: pd.DataFrame(Y_PROBE_LOGITLENS[algorithm_key]) for algorithm_key in ['erm', 'groupdro', 'jtt', 'lff', 'focal']},axis=1) * 100
Y_PROBE_LOGITLENS_df.to_csv(path_to_root + f'/results/LogitLens/{DATASET}_seed{SEED}')

In [None]:
Y_PROBE_LOGITLENS_df = pd.read_csv(path_to_root + f'/results/LogitLens/{DATASET}_seed{SEED}' , index_col=[0], header=[0,1] ).round(1)

In [None]:
import torch
import pandas as pd
import numpy as np

DATASET = 'MultiNLI' # ['MultiNLI', 'CivilComments']
SEED = 0
LogitLensTables = []

for SEED in [0,1,2]:
    df = pd.read_csv(path_to_root + f'/results/LogitLens/{DATASET}_seed{SEED}' , index_col=[0], header=[0,1] )
    LogitLensTables.append(df)