# Imports

In [1]:
%load_ext autoreload
%autoreload 2

from IPython.display import clear_output
import torch
from torch.utils.data import Dataset

from iit.model_pairs.strict_iit_model_pair import StrictIITModelPair
from iit.utils.iit_dataset import train_test_split
from iit.utils.iit_dataset import IITDataset

import circuits_benchmark.benchmark.cases.case_3 as case3
import circuits_benchmark.benchmark.cases.case_4 as case4
from circuits_benchmark.utils.ll_model_loader.ll_model_loader_factory import get_ll_model_loader
from circuits_benchmark.utils.iit.iit_hl_model import IITHLModel
from circuits_benchmark.transformers.hooked_tracr_transformer import HookedTracrTransformer
from circuits_benchmark.benchmark.vocabs import TRACR_BOS, TRACR_PAD

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Load cases from huggingface

In [2]:
#load cases
cases = [case3.Case3(), case4.Case4()]
corrs = []
ll_models = []
hl_models = []
model_pairs = []
for case in cases:
    ll_model_loader = get_ll_model_loader(case, interp_bench=True)
    corr, ll_model = ll_model_loader.load_ll_model_and_correspondence(device=device)
    hl_model = case.get_hl_model()

    if isinstance(hl_model, HookedTracrTransformer):
        hl_model = IITHLModel(hl_model, eval_mode=True)

    model_pair = case.build_model_pair(ll_model=ll_model, hl_model=hl_model)

    corrs.append(corr)
    ll_models.append(ll_model)
    hl_models.append(hl_model)
    model_pairs.append(model_pair)




Moving model to device:  cpu
Moving model to device:  cpu


# Generate mixed dataset

In [3]:
#find overall min and max sequence length (we'll have to pad some)
min_seq_len = 100000
max_seq_len = 0

for case in cases:
    if case.get_min_seq_len() < min_seq_len:
        min_seq_len = case.get_min_seq_len()
    if case.get_max_seq_len() > max_seq_len:
        max_seq_len = case.get_max_seq_len()
print(min_seq_len, max_seq_len)

4 10


In [4]:
# find max vocab size (some tasks will just use a subset of vocab size)
vocab_size = 0
for case in cases:
    if len(case.get_vocab()) > vocab_size:
        vocab_size = len(case.get_vocab())
print(vocab_size)


5


In [5]:
max_case_samples = 10_000
case_samples = []
for case in cases:
    num_samples = min(max_case_samples, case.get_total_data_len())
    case_samples.append(num_samples)
print(case_samples)

[320, 10000]


In [6]:
# generate clean and corrupted datasets
clean_datasets = []
corrupted_datasets = []
masks = []
for task_id, case, samples, hl_model in zip(range(len(cases)), cases, case_samples, hl_models):
    dataset = case.get_clean_data(max_samples=samples)
    encoder = hl_model.tracr_input_encoder    
    # print(encoder.encoding_map)
    def encode(tok):
        return encoder.encoding_map[tok]


    #Input
    #put task_id after BOS token and pads after EOS.
    inputs = dataset.inputs
    str_tokens = [encoder.decode(inputs[i].tolist()) for i in range(inputs.shape[0])]
    str_task_id = encoder.decode([task_id])
    pads = [TRACR_PAD] * (max_seq_len - case.get_max_seq_len())
    str_tokens = [str_task_id + [TRACR_BOS] + tokens[1:] + pads for tokens in str_tokens]
    inputs = torch.tensor([list(map(encode, tokens)) for tokens in str_tokens])

    #Target
    #add 0 to beginning of seq and a bunch of 0s to end.
    target = dataset.targets
    label = torch.zeros((target.shape[0], 1, target.shape[2]), dtype=target.dtype)
    pads = torch.zeros((target.shape[0], max_seq_len - case.get_max_seq_len(), target.shape[2]), dtype=target.dtype)
    target = torch.cat((label, target, pads), dim=1)
    dataset.inputs = inputs
    dataset.targets = target
    # print(clean_dataset.inputs.shape, dataset.inputs.shape)
    # print(clean_dataset.targets.shape, dataset.targets.shape)
    clean_datasets.append(dataset)

    # if case.get_max_seq_len() < max_seq_len:
    #     # pad sequences

