In [1]:
import circuitsvis as cv

In [2]:
from IPython import get_ipython
ipython = get_ipython()
# Code to automatically update the HookedTransformer code as its edited without restarting the kernel
ipython.magic("load_ext autoreload")
ipython.magic("autoreload 2")

  ipython.magic("load_ext autoreload")
  ipython.magic("autoreload 2")


In [3]:
import plotly.io as pio
pio.renderers.default = "browser" # or use "browser" if you want plots to open with browser


import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
from dataclasses import dataclass
from einops import rearrange
from tqdm.notebook import tqdm_notebook
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader, Dataset

from torchtyping import TensorType as TT
from typing import List, Union, Optional
from functools import partial
import copy

from IPython.display import HTML, display

import transformer_lens as tl
import transformer_lens.utils as utils
from transformer_lens.hook_points import HookedRootModule, HookPoint  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)
    return px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)
    return px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)
    return px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs)

In [4]:
def build_fibonacci_sequences(seq_len, max_start):
    x, y = [], []
    for i in range(0, max_start):
        seq = [i, i+1]
        for j in range(2, seq_len+1):
            seq.append(seq[j-1] + seq[j-2])
        x.append(seq[:seq_len])
        y.append(seq[1:])
    return torch.tensor(x, dtype=torch.float32), torch.tensor(y, dtype=torch.int64)

In [5]:
def build_fibmodp_sequences(seq_len, dataset_size, p):
    
    combined_seq = [0, 1]
    for i in range(2, dataset_size):
        combined_seq.append((combined_seq[-1] + combined_seq[-2]) % p)
    
    x, y = [], []
    for i in range(0, dataset_size-seq_len-1):
        x.append(combined_seq[i:i+seq_len])
        y.append(combined_seq[i+1:i+seq_len+1])
 
    return torch.tensor(x, dtype=torch.float32), torch.tensor(y, dtype=torch.int64)

In [6]:
def build_lookback_addition_sequences(seq_len, max_start, k=1):
    x, y = [], []
    for i in range(0, max_start):
        seq = [i, i+1]
        for j in range(2, seq_len+1):
            seq.append(seq[j-k] + 1)
        x.append(seq[:seq_len])
        y.append(seq[1:])
    return torch.tensor(x, dtype=torch.float32), torch.tensor(y, dtype=torch.int64)

In [7]:
def build_mod3_sequences(seq_len, max_start):
    x, y = [], []
    for i in range(0, max_start):
        seq = [i]
        for j in range(1, seq_len+1):
            if seq[j-1] % 3 == 0:
                seq.append(seq[j-1]+1)
            else:
                seq.append(seq[j-1]+2)
        x.append(seq[:seq_len])
        y.append(seq[1:])
    return torch.tensor(x, dtype=torch.float32), torch.tensor(y, dtype=torch.int64)

In [8]:
class NumSequenceDataset(Dataset):
    def __init__(self, x, y):
        """Initialize the dataset
        Args:
            x (list): list of input sequences
            y (list): list of output sequences
        """
        self.x = x
        self.y = y

    def __len__(self):
        return len(self.x)
    
    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]

In [9]:
from typing import Tuple

@dataclass
class DataArgs():
    max_seq_len: int
    batch_size: int
    num_workers: int
    dataset_size: int

@dataclass
class TrainingArgs():
    batch_size: int
    epochs: int
    optimizer: torch.optim.Optimizer
    lr: float
    betas: Tuple[float]
    weight_decay: float
    track: bool
    cuda: bool

In [10]:
data_args = DataArgs(
    max_seq_len = 128,
    batch_size = 256,
    num_workers = 4,
    dataset_size = 2048
)

#x, y = build_fibonacci_sequences(args.max_seq_len, 256)
x, y = build_fibmodp_sequences(data_args.max_seq_len, data_args.dataset_size, 100)
#x, y = build_mod3_sequences(args.max_seq_len, 512)
dataset = NumSequenceDataset(x, y)
vocab_size = int(torch.max(y).item())+1
print(vocab_size)

train_set, val_set = torch.utils.data.random_split(dataset, [int(len(dataset)*0.8), int(len(dataset) - int(len(dataset)*0.8))])

trainloader = DataLoader(train_set, batch_size=data_args.batch_size, shuffle=True, num_workers=4)
valloader = DataLoader(val_set, batch_size=data_args.batch_size, shuffle=False, num_workers=4)



100


In [11]:
config = tl.EasyTransformerConfig(
    d_model=64,
    d_head=32,
    n_heads=2,
    d_mlp=256,
    n_layers=1,
    n_ctx=128,
    #use_local_attn=True,
    #attn_types=["local","global"],
    #window_size=2,
    act_fn="solu_ln",
    d_vocab=vocab_size,
    normalization_type="LN",
    seed=23,
)

