# Setup

## Install

In [None]:
!pip install einops datasets jaxtyping better_abc fancy_einsum wandb

## Mount Drive

In [None]:
import sys
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
path_to_tez = '/content/drive/My Drive/Colab Notebooks/BatuEl_Dissertation'
sys.path.append(path_to_tez)
print("Drive mounted.")

## Imports

In [None]:
import torch
import tqdm
from reprshift.learning.algorithms import ERM
from reprshift.models.hparams import hparams_f
from reprshift.dataset.datasets import MultiNLI
from reprshift.dataset.dataloaders import InfiniteDataLoader, FastDataLoader
import pandas as pd
import numpy as np

from transformers import AutoTokenizer
from transformer_lens2 import HookedEncoder, HookedTransformerConfig
from reprshift.models.model_param_maps import ERM_to_HookedEncoder

from jaxtyping import Float
from functools import partial
from transformer_lens2 import utils, ActivationCache
import einops
import random
import pickle

## Plots

In [None]:
# Reference: https://colab.research.google.com/drive/1IgVv13RLWO_YnpgSxN2XG8qp_MHy6q0j#scrollTo=MKaFMPo8egxm

import plotly.express as px

def imshow(tensor, **kwargs):
    px.imshow(
        utils.to_numpy(tensor),
        color_continuous_midpoint=0.0,
        color_continuous_scale="RdBu",
        **kwargs,
    ).show()


def line(tensor, **kwargs):
    px.line(
        y=utils.to_numpy(tensor),
        **kwargs,
    ).show()


def scatter(x, y, xaxis="", yaxis="", caxis="", **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(
        y=y,
        x=x,
        labels={"x": xaxis, "y": yaxis, "color": caxis},
        **kwargs,
    ).show()

# Data

## Load Data

In [None]:
data_path = 'drive/MyDrive/Colab Notebooks/BatuEl_Dissertation/data/multinli'

metadata = 'metadata_multinli.csv'
train_uncased_128_mnli = 'glue_data/MNLI/cached_train_bert-base-uncased_128_mnli'
dev_uncased_128_mnli = 'glue_data/MNLI/cached_dev_bert-base-uncased_128_mnli'
dev_uncased_128_mnli_mm = 'glue_data/MNLI/cached_dev_bert-base-uncased_128_mnli-mm'

### Features
with open(data_path + '/' + train_uncased_128_mnli, 'rb') as f:
    train = torch.load(f)
with open(data_path + '/' + dev_uncased_128_mnli, 'rb') as f:
    dev_matched = torch.load(f)
with open(data_path + '/' + dev_uncased_128_mnli_mm, 'rb') as f:
    dev_mismatched = torch.load(f)
features = np.array(train + dev_matched + dev_mismatched)

### Metadata
metadata = pd.read_csv(data_path + '/' + metadata)

# Preprocessing Function to create Masks
def preprocess_features(features_array):
    all_input_ids = torch.tensor([f.input_ids for f in features_array], dtype=torch.long)
    all_input_masks = torch.tensor([f.input_mask for f in features_array], dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in features_array], dtype=torch.long)
    all_label_ids = torch.tensor([f.label_id for f in features_array], dtype=torch.long)
    x_array = torch.stack((all_input_ids, all_input_masks, all_segment_ids), dim=2)
    return x_array

## Validation & Test Sets

In [None]:
def get_features(split, num_examples):
    # Shuffle
    def shuffle_and_slice(features, num_examples):
        perm = torch.randperm(features.size(0))
        features_shuffled = features[perm]
        return features_shuffled[:num_examples, :, :]
    # Features Arrays
    features_a0_y0 = features[(metadata['split'] == split) & (metadata['a'] == 0) & (metadata['y'] == 0)]
    features_a0_y1 = features[(metadata['split'] == split) & (metadata['a'] == 0) & (metadata['y'] == 1)]
    features_a0_y2 = features[(metadata['split'] == split) & (metadata['a'] == 0) & (metadata['y'] == 2)]
    features_a1_y0 = features[(metadata['split'] == split) & (metadata['a'] == 1) & (metadata['y'] == 0)]
    features_a1_y1 = features[(metadata['split'] == split) & (metadata['a'] == 1) & (metadata['y'] == 1)]
    features_a1_y2 = features[(metadata['split'] == split) & (metadata['a'] == 1) & (metadata['y'] == 2)]
    # Preprocessed Arrays
    X_a0_y0 = preprocess_features(features_a0_y0).cuda()
    X_a0_y1 = preprocess_features(features_a0_y1).cuda()
    X_a0_y2 = preprocess_features(features_a0_y2).cuda()
    X_a1_y0 = preprocess_features(features_a1_y0).cuda()
    X_a1_y1 = preprocess_features(features_a1_y1).cuda()
    X_a1_y2 = preprocess_features(features_a1_y2).cuda()
    # Shuffle and slice
    X_a0_y0 = shuffle_and_slice(X_a0_y0, num_examples)
    X_a0_y1 = shuffle_and_slice(X_a0_y1, num_examples)
    X_a0_y2 = shuffle_and_slice(X_a0_y2, num_examples)
    X_a1_y0 = shuffle_and_slice(X_a1_y0, num_examples)
    X_a1_y1 = shuffle_and_slice(X_a1_y1, num_examples)
    X_a1_y2 = shuffle_and_slice(X_a1_y2, num_examples)
    # Return Dicts
    X_array_a0 = {0: X_a0_y0, 1: X_a0_y1, 2:X_a0_y2}
    X_array_a1 = {0: X_a1_y0, 1: X_a1_y1, 2:X_a1_y2}
    return X_array_a0, X_array_a1

X_array_a0_va, X_array_a1_va = get_features(1, 600)
X_array_a0_te, X_array_a1_te = get_features(2, 800)

## Define Shortcuts and Counterfactuals

In [None]:
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

