# Install

In [None]:
!pip install einops datasets jaxtyping better_abc fancy_einsum wandb netcal

# Setup

In [None]:
import sys
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
path_to_root = '/content/drive/My Drive/Colab Notebooks/BatuEl_Dissertation'
sys.path.append(path_to_root)
print("Drive mounted.")

data_path = path_to_root + '/data'

In [None]:
import torch
import tqdm
from reprshift.learning.algorithms import ERM
from reprshift.models.hparams import hparams_f
from reprshift.dataset.datasets import MultiNLI, CivilComments
from reprshift.dataset.dataloaders import InfiniteDataLoader, FastDataLoader

from reprshift.models.model_param_maps import ERM_to_HookedEncoder, load_focal, load_groupdro, load_jtt, load_lff
from reprshift.models.HookedEncoderConfig import bert_config

from transformer_lens2 import HookedEncoder, HookedTransformerConfig
import numpy as np
import pandas as pd

# Dataset

In [None]:
DATASET = 'MultiNLI'  # 'CivilComments' , 'MultiNLI'
SEED = 2

if DATASET == 'MultiNLI':
    NUM_CLASSES = 3
    NUM_ATTRIBUTES = 2
    # train_dataset = MultiNLI(data_path, 'tr',  hparams=hparams_f('ERM'))
    # val_dataset = MultiNLI(data_path, 'va', hparams=hparams_f('ERM'))
    # te_dataset = MultiNLI(data_path, 'te', hparams=hparams_f('ERM'))
    models_path = path_to_root + '/models/models_mnli'
    representations_path = path_to_root + '/representations/representations_mnli'
    print(DATASET)
elif DATASET  == 'CivilComments':
    NUM_CLASSES = 2
    NUM_ATTRIBUTES = 8
    # train_dataset = CivilComments(data_path, 'tr',  hparams=hparams_f('ERM'), granularity="fine")
    # val_dataset = CivilComments(data_path, 'va', hparams=hparams_f('ERM'))
    # te_dataset = CivilComments(data_path, 'te', hparams=hparams_f('ERM'))
    models_path = path_to_root + '/models/models_civilcomments'
    representations_path = path_to_root + '/representations/representations_civilcomments'
    print(DATASET)
else:
    print('Dataset Not Implemented')

# Models

In [None]:
valmetrics = torch.load(path_to_root + f'/results/ValidationMetrics/clean_val_results.pth')
CURRENT_BEST_EPOCH = valmetrics[DATASET][SEED]['A Selection']
### MODELS ###
MODELS  = {'pretrained': {'path': models_path + '/00_randominit/',  'load_f': lambda x: x, 'epoch':0,},
          'erm': {'path': models_path + '/01_erm/',  'load_f': lambda x: x, 'epoch':CURRENT_BEST_EPOCH['erm'],},
          'groupdro': {'path': models_path + '/03_groupdro/',  'load_f': load_groupdro, 'epoch':CURRENT_BEST_EPOCH['groupdro'],},
          'jtt': {'path': models_path + '/06_jtt/',  'load_f': load_jtt, 'epoch':CURRENT_BEST_EPOCH['jtt'],},
          'lff': {'path': models_path + '/07_lff/',  'load_f': load_lff, 'epoch':CURRENT_BEST_EPOCH['lff'],},
          'focal': {'path': models_path + '/15_focal/',  'load_f': lambda x: x, 'epoch':CURRENT_BEST_EPOCH['focal'],},}

def load_and_init(algorithm_name, seed):
    state_dict_PATH = MODELS[algorithm_name]['path']
    load_f =  MODELS[algorithm_name]['load_f']
    epoch = MODELS[algorithm_name]['epoch']
    algorithm_state_dict_PATH = state_dict_PATH + f'seed{seed}/sd_epoch{epoch}.pth'
    sd = load_f(torch.load(algorithm_state_dict_PATH))
    ### Initialize ERM Model ###
    # hparams = hparams_f('ERM')
    # algorithm = ERM(num_classes=NUM_CLASSES, num_attributes=NUM_ATTRIBUTES, hparams=hparams)
    # algorithm.load_state_dict(sd)
    ### Load ERM Model ###
    bert = HookedEncoder(HookedTransformerConfig(**bert_config(NUM_CLASSES)))
    bert.load_state_dict(ERM_to_HookedEncoder(sd, bert.state_dict()))
    return bert

In [None]:
bert = load_and_init('erm', SEED).cuda().eval()
# Set requires_grad to False for all parameters of BERT
for param in bert.parameters():
    param.requires_grad = False

In [None]:
CURRENT_BEST_EPOCH

