# Imports

In [1]:
%load_ext autoreload
%autoreload 2

from IPython.display import clear_output
import torch
from torch.utils.data import Dataset
import transformer_lens as tl

from iit.model_pairs.strict_iit_model_pair import StrictIITModelPair
from iit.utils.iit_dataset import train_test_split
from iit.utils.iit_dataset import IITDataset

import circuits_benchmark.benchmark.cases.case_3 as case3
import circuits_benchmark.benchmark.cases.case_4 as case4
from circuits_benchmark.utils.ll_model_loader.ll_model_loader_factory import get_ll_model_loader
from circuits_benchmark.utils.iit.iit_hl_model import IITHLModel
from circuits_benchmark.transformers.hooked_tracr_transformer import HookedTracrTransformer
from circuits_benchmark.benchmark.vocabs import TRACR_BOS, TRACR_PAD

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Load cases from huggingface

In [2]:
#load cases
cases = [case3.Case3(), case4.Case4()]
# cases = [case3.Case3(),]
# cases = [case4.Case4()]
corrs = []
ll_models = []
hl_models = []
model_pairs = []
for case in cases:
    ll_model_loader = get_ll_model_loader(case, interp_bench=True)
    corr, ll_model = ll_model_loader.load_ll_model_and_correspondence(device=device)
    hl_model = case.get_hl_model()

    if isinstance(hl_model, HookedTracrTransformer):
        hl_model = IITHLModel(hl_model, eval_mode=True)

    model_pair = case.build_model_pair(ll_model=ll_model, hl_model=hl_model)

    corrs.append(corr)
    ll_models.append(ll_model)
    hl_models.append(hl_model)
    model_pairs.append(model_pair)


Moving model to device:  cpu




Moving model to device:  cpu


# Generate mixed dataset

In [3]:
#find overall min and max sequence length (we'll have to pad some)
min_seq_len = 100000
max_seq_len = 0

for case in cases:
    if case.get_min_seq_len() < min_seq_len:
        min_seq_len = case.get_min_seq_len()
    if case.get_max_seq_len() > max_seq_len:
        max_seq_len = case.get_max_seq_len()
print(min_seq_len, max_seq_len)

4 10


In [4]:
# find max vocab size (some tasks will just use a subset of vocab size)
vocab_size = 0
for case in cases:
    if len(case.get_vocab()) > vocab_size:
        vocab_size = len(case.get_vocab())
print(vocab_size)


5


In [5]:
max_case_samples = 10_000
case_samples = []
for case in cases:
    num_samples = min(max_case_samples, case.get_total_data_len())
    case_samples.append(num_samples)
print(case_samples)

[320, 10000]


In [6]:
# generate clean and corrupted datasets
clean_datasets = []
corrupted_datasets = []
masks = []
for task_id, case, samples, hl_model in zip(range(len(cases)), cases, case_samples, hl_models):
    dataset = case.get_clean_data(max_samples=samples)
    encoder = hl_model.tracr_input_encoder    
    # print(encoder.encoding_map)
    def encode(tok):
        return encoder.encoding_map[tok]


    #Input
    #put task_id after BOS token and pads after EOS.
    inputs = dataset.inputs
    str_tokens = [encoder.decode(inputs[i].tolist()) for i in range(inputs.shape[0])]
    str_task_id = encoder.decode([task_id])
    pads = [TRACR_PAD] * (max_seq_len - case.get_max_seq_len())
    str_tokens = [str_task_id + [TRACR_BOS] + tokens[1:] + pads for tokens in str_tokens]
    inputs = torch.tensor([list(map(encode, tokens)) for tokens in str_tokens])

    #Target
    #add 0 to beginning of seq and a bunch of 0s to end.
    target = dataset.targets
    label = torch.zeros((target.shape[0], 1, target.shape[2]), dtype=target.dtype)
    pads = torch.zeros((target.shape[0], max_seq_len - case.get_max_seq_len(), target.shape[2]), dtype=target.dtype)
    target = torch.cat((label, target, pads), dim=1)
    dataset.inputs = inputs
    dataset.targets = target
    # print(clean_dataset.inputs.shape, dataset.inputs.shape)
    # print(clean_dataset.targets.shape, dataset.targets.shape)
    clean_datasets.append(dataset)

    # if case.get_max_seq_len() < max_seq_len:
    #     # pad sequences

