In [1]:
%load_ext autoreload
%autoreload 2

import torch

import circuits_benchmark.benchmark.cases.case_3 as case3
import circuits_benchmark.benchmark.cases.case_4 as case4
from circuits_benchmark.utils.ll_model_loader.ll_model_loader_factory import get_ll_model_loader
from circuits_benchmark.utils.iit.iit_hl_model import IITHLModel
from circuits_benchmark.transformers.hooked_tracr_transformer import HookedTracrTransformer
from circuits_benchmark.benchmark.vocabs import TRACR_BOS, TRACR_PAD

device = 'cuda' if torch.cuda.is_available() else 'cpu'


#load cases
cases = [case3.Case3(), case4.Case4()]
corrs = []
ll_models = []
hl_models = []
model_pairs = []
for case in cases:
    ll_model_loader = get_ll_model_loader(case, interp_bench=True)
    corr, ll_model = ll_model_loader.load_ll_model_and_correspondence(device=device)
    hl_model = case.get_hl_model()

    if isinstance(hl_model, HookedTracrTransformer):
        hl_model = IITHLModel(hl_model, eval_mode=True)

    model_pair = case.build_model_pair(ll_model=ll_model, hl_model=hl_model)

    corrs.append(corr)
    ll_models.append(ll_model)
    hl_models.append(hl_model)
    model_pairs.append(model_pair)


Moving model to device:  cpu
{'hook_embed': HookPoint(), 'hook_pos_embed': HookPoint(), 'blocks.0.attn.hook_k': HookPoint(), 'blocks.0.attn.hook_q': HookPoint(), 'blocks.0.attn.hook_v': HookPoint(), 'blocks.0.attn.hook_z': HookPoint(), 'blocks.0.attn.hook_attn_scores': HookPoint(), 'blocks.0.attn.hook_pattern': HookPoint(), 'blocks.0.attn.hook_result': HookPoint(), 'blocks.0.mlp.hook_pre': HookPoint(), 'blocks.0.mlp.hook_post': HookPoint(), 'blocks.0.hook_attn_in': HookPoint(), 'blocks.0.hook_q_input': HookPoint(), 'blocks.0.hook_k_input': HookPoint(), 'blocks.0.hook_v_input': HookPoint(), 'blocks.0.hook_mlp_in': HookPoint(), 'blocks.0.hook_attn_out': HookPoint(), 'blocks.0.hook_mlp_out': HookPoint(), 'blocks.0.hook_resid_pre': HookPoint(), 'blocks.0.hook_resid_mid': HookPoint(), 'blocks.0.hook_resid_post': HookPoint(), 'blocks.1.attn.hook_k': HookPoint(), 'blocks.1.attn.hook_q': HookPoint(), 'blocks.1.attn.hook_v': HookPoint(), 'blocks.1.attn.hook_z': HookPoint(), 'blocks.1.attn.hook_



Moving model to device:  cpu
{'hook_embed': HookPoint(), 'hook_pos_embed': HookPoint(), 'blocks.0.attn.hook_k': HookPoint(), 'blocks.0.attn.hook_q': HookPoint(), 'blocks.0.attn.hook_v': HookPoint(), 'blocks.0.attn.hook_z': HookPoint(), 'blocks.0.attn.hook_attn_scores': HookPoint(), 'blocks.0.attn.hook_pattern': HookPoint(), 'blocks.0.attn.hook_result': HookPoint(), 'blocks.0.mlp.hook_pre': HookPoint(), 'blocks.0.mlp.hook_post': HookPoint(), 'blocks.0.hook_attn_in': HookPoint(), 'blocks.0.hook_q_input': HookPoint(), 'blocks.0.hook_k_input': HookPoint(), 'blocks.0.hook_v_input': HookPoint(), 'blocks.0.hook_mlp_in': HookPoint(), 'blocks.0.hook_attn_out': HookPoint(), 'blocks.0.hook_mlp_out': HookPoint(), 'blocks.0.hook_resid_pre': HookPoint(), 'blocks.0.hook_resid_mid': HookPoint(), 'blocks.0.hook_resid_post': HookPoint(), 'blocks.1.attn.hook_k': HookPoint(), 'blocks.1.attn.hook_q': HookPoint(), 'blocks.1.attn.hook_v': HookPoint(), 'blocks.1.attn.hook_z': HookPoint(), 'blocks.1.attn.hook_

In [2]:
print(TRACR_BOS, TRACR_PAD)
print(sorted(list(case.get_vocab())))
for hl_model in hl_models:
    print(dir(hl_model.tracr_input_encoder))
    print(hl_model.tracr_input_encoder.encoding_map)

BOS PAD
['(', ')', 'a', 'b', 'c']
['__abstractmethods__', '__annotations__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__slots__', '__str__', '__subclasshook__', '__weakref__', '_abc_impl', '_bos_token', '_max_seq_len', '_pad_token', 'basis', 'bos_encoding', 'bos_token', 'decode', 'encode', 'encoding_map', 'enforce_bos', 'pad_encoding', 'pad_token', 'vocab_size']
{'BOS': 0, 'PAD': 1, 'a': 2, 'b': 3, 'c': 4, 'x': 5}
['__abstractmethods__', '__annotations__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce

In [3]:
#find overall min and max sequence length (we'll have to pad some)
min_seq_len = 100000
max_seq_len = 0

for case in cases:
    if case.get_min_seq_len() < min_seq_len:
        min_seq_len = case.get_min_seq_len()
    if case.get_max_seq_len() > max_seq_len:
        max_seq_len = case.get_max_seq_len()
print(min_seq_len, max_seq_len)

4 10


In [4]:
# find max vocab size (some tasks will just use a subset of vocab size)
vocab_size = 0
for case in cases:
    if len(case.get_vocab()) > vocab_size:
        vocab_size = len(case.get_vocab())
print(vocab_size)


5


In [5]:
max_case_samples = 1_000
case_samples = []
for case in cases:
    num_samples = min(max_case_samples, case.get_total_data_len())
    case_samples.append(num_samples)
print(case_samples)

[320, 1000]


In [41]:
# generate clean and corrupted datasets
clean_datasets = []
corrupted_datasets = []
masks = []
for task_id, case, samples, hl_model in zip(range(len(cases)), cases, case_samples, hl_models):
    clean_dataset = case.get_clean_data(max_samples=samples)
    corrupted_dataset = case.get_corrupted_data(max_samples=samples)
    encoder = hl_model.tracr_input_encoder    
    print(encoder.encoding_map)
    def encode(tok):
        return encoder.encoding_map[tok]

    for dataset in (clean_dataset, corrupted_dataset):
        #Input
        #put task_id after BOS token and pads after EOS.
        inputs = dataset.inputs
        str_tokens = [encoder.decode(inputs[i].tolist()) for i in range(inputs.shape[0])]
        str_task_id = encoder.decode([task_id])
        pads = [TRACR_PAD] * (max_seq_len - case.get_max_seq_len())
        str_tokens = [str_task_id + [TRACR_BOS] + tokens[1:] + pads for tokens in str_tokens]
        inputs = torch.tensor([list(map(encode, tokens)) for tokens in str_tokens])

        #Target
        #add 0 to beginning of seq and a bunch of 0s to end.
        target = dataset.targets
        label = torch.zeros((target.shape[0], 1, target.shape[2]), dtype=target.dtype)
        pads = torch.zeros((target.shape[0], max_seq_len - case.get_max_seq_len(), target.shape[2]), dtype=target.dtype)
        target = torch.cat((label, target, pads), dim=1)
        dataset.inputs = inputs
        dataset.targets = target
    print(clean_dataset.inputs.shape, corrupted_dataset.inputs.shape)
    print(clean_dataset.targets.shape, corrupted_dataset.targets.shape)
    clean_datasets.append(clean_dataset)
    corrupted_datasets.append(corrupted_dataset)

    # if case.get_max_seq_len() < max_seq_len:
    #     # pad sequences
        




{'BOS': 0, 'PAD': 1, 'a': 2, 'b': 3, 'c': 4, 'x': 5}
torch.Size([320, 11]) torch.Size([320, 11])
torch.Size([320, 11, 1]) torch.Size([320, 11, 1])




{'(': 0, ')': 1, 'BOS': 2, 'PAD': 3, 'a': 4, 'b': 5, 'c': 6}
torch.Size([1000, 11]) torch.Size([1000, 11])
torch.Size([1000, 11, 1]) torch.Size([1000, 11, 1])


In [46]:
#make all datasets ~the same length by duplicating shorter datasets

max_length = max([clean_dataset.inputs.shape[0] for clean_dataset in clean_datasets])

for dset_list in [clean_datasets, corrupted_datasets]:
    for dataset in dset_list:
        if dataset.inputs.shape[0] < max_length:
            num_dups = max_length // dataset.inputs.shape[0]
            dataset.inputs = dataset.inputs.repeat(num_dups, 1)
            dataset.targets = dataset.targets.repeat(num_dups, 1, 1)
print(clean_datasets[0].inputs.shape, clean_datasets[1].inputs.shape)
print(corrupted_datasets[0].inputs.shape, corrupted_datasets[1].inputs.shape)

torch.Size([960, 11]) torch.Size([1000, 11])
torch.Size([960, 11]) torch.Size([1000, 11])


In [48]:
from circuits_benchmark.benchmark.tracr_encoded_dataset import TracrEncodedDataset
#smush datasets together into one big dataset, then shuffle
datasets = []
for dset_list in [clean_datasets, corrupted_datasets]:
    inputs = torch.cat([dset.inputs for dset in dset_list], dim=0)
    targets = torch.cat([dset.targets for dset in dset_list], dim=0)

    # shuffle dataset contents, keeping inputs and targets in sync
    indices = torch.randperm(inputs.shape[0])
    inputs = inputs[indices]
    targets = targets[indices]

    datasets.append(TracrEncodedDataset(inputs, targets))


In [49]:
loaders = [d.make_loader(batch_size = 256, shuffle=True, device=device) for d in datasets]

In [51]:
print(loaders)

[<torch.utils.data.dataloader.DataLoader object at 0x333f90d10>, <torch.utils.data.dataloader.DataLoader object at 0x333ed95d0>]


In [55]:
print(corrs[0].suffixes)

{'attn': 'attn.hook_result', 'mlp': 'mlp.hook_post'}


In [59]:
print(hl_model.cfg.n_heads)

2


In [77]:
from transformer_lens.hook_points import HookedRootModule, HookPoint

class MultiHighLevelModel(HookedRootModule):

    def __init__(self, hl_models, corrs, cases):
        super().__init__()
        self.hl_models = hl_models
        self.corrs = corrs
        self.cases = cases

        n_layers = max([mod.cfg.n_layers for mod in self.hl_models])
        n_heads = max([mod.cfg.n_heads for mod in self.hl_models])
        self.n_ctx = max([mod.cfg.n_ctx for mod in self.hl_models]) + 1

        #make hooks for each attn head and each mlp
        self.attn_hooks = [[HookPoint() for _ in range(n_heads)] for _ in range(n_layers)]
        self.mlp_hooks = [[HookPoint() for _ in range(n_layers)] for _ in range(n_layers)]

        self.setup()
    
    def forward(self, x):
        # get sorting indices by task id
        task_ids = x[:, 0]
        argsort = torch.argsort(task_ids)
        reverse_indices = torch.argsort(argsort)
        outputs = torch.zeros((x.shape[0], x.shape[1], 1))
        for i, model in enumerate(self.hl_models):
            case = self.cases[i]
            inputs = x[task_ids == i, 1:case.get_max_seq_len()+1]
            model_output = model(inputs)
            outputs[task_ids == i, 1:case.get_max_seq_len()+1,:] = model_output
            print(model_output.shape)
        # outputs = outputs[reverse_indices]

        return outputs

model = MultiHighLevelModel(hl_models, corrs, cases)

for b1, b2 in zip(*tuple(loaders)):
    input, target = b1
    output = model(input)
    print(input[:1])
    print(output[:1])
    print(target[:1])
    break

torch.Size([113, 5, 1])
torch.Size([143, 10, 1])
tensor([[1, 2, 1, 1, 4, 0, 5, 5, 0, 0, 0]])
tensor([[[ 0.0000],
         [ 0.0000],
         [-1.0000],
         [-1.0000],
         [-0.6667],
         [-0.2500],
         [-0.2000],
         [-0.1667],
         [ 0.0000],
         [ 0.1250],
         [ 0.2222]]])
tensor([[[ 0.0000],
         [ 0.0000],
         [-1.0000],
         [-1.0000],
         [-0.6667],
         [-0.2500],
         [-0.2000],
         [-0.1667],
         [ 0.0000],
         [ 0.1250],
         [ 0.2222]]])