shortcut_tokens = ['no', 'nothing' ,  'nobody' ,  'never']
shortcut_token_ids = [tokenizer(token)['input_ids'][1] for token in shortcut_tokens]
shortcut_token_index = {token: tokenizer(token)['input_ids'][1] for token in shortcut_tokens}
print('Shortcuts:', shortcut_token_index)

counterfactual_tokens = ['some', 'everything' ,  'everybody' ,  'always']
counterfactual_token_ids = [tokenizer(token)['input_ids'][1] for token in counterfactual_tokens]
counterfactual_token_index = {token: tokenizer(token)['input_ids'][1] for token in counterfactual_tokens}
print('Counterfactuals:', counterfactual_token_index)

shortcuts_to_counterfactuals = {'no':'some', 'nothing':'everything', 'nobody':'everybody','never':'always'}
shortcuts_to_counterfactuals_tokens = {shortcut_token_index[key]:counterfactual_token_index[shortcuts_to_counterfactuals[key]] for key in shortcuts_to_counterfactuals}
print('Map:', shortcuts_to_counterfactuals_tokens)

In [None]:
def corrupt_shortcuts(X_array):
    X_array_corrupt = X_array.detach().clone()
    num_examples = X_array.shape[0]
    seq_length =  X_array.shape[1]
    for example_idx in range(num_examples):
        for pos_idx in range(seq_length):
            if X_array[example_idx][pos_idx][0] in list(shortcuts_to_counterfactuals_tokens.keys()):
                if X_array[example_idx][pos_idx][2] == 1:
                    shortcut_token = X_array[example_idx][pos_idx][0].item()
                    counterfactual_token = shortcuts_to_counterfactuals_tokens[shortcut_token]
                    X_array_corrupt[example_idx][pos_idx][0] = counterfactual_token
    return X_array_corrupt

X_array_a1_va_corrupt = {key:corrupt_shortcuts(X_array_a1_va[key]) for key in [0,1,2]}
X_array_a1_te_corrupt = {key:corrupt_shortcuts(X_array_a1_te[key]) for key in [0,1,2]}

In [None]:
print(X_array_a1_va_corrupt[0].shape, X_array_a1_va_corrupt[1].shape, X_array_a1_va_corrupt[2].shape)
print(X_array_a1_te_corrupt[0].shape, X_array_a1_te_corrupt[1].shape, X_array_a1_te_corrupt[2].shape)

In [None]:
# Randomly select one of the tokens
def corrupt_controls(X_array):
    X_array_control = X_array.detach().clone()
    num_examples = X_array.shape[0]
    seq_length =  X_array.shape[1]
    for example_idx in range(num_examples):
        pos_idxs = torch.tensor([i for i in range(seq_length)]).cuda()
        token_ids_mask = X_array[example_idx, :, 2] == 1
        candidate_token_idxs = pos_idxs[token_ids_mask]
        candidate_token_idxs = candidate_token_idxs[:-1] # exclude [SEP]
        random_index = torch.randint(0, candidate_token_idxs.size(0), (1,))
        control_token_idx = candidate_token_idxs[random_index]
        control_token = torch.tensor(random.choice(counterfactual_token_ids)).cuda()
        X_array_control[example_idx, control_token_idx, 0] = control_token
    return X_array_control

X_array_a1_va_control = {key:corrupt_controls(X_array_a1_va[key]) for key in [0,1,2]}
X_array_a1_te_control = {key:corrupt_controls(X_array_a1_te[key]) for key in [0,1,2]}

In [None]:
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

def show_sents (X_array, n_examples):
    X_array_len = list(X_array[:n_examples, :, 1].sum(dim=1))
    sent_tokens = tokenizer.batch_decode(X_array[:n_examples, :, 0])
    sents = [tokenizer.decode(sent[:X_array_len[i]]) for i, sent in enumerate(X_array[:n_examples, :, 0])]
    for idx, sent in enumerate(sents):
        print(idx, sent)

print("Contradiction")
print('Clean:')
show_sents (X_array_a1_va[0], 2)
print('Corrupted:')
show_sents (X_array_a1_va_corrupt[0], 2)
print('Control:')
show_sents (X_array_a1_va_control[0], 2)

print("Entailment")
print('Clean:')
show_sents (X_array_a1_va[1], 2)
print('Corrupted:')
show_sents (X_array_a1_va_corrupt[1], 2)
print('Control:')
show_sents (X_array_a1_va_control[1], 2)

print("Neutral")
print('Clean:')
show_sents (X_array_a1_va[2], 2)
print('Corrupted:')
show_sents (X_array_a1_va_corrupt[2], 2)
print('Control:')
show_sents (X_array_a1_va_control[2], 2)

## Shortcuts Breakdown

In [None]:
# X_array_a0 = X_array_a0_va
# X_array_a1 = X_array_a1_va

# X_a1 = {0:{2053:[],2498:[],6343:[], 2196:[],},1:{2053:[],2498:[],6343:[], 2196:[],},2:{2053:[],2498:[],6343:[], 2196:[],},}
# X_a1_corrupt = {0:{2053:[],2498:[],6343:[], 2196:[],},1:{2053:[],2498:[],6343:[], 2196:[],},2:{2053:[],2498:[],6343:[], 2196:[],},}

# for label_key in [0,1,2]:

#     X_array = X_array_a1[label_key]
#     X_array_corrupt = X_array_a1[label_key]
#     num_examples = X_array.shape[0]
#     print(num_examples)
#     seq_length = X_array.shape[1]

#     ### Clean Data ###
#     for example_idx in range(num_examples):
#         for pos_idx in range(seq_length):
#             if X_array[example_idx][pos_idx][2] == 1:
#                 if X_array[example_idx][pos_idx][0] in list(shortcuts_to_counterfactuals_tokens.keys()):
#                     X_a1[label_key][X_array[example_idx][pos_idx][0].item()].append(X_array[example_idx])
#                     # some of the sentences contain more than one shortcut token
#     for key in X_a1[label_key].keys():
#         X_a1[label_key][key] = torch.stack(X_a1[label_key][key])
#         # print('clean:', X_a1[label_key][key].shape)