# Cache Representations

In [None]:
if DATASET == 'CivilComments':
    per_group_repr = 150
if DATASET == 'MultiNLI':
    per_group_repr = 600

unique_y = [i for i in range(NUM_CLASSES)]
unique_a = [i for i in range(NUM_ATTRIBUTES)]
y_s = [f'y{i}' for i in range(NUM_CLASSES)]
a_s = [f'a{i}' for i in range(NUM_ATTRIBUTES)]

REPR_KEYS = ['blocks.11.ln1.hook_normalized', 'blocks.11.mlp.hook_post' , 'blocks.11.hook_mlp_out', 'blocks.11.ln2.hook_normalized']
REPRS = {repr_key: {f'y{y_idx}':{f'a{a_idx}':[] for a_idx in unique_a}  for y_idx in unique_y} for repr_key in REPR_KEYS}
LOGITS = {f'y{y_idx}':{f'a{a_idx}':[] for a_idx in unique_a} for y_idx in unique_y}
LOGITS, REPRS

In [None]:
SPLIT = 'va'  # 'tr', 'va', 'te'

# if SPLIT == 'tr':
#     dataset = train_dataset
# elif SPLIT == 'va':
#     dataset = val_dataset
# elif SPLIT == 'te':
#     dataset = te_dataset
# else:
#     print('Split Unavailable')

# for Y_CURR in unique_y:
#     for A_CURR in unique_a:
#         print(f'y{Y_CURR}', f'a{A_CURR}')
#         val_loader = FastDataLoader(  dataset=dataset,
#                               batch_size=32,
#                               num_workers=1,
#                               )
#         train_minibatches_iterator = iter(val_loader)

#         for step in tqdm.tqdm(range(len(val_loader))):
#             total_group_members = np.sum([batch.shape[0] for batch in LOGITS[f'y{Y_CURR}'][f'a{A_CURR}']])
#             if total_group_members > per_group_repr:
#                 break
#             i, x, y, a = next(train_minibatches_iterator)

#             A_MASK = (a == A_CURR)
#             Y_MASK = (y == Y_CURR)
#             MASK = A_MASK & Y_MASK
#             x = x[MASK]
#             input_ids = x[:,:,0].cuda()
#             one_zero_attention_mask = x[:,:,1].cuda()
#             token_type_ids = x[:,:,2].cuda()

#             with torch.no_grad():
#                 logits, cache = bert.run_with_cache(input_ids, one_zero_attention_mask=one_zero_attention_mask, token_type_ids=token_type_ids)
#                 LOGITS[f'y{Y_CURR}'][f'a{A_CURR}'].append(logits[:,0,:].detach().cpu())
#                 for repr_key in REPR_KEYS:
#                     REPRS[repr_key][f'y{Y_CURR}'][f'a{A_CURR}'].append(cache[repr_key][:,0,:].detach().cpu())
#             del logits, cache
#             torch.cuda.empty_cache()
#         LOGITS[f'y{Y_CURR}'][f'a{A_CURR}'] = torch.cat(LOGITS[f'y{Y_CURR}'][f'a{A_CURR}'])[:per_group_repr]
#         for repr_key in REPR_KEYS:
#             REPRS[repr_key][f'y{Y_CURR}'][f'a{A_CURR}'] =  torch.cat(REPRS[repr_key][f'y{Y_CURR}'][f'a{A_CURR}'])[:per_group_repr]

# torch.save(LOGITS, representations_path + f'/ModelEdit/LOGITS_{SPLIT}_seed{SEED}' )
# torch.save(REPRS, representations_path + f'/ModelEdit/REPRS_{SPLIT}_seed{SEED}' )

In [None]:
LOGITS = torch.load(representations_path + f'/ModelEdit/LOGITS_{SPLIT}_seed{SEED}')
REPRS  = torch.load(representations_path + f'/ModelEdit/REPRS_{SPLIT}_seed{SEED}' )

In [None]:
SEED

In [None]:
### Sanity Check:  bert.W_out[11] ###
mlp_out = lambda x: (x.cuda()  @ bert.W_out[11]) +  bert.b_out[11]
mlp_out(REPRS['blocks.11.mlp.hook_post']['y0']['a0']).shape, REPRS['blocks.11.hook_mlp_out']['y0']['a0'].shape
REPRS['blocks.11.hook_mlp_out']['y0']['a0'].cuda().round(decimals=4) == mlp_out(REPRS['blocks.11.mlp.hook_post']['y0']['a0']).round(decimals=4)

