In [1]:
import sys
import os

# Add the root directory of your project to the Python path
project_root = os.path.abspath("..")
if project_root not in sys.path:
    sys.path.append(project_root)

In [2]:
import einops
import torch
from torch.optim.lr_scheduler import LambdaLR
from models import *
from branch_datasets import *
from torch.utils.data import DataLoader, Dataset

from torch import Tensor
from typing import Optional, Tuple
from jaxtyping import Float, Int

import matplotlib.pyplot as plt
import plotly.graph_objects as go
import plotly.offline as pyo


In [3]:
# Define Config

In [4]:
length = 50 # number of tokens to remember

config = TransformerConfig(
    vocab_size = 128,
    d_model = 192,
    d_ffn = 768,
    h = 1,
    max_len = 3*length + 1,
    attn_only = True,
    layer_norm = False
)

num_epochs = 4000
model = Transformer(config).to("cuda")

In [5]:
# Dataset

In [6]:

train_dataset = InductionHeadsDataset(4096, length, config.vocab_size)
test_dataset = InductionHeadsDataset(1024, length, config.vocab_size)
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=True)

In [7]:
# Train

In [8]:
cross_entropy = torch.nn.CrossEntropyLoss()

def loss_fn(logits: Float[Tensor, "batch 3*length+1 d_model"], tokens: Float[Tensor, "batch 3*length+1"], gaps: Int[Tensor, "batch"]) -> Float:
    """
    Function to calculate the loss for each token in each batch where the loss is only calculated for the tokens immediately after the gap.
    """
    # Compute the start indices for slicing in a vectorized way
    start_indices = gaps + length + 1
    index_array = torch.arange(length-1).to(logits.device) + start_indices[:, None]
    # Use advanced indexing to gather the appropriate slices
    masked_logits = logits[torch.arange(logits.shape[0])[:, None, None], index_array[:, :, None], torch.arange(logits.shape[2])[None, None, :]]
    masked_logits = masked_logits.to(logits.device)
    tokens = tokens[:, 2:length+1]
    tokens = tokens.reshape(-1)
    masked_logits = masked_logits.reshape(-1, logits.size(-1))
    loss = cross_entropy(masked_logits, tokens)
    return loss

In [9]:
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, betas=(0.9, 0.999), weight_decay=0.01) # Change weight decay if necessary
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda i: min(i/100, 1))
losses = []

for epoch in range(num_epochs):
    tokens, gaps = next(iter(train_dataloader))
    tokens, gaps = tokens.cuda(), gaps.cuda()
    logits = model(tokens)
    loss = loss_fn(logits, tokens, gaps)
    losses.append(loss.item())
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    scheduler.step()
    if epoch % 10 == 0 or epoch < 20:
        print(f'Epoch {epoch}: {loss.item()}')
    

Epoch 0: 5.529040336608887
Epoch 1: 5.509031772613525
Epoch 2: 5.507711410522461
Epoch 3: 5.451091766357422
Epoch 4: 5.459390640258789
Epoch 5: 5.453052043914795
Epoch 6: 5.432517051696777
Epoch 7: 5.381065368652344
Epoch 8: 5.310576915740967
Epoch 9: 5.311708450317383
Epoch 10: 5.24062967300415
Epoch 11: 5.228074550628662
Epoch 12: 5.199084758758545
Epoch 13: 5.1618523597717285
Epoch 14: 5.124350547790527
Epoch 15: 5.108597278594971
Epoch 16: 5.068211078643799
Epoch 17: 5.064356327056885
Epoch 18: 5.027658462524414
Epoch 19: 5.057249069213867
Epoch 20: 5.021362781524658
Epoch 30: 4.9785566329956055
Epoch 40: 4.9509992599487305
Epoch 50: 4.932194709777832
Epoch 60: 4.930891036987305
Epoch 70: 4.905892848968506
Epoch 80: 4.896651268005371
Epoch 90: 4.898183345794678
Epoch 100: 4.891670227050781
Epoch 110: 4.8934006690979
Epoch 120: 4.883333206176758
Epoch 130: 4.882626056671143
Epoch 140: 4.8681321144104
Epoch 150: 4.8742265701293945
Epoch 160: 4.868419647216797
Epoch 170: 4.86420440673

In [10]:
# Plot the loss curve with plotly library

pyo.init_notebook_mode(connected=True)


fig = go.Figure()
fig.add_trace(go.Scatter(y=losses))
fig.add_trace(go.Scatter(
    x=list(range(len(losses))),
    y=losses,
    mode='lines+markers',
    name='Loss',
    hoverinfo='x+y',
    line=dict(color='blue', width=2)
))

