In [6]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from transformer_lens import HookedTransformer, HookedTransformerConfig, utils
import circuitsvis as cv
import plotly.express as px
import pandas as pd
from fancy_einsum import einsum

import itertools

In [7]:
# model parameter constants
N_LAYERS = 1
N_HEADS = 1
D_MODEL = 32
D_HEAD = 32
D_MLP = None
D_VOCAB = 64
SEED = 123
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [8]:

def line(tensor, line_labels=None, yaxis="", xaxis="", **kwargs):
    if isinstance(tensor, torch.Tensor):
        tensor = tensor.detach().numpy()
    
    df = pd.DataFrame({
        xaxis: np.arange(len(tensor)),
        yaxis: tensor
    })
    
    fig = px.line(df, x=xaxis, y=yaxis, labels={xaxis: xaxis, yaxis: yaxis}, **kwargs)
    
    if line_labels:
        for c, label in enumerate(line_labels):
            fig.data[c].name = label

    fig.show()

In [9]:
# data generation
import random

FIXED_LENGTH = 2

def generateLists(n, training=True):
    output = []
    for _ in range(n):
        curr = []
        for _ in range(FIXED_LENGTH):
            if training:
                curr.append(random.randint(0, D_VOCAB - 1))
            else:
                curr.append(random.randint(D_VOCAB // 2, D_VOCAB - 1))
            # curr.append(random.randint(0, 100))

        output.append(curr)

    output = torch.tensor(output)
    return output

In [10]:
def separate_data(n, split=0.7):
    output = list(itertools.product(range(D_VOCAB), repeat=FIXED_LENGTH))

    random.shuffle(output)

    split_index = int(len(output) * split)
    # return training, testing
    return torch.tensor(output[:split_index]), torch.tensor(output[split_index:])

print(separate_data(2, 0.7))

(tensor([[24, 43],
        [57, 54],
        [40, 43],
        ...,
        [47, 41],
        [61, 53],
        [46,  7]]), tensor([[16, 60],
        [31, 33],
        [25, 54],
        ...,
        [45, 37],
        [ 6, 46],
        [24, 59]]))


In [11]:
def output_data(data, batch_size=128):
    for i in range(0, len(data), batch_size):
        yield data[i:i + batch_size]

In [12]:
def raw_generate():
    output = []
    for i in range(D_VOCAB):
        for j in range(D_VOCAB):
            curr = [i, j]
            output.append(curr)

    random.shuffle(output)
    return output

def cross_val_generate(output, epoch, total_epoch):
    testing_size = int(len(output) / total_epoch)
    split_index = testing_size * epoch
    # return training, testing
    testing = torch.tensor(output[split_index:split_index + testing_size])
    training = torch.tensor(output[:split_index] + output[split_index + testing_size:])
    return training, testing

In [13]:
def generateVariableLists(n, training=True):
    output = []
    for _ in range(n):
        i = 0
        j = random.randint(0, FIXED_LENGTH)
        curr = []
        while i < j:
            curr.append(random.randint(0, D_VOCAB - 1))
            i += 1
        while i < FIXED_LENGTH:
            curr.append(0)
            i += 1
        output.append(curr)

    output = torch.tensor(output)
    return output

In [14]:
# model setup
cfg = HookedTransformerConfig(
    d_model=D_MODEL,
    n_layers=N_LAYERS,
    n_heads=N_HEADS,
    d_head=D_HEAD,
    n_ctx=FIXED_LENGTH,
    d_vocab=D_VOCAB,
    act_fn="relu",
    seed=SEED,
    device=DEVICE,
    attn_only=True
)

# hooked transformer used for interpretation later
model = HookedTransformer(cfg, move_to_device=True)

In [15]:
def loss_function(logits, tokens, return_per_token=True, print_tokens=False):
    # we take the last element of the logits to make the next prediction
    logits = logits[:, -1, :]
    answer = torch.max(tokens, dim=1)[0]
    log_prob = logits.log_softmax(-1)
    if print_tokens:
        print("tokens", tokens)
        print("predicted", torch.argmax(logits, dim=-1))
    # shape is (batch_size, 1) which represents probabilities 
    # of the correct answer
    output_prob = log_prob.gather(-1, answer.unsqueeze(-1))
    if return_per_token:
        return -output_prob.squeeze()
    return -output_prob.mean()

In [16]:
def accuracy(logits, tokens, return_per_token=False):
    logits = logits[:, -1, :]
    predicted = torch.argmax(logits, dim=1)
    answer = torch.max(tokens, dim=1)[0]
    if return_per_token:
        return (predicted == answer).float()
    
    mismatches = predicted != answer
    if mismatches.any():
        print("Mismatches found:")
        for idx in torch.where(mismatches)[0]:
            print(f"Index: {tokens[idx]}, Predicted: {predicted[idx]}, Actual: {answer[idx]}")

    return (predicted == answer).float().mean().item()

In [17]:
x = torch.tensor([[1], [2], [3], [4]])

In [18]:
def train_model(model, n_epochs, batch_size, batches_per, sequence_length=2):
    lr = 1e-3
    betas = (0.9, 0.999)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=betas)

    train_losses = []
    for epoch in range(n_epochs):
        epoch_losses = []
        for _ in range(batches_per):
            tokens = generateVariableLists(batch_size, training=True)
            logits = model(tokens)
            # print(tokens.shape)
            # print(logits.shape)
            losses = loss_function(logits, tokens, print_tokens=False)
            losses.mean().backward()
            optimizer.step()
            optimizer.zero_grad()
            epoch_losses.extend(losses.detach())

        train_losses.append(np.mean(epoch_losses))
        if epoch % 10 == 0:
            print(f"Epoch {epoch}, train loss: {train_losses[-1]}")

    model.eval()
    # might want to create the training and testing set beforehand
    test_data = generateVariableLists(1280, training=True)
    logits = model(test_data)
    acc = accuracy(logits, test_data, return_per_token=False)

    print(f"Test accuracy: {acc}")

    return train_losses
                

In [19]:
def train_model2(model, n_epochs, batch_size, batches_per, sequence_length=2):
    lr = 1e-3
    betas = (0.9, 0.999)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=betas)

    train_losses = []
    training_data, testing_data = separate_data(batch_size)

    for epoch in range(n_epochs):
        data_generator = output_data(training_data, batch_size)
        epoch_losses = []
        for _ in range(batches_per):
            tokens = next(data_generator)
            logits = model(tokens)
            # print(tokens.shape)
            # print(logits.shape)
            losses = loss_function(logits, tokens, print_tokens=False)
            losses.mean().backward()
            optimizer.step()
            optimizer.zero_grad()
            epoch_losses.extend(losses.detach())

        train_losses.append(np.mean(epoch_losses))
        if epoch % 10 == 0:
            print(f"Epoch {epoch}, train loss: {train_losses[-1]}")

    model.eval()
    # might want to create the training and testing set beforehand
    logits = model(testing_data)
    acc = accuracy(logits, testing_data, return_per_token=False)

    print(f"Test accuracy: {acc}")

    return train_losses
                

In [20]:
def cross_val(model, n_epochs, batch_size, batches_per, sequence_length=2):
    lr = 1e-3
    betas = (0.9, 0.999)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=betas)

    train_losses = []
    all_data = raw_generate()

    for epoch in range(n_epochs):
        tokens, test = cross_val_generate(all_data, epoch, n_epochs - 1)
        epoch_losses = []
        for _ in range(batches_per):
            logits = model(tokens)
            # print(tokens.shape)
            # print(logits.shape)
            losses = loss_function(logits, tokens, print_tokens=False)
            losses.mean().backward()
            optimizer.step()
            optimizer.zero_grad()
            epoch_losses.extend(losses.detach())

        train_losses.append(np.mean(epoch_losses))

        model.eval()
        logits = model(test)
        acc = accuracy(logits, test, return_per_token=False)
        if epoch % 10 == 0:
            print(f"Epoch {epoch}, train loss: {train_losses[-1]}")
            print(f"Test accuracy: {acc}")

    return train_losses
                

In [21]:
losses = train_model2(model, 200, 128, 10, 2)

Epoch 0, train loss: 4.1202392578125
Epoch 10, train loss: 0.9903205037117004
Epoch 20, train loss: 0.2248077094554901
Epoch 30, train loss: 0.08770783990621567
Epoch 40, train loss: 0.045312389731407166
Epoch 50, train loss: 0.027901560068130493
Epoch 60, train loss: 0.019090991467237473
Epoch 70, train loss: 0.013963868841528893
Epoch 80, train loss: 0.010690105147659779
Epoch 90, train loss: 0.008459190838038921
Epoch 100, train loss: 0.006864064838737249
Epoch 110, train loss: 0.005680583883076906
Epoch 120, train loss: 0.00477637629956007
Epoch 130, train loss: 0.004068854730576277
Epoch 140, train loss: 0.003504190593957901
Epoch 150, train loss: 0.0030459598638117313
Epoch 160, train loss: 0.0026687586214393377
Epoch 170, train loss: 0.002354439813643694
Epoch 180, train loss: 0.0020896729547530413
Epoch 190, train loss: 0.0018645531963557005
Mismatches found:
Index: tensor([41, 43]), Predicted: 41, Actual: 43
Index: tensor([17, 16]), Predicted: 16, Actual: 17
Index: tensor([30,

In [22]:
line(losses, xaxis="Epoch", yaxis="Loss")

In [23]:
training_data, testing_data = separate_data(128)
train_data_gen = output_data(training_data)
tokens = next(train_data_gen)

inserted = torch.tensor([[43, 44]])
tokens = torch.cat((inserted, tokens), dim=0)
logits, cache = model.run_with_cache(tokens)

In [24]:
attention_pattern = cache["pattern", 0, "attn"]

print(torch.argmax(logits[:, -1, :], dim=-1))

tensor([44, 40, 50, 47, 39, 57, 60, 55, 51, 54, 31, 51, 52, 46, 49, 30, 34, 57,
        57, 54, 63, 59, 43, 45, 28, 46, 54, 40, 42, 63, 31, 58, 36, 31, 38, 57,
        18, 35, 60, 34, 37, 29, 11, 29, 39, 44, 11, 61, 39, 59, 63, 61, 55, 28,
        55, 60, 10, 62, 36, 21, 42, 21, 42, 51, 55, 44, 48, 44, 18, 57, 48, 39,
         8, 54, 51, 11, 48, 60, 48, 45, 58, 51, 32, 63, 19, 49, 21, 58, 60, 60,
        63, 49, 14, 26, 40, 39, 36, 59, 62, 21, 10, 39, 40, 47, 57, 62,  7, 48,
        29, 42,  9, 45, 54, 52, 35, 54, 58, 61, 30, 55, 49, 44, 11, 49, 43, 33,
        22, 50, 15])


In [25]:
cv.attention.attention_heads(tokens=list(map(lambda t: str(t.item()), tokens[0])), attention=attention_pattern[0])

In [26]:
cv.attention.attention_patterns(tokens=list(map(lambda t: str(t.item()), tokens[0])), attention=attention_pattern[0])

In [27]:
cv.attention.attention_patterns(tokens=list(map(lambda t: str(t.item()), tokens[1])), attention=attention_pattern[1])

In [28]:
cv.attention.attention_patterns(tokens=list(map(lambda t: str(t.item()), tokens[2])), attention=attention_pattern[2])

In [29]:
cv.attention.attention_patterns(tokens=list(map(lambda t: str(t.item()), tokens[3])), attention=attention_pattern[3])

In [30]:
cv.attention.attention_patterns(tokens=list(map(lambda t: str(t.item()), tokens[4])), attention=attention_pattern[4])

In [31]:
cv.attention.attention_patterns(tokens=list(map(lambda t: str(t.item()), tokens[5])), attention=attention_pattern[5])