#     ### Corrupt Data ###
#     for example_idx in range(num_examples):
#         for pos_idx in range(seq_length):
#             if X_array[example_idx][pos_idx][2] == 1:
#                 if X_array[example_idx][pos_idx][0] in list(shortcuts_to_counterfactuals_tokens.keys()):
#                     shortcut_token = X_array[example_idx][pos_idx][0].item()
#                     counterfactual_token = shortcuts_to_counterfactuals_tokens[shortcut_token]
#                     X_array_corrupt[example_idx][pos_idx][0] = counterfactual_token
#                     X_a1_corrupt[label_key][shortcut_token].append(X_array_corrupt[example_idx])
#     for key in X_a1_corrupt[label_key].keys():
#         X_a1_corrupt[label_key][key] = torch.stack(X_a1_corrupt[label_key][key])
#         # print('corrput:', X_a1_corrupt[label_key][key].shape)
# # 78 + 15 + 2 + 35, 84 + 8 + 5 + 35, 71 + 8 + 7 + 44, 525 + 85 +21+185

In [None]:
# ### Sanity Check ###
# def show_sents (X_array, n_examples):
#     X_array_len = list(X_array[:n_examples, :, 1].sum(dim=1))
#     sent_tokens = tokenizer.batch_decode(X_array[:n_examples, :, 0])
#     sents = [tokenizer.decode(sent[:X_array_len[i]]) for i, sent in enumerate(X_array[:n_examples, :, 0])]
#     for idx, sent in enumerate(sents):
#         print(idx, sent)
# print('Clean:')
# show_sents (X_array, 2)
# print('Corrupted:')
# show_sents (X_array_corrupt, 2)

# Model

In [None]:
config = HookedTransformerConfig(
 **{'act_fn': 'gelu',
 'attention_dir': 'bidirectional',
 'attn_only': False,
 'attn_types': None,
 'checkpoint_index': None,
 'checkpoint_label_type': None,
 'checkpoint_value': None,
 'd_head': 64,
 'd_mlp': 3072,
 'd_model': 768,
 'd_vocab': 30522,
 'd_vocab_out': 3,
 'default_prepend_bos': True,
 'device': 'cuda',
 'dtype': torch.float32,
 'eps': 1e-12,
 'final_rms': False,
 'from_checkpoint': False,
 'gated_mlp': False,
 'init_mode': 'gpt2',
 'init_weights': True,
 'initializer_range': 0.02886751345948129,
 'model_name': 'bert-base-cased',
 'n_ctx': 512,
 'n_devices': 1,
 'n_heads': 12,
 'n_key_value_heads': None,
 'n_layers': 12,
#  'n_params': 84934656,
 'normalization_type': 'LN',
 'original_architecture': 'BertForMaskedLM',
 'parallel_attn_mlp': False,
 'positional_embedding_type': 'standard',
 'post_embedding_ln': False,
 'rotary_adjacent_pairs': False,
 'rotary_base': 10000,
 'rotary_dim': None,
 'scale_attn_by_inverse_layer_idx': False,
 'seed': None,
 'tokenizer_name': 'bert-base-uncased',
 'tokenizer_prepends_bos': None,
 'trust_remote_code': False,
 'use_attn_in': False,
 'use_attn_result': False,
 'use_attn_scale': True,
 'use_hook_mlp_in': False,
 'use_hook_tokens': False,
 'use_local_attn': False,
 'use_split_qkv_input': False,
 'window_size': False})

# Load HookedEncoder
bert = HookedEncoder(config)
sd_HookedEncoder = bert.state_dict()
# Load ERM Model
algorithm_name = 'ERM'

random_seed = 2
psuedo_epoch = 30
state_dict_path = f'drive/MyDrive/Colab Notebooks/BatuEl_Dissertation/models/models_mnli/01_erm/seed{random_seed}/sd_epoch{psuedo_epoch}.pth'
# algorithm_state_dict_PATH = state_dict_path + f'statedict_{algorithm_name}_{random_seed}_{psuedo_epoch}.pth'
sd_ERM = torch.load(state_dict_path)

###
sd_ERM_HookedEncoder = ERM_to_HookedEncoder(sd_ERM, sd_HookedEncoder)
bert.load_state_dict(sd_ERM_HookedEncoder)
bert.eval()
bert.cuda()
###
bert_ERM = ERM(num_classes=3, num_attributes=2, hparams=hparams_f('ERM'))
bert_ERM.eval()
bert_ERM.load_state_dict(sd_ERM)
bert_ERM.cuda()
print("------------")
print("Sanity Check")
print("------------")
x_array = X_array_a1_va[2][:5, :, :]
print('ERM Algorithm:\n', bert_ERM.predict(x_array))
bert_ERM.cpu()
###
input_ids = x_array[:, : , 0]
one_zero_attention_mask = x_array[:, : , 1]
token_type_ids = x_array[:, : , 2]
print('HookedEncoder:\n', bert(input_ids, one_zero_attention_mask=one_zero_attention_mask, token_type_ids=token_type_ids)[:,0,:])

# Preliminary Tables

## Table 1

In [None]:
import torch