# Add titles and labels
fig.update_layout(
    title='Training Loss Over Epochs',
    xaxis_title='Epoch',
    yaxis_title='Loss',
    hovermode='x unified'
)

# Show the plot
pyo.iplot(fig)

In [11]:
# Test

In [12]:
# generate some tokens from the test dataloader
model.eval()
tokens, gaps = next(iter(test_dataloader))
tokens = tokens.cuda()[:1]
logits = model(tokens)
generated_tokens = torch.argmax(logits, dim=-1)
print(tokens)
print(f'Generated tokens: {generated_tokens}')

tensor([[  0,  59, 127, 127,  32,  84,  29, 105,  45,  25,  99,  41,  68,  18,
          57,  59,  84,  48,  38,  26,  85, 100,  86,  86,  29,  38,  14,  22,
          39,  95,  91,  44,  56,  74,   1,  54,  12,  66,  80,  46,  72,  61,
         126, 117,  63,  32,  27,  24,  44,  47, 123,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,  59, 127, 127,  32,  84,
          29, 105,  45,  25,  99,  41,  68,  18,  57,  59,  84,  48,  38,  26,
          85, 100,  86,  86,  29,  38,  14,  22,  39,  95,  91,  44,  56,  74,
           1,  54,  12,  66,  80,  46,  72,  61, 126, 117,  63,  32,  27,  24,
          44,  47, 123,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0]],
       device='cuda:0')
Generated tokens: tensor([[  0,  59,   0, 127,   0, 127,   0, 127, 127, 127, 127,  84,  84, 127,
         127, 127, 127,

In [13]:
# Visualize Attention

In [14]:
print(model)

Transformer(
  (W_E): Embedding(128, 192)
  (layers): ModuleList(
    (0-1): 2 x DecoderLayer(
      (attn): SubLayer(
        (layer_fn): MultiHeadAttention(
          (linears): Linear(in_features=192, out_features=576, bias=True)
          (O): Linear(in_features=192, out_features=192, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
  )
  (pos): PositionalEncoding()
  (W_U): Linear(in_features=192, out_features=128, bias=True)
)


In [15]:
for head in range(config.h):
    # view OV circuit
    ov_circuit = model.cache["W_E"] @ model.layers[0].attn.layer_fn.cache['W_OV'][head].T @ model.cache["W_U"].T # (vocab_size, vocab_size)
    plotly_fig = go.Figure(data=go.Heatmap(z=ov_circuit.numpy()))
    plotly_fig.update_layout(title=f'OV Circuit {head}')
    pyo.iplot(plotly_fig)
    

In [16]:
def mask_scores(attn_scores: Float[Tensor, "heads length length"]) -> Float[Tensor, "heads length length"]:
    '''Mask the attention scores so that tokens don't attend to previous tokens.'''
    mask = torch.tril(torch.ones_like(attn_scores)).bool()
    neg_inf = torch.tensor(-1.0e6).to(attn_scores.device)
    masked_attn_scores = torch.where(mask, attn_scores, neg_inf)
    return masked_attn_scores


for head in range(config.h):
    # view OV circuit
    W_pos = model.cache["W_pos"]
    pos_by_pos_scores = W_pos @ model.layers[0].attn.layer_fn.cache['W_QK'][head] @ W_pos.T # (vocab_size, vocab_size)
    masked_scaled = mask_scores(pos_by_pos_scores / ((config.d_head) ** 0.5))
    pos_by_pos_pattern = torch.softmax(masked_scaled, dim=-1)
    plotly_fig = go.Figure(data=go.Heatmap(z=pos_by_pos_pattern.numpy()))
    plotly_fig.update_layout(title=f'QK Circuit {head}')
    pyo.iplot(plotly_fig)

In [17]:
for head in range(config.h):
    # view layer 0
    head_attn = model.layers[0].attn.layer_fn.cache['attn'][0, head] # (vocab_size, vocab_size)
    plotly_fig = go.Figure(data=go.Heatmap(z=head_attn.numpy()))
    plotly_fig.update_layout(title=f'Attention head {head}')
    pyo.iplot(plotly_fig)

for head in range(config.h):
    # view layer 1
    head_attn = model.layers[1].attn.layer_fn.cache['attn'][0, head] # (vocab_size, vocab_size)
    plotly_fig = go.Figure(data=go.Heatmap(z=head_attn.numpy()))
    plotly_fig.update_layout(title=f'Attention head {head}')
    pyo.iplot(plotly_fig)