In [550]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from transformer_lens import HookedTransformer, HookedTransformerConfig, utils
import random

In [551]:
# model parameter constants
N_LAYERS = 1
N_HEADS = 1
D_MODEL = 32
D_HEAD = 32
D_MLP = None
D_VOCAB = 12
SEED = 123
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [552]:
MAX_DIGITS = 5
NUM_DIGITS = 2
# 10 is between numbers, 11 for filler space
CTX_LEN = NUM_DIGITS * MAX_DIGITS + 1

In [553]:
def generate_number_with_variable_length(max_digits):
    num_digits = random.randint(1, max_digits)
    if num_digits == 1:
        return random.randint(0, 9)  # For a single digit, range is 0-9
    else:
        return random.randint(10**(num_digits-1), 10**num_digits - 1)

def generate_data(n, training=True):
    tokens = []
    target = []

    for _ in range(n):
        curr = []
        answer = ""
        
        a = (generate_number_with_variable_length(MAX_DIGITS))
        b = (generate_number_with_variable_length(MAX_DIGITS))

        if a > b:
            answer = "a"
        else:
            answer = "b"

        i = 0
        for digit in str(a):
            curr.append(int(digit))
            i += 1
        while i < MAX_DIGITS:
            curr.append(11)
            i += 1
        
        curr.append(10)
        
        i = 0
        for digit in str(b):
            curr.append(int(digit))
            i += 1
        while i < MAX_DIGITS:
            curr.append(11)
            i += 1
        
        # if answer == "a":
        #     answer_list = curr[:MAX_DIGITS]
        # else:
        #     answer_list = curr[MAX_DIGITS+1:]

        answer_list = curr[MAX_DIGITS + 1:]

        # while len(answer_list) < CTX_LEN:
        #     answer_list.append(11)
            
        tokens.append(curr)
        target.append(answer_list)

    return torch.tensor(tokens), torch.tensor(target)

generate_data(10)

(tensor([[ 6,  2, 11, 11, 11, 10,  4,  5,  1, 11, 11],
         [ 2, 11, 11, 11, 11, 10,  4,  4,  3,  5, 11],
         [ 6,  2,  7,  5,  6, 10,  2,  0,  7, 11, 11],
         [ 1,  4,  7, 11, 11, 10,  8,  0,  5,  9,  6],
         [ 0, 11, 11, 11, 11, 10,  1, 11, 11, 11, 11],
         [ 5,  2,  7,  3,  6, 10,  4,  2,  1,  2, 11],
         [ 0, 11, 11, 11, 11, 10,  2,  6,  6,  4,  8],
         [ 8,  8,  0,  3,  9, 10,  4,  2, 11, 11, 11],
         [ 9,  7,  0, 11, 11, 10,  9,  2,  6,  2,  0],
         [ 8,  1,  4,  3, 11, 10,  2,  1,  3, 11, 11]]),
 tensor([[ 4,  5,  1, 11, 11],
         [ 4,  4,  3,  5, 11],
         [ 2,  0,  7, 11, 11],
         [ 8,  0,  5,  9,  6],
         [ 1, 11, 11, 11, 11],
         [ 4,  2,  1,  2, 11],
         [ 2,  6,  6,  4,  8],
         [ 4,  2, 11, 11, 11],
         [ 9,  2,  6,  2,  0],
         [ 2,  1,  3, 11, 11]]))

In [554]:
# model setup

# MIGHT HAVE TO CHANGE CONTEXT OR VOCAB
cfg = HookedTransformerConfig(
    d_model=D_MODEL,
    n_layers=N_LAYERS,
    n_heads=N_HEADS,
    d_head=D_HEAD,
    n_ctx=CTX_LEN,
    d_vocab=D_VOCAB,
    act_fn="relu",
    seed=SEED,
    device=DEVICE,
    attn_only=True
)

# hooked transformer used for interpretation later
model = HookedTransformer(cfg, move_to_device=True)

