In [172]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from transformer_lens import HookedTransformer, HookedTransformerConfig, utils
import random
import circuitsvis as cv
from fancy_einsum import einsum

In [173]:
# model parameter constants
N_LAYERS = 1
N_HEADS = 1
D_MODEL = 32
D_HEAD = 32
D_MLP = None
# D_VOCAB = 12
D_VOCAB = 10
SEED = 123
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# DEVICE = torch.device("mps")

In [174]:
MAX_DIGITS = 5
NUM_DIGITS = 2
# 10 is between numbers, 11 for filler space
# CTX_LEN = (1 + NUM_DIGITS) * MAX_DIGITS + (NUM_DIGITS)
CTX_LEN = MAX_DIGITS * 3

In [175]:
def generate_number_with_variable_length(max_digits):
    num_digits = random.randint(1, max_digits)
    if num_digits == 1:
        return random.randint(0, 9), num_digits  # For a single digit, range is 0-9
    else:
        return random.randint(10**(num_digits-1), 10**num_digits - 1), num_digits

def generate_data(n, training=True):
    tokens = []
    target = []

    for _ in range(n):
        curr = []
        answer = ""
        
        a, len_a = (generate_number_with_variable_length(MAX_DIGITS))
        b, len_b = (generate_number_with_variable_length(MAX_DIGITS))
        # a = random.randint(0, 99999)
        # b = random.randint(0, 99999)
        # len_a = len(str(a))
        # len_b = len(str(b))

        if a > b:
            answer = "a"
        else:
            answer = "b"

        for i in range(MAX_DIGITS - len_a):
            curr.append(0)
        for digit in str(a):
            curr.append(int(digit))
        
        # curr.append(10)
        
        for i in range(MAX_DIGITS - len_b):
            curr.append(0)
        for digit in str(b):
            curr.append(int(digit))
        
        if answer == "a":
            answer_list = curr[:MAX_DIGITS]
        else:
            answer_list = curr[MAX_DIGITS:]

        # answer_list = curr[MAX_DIGITS + 1:]

        # while len(answer_list) < CTX_LEN:
        #     answer_list.append(11)
            
        # curr.append(10)
        for ans in answer_list:
            curr.append(ans)
        
        tokens.append(curr)
        target.append(answer_list)

    return torch.tensor(tokens), torch.tensor(target)

generate_data(10)

(tensor([[0, 0, 6, 4, 4, 0, 0, 1, 0, 4, 0, 0, 6, 4, 4],
         [0, 9, 1, 0, 6, 0, 0, 0, 9, 1, 0, 9, 1, 0, 6],
         [0, 0, 0, 2, 7, 0, 0, 1, 4, 9, 0, 0, 1, 4, 9],
         [0, 0, 0, 0, 5, 8, 2, 1, 6, 5, 8, 2, 1, 6, 5],
         [4, 9, 6, 8, 5, 0, 2, 6, 0, 0, 4, 9, 6, 8, 5],
         [0, 0, 5, 1, 8, 0, 0, 6, 6, 7, 0, 0, 6, 6, 7],
         [0, 0, 0, 9, 4, 0, 1, 5, 9, 9, 0, 1, 5, 9, 9],
         [6, 7, 1, 6, 9, 0, 0, 0, 0, 3, 6, 7, 1, 6, 9],
         [0, 0, 6, 3, 3, 0, 0, 8, 1, 3, 0, 0, 8, 1, 3],
         [0, 0, 0, 8, 1, 0, 6, 9, 1, 0, 0, 6, 9, 1, 0]]),
 tensor([[0, 0, 6, 4, 4],
         [0, 9, 1, 0, 6],
         [0, 0, 1, 4, 9],
         [8, 2, 1, 6, 5],
         [4, 9, 6, 8, 5],
         [0, 0, 6, 6, 7],
         [0, 1, 5, 9, 9],
         [6, 7, 1, 6, 9],
         [0, 0, 8, 1, 3],
         [0, 6, 9, 1, 0]]))

In [176]:
# model setup

# MIGHT HAVE TO CHANGE CONTEXT OR VOCAB
cfg = HookedTransformerConfig(
    d_model=D_MODEL,
    n_layers=N_LAYERS,
    n_heads=N_HEADS,
    d_head=D_HEAD,
    n_ctx=CTX_LEN,
    d_vocab=D_VOCAB,
    act_fn="relu",
    seed=SEED,
    device=DEVICE,
    attn_only=True
)

# hooked transformer used for interpretation later
model = HookedTransformer(cfg, move_to_device=True)
# model.to(DEVICE)