def run_batches(x_array, batchsize, device='cuda'):
    num_forward = int(x_array.shape[0] / batchsize)
    all_logits = []

    if (x_array.shape[0] / batchsize) - num_forward > 0:
        num_forward += 1
    x_array = x_array.to(device)

    for i in range(num_forward):
        start_idx = i * batchsize
        end_idx = (i + 1) * batchsize
        x_array_curr = x_array[start_idx:end_idx]

        input_ids = x_array_curr[:, :, 0]
        one_zero_attention_mask = x_array_curr[:, :, 1]
        token_type_ids = x_array_curr[:, :, 2]

        with torch.no_grad():
            logits = bert(input_ids,
                          one_zero_attention_mask=one_zero_attention_mask,
                          token_type_ids=token_type_ids)[:, 0, :]
            all_logits.append(logits.cpu())

        del x_array_curr, input_ids, one_zero_attention_mask, token_type_ids, logits
        torch.cuda.empty_cache()

    return torch.cat(all_logits)

In [None]:
res_va = {}
res_va['a1'] = {}
res_va['a0'] = {}

for label_idx in [0,1,2]:
    res_va['a0'][f'true{label_idx}'] = {}
    logits = run_batches(X_array_a0_va[label_idx], 32)
    pred = logits.argmax(dim=1)
    pred0 = round(((pred == 0).sum() / pred.shape[0]).item(),4)
    pred1 = round(((pred == 1).sum() / pred.shape[0]).item(),4)
    pred2 = round(((pred == 2).sum() / pred.shape[0]).item(),4)
    res_va['a0'][f'true{label_idx}']['pred0'] = pred0
    res_va['a0'][f'true{label_idx}']['pred1'] = pred1
    res_va['a0'][f'true{label_idx}']['pred2'] = pred2

    res_va['a1'][f'true{label_idx}'] = {}
    logits = run_batches(X_array_a1_va[label_idx], 32)
    pred = logits.argmax(dim=1)
    pred0 = round(((pred == 0).sum() / pred.shape[0]).item(),4)
    pred1 = round(((pred == 1).sum() / pred.shape[0]).item(),4)
    pred2 = round(((pred == 2).sum() / pred.shape[0]).item(),4)
    res_va['a1'][f'true{label_idx}']['pred0'] = pred0
    res_va['a1'][f'true{label_idx}']['pred1'] = pred1
    res_va['a1'][f'true{label_idx}']['pred2'] = pred2

res_te = {}
res_te['a1'] = {}
res_te['a0'] = {}

for label_idx in [0,1,2]:

    res_te['a0'][f'true{label_idx}'] = {}
    logits = run_batches(X_array_a0_te[label_idx], 32)
    pred = logits.argmax(dim=1)
    pred0 = round(((pred == 0).sum() / pred.shape[0]).item(),4)
    pred1 = round(((pred == 1).sum() / pred.shape[0]).item(),4)
    pred2 = round(((pred == 2).sum() / pred.shape[0]).item(),4)
    res_te['a0'][f'true{label_idx}']['pred0'] = pred0
    res_te['a0'][f'true{label_idx}']['pred1'] = pred1
    res_te['a0'][f'true{label_idx}']['pred2'] = pred2

    res_te['a1'][f'true{label_idx}'] = {}
    logits = run_batches(X_array_a1_te[label_idx], 32)
    pred = logits.argmax(dim=1)
    pred0 = round(((pred == 0).sum() / pred.shape[0]).item(),4)
    pred1 = round(((pred == 1).sum() / pred.shape[0]).item(),4)
    pred2 = round(((pred == 2).sum() / pred.shape[0]).item(),4)
    res_te['a1'][f'true{label_idx}']['pred0'] = pred0
    res_te['a1'][f'true{label_idx}']['pred1'] = pred1
    res_te['a1'][f'true{label_idx}']['pred2'] = pred2

In [None]:
df = pd.concat({key:pd.DataFrame(res_va[key]) for key in res_va.keys()}, axis=1)
df = df[['a0', 'a1']] * 100
df.to_csv(path_to_tez + f'/results/MNLIexp/va_table1_seed{random_seed}.csv')

In [None]:
df = pd.read_csv(path_to_tez + f'/results/MNLIexp/va_table1_seed{random_seed}.csv', header=[0, 1], index_col=0)
def style_diagonal(data):
    styled_df = pd.DataFrame('', index=data.index, columns=data.columns)
    length = min(data.shape)
    for i in range(length):
        styled_df.iat[i, i] = 'background-color: lightgreen'
        styled_df.iat[i, i+3] = 'background-color: lightgreen'
    return styled_df
print('Validation Dataset:')
styled_df = df.style.apply(style_diagonal, axis=None)
styled_df.format("{:.1f}")

In [None]:
df = pd.concat({key:pd.DataFrame(res_te[key]) for key in res_te.keys()}, axis=1)
df.to_csv(path_to_tez + f'/results/MNLIexp/te_table1_seed{random_seed}.csv')
df = pd.read_csv(path_to_tez + f'/results/MNLIexp/te_table1_seed{random_seed}.csv', header=[0, 1], index_col=0)
df = df[['a0', 'a1']]
def style_diagonal(data):
    styled_df = pd.DataFrame('', index=data.index, columns=data.columns)
    length = min(data.shape)
    for i in range(length):
        styled_df.iat[i, i] = 'background-color: lightgreen'
        styled_df.iat[i, i+3] = 'background-color: lightgreen'
    return styled_df

styled_df = df.style.apply(style_diagonal, axis=None)
print('Test Dataset:')
styled_df.format("{:.2f}")

## Table 2

In [None]:
res_va = {}
res_va['a0'] = {}
res_va['a1'] = {}
res_va['a1_corrupt'] = {}
res_va['a1_control'] = {}

