In [1]:
cd ~/git/tracr

/Users/ellenar/git/tracr


In [2]:
#@title Imports
import jax
import numpy as np

# The default of float16 can lead to discrepancies between outputs of
# the compiled model and the RASP program.
jax.config.update('jax_default_matmul_precision', 'float32')

from tracr.compiler import compiling
from tracr.rasp import rasp
import torch
import torch.nn as nn

In [3]:
from model.transformer_model import TransformerModel

In [4]:
def make_length():
  all_true_selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.TRUE)
  return rasp.SelectorWidth(all_true_selector)

In [5]:
def reverse():
    length = make_length()  # `length` is not a primitive in our implementation.
    opp_index = length - rasp.indices - 1
    flip = rasp.Select(rasp.indices, opp_index, rasp.Comparison.EQ)
    reverse = rasp.Aggregate(flip, rasp.tokens)
    return reverse
reverse= reverse()



In [6]:
bos = "bos"
compiled_model = compiling.compile_rasp_to_model(
    reverse,
    vocab={0, 1, 2, 3, 4, 5, 6, 7, 8, 9},
    max_seq_len=9,
    compiler_bos=bos,
)

In [7]:
def extract_config(model, act_fn = "relu"):
    model_config = {'activation_function': act_fn}
    for key,val in model.model_config.__dict__.items():
        if key == 'activation_function':
            continue
        model_config[key]=val
    
    model_config["max_seq_len"] = model.params["pos_embed"]['embeddings'].shape[0] 
    model_config["vocab_size"] = model.params["token_embed"]['embeddings'].shape[0] # Vocab size plus 2 for BOS and PAD
    model_config["vocab_size_out"] = model_config["vocab_size"] - 2
    model_config["hidden_size"] = model.params["token_embed"]['embeddings'].shape[1]
    return model_config 
model_config = extract_config(compiled_model)

In [8]:
sd = {}
for name in compiled_model.params:
    if 'transformer' in name:
        _, layer, module, param = name.split('/')
        layer_num = layer.split('_')[1]
        sd[f"layers.{layer_num}.{module}.{param}.weight"] = torch.transpose(torch.tensor(np.array(compiled_model.params[name]['w'])), 0, 1)
        sd[f"layers.{layer_num}.{module}.{param}.bias"] = torch.tensor(np.array(compiled_model.params[name]['b']))
    else:
        sd[f"{name}.embeddings"] = torch.tensor(np.array(compiled_model.params[name]['embeddings']))

In [9]:
trans_model = TransformerModel(model_config).eval()
trans_model.load_state_dict(sd)

<All keys matched successfully>

In [11]:
x = compiled_model.input_encoder.encode(["bos", 5,4,3])
out = trans_model(x)

In [12]:
out[0,0]

tensor(0., grad_fn=<SelectBackward0>)

In [13]:
max_output_indices = torch.argmax(out, dim=1)
max_output_indices

tensor([0, 0, 0, 0])