clear_output()



In [None]:
#make all datasets ~the same length by duplicating shorter datasets
max_length = max([clean_dataset.inputs.shape[0] for clean_dataset in clean_datasets])

for dset_list in [clean_datasets,]:
    for dataset in dset_list:
        if dataset.inputs.shape[0] < max_length:
            num_dups = max_length // dataset.inputs.shape[0]
            dataset.inputs = dataset.inputs.repeat(num_dups, 1)
            dataset.targets = dataset.targets.repeat(num_dups, 1, 1)
print(clean_datasets[0].inputs.shape,)# clean_datasets[1].inputs.shape)
# print(corrupted_datasets[0].inputs.shape)#, corrupted_datasets[1].inputs.shape)

torch.Size([9920, 11])


In [None]:
#smush datasets together into one big dataset, then shuffle
datasets = []
for dset_list in [clean_datasets,]:
    print(dset_list[0].inputs)
    inputs = torch.cat([dset.inputs for dset in dset_list], dim=0)
    targets = torch.cat([dset.targets for dset in dset_list], dim=0)

    # shuffle dataset contents, keeping inputs and targets in sync
    indices = torch.randperm(inputs.shape[0])
    inputs = inputs[indices]
    targets = targets[indices]

    class CustomDataset(Dataset):
        def __init__(self, data, targets):
            """
            Args:
                data (list or numpy array): List or array of input data.
                targets (list or numpy array): List or array of target data.
            """
            self.data = torch.tensor(data).to(int)
            self.targets = torch.tensor(targets)

        def __len__(self):
            return len(self.data)

        def __getitem__(self, idx):
            """
            Args:
                idx (int): Index
            Returns:
                tuple: (input tensor, target tensor)
            """
            return self.data[idx], self.targets[idx]

    decorated_dset = CustomDataset(
        data = inputs,
        targets = targets,
    )
    print(decorated_dset.data.shape)
    print(decorated_dset.targets.shape)

    train_dataset, test_dataset = train_test_split(
        decorated_dset, test_size=0.2, random_state=42
    )
    train_set = IITDataset(train_dataset, train_dataset, seed=0)
    test_set = IITDataset(test_dataset, test_dataset, seed=0)


tensor([[0, 0, 2,  ..., 1, 1, 1],
        [0, 0, 4,  ..., 1, 1, 1],
        [0, 0, 5,  ..., 1, 1, 1],
        ...,
        [0, 0, 5,  ..., 1, 1, 1],
        [0, 0, 2,  ..., 1, 1, 1],
        [0, 0, 3,  ..., 1, 1, 1]])
torch.Size([19920, 11])
torch.Size([19920, 11, 1])


  self.data = torch.tensor(data).to(int)
  self.targets = torch.tensor(targets)


In [None]:
#Test making a loader works
loader = train_set.make_loader(batch_size=32, num_workers=0)
for b, s in loader:
    break

# Build Polysemantic model

In [None]:
from poly_hl_model import PolyHLModel

model = PolyHLModel(hl_models, corrs, cases)

input, target = train_dataset[4]
output, cache = model.run_with_cache(input[None,:])
print(input)
# print(hl_models[0](input[1:6]))
print(output[:1])
print(target)
print(model.mask)

tensor([0, 0, 4, 3, 3, 3, 1, 1, 1, 1, 1])
tensor([[[0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.]]])
tensor([[0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.]])
tensor([[[False],
         [ True],
         [ True],
         [ True],
         [ True],
         [ True],
         [False],
         [False],
         [False],
         [False],
         [False]]])


# Import LL model and build model pair

In [None]:


D_MODEL = max(model.tracr_d_models)
N_CTX = model.n_ctx
N_LAYERS = model.n_layers
N_HEADS = model.n_heads
D_VOCAB = max([hl_model.cfg.d_vocab for hl_model in model.hl_models])

#Want to specify this somewhere central, e.g., iit.tasks.ioi but iit.tasks.parens
ll_cfg = tl.HookedTransformerConfig(
        n_layers = N_LAYERS,
        d_model = D_MODEL,
        n_ctx = N_CTX,
        d_head = D_MODEL // N_HEADS,
        d_vocab = D_VOCAB,
        act_fn = "relu",
)