for label_idx in [0,1,2]:
    res_va['a0'][f'true{label_idx}'] = {}
    logits = run_batches(X_array_a0_va[label_idx], 32)
    pred = logits.argmax(dim=1)
    pred0 = round(((pred == 0).sum() / pred.shape[0]).item(),4)
    pred1 = round(((pred == 1).sum() / pred.shape[0]).item(),4)
    pred2 = round(((pred == 2).sum() / pred.shape[0]).item(),4)
    res_va['a0'][f'true{label_idx}']['pred0'] = pred0
    res_va['a0'][f'true{label_idx}']['pred1'] = pred1
    res_va['a0'][f'true{label_idx}']['pred2'] = pred2

    res_va['a1'][f'true{label_idx}'] = {}
    logits = run_batches(X_array_a1_va[label_idx], 32)
    pred = logits.argmax(dim=1)
    pred0 = round(((pred == 0).sum() / pred.shape[0]).item(),4)
    pred1 = round(((pred == 1).sum() / pred.shape[0]).item(),4)
    pred2 = round(((pred == 2).sum() / pred.shape[0]).item(),4)
    res_va['a1'][f'true{label_idx}']['pred0'] = pred0
    res_va['a1'][f'true{label_idx}']['pred1'] = pred1
    res_va['a1'][f'true{label_idx}']['pred2'] = pred2

    res_va['a1_corrupt'][f'true{label_idx}'] = {}
    logits = run_batches(X_array_a1_va_corrupt[label_idx], 32)
    pred = logits.argmax(dim=1)
    pred0 = round(((pred == 0).sum() / pred.shape[0]).item(),4)
    pred1 = round(((pred == 1).sum() / pred.shape[0]).item(),4)
    pred2 = round(((pred == 2).sum() / pred.shape[0]).item(),4)
    res_va['a1_corrupt'][f'true{label_idx}']['pred0'] = pred0
    res_va['a1_corrupt'][f'true{label_idx}']['pred1'] = pred1
    res_va['a1_corrupt'][f'true{label_idx}']['pred2'] = pred2

    res_va['a1_control'][f'true{label_idx}'] = {}
    logits = run_batches(X_array_a1_va_control[label_idx], 32)
    pred = logits.argmax(dim=1)
    pred0 = round(((pred == 0).sum() / pred.shape[0]).item(),4)
    pred1 = round(((pred == 1).sum() / pred.shape[0]).item(),4)
    pred2 = round(((pred == 2).sum() / pred.shape[0]).item(),4)
    res_va['a1_control'][f'true{label_idx}']['pred0'] = pred0
    res_va['a1_control'][f'true{label_idx}']['pred1'] = pred1
    res_va['a1_control'][f'true{label_idx}']['pred2'] = pred2

In [None]:
df = pd.concat({key: pd.DataFrame(res_va[key]) for key in res_va.keys()}, axis=1)
df.to_csv(path_to_tez + f'/results/MNLIexp/va_table2_seed{random_seed}.csv')
df = pd.read_csv(path_to_tez +  f'/results/MNLIexp/va_table2_seed{random_seed}.csv', header=[0, 1], index_col=0)
def style_diagonal(data):
    styled_df = pd.DataFrame('', index=data.index, columns=data.columns)
    length = min(data.shape)
    for i in range(length):
        styled_df.iat[i, i] = 'background-color: lightgreen'
        styled_df.iat[i, i+3] = 'background-color: lightgreen'
        styled_df.iat[i, i+6] = 'background-color: lightgreen'
        styled_df.iat[i, i+9] = 'background-color: lightgreen'
    return styled_df

styled_df = df.style.apply(style_diagonal, axis=None)
styled_df.format("{:.2f}")

## Table 3: Logit Difference

In [None]:
def logits_to_ave_logit_diff(logits, answer_label):
    cls_logits = logits[:, 0, :]
    wrong_label_1 = (answer_label + 1) % 3
    wrong_label_2 = (answer_label + 2) % 3
    # print(answer_label, wrong_label_1, wrong_label_2)
    answer_logit_diff = cls_logits[:, answer_label] - torch.max(cls_logits[:, wrong_label_1], cls_logits[:, wrong_label_2])
    return answer_logit_diff.mean()

In [None]:
def logit_diff(label_idx, X_array, X_array_corrupt, X_array_control, num_forward, batch_size):
    original_average_logit_diff = 0
    corrupted_average_logit_diff = 0
    control_average_logit_diff = 0

    for i in tqdm.tqdm(range(num_forward)):

        example_id = batch_size*i

        tokens = X_array[example_id:example_id+batch_size, : , 0]
        corrupted_tokens = X_array_corrupt[example_id:example_id+batch_size, : , 0]
        control_tokens = X_array_control[example_id:example_id+batch_size, : , 0]
        one_zero_attention_mask = X_array[example_id:example_id+batch_size, : , 1]
        token_type_ids = X_array[example_id:example_id+batch_size, : , 2]

        with torch.no_grad():
            # Logits
            logits = bert(tokens, one_zero_attention_mask=one_zero_attention_mask, token_type_ids=token_type_ids, return_type="logits")
            corrupted_logits = bert(corrupted_tokens, one_zero_attention_mask=one_zero_attention_mask, token_type_ids=token_type_ids, return_type="logits")
            control_logits = bert(control_tokens, one_zero_attention_mask=one_zero_attention_mask, token_type_ids=token_type_ids, return_type="logits")
            # Logit Difference
            original_average_logit_diff += logits_to_ave_logit_diff(logits, answer_label=label_idx).item()
            corrupted_average_logit_diff += logits_to_ave_logit_diff(corrupted_logits, answer_label=label_idx).item()
            control_average_logit_diff += logits_to_ave_logit_diff(control_logits, answer_label=label_idx).item()

    return {'Clean Average Logit Diff': round(original_average_logit_diff/num_forward, 3),
            'Corrupted Average Logit Diff': round(corrupted_average_logit_diff/num_forward, 3),
            'Control Average Logit Diff': round(control_average_logit_diff/num_forward, 3),
            }

va_logit_diff = {}
te_logit_diff = {}
for label_idx in [0,1,2]:
    va_logit_diff[label_idx] = logit_diff(label_idx, X_array_a1_va[label_idx], X_array_a1_va_corrupt[label_idx], X_array_a1_va_control[label_idx], 6, 100)
    te_logit_diff[label_idx] = logit_diff(label_idx, X_array_a1_te[label_idx], X_array_a1_te_corrupt[label_idx], X_array_a1_te_control[label_idx], 8, 100)

