In [1]:
import einops
import torch
from torch.optim.lr_scheduler import LambdaLR
from transformer.transformer import Transformer
from torch.utils.data import DataLoader, Dataset

import matplotlib.pyplot as plt
import plotly.graph_objects as go


In [2]:
# Define Config

In [3]:
length = 50 # number of tokens to remember
vocab_size=128
d_model=192
d_ffn=768
h=1
n=2
max_len = 3*length + 1
num_epochs = 4000
model = Transformer(vocab_size, d_model, d_ffn, h, n, max_len, attn_only=True, layer_norm=False).to("cuda")

In [4]:
# Dataset

In [5]:
def induction_heads_dataset(batch, length, vocab_size):
    """
    Generate the induction heads dataset. Format is length number of randomly generated
    tokens, then with the same tokens repeated. 
    """
    dataset = torch.zeros(batch, 1 + length*3).long()

    random_sequences = torch.randint(0, vocab_size, (batch, length))
    dataset[:, 1:length+1] = random_sequences
    random_gap = torch.randint(1, length, (batch,))
    for i in range(batch):
        dataset[i, length+1+random_gap[i]:length*2+1+random_gap[i]] = random_sequences[i]
    return dataset, random_gap


In [6]:
## Test dataset function
print(induction_heads_dataset(4, 12, 8192))