args = TrainingArgs(
    batch_size = 128,
    epochs = 50,
    optimizer = torch.optim.AdamW,
    lr = 0.001,
    betas = (0.99, 0.999),
    weight_decay = 1,
    track = False,
    cuda = False
)

In [12]:
from typing import Callable

loss_fn = nn.CrossEntropyLoss()

MODEL_FILENAME = "./fibonacci_model.pt"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def train_transformer(trainloader: DataLoader, args) -> list:
    '''
    Defines a Transformer from our custom modules, and trains it on the algotimic seq dataset.
    '''
    epochs = args.epochs
    
    #model = NumSequenceTransformer(args).to(device).train()
    model = tl.EasyTransformer(config).to(device)
    optimizer = args.optimizer(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    loss_list = []
    accuracy_list = []
    #scheduler = torch.optim.lr_scheduler.OneCycleLR(
    #    optimizer, max_lr=0.01, steps_per_epoch=len(trainloader), epochs=args.epochs)
    
    progress_bar = tqdm_notebook(range(epochs))
    for epoch in progress_bar:
        
        for (x, y) in trainloader:
        #for batch in trainloader:
            
            #x = x.to(device)
            #y = y.to(device)

            #logits = model(x)
            logits = model(x.long(), return_type="logits")
            logits = rearrange(logits, 'B S V -> (B S) V')
            y = rearrange(y, 'B S -> (B S)')

            loss = F.cross_entropy(logits, y)
            
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            #scheduler.step()

            loss_list.append(loss.item())

            progress_bar.set_description(f"Epoch = {epoch}, Loss = {loss.item():.4f}")

    print(f"Saving model to: {MODEL_FILENAME}")
    torch.save(model, MODEL_FILENAME)
    return model, loss_list, accuracy_list

In [13]:
model, loss_list, accuracy_list = train_transformer(trainloader, args=args)

fig = px.line(y=loss_list, template="simple_white")
fig.update_layout(title="Cross entropy loss on Fibonacci", yaxis_range=[0, max(loss_list)])
fig.show()

Moving model to device:  cpu


  0%|          | 0/50 [00:00<?, ?it/s]

Saving model to: ./fibonacci_model.pt
Opening in existing browser session.


[1216/142421.923285:ERROR:file_io_posix.cc(152)] open /home/curttigges/.config/BraveSoftware/Brave-Browser/Crash Reports/pending/0e0868d6-92df-4383-b1e1-85e8fe993f22.lock: File exists (17)


libva error: /usr/lib/x86_64-linux-gnu/dri/i965_drv_video.so init failed
[38895:38895:0100/000000.159654:ERROR:sandbox_linux.cc(376)] InitializeSandbox() called with multiple threads in process gpu-process.


In [14]:
import sys
sys.path.append('common_modules/')

import sample_methods as s

model = torch.load(MODEL_FILENAME, map_location=torch.device('cpu'))
model.eval()

initial_seq = [84, 81, 65, 46, 11]

output = s.sample_tokens_no_detokenization(
    model, initial_seq, max_tokens_generated=100, max_seq_len=data_args.max_seq_len, 
    temperature=0, top_k=10
)

print(output)

[84, 81, 65, 46, 11, 57, 68, 25, 93, 18, 11, 29, 40, 69, 9, 78, 87, 65, 52, 17, 69, 86, 55, 41, 96, 37, 33, 70, 3, 73, 76, 49, 25, 74, 99, 73, 72, 45, 17, 62, 79, 41, 20, 61, 81, 42, 23, 65, 88, 53, 41, 94, 35, 29, 64, 93, 57, 50, 7, 57, 64, 21, 85, 6, 91, 97, 88, 85, 73, 58, 31, 89, 20, 9, 29, 38, 67, 5, 72, 77, 49, 26, 75, 1, 76, 77, 53, 30, 83, 13, 96, 9, 5, 14, 19, 33, 52, 85, 37, 22, 59, 81, 40, 21, 61]


In [22]:
example_input = dataset.__getitem__(1)[0].long()[:20]
loss, cache = model.run_with_cache(example_input, return_type="loss")

In [23]:
print(type(cache))
attention_pattern = cache["pattern", 0, "attn"]
attention_pattern = attention_pattern.squeeze(0)
print(attention_pattern.shape)

token_labels = [str(int(t)) for t in example_input]

print("Layer 0 Head Attention Patterns:")
cv.attention.attention_heads(tokens=token_labels, attention=attention_pattern)

# Use the following if outputs not displaying in notebook or if you want to save the html file
#html = cv.attention.attention_heads(tokens=list(map(str, example_input)), attention=attention_pattern)
#with open("cv_attn_2.html", "w") as f:
#    f.write(str(html))

<class 'transformer_lens.ActivationCache.ActivationCache'>
torch.Size([2, 20, 20])
Layer 0 Head Attention Patterns:
