In [1]:
import torch
from torch.optim.lr_scheduler import LambdaLR
from transformer import Transformer
from torch.utils.data import DataLoader, Dataset

import matplotlib.pyplot as plt
import plotly.graph_objects as go


In [2]:
# Define Config

In [3]:
length = 64 # number of tokens to remember
vocab_size=128
d_model=192
d_ffn=768
h=12
n=2
max_len=256
num_epochs = 1000
model = Transformer(vocab_size, d_model, d_ffn, h, n, max_len).to("cuda")

In [4]:
# Dataset

In [5]:
def induction_heads_dataset(batch, length, vocab_size):
    """
    Generate the induction heads dataset. Format is length number of randomly generated
    tokens, then with the same tokens repeated. 
    """
    
    random_sequences = torch.randint(0, vocab_size, (batch, length))
    repeated_sequences = torch.cat([random_sequences, random_sequences], dim=-1)    
    return repeated_sequences


In [6]:
## Test dataset function
print(induction_heads_dataset(4, 12, 8192))

tensor([[4862, 2224, 8032, 5612, 4912, 1079, 2975, 1175, 3263,  289, 7419, 6808,
         4862, 2224, 8032, 5612, 4912, 1079, 2975, 1175, 3263,  289, 7419, 6808],
        [7120, 2391, 4075, 2248, 3271,  849, 5252, 3925, 6822,  340, 5049, 2800,
         7120, 2391, 4075, 2248, 3271,  849, 5252, 3925, 6822,  340, 5049, 2800],
        [ 196, 8091, 2317, 6543, 1541, 4006, 3657, 4023, 8171, 4779, 1143, 1806,
          196, 8091, 2317, 6543, 1541, 4006, 3657, 4023, 8171, 4779, 1143, 1806],
        [5205, 3532, 2683, 5142, 3756, 1867, 1309,   19, 6456, 4316, 4846, 6364,
         5205, 3532, 2683, 5142, 3756, 1867, 1309,   19, 6456, 4316, 4846, 6364]])


In [7]:
class InductionHeadsDataset(Dataset):
    def __init__(self, num_samples, length, vocab_size):
        """
        num_samples: number of samples in the dataset
        length: number of tokens to remember (actual segment length should be 2*length)
        vocab_size: number of possible tokens to choose from
        """
        self.num_samples = num_samples
        self.length = length
        self.vocab_size = vocab_size
        self.data = induction_heads_dataset(num_samples, length, vocab_size)

    def __len__(self):
        return self.num_samples
    
    def __getitem__(self, idx):
        return self.data[idx]

train_dataset = InductionHeadsDataset(4096, length, vocab_size)
test_dataset = InductionHeadsDataset(1024, length, vocab_size)
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=True)

In [8]:
# Train

In [9]:
cross_entropy = torch.nn.CrossEntropyLoss()

def loss_fn(logits, tokens):
    logits = logits[:, -length-1:-1, :]
    tokens = tokens[:, -length:]
    tokens = tokens.reshape(-1)
    logits = logits.reshape(-1, logits.size(-1))
    loss = cross_entropy(logits, tokens)
    return loss

In [10]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.98))
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda i: min(i/100, 1.))
losses = []

for epoch in range(num_epochs):
    tokens = next(iter(train_dataloader))
    tokens = tokens.cuda()
    logits = model(tokens)
    loss = loss_fn(logits, tokens)
    losses.append(loss.item())
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    scheduler.step()
    if epoch % 10 == 0 or epoch < 20:
        print(f'Epoch {epoch}: {loss.item()}')
    

Epoch 0: 5.479373455047607
Epoch 1: 5.464622497558594
Epoch 2: 5.451710224151611
Epoch 3: 5.413358211517334
Epoch 4: 5.3861260414123535
Epoch 5: 5.346640586853027
Epoch 6: 5.34651517868042
Epoch 7: 5.285891532897949
Epoch 8: 5.2473225593566895
Epoch 9: 5.2092156410217285
Epoch 10: 5.166035175323486
Epoch 11: 5.131063461303711
Epoch 12: 5.09135103225708
Epoch 13: 5.064133167266846
Epoch 14: 5.045584201812744
Epoch 15: 5.031469821929932
Epoch 16: 4.988376617431641
Epoch 17: 4.97529411315918
Epoch 18: 4.946115493774414
Epoch 19: 4.933976173400879
Epoch 20: 4.915812969207764
Epoch 30: 4.8133864402771
Epoch 40: 4.763978958129883
Epoch 50: 4.537522792816162
Epoch 60: 3.1122524738311768
Epoch 70: 0.5826408267021179
Epoch 80: 0.05032998323440552
Epoch 90: 0.01411201898008585
Epoch 100: 0.006703242193907499
Epoch 110: 0.004226430784910917
Epoch 120: 0.0031679582316428423
Epoch 130: 0.002706334460526705
Epoch 140: 0.002332075033336878
Epoch 150: 0.0020824705716222525
Epoch 160: 0.001906992867588

In [11]:
# Plot the loss curve with plotly library
import plotly.offline as pyo
import nbformat

pyo.init_notebook_mode(connected=True)


fig = go.Figure()
fig.add_trace(go.Scatter(y=losses))
fig.add_trace(go.Scatter(
    x=list(range(len(losses))),
    y=losses,
    mode='lines+markers',
    name='Loss',
    hoverinfo='x+y',
    line=dict(color='blue', width=2)
))

# Add titles and labels
fig.update_layout(
    title='Training Loss Over Epochs',
    xaxis_title='Epoch',
    yaxis_title='Loss',
    hovermode='x unified'
)

# Show the plot
pyo.iplot(fig)

In [12]:
# Test

In [13]:
# generate some tokens from the test dataloader
tokens = next(iter(test_dataloader))
tokens = tokens.cuda()[:4]
logits = model(tokens)
generated_tokens = torch.argmax(logits, dim=-1)
print(tokens)
print(f'Generated tokens: {generated_tokens}')

tensor([[ 22,  47, 115,  18,  93,  93, 121,  64,  39,  90,  41,  72, 110,  75,
         113, 124,  72,  92, 112,  23,  18, 102,  73, 126,  18,  34,  66,  39,
          96,  47,  48,  75, 105,  74, 114,  39,  31,  99,  75,  51, 109,   8,
          15,   5,  14,  71, 100,  51,  32,  98, 123,  21,  18,  64,   1, 109,
           6,  68,  64,  38, 103,   5,  62,  44,  22,  47, 115,  18,  93,  93,
         121,  64,  39,  90,  41,  72, 110,  75, 113, 124,  72,  92, 112,  23,
          18, 102,  73, 126,  18,  34,  66,  39,  96,  47,  48,  75, 105,  74,
         114,  39,  31,  99,  75,  51, 109,   8,  15,   5,  14,  71, 100,  51,
          32,  98, 123,  21,  18,  64,   1, 109,   6,  68,  64,  38, 103,   5,
          62,  44],
        [ 53,  59,  10,  50,  22,  74, 118,  64, 110,  85, 113,  10,   6,  45,
          35, 118, 119,  79,  45, 118,  34,  18, 119,  47,  23,  70, 118,   1,
          64, 125, 121,  44,  66,  37,  88,  30,  18,  17, 109,   3, 124,  20,
          42,  61, 106,  87,  39

In [14]:
# Visualize Attention