In [1]:
import jax
# jax.config.update('jax_default_matmul_precision', 'float32')
from tracr.compiler import compiling
from tracr.compiler import lib
from tracr.rasp import rasp

In [2]:

input_size = 10
vocab = {*range(input_size)}
program = lib.make_sort(rasp.tokens, rasp.tokens, max_seq_len=input_size, min_key=0)

assembled_model = compiling.compile_rasp_to_model(
      program=program,
      vocab=vocab,
      max_seq_len=input_size,
      compiler_bos="bos",
      mlp_exactness=100)

In [3]:
import pprint
pprint.pprint(assembled_model.model_config)

TransformerConfig(num_heads=1,
                  num_layers=3,
                  key_size=12,
                  mlp_hidden_size=100,
                  dropout_rate=0.0,
                  activation_function=<jax._src.custom_derivatives.custom_jvp object at 0x120d6fdd0>,
                  layer_norm=False,
                  causal=False)


In [4]:
from utils import cfg_from_tracr, load_tracr_weights
from transformer_lens import HookedTransformer
cfg = cfg_from_tracr(assembled_model)
model = HookedTransformer(cfg)

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
assembled_model

AssembledTransformerModel(forward=<function without_apply_rng.<locals>.apply_fn at 0x144aff560>, get_compiled_model=<function assemble_craft_model.<locals>.get_compiled_model at 0x104fe5440>, params={'pos_embed': {'embeddings': Array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0.],
       [0., 0

In [8]:
model = load_tracr_weights(model, assembled_model, cfg)

In [9]:
model.state_dict().keys()

odict_keys(['embed.W_E', 'pos_embed.W_pos', 'blocks.0.attn.W_Q', 'blocks.0.attn.W_O', 'blocks.0.attn.b_Q', 'blocks.0.attn.b_O', 'blocks.0.attn.W_K', 'blocks.0.attn.W_V', 'blocks.0.attn.b_K', 'blocks.0.attn.b_V', 'blocks.0.attn.mask', 'blocks.0.attn.IGNORE', 'blocks.0.mlp.W_in', 'blocks.0.mlp.b_in', 'blocks.0.mlp.W_out', 'blocks.0.mlp.b_out', 'blocks.1.attn.W_Q', 'blocks.1.attn.W_O', 'blocks.1.attn.b_Q', 'blocks.1.attn.b_O', 'blocks.1.attn.W_K', 'blocks.1.attn.W_V', 'blocks.1.attn.b_K', 'blocks.1.attn.b_V', 'blocks.1.attn.mask', 'blocks.1.attn.IGNORE', 'blocks.1.mlp.W_in', 'blocks.1.mlp.b_in', 'blocks.1.mlp.W_out', 'blocks.1.mlp.b_out', 'blocks.2.attn.W_Q', 'blocks.2.attn.W_O', 'blocks.2.attn.b_Q', 'blocks.2.attn.b_O', 'blocks.2.attn.W_K', 'blocks.2.attn.W_V', 'blocks.2.attn.b_K', 'blocks.2.attn.b_V', 'blocks.2.attn.mask', 'blocks.2.attn.IGNORE', 'blocks.2.mlp.W_in', 'blocks.2.mlp.b_in', 'blocks.2.mlp.W_out', 'blocks.2.mlp.b_out', 'unembed.W_U', 'unembed.b_U'])

In [10]:
for k, v in model.state_dict().items():
    print(k, v.shape)

embed.W_E torch.Size([12, 55])
pos_embed.W_pos torch.Size([11, 55])
blocks.0.attn.W_Q torch.Size([1, 55, 12])
blocks.0.attn.W_O torch.Size([1, 12, 55])
blocks.0.attn.b_Q torch.Size([1, 12])
blocks.0.attn.b_O torch.Size([55])
blocks.0.attn.W_K torch.Size([1, 55, 12])
blocks.0.attn.W_V torch.Size([1, 55, 12])
blocks.0.attn.b_K torch.Size([1, 12])
blocks.0.attn.b_V torch.Size([1, 12])
blocks.0.attn.mask torch.Size([11, 11])
blocks.0.attn.IGNORE torch.Size([])
blocks.0.mlp.W_in torch.Size([55, 100])
blocks.0.mlp.b_in torch.Size([100])
blocks.0.mlp.W_out torch.Size([100, 55])
blocks.0.mlp.b_out torch.Size([55])
blocks.1.attn.W_Q torch.Size([1, 55, 12])
blocks.1.attn.W_O torch.Size([1, 12, 55])
blocks.1.attn.b_Q torch.Size([1, 12])
blocks.1.attn.b_O torch.Size([55])
blocks.1.attn.W_K torch.Size([1, 55, 12])
blocks.1.attn.W_V torch.Size([1, 55, 12])
blocks.1.attn.b_K torch.Size([1, 12])
blocks.1.attn.b_V torch.Size([1, 12])
blocks.1.attn.mask torch.Size([11, 11])
blocks.1.attn.IGNORE torch.Si

In [12]:
# print the content of the model
for k, v in model.state_dict().items():
    print(k, v)

embed.W_E tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
         0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 