In [None]:
df = pd.concat({ 'Validation':pd.DataFrame(va_logit_diff), 'Test':pd.DataFrame(te_logit_diff)}, axis=1)
df.to_csv(path_to_tez + f'/results/MNLIexp/logitdiff_table3_seed{random_seed}.csv')
df = pd.read_csv(path_to_tez +  f'/results/MNLIexp/logitdiff_table3_seed{random_seed}.csv', header=[0, 1], index_col=0)

In [None]:
random_seed

# Preliminary Tables (3 Seeds - Mean/Std.)

In [None]:
table1_va = lambda  SEED : pd.read_csv(path_to_tez + f'/results/MNLIexp/va_table1_seed{SEED}.csv', header=[0, 1], index_col=0)
table1_te =  lambda SEED : pd.read_csv(path_to_tez + f'/results/MNLIexp/te_table1_seed{SEED}.csv', header=[0, 1], index_col=0)
table2_va =  lambda SEED : pd.read_csv(path_to_tez + f'/results/MNLIexp/va_table2_seed{SEED}.csv', header=[0, 1], index_col=0)
table3 = lambda SEED : pd.read_csv(path_to_tez +  f'/results/MNLIexp/logitdiff_table3_seed{SEED}.csv', header=[0, 1], index_col=0)

In [None]:
table1_va = ( table1_va(0) + table1_va(1) + table1_va(2)) / 3
table1_te = ( table1_te(0) + table1_te(1) + table1_te(2)) / 3
table2_va = ( table2_va(0) + table2_va(1) + table2_va(2)) / 3
table3 = ( table3(0) + table3(1) + table3(2)) / 3

In [None]:
# (table2_va[['a0','a1','a1_corrupt']] *100).round(1)
table3.round(2)

# Activation Patching Figures

## Figure 1: Patching the Residual Stream

In [None]:
tokenizer.batch_decode(tokens)

In [None]:
logits[:, 0, :]

In [None]:
tokenizer.batch_decode(corrupted_tokens)

In [None]:
corrupted_logits[:, 0, :]

In [None]:
LABEL_IDX = 1
example_id = 8
num_examples = 1
X_array = X_array_a1_va[LABEL_IDX]
X_array_corrupt = X_array_a1_va_corrupt[LABEL_IDX]
tokens = X_array[example_id:example_id+num_examples, : , 0]
corrupted_tokens = X_array_corrupt[example_id:example_id+num_examples, : , 0]
one_zero_attention_mask = X_array[example_id:example_id+num_examples, : , 1]
token_type_ids = X_array[example_id:example_id+num_examples, : , 2]

logits, cache = bert.run_with_cache(tokens, one_zero_attention_mask=one_zero_attention_mask, token_type_ids=token_type_ids, return_type="logits")
corrupted_logits, corrupted_cache = bert.run_with_cache(corrupted_tokens, one_zero_attention_mask=one_zero_attention_mask, token_type_ids=token_type_ids, return_type="logits")

In [None]:
def logits_to_ave_logit_diff(logits, answer_label):
    cls_logits = logits[:, 0, :]
    wrong_label_1 = (answer_label + 1) % 3
    wrong_label_2 = (answer_label + 2) % 3
    # print(answer_label, wrong_label_1, wrong_label_2)
    answer_logit_diff = cls_logits[:, answer_label] - (cls_logits[:, wrong_label_1] + cls_logits[:, wrong_label_2])
    return answer_logit_diff.mean()

original_average_logit_diff = logits_to_ave_logit_diff(logits, answer_label=LABEL_IDX)
corrupted_average_logit_diff = logits_to_ave_logit_diff(corrupted_logits, answer_label=LABEL_IDX)
print("Clean Average Logit Diff", round(original_average_logit_diff.item(), 2))
print("Corrupted Average Logit Diff", round(corrupted_average_logit_diff.item(), 2))

In [None]:
def patch_residual_component(
    corrupted_residual_component: Float[torch.Tensor, "batch pos d_model"],
    hook,
    pos,
    clean_cache,
    ):
    corrupted_residual_component[:, pos, :] = clean_cache[hook.name][:, pos, :]
    return corrupted_residual_component

def normalize_patched_logit_diff(patched_logit_diff):
    # Subtract corrupted logit diff to measure the improvement, divide by the total improvement from clean to corrupted to normalise
    # 0 means zero change, negative means actively made worse, 1 means totally recovered clean performance, >1 means actively *improved* on clean performance
    return (patched_logit_diff - corrupted_average_logit_diff) / (original_average_logit_diff - corrupted_average_logit_diff)

patched_residual_stream_diff = torch.zeros(
    bert.cfg.n_layers, tokens.shape[1], device='cuda', dtype=torch.float32
)

for layer in range(bert.cfg.n_layers):
    for position in range(tokens.shape[1]):
        hook_fn = partial(patch_residual_component, pos=position, clean_cache=cache)
        with torch.no_grad():
            patched_logits = bert.run_with_hooks(
                corrupted_tokens,
                one_zero_attention_mask=one_zero_attention_mask, token_type_ids=token_type_ids,
                fwd_hooks=[(utils.get_act_name("resid_pre", layer), hook_fn)],
                return_type="logits",
            )
        patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_label=LABEL_IDX)

        patched_residual_stream_diff[layer, position] = normalize_patched_logit_diff(
            patched_logit_diff
        )

In [None]:
patched_residual_stream_diff[:, 0:1].shape

In [None]:
patched_residual_stream_diff[:, 1:].sum(dim=1).unsqueeze(dim=1).shape

