In [72]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import einops
from fancy_einsum import einsum
import tqdm.auto as tqdm
import plotly.express as px
import numpy as np
from dataclasses import dataclass
import matplotlib.pyplot as plt

import os
import datetime
import pickle

from jaxtyping import Float
from functools import partial

import wandb

In [73]:
import transformer_lens

In [74]:
DEVICE="cuda"

In [75]:
N_CTX = 64
D_VOCAB = 11  # 10 digits and a comma
cfg = transformer_lens.HookedTransformerConfig(
    n_layers=1,
    d_model=128,
    n_ctx=N_CTX,
    d_head=64,
    n_heads=1,
    d_mlp=128,
    d_vocab=D_VOCAB,
    act_fn="relu",
    seed=42,
    device=DEVICE,
    attn_only=False,
)
model = transformer_lens.HookedTransformer(cfg, move_to_device=True)

In [76]:
def load_model_state(model: transformer_lens.HookedTransformer, filename: str):
    assert os.path.isdir("models"), "Make a directory `models` with model state dicts"
    if not filename.startswith("models/"):
        filename = f"models/{filename}"
    with open(filename, "rb") as f:
        state_dict = pickle.load(f)
    model.load_state_dict(state_dict)

load_model_state(model, "addition_model_state_dict_2024-06-19T16-18.pkl")

In [77]:
def tokenize(c: str):
    return ord(c) - ord("0") if c.isdigit() else 10  # 10 is comma

def str_to_tokens(seq_str):
    return torch.tensor([tokenize(c) for c in seq_str], device=DEVICE)

In [78]:
test_example = "1234,0373,"
example_toks = str_to_tokens(test_example)
if len(example_toks) < cfg.n_ctx:
    example_toks = F.pad(example_toks, (cfg.n_ctx - len(example_toks), 0), value=10)
example_toks

tensor([10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
        10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
        10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
         1,  2,  3,  4, 10,  0,  3,  7,  3, 10], device='cuda:0')

In [79]:
logits = model(example_toks)[:, -1, :]
pred = logits.argmax(dim=-1)
print(pred)

for _ in range(5):
    example_toks = torch.cat([example_toks[1:], pred], dim=0)
    logits = model(example_toks)[:, -1, :]
    pred = logits.argmax(dim=-1)
    print(pred)

tensor([1], device='cuda:0')
tensor([6], device='cuda:0')
tensor([0], device='cuda:0')
tensor([7], device='cuda:0')
tensor([8], device='cuda:0')
tensor([2], device='cuda:0')