(tensor([[   0,  810, 4136, 1454, 8114, 8092, 1223, 7292, 5165, 1316, 3954, 2897,
           47,    0,    0,    0,    0,  810, 4136, 1454, 8114, 8092, 1223, 7292,
         5165, 1316, 3954, 2897,   47,    0,    0,    0,    0,    0,    0,    0,
            0],
        [   0, 6065, 3311, 6222,  948, 5410, 1791, 6058, 5960, 7704, 7157, 3339,
         6223,    0,    0,    0,    0, 6065, 3311, 6222,  948, 5410, 1791, 6058,
         5960, 7704, 7157, 3339, 6223,    0,    0,    0,    0,    0,    0,    0,
            0],
        [   0, 5401, 6625, 4048, 7084, 7916,  361, 7255, 1984, 6742, 4987, 4416,
         4313,    0,    0,    0,    0,    0,    0,    0,    0, 5401, 6625, 4048,
         7084, 7916,  361, 7255, 1984, 6742, 4987, 4416, 4313,    0,    0,    0,
            0],
        [   0, 6735, 3341, 3852, 4580, 3229, 8120, 7330, 3594,  114, 8172, 4364,
         3691,    0,    0, 6735, 3341, 3852, 4580, 3229, 8120, 7330, 3594,  114,
         8172, 4364, 3691,    0,    0,    0,    0,    0,    

In [7]:
class InductionHeadsDataset(Dataset):
    def __init__(self, num_samples, length, vocab_size):
        """
        num_samples: number of samples in the dataset
        length: number of tokens to remember (actual segment length should be 2*length)
        vocab_size: number of possible tokens to choose from
        """
        self.num_samples = num_samples
        self.length = length
        self.vocab_size = vocab_size
        self.data, self.gaps = induction_heads_dataset(num_samples, length, vocab_size)

    def __len__(self):
        return self.num_samples
    
    def __getitem__(self, idx):
        return self.data[idx], self.gaps[idx]

train_dataset = InductionHeadsDataset(4096, length, vocab_size)
test_dataset = InductionHeadsDataset(1024, length, vocab_size)
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=True)

In [8]:
# Train

In [9]:
cross_entropy = torch.nn.CrossEntropyLoss()

def loss_fn(logits, tokens, gaps):
    # Compute the start indices for slicing in a vectorized way
    start_indices = gaps + length + 1
    index_array = torch.arange(length-1).to(logits.device) + start_indices[:, None]
    # Use advanced indexing to gather the appropriate slices
    masked_logits = logits[torch.arange(logits.shape[0])[:, None, None], index_array[:, :, None], torch.arange(logits.shape[2])[None, None, :]]
    masked_logits = masked_logits.to(logits.device)
    tokens = tokens[:, 2:length+1]
    tokens = tokens.reshape(-1)
    masked_logits = masked_logits.reshape(-1, logits.size(-1))
    loss = cross_entropy(masked_logits, tokens)
    return loss

In [11]:
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, betas=(0.9, 0.999), weight_decay=0.01) # Change weight decay if necessary
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda i: min(i/100, 1))
losses = []

for epoch in range(num_epochs):
    tokens, gaps = next(iter(train_dataloader))
    tokens, gaps = tokens.cuda(), gaps.cuda()
    logits = model(tokens)
    loss = loss_fn(logits, tokens, gaps)
    losses.append(loss.item())
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    scheduler.step()
    if epoch % 10 == 0 or epoch < 20:
        print(f'Epoch {epoch}: {loss.item()}')
    

Epoch 0: 5.310470104217529
Epoch 1: 5.315279006958008
Epoch 2: 5.293283462524414
Epoch 3: 5.277859210968018
Epoch 4: 5.2813849449157715
Epoch 5: 5.282442569732666
Epoch 6: 5.206539154052734
Epoch 7: 5.238861560821533
Epoch 8: 5.168002605438232
Epoch 9: 5.18679141998291
Epoch 10: 5.131943225860596
Epoch 11: 5.102895736694336
Epoch 12: 5.1013407707214355
Epoch 13: 5.080920219421387
Epoch 14: 5.060436725616455
Epoch 15: 5.027204990386963
Epoch 16: 5.038907527923584
Epoch 17: 5.034437656402588
Epoch 18: 5.017813682556152
Epoch 19: 4.998685359954834
Epoch 20: 4.982463836669922
Epoch 30: 4.967130184173584
Epoch 40: 4.9450364112854
Epoch 50: 4.929076194763184
Epoch 60: 4.92122745513916
Epoch 70: 4.915471076965332
Epoch 80: 4.89203405380249
Epoch 90: 4.894589900970459
Epoch 100: 4.888542175292969
Epoch 110: 4.881584644317627
Epoch 120: 4.883635997772217
Epoch 130: 4.873321056365967
Epoch 140: 4.869999408721924
Epoch 150: 4.861929893493652
Epoch 160: 4.867027282714844
Epoch 170: 4.8703508377075

KeyboardInterrupt: 

In [12]:
# Plot the loss curve with plotly library
import plotly.offline as pyo
import nbformat

pyo.init_notebook_mode(connected=True)


fig = go.Figure()
fig.add_trace(go.Scatter(y=losses))
fig.add_trace(go.Scatter(
    x=list(range(len(losses))),
    y=losses,
    mode='lines+markers',
    name='Loss',
    hoverinfo='x+y',
    line=dict(color='blue', width=2)
))

# Add titles and labels
fig.update_layout(
    title='Training Loss Over Epochs',
    xaxis_title='Epoch',
    yaxis_title='Loss',
    hovermode='x unified'
)

# Show the plot
pyo.iplot(fig)

In [None]:
# Test

In [None]:
# generate some tokens from the test dataloader
model.eval()
tokens, gaps = next(iter(test_dataloader))
tokens = tokens.cuda()[:1]
logits = model(tokens)
generated_tokens = torch.argmax(logits, dim=-1)
print(tokens)
print(f'Generated tokens: {generated_tokens}')

tensor([[  0,   0,  98, 112, 113,  21,   1,  58,  31,  65,  87,  87,  64,  82,
          69,  74, 126,  85, 124,  58,   1,  13,  38,  39,  64,  29,  78,  95,
          38, 111, 121,  43, 113,  31,  84,  11,  77,  26, 126, 118,  87,  21,
         115,  16,  84, 117,  14,  50,  89,   9,  96,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,  98, 112, 113,  21,   1,  58,  31,  65,  87,  87,  64,  82,
          69,  74, 126,  85, 124,  58,   1,  13,  38,  39,  64,  29,  78,  95,
          38, 111, 121,  43, 113,  31,  84,  11,  77,  26, 126, 118,  87,  21,
         115,  16,  84, 117,  14,  50,  89,   9,  96,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0]],
       device='cuda:0')
Generated tokens: tensor([[  0,   0,  98,   0, 112,   0,   0,   1,   1, 113,   0,  98,  98,   0,
          98,   0,   0,

In [None]:
# Visualize Attention

In [None]:
print(model)

Transformer(
  (W_E): Embedding(128, 192)
  (layers): ModuleList(
    (0-1): 2 x DecoderLayer(
      (attn): SubLayer(
        (layer_fn): MultiHeadAttention(
          (linears): Linear(in_features=192, out_features=576, bias=True)
          (O): Linear(in_features=192, out_features=192, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
  )
  (pos): PositionalEncoding()
  (W_U): Linear(in_features=192, out_features=128, bias=True)
)


In [None]:
for head in range(h):
    # view OV circuit
    ov_circuit = model.cache["W_E"] @ model.layers[0].attn.layer_fn.cache['W_OV'][head].T @ model.cache["W_U"].T # (vocab_size, vocab_size)
    plotly_fig = go.Figure(data=go.Heatmap(z=ov_circuit.numpy()))
    plotly_fig.update_layout(title=f'OV Circuit {head}')
    pyo.iplot(plotly_fig)
    

In [None]:
def mask_scores(attn_scores):
    '''Mask the attention scores so that tokens don't attend to previous tokens.'''
    mask = torch.tril(torch.ones_like(attn_scores)).bool()
    neg_inf = torch.tensor(-1.0e6).to(attn_scores.device)
    masked_attn_scores = torch.where(mask, attn_scores, neg_inf)
    return masked_attn_scores


for head in range(h):
    # view OV circuit
    W_pos = model.cache["W_pos"]
    pos_by_pos_scores = W_pos @ model.layers[0].attn.layer_fn.cache['W_QK'][head] @ W_pos.T # (vocab_size, vocab_size)
    masked_scaled = mask_scores(pos_by_pos_scores / ((d_model // h) ** 0.5))
    pos_by_pos_pattern = torch.softmax(masked_scaled, dim=-1)
    plotly_fig = go.Figure(data=go.Heatmap(z=pos_by_pos_pattern.numpy()))
    plotly_fig.update_layout(title=f'QK Circuit {head}')
    pyo.iplot(plotly_fig)

In [None]:
for head in range(h):
    # view layer 0
    head_attn = model.layers[0].attn.layer_fn.cache['attn'][0, head] # (vocab_size, vocab_size)
    plotly_fig = go.Figure(data=go.Heatmap(z=head_attn.numpy()))
    plotly_fig.update_layout(title=f'Attention head {head}')
    pyo.iplot(plotly_fig)

for head in range(h):
    # view layer 1
    head_attn = model.layers[1].attn.layer_fn.cache['attn'][0, head] # (vocab_size, vocab_size)
    plotly_fig = go.Figure(data=go.Heatmap(z=head_attn.numpy()))
    plotly_fig.update_layout(title=f'Attention head {head}')
    pyo.iplot(plotly_fig)