clear_output()

In [7]:
#make all datasets ~the same length by duplicating shorter datasets
max_length = max([clean_dataset.inputs.shape[0] for clean_dataset in clean_datasets])

for dset_list in [clean_datasets,]:
    for dataset in dset_list:
        if dataset.inputs.shape[0] < max_length:
            num_dups = max_length // dataset.inputs.shape[0]
            dataset.inputs = dataset.inputs.repeat(num_dups, 1)
            dataset.targets = dataset.targets.repeat(num_dups, 1, 1)
print(clean_datasets[0].inputs.shape, clean_datasets[1].inputs.shape)
# print(corrupted_datasets[0].inputs.shape)#, corrupted_datasets[1].inputs.shape)

torch.Size([9920, 11]) torch.Size([10000, 11])


In [8]:
#smush datasets together into one big dataset, then shuffle
datasets = []
for dset_list in [clean_datasets,]:
    print(dset_list[0].inputs)
    inputs = torch.cat([dset.inputs for dset in dset_list], dim=0)
    targets = torch.cat([dset.targets for dset in dset_list], dim=0)

    # shuffle dataset contents, keeping inputs and targets in sync
    indices = torch.randperm(inputs.shape[0])
    inputs = inputs[indices]
    targets = targets[indices]

    class CustomDataset(Dataset):
        def __init__(self, data, targets):
            """
            Args:
                data (list or numpy array): List or array of input data.
                targets (list or numpy array): List or array of target data.
            """
            self.data = torch.tensor(data).to(int)
            self.targets = torch.tensor(targets)

        def __len__(self):
            return len(self.data)

        def __getitem__(self, idx):
            """
            Args:
                idx (int): Index
            Returns:
                tuple: (input tensor, target tensor)
            """
            return self.data[idx], self.targets[idx]

    decorated_dset = CustomDataset(
        data = inputs,
        targets = targets,
    )
    print(decorated_dset.data.shape)
    print(decorated_dset.targets.shape)

    train_dataset, test_dataset = train_test_split(
        decorated_dset, test_size=0.2, random_state=42
    )
    train_set = IITDataset(train_dataset, train_dataset, seed=0)
    test_set = IITDataset(test_dataset, test_dataset, seed=0)


tensor([[0, 0, 2,  ..., 1, 1, 1],
        [0, 0, 4,  ..., 1, 1, 1],
        [0, 0, 5,  ..., 1, 1, 1],
        ...,
        [0, 0, 5,  ..., 1, 1, 1],
        [0, 0, 2,  ..., 1, 1, 1],
        [0, 0, 3,  ..., 1, 1, 1]])
torch.Size([19920, 11])
torch.Size([19920, 11, 1])


  self.data = torch.tensor(data).to(int)
  self.targets = torch.tensor(targets)


In [9]:
#Test making a loader works
loader = train_set.make_loader(batch_size=32, num_workers=0)
for b, s in loader:
    break

# Build Polysemantic model

In [10]:
from poly_hl_model import PolyHLModel

model = PolyHLModel(hl_models, corrs, cases)

input, target = train_dataset[4]
output, cache = model.run_with_cache(input[None,:])
print(input)
# print(hl_models[0](input[1:6]))
print(output[:1])
print(target)
print(model.mask)

tensor([0, 0, 4, 3, 3, 3, 1, 1, 1, 1, 1])
tensor([[[0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.]]])
tensor([[0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.]])
tensor([[[False],
         [ True],
         [ True],
         [ True],
         [ True],
         [ True],
         [False],
         [False],
         [False],
         [False],
         [False]]])


# Import LL model and build model pair

In [11]:
import transformer_lens as tl

D_MODEL = max(model.d_models)
N_CTX = model.n_ctx
N_LAYERS = model.n_layers
N_HEADS = model.n_heads
D_VOCAB = max([hl_model.cfg.d_vocab for hl_model in model.hl_models])

