EX1: The n-dimensional tensor mastery challenge: Combine the `Head` and `MultiHeadAttention` into one class that processes all the heads in parallel, treating the heads as another batch dimension (answer is in nanoGPT).

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Pulling in our code from lecture, we left off with the following implementation:

In [11]:
class SelfAttentionHead(nn.Module):
    def __init__(self, input_channels, output_channels, context_length) -> None:
        super().__init__()
        self.input_channels = input_channels
        self.output_channels = output_channels
        self.key = nn.Linear(input_channels, output_channels, bias=False)
        self.query = nn.Linear(input_channels, output_channels, bias=False)
        self.value = nn.Linear(input_channels, output_channels, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones((context_length, context_length))))
        
    def forward(self, idx):
        B, T, C = idx.shape
        
        # lookup query, key, and value vectors
        # C == input_channels, H == output_channels == head_size
        k = self.key(idx) # (B, T, C) -> (B, T, H)
        q = self.query(idx) # (B, T, C) -> (B, T, H)
        v = self.value(idx) # (B, T, C) -> (B, T, H)

        # compute self attention by taking dot product of query and key
        wei = q @ k.transpose(-2, -1) # (B, T, H) @ (B, H, T) -> (B, T, T)

        wei *= self.output_channels ** -0.5 # scale by sqrt of head size
        
        # apply lower triangular mask to weights
        wei = wei.masked_fill(self.tril[:T,:T]==0, float('-inf')) # (B, T, T)

        # apply softmax to get attention weights
        wei = F.softmax(wei, dim=-1) # (B, T, T)

        # apply attention weights to values
        out = wei @ v # (B, T, T) @ (B, T, H) -> (B, T, H)

        return out # (B, T, H)
    

class FeedForward(nn.Module):
    def __init__(self, input_channels, output_channels) -> None:
        super().__init__()

        self.ff = nn.Sequential(
            nn.Linear(input_channels, output_channels),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.ff(x)

class MultiHeadAttention(nn.Module):
    def __init__(self, emb_size, context_length, head_size, num_heads) -> None:
        super().__init__()

        self.heads = nn.ModuleList([
            SelfAttentionHead(emb_size, head_size, context_length) for _ in range(num_heads)
        ])

    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)
    

class ResidualTransformerBlock(nn.Module):
    def __init__(self, emb_size, head_size, context_length, num_multi_attn_heads) -> None:
        super().__init__()
        self.attn = MultiHeadAttention(emb_size, context_length, head_size//num_multi_attn_heads, num_heads=num_multi_attn_heads)
        self.ff = FeedForward(head_size, head_size)
        self.norm1 = nn.LayerNorm(head_size)
        self.norm2 = nn.LayerNorm(head_size)

    def forward(self, x):
        x = self.attn(self.norm1(x)) + x # residual
        x = self.ff(self.norm2(x)) + x # residual
        return x
    
class MultiBlockModel(nn.Module):
    def __init__(self, vocab_size, emb_size, head_size, context_length, num_multi_attn_heads, num_blocks) -> None:
        super().__init__()
        assert head_size % num_multi_attn_heads == 0
        self.token_embeddings = nn.Embedding(vocab_size, emb_size)
        self.positional_embeddings = nn.Embedding(context_length, emb_size)
        self.blocks = nn.Sequential(*[ResidualTransformerBlock(emb_size, head_size, context_length, num_multi_attn_heads) for _ in range(num_blocks)], nn.LayerNorm(head_size))
        self.ff = FeedForward(head_size, head_size)
        self.output_layer = nn.Linear(head_size, vocab_size)

        self.context_length = context_length

    def forward(self, idx, targets=None):
        B, T = idx.shape

        # get token and positional embeddings
        t = self.token_embeddings(idx) # (B, T) -> (B, T, C)
        p = self.positional_embeddings(torch.arange(T, device=idx.device)) # (T, C)
        x = t + p # (B, T, C) (broadcasting)

        # pass thru attention head
        x = self.blocks(x) # (B, T, C) -> (B, T, H)
        
        # pass thru feed forward
        x = self.ff(x) # (B, T, H) -> (B, T, H)

        # pass thru output layer
        logits = self.output_layer(x) # (B, T, H) -> (B, T, V)

        
        if targets is not None:
            B, T, V = logits.shape
            logits = logits.view(B*T, V)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        else:
            loss = None

        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_crop = idx[:,-self.context_length:]
            logits, _ = self(idx_crop)
            logits = logits[:,-1,:] # all batches, last timestamp, all channels
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx


In lecture we implemented multi-head attention by feeding data into several different 'Head' modules in series and then concating the results. We can implement a vectorized version by combining the linear layers of each head into unified linear layers. We just need to be sure to get the dimensions correctly matched up, so that when we create a dimension for each head prior to computing the query-key dot product. I've outlined the shapes of each intermediate tensor in comments.

In [19]:
class MultiHeadAttention(nn.Module):
    def __init__(self, input_channels, output_channels, context_length, num_heads) -> None:
        assert output_channels % num_heads == 0
        super().__init__()
        self.input_channels = input_channels
        self.output_channels = output_channels
        self.head_size = output_channels // num_heads
        self.num_heads = num_heads
        self.key = nn.Linear(input_channels, output_channels, bias=False)
        self.query = nn.Linear(input_channels, output_channels, bias=False)
        self.value = nn.Linear(input_channels, output_channels, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones((context_length, context_length))))
        
    def forward(self, idx):
        B, T, C = idx.shape 
        # B = batches
        # T = timesteps
        # C = input channels
        # Nh = attention heads
        # Hs = head size
        
        q = self.query(idx) # (B, T, C) -> (B, T, Nh * Hs)
        k = self.key(idx) # (B, T, C) -> (B, T, Nh * Hs)
        
        q = q.transpose(-2, -1).view(B, self.num_heads, self.head_size, T).transpose(-2, -1) # (B, T, Nh * Hs) -> (B, Nh, T, Hs)
        k = k.transpose(-2, -1).view(B, self.num_heads, self.head_size, T) # (B, T, Nh * Hs) -> (B, Nh, Hs, T)
        
        w = q @ k # (B, Nh, T, Hs) @ (B, Nh, Hs, T) -> (B, Nh, T, T)
        w = w.masked_fill(self.tril[:T,:T] == 0, float('-inf'))

        w *= self.head_size**-0.5
        w = F.softmax(w, dim=-1)
        
        
        v = self.value(idx) # (B, T, C) -> (B, T, Nh * Hs)
        v = v.transpose(-2, -1).view(B, self.num_heads, self.head_size, T).transpose(-2, -1) # (B, T, Nh * Hs) -> (B, Nh, T, Hs)

        # torch requires a reshape here for reasons I don't fully understand
        attn_out = (w @ v).transpose(-3, -2).reshape(B, T, -1)

        return attn_out
    

