In [1]:
%load_ext autoreload
%autoreload 2

import torch

import circuits_benchmark.benchmark.cases.case_3 as case3
import circuits_benchmark.benchmark.cases.case_4 as case4
from circuits_benchmark.utils.ll_model_loader.ll_model_loader_factory import get_ll_model_loader
from circuits_benchmark.utils.iit.iit_hl_model import IITHLModel
from circuits_benchmark.transformers.hooked_tracr_transformer import HookedTracrTransformer
from circuits_benchmark.benchmark.vocabs import TRACR_BOS, TRACR_PAD

device = 'cuda' if torch.cuda.is_available() else 'cpu'


#load cases
cases = [case3.Case3(), case4.Case4()]
corrs = []
ll_models = []
hl_models = []
model_pairs = []
for case in cases:
    ll_model_loader = get_ll_model_loader(case, interp_bench=True)
    corr, ll_model = ll_model_loader.load_ll_model_and_correspondence(device=device)
    hl_model = case.get_hl_model()

    if isinstance(hl_model, HookedTracrTransformer):
        hl_model = IITHLModel(hl_model, eval_mode=True)

    model_pair = case.build_model_pair(ll_model=ll_model, hl_model=hl_model)

    corrs.append(corr)
    ll_models.append(ll_model)
    hl_models.append(hl_model)
    model_pairs.append(model_pair)