#Want to specify this somewhere central, e.g., iit.tasks.ioi but iit.tasks.parens
ll_cfg = tl.HookedTransformerConfig(
        n_layers = N_LAYERS,
        d_model = D_MODEL,
        n_ctx = N_CTX,
        d_head = D_MODEL // N_HEADS,
        d_vocab = D_VOCAB,
        act_fn = "relu",
)

class SingleOutputHookedTransformer(tl.HookedTransformer):
    def forward(self, x):
        output = super().forward(x)
        return output[:,:,:1]

ll_model = SingleOutputHookedTransformer(ll_cfg).to(device)

Moving model to device:  cpu


In [12]:
n_epochs = 100
training_args = {
    "batch_size": 256,
    "lr": 0.0001,
    "num_workers": 0,
    "use_single_loss": True,
    "behavior_weight": 0.25,
    "iit_weight": 0.5,
    "strict_weight": 0.25,
    "clip_grad_norm": False, #or float = 1.0
    "early_stop" : True,
    "lr_scheduler": torch.optim.lr_scheduler.LinearLR,
    "scheduler_kwargs": dict(start_factor=1, end_factor=0, total_iters=n_epochs),
    "scheduler_val_metric": ["val/accuracy", "val/IIA"], #for ReduceLRonPlateau
    "scheduler_mode": "max", #for ReduceLRonPlateau
}
model_pair = StrictIITModelPair(hl_model=model, ll_model=ll_model, corr=model.corr, training_args=training_args)

### train model:

In [13]:
model_pair.train(
    train_set=train_set,
    test_set=test_set,
    optimizer_cls=torch.optim.AdamW,
    epochs=n_epochs
)

training_args={'batch_size': 256, 'lr': 0.001, 'num_workers': 0, 'early_stop': True, 'lr_scheduler': <class 'torch.optim.lr_scheduler.LinearLR'>, 'scheduler_val_metric': ['val/accuracy', 'val/IIA'], 'scheduler_mode': 'max', 'scheduler_kwargs': {'start_factor': 1, 'end_factor': 0, 'total_iters': 100}, 'clip_grad_norm': False, 'seed': 0, 'detach_while_caching': True, 'atol': 0.05, 'use_single_loss': True, 'iit_weight': 0.5, 'behavior_weight': 0.25, 'strict_weight': 0.25}


VBox(children=(Training Epochs:   0%|          | 0/100 [00:00<?, ?it/s],))

Epoch 1: lr: 9.90e-04, train/iit_loss: 0.0495, train/behavior_loss: 0.0205, train/strict_loss: 0.0051, val/iit_loss: 0.0487, val/IIA: 27.00%, val/accuracy: 29.57%, val/strict_accuracy: 29.57%
Epoch 2: lr: 9.80e-04, train/iit_loss: 0.0219, train/behavior_loss: 0.0050, train/strict_loss: 0.0012, val/iit_loss: 0.0515, val/IIA: 31.91%, val/accuracy: 41.93%, val/strict_accuracy: 41.93%
Epoch 3: lr: 9.70e-04, train/iit_loss: 0.0183, train/behavior_loss: 0.0026, train/strict_loss: 0.0006, val/iit_loss: 0.0320, val/IIA: 46.52%, val/accuracy: 60.13%, val/strict_accuracy: 60.13%
Epoch 4: lr: 9.60e-04, train/iit_loss: 0.0171, train/behavior_loss: 0.0018, train/strict_loss: 0.0004, val/iit_loss: 0.0305, val/IIA: 40.53%, val/accuracy: 57.51%, val/strict_accuracy: 57.51%
Epoch 5: lr: 9.50e-04, train/iit_loss: 0.0162, train/behavior_loss: 0.0016, train/strict_loss: 0.0004, val/iit_loss: 0.0233, val/IIA: 57.10%, val/accuracy: 71.05%, val/strict_accuracy: 71.05%
Epoch 6: lr: 9.40e-04, train/iit_loss: 0