In [None]:
new = torch.cat([patched_residual_stream_diff[:, 0:1],patched_residual_stream_diff[:, 1:].sum(dim=1).unsqueeze(dim=1) ], dim=1)

In [None]:
new.shape

In [None]:
tokenizer.batch_decode(corrupted_tokens)

In [None]:
crop = 44
imshow(patched_residual_stream_diff[:,:crop],
       x=[str((i , tokenizer.decode(tokens[0].detach().cpu().numpy()[i]))) for i in range(crop)],
       title="Logit Difference From Patched Residual Stream",
       labels={"x": "Position", "y": "Layer"},
           zmin=-1,
           zmax=1,
       )

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

# Assuming 'new' is a PyTorch tensor, convert it to numpy array
data = new.detach().cpu().numpy()

# Set up the figure and axis
fig, ax = plt.subplots(figsize=(10, 8))

# Create a heatmap with Seaborn
sns.heatmap(data, cmap='coolwarm', vmin=-1, vmax=1, annot=True, fmt=".2f",
            cbar_kws={'label': 'Logit Difference'}, ax=ax)

# Set custom x-axis labels
x_labels = ["CLS"] + ["Rest of Context"] * (data.shape[1] - 1)
ax.set_xticklabels(x_labels)

# Set labels and title
ax.set_title("Logit Difference From Patched Residual Stream")
ax.set_xlabel("Position")
ax.set_ylabel("Layer")

# Show the plot
plt.show()

## Figure 2: Patching Sub-Layers

In [None]:
device = 'cuda'
model = bert

patched_attn_diff = torch.zeros(
    model.cfg.n_layers, tokens.shape[1], device=device, dtype=torch.float32
)
patched_mlp_diff = torch.zeros(
    model.cfg.n_layers, tokens.shape[1], device=device, dtype=torch.float32
)
for layer in range(model.cfg.n_layers):
    for position in range(tokens.shape[1]):
        hook_fn = partial(patch_residual_component, pos=position, clean_cache=cache)
        with torch.no_grad():
            patched_attn_logits = model.run_with_hooks(
                corrupted_tokens,
                one_zero_attention_mask=one_zero_attention_mask, token_type_ids=token_type_ids,
                fwd_hooks=[(utils.get_act_name("attn_out", layer), hook_fn)],
                return_type="logits",
            )
            patched_attn_logit_diff = logits_to_ave_logit_diff(
                patched_attn_logits, answer_label=LABEL_IDX
            )
            patched_mlp_logits = model.run_with_hooks(
                corrupted_tokens,
                one_zero_attention_mask=one_zero_attention_mask, token_type_ids=token_type_ids,
                fwd_hooks=[(utils.get_act_name("mlp_out", layer), hook_fn)],
                return_type="logits",
            )
            patched_mlp_logit_diff = logits_to_ave_logit_diff(
                patched_mlp_logits, answer_label=LABEL_IDX
            )

        patched_attn_diff[layer, position] = normalize_patched_logit_diff(
            patched_attn_logit_diff
        )
        patched_mlp_diff[layer, position] = normalize_patched_logit_diff(
            patched_mlp_logit_diff
        )

### Figure 2.1: : Patching Attention

In [None]:
crop = 128
imshow(
    patched_attn_diff[:, :crop],
    x=[i for i in range(crop)],
    title="Logit Difference From Patched Attention Layer",
    labels={"x": "Position", "y": "Layer"},
           zmin=-1,
           zmax=1,
)

In [None]:
new_attn = torch.cat([patched_attn_diff[:, 0:1],patched_attn_diff[:, 1:].sum(dim=1).unsqueeze(dim=1) ], dim=1)

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

# Assuming 'new' is a PyTorch tensor, convert it to numpy array
data = new_attn.detach().cpu().numpy()

# Set up the figure and axis
fig, ax = plt.subplots(figsize=(10, 8))

# Create a heatmap with Seaborn
sns.heatmap(data, cmap='coolwarm', vmin=-1, vmax=1, annot=True, fmt=".2f",
            cbar_kws={'label': 'Logit Difference'}, ax=ax)

# Set custom x-axis labels
x_labels = ["CLS"] + ["Rest of Context"] * (data.shape[1] - 1)
ax.set_xticklabels(x_labels)

# Set labels and title
ax.set_title("Logit Difference From Patched Attention Layer")
ax.set_xlabel("Position")
ax.set_ylabel("Layer")

# Show the plot
plt.show()

### Figure 2.2: : Patching MLP

In [None]:
imshow(
    patched_mlp_diff[:, :crop],
    x=[i for i in range(crop)],
    title="Logit Difference From Patched MLP Layer",
    labels={"x": "Position", "y": "Layer"},
           zmin=-1,
           zmax=1,
)

In [None]:
new_mlp = torch.cat([patched_mlp_diff[:, 0:1],patched_mlp_diff[:, 1:].sum(dim=1).unsqueeze(dim=1) ], dim=1)

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

# Assuming 'new' is a PyTorch tensor, convert it to numpy array
data = new_mlp.detach().cpu().numpy()

# Set up the figure and axis
fig, ax = plt.subplots(figsize=(10, 8))

# Create a heatmap with Seaborn
sns.heatmap(data, cmap='coolwarm', vmin=-1, vmax=1, annot=True, fmt=".2f",
            cbar_kws={'label': 'Logit Difference'}, ax=ax)

# Set custom x-axis labels
x_labels = ["CLS"] + ["Rest of Context"] * (data.shape[1] - 1)
ax.set_xticklabels(x_labels)

# Set labels and title
ax.set_title("Logit Difference From Patched MLP Layer")
ax.set_xlabel("Position")
ax.set_ylabel("Layer")

# Show the plot
plt.show()

## Figure 3: Patching Attention Heads

### Figure 3.1: Patching Head Outputs


In [None]:
def patch_head_vector(
    corrupted_head_vector: Float[torch.Tensor, "batch pos head_index d_head"],
    hook,
    head_index,
    clean_cache,
):
    corrupted_head_vector[:, :, head_index, :] = clean_cache[hook.name][
        :, :, head_index, :
    ]
    return corrupted_head_vector

patched_head_z_diff = torch.zeros(
    model.cfg.n_layers, model.cfg.n_heads, device=device, dtype=torch.float32
)
for layer in range(model.cfg.n_layers):
    for head_index in range(model.cfg.n_heads):
        hook_fn = partial(patch_head_vector, head_index=head_index, clean_cache=cache)
        with torch.no_grad():
            patched_logits = model.run_with_hooks(
                corrupted_tokens,
                one_zero_attention_mask=one_zero_attention_mask, token_type_ids=token_type_ids,
                fwd_hooks=[(utils.get_act_name("z", layer, "attn"), hook_fn)],
                return_type="logits",
            )
        patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_label=LABEL_IDX)

        patched_head_z_diff[layer, head_index] = normalize_patched_logit_diff(
            patched_logit_diff
        )

In [None]:
imshow(
    patched_head_z_diff,
    title="Logit Difference From Patched Head Output",
    labels={"x": "Head", "y": "Layer"},
           zmin=-1,
           zmax=1,
)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

# Assuming 'new' is a PyTorch tensor, convert it to numpy array
data = patched_head_z_diff.detach().cpu().numpy()

# Set up the figure and axis
fig, ax = plt.subplots(figsize=(10, 8))

# Create a heatmap with Seaborn
sns.heatmap(data, cmap='coolwarm', vmin=-1, vmax=1, annot=True, fmt=".2f",
            cbar_kws={'label': 'Logit Difference'}, ax=ax)

# Set custom x-axis labels
# x_labels = ["CLS"] + ["Rest of Context"] * (data.shape[1] - 1)
# ax.set_xticklabels(x_labels)

# Set labels and title
ax.set_title("Logit Difference From Patched Attention Head")
ax.set_xlabel("Head")
ax.set_ylabel("Layer")

# Show the plot
plt.show()

### Figure 3.2: Patching Head Values


In [None]:
patched_head_v_diff = torch.zeros(
    model.cfg.n_layers, model.cfg.n_heads, device=device, dtype=torch.float32
)
for layer in range(model.cfg.n_layers):
    for head_index in range(model.cfg.n_heads):
        hook_fn = partial(patch_head_vector, head_index=head_index, clean_cache=cache)
        patched_logits = model.run_with_hooks(
            corrupted_tokens,
            one_zero_attention_mask=one_zero_attention_mask, token_type_ids=token_type_ids,
            fwd_hooks=[(utils.get_act_name("v", layer, "attn"), hook_fn)],
            return_type="logits",
        )
        patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_label=LABEL_IDX)

        patched_head_v_diff[layer, head_index] = normalize_patched_logit_diff(
            patched_logit_diff
        )

In [None]:
imshow(
    patched_head_v_diff,
    title="Logit Difference From Patched Head Value",
    labels={"x": "Head", "y": "Layer"},
)

### Figure 3.3: Patch Attention Pattern

In [None]:
def patch_head_pattern(
    corrupted_head_pattern: Float[torch.Tensor, "batch head_index query_pos d_head"],
    hook,
    head_index,
    clean_cache,
):
    corrupted_head_pattern[:, head_index, :, :] = clean_cache[hook.name][
        :, head_index, :, :
    ]
    return corrupted_head_pattern

patched_head_attn_diff = torch.zeros(
    model.cfg.n_layers, model.cfg.n_heads, device=device, dtype=torch.float32
)
for layer in range(model.cfg.n_layers):
    for head_index in range(model.cfg.n_heads):
        hook_fn = partial(patch_head_pattern, head_index=head_index, clean_cache=cache)
        with torch.no_grad():
            patched_logits = model.run_with_hooks(
                corrupted_tokens,
                one_zero_attention_mask=one_zero_attention_mask, token_type_ids=token_type_ids,
                fwd_hooks=[(utils.get_act_name("attn", layer, "attn"), hook_fn)],
                return_type="logits",
            )
        patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_label=LABEL_IDX)

        patched_head_attn_diff[layer, head_index] = normalize_patched_logit_diff(
            patched_logit_diff
        )

In [None]:
imshow(
    patched_head_attn_diff,
    title="Logit Difference From Patched Head Pattern",
    labels={"x": "Head", "y": "Layer"},
)
head_labels = [
    f"L{l}H{h}" for l in range(model.cfg.n_layers) for h in range(model.cfg.n_heads)
]

### Plot 3.4: Output vs. Value Patching

In [None]:
head_labels = [
    f"L{l}H{h}" for l in range(model.cfg.n_layers) for h in range(model.cfg.n_heads)
]
scatter(
    x=utils.to_numpy(patched_head_v_diff.flatten()),
    y=utils.to_numpy(patched_head_z_diff.flatten()),
    xaxis="Value Patch",
    yaxis="Output Patch",
    caxis="Layer",
    hover_name=head_labels,
    color=einops.repeat(
        np.arange(model.cfg.n_layers), "layer -> (layer head)", head=model.cfg.n_heads
    ),
    range_x=(-0.05, 0.2),
    range_y=(-0.05, 0.2),
    title="Scatter plot of output patching vs value patching",
)

### Plot 3.5: Output vs. Attention Patching


In [None]:
head_labels = [
    f"L{l}H{h}" for l in range(model.cfg.n_layers) for h in range(model.cfg.n_heads)
]
scatter(
    x=utils.to_numpy(patched_head_attn_diff.flatten()),
    y=utils.to_numpy(patched_head_z_diff.flatten()),
    hover_name=head_labels,
    xaxis="Attention Patch",
    yaxis="Output Patch",
    title="Scatter plot of output patching vs attention patching",
)