Moving model to device:  cpu
{'hook_embed': HookPoint(), 'hook_pos_embed': HookPoint(), 'blocks.0.attn.hook_k': HookPoint(), 'blocks.0.attn.hook_q': HookPoint(), 'blocks.0.attn.hook_v': HookPoint(), 'blocks.0.attn.hook_z': HookPoint(), 'blocks.0.attn.hook_attn_scores': HookPoint(), 'blocks.0.attn.hook_pattern': HookPoint(), 'blocks.0.attn.hook_result': HookPoint(), 'blocks.0.mlp.hook_pre': HookPoint(), 'blocks.0.mlp.hook_post': HookPoint(), 'blocks.0.hook_attn_in': HookPoint(), 'blocks.0.hook_q_input': HookPoint(), 'blocks.0.hook_k_input': HookPoint(), 'blocks.0.hook_v_input': HookPoint(), 'blocks.0.hook_mlp_in': HookPoint(), 'blocks.0.hook_attn_out': HookPoint(), 'blocks.0.hook_mlp_out': HookPoint(), 'blocks.0.hook_resid_pre': HookPoint(), 'blocks.0.hook_resid_mid': HookPoint(), 'blocks.0.hook_resid_post': HookPoint(), 'blocks.1.attn.hook_k': HookPoint(), 'blocks.1.attn.hook_q': HookPoint(), 'blocks.1.attn.hook_v': HookPoint(), 'blocks.1.attn.hook_z': HookPoint(), 'blocks.1.attn.hook_



Moving model to device:  cpu
{'hook_embed': HookPoint(), 'hook_pos_embed': HookPoint(), 'blocks.0.attn.hook_k': HookPoint(), 'blocks.0.attn.hook_q': HookPoint(), 'blocks.0.attn.hook_v': HookPoint(), 'blocks.0.attn.hook_z': HookPoint(), 'blocks.0.attn.hook_attn_scores': HookPoint(), 'blocks.0.attn.hook_pattern': HookPoint(), 'blocks.0.attn.hook_result': HookPoint(), 'blocks.0.mlp.hook_pre': HookPoint(), 'blocks.0.mlp.hook_post': HookPoint(), 'blocks.0.hook_attn_in': HookPoint(), 'blocks.0.hook_q_input': HookPoint(), 'blocks.0.hook_k_input': HookPoint(), 'blocks.0.hook_v_input': HookPoint(), 'blocks.0.hook_mlp_in': HookPoint(), 'blocks.0.hook_attn_out': HookPoint(), 'blocks.0.hook_mlp_out': HookPoint(), 'blocks.0.hook_resid_pre': HookPoint(), 'blocks.0.hook_resid_mid': HookPoint(), 'blocks.0.hook_resid_post': HookPoint(), 'blocks.1.attn.hook_k': HookPoint(), 'blocks.1.attn.hook_q': HookPoint(), 'blocks.1.attn.hook_v': HookPoint(), 'blocks.1.attn.hook_z': HookPoint(), 'blocks.1.attn.hook_

In [2]:
print(TRACR_BOS, TRACR_PAD)
print(sorted(list(case.get_vocab())))
for hl_model in hl_models:
    print(dir(hl_model.tracr_input_encoder))
    print(hl_model.tracr_input_encoder.encoding_map)

BOS PAD
['(', ')', 'a', 'b', 'c']
['__abstractmethods__', '__annotations__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__slots__', '__str__', '__subclasshook__', '__weakref__', '_abc_impl', '_bos_token', '_max_seq_len', '_pad_token', 'basis', 'bos_encoding', 'bos_token', 'decode', 'encode', 'encoding_map', 'enforce_bos', 'pad_encoding', 'pad_token', 'vocab_size']
{'BOS': 0, 'PAD': 1, 'a': 2, 'b': 3, 'c': 4, 'x': 5}
['__abstractmethods__', '__annotations__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce

In [3]:
#find overall min and max sequence length (we'll have to pad some)
min_seq_len = 100000
max_seq_len = 0

for case in cases:
    if case.get_min_seq_len() < min_seq_len:
        min_seq_len = case.get_min_seq_len()
    if case.get_max_seq_len() > max_seq_len:
        max_seq_len = case.get_max_seq_len()
print(min_seq_len, max_seq_len)

4 10


In [4]:
# find max vocab size (some tasks will just use a subset of vocab size)
vocab_size = 0
for case in cases:
    if len(case.get_vocab()) > vocab_size:
        vocab_size = len(case.get_vocab())
print(vocab_size)


5


In [5]:
max_case_samples = 1_000
case_samples = []
for case in cases:
    num_samples = min(max_case_samples, case.get_total_data_len())
    case_samples.append(num_samples)
print(case_samples)

[320, 1000]


In [39]:
# generate clean and corrupted datasets
clean_datasets = []
corrupted_datasets = []
for task_id, case, samples, hl_model in zip(range(len(cases)), cases, case_samples, hl_models):
    clean_dataset = case.get_clean_data(max_samples=samples)
    corrupted_dataset = case.get_corrupted_data(max_samples=samples)
    encoder = hl_model.tracr_input_encoder    
    print(encoder.encoding_map)
    def encode(tok):
        return encoder.encoding_map[tok]

    for dataset in (clean_dataset, corrupted_dataset):
        #Input
        #put task_id after BOS token and pads after EOS.
        inputs = dataset.inputs
        str_tokens = [encoder.decode(inputs[i].tolist()) for i in range(inputs.shape[0])]
        str_task_id = encoder.decode([task_id])
        pads = [TRACR_PAD] * (max_seq_len - case.get_max_seq_len())
        str_tokens = [str_task_id + [TRACR_BOS] + tokens[1:] + pads for tokens in str_tokens]
        inputs = torch.tensor([list(map(encode, tokens)) for tokens in str_tokens])

        #Target
        #add 0 to beginning of seq and a bunch of 0s to end.
        target = dataset.targets
        label = torch.zeros((target.shape[0], 1, target.shape[2]), dtype=target.dtype)
        pads = torch.zeros((target.shape[0], max_seq_len - case.get_max_seq_len(), target.shape[2]), dtype=target.dtype)
        target = torch.cat((label, target, pads), dim=1)
        dataset.inputs = inputs
        dataset.targets = target
    print(clean_dataset.inputs.shape, corrupted_dataset.inputs.shape)
    print(clean_dataset.targets.shape, corrupted_dataset.targets.shape)

    # if case.get_max_seq_len() < max_seq_len:
    #     # pad sequences
        




{'BOS': 0, 'PAD': 1, 'a': 2, 'b': 3, 'c': 4, 'x': 5}
torch.Size([320, 11]) torch.Size([320, 11])
torch.Size([320, 11, 1]) torch.Size([320, 11, 1])




{'(': 0, ')': 1, 'BOS': 2, 'PAD': 3, 'a': 4, 'b': 5, 'c': 6}
torch.Size([1000, 11]) torch.Size([1000, 11])
torch.Size([1000, 11, 1]) torch.Size([1000, 11, 1])