# Need to redifine classes that use MultiHeadAttention as well!

class ResidualTransformerBlock(nn.Module):
    def __init__(self, emb_size, head_size, context_length, num_multi_attn_heads) -> None:
        super().__init__()
        self.attn = MultiHeadAttention(emb_size, head_size, context_length, num_heads=num_multi_attn_heads)
        self.ff = FeedForward(head_size, head_size)
        self.norm1 = nn.LayerNorm(head_size)
        self.norm2 = nn.LayerNorm(head_size)

    def forward(self, x):
        x = self.attn(self.norm1(x)) + x # residual
        x = self.ff(self.norm2(x)) + x # residual
        return x
    
class MultiBlockModel(nn.Module):
    def __init__(self, vocab_size, emb_size, head_size, context_length, num_multi_attn_heads, num_blocks) -> None:
        super().__init__()
        assert head_size % num_multi_attn_heads == 0
        self.token_embeddings = nn.Embedding(vocab_size, emb_size)
        self.positional_embeddings = nn.Embedding(context_length, emb_size)
        self.blocks = nn.Sequential(*[ResidualTransformerBlock(emb_size, head_size, context_length, num_multi_attn_heads) for _ in range(num_blocks)], nn.LayerNorm(head_size))
        self.ff = FeedForward(head_size, head_size)
        self.output_layer = nn.Linear(head_size, vocab_size)

        self.context_length = context_length

    def forward(self, idx, targets=None):
        B, T = idx.shape

        # get token and positional embeddings
        t = self.token_embeddings(idx) # (B, T) -> (B, T, C)
        p = self.positional_embeddings(torch.arange(T, device=idx.device)) # (T, C)
        x = t + p # (B, T, C) (broadcasting)

        # pass thru attention head
        x = self.blocks(x) # (B, T, C) -> (B, T, H)
        
        # pass thru feed forward
        x = self.ff(x) # (B, T, H) -> (B, T, H)

        # pass thru output layer
        logits = self.output_layer(x) # (B, T, H) -> (B, T, V)

        
        if targets is not None:
            B, T, V = logits.shape
            logits = logits.view(B*T, V)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        else:
            loss = None

        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_crop = idx[:,-self.context_length:]
            logits, _ = self(idx_crop)
            logits = logits[:,-1,:] # all batches, last timestamp, all channels
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx


Now we can try training a model using this new component. If done correctly we should receive nearly identical results when compared to those from the lecture notebook, as the underlying structure of the network should not be fundamentally different and none of the hyperparameters have otherwise changed. 

First I'll pull in the all the training utility code:

In [20]:
import string

chars = sorted(list(set(string.printable)))
vocab_size = len(chars) + 1

stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
stoi[len(stoi)] = 'ukn'
itos['ukn'] = len(itos)


encode = lambda s: [stoi.get(c, itos['ukn']) for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

data = torch.tensor(encode(text), dtype=torch.long, device=device)

spl = int(0.9 * len(data))

train_data = data[:spl]

val_data = data[spl:]

torch.manual_seed(1337)

batch_size = 4
context_length = 8

def get_batch(split, batch_size, context_length):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - context_length, (batch_size,))
    x = torch.stack([data[i:i+context_length] for i in ix])
    y = torch.stack([data[i+1:i+context_length+1] for i in ix])
    return x, y

