In [1]:
import einops
import torch
import numpy as np
import matplotlib.pyplot as plt
from transformer_lens import HookedTransformer, HookedTransformerConfig, utils


In [19]:
# model parameter constants
N_LAYERS = 1
N_HEADS = 1
D_MODEL = 32
D_HEAD = 32
D_MLP = None
D_VOCAB = 64
SEED = 123
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

Generating lists of a fixed length parameter

In [20]:
# data generation
import random

FIXED_LENGTH = 2

def generateLists(n):
    output = []
    for _ in range(n):
        curr = []
        for _ in range(FIXED_LENGTH):
            curr.append(random.randint(0, D_VOCAB - 1))

        # maximum = max(curr)
        # output.append((curr, maximum))
        output.append(curr)

    output = torch.tensor(output)
    return output

Model paramters. We are using one layer, one attention head (which pays to the 
tokens in contex to another), the dimensions of the model, dimension of the head,
vocab is the size of the logits

In [21]:
# model setup
cfg = HookedTransformerConfig(
    d_model=D_MODEL,
    n_layers=N_LAYERS,
    n_heads=N_HEADS,
    d_head=D_HEAD,
    n_ctx=2,
    d_vocab=D_VOCAB,
    act_fn="relu",
    seed=SEED,
    device=DEVICE,
    attn_only=True
)

# hooked transformer used for interpretation later
model = HookedTransformer(cfg, move_to_device=True)

In [51]:
def loss_function(logits, tokens, return_per_token=True):
    # we take the last element of the logits to make the next prediction
    logits = logits[:, -1, :]
    print("tokens", tokens)
    answer = torch.max(tokens, dim=1)[0]
    log_prob = logits.log_softmax(-1)
    print("predicted", torch.argmax(logits, dim=-1))
    # shape is (batch_size, 1) which represents probabilities 
    # of the correct answer
    output_prob = log_prob.gather(-1, answer.unsqueeze(-1))
    if return_per_token:
        return -1 * output_prob.squeeze()
    return -1 * output_prob.mean()

In [46]:
def accuracy(logits, tokens, return_per_token=False):
    logits = logits[:, -1, :]
    predicted = torch.argmax(logits, dim=1)
    answer = torch.max(tokens, dim=1)[0]
    if return_per_token:
        return (predicted == answer).float()
    return (predicted == answer).float().mean().item()

In [10]:
x = torch.tensor([[1], [2], [3], [4]])

torch.Size([4, 1])
tensor([1, 2, 3, 4])


In [47]:
def train_model(model, n_epochs, batch_size, batches_per, sequence_length):
    lr = 1e-3
    betas = (0.9, 0.999)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=betas)

    train_losses = []
    for epoch in range(n_epochs):
        epoch_losses = []
        for _ in range(batches_per):
            tokens = generateLists(batch_size)
            logits = model(tokens)
            # print(tokens.shape)
            # print(logits.shape)
            losses = loss_function(logits, tokens)
            losses.mean().backward()
            optimizer.step()
            optimizer.zero_grad()
            epoch_losses.extend(losses.detach())

        train_losses.append(np.mean(epoch_losses))
        if epoch % 10 == 0:
            print(f"Epoch {epoch}, train loss: {train_losses[-1]}")

    model.eval()
    test_data = generateLists(1280)
    logits = model(test_data)
    acc = accuracy(logits, test_data, return_per_token=False)

    print(f"Test accuracy: {acc}")

    return losses
                

In [52]:
losses = train_model(model, 1, 128, 10, 2)

tokens tensor([[22,  7],
        [23,  2],
        [52, 16],
        [15, 53],
        [40, 22],
        [36, 28],
        [40, 43],
        [20, 12],
        [50,  1],
        [46, 18],
        [25, 38],
        [25, 44],
        [55, 62],
        [ 4, 31],
        [33, 38],
        [55, 59],
        [44,  8],
        [28, 32],
        [51, 61],
        [39,  6],
        [61, 15],
        [ 3, 40],
        [19, 25],
        [17, 28],
        [ 6, 54],
        [47, 36],
        [44, 32],
        [31, 51],
        [62, 40],
        [18, 42],
        [ 3, 23],
        [62, 38],
        [37, 21],
        [23,  8],
        [ 4, 39],
        [26, 23],
        [58, 27],
        [20, 17],
        [56, 27],
        [33, 54],
        [21,  4],
        [47, 28],
        [ 0, 11],
        [35, 54],
        [32, 46],
        [62, 17],
        [32, 55],
        [ 5, 29],
        [51, 43],
        [ 3, 60],
        [13, 27],
        [55, 53],
        [27, 42],
        [16, 61],
        [32,  7],
   