In [10]:
%load_ext autoreload
%autoreload 2

import torch

import circuits_benchmark.benchmark.cases.case_3 as case3
import circuits_benchmark.benchmark.cases.case_4 as case4
from circuits_benchmark.utils.ll_model_loader.ll_model_loader_factory import get_ll_model_loader
from circuits_benchmark.utils.iit.iit_hl_model import IITHLModel
from circuits_benchmark.transformers.hooked_tracr_transformer import HookedTracrTransformer
from circuits_benchmark.benchmark.vocabs import TRACR_BOS, TRACR_PAD

device = 'cuda' if torch.cuda.is_available() else 'cpu'


#load cases
cases = [case3.Case3(), case4.Case4()]
corrs = []
ll_models = []
hl_models = []
model_pairs = []
for case in cases:
    ll_model_loader = get_ll_model_loader(case, interp_bench=True)
    corr, ll_model = ll_model_loader.load_ll_model_and_correspondence(device=device)
    hl_model = case.get_hl_model()

    if isinstance(hl_model, HookedTracrTransformer):
        hl_model = IITHLModel(hl_model, eval_mode=True)

    model_pair = case.build_model_pair(ll_model=ll_model, hl_model=hl_model)

    corrs.append(corr)
    ll_models.append(ll_model)
    hl_models.append(hl_model)
    model_pairs.append(model_pair)


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Moving model to device:  cpu




Moving model to device:  cpu


In [11]:
print(TRACR_BOS, TRACR_PAD)
print(sorted(list(case.get_vocab())))
for hl_model in hl_models:
    print(dir(hl_model.tracr_input_encoder))
    print(hl_model.tracr_input_encoder.encoding_map)

BOS PAD
['(', ')', 'a', 'b', 'c']
['__abstractmethods__', '__annotations__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__slots__', '__str__', '__subclasshook__', '__weakref__', '_abc_impl', '_bos_token', '_max_seq_len', '_pad_token', 'basis', 'bos_encoding', 'bos_token', 'decode', 'encode', 'encoding_map', 'enforce_bos', 'pad_encoding', 'pad_token', 'vocab_size']
{'BOS': 0, 'PAD': 1, 'a': 2, 'b': 3, 'c': 4, 'x': 5}
['__abstractmethods__', '__annotations__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce

In [12]:
#find overall min and max sequence length (we'll have to pad some)
min_seq_len = 100000
max_seq_len = 0

for case in cases:
    if case.get_min_seq_len() < min_seq_len:
        min_seq_len = case.get_min_seq_len()
    if case.get_max_seq_len() > max_seq_len:
        max_seq_len = case.get_max_seq_len()
print(min_seq_len, max_seq_len)

4 10


In [13]:
# find max vocab size (some tasks will just use a subset of vocab size)
vocab_size = 0
for case in cases:
    if len(case.get_vocab()) > vocab_size:
        vocab_size = len(case.get_vocab())
print(vocab_size)


5


In [14]:
max_case_samples = 1_000
case_samples = []
for case in cases:
    num_samples = min(max_case_samples, case.get_total_data_len())
    case_samples.append(num_samples)
print(case_samples)

[320, 1000]


In [15]:
# generate clean and corrupted datasets
clean_datasets = []
corrupted_datasets = []
masks = []
for task_id, case, samples, hl_model in zip(range(len(cases)), cases, case_samples, hl_models):
    dataset = case.get_clean_data(max_samples=samples)
    encoder = hl_model.tracr_input_encoder    
    print(encoder.encoding_map)
    def encode(tok):
        return encoder.encoding_map[tok]


    #Input
    #put task_id after BOS token and pads after EOS.
    inputs = dataset.inputs
    str_tokens = [encoder.decode(inputs[i].tolist()) for i in range(inputs.shape[0])]
    str_task_id = encoder.decode([task_id])
    pads = [TRACR_PAD] * (max_seq_len - case.get_max_seq_len())
    str_tokens = [str_task_id + [TRACR_BOS] + tokens[1:] + pads for tokens in str_tokens]
    inputs = torch.tensor([list(map(encode, tokens)) for tokens in str_tokens])

    #Target
    #add 0 to beginning of seq and a bunch of 0s to end.
    target = dataset.targets
    label = torch.zeros((target.shape[0], 1, target.shape[2]), dtype=target.dtype)
    pads = torch.zeros((target.shape[0], max_seq_len - case.get_max_seq_len(), target.shape[2]), dtype=target.dtype)
    target = torch.cat((label, target, pads), dim=1)
    dataset.inputs = inputs
    dataset.targets = target
    # print(clean_dataset.inputs.shape, dataset.inputs.shape)
    # print(clean_dataset.targets.shape, dataset.targets.shape)
    clean_datasets.append(dataset)

    # if case.get_max_seq_len() < max_seq_len:
    #     # pad sequences
        




{'BOS': 0, 'PAD': 1, 'a': 2, 'b': 3, 'c': 4, 'x': 5}




{'(': 0, ')': 1, 'BOS': 2, 'PAD': 3, 'a': 4, 'b': 5, 'c': 6}


In [16]:
#make all datasets ~the same length by duplicating shorter datasets

max_length = max([clean_dataset.inputs.shape[0] for clean_dataset in clean_datasets])

for dset_list in [clean_datasets,]:
    for dataset in dset_list:
        if dataset.inputs.shape[0] < max_length:
            num_dups = max_length // dataset.inputs.shape[0]
            dataset.inputs = dataset.inputs.repeat(num_dups, 1)
            dataset.targets = dataset.targets.repeat(num_dups, 1, 1)
print(clean_datasets[0].inputs.shape, clean_datasets[1].inputs.shape)
# print(corrupted_datasets[0].inputs.shape)#, corrupted_datasets[1].inputs.shape)

torch.Size([960, 11]) torch.Size([1000, 11])


In [49]:
from circuits_benchmark.benchmark.tracr_encoded_dataset import TracrEncodedDataset
from iit.utils.iit_dataset import train_test_split
from iit.utils.iit_dataset import IITDataset
from torch.utils.data import Dataset
#smush datasets together into one big dataset, then shuffle
datasets = []
for dset_list in [clean_datasets,]:
    print(dset_list[0].inputs)
    inputs = torch.cat([dset.inputs for dset in dset_list], dim=0)
    targets = torch.cat([dset.targets for dset in dset_list], dim=0)

    # shuffle dataset contents, keeping inputs and targets in sync
    indices = torch.randperm(inputs.shape[0])
    inputs = inputs[indices]
    targets = targets[indices]

    class CustomDataset(Dataset):
        def __init__(self, data, targets):
            """
            Args:
                data (list or numpy array): List or array of input data.
                targets (list or numpy array): List or array of target data.
            """
            self.data = torch.tensor(data).to(int)
            self.targets = torch.tensor(targets)

        def __len__(self):
            return len(self.data)

        def __getitem__(self, idx):
            """
            Args:
                idx (int): Index
            Returns:
                tuple: (input tensor, target tensor)
            """
            return self.data[idx], self.targets[idx]

    decorated_dset = CustomDataset(
        data = inputs,
        targets = targets,
    )
    print(decorated_dset.data.shape)
    print(decorated_dset.targets.shape)

    train_dataset, test_dataset = train_test_split(
        decorated_dset, test_size=0.2, random_state=42
    )
    train_set = IITDataset(train_dataset, train_dataset, seed=0)
    test_set = IITDataset(test_dataset, test_dataset, seed=0)

    # datasets.append(TracrEncodedDataset(inputs, targets))


tensor([[0, 0, 2,  ..., 1, 1, 1],
        [0, 0, 4,  ..., 1, 1, 1],
        [0, 0, 5,  ..., 1, 1, 1],
        ...,
        [0, 0, 5,  ..., 1, 1, 1],
        [0, 0, 2,  ..., 1, 1, 1],
        [0, 0, 3,  ..., 1, 1, 1]])
torch.Size([1960, 11])
torch.Size([1960, 11, 1])


  self.data = torch.tensor(data).to(int)
  self.targets = torch.tensor(targets)


In [50]:
loader = train_set.make_loader(batch_size=32, num_workers=0)
for b, s in loader: #something is wrong here, meh.
    print(b)
    break

(tensor([[0, 0, 4, 5, 4, 5, 1, 1, 1, 1, 1],
        [0, 0, 2, 4, 3, 2, 1, 1, 1, 1, 1],
        [1, 2, 6, 5, 6, 1, 6, 1, 6, 0, 0],
        [0, 0, 2, 2, 5, 2, 1, 1, 1, 1, 1],
        [0, 0, 4, 5, 3, 3, 1, 1, 1, 1, 1],
        [0, 0, 2, 5, 2, 5, 1, 1, 1, 1, 1],
        [1, 2, 5, 0, 1, 0, 6, 4, 6, 4, 5],
        [1, 2, 6, 6, 1, 0, 1, 0, 0, 6, 0],
        [0, 0, 3, 5, 4, 4, 1, 1, 1, 1, 1],
        [1, 2, 4, 5, 5, 0, 4, 1, 0, 0, 6],
        [1, 2, 0, 5, 0, 4, 0, 5, 1, 0, 4],
        [0, 0, 5, 4, 4, 3, 1, 1, 1, 1, 1],
        [0, 0, 3, 4, 4, 3, 1, 1, 1, 1, 1],
        [1, 2, 4, 1, 4, 0, 0, 6, 0, 4, 5],
        [0, 0, 3, 4, 2, 5, 1, 1, 1, 1, 1],
        [1, 2, 5, 6, 6, 6, 6, 6, 4, 5, 6],
        [1, 2, 6, 6, 4, 6, 5, 1, 0, 4, 6],
        [1, 2, 1, 0, 1, 6, 6, 5, 1, 1, 0],
        [0, 0, 2, 4, 5, 3, 1, 1, 1, 1, 1],
        [0, 0, 3, 4, 2, 5, 1, 1, 1, 1, 1],
        [1, 2, 6, 1, 6, 1, 6, 4, 1, 6, 0],
        [1, 2, 1, 0, 5, 1, 4, 5, 6, 0, 6],
        [0, 0, 2, 5, 5, 4, 1, 1, 1, 1, 1],
        [1

In [51]:
for hl_model in hl_models:
    print(hl_model.tracr_input_encoder.encoding_map)
    print(hl_model.cfg.d_vocab)

{'BOS': 0, 'PAD': 1, 'a': 2, 'b': 3, 'c': 4, 'x': 5}
6
{'(': 0, ')': 1, 'BOS': 2, 'PAD': 3, 'a': 4, 'b': 5, 'c': 6}
7


In [52]:
print(corrs[0].suffixes)
for corr in corrs[1].items():
    print(corr[0])
    print(corr[1])
    print()
print(corrs[1])

{'attn': 'attn.hook_result', 'mlp': 'mlp.hook_post'}
blocks.0.mlp.hook_post
{LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}

blocks.0.mlp.hook_post
{LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}

blocks.1.attn.hook_result
{LLNode(name='blocks.1.attn.hook_result', index=[:, :, 0, :], subspace=None)}

blocks.1.attn.hook_result
{LLNode(name='blocks.1.attn.hook_result', index=[:, :, 1, :], subspace=None)}

blocks.1.mlp.hook_post
{LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}

{TracrHLNode(name: blocks.0.mlp.hook_post,
 label: bools_close_7,
 classes: 0,
 index: [:]
): {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}, TracrHLNode(name: blocks.0.mlp.hook_post,
 label: bools_open_5,
 classes: 0,
 index: [:]
): {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}, TracrHLNode(name: blocks.1.attn.hook_result,
 label: opens_2,
 classes: 0,
 index: [:, :, 0, :]
): {LLNode(name='blocks.1.attn.hook_result', index=

In [53]:
from collections import defaultdict


    

In [54]:
print(ll_models[1].cfg.d_model)

20


In [55]:
from functools import partial
from typing import Optional
from transformer_lens.hook_points import HookedRootModule, HookPoint
from torch import nn

from iit.utils.index import TorchIndex, Ix

class MultiHighLevelModel(HookedRootModule):

    def __init__(self, hl_models, corrs, cases):
        super().__init__()
        self.hl_models = hl_models
        self.corrs = corrs
        self.cases = cases

        self.n_layers = max([mod.cfg.n_layers for mod in self.hl_models])
        self.n_heads = max([mod.cfg.n_heads for mod in self.hl_models])
        self.n_ctx = max([mod.cfg.n_ctx for mod in self.hl_models]) + 1
        self.tracr_d_heads = [mod.W_Q.shape[-1] for mod in self.hl_models]
        self.d_head = max(self.tracr_d_heads)
        self.tracr_d_mlps = [mod.W_in[0].shape[1] for mod in self.hl_models]
        self.d_mlp = max(self.tracr_d_mlps)
        self.d_models = [mod.cfg.d_model for mod in self.hl_models]
        self.attn_shapes = []

        #make hooks for each necessary attn head and each mlp
        self.input_hook = HookPoint()
        self.task_hook = HookPoint()
        self.attn_hooks = nn.ModuleList([nn.ModuleList([HookPoint() for _ in range(self.n_heads)]) for _ in range(self.n_layers)])
        self.mlp_hooks = nn.ModuleList([HookPoint() for _ in range(self.n_layers)])


        corr_dict = defaultdict(list)

        for model_number, corr in enumerate(corrs):
            if corr.suffixes['attn'] == 'attn.hook_result':
                self.attn_shapes.append(self.d_models[model_number])
            else:
                self.attn_shapes.append(self.tracr_d_heads[model_number])
            corr_keys = defaultdict(list)
            for k in corr.keys():
                corr_keys[k.name].append(k)
            for i in range(self.n_layers):
                for k, suff in corr.suffixes.items():

                    name = f'blocks.{i}.{suff}'
                    if k == 'attn':
                        if model_number == 0:
                            for _ in range(self.n_heads):
                                corr_dict[name].append([])
                        if name in corr_keys.keys():
                            hl_nodes = corr_keys[name]
                            for hl_node in hl_nodes:
                                ll_nodes = corr[hl_node]
                                for head in range(self.n_heads):
                                    used = False
                                    for node in ll_nodes:
                                        if node.index.as_index[2] == head:
                                            corr_dict[name][head].append(corr_keys[name])
                                            used = True
                                            break
                                    if not used:
                                        corr_dict[name][head].append(None)
                        else:
                            for head in range(self.n_heads):
                                corr_dict[name][head].append(None)
                    elif k == 'mlp':
                        if name in corr_keys.keys():
                            corr_dict[name].append(corr_keys[name])
                        else:
                            corr_dict[name].append(None)
                    else:
                        raise ValueError(f"Unknown suffix in correspondence: {k}")
        self.attn_shape = max(self.attn_shapes)
        # for k, v in corr_dict.items():
        #     print(k)
        #     print(v)
        #     print()
        self.corr_dict = corr_dict

        self.setup()
    
    def is_categorical(self):
        return False
    
    def forward(self, x):
        # get sorting indices by task id
        if isinstance(x, tuple):
            x = x[0]
        tokens = self.input_hook(x)
        task_ids = tokens[:, 0]
        task_ids = self.task_hook(task_ids)

        # Step 1 -- get all the activations.
        caches = []
        for i, hl_model in enumerate(self.hl_models):
            if isinstance(hl_model, IITHLModel):
                hl_model = hl_model.hl_model
            #clip token vocab down to task vocab size; replace out-of-task tokens with PAD.
            encoder = hl_model.tracr_input_encoder
            task_tokens = torch.clone(tokens)[:, 1:hl_model.cfg.n_ctx+1]
            task_tokens[task_tokens >= hl_model.cfg.d_vocab] = encoder.encoding_map[TRACR_PAD]
            _, cache = hl_model.run_with_cache(task_tokens)
            caches.append(cache)
        
        # for k, i in caches[0].items():
        #     print(k, i.shape)
        
        #ActivationCache isn't writable so we need to copy it.
        caches = [{k: i for k, i in cache.items()} for cache in caches]


        # Step 2 -- construct all the hooks for THIS model.
        for layer in range(self.n_layers):
            # create MLP hooks
            mlp_storage = torch.zeros((len(self.hl_models), tokens.shape[0], tokens.shape[1], self.d_mlp))
            attn_storage = torch.zeros((len(self.hl_models), tokens.shape[0], tokens.shape[1], self.n_heads, self.attn_shape))
            for i, hl_model in enumerate(self.hl_models):

                #MLP
                suffix = self.corrs[i].suffixes['mlp']
                hook_name = f'blocks.{layer}.{suffix}'
                #unpack from cache
                if self.corr_dict[hook_name][i] is not None:
                    print(hook_name, mlp_storage.shape, caches[i][hook_name].shape)
                    mlp_storage[i,:,1:hl_model.cfg.n_ctx+1,:self.tracr_d_mlps[i]] = caches[i][hook_name]
                

                #Attn
                suffix = self.corrs[i].suffixes['attn']
                hook_name = f'blocks.{layer}.{suffix}'
                attn_shape = self.attn_shapes[i]
                #unpack from cache
                for head in range(self.n_heads):
                    if self.corr_dict[hook_name][head][i] is not None:
                        print(attn_storage[i,:,1:hl_model.cfg.n_ctx+1,head,:attn_shape].shape, caches[i][hook_name][:,:,head].shape)
                        attn_storage[i,:,1:hl_model.cfg.n_ctx+1,head,:attn_shape] = caches[i][hook_name][:,:,head]

            #modify with hook
            mlp_storage = self.mlp_hooks[layer](mlp_storage)
            for head in range(self.n_heads):
                attn_storage[:,:,:,head] = self.attn_hooks[layer][head](attn_storage[:,:,:,head])

                
            for i, hl_model in enumerate(self.hl_models):
                #pack back into cache
                # MLP
                suffix = self.corrs[i].suffixes['mlp']
                hook_name = f'blocks.{layer}.{suffix}'
                caches[i][hook_name] = mlp_storage[i,:,1:hl_model.cfg.n_ctx+1,:self.tracr_d_mlps[i]]
                # Attn
                suffix = self.corrs[i].suffixes['attn']
                hook_name = f'blocks.{layer}.{suffix}'
                attn_shape = self.attn_shapes[i]
                caches[i][hook_name] = attn_storage[i,:,1:hl_model.cfg.n_ctx+1,:hl_model.cfg.n_heads,:attn_shape]


        # Step 3 -- run with a bunch of hooks on the tracr models using the cache, actually doing interventions, and return intervened output.
        outputs = torch.zeros((tokens.shape[0], tokens.shape[1], 1))
        for i, hl_model in enumerate(self.hl_models):
            if isinstance(hl_model, IITHLModel):
                hl_model = hl_model.hl_model
            #clip token vocab down to task vocab size; replace out-of-task tokens with PAD.
            encoder = hl_model.tracr_input_encoder
            task_tokens = torch.clone(tokens)[:, 1:hl_model.cfg.n_ctx+1]
            task_tokens[task_tokens >= hl_model.cfg.d_vocab] = encoder.encoding_map[TRACR_PAD]

            # define hooks.
            def mlp_replacement_hook(x, hook):
                x[:] = caches[i][hook.name]
            
            def attn_replacement_hook(x, hook, index: Optional[TorchIndex] = None):
                x[index.as_index] = caches[i][hook.name][index.as_index]

            hooks = []
            #go through self.corr_dict and link up each hook with corresponding hook(s) in HL tracr models.
            for layer in range(self.n_layers):

                #MLP
                suffix = self.corrs[i].suffixes['mlp']
                hook_name = f'blocks.{layer}.{suffix}'
                #unpack from cache
                if self.corr_dict[hook_name][i] is not None:
                    hooks.append((hook_name, mlp_replacement_hook))
                
                # Attn
                suffix = self.corrs[i].suffixes['attn']
                hook_name = f'blocks.{layer}.{suffix}'
                for head in range(hl_model.cfg.n_heads):
                    if self.corr_dict[hook_name][head][i] is not None:
                        hooks.append((hook_name, partial(attn_replacement_hook, index=TorchIndex([None,None,head,None]))))

            # print(hooks)
            model_output = hl_model.run_with_hooks(task_tokens, fwd_hooks=hooks)
            outputs[task_ids == i, 1:hl_model.cfg.n_ctx+1,:] = model_output[task_ids == i]
        print(caches[0].keys())

        return outputs

model = MultiHighLevelModel(hl_models, corrs, cases)

input, target = train_dataset[0]
output, cache = model.run_with_cache(input[None,:])
print(input)
print(output[:1])
print(target)

blocks.0.mlp.hook_post torch.Size([2, 1, 11, 4]) torch.Size([1, 5, 1])
blocks.0.mlp.hook_post torch.Size([2, 1, 11, 4]) torch.Size([1, 10, 4])
blocks.1.mlp.hook_post torch.Size([2, 1, 11, 4]) torch.Size([1, 10, 4])
torch.Size([1, 10, 22]) torch.Size([1, 10, 22])
dict_keys(['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.hook_q_input', 'blocks.0.hook_k_input', 'blocks.0.hook_v_input', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.attn.hook_result', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.hook_mlp_in', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.hook_q_input', 'blocks.1.hook_k_input', 'blocks.1.hook_v_input', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn

In [56]:
# Create a correspondence from the corr_dict
from iit.utils.correspondence import Correspondence

#don't forget to link input to resid_pre
# also don't forget to link task_id to first free attention head.

print(model.n_layers)
print(model.n_heads)
corr_dict = {}
task_id_set = False
corr_dict['input_hook'] = [('blocks.0.hook_resid_pre', Ix[[None]], None),]
for i in range(model.n_layers):
    for j in range(model.n_heads):
        use_attn_head = False
        use_mlp = False
        for k in range(len(model.hl_models)):
            corr = model.corrs[k]
            suffixes = corr.suffixes
            attn_hook_name = f'blocks.{i}.{suffixes["attn"]}'
            mlp_hook_name = f'blocks.{i}.{suffixes["mlp"]}'
            if model.corr_dict[attn_hook_name][j][k] is not None:
                use_attn_head = True
            if model.corr_dict[mlp_hook_name][k] is not None:
                use_mlp = True
        if use_attn_head:
            corr_dict[f'attn_hooks.{i}.{j}'] = [(attn_hook_name, Ix[[None,None,j,None]], None),]
        elif not task_id_set:
            corr_dict[f'task_hook'] = [(attn_hook_name, Ix[[None,None,j,None]], None),]
            task_id_set = True
        if use_mlp and j == 0:
            corr_dict[f'mlp_hooks.{i}'] = [(mlp_hook_name, Ix[[None]], None),]
corr = Correspondence.make_corr_from_dict(corr_dict)
print(corr)

2
2
{input_hook: {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}, task_hook: {LLNode(name='blocks.0.attn.hook_result', index=[:, :, 0, :], subspace=None)}, mlp_hooks.0: {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}, attn_hooks.1.0: {LLNode(name='blocks.1.attn.hook_result', index=[:, :, 0, :], subspace=None)}, mlp_hooks.1: {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}}


In [57]:
import transformer_lens as tl

D_MODEL = max(model.d_models)
N_CTX = model.n_ctx
N_LAYERS = model.n_layers
N_HEADS = model.n_heads
D_VOCAB = max([hl_model.cfg.d_vocab for hl_model in model.hl_models])

#Want to specify this somewhere central, e.g., iit.tasks.ioi but iit.tasks.parens
ll_cfg = tl.HookedTransformerConfig(
        n_layers = N_LAYERS,
        d_model = D_MODEL,
        n_ctx = N_CTX,
        d_head = D_MODEL // N_HEADS,
        d_vocab = D_VOCAB,
        act_fn = "relu",
)

class SingleOutputHookedTransformer(tl.HookedTransformer):
    def forward(self, x):
        output = super().forward(x)
        return output[:,:,:1]

ll_model = SingleOutputHookedTransformer(ll_cfg).to(device)

Moving model to device:  cpu


In [58]:
idx = 1
input = train_set[idx][0][0]
output = train_set[idx][0][1]
print(input)
print(output)
print(model(torch.stack([input,])))
print(ll_model(torch.stack([input,])))

tensor([0, 0, 5, 2, 3, 3, 1, 1, 1, 1, 1])
tensor([[0.0000],
        [0.0000],
        [1.0000],
        [0.5000],
        [0.3333],
        [0.2500],
        [0.0000],
        [0.0000],
        [0.0000],
        [0.0000],
        [0.0000]])
blocks.0.mlp.hook_post torch.Size([2, 1, 11, 4]) torch.Size([1, 5, 1])
blocks.0.mlp.hook_post torch.Size([2, 1, 11, 4]) torch.Size([1, 10, 4])
blocks.1.mlp.hook_post torch.Size([2, 1, 11, 4]) torch.Size([1, 10, 4])
torch.Size([1, 10, 22]) torch.Size([1, 10, 22])
dict_keys(['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.hook_q_input', 'blocks.0.hook_k_input', 'blocks.0.hook_v_input', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.attn.hook_result', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.hook_mlp_in', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0

In [59]:
from iit.model_pairs.strict_iit_model_pair import StrictIITModelPair

for k in corr.keys():
    print(str(k))
print(model.hook_dict)
assert all([str(k) in model.hook_dict for k in corr.keys()])
model_pair = StrictIITModelPair(hl_model=model, ll_model=ll_model, corr=corr)

input_hook
task_hook
mlp_hooks.0
attn_hooks.1.0
mlp_hooks.1
{'input_hook': HookPoint(), 'task_hook': HookPoint(), 'attn_hooks.0.0': HookPoint(), 'attn_hooks.0.1': HookPoint(), 'attn_hooks.1.0': HookPoint(), 'attn_hooks.1.1': HookPoint(), 'mlp_hooks.0': HookPoint(), 'mlp_hooks.1': HookPoint()}


In [60]:
model_pair.train(
    train_set=train_set,
    test_set=test_set,
    optimizer_cls=torch.optim.AdamW,
    epochs=10
)

training_args={'batch_size': 256, 'lr': 0.001, 'num_workers': 0, 'early_stop': True, 'lr_scheduler': None, 'scheduler_val_metric': ['val/accuracy', 'val/IIA'], 'scheduler_mode': 'max', 'clip_grad_norm': 1.0, 'seed': 0, 'detach_while_caching': True, 'atol': 0.05, 'use_single_loss': False, 'iit_weight': 1.0, 'behavior_weight': 1.0, 'strict_weight': 1.0}


  0%|          | 0/10 [00:00<?, ?it/s]

blocks.0.mlp.hook_post torch.Size([2, 256, 11, 4]) torch.Size([256, 5, 1])
blocks.0.mlp.hook_post torch.Size([2, 256, 11, 4]) torch.Size([256, 10, 4])
blocks.1.mlp.hook_post torch.Size([2, 256, 11, 4]) torch.Size([256, 10, 4])
torch.Size([256, 10, 22]) torch.Size([256, 10, 22])
dict_keys(['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.hook_q_input', 'blocks.0.hook_k_input', 'blocks.0.hook_v_input', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.attn.hook_result', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.hook_mlp_in', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.hook_q_input', 'blocks.1.hook_k_input', 'blocks.1.hook_v_input', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores'



torch.Size([256, 11, 1]) torch.Size([256, 11, 1]) [:]
blocks.0.mlp.hook_post torch.Size([2, 256, 11, 4]) torch.Size([256, 5, 1])
blocks.0.mlp.hook_post torch.Size([2, 256, 11, 4]) torch.Size([256, 10, 4])
blocks.1.mlp.hook_post torch.Size([2, 256, 11, 4]) torch.Size([256, 10, 4])
torch.Size([256, 10, 22]) torch.Size([256, 10, 22])
dict_keys(['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.hook_q_input', 'blocks.0.hook_k_input', 'blocks.0.hook_v_input', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.attn.hook_result', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.hook_mlp_in', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.hook_q_input', 'blocks.1.hook_k_input', 'blocks.1.hook_v_input', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'b



blocks.0.mlp.hook_post torch.Size([2, 256, 11, 4]) torch.Size([256, 5, 1])
blocks.0.mlp.hook_post torch.Size([2, 256, 11, 4]) torch.Size([256, 10, 4])
blocks.1.mlp.hook_post torch.Size([2, 256, 11, 4]) torch.Size([256, 10, 4])
torch.Size([256, 10, 22]) torch.Size([256, 10, 22])
dict_keys(['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.hook_q_input', 'blocks.0.hook_k_input', 'blocks.0.hook_v_input', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.attn.hook_result', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.hook_mlp_in', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.hook_q_input', 'blocks.1.hook_k_input', 'blocks.1.hook_v_input', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores'



torch.Size([256, 11, 1]) torch.Size([256, 11, 1]) [:]
blocks.0.mlp.hook_post torch.Size([2, 256, 11, 4]) torch.Size([256, 5, 1])
blocks.0.mlp.hook_post torch.Size([2, 256, 11, 4]) torch.Size([256, 10, 4])
blocks.1.mlp.hook_post torch.Size([2, 256, 11, 4]) torch.Size([256, 10, 4])
torch.Size([256, 10, 22]) torch.Size([256, 10, 22])
dict_keys(['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.hook_q_input', 'blocks.0.hook_k_input', 'blocks.0.hook_v_input', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.attn.hook_result', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.hook_mlp_in', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.hook_q_input', 'blocks.1.hook_k_input', 'blocks.1.hook_v_input', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'b



torch.Size([256, 11, 1]) torch.Size([256, 11, 1]) [:]
blocks.0.mlp.hook_post torch.Size([2, 256, 11, 4]) torch.Size([256, 5, 1])
blocks.0.mlp.hook_post torch.Size([2, 256, 11, 4]) torch.Size([256, 10, 4])
blocks.1.mlp.hook_post torch.Size([2, 256, 11, 4]) torch.Size([256, 10, 4])
torch.Size([256, 10, 22]) torch.Size([256, 10, 22])
dict_keys(['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.hook_q_input', 'blocks.0.hook_k_input', 'blocks.0.hook_v_input', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.attn.hook_result', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.hook_mlp_in', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.hook_q_input', 'blocks.1.hook_k_input', 'blocks.1.hook_v_input', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'b

100%|██████████| 7/7 [00:00<00:00,  9.97it/s]

blocks.0.mlp.hook_post torch.Size([2, 32, 11, 4]) torch.Size([32, 5, 1])
blocks.0.mlp.hook_post torch.Size([2, 32, 11, 4]) torch.Size([32, 10, 4])
blocks.1.mlp.hook_post torch.Size([2, 32, 11, 4]) torch.Size([32, 10, 4])
torch.Size([32, 10, 22]) torch.Size([32, 10, 22])
dict_keys(['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.hook_q_input', 'blocks.0.hook_k_input', 'blocks.0.hook_v_input', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.attn.hook_result', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.hook_mlp_in', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.hook_q_input', 'blocks.1.hook_k_input', 'blocks.1.hook_v_input', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores', 'block




blocks.0.mlp.hook_post torch.Size([2, 256, 11, 4]) torch.Size([256, 5, 1])
blocks.0.mlp.hook_post torch.Size([2, 256, 11, 4]) torch.Size([256, 10, 4])
blocks.1.mlp.hook_post torch.Size([2, 256, 11, 4]) torch.Size([256, 10, 4])
torch.Size([256, 10, 22]) torch.Size([256, 10, 22])
dict_keys(['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.hook_q_input', 'blocks.0.hook_k_input', 'blocks.0.hook_v_input', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.attn.hook_result', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.hook_mlp_in', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.hook_q_input', 'blocks.1.hook_k_input', 'blocks.1.hook_v_input', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores'

 10%|█         | 1/10 [00:00<00:08,  1.12it/s]


Epoch 0: train/iit_loss: 0.2587, train/behavior_loss: 0.1447, train/strict_loss: 0.1920, val/iit_loss: 0.0685, val/IIA: 25.94%, val/accuracy: 32.93%, val/strict_accuracy: 32.93%, 




blocks.0.mlp.hook_post torch.Size([2, 256, 11, 4]) torch.Size([256, 5, 1])
blocks.0.mlp.hook_post torch.Size([2, 256, 11, 4]) torch.Size([256, 10, 4])
blocks.1.mlp.hook_post torch.Size([2, 256, 11, 4]) torch.Size([256, 10, 4])
torch.Size([256, 10, 22]) torch.Size([256, 10, 22])




dict_keys(['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.hook_q_input', 'blocks.0.hook_k_input', 'blocks.0.hook_v_input', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.attn.hook_result', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.hook_mlp_in', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.hook_q_input', 'blocks.1.hook_k_input', 'blocks.1.hook_v_input', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.attn.hook_result', 'blocks.1.hook_attn_out', 'blocks.1.hook_resid_mid', 'blocks.1.hook_mlp_in', 'blocks.1.mlp.hook_pre', 'blocks.1.mlp.hook_post', 'blocks.1.hook_mlp_out', 'blocks.1.hook_resid_post'])
blocks.0.mlp



blocks.0.mlp.hook_post torch.Size([2, 256, 11, 4]) torch.Size([256, 5, 1])
blocks.0.mlp.hook_post torch.Size([2, 256, 11, 4]) torch.Size([256, 10, 4])
blocks.1.mlp.hook_post torch.Size([2, 256, 11, 4]) torch.Size([256, 10, 4])
torch.Size([256, 10, 22]) torch.Size([256, 10, 22])
dict_keys(['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.hook_q_input', 'blocks.0.hook_k_input', 'blocks.0.hook_v_input', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.attn.hook_result', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.hook_mlp_in', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.hook_q_input', 'blocks.1.hook_k_input', 'blocks.1.hook_v_input', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores'



torch.Size([256, 11, 1]) torch.Size([256, 11, 1]) [:]
blocks.0.mlp.hook_post torch.Size([2, 256, 11, 4]) torch.Size([256, 5, 1])
blocks.0.mlp.hook_post torch.Size([2, 256, 11, 4]) torch.Size([256, 10, 4])
blocks.1.mlp.hook_post torch.Size([2, 256, 11, 4]) torch.Size([256, 10, 4])
torch.Size([256, 10, 22]) torch.Size([256, 10, 22])
dict_keys(['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.hook_q_input', 'blocks.0.hook_k_input', 'blocks.0.hook_v_input', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.attn.hook_result', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.hook_mlp_in', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.hook_q_input', 'blocks.1.hook_k_input', 'blocks.1.hook_v_input', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'b



blocks.0.mlp.hook_post torch.Size([2, 256, 11, 4]) torch.Size([256, 5, 1])
blocks.0.mlp.hook_post torch.Size([2, 256, 11, 4]) torch.Size([256, 10, 4])
blocks.1.mlp.hook_post torch.Size([2, 256, 11, 4]) torch.Size([256, 10, 4])
torch.Size([256, 10, 22]) torch.Size([256, 10, 22])
dict_keys(['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.hook_q_input', 'blocks.0.hook_k_input', 'blocks.0.hook_v_input', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.attn.hook_result', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.hook_mlp_in', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.hook_q_input', 'blocks.1.hook_k_input', 'blocks.1.hook_v_input', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores'

100%|██████████| 7/7 [00:00<00:00,  8.79it/s]


dict_keys(['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.hook_q_input', 'blocks.0.hook_k_input', 'blocks.0.hook_v_input', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.attn.hook_result', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.hook_mlp_in', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.hook_q_input', 'blocks.1.hook_k_input', 'blocks.1.hook_v_input', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.attn.hook_result', 'blocks.1.hook_attn_out', 'blocks.1.hook_resid_mid', 'blocks.1.hook_mlp_in', 'blocks.1.mlp.hook_pre', 'blocks.1.mlp.hook_post', 'blocks.1.hook_mlp_out', 'blocks.1.hook_resid_post'])
torch.Size([

 20%|██        | 2/10 [00:01<00:07,  1.09it/s]

blocks.0.mlp.hook_post torch.Size([2, 136, 11, 4]) torch.Size([136, 5, 1])
blocks.0.mlp.hook_post torch.Size([2, 136, 11, 4]) torch.Size([136, 10, 4])
blocks.1.mlp.hook_post torch.Size([2, 136, 11, 4]) torch.Size([136, 10, 4])
torch.Size([136, 10, 22]) torch.Size([136, 10, 22])
dict_keys(['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.hook_q_input', 'blocks.0.hook_k_input', 'blocks.0.hook_v_input', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.attn.hook_result', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.hook_mlp_in', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.hook_q_input', 'blocks.1.hook_k_input', 'blocks.1.hook_v_input', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores'



blocks.0.mlp.hook_post torch.Size([2, 256, 11, 4]) torch.Size([256, 5, 1])
blocks.0.mlp.hook_post torch.Size([2, 256, 11, 4]) torch.Size([256, 10, 4])
blocks.1.mlp.hook_post torch.Size([2, 256, 11, 4]) torch.Size([256, 10, 4])
torch.Size([256, 10, 22]) torch.Size([256, 10, 22])
dict_keys(['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.hook_q_input', 'blocks.0.hook_k_input', 'blocks.0.hook_v_input', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.attn.hook_result', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.hook_mlp_in', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.hook_q_input', 'blocks.1.hook_k_input', 'blocks.1.hook_v_input', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores'



blocks.0.mlp.hook_post torch.Size([2, 256, 11, 4]) torch.Size([256, 5, 1])
blocks.0.mlp.hook_post torch.Size([2, 256, 11, 4]) torch.Size([256, 10, 4])
blocks.1.mlp.hook_post torch.Size([2, 256, 11, 4]) torch.Size([256, 10, 4])
torch.Size([256, 10, 22]) torch.Size([256, 10, 22])
dict_keys(['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.hook_q_input', 'blocks.0.hook_k_input', 'blocks.0.hook_v_input', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.attn.hook_result', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.hook_mlp_in', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.hook_q_input', 'blocks.1.hook_k_input', 'blocks.1.hook_v_input', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores'



blocks.0.mlp.hook_post torch.Size([2, 256, 11, 4]) torch.Size([256, 5, 1])
blocks.0.mlp.hook_post torch.Size([2, 256, 11, 4]) torch.Size([256, 10, 4])
blocks.1.mlp.hook_post torch.Size([2, 256, 11, 4]) torch.Size([256, 10, 4])
torch.Size([256, 10, 22]) torch.Size([256, 10, 22])
dict_keys(['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.hook_q_input', 'blocks.0.hook_k_input', 'blocks.0.hook_v_input', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.attn.hook_result', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.hook_mlp_in', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.hook_q_input', 'blocks.1.hook_k_input', 'blocks.1.hook_v_input', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores'

100%|██████████| 7/7 [00:00<00:00,  9.46it/s]


blocks.0.mlp.hook_post torch.Size([2, 32, 11, 4]) torch.Size([32, 5, 1])
blocks.0.mlp.hook_post torch.Size([2, 32, 11, 4]) torch.Size([32, 10, 4])
blocks.1.mlp.hook_post torch.Size([2, 32, 11, 4]) torch.Size([32, 10, 4])
torch.Size([32, 10, 22]) torch.Size([32, 10, 22])
dict_keys(['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.hook_q_input', 'blocks.0.hook_k_input', 'blocks.0.hook_v_input', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.attn.hook_result', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.hook_mlp_in', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.hook_q_input', 'blocks.1.hook_k_input', 'blocks.1.hook_v_input', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores', 'block

 30%|███       | 3/10 [00:02<00:06,  1.12it/s]

blocks.0.mlp.hook_post torch.Size([2, 136, 11, 4]) torch.Size([136, 5, 1])
blocks.0.mlp.hook_post torch.Size([2, 136, 11, 4]) torch.Size([136, 10, 4])
blocks.1.mlp.hook_post torch.Size([2, 136, 11, 4]) torch.Size([136, 10, 4])
torch.Size([136, 10, 22]) torch.Size([136, 10, 22])
dict_keys(['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.hook_q_input', 'blocks.0.hook_k_input', 'blocks.0.hook_v_input', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.attn.hook_result', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.hook_mlp_in', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.hook_q_input', 'blocks.1.hook_k_input', 'blocks.1.hook_v_input', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores'



blocks.0.mlp.hook_post torch.Size([2, 256, 11, 4]) torch.Size([256, 5, 1])
blocks.0.mlp.hook_post torch.Size([2, 256, 11, 4]) torch.Size([256, 10, 4])
blocks.1.mlp.hook_post torch.Size([2, 256, 11, 4]) torch.Size([256, 10, 4])
torch.Size([256, 10, 22]) torch.Size([256, 10, 22])
dict_keys(['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.hook_q_input', 'blocks.0.hook_k_input', 'blocks.0.hook_v_input', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.attn.hook_result', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.hook_mlp_in', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.hook_q_input', 'blocks.1.hook_k_input', 'blocks.1.hook_v_input', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores'



blocks.0.mlp.hook_post torch.Size([2, 256, 11, 4]) torch.Size([256, 5, 1])
blocks.0.mlp.hook_post torch.Size([2, 256, 11, 4]) torch.Size([256, 10, 4])
blocks.1.mlp.hook_post torch.Size([2, 256, 11, 4]) torch.Size([256, 10, 4])
torch.Size([256, 10, 22]) torch.Size([256, 10, 22])
dict_keys(['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.hook_q_input', 'blocks.0.hook_k_input', 'blocks.0.hook_v_input', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.attn.hook_result', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.hook_mlp_in', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.hook_q_input', 'blocks.1.hook_k_input', 'blocks.1.hook_v_input', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores'



torch.Size([256, 11, 1]) torch.Size([256, 11, 1]) [:]
blocks.0.mlp.hook_post torch.Size([2, 256, 11, 4]) torch.Size([256, 5, 1])
blocks.0.mlp.hook_post torch.Size([2, 256, 11, 4]) torch.Size([256, 10, 4])
blocks.1.mlp.hook_post torch.Size([2, 256, 11, 4]) torch.Size([256, 10, 4])
torch.Size([256, 10, 22]) torch.Size([256, 10, 22])
dict_keys(['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.hook_q_input', 'blocks.0.hook_k_input', 'blocks.0.hook_v_input', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.attn.hook_result', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.hook_mlp_in', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.hook_q_input', 'blocks.1.hook_k_input', 'blocks.1.hook_v_input', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'b



blocks.0.mlp.hook_post torch.Size([2, 256, 11, 4]) torch.Size([256, 5, 1])
blocks.0.mlp.hook_post torch.Size([2, 256, 11, 4]) torch.Size([256, 10, 4])
blocks.1.mlp.hook_post torch.Size([2, 256, 11, 4]) torch.Size([256, 10, 4])
torch.Size([256, 10, 22]) torch.Size([256, 10, 22])
dict_keys(['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.hook_q_input', 'blocks.0.hook_k_input', 'blocks.0.hook_v_input', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.attn.hook_result', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.hook_mlp_in', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.hook_q_input', 'blocks.1.hook_k_input', 'blocks.1.hook_v_input', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores'



blocks.0.mlp.hook_post torch.Size([2, 256, 11, 4]) torch.Size([256, 5, 1])
blocks.0.mlp.hook_post torch.Size([2, 256, 11, 4]) torch.Size([256, 10, 4])
blocks.1.mlp.hook_post torch.Size([2, 256, 11, 4]) torch.Size([256, 10, 4])
torch.Size([256, 10, 22]) torch.Size([256, 10, 22])
dict_keys(['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.hook_q_input', 'blocks.0.hook_k_input', 'blocks.0.hook_v_input', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.attn.hook_result', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.hook_mlp_in', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.hook_q_input', 'blocks.1.hook_k_input', 'blocks.1.hook_v_input', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores'



blocks.0.mlp.hook_post torch.Size([2, 256, 11, 4]) torch.Size([256, 5, 1])
blocks.0.mlp.hook_post torch.Size([2, 256, 11, 4]) torch.Size([256, 10, 4])
blocks.1.mlp.hook_post torch.Size([2, 256, 11, 4]) torch.Size([256, 10, 4])
torch.Size([256, 10, 22]) torch.Size([256, 10, 22])
dict_keys(['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.hook_q_input', 'blocks.0.hook_k_input', 'blocks.0.hook_v_input', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.attn.hook_result', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.hook_mlp_in', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.hook_q_input', 'blocks.1.hook_k_input', 'blocks.1.hook_v_input', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores'



blocks.0.mlp.hook_post torch.Size([2, 32, 11, 4]) torch.Size([32, 5, 1])
blocks.0.mlp.hook_post torch.Size([2, 32, 11, 4]) torch.Size([32, 10, 4])
blocks.1.mlp.hook_post torch.Size([2, 32, 11, 4]) torch.Size([32, 10, 4])
torch.Size([32, 10, 22]) torch.Size([32, 10, 22])
dict_keys(['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.hook_q_input', 'blocks.0.hook_k_input', 'blocks.0.hook_v_input', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.attn.hook_result', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.hook_mlp_in', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.hook_q_input', 'blocks.1.hook_k_input', 'blocks.1.hook_v_input', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores', 'block

100%|██████████| 7/7 [00:00<00:00,  8.12it/s]


blocks.0.mlp.hook_post torch.Size([2, 256, 11, 4]) torch.Size([256, 5, 1])
blocks.0.mlp.hook_post torch.Size([2, 256, 11, 4]) torch.Size([256, 10, 4])
blocks.1.mlp.hook_post torch.Size([2, 256, 11, 4]) torch.Size([256, 10, 4])
torch.Size([256, 10, 22]) torch.Size([256, 10, 22])
dict_keys(['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.hook_q_input', 'blocks.0.hook_k_input', 'blocks.0.hook_v_input', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.attn.hook_result', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.hook_mlp_in', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.hook_q_input', 'blocks.1.hook_k_input', 'blocks.1.hook_v_input', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores'

 40%|████      | 4/10 [00:03<00:05,  1.06it/s]

dict_keys(['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.hook_q_input', 'blocks.0.hook_k_input', 'blocks.0.hook_v_input', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.attn.hook_result', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.hook_mlp_in', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.hook_q_input', 'blocks.1.hook_k_input', 'blocks.1.hook_v_input', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.attn.hook_result', 'blocks.1.hook_attn_out', 'blocks.1.hook_resid_mid', 'blocks.1.hook_mlp_in', 'blocks.1.mlp.hook_pre', 'blocks.1.mlp.hook_post', 'blocks.1.hook_mlp_out', 'blocks.1.hook_resid_post'])
blocks.0.mlp



blocks.0.mlp.hook_post torch.Size([2, 256, 11, 4]) torch.Size([256, 5, 1])
blocks.0.mlp.hook_post torch.Size([2, 256, 11, 4]) torch.Size([256, 10, 4])
blocks.1.mlp.hook_post torch.Size([2, 256, 11, 4]) torch.Size([256, 10, 4])
torch.Size([256, 10, 22]) torch.Size([256, 10, 22])
dict_keys(['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.hook_q_input', 'blocks.0.hook_k_input', 'blocks.0.hook_v_input', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.attn.hook_result', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.hook_mlp_in', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.hook_q_input', 'blocks.1.hook_k_input', 'blocks.1.hook_v_input', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores'

  0%|          | 0/7 [00:00<?, ?it/s]
 40%|████      | 4/10 [00:03<00:05,  1.04it/s]


KeyboardInterrupt: 