In [1]:
import einops
import torch
from torch.optim.lr_scheduler import LambdaLR
from transformer.transformer import Transformer
from torch.utils.data import DataLoader, Dataset

import matplotlib.pyplot as plt
import plotly.graph_objects as go


In [2]:
# Define Config

In [3]:
length = 50 # number of tokens to remember
vocab_size=128
d_model=192
d_ffn=768
h=1
n=2
max_len = 2*length + 1
num_epochs = 2000
model = Transformer(vocab_size, d_model, d_ffn, h, n, max_len, attn_only=True, layer_norm=False).to("cuda")

In [4]:
# Dataset

In [29]:
def induction_heads_dataset(batch, length, vocab_size):
    """
    Generate the induction heads dataset. Format is length number of randomly generated
    tokens, then with the same tokens repeated. 
    """
    dataset = torch.zeros(batch, 1 + length*2).long
    random_sequences = torch.randint(0, vocab_size, (batch, length))
    dataset[:, 1:] = einops.repeat(random_sequences, "b s -> b (t s)", t=2)  
    return dataset


In [30]:
## Test dataset function
print(induction_heads_dataset(4, 12, 8192))

RuntimeError: shape mismatch: value tensor of shape [48, 12] cannot be broadcast to indexing result of shape [48]

In [7]:
class InductionHeadsDataset(Dataset):
    def __init__(self, num_samples, length, vocab_size):
        """
        num_samples: number of samples in the dataset
        length: number of tokens to remember (actual segment length should be 2*length)
        vocab_size: number of possible tokens to choose from
        """
        self.num_samples = num_samples
        self.length = length
        self.vocab_size = vocab_size
        self.data = induction_heads_dataset(num_samples, length, vocab_size)

    def __len__(self):
        return self.num_samples
    
    def __getitem__(self, idx):
        return self.data[idx]

train_dataset = InductionHeadsDataset(4096, length, vocab_size)
test_dataset = InductionHeadsDataset(1024, length, vocab_size)
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=True)

In [8]:
# Train

In [9]:
cross_entropy = torch.nn.CrossEntropyLoss()

def loss_fn(logits, tokens):
    logits = logits[:, -length-1:-1, :]
    tokens = tokens[:, -length:]
    tokens = tokens.reshape(-1)
    logits = logits.reshape(-1, logits.size(-1))
    loss = cross_entropy(logits, tokens)
    return loss

In [10]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.98))
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda i: min(i/100, 1.))
losses = []

for epoch in range(num_epochs):
    tokens = next(iter(train_dataloader))
    tokens = tokens.cuda()
    logits = model(tokens)
    loss = loss_fn(logits, tokens)
    losses.append(loss.item())
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    scheduler.step()
    if epoch % 10 == 0 or epoch < 20:
        print(f'Epoch {epoch}: {loss.item()}')
    

Epoch 0: 5.408360004425049
Epoch 1: 5.387166976928711
Epoch 2: 5.407962799072266
Epoch 3: 5.386505126953125
Epoch 4: 5.357471942901611
Epoch 5: 5.366971015930176
Epoch 6: 5.315518379211426
Epoch 7: 5.311689376831055
Epoch 8: 5.246866226196289
Epoch 9: 5.229741096496582
Epoch 10: 5.192896842956543
Epoch 11: 5.154097080230713
Epoch 12: 5.137121677398682
Epoch 13: 5.110810279846191
Epoch 14: 5.052134037017822
Epoch 15: 5.063539981842041
Epoch 16: 5.038213729858398
Epoch 17: 5.037948131561279
Epoch 18: 4.998725414276123
Epoch 19: 5.007537364959717
Epoch 20: 4.994324207305908
Epoch 30: 4.959488868713379
Epoch 40: 4.940614700317383
Epoch 50: 4.92567777633667
Epoch 60: 4.905764579772949
Epoch 70: 4.890811920166016
Epoch 80: 4.889928340911865
Epoch 90: 4.889792442321777
Epoch 100: 4.882784366607666
Epoch 110: 4.876563549041748
Epoch 120: 4.865502834320068
Epoch 130: 4.861303806304932
Epoch 140: 4.847307205200195
Epoch 150: 4.837465763092041
Epoch 160: 4.80520486831665
Epoch 170: 4.731040954589

In [11]:
# Plot the loss curve with plotly library
import plotly.offline as pyo
import nbformat

pyo.init_notebook_mode(connected=True)


fig = go.Figure()
fig.add_trace(go.Scatter(y=losses))
fig.add_trace(go.Scatter(
    x=list(range(len(losses))),
    y=losses,
    mode='lines+markers',
    name='Loss',
    hoverinfo='x+y',
    line=dict(color='blue', width=2)
))