In [555]:
def loss_function(logits, target, return_per_token=True, print_tokens=False):
    answer = target
    logits = logits[:, MAX_DIGITS : -1, :]
    log_prob = logits.log_softmax(-1)
    output_prob = log_prob.gather(-1, answer[..., None])[..., 0]

    mask = (target != 11).float()
    masked_output_prob = output_prob * mask

    if print_tokens:
        print("target", target)
        print("predicted", torch.argmax(logits, dim=-1))
        # print(answer.unsqueeze(-1))
        # print(output_prob)
        
    if return_per_token:
        return -(masked_output_prob / mask.sum(dim=1, keepdim=True).clamp(min=1))
    return -(masked_output_prob.sum() / mask.sum())

In [556]:
def accuracy(logits, target, return_per_token=False):
    logits = logits[:, MAX_DIGITS : -1, :]
    predicted = torch.argmax(logits, dim=-1)
    answer = target
    print(predicted, answer)
    if return_per_token:
        return (predicted == answer).float()
    return (predicted == answer).float().mean().item()

In [557]:
import torch
import torch.nn as nn
import numpy as np

def train_model(model, n_epochs, batch_size, batches_per, sequence_length=2):
    lr = 1e-3
    betas = (0.9, 0.999)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=betas)
    
    # Assuming you've defined your model's vocabulary size somewhere
    vocab_size = 12 # Adjust this based on your actual vocab size
    train_losses = []
    # model.train()
    for epoch in range(n_epochs):
        epoch_losses = []
        for _ in range(batches_per):
            tokens, targets = generate_data(batch_size, training=True)  # Adjusted to expect tokens and targets
            logits = model(tokens)
            
            # print(tokens)
            losses = loss_function(logits, targets, print_tokens=False)
            
            losses.mean().backward()
            optimizer.step()
            optimizer.zero_grad()
            
            epoch_losses.extend(losses.detach())

        train_losses.append(np.mean(epoch_losses))
        if epoch % 10 == 0:
            print(f"Epoch {epoch}, train loss: {train_losses[-1]}")

    model.eval()

    test_data, targets = generate_data(1280, training=True)
    logits = model(test_data)

    acc = accuracy(logits, targets, return_per_token=False)

    print(f"Validation accuracy: {acc}")

    return train_losses


In [559]:
torch.set_printoptions(threshold=10000)
losses = train_model(model, 1, 128, 10)

Epoch 0, train loss: 0.4602588415145874
tensor([[8, 0, 0, 0, 8],
        [8, 8, 8, 8, 8],
        [8, 8, 2, 2, 3],
        [2, 2, 2, 2, 2],
        [2, 2, 0, 2, 3],
        [2, 2, 2, 0, 0],
        [8, 8, 8, 2, 0],
        [8, 8, 2, 2, 3],
        [2, 8, 2, 2, 3],
        [2, 8, 2, 2, 0],
        [8, 2, 2, 2, 0],
        [8, 8, 3, 2, 8],
        [8, 8, 2, 2, 3],
        [8, 8, 0, 0, 3],
        [2, 2, 0, 2, 3],
        [8, 2, 2, 0, 0],
        [8, 8, 8, 2, 3],
        [8, 8, 2, 2, 3],
        [2, 2, 0, 2, 3],
        [2, 8, 2, 2, 3],
        [2, 8, 2, 8, 8],
        [2, 2, 2, 2, 2],
        [2, 8, 3, 2, 3],
        [8, 8, 2, 2, 2],
        [8, 8, 0, 0, 3],
        [2, 8, 8, 8, 0],
        [2, 0, 2, 2, 0],
        [2, 8, 0, 2, 0],
        [2, 8, 2, 2, 3],
        [8, 8, 8, 2, 2],
        [8, 2, 3, 2, 3],
        [8, 2, 0, 2, 2],
        [8, 8, 2, 8, 3],
        [2, 8, 2, 2, 0],
        [8, 8, 2, 2, 3],
        [8, 2, 2, 2, 0],
        [2, 8, 8, 8, 3],
        [2, 2, 0, 2, 2],
        [2