<a target="_blank" href="https://colab.research.google.com/github/neelnanda-io/TransformerLens/blob/main/demos/Grokking_Demo.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

# Prime number classification attempt.

<b style="color: red">For this task, we aim to identify whether the Transformer model can identify the primality of numbers. Specifically, given a random number, the model would need to give an output whether the model is a prime number or not.

We hypothesized that with a small subset of numbers, in this particular case, going from 1 to 1000, the Transformer model would be able to identify the pattern of prime numbers, such as Ulam Spiral, or utilize algorithms such as Sieves of Erasnothos to classify the prime number.

Unfortunately, the model overfitted on all tests. We tried different learning rates and weight decays to see if their weights influenced the result. After many unsuccessful attempts, we believed our model could not learn prime numbers.

We concluded that multiple reasons led to what we observed.
First, the problem itself is very challenging. The primality of a number is very complicated to identify. Prime numbers do not have any concrete pattern. Although Ulam Spiral and algorithms such as Sieves of Erasthonos exist, a more sophisticated model would be required to learn them, given limited data.
Second, the dataset itself could be better constructed and defined. For the subset of numbers from one to a thousand, there are only 168 prime numbers, about 17% of the total numbers. This imbalance may lead to limited data for the model to learn from.
Third, the task itself does not fit with the Hooked Transformer, as the Transformer's attention depends on the tokens' serialization. As showcased through the experiments in the paper "Attention is all you need," the Transformer model would be fitter for the NLP task, where long-ranged dependencies and parallelization would contribute more to finding the pattern.
In our case, tokenizing the input numbers' digits offers no insight, regardless of its primality.</b>

# Setup
(No need to read)

In [None]:
TRAIN_MODEL = True

In [None]:
DEVELOPMENT_MODE = True
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install transformer_lens
    %pip install circuitsvis
    %pip install einops
    %pip install torchtyping
    %pip install numpy
    %pip install SymPy

Running as a Colab notebook
Collecting transformer_lens
  Downloading transformer_lens-1.11.0-py3-none-any.whl (119 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m119.1/119.1 kB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting accelerate>=0.23.0 (from transformer_lens)
  Downloading accelerate-0.25.0-py3-none-any.whl (265 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m265.7/265.7 kB[0m [31m11.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting beartype<0.15.0,>=0.14.1 (from transformer_lens)
  Downloading beartype-0.14.1-py3-none-any.whl (739 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m739.7/739.7 kB[0m [31m16.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting datasets>=2.7.1 (from transformer_lens)
  Downloading datasets-2.15.0-py3-none-any.whl (521 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m521.2/521.2 kB[0m [31m19.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting einops>=0.6.0 (from

In [None]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.auto as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from torchtyping import TensorType as TT
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

In [None]:
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

# Model Training

## Config

## Define Task
* Define modular addition
* Define the dataset & labels

Input format:
|a|b|=|

In [None]:
import sympy
import torch
def is_prime(limit):
    primes = []
    for i in range(1,limit+1):
      primes.append(sympy.isprime(i))
    return primes


p = 1000
is_prime_array = is_prime(p)

# Generate dataset and labels
dataset = torch.arange(1, p + 1).unsqueeze(1)
labels = torch.tensor(is_prime_array[0:], dtype=torch.long)  # 1 for prime, 0 for not prime


In [None]:
print(dataset[:5])
print(dataset.shape)
print(labels.shape)
print(labels[:5])
print(len(dataset))

tensor([[1],
        [2],
        [3],
        [4],
        [5]])
torch.Size([1000, 1])
torch.Size([1000])
tensor([0, 1, 1, 0, 1])
1000


In [None]:
DATA_SEED = 6
frac_train = 0.7

torch.manual_seed(DATA_SEED)
cutoff = int(p * frac_train)

train_data = dataset[:cutoff]
train_labels = labels[:cutoff]
test_data = dataset[cutoff:]
test_labels = labels[cutoff:]


print(train_data[:5])
print(train_labels[:5])
print(train_data.shape)
print(test_data[:5])
print(test_labels[:5])
print(test_data.shape)


tensor([[1],
        [2],
        [3],
        [4],
        [5]])
tensor([0, 1, 1, 0, 1])
torch.Size([700, 1])
tensor([[701],
        [702],
        [703],
        [704],
        [705]])
tensor([1, 0, 0, 0, 0])
torch.Size([300, 1])


## Define Model

In [None]:
# Optimizer config
lr = 1e-4
wd = 1e-4
betas = (0.9, 0.98)

num_epochs =  25000
checkpoint_every = 100
momentum = 0.9

In [None]:
cfg = HookedTransformerConfig(
    n_layers = 1,
    n_heads = 4,
    d_model = 128,
    d_head = 32,
    d_mlp = 256,
    act_fn = "relu",
    normalization_type=None,
    d_vocab=p+1,
    d_vocab_out=2,
    n_ctx=1,
    init_weights=True,
    device="cpu",
    seed = 999,
)

model = HookedTransformer(cfg)

Disable the biases, as we don't need them for this task and it makes things easier to interpret.

In [None]:
for name, param in model.named_parameters():
  param.requires_grad = True

## Define Optimizer + Loss

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define the loss function
def loss_fn(logits, labels):
    if len(logits.shape)==3:
        logits = logits[:, -1]
    logits = logits.to(torch.float64)
    labels = labels.to(torch.int64)
    log_probs = logits.log_softmax(dim=-1)
    correct_log_probs = log_probs.gather(dim=-1, index=labels[:, None])[:, 0]


    return -correct_log_probs.mean()

# Assume train_labels and test_labels are 1D Tensors of shape [batch_size,]
train_labels = train_labels.to(device)
test_labels = test_labels.to(device)

# Forward pass through the model for both train and test sets
# # Ensure train_data and test_data are Tensors on the correct device
train_logits = model(train_data).to(device)
train_loss = loss_fn(train_logits, train_labels)

test_logits = model(test_data).to(device)
test_loss = loss_fn(test_logits, test_labels)

# Print shapes to debug
print("train_logits shape:", train_logits.shape)
print("train_labels shape:", train_labels.shape)
print("test_logits shape:", test_logits.shape)
print("test_labels shape:", test_labels.shape)


train_logits shape: torch.Size([700, 1, 2])
train_labels shape: torch.Size([700])
test_logits shape: torch.Size([300, 1, 2])
test_labels shape: torch.Size([300])


In [None]:
import wandb
wandb.init(
    # set the wandb project where this run will be logged
    project="prime-classification",

    # track hyperparameters and run metadata
    config={
    "learning_rate": lr,
    "epochs": num_epochs,
    "weight decay": wd,
    }
)

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

In [None]:
wandb.watch(model, log='all')

In [None]:
import tqdm

train_losses = []
test_losses = []
model_checkpoints = []
checkpoint_epochs = []

train_data = train_data.to(device)
train_labels = train_labels.to(device)
test_data = test_data.to(device)
test_labels = test_labels.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd,betas=betas)

for epoch in tqdm.tqdm(range(num_epochs)):
    # model.train()
    train_logits = model(train_data).to(device)
    train_loss = loss_fn(train_logits, train_labels)
    train_loss.backward()
    train_losses.append(train_loss.item())

    optimizer.step()
    optimizer.zero_grad()

    with torch.inference_mode():
        test_logits = model(test_data).to(device)
        test_loss = loss_fn(test_logits, test_labels)
        test_losses.append(test_loss.item())


    # Checkpointing logic
    if (epoch + 1) % checkpoint_every == 0:
        checkpoint_epochs.append(epoch)
        model_checkpoints.append(copy.deepcopy(model.state_dict()))
        wandb.log({"train_loss": train_loss, "test_loss": test_loss})
        print(f"Epoch {epoch + 1} Train Loss: {train_loss} Test Loss: {test_loss}")

wandb.finish()


In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(12, 6))
plt.plot(train_losses, label='Training Loss')
plt.plot(test_losses, label='Validation Loss')
plt.title('Training and Validation Loss Curves')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()