class SingleOutputHookedTransformer(tl.HookedTransformer):
    def forward(self, x):
        output = super().forward(x)
        return output[:,:,:1]

ll_model = SingleOutputHookedTransformer(ll_cfg).to(device)

Moving model to device:  cpu


In [None]:
n_epochs = 1000
training_args = {
    "batch_size": 256,
    "lr": 0.001,
    "num_workers": 0,
    "use_single_loss": True,
    "behavior_weight": 1.,
    "iit_weight": 1.,
    "strict_weight": 0.4,
    "clip_grad_norm": 1.0,
    "early_stop" : True,
    "lr_scheduler": torch.optim.lr_scheduler.LinearLR,
    "scheduler_kwargs": dict(start_factor=1, end_factor=0, total_iters=n_epochs),
    "scheduler_val_metric": ["val/accuracy", "val/IIA"], #for ReduceLRonPlateau
    "scheduler_mode": "max", #for ReduceLRonPlateau
}
model_pair = StrictIITModelPair(hl_model=model, ll_model=ll_model, corr=model.corr, training_args=training_args)

In [None]:
print(model.corr)
print(model_pair.nodes_not_in_circuit)

{input_hook: {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}, task_hook: {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}, mlp_hooks.0: {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}, attn_hooks.1.0: {LLNode(name='blocks.1.attn.hook_z', index=[:, :, 0, :], subspace=None)}, mlp_hooks.1: {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}}
[LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 1, :], subspace=None)]


### train model:

In [None]:
model_pair.train(
    train_set=train_set,
    test_set=test_set,
    optimizer_cls=torch.optim.AdamW,
    epochs=n_epochs,
    optimizer_kwargs=dict(weight_decay=1e-4),
)

training_args={'batch_size': 256, 'lr': 0.001, 'num_workers': 0, 'early_stop': True, 'lr_scheduler': <class 'torch.optim.lr_scheduler.LinearLR'>, 'scheduler_val_metric': ['val/accuracy', 'val/IIA'], 'scheduler_mode': 'max', 'scheduler_kwargs': {'start_factor': 1, 'end_factor': 0, 'total_iters': 1000}, 'clip_grad_norm': 1.0, 'seed': 0, 'detach_while_caching': True, 'atol': 0.05, 'use_single_loss': False, 'iit_weight': 1.0, 'behavior_weight': 1.0, 'strict_weight': 0.4}
Epoch 1: lr: 9.99e-04, train/iit_loss: 0.0453, train/behavior_loss: 0.0170, train/strict_loss: 0.0056, val/iit_loss: 0.0218, val/IIA: 50.92%, val/accuracy: 68.98%, val/strict_accuracy: 43.54%
Epoch 2: lr: 9.98e-04, train/iit_loss: 0.0206, train/behavior_loss: 0.0030, train/strict_loss: 0.0031, val/iit_loss: 0.0227, val/IIA: 48.09%, val/accuracy: 81.20%, val/strict_accuracy: 48.34%
Epoch 3: lr: 9.97e-04, train/iit_loss: 0.0149, train/behavior_loss: 0.0019, train/strict_loss: 0.0025, val/iit_loss: 0.0127, val/IIA: 62.61%, va

In [None]:
# Save model
from safetensors.torch import save_file
state_dict = model_pair.ll_model.state_dict()
tensors = {key: value.cpu() for key, value in state_dict.items()}
save_file(tensors, "cases_03_04_ll_model_IIA_??_strict_??.pt")

In [None]:
#load it in
from safetensors.torch import load_file
state_dict = load_file("cases_03_04_ll_model_IIA_99p84_strict_99p06.pt")
model_pair.ll_model.load_state_dict(state_dict)

<All keys matched successfully>

# Comparison with pure benchmark HL model

In [None]:
hl_model = hl_models[0]
corr = corrs[0]

ll_cfg = cases[0].get_ll_model_cfg()

class SingleOutputHookedTransformer(tl.HookedTransformer):
    def forward(self, x):
        output = super().forward(x)
        return output[:,:,:1]