def evaluate(model, batch_size, context_length, num_batches = 100):
    model.eval()
    x, y = get_batch('val', batch_size, context_length)
    losses = []
    for _ in range(num_batches):
        _, loss = model(x, y)
        losses.append(loss.item())
    loss = np.mean(losses)
    model.train()
    return loss


def train(model, num_steps, batch_size, context_length, learning_rate=1e-3, optimizer=None, print_every=1000, evaluate_every=1000):
    if optimizer is None:
        optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    model.train()
    losses = []
    eval_losses = []
    for step in range(1, num_steps+1):
        optimizer.zero_grad()
        x, y = get_batch('train', batch_size, context_length)
        logits, loss = model(x, y)
        loss.backward()
        optimizer.step()
        if step % print_every == 0:
            print(f'step {step}: loss {loss.item()}')

        if step == 0 or step % evaluate_every == 0:
            eval_loss = evaluate(model, batch_size, context_length)
            eval_losses.append(eval_loss)
        losses.append(loss.item())
    

    return losses, eval_losses

def plot_ema(losses, eval_losses, gamma=0.99, title='ema'):
    ema = losses[0]
    ema_losses = []
    eval_indices = np.linspace(0, len(losses), len(eval_losses), dtype=int)
    for i, l in enumerate(losses):
        ema = gamma * ema + (1-gamma) * l
        ema_losses.append(ema)
    plt.plot(ema_losses, label='train', color='red')
    plt.plot(eval_indices, eval_losses, label='val', color='blue')
    plt.legend()
    plt.show()
    print('final train loss (ema):', ema_losses[-1])
    print('final validation loss:', eval_losses[-1])

def generate_text(model, starting_text=' ', max_new_tokens=100):
    data = torch.tensor(encode(starting_text), dtype=torch.long, device=device).reshape(-1, 1)
    model.eval()
    with torch.no_grad():
        print(starting_text + decode(model.generate(data, max_new_tokens=max_new_tokens)[0].tolist()))
    model.train()

Now we can train the new model with identical parameters to the last training run from lecture:

In [21]:
context_length = 128
batch_size = 32
emb_size = 64
multi_block_model = MultiBlockModel(vocab_size, emb_size, emb_size, context_length, num_multi_attn_heads=8, num_blocks=8).to(device)
losses, eval_losses = train(multi_block_model, 10000, batch_size, context_length)
plot_ema(losses, eval_losses)
evaluate(multi_block_model, batch_size, context_length)
generate_text(multi_block_model)

KeyboardInterrupt: 

With a loss of ~1.49, our results are almost exactly the same, which means our implementation should be correct. This new implementation computes the entire multi-head attention layer in a vectorized manner rather than concatenating results from 

EX2: Train the GPT on your own dataset of choice! What other data could be fun to blabber on about? (A fun advanced suggestion if you like: train a GPT to do addition of two numbers, i.e. a+b=c. You may find it helpful to predict the digits of c in reverse order, as the typical addition algorithm (that you're hoping it learns) would proceed right to left too. You may want to modify the data loader to simply serve random problems and skip the generation of train.bin, val.bin. You may want to mask out the loss at the input positions of a+b that just specify the problem using y=-1 in the targets (see CrossEntropyLoss ignore_index). Does your Transformer learn to add? Once you have this, swole doge project: build a calculator clone in GPT, for all of +-*/. Not an easy problem. You may need Chain of Thought traces.)

In [34]:
from datasets import load_dataset

data = load_dataset("ccdv/cnn_dailymail", "1.0.0")
train_data = data['train']
val_data = data['validation']


Found cached dataset cnn_dailymail (/Users/marshingjay/.cache/huggingface/datasets/ccdv___cnn_dailymail/1.0.0/1.0.0/0107f7388b5c6fae455a5661bcd134fc22da53ea75852027040d8d1e997f101f)
100%|██████████| 3/3 [00:00<00:00, 130.78it/s]


In [87]:
context_length = 128
batch_size = 32
emb_size = 64
multi_block_model = MultiBlockModel(vocab_size, emb_size, emb_size, context_length, num_multi_attn_heads=8, num_blocks=8).to(device)
losses, eval_losses = train(multi_block_model, 10000, batch_size, context_length)
plot_ema(losses, eval_losses)
generate_text(multi_block_model)

KeyboardInterrupt: 

EX3: Find a dataset that is very large, so large that you can't see a gap between train and val loss. Pretrain the transformer on this data, then initialize with that model and finetune it on tiny shakespeare with a smaller number of steps and lower learning rate. Can you obtain a lower validation loss by the use of pretraining?

EX4: Read some transformer papers and implement one additional feature or change that people seem to use. Does it improve the performance of your GPT?