In [None]:
### Sanity Check:  Logit Lens ###
def Layer11Norm2(res):
    res = bert.blocks[11].ln2(res.unsqueeze(dim=1).cuda())
    return res
def LogitLens(res):
    res = bert.mlm_head(res)
    res = bert.unembed(res)
    return res
def LogitLenswLayer11Norm2(res):
    res = bert.blocks[11].ln2(res.unsqueeze(dim=1).cuda())
    res = bert.mlm_head(res)
    res = bert.unembed(res)
    return res[:,0,:]

# Implementation with Matrix Multiplication
def logit_lens_with_norm(cls_reprs, sd=bert.state_dict()):
  cls_reprs = cls_reprs.cuda()
  mean = cls_reprs.mean(dim=1, keepdim=True)
  std = cls_reprs.std(dim=1, keepdim=True)
  cls_reprs = (cls_reprs - mean) / (std)
  cls_reprs = (sd['blocks.11.ln2.w'] * cls_reprs) + sd['blocks.11.ln2.b']
  out = cls_reprs @ sd['mlm_head.W'].T
  out = out + sd['mlm_head.b']
  out = torch.tanh(out)
  out = out @ sd['unembed.W_U']
  out = out + sd['unembed.b_U']
  return out

resid_post = REPRS['blocks.11.ln1.hook_normalized']['y0']['a0'] + REPRS['blocks.11.hook_mlp_out']['y0']['a0']
# Implementation with Matrix Multiplication
# logit_lens_with_norm(resid_post)[:,:], LOGITS['y0']['a0']
# Implementation with Components
LogitLenswLayer11Norm2(resid_post), LOGITS['y0']['a0']

# Logit Shift

In [None]:
mean_logits_dict = {}
for ai in a_s:
    mean_logits_ai = []
    for label in range(NUM_CLASSES):
        logits = LOGITS[y_s[label]][ai]
        mean_logits = LOGITS[y_s[label]][ai].mean(dim=0)
        mean_logits_ai.append(mean_logits.numpy())
    mean_logits_ai = np.array(mean_logits_ai).T
    mean_logits_dict[ai] =  pd.DataFrame(mean_logits_ai, columns=y_s)

mean_logits_df = pd.concat(mean_logits_dict, axis=1)
mean_logits_df.index = y_s

### For Validation ###
mean_logits_df_PATH = path_to_root + f'/results/LogitShift/MeanLogitDiffVal_{DATASET}_seed{SEED}'
mean_logits_df.to_csv(mean_logits_df_PATH)

### For Training ###
# mean_logits_df_PATH = path_to_root + f'/results/LogitShift/MeanLogitDiffTrain_{DATASET}_seed{SEED}'
# mean_logits_df.to_csv(mean_logits_df_PATH)

# Confusion Matrix

In [None]:
conf_matrix_a0 = np.zeros((NUM_CLASSES, NUM_CLASSES), dtype=int)
conf_matrix_a1 = np.zeros((NUM_CLASSES, NUM_CLASSES), dtype=int)

conf_matrix = {ai: np.zeros((NUM_CLASSES, NUM_CLASSES), dtype=int) for ai in a_s}

# for ai in a_s:
conf_matrix_dict = {}

for ai in a_s:
    for true_label in range(NUM_CLASSES):
        logits_ai = LOGITS[y_s[true_label]][ai]
        pred_labels_ai = logits_ai.argmax(dim=1).numpy()
        for pred_label in pred_labels_ai:
            conf_matrix[ai][true_label, pred_label] += 1
    conf_matrix_dict[ai] = pd.DataFrame(conf_matrix[ai], columns=[f'Pred y{i}' for i in range(NUM_CLASSES)], index=[f'True y{i}' for i in range(NUM_CLASSES)]).T / per_group_repr * 100

accuracy_df = pd.concat(conf_matrix_dict, axis=1).round(2)

accuracy_df_PATH = path_to_root + f'/results/LogitShift/ConfusionMatrixVal_{DATASET}_seed{SEED}'
accuracy_df.to_csv(accuracy_df_PATH)

# accuracy_df_PATH = path_to_root + f'/results/LogitShift/ConfusionMatrixTrain_{DATASET}_seed{SEED}'
# accuracy_df.to_csv(accuracy_df_PATH)

In [None]:
accuracy_df

# Logit Modification

In [None]:
modification_factor = 1

In [None]:
label = 0

logits = LOGITS['y2']['a1']
accuracy = (logits.argmax(dim=1) == label).sum()
modified_logits = logits - torch.tensor([modification_factor,0,0])
modified_accuracy = (modified_logits.argmax(dim=1) == label).sum()
accuracy, modified_accuracy