# Install

In [None]:
!pip install einops datasets jaxtyping better_abc fancy_einsum wandb netcal

# Setup

In [None]:
import sys
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
path_to_root = '/content/drive/My Drive/Colab Notebooks/BatuEl_Dissertation'
sys.path.append(path_to_root)
print("Drive mounted.")

data_path = path_to_root + '/data'

In [None]:
import torch
import tqdm
from reprshift.learning.algorithms import ERM
from reprshift.models.hparams import hparams_f
from reprshift.dataset.datasets import MultiNLI, CivilComments
from reprshift.dataset.dataloaders import InfiniteDataLoader, FastDataLoader

from reprshift.models.model_param_maps import ERM_to_HookedEncoder, load_focal, load_groupdro, load_jtt, load_lff
from reprshift.models.HookedEncoderConfig import bert_config

from transformer_lens2 import HookedEncoder, HookedTransformerConfig
import numpy as np

# Dataset

In [None]:
DATASET = 'MultiNLI'  # 'CivilComments' , 'MultiNLI'

if DATASET == 'MultiNLI':
    NUM_CLASSES = 3
    NUM_ATTRIBUTES = 2
    # train_dataset = MultiNLI(data_path, 'tr', hparams=hparams_f('ERM'))
    # val_dataset = MultiNLI(data_path, 'va', hparams=hparams_f('ERM'))
    te_dataset = MultiNLI(data_path, 'te', hparams=hparams_f('ERM'))
    models_path = path_to_root + '/models/models_mnli'
    representations_path = path_to_root + '/representations/representations_mnli'
    print(DATASET)
elif DATASET  == 'CivilComments':
    NUM_CLASSES = 2
    NUM_ATTRIBUTES = 8
    # train_dataset = CivilComments(data_path, 'tr', hparams=hparams_f('ERM'), granularity="fine")
    # val_dataset = CivilComments(data_path, 'va', hparams=hparams_f('ERM'))
    te_dataset = CivilComments(data_path, 'te', hparams=hparams_f('ERM'))
    models_path = path_to_root + '/models/models_civilcomments'
    representations_path = path_to_root + '/representations/representations_civilcomments'
    print(DATASET)
else:
    print('Dataset Not Implemented')

# Save CLS Repr

In [None]:
for SEED in [0,1,2]:
    valmetrics = torch.load(path_to_root + f'/results/ValidationMetrics/clean_val_results.pth')
    CURRENT_BEST_EPOCH = valmetrics[DATASET][SEED]['WGA Selection']

    ### MODELS ###
    MODELS  = {'pretrained': {'path': models_path + '/00_randominit/',  'load_f': lambda x: x, 'epoch':0,},
              'erm': {'path': models_path + '/01_erm/',  'load_f': lambda x: x, 'epoch':CURRENT_BEST_EPOCH['erm'],},
              'groupdro': {'path': models_path + '/03_groupdro/',  'load_f': load_groupdro, 'epoch':CURRENT_BEST_EPOCH['groupdro'],},
              'jtt': {'path': models_path + '/06_jtt/',  'load_f': load_jtt, 'epoch':CURRENT_BEST_EPOCH['jtt'],},
              'lff': {'path': models_path + '/07_lff/',  'load_f': load_lff, 'epoch':CURRENT_BEST_EPOCH['lff'],},
              'focal': {'path': models_path + '/15_focal/',  'load_f': lambda x: x, 'epoch':CURRENT_BEST_EPOCH['focal'],},}

    def load_and_init(algorithm_name, seed):
        state_dict_PATH = MODELS[algorithm_name]['path']
        load_f =  MODELS[algorithm_name]['load_f']
        epoch = MODELS[algorithm_name]['epoch']
        algorithm_state_dict_PATH = state_dict_PATH + f'seed{seed}/sd_epoch{epoch}.pth'
        sd = load_f(torch.load(algorithm_state_dict_PATH))
        ### Initialize ERM Model ###
        # hparams = hparams_f('ERM')
        # algorithm = ERM(num_classes=NUM_CLASSES, num_attributes=NUM_ATTRIBUTES, hparams=hparams)
        # algorithm.load_state_dict(sd)
        ### Load ERM Model ###
        bert = HookedEncoder(HookedTransformerConfig(**bert_config(NUM_CLASSES)))
        bert.load_state_dict(ERM_to_HookedEncoder(sd, bert.state_dict()))
        return bert

    ######################################
    algorithm_names = ['pretrained', 'erm', 'groupdro', 'jtt', 'lff', 'focal']
    if DATASET == 'CivilComments':
        per_group_repr = 225 #225
    if DATASET == 'MultiNLI':
        per_group_repr = 600 #600
    algorithms = {algorithm_name: load_and_init(algorithm_name, SEED).cuda().eval() for algorithm_name in algorithm_names}


    ######################################
    ### Add Randomly Initialized Model ###
    ######################################
    cfg = bert_config(NUM_CLASSES)
    cfg['init_mode'] = 'xavier_normal'
    bert_random = HookedEncoder(HookedTransformerConfig(**bert_config(NUM_CLASSES)))
    algorithms['randominit'] = bert_random
    algorithm_names = ['randominit', 'pretrained', 'erm', 'groupdro', 'focal', 'jtt', 'lff', ]
    # Function to initialize model weights with Xavier normal initialization
    def initialize_weights(model):
        for name, param in model.named_parameters():
            torch.nn.init.normal_(param.data, mean=0.0, std=0.02)
    # Apply the initialization function to the transformer model
    initialize_weights(bert_random)


    #########################################################
    ### Do the Forward Passes and Get CLS Representations ###
    #########################################################
    unique_y = [i for i in range(NUM_CLASSES)]
    unique_a = [i for i in range(NUM_ATTRIBUTES)]

    REPRS = {}
    LOGITS = {}
    for algorithm_name in algorithm_names:
        LOGITS[algorithm_name] = {f'y{y_idx}':{f'a{a_idx}':[] for a_idx in unique_a} for y_idx in unique_y}
        REPRS[algorithm_name] = {f'layer{i}' : {f'y{y_idx}':{f'a{a_idx}':[] for a_idx in unique_a}  for y_idx in unique_y} for i in range(12)}

    for Y_CURR in unique_y:
        for A_CURR in unique_a:
            print(f'y{Y_CURR}', f'a{A_CURR}')
            val_loader = FastDataLoader(  dataset=te_dataset,
                                  batch_size=32,
                                  num_workers=1,
                                  )
            train_minibatches_iterator = iter(val_loader)

            for step in tqdm.tqdm(range(len(val_loader))):
                total_group_members = np.sum([batch.shape[0] for batch in LOGITS['erm'][f'y{Y_CURR}'][f'a{A_CURR}']])
                if total_group_members > per_group_repr:
                    break
                i, x, y, a = next(train_minibatches_iterator)

                A_MASK = (a == A_CURR)
                Y_MASK = (y == Y_CURR)
                MASK = A_MASK & Y_MASK
                x = x[MASK]
                input_ids = x[:,:,0].cuda()
                one_zero_attention_mask = x[:,:,1].cuda()
                token_type_ids = x[:,:,2].cuda()

                for algorithm_name in algorithm_names:
                    bert = algorithms[algorithm_name]
                    with torch.no_grad():
                        logits, cache = bert.run_with_cache(input_ids, one_zero_attention_mask=one_zero_attention_mask, token_type_ids=token_type_ids)
                        LOGITS[algorithm_name][f'y{Y_CURR}'][f'a{A_CURR}'].append(logits.detach().cpu())
                        for layer_idx in range(12):
                            REPRS[algorithm_name][f'layer{layer_idx}'][f'y{Y_CURR}'][f'a{A_CURR}'].append(cache[f'blocks.{layer_idx}.hook_normalized_resid_post'][:,0,:].detach().cpu())
                    del logits, cache
                    torch.cuda.empty_cache()

            for algorithm_name in algorithm_names:
                LOGITS[algorithm_name][f'y{Y_CURR}'][f'a{A_CURR}'] = torch.cat(LOGITS[algorithm_name][f'y{Y_CURR}'][f'a{A_CURR}'])[:per_group_repr]
                for layer_idx in range(12):
                    REPRS[algorithm_name][f'layer{layer_idx}'][f'y{Y_CURR}'][f'a{A_CURR}'] =  torch.cat(REPRS[algorithm_name][f'layer{layer_idx}'][f'y{Y_CURR}'][f'a{A_CURR}'])[:per_group_repr]

    ######################################
    ### Save Random Matrix of Same Size ##
    ######################################
    algorithm_key = 'random'
    REPRS[algorithm_key] = {}
    for layer_key in tqdm.tqdm(REPRS['erm'].keys()):
        REPRS[algorithm_key][layer_key] = {}
        for y_key in REPRS['erm'][layer_key].keys():
            REPRS[algorithm_key][layer_key][y_key] = {}
            for a_key in REPRS['erm'][layer_key][y_key].keys():
                REPRS[algorithm_key][layer_key][y_key][a_key] = torch.rand_like(REPRS['erm'][layer_key][y_key][a_key])

    ###############################
    ### Save Everything to Drive ##
    ###############################
    torch.save(LOGITS, f'{representations_path}/seed{SEED}'+'_logits')
    torch.save(REPRS, f'{representations_path}/seed{SEED}'+'_reprs')