In [529]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from transformer_lens import HookedTransformer, HookedTransformerConfig, utils


In [530]:
# model parameter constants
N_LAYERS = 1
N_HEADS = 1
D_MODEL = 32
D_HEAD = 32
D_MLP = None
D_VOCAB = 64
SEED = 123
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

Generating lists of a fixed length parameter

In [531]:
# data generation
import random

FIXED_LENGTH = 10

def generateLists(n, training=True):
    output = []
    for _ in range(n):
        curr = []
        for _ in range(FIXED_LENGTH):
            if training:
                curr.append(random.randint(0, D_VOCAB - 1))
            else:
                curr.append(random.randint(D_VOCAB // 2, D_VOCAB - 1))
            # curr.append(random.randint(0, 100))

        # maximum = max(curr)
        # output.append((curr, maximum))
        output.append(curr)

    output = torch.tensor(output)
    return output

Generating data function with cross-validation

In [532]:
FIXED_LENGTH = 10

def cross_generation(n, split=0.7):
    output = []
    for i in range(D_VOCAB):
        for j in range(D_VOCAB):
            curr = [i, j]
            output.append(curr)

    random.shuffle(output)

    split_index = int(len(output) * split)
    # return training, testing
    return torch.tensor(output[:split_index]), torch.tensor(output[split_index:])

In [533]:
def output_data(data, batch_size=128):
    for i in range(0, len(data), batch_size):
        yield data[i:i + batch_size]

Model paramters. We are using one layer, one attention head (which pays to the 
tokens in contex to another), the dimensions of the model, dimension of the head,
vocab is the size of the logits

In [534]:
# model setup
cfg = HookedTransformerConfig(
    d_model=D_MODEL,
    n_layers=N_LAYERS,
    n_heads=N_HEADS,
    d_head=D_HEAD,
    n_ctx=FIXED_LENGTH,
    d_vocab=D_VOCAB,
    act_fn="relu",
    seed=SEED,
    device=DEVICE,
    attn_only=True
)

# hooked transformer used for interpretation later
model = HookedTransformer(cfg, move_to_device=True)

In [535]:
def loss_function(logits, tokens, return_per_token=True, print_tokens=False):
    # we take the last element of the logits to make the next prediction
    logits = logits[:, -1, :]
    answer = torch.max(tokens, dim=1)[0]
    log_prob = logits.log_softmax(-1)
    if print_tokens:
        print("tokens", tokens)
        print("predicted", torch.argmax(logits, dim=-1))
    # shape is (batch_size, 1) which represents probabilities 
    # of the correct answer
    output_prob = log_prob.gather(-1, answer.unsqueeze(-1))
    if return_per_token:
        return -1 * output_prob.squeeze()
    return -1 * output_prob.mean()

In [536]:
def accuracy(logits, tokens, return_per_token=False):
    logits = logits[:, -1, :]
    predicted = torch.argmax(logits, dim=1)
    answer = torch.max(tokens, dim=1)[0]
    if return_per_token:
        return (predicted == answer).float()
    return (predicted == answer).float().mean().item()

In [537]:
x = torch.tensor([[1], [2], [3], [4]])

In [538]:
def train_model(model, n_epochs, batch_size, batches_per, sequence_length=2):
    lr = 1e-3
    betas = (0.9, 0.999)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=betas)

    train_losses = []
    for epoch in range(n_epochs):
        epoch_losses = []
        for _ in range(batches_per):
            tokens = generateLists(batch_size, training=True)
            logits = model(tokens)
            # print(tokens.shape)
            # print(logits.shape)
            losses = loss_function(logits, tokens, print_tokens=False)
            losses.mean().backward()
            optimizer.step()
            optimizer.zero_grad()
            epoch_losses.extend(losses.detach())

        train_losses.append(np.mean(epoch_losses))
        if epoch % 10 == 0:
            print(f"Epoch {epoch}, train loss: {train_losses[-1]}")

    model.eval()
    # might want to create the training and testing set beforehand
    test_data = generateLists(1280, training=True)
    logits = model(test_data)
    acc = accuracy(logits, test_data, return_per_token=False)

    print(f"Test accuracy: {acc}")

    return losses
                

In [539]:
def train_model2(model, n_epochs, batch_size, batches_per, sequence_length=2):
    lr = 1e-3
    betas = (0.9, 0.999)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=betas)

    train_losses = []
    training_data, testing_data = cross_generation(batch_size)
    print(training_data.shape)
    print(testing_data.shape)
    for epoch in range(n_epochs):
        data_generator = output_data(training_data, batch_size)
        epoch_losses = []
        for _ in range(batches_per):
            tokens = next(data_generator)
            logits = model(tokens)
            # print(tokens.shape)
            # print(logits.shape)
            losses = loss_function(logits, tokens, print_tokens=False)
            losses.mean().backward()
            optimizer.step()
            optimizer.zero_grad()
            epoch_losses.extend(losses.detach())

        train_losses.append(np.mean(epoch_losses))
        if epoch % 10 == 0:
            print(f"Epoch {epoch}, train loss: {train_losses[-1]}")

    model.eval()
    # might want to create the training and testing set beforehand
    logits = model(testing_data)
    acc = accuracy(logits, testing_data, return_per_token=False)

    print(f"Test accuracy: {acc}")

    return losses
                

In [540]:
losses = train_model2(model, 100, 128, 10, 3)

torch.Size([2867, 2])
torch.Size([1229, 2])
Epoch 0, train loss: 4.462099552154541
Epoch 10, train loss: 1.206030249595642
Epoch 20, train loss: 0.2497977912425995
Epoch 30, train loss: 0.09170398861169815
Epoch 40, train loss: 0.04693559929728508
Epoch 50, train loss: 0.0287907924503088
Epoch 60, train loss: 0.01964835822582245
Epoch 70, train loss: 0.014348438009619713
Epoch 80, train loss: 0.010973022319376469
Epoch 90, train loss: 0.008676661178469658
Test accuracy: 0.9568755030632019