ll_model = cases[0].get_ll_model()#SingleOutputHookedTransformer(ll_cfg).to(device)

Moving model to device:  cpu


In [None]:
model_pair = StrictIITModelPair(hl_model=hl_model.hl_model, ll_model=ll_model, corr=corr, training_args=training_args)
print(hl_model.W_E.shape, hl_model.W_pos.shape)
print(inputs[0])
hl_model.run_with_cache(inputs[0][None,1:])

torch.Size([6, 13]) torch.Size([5, 13])
tensor([0, 0, 2, 3, 4, 5])


(tensor([[[0.0000],
          [0.0000],
          [0.0000],
          [0.0000],
          [0.2500]]]),
 ActivationCache with keys ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.hook_q_input', 'blocks.0.hook_k_input', 'blocks.0.hook_v_input', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.attn.hook_result', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.hook_mlp_in', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.hook_q_input', 'blocks.1.hook_k_input', 'blocks.1.hook_v_input', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.attn.hook_result', 'blocks.1.hook_attn_out', 'blocks.1.hook_resid_mid', 'blocks.1.hook_mlp_in

In [None]:
dset_list = clean_datasets
inputs = torch.cat([dset.inputs for dset in dset_list], dim=0)
targets = torch.cat([dset.targets for dset in dset_list], dim=0)

# shuffle dataset contents, keeping inputs and targets in sync
indices = torch.randperm(inputs.shape[0])
inputs = inputs[indices][:,1:]
targets = targets[indices][:,1:]

class CustomDataset(Dataset):
    def __init__(self, data, targets):
        """
        Args:
            data (list or numpy array): List or array of input data.
            targets (list or numpy array): List or array of target data.
        """
        self.data = torch.tensor(data).to(int)
        self.targets = torch.tensor(targets)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        """
        Args:
            idx (int): Index
        Returns:
            tuple: (input tensor, target tensor)
        """
        return self.data[idx], self.targets[idx]

decorated_dset = CustomDataset(
    data = inputs,
    targets = targets,
)
print(decorated_dset.data.shape)
print(decorated_dset.targets.shape)

train_dataset, test_dataset = train_test_split(
    decorated_dset, test_size=0.2, random_state=42
)
this_train_set = IITDataset(train_dataset, train_dataset, seed=0)
this_test_set = IITDataset(test_dataset, test_dataset, seed=0)

torch.Size([320, 5])
torch.Size([320, 5, 1])


  self.data = torch.tensor(data).to(int)
  self.targets = torch.tensor(targets)


In [None]:
model_pair.train(
    train_set=this_train_set,
    test_set=this_test_set,
    optimizer_cls=torch.optim.AdamW,
    epochs=n_epochs
)

training_args={'batch_size': 256, 'lr': 0.001, 'num_workers': 0, 'early_stop': True, 'lr_scheduler': <class 'torch.optim.lr_scheduler.LinearLR'>, 'scheduler_val_metric': ['val/accuracy', 'val/IIA'], 'scheduler_mode': 'max', 'scheduler_kwargs': {'start_factor': 1, 'end_factor': 0, 'total_iters': 1000}, 'clip_grad_norm': False, 'seed': 0, 'detach_while_caching': True, 'atol': 0.05, 'use_single_loss': True, 'iit_weight': 0.5, 'behavior_weight': 0.25, 'strict_weight': 0.25}
Epoch 1: lr: 9.99e-04, train/iit_loss: 0.1611, train/behavior_loss: 0.0867, train/strict_loss: 0.0218, val/iit_loss: 0.3348, val/IIA: 3.75%, val/accuracy: 5.00%, val/strict_accuracy: 5.20%
Epoch 2: lr: 9.98e-04, train/iit_loss: 0.1472, train/behavior_loss: 0.0768, train/strict_loss: 0.0186, val/iit_loss: 0.3080, val/IIA: 5.62%, val/accuracy: 4.69%, val/strict_accuracy: 5.62%
Epoch 3: lr: 9.97e-04, train/iit_loss: 0.1322, train/behavior_loss: 0.0690, train/strict_loss: 0.0170, val/iit_loss: 0.2797, val/IIA: 8.44%, val/ac