# Add titles and labels
fig.update_layout(
    title='Training Loss Over Epochs',
    xaxis_title='Epoch',
    yaxis_title='Loss',
    hovermode='x unified'
)

# Show the plot
pyo.iplot(fig)

In [12]:
# Test

In [13]:
# generate some tokens from the test dataloader
model.eval()
tokens = next(iter(test_dataloader))
tokens = tokens.cuda()[:4]
logits = model(tokens)
generated_tokens = torch.argmax(logits, dim=-1)
print(tokens)
print(f'Generated tokens: {generated_tokens}')

tensor([[  0,  94, 122,  30,  22, 109,  69, 111,  97, 115,  85,  84,   1,  62,
         118,  96,  58,  17, 116,  35,  80,  44,  63,  76,  14, 106, 106,  94,
          91,  87,  32, 120,  33,  27,  12,  26,  69,  88,  91,  82,   2, 117,
          82,  70,  47, 110,  56, 100,  58,  92,  54,  94, 122,  30,  22, 109,
          69, 111,  97, 115,  85,  84,   1,  62, 118,  96,  58,  17, 116,  35,
          80,  44,  63,  76,  14, 106, 106,  94,  91,  87,  32, 120,  33,  27,
          12,  26,  69,  88,  91,  82,   2, 117,  82,  70,  47, 110,  56, 100,
          58,  92,  54],
        [  0,  86,  35,  80, 114,   2,  37, 124,  36,  18, 116,  71,  19,   7,
         119,  82,   1,  27,  60, 124,  98, 120,  20, 112,  73,  55,  33,  36,
         121,  36,  25,  28,  38, 122,  46,  94, 114,  94, 109,  53,  67,  40,
           0, 103,  14,  99,  32,  97,   6,  43,  69,  86,  35,  80, 114,   2,
          37, 124,  36,  18, 116,  71,  19,   7, 119,  82,   1,  27,  60, 124,
          98, 120,  20, 112

In [14]:
# Visualize Attention

In [15]:
print(model)

Transformer(
  (W_E): Embedding(128, 192)
  (layers): ModuleList(
    (0-1): 2 x DecoderLayer(
      (attn): SubLayer(
        (layer_fn): MultiHeadAttention(
          (linears): Linear(in_features=192, out_features=576, bias=True)
          (O): Linear(in_features=192, out_features=192, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
  )
  (pos): PositionalEncoding()
  (W_U): Linear(in_features=192, out_features=128, bias=True)
)


In [24]:
for head in range(h):
    # view OV circuit
    ov_circuit = model.cache["W_E"] @ model.layers[0].attn.layer_fn.cache['W_OV'][head].T @ model.cache["W_U"].T # (vocab_size, vocab_size)
    plotly_fig = go.Figure(data=go.Heatmap(z=ov_circuit.numpy()))
    plotly_fig.update_layout(title=f'OV Circuit {head}')
    pyo.iplot(plotly_fig)
    

In [26]:
def mask_scores(attn_scores):
    '''Mask the attention scores so that tokens don't attend to previous tokens.'''
    mask = torch.tril(torch.ones_like(attn_scores)).bool()
    neg_inf = torch.tensor(-1.0e6).to(attn_scores.device)
    masked_attn_scores = torch.where(mask, attn_scores, neg_inf)
    return masked_attn_scores


for head in range(h):
    # view OV circuit
    W_pos = model.cache["W_pos"]
    pos_by_pos_scores = W_pos @ model.layers[0].attn.layer_fn.cache['W_QK'][head] @ W_pos.T # (vocab_size, vocab_size)
    masked_scaled = mask_scores(pos_by_pos_scores / ((d_model // h) ** 0.5))
    pos_by_pos_pattern = torch.softmax(masked_scaled, dim=-1)
    plotly_fig = go.Figure(data=go.Heatmap(z=pos_by_pos_pattern.numpy()))
    plotly_fig.update_layout(title=f'QK Circuit {head}')
    pyo.iplot(plotly_fig)

In [19]:
for head in range(h):
    # view layer 0
    head_attn = model.layers[0].attn.layer_fn.cache['attn'][0, head] # (vocab_size, vocab_size)
    plotly_fig = go.Figure(data=go.Heatmap(z=head_attn.numpy()))
    plotly_fig.update_layout(title=f'Attention head {head}')
    pyo.iplot(plotly_fig)

for head in range(h):
    # view layer 1
    head_attn = model.layers[1].attn.layer_fn.cache['attn'][0, head] # (vocab_size, vocab_size)
    plotly_fig = go.Figure(data=go.Heatmap(z=head_attn.numpy()))
    plotly_fig.update_layout(title=f'Attention head {head}')
    pyo.iplot(plotly_fig)