In [177]:
def loss_function(logits, target, return_per_token=True, print_tokens=False):
    answer = target
    # print(logits.shape)
    logits = logits[:, 2 * MAX_DIGITS : , :]
    # print(logits.shape)
    log_prob = logits.log_softmax(-1)
    output_prob = log_prob.gather(-1, answer[..., None])[..., 0]

    # mask = (target != 11).float()
    # masked_output_prob = output_prob * mask

    if print_tokens:
        print("target", target)
        print("predicted", torch.argmax(logits, dim=-1))
        # print(answer.unsqueeze(-1))
        # print(output_prob)
        
    if return_per_token:
        # return -(masked_output_prob / mask.sum(dim=1, keepdim=True).clamp(min=1))
        return -output_prob
    # return -(masked_output_prob.sum() / mask.sum())
    return -output_prob.mean()

In [178]:
def accuracy(logits, target, return_per_token=False):
    logits = logits[:, 2 * MAX_DIGITS:, :]
    predicted = torch.argmax(logits, dim=-1)
    answer = target
    # print(predicted, answer)
    if return_per_token:
        return (predicted == answer).float()
    return (predicted == answer).float().mean().item()

In [179]:
import torch
import torch.nn as nn
import numpy as np

def train_model(model, n_epochs, batch_size, batches_per, sequence_length=2):
    lr = 1e-3
    betas = (0.9, 0.999)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=betas)
    
    vocab_size = 12 
    train_losses = []
    # model.train()
    for epoch in range(n_epochs):
        epoch_losses = []

        model.train()
        for _ in range(batches_per):
            tokens, targets = generate_data(batch_size, training=True) 
            logits = model(tokens)
            
            # print(tokens)
            losses = loss_function(logits, targets, print_tokens=False)
            
            losses.mean().backward()
            optimizer.step()
            optimizer.zero_grad()
            
            epoch_losses.extend(losses.detach())

        # epoch_losses_cpu = [loss.cpu().numpy() for loss in epoch_losses]
        train_losses.append(np.mean(epoch_losses))
        if epoch % 10 == 0:
            print(f"Epoch {epoch}, train loss: {train_losses[-1]}")

        model.eval()

        test_data, targets = generate_data(1280, training=True)
        logits = model(test_data)

        acc = accuracy(logits, targets, return_per_token=False)

        print(f"Validation {epoch} accuracy: {acc}")

    return train_losses


In [180]:
torch.set_printoptions(threshold=10000)
losses = train_model(model, 100, 128, 10)

Epoch 0, train loss: 2.020228385925293
Validation 0 accuracy: 0.43671876192092896
Validation 1 accuracy: 0.6449999809265137
Validation 2 accuracy: 0.8089062571525574
Validation 3 accuracy: 0.9496874809265137
Validation 4 accuracy: 0.9879687428474426
Validation 5 accuracy: 0.9957812428474426
Validation 6 accuracy: 0.9982812404632568
Validation 7 accuracy: 0.9987499713897705
Validation 8 accuracy: 0.9996874928474426
Validation 9 accuracy: 0.9987499713897705
Epoch 10, train loss: 0.0765344426035881
Validation 10 accuracy: 0.9996874928474426
Validation 11 accuracy: 0.9995312690734863
Validation 12 accuracy: 0.9995312690734863
Validation 13 accuracy: 0.9998437762260437
Validation 14 accuracy: 1.0
Validation 15 accuracy: 1.0
Validation 16 accuracy: 0.9996874928474426
Validation 17 accuracy: 1.0
Validation 18 accuracy: 1.0
Validation 19 accuracy: 1.0
Epoch 20, train loss: 0.019552577286958694
Validation 20 accuracy: 0.9998437762260437
Validation 21 accuracy: 0.9998437762260437
Validation 22 a

In [181]:
tokens, targets = generate_data(128, training=True)
logits, cache = model.run_with_cache(tokens)
attention_pattern = cache["pattern", 0, "attn"]

In [182]:
cv.attention.attention_heads(tokens=list(map(lambda t: str(t.item()), tokens[0])), attention=attention_pattern[0])

In [183]:
cv.attention.attention_patterns(tokens=list(map(lambda t: str(t.item()), tokens[0])), attention=attention_pattern[0])

In [184]:
cv.attention.attention_patterns(tokens=list(map(lambda t: str(t.item()), tokens[2])), attention=attention_pattern[2])

In [185]:
cv.attention.attention_patterns(tokens=list(map(lambda t: str(t.item()), tokens[3])), attention=attention_pattern[3])

In [186]:
cv.attention.attention_patterns(tokens=list(map(lambda t: str(t.item()), tokens[4])), attention=attention_pattern[4])

In [187]:
cv.attention.attention_patterns(tokens=list(map(lambda t: str(t.item()), tokens[5])), attention=attention_pattern[5])