In [27]:
import jax
jax.config.update('jax_default_matmul_precision', 'float32')
from tracr.compiler import compiling
from tracr.compiler import lib
from tracr.rasp import rasp
from transformer_lens import HookedTransformerConfig, HookedTransformer

from utils import cfg_from_tracr, load_tracr_weights





input_size = 7 # Length of sequences
vocab_size = 5  # Vocabulary size


vocab = {*range(vocab_size)}
program = lib.make_sort(rasp.tokens, rasp.tokens, max_seq_len=input_size, min_key=0)

tracr_model = compiling.compile_rasp_to_model(
    program=program,
    vocab=vocab,
    max_seq_len=input_size,
    compiler_bos="bos",
    mlp_exactness=100)

cfg = cfg_from_tracr(tracr_model)
model = HookedTransformer(cfg)
# model = load_tracr_weights(model, tracr_model, cfg)

In [32]:
import torch as tc
import itertools
def get_all_sequences(max_seq_len, vocab_size):
    # Generate all possible sequences of length n with vocabulary size m
    sequences = itertools.product(range(vocab_size), repeat=max_seq_len)
    sequences = [tc.tensor(seq) for seq in sequences]
    return sequences
sequences = get_all_sequences(input_size, vocab_size)
print(f"Generated {len(sequences)} sequences of length {input_size} with vocabulary size {vocab_size}")

train_test_split = 0.7
train_size = int(len(sequences) * train_test_split)
train_sequences = sequences[:train_size]
test_sequences = sequences[train_size:]
print(f"Train size: {len(train_sequences)}")
print(f"Test size: {len(test_sequences)}")

Generated 78125 sequences of length 7 with vocabulary size 5
Train size: 54687
Test size: 23438


In [51]:
import torch as tc

def loss_function(outputs, targets):
    return tc.nn.functional.cross_entropy(outputs, targets, reduce=True, reduction='mean')

def accuracy_function(outputs, targets):
    return (outputs.argmax(1) == targets).float().mean()

In [40]:
train_loader = tc.utils.data.DataLoader(train_sequences, batch_size=16, shuffle=True)
inputs = next(iter(train_loader))
targets = tc.sort(inputs, dim=1).values
print(inputs.shape)
print(targets.shape)

torch.Size([16, 7])
torch.Size([16, 7])


In [52]:
model.to('cpu')
inputs = next(iter(train_loader))
outputs, cache = model.run_with_cache(inputs)
print(outputs.shape)
# output shape is (batch_size, vocab_size, input_size) but we need (batch_size, input_size, vocab_size)
outputs = outputs.permute(0, 2, 1)
accuracy = accuracy_function(outputs, targets) 
loss = loss_function(outputs, targets)
print(f"Initial accuracy: {accuracy.item()}")
print(f"Initial loss: {loss.item()}")

Moving model to device:  cpu
torch.Size([16, 7, 5])
Initial accuracy: 0.1607142835855484
Initial loss: 1.637537956237793




In [53]:
from tqdm import tqdm

train_loader = tc.utils.data.DataLoader(train_sequences, batch_size=256, shuffle=True)

model.to('cpu')
model.train()
optimizer = tc.optim.Adam(model.parameters(), lr=0.001)

for i, inputs in enumerate(tqdm(train_loader)):
    targets = tc.sort(inputs, dim=1).values
    optimizer.zero_grad()
    outputs, cache = model.run_with_cache(inputs)
    outputs = outputs.permute(0, 2, 1)
    loss = loss_function(outputs, targets)
    loss.backward()
    optimizer.step()
    if i % 100 == 0:
        accuracy = accuracy_function(outputs, targets) 
        print(f'batch {i}, loss: {loss.item()}', f'accuracy: {accuracy.item()}')

# TODO add testset evaluation

Moving model to device:  cpu


  1%|▏         | 3/214 [00:00<00:07, 26.92it/s]

batch 0, loss: 1.6612650156021118 accuracy: 0.1467633992433548


 50%|████▉     | 106/214 [00:03<00:03, 29.15it/s]

batch 100, loss: 0.1305752545595169 accuracy: 0.9631696343421936


 96%|█████████▌| 205/214 [00:07<00:00, 29.52it/s]

batch 200, loss: 0.0043729376047849655 accuracy: 1.0


100%|██████████| 214/214 [00:07<00:00, 28.21it/s]


78125


In [55]:
# test set
test_loader = tc.utils.data.DataLoader(test_sequences, batch_size=256, shuffle=True)
model.eval()
losses = []
accuracies = []
for i, inputs in enumerate(tqdm(test_loader)):
    targets = tc.sort(inputs, dim=1).values
    outputs, cache = model.run_with_cache(inputs)
    outputs = outputs.permute(0, 2, 1)
    loss = loss_function(outputs, targets)
    accuracy = accuracy_function(outputs, targets)
    losses.append(loss.item())
    accuracies.append(accuracy.item())
loss = tc.tensor(losses).mean()
accuracy = tc.tensor(accuracies).mean()
print(f"test loss: {loss.item()}")
print(f"test accuracy: {accuracy.item()}")


100%|██████████| 92/92 [00:01<00:00, 53.17it/s]

Test loss: 0.04582557454705238
Test accuracy: 0.98520827293396





-> model performs well on train and test

In [68]:
# save model with timestamp
import datetime
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
tc.save(model.state_dict(), f'saved_models/{timestamp}.pt')
# save config
import pickle
with open(f'saved_models/{timestamp}.cfg', 'wb') as f:
    pickle.dump(cfg, f)

In [69]:
# load config
import pickle
with open(f'saved_models/{timestamp}.cfg', 'rb') as f:
    cfg = pickle.load(f)
# load model
model = HookedTransformer(cfg)
model.load_state_dict(tc.load(f'saved_models/{timestamp}.pt'))

<All keys matched successfully>