# Install

In [None]:
!pip install einops datasets jaxtyping better_abc fancy_einsum wandb netcal

# Setup

In [None]:
import sys
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
path_to_root = '/content/drive/My Drive/Colab Notebooks/BatuEl_Dissertation'
sys.path.append(path_to_root)
print("Drive mounted.")

data_path = path_to_root + '/data'

In [None]:
import torch
import tqdm
from reprshift.learning.algorithms import ERM
from reprshift.models.hparams import hparams_f
from reprshift.dataset.datasets import MultiNLI, CivilComments
from reprshift.dataset.dataloaders import InfiniteDataLoader, FastDataLoader

from reprshift.models.model_param_maps import ERM_to_HookedEncoder, load_focal, load_groupdro, load_jtt, load_lff
from reprshift.models.HookedEncoderConfig import bert_config

from transformer_lens2 import HookedEncoder, HookedTransformerConfig

# Dataset

In [None]:
DATASET = 'MultiNLI'  # 'CivilComments' , 'MultiNLI'

if DATASET == 'MultiNLI':
    NUM_CLASSES = 3
    NUM_ATTRIBUTES = 2
    # train_dataset = MultiNLI(data_path, 'tr', hparams)
    val_dataset = MultiNLI(data_path, 'va', hparams=hparams_f('ERM'))
    # te_dataset = MultiNLI(data_path, 'te', hparams=hparams_f('ERM'))
    models_path = path_to_root + '/models/models_mnli'
    representations_path = path_to_root + '/representations/representations_mnli'
    print(DATASET)
elif DATASET  == 'CivilComments':
    NUM_CLASSES = 2
    NUM_ATTRIBUTES = 8
    # train_dataset = CivilComments(data_path, 'tr', hparams, granularity="fine")
    val_dataset = CivilComments(data_path, 'va', hparams=hparams_f('ERM'))
    # te_dataset = CivilComments(data_path, 'te', hparams=hparams_f('ERM'))
    models_path = path_to_root + '/models/models_civilcomments'
    representations_path = path_to_root + '/representations/representations_civilcomments'
    print(DATASET)
else:
    print('Dataset Not Implemented')

# Load Model

In [None]:
### MODELS ###
MODELS  = {'pretrained': {'path': models_path + '/00_randominit/',  'load_f': lambda x: x, 'epoch':0,},
          'erm': {'path': models_path + '/01_erm/',  'load_f': lambda x: x, 'epoch':30,},
          'groupdro': {'path': models_path + '/03_groupdro/',  'load_f': load_groupdro, 'epoch':30,},
          'jtt': {'path': models_path + '/06_jtt/',  'load_f': load_jtt, 'epoch':30,},
          'lff': {'path': models_path + '/07_lff/',  'load_f': load_lff, 'epoch':30,},
          'focal': {'path': models_path + '/15_focal/',  'load_f': lambda x: x, 'epoch':30,},}

In [None]:
### Load Statedict ###
algorithm_name = 'erm'
state_dict_PATH = MODELS[algorithm_name]['path']
load_f =  MODELS[algorithm_name]['load_f']


seed = 0
epoch = 30
algorithm_state_dict_PATH = state_dict_PATH + f'seed{seed}/sd_epoch{epoch}.pth'
sd = load_f(torch.load(algorithm_state_dict_PATH))
### Load ERM Model ###
bert = HookedEncoder(HookedTransformerConfig(**bert_config(3)))
bert.load_state_dict(ERM_to_HookedEncoder(sd, bert.state_dict()))

In [None]:
seed = 1
epoch = 30
algorithm_state_dict_PATH = state_dict_PATH + f'seed{seed}/sd_epoch{epoch}.pth'
sd1 = load_f(torch.load(algorithm_state_dict_PATH))
### Load ERM Model ###
bert1 = HookedEncoder(HookedTransformerConfig(**bert_config(3)))
bert1.load_state_dict(ERM_to_HookedEncoder(sd1, bert.state_dict()))

In [None]:
bert.W_U, bert1.W_U

In [None]:
### Load Statedict ###
algorithm_name = 'erm'
state_dict_PATH = MODELS[algorithm_name]['path']
load_f =  MODELS[algorithm_name]['load_f']
epoch = MODELS[algorithm_name]['epoch']
seed = 0
algorithm_state_dict_PATH = state_dict_PATH + f'seed{seed}/sd_epoch{epoch}.pth'
sd = load_f(torch.load(algorithm_state_dict_PATH))

### Initialize ERM Model ###
hparams = hparams_f('ERM')
algorithm = ERM(num_classes=NUM_CLASSES, num_attributes=NUM_ATTRIBUTES, hparams=hparams)
algorithm.load_state_dict(sd)

### Load ERM Model ###
bert = HookedEncoder(HookedTransformerConfig(**bert_config(3)))
bert.load_state_dict(ERM_to_HookedEncoder(sd, bert.state_dict()))