In [1]:
%load_ext autoreload
%autoreload 2

import torch

import circuits_benchmark.benchmark.cases.case_3 as case3
import circuits_benchmark.benchmark.cases.case_4 as case4
from circuits_benchmark.utils.ll_model_loader.ll_model_loader_factory import get_ll_model_loader
from circuits_benchmark.utils.iit.iit_hl_model import IITHLModel
from circuits_benchmark.transformers.hooked_tracr_transformer import HookedTracrTransformer
from circuits_benchmark.benchmark.vocabs import TRACR_BOS, TRACR_PAD

device = 'cuda' if torch.cuda.is_available() else 'cpu'


#load cases
cases = [case3.Case3(), case4.Case4()]
corrs = []
ll_models = []
hl_models = []
model_pairs = []
for case in cases:
    ll_model_loader = get_ll_model_loader(case, interp_bench=True)
    corr, ll_model = ll_model_loader.load_ll_model_and_correspondence(device=device)
    hl_model = case.get_hl_model()

    if isinstance(hl_model, HookedTracrTransformer):
        hl_model = IITHLModel(hl_model, eval_mode=True)

    model_pair = case.build_model_pair(ll_model=ll_model, hl_model=hl_model)

    corrs.append(corr)
    ll_models.append(ll_model)
    hl_models.append(hl_model)
    model_pairs.append(model_pair)


Moving model to device:  cpu
{'hook_embed': HookPoint(), 'hook_pos_embed': HookPoint(), 'blocks.0.attn.hook_k': HookPoint(), 'blocks.0.attn.hook_q': HookPoint(), 'blocks.0.attn.hook_v': HookPoint(), 'blocks.0.attn.hook_z': HookPoint(), 'blocks.0.attn.hook_attn_scores': HookPoint(), 'blocks.0.attn.hook_pattern': HookPoint(), 'blocks.0.attn.hook_result': HookPoint(), 'blocks.0.mlp.hook_pre': HookPoint(), 'blocks.0.mlp.hook_post': HookPoint(), 'blocks.0.hook_attn_in': HookPoint(), 'blocks.0.hook_q_input': HookPoint(), 'blocks.0.hook_k_input': HookPoint(), 'blocks.0.hook_v_input': HookPoint(), 'blocks.0.hook_mlp_in': HookPoint(), 'blocks.0.hook_attn_out': HookPoint(), 'blocks.0.hook_mlp_out': HookPoint(), 'blocks.0.hook_resid_pre': HookPoint(), 'blocks.0.hook_resid_mid': HookPoint(), 'blocks.0.hook_resid_post': HookPoint(), 'blocks.1.attn.hook_k': HookPoint(), 'blocks.1.attn.hook_q': HookPoint(), 'blocks.1.attn.hook_v': HookPoint(), 'blocks.1.attn.hook_z': HookPoint(), 'blocks.1.attn.hook_



Moving model to device:  cpu
{'hook_embed': HookPoint(), 'hook_pos_embed': HookPoint(), 'blocks.0.attn.hook_k': HookPoint(), 'blocks.0.attn.hook_q': HookPoint(), 'blocks.0.attn.hook_v': HookPoint(), 'blocks.0.attn.hook_z': HookPoint(), 'blocks.0.attn.hook_attn_scores': HookPoint(), 'blocks.0.attn.hook_pattern': HookPoint(), 'blocks.0.attn.hook_result': HookPoint(), 'blocks.0.mlp.hook_pre': HookPoint(), 'blocks.0.mlp.hook_post': HookPoint(), 'blocks.0.hook_attn_in': HookPoint(), 'blocks.0.hook_q_input': HookPoint(), 'blocks.0.hook_k_input': HookPoint(), 'blocks.0.hook_v_input': HookPoint(), 'blocks.0.hook_mlp_in': HookPoint(), 'blocks.0.hook_attn_out': HookPoint(), 'blocks.0.hook_mlp_out': HookPoint(), 'blocks.0.hook_resid_pre': HookPoint(), 'blocks.0.hook_resid_mid': HookPoint(), 'blocks.0.hook_resid_post': HookPoint(), 'blocks.1.attn.hook_k': HookPoint(), 'blocks.1.attn.hook_q': HookPoint(), 'blocks.1.attn.hook_v': HookPoint(), 'blocks.1.attn.hook_z': HookPoint(), 'blocks.1.attn.hook_

In [2]:
print(TRACR_BOS, TRACR_PAD)
print(sorted(list(case.get_vocab())))
for hl_model in hl_models:
    print(dir(hl_model.tracr_input_encoder))
    print(hl_model.tracr_input_encoder.encoding_map)

BOS PAD
['(', ')', 'a', 'b', 'c']
['__abstractmethods__', '__annotations__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__slots__', '__str__', '__subclasshook__', '__weakref__', '_abc_impl', '_bos_token', '_max_seq_len', '_pad_token', 'basis', 'bos_encoding', 'bos_token', 'decode', 'encode', 'encoding_map', 'enforce_bos', 'pad_encoding', 'pad_token', 'vocab_size']
{'BOS': 0, 'PAD': 1, 'a': 2, 'b': 3, 'c': 4, 'x': 5}
['__abstractmethods__', '__annotations__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce

In [3]:
#find overall min and max sequence length (we'll have to pad some)
min_seq_len = 100000
max_seq_len = 0

for case in cases:
    if case.get_min_seq_len() < min_seq_len:
        min_seq_len = case.get_min_seq_len()
    if case.get_max_seq_len() > max_seq_len:
        max_seq_len = case.get_max_seq_len()
print(min_seq_len, max_seq_len)

4 10


In [4]:
# find max vocab size (some tasks will just use a subset of vocab size)
vocab_size = 0
for case in cases:
    if len(case.get_vocab()) > vocab_size:
        vocab_size = len(case.get_vocab())
print(vocab_size)


5


In [5]:
max_case_samples = 1_000
case_samples = []
for case in cases:
    num_samples = min(max_case_samples, case.get_total_data_len())
    case_samples.append(num_samples)
print(case_samples)

[320, 1000]


In [41]:
# generate clean and corrupted datasets
clean_datasets = []
corrupted_datasets = []
masks = []
for task_id, case, samples, hl_model in zip(range(len(cases)), cases, case_samples, hl_models):
    clean_dataset = case.get_clean_data(max_samples=samples)
    corrupted_dataset = case.get_corrupted_data(max_samples=samples)
    encoder = hl_model.tracr_input_encoder    
    print(encoder.encoding_map)
    def encode(tok):
        return encoder.encoding_map[tok]

    for dataset in (clean_dataset, corrupted_dataset):
        #Input
        #put task_id after BOS token and pads after EOS.
        inputs = dataset.inputs
        str_tokens = [encoder.decode(inputs[i].tolist()) for i in range(inputs.shape[0])]
        str_task_id = encoder.decode([task_id])
        pads = [TRACR_PAD] * (max_seq_len - case.get_max_seq_len())
        str_tokens = [str_task_id + [TRACR_BOS] + tokens[1:] + pads for tokens in str_tokens]
        inputs = torch.tensor([list(map(encode, tokens)) for tokens in str_tokens])

        #Target
        #add 0 to beginning of seq and a bunch of 0s to end.
        target = dataset.targets
        label = torch.zeros((target.shape[0], 1, target.shape[2]), dtype=target.dtype)
        pads = torch.zeros((target.shape[0], max_seq_len - case.get_max_seq_len(), target.shape[2]), dtype=target.dtype)
        target = torch.cat((label, target, pads), dim=1)
        dataset.inputs = inputs
        dataset.targets = target
    print(clean_dataset.inputs.shape, corrupted_dataset.inputs.shape)
    print(clean_dataset.targets.shape, corrupted_dataset.targets.shape)
    clean_datasets.append(clean_dataset)
    corrupted_datasets.append(corrupted_dataset)

    # if case.get_max_seq_len() < max_seq_len:
    #     # pad sequences
        




{'BOS': 0, 'PAD': 1, 'a': 2, 'b': 3, 'c': 4, 'x': 5}
torch.Size([320, 11]) torch.Size([320, 11])
torch.Size([320, 11, 1]) torch.Size([320, 11, 1])




{'(': 0, ')': 1, 'BOS': 2, 'PAD': 3, 'a': 4, 'b': 5, 'c': 6}
torch.Size([1000, 11]) torch.Size([1000, 11])
torch.Size([1000, 11, 1]) torch.Size([1000, 11, 1])


In [46]:
#make all datasets ~the same length by duplicating shorter datasets

max_length = max([clean_dataset.inputs.shape[0] for clean_dataset in clean_datasets])

for dset_list in [clean_datasets, corrupted_datasets]:
    for dataset in dset_list:
        if dataset.inputs.shape[0] < max_length:
            num_dups = max_length // dataset.inputs.shape[0]
            dataset.inputs = dataset.inputs.repeat(num_dups, 1)
            dataset.targets = dataset.targets.repeat(num_dups, 1, 1)
print(clean_datasets[0].inputs.shape, clean_datasets[1].inputs.shape)
print(corrupted_datasets[0].inputs.shape, corrupted_datasets[1].inputs.shape)

torch.Size([960, 11]) torch.Size([1000, 11])
torch.Size([960, 11]) torch.Size([1000, 11])


In [48]:
from circuits_benchmark.benchmark.tracr_encoded_dataset import TracrEncodedDataset
#smush datasets together into one big dataset, then shuffle
datasets = []
for dset_list in [clean_datasets, corrupted_datasets]:
    inputs = torch.cat([dset.inputs for dset in dset_list], dim=0)
    targets = torch.cat([dset.targets for dset in dset_list], dim=0)

    # shuffle dataset contents, keeping inputs and targets in sync
    indices = torch.randperm(inputs.shape[0])
    inputs = inputs[indices]
    targets = targets[indices]

    datasets.append(TracrEncodedDataset(inputs, targets))


In [49]:
loaders = [d.make_loader(batch_size = 256, shuffle=True, device=device) for d in datasets]

In [51]:
print(loaders)

[<torch.utils.data.dataloader.DataLoader object at 0x333f90d10>, <torch.utils.data.dataloader.DataLoader object at 0x333ed95d0>]


In [81]:
for hl_model in hl_models:
    print(hl_model.tracr_input_encoder.encoding_map)
    print(hl_model.cfg.d_vocab)

{'BOS': 0, 'PAD': 1, 'a': 2, 'b': 3, 'c': 4, 'x': 5}
6
{'(': 0, ')': 1, 'BOS': 2, 'PAD': 3, 'a': 4, 'b': 5, 'c': 6}
7


In [128]:
print(corrs[0].suffixes)
for corr in corrs[1].items():
    print(corr[0])
    print(corr[1])
    print()
print(corrs[1])

{'attn': 'attn.hook_result', 'mlp': 'mlp.hook_post'}
blocks.0.mlp.hook_post
{LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}

blocks.0.mlp.hook_post
{LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}

blocks.1.attn.hook_result
{LLNode(name='blocks.1.attn.hook_result', index=[:, :, 0, :], subspace=None)}

blocks.1.attn.hook_result
{LLNode(name='blocks.1.attn.hook_result', index=[:, :, 1, :], subspace=None)}

blocks.1.mlp.hook_post
{LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}

{TracrHLNode(name: blocks.0.mlp.hook_post,
 label: bools_close_7,
 classes: 0,
 index: [:]
): {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}, TracrHLNode(name: blocks.0.mlp.hook_post,
 label: bools_open_5,
 classes: 0,
 index: [:]
): {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}, TracrHLNode(name: blocks.1.attn.hook_result,
 label: opens_2,
 classes: 0,
 index: [:, :, 0, :]
): {LLNode(name='blocks.1.attn.hook_result', index=

In [115]:
from collections import defaultdict


    

defaultdict(<class 'list'>, {'blocks.0.attn.hook_result': [None, None], 'blocks.0.mlp.hook_post': [TracrHLNode(name: blocks.0.mlp.hook_post,
 label: is_x_3,
 classes: 0,
 index: [:]
), TracrHLNode(name: blocks.0.mlp.hook_post,
 label: bools_open_5,
 classes: 0,
 index: [:]
)], 'blocks.1.attn.hook_result': [TracrHLNode(name: blocks.1.attn.hook_result,
 label: frac_prevs_1,
 classes: 0,
 index: [:, :, 0, :]
), TracrHLNode(name: blocks.1.attn.hook_result,
 label: closes_3,
 classes: 0,
 index: [:, :, 1, :]
)], 'blocks.1.mlp.hook_post': [None, TracrHLNode(name: blocks.1.mlp.hook_post,
 label: pair_balance_1,
 classes: 0,
 index: [:]
)]})


In [80]:
print(ll_models[1].cfg.d_model)

20


In [203]:
from functools import partial
from typing import Optional
from transformer_lens.hook_points import HookedRootModule, HookPoint
from torch import nn

from iit.utils.index import TorchIndex

class MultiHighLevelModel(HookedRootModule):

    def __init__(self, hl_models, corrs, cases):
        super().__init__()
        self.hl_models = hl_models
        self.corrs = corrs
        self.cases = cases

        self.n_layers = max([mod.cfg.n_layers for mod in self.hl_models])
        self.n_heads = max([mod.cfg.n_heads for mod in self.hl_models])
        self.n_ctx = max([mod.cfg.n_ctx for mod in self.hl_models]) + 1
        self.tracr_d_heads = [mod.W_Q.shape[-1] for mod in self.hl_models]
        self.d_head = max(self.tracr_d_heads)
        self.tracr_d_mlps = [mod.W_in[0].shape[1] for mod in self.hl_models]
        self.d_mlp = max(self.tracr_d_mlps)
        self.d_models = [mod.cfg.d_model for mod in self.hl_models]
        self.attn_shapes = []

        #make hooks for each necessary attn head and each mlp
        self.input_hook = HookPoint()
        self.attn_hooks = nn.ModuleList([nn.ModuleList([HookPoint() for _ in range(self.n_heads)]) for _ in range(self.n_layers)])
        self.mlp_hooks = nn.ModuleList([HookPoint() for _ in range(self.n_layers)])


        corr_dict = defaultdict(list)

        for model_number, corr in enumerate(corrs):
            if corr.suffixes['attn'] == 'attn.hook_result':
                self.attn_shapes.append(self.d_models[model_number])
            else:
                self.attn_shapes.append(self.tracr_d_heads[model_number])
            corr_keys = defaultdict(list)
            for k in corr.keys():
                corr_keys[k.name].append(k)
            for i in range(self.n_layers):
                for k, suff in corr.suffixes.items():

                    name = f'blocks.{i}.{suff}'
                    if k == 'attn':
                        if model_number == 0:
                            for _ in range(self.n_heads):
                                corr_dict[name].append([])
                        if name in corr_keys.keys():
                            hl_nodes = corr_keys[name]
                            for hl_node in hl_nodes:
                                ll_nodes = corr[hl_node]
                                for head in range(self.n_heads):
                                    used = False
                                    for node in ll_nodes:
                                        if node.index.as_index[2] == head:
                                            corr_dict[name][head].append(corr_keys[name])
                                            used = True
                                            break
                                    if not used:
                                        corr_dict[name][head].append(None)
                        else:
                            for head in range(self.n_heads):
                                corr_dict[name][head].append(None)
                    elif k == 'mlp':
                        if name in corr_keys.keys():
                            corr_dict[name].append(corr_keys[name])
                        else:
                            corr_dict[name].append(None)
                    else:
                        raise ValueError(f"Unknown suffix in correspondence: {k}")
        self.attn_shape = max(self.attn_shapes)
        for k, v in corr_dict.items():
            print(k)
            print(v)
            print()
        self.corr_dict = corr_dict

        self.setup()
    
    def forward(self, x):
        # get sorting indices by task id
        tokens = self.input_hook(x)

        # Step 1 -- get all the activations.
        caches = []
        for i, hl_model in enumerate(self.hl_models):
            if isinstance(hl_model, IITHLModel):
                hl_model = hl_model.hl_model
            #clip token vocab down to task vocab size; replace out-of-task tokens with PAD.
            encoder = hl_model.tracr_input_encoder
            task_tokens = torch.clone(tokens)[:, 1:hl_model.cfg.n_ctx+1]
            task_tokens[task_tokens >= hl_model.cfg.d_vocab] = encoder.encoding_map[TRACR_PAD]
            _, cache = hl_model.run_with_cache(task_tokens)
            caches.append(cache)
        
        for k, i in caches[0].items():
            print(k, i.shape)
        
        #ActivationCache isn't writable so we need to copy it.
        caches = [{k: i for k, i in cache.items()} for cache in caches]
        # Step 2 -- construct all the hooks for THIS model.
        for layer in range(self.n_layers):
            # create MLP hooks
            mlp_storage = torch.zeros((len(self.hl_models), tokens.shape[0], tokens.shape[1], self.d_mlp))
            attn_storage = torch.zeros((len(self.hl_models), tokens.shape[0], tokens.shape[1], self.n_heads, self.attn_shape))
            for i, hl_model in enumerate(self.hl_models):

                #MLP
                suffix = self.corrs[i].suffixes['mlp']
                hook_name = f'blocks.{layer}.{suffix}'
                #unpack from cache
                if self.corr_dict[hook_name][i] is not None:
                    print(hook_name, mlp_storage.shape, caches[i][hook_name].shape)
                    mlp_storage[i,:,1:hl_model.cfg.n_ctx+1,:self.tracr_d_mlps[i]] = caches[i][hook_name]
                

                #Attn
                suffix = self.corrs[i].suffixes['attn']
                hook_name = f'blocks.{layer}.{suffix}'
                attn_shape = self.attn_shapes[i]
                #unpack from cache
                for head in range(self.n_heads):
                    if self.corr_dict[hook_name][head][i] is not None:
                        print(attn_storage[i,:,1:hl_model.cfg.n_ctx+1,head,:attn_shape].shape, caches[i][hook_name][:,:,head].shape)
                        attn_storage[i,:,1:hl_model.cfg.n_ctx+1,head,:attn_shape] = caches[i][hook_name][:,:,head]

            #modify with hook
            mlp_storage = self.mlp_hooks[layer](mlp_storage)
            for head in range(self.n_heads):
                attn_storage[:,:,:,head] = self.attn_hooks[layer][head](attn_storage[:,:,:,head])

                
            for i, hl_model in enumerate(self.hl_models):
                #pack back into cache
                # MLP
                suffix = self.corrs[i].suffixes['mlp']
                hook_name = f'blocks.{layer}.{suffix}'
                caches[i][hook_name] = mlp_storage[i,:,1:hl_model.cfg.n_ctx+1,:self.tracr_d_mlps[i]]
                # Attn
                suffix = self.corrs[i].suffixes['attn']
                hook_name = f'blocks.{layer}.{suffix}'
                attn_shape = self.attn_shapes[i]
                caches[i][hook_name] = attn_storage[i,:,1:hl_model.cfg.n_ctx+1,:hl_model.cfg.n_heads,:attn_shape]


        # Step 3 -- run with a bunch of hooks on the tracr models using the caches.
        task_ids = tokens[:, 0]
        outputs = torch.zeros((tokens.shape[0], tokens.shape[1], 1))
        for i, hl_model in enumerate(self.hl_models):
            if isinstance(hl_model, IITHLModel):
                hl_model = hl_model.hl_model
            #clip token vocab down to task vocab size; replace out-of-task tokens with PAD.
            encoder = hl_model.tracr_input_encoder
            task_tokens = torch.clone(tokens)[:, 1:hl_model.cfg.n_ctx+1]
            task_tokens[task_tokens >= hl_model.cfg.d_vocab] = encoder.encoding_map[TRACR_PAD]

            #TODO: add hooks here
            # basically, for each operation in the LL model's correspondence, we want to 

            def mlp_replacement_hook(x, hook):
                x[:] = caches[i][hook.name]
            
            def attn_replacement_hook(x, hook, index: Optional[TorchIndex] = None):
                x[index.as_index] = caches[i][hook.name][index.as_index]


            hooks = []
            #go through self.corr_dict and link up each hook with corresponding hook(s) in HL tracr models.
            for layer in range(self.n_layers):

                #MLP
                suffix = self.corrs[i].suffixes['mlp']
                hook_name = f'blocks.{layer}.{suffix}'
                #unpack from cache
                if self.corr_dict[hook_name][i] is not None:
                    hooks.append((hook_name, mlp_replacement_hook))
                
                # Attn
                suffix = self.corrs[i].suffixes['attn']
                hook_name = f'blocks.{layer}.{suffix}'
                for head in range(hl_model.cfg.n_heads):
                    if self.corr_dict[hook_name][head][i] is not None:
                        hooks.append((hook_name, partial(attn_replacement_hook, index=TorchIndex([None,None,head,None]))))

            # print(hooks)
            model_output = hl_model.run_with_hooks(task_tokens, fwd_hooks=hooks)
            outputs[task_ids == i, 1:hl_model.cfg.n_ctx+1,:] = model_output[task_ids == i]

        return outputs

model = MultiHighLevelModel(hl_models, corrs, cases)

for b1, b2 in zip(*tuple(loaders)):
    input, target = b1
    output = model(input)
    print(input[:1])
    print(output[:1])
    print(target[:1])
    break

blocks.0.attn.hook_result
[[None, None], [None, None]]

blocks.0.mlp.hook_post
[[TracrHLNode(name: blocks.0.mlp.hook_post,
 label: is_x_3,
 classes: 0,
 index: [:]
)], [TracrHLNode(name: blocks.0.mlp.hook_post,
 label: bools_close_7,
 classes: 0,
 index: [:]
), TracrHLNode(name: blocks.0.mlp.hook_post,
 label: bools_open_5,
 classes: 0,
 index: [:]
)]]

blocks.1.attn.hook_result
[[None, [TracrHLNode(name: blocks.1.attn.hook_result,
 label: opens_2,
 classes: 0,
 index: [:, :, 0, :]
), TracrHLNode(name: blocks.1.attn.hook_result,
 label: closes_3,
 classes: 0,
 index: [:, :, 1, :]
)], None], [None, None, [TracrHLNode(name: blocks.1.attn.hook_result,
 label: opens_2,
 classes: 0,
 index: [:, :, 0, :]
), TracrHLNode(name: blocks.1.attn.hook_result,
 label: closes_3,
 classes: 0,
 index: [:, :, 1, :]
)]]]

blocks.1.mlp.hook_post
[None, [TracrHLNode(name: blocks.1.mlp.hook_post,
 label: pair_balance_1,
 classes: 0,
 index: [:]
)]]

hook_embed torch.Size([256, 5, 13])
hook_pos_embed torch.Si