In [None]:
import torch
import torch.nn.functional as F
import torch.nn as nn
import string
import random
import matplotlib.pyplot as plt
import tqdm

%matplotlib inline

In [None]:
# --- Hyperparameters ---
context_length = 8
embedding_size = 32
num_heads = 4
num_layers = 4

batch_size = 64
dropout = 0.2

num_iterations = 10_000
learning_rate = 1e-3
use_subset = False
device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Using {device}")

In [None]:
text = open("data/wiki.txt", "r").read()
text = text.lower()

if use_subset:
    text = text[:1000]

print(f"Input has size: {len(text)}")
print()
print(text[:1000])

In [None]:
chars = sorted(list(set("".join(text))))

stoi = dict()
for c in chars:
    stoi[c] = len(stoi)

vocab_size = len(stoi)

itos = {i:s for s, i in stoi.items()}

print(stoi)
print(itos)

def encode(s):
    return [stoi[ch] for ch in s]

def decode(d):
    return "".join([itos[i] for i in d])

In [None]:
# Encode the entire dataset as a torch tensor
data = torch.tensor(encode(text), dtype=torch.long).to(device)
print(data.shape)
print()
print(data[:1000])

In [None]:
train_size = int(0.9 * len(data))
train_data = data[:train_size]
val_data = data[train_size:]
print(f"Train size: {len(train_data)}")
print(f"Validation size: {len(val_data)}")

In [None]:
def get_batch(split):
    data = train_data if split == "train" else val_data
    # Get batch_size indices into the array
    indices = torch.randint(len(data) - context_length, (batch_size,))
    x = torch.stack([data[i:i + context_length] for i in indices])
    y = torch.stack([data[i + 1:i + context_length + 1] for i in indices])
    return x, y

def print_batch_example():
    xb, yb = get_batch("train")
    print(xb.shape)
    print(xb)
    print(yb.shape)
    print(yb)

    for b in range(batch_size):
        for c in range(context_length):
            context = xb[b, :c + 1]
            targets = yb[b, c]
            print(f"{context} -> {targets}")

print_batch_example()

In [None]:
class SelfAttentionHead(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.K = nn.Linear(embedding_size, head_size, bias=False)
        self.Q = nn.Linear(embedding_size, head_size, bias=False)
        self.V = nn.Linear(embedding_size, head_size, bias=False)
        self.dropout = nn.Dropout(dropout)
        # Triangular matrix that makes sure that tokens cannot look at following tokens
        self.register_buffer("tril", torch.tril(torch.ones(context_length, context_length)))

    def forward(self, x):
        # x is (B, T, C)
        B, T, C = x.shape
        k = self.K(x) # (B, T, head_size)
        q = self.K(x) # (B, T, head_size)
        
        # Compute attention
        wei = q @ torch.transpose(k, -1, -2) # (B, T, head_size) @ (B, head_size, T) = (B, T, T)
        wei = wei / C**0.5
        wei = torch.masked_fill(wei, self.tril[:T, :T] == 0, float("-inf"))
        wei = torch.softmax(wei, dim=-1)
        wei = self.dropout(wei)
        
        v = self.V(x) # (B, T, head_size)
        out = wei @ v # (B, T, T) @ (B, T, head_size) = (B, T, head_size)
        return out
    
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([SelfAttentionHead(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(embedding_size, embedding_size)
        self.dropout = nn.Dropout()
        
    def forward(self, x):
        x = torch.cat([h(x) for h in self.heads], dim=-1) # (B, T, C)
        x = self.proj(x)
        x = self.dropout(x)
        return x
    
class FeedForward(nn.Module):
    def __init__(self, size):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(size, 4 * size),
            nn.ReLU(),
            nn.Linear(4 * size, size),
            nn.Dropout(dropout),
        )
        
    def forward(self, x):
        return self.net(x)
        
class Block(nn.Module):
    def __init__(self, embedding_size, num_heads):
        super().__init__()
        head_size = embedding_size // num_heads
        self.sa = MultiHeadAttention(num_heads, head_size)
        self.ffwd = FeedForward(embedding_size)
        self.ln1 = nn.LayerNorm(embedding_size)
        self.ln2 = nn.LayerNorm(embedding_size)

    def forward(self, x):
        # With residual connection
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x
    
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, embedding_size)
        self.position_embedding_table = nn.Embedding(context_length, embedding_size)
        self.blocks = nn.Sequential(
            *[Block(embedding_size, num_heads=num_heads) for _ in range(num_layers)]
        )
        self.lm_head = nn.Linear(embedding_size, vocab_size)
        self.ln = nn.LayerNorm(embedding_size)
        
    def forward(self, context, targets=None):
        B, T = context.shape
        
        tok_emb = self.token_embedding_table(context) # (B, T, C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T, C)
        x = tok_emb + pos_emb # (B, T, C)
        x = self.blocks(x) # (B, T, C)
        logits = self.lm_head(self.ln(x)) # (B, T, vocab_size)
        
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            
            logits = logits.view(B * T, C)
            targets = targets.view(B * T)
            
            loss = F.cross_entropy(logits, targets)
        
        return logits, loss
    
    def generate(self, context, num_new_tokens):
        # context is (B, T)
        for _ in range(num_new_tokens):
            # Crop context to only contain context_length tokens
            cropped_context = context[:, -context_length:]
            # Predict next token
            logits, _ = self(cropped_context) # (B, T, C)
            # Get the raw outputs for the next token
            logits = logits[:, -1, :]
            # Apply softmax to get probabilities
            probs = torch.softmax(logits, dim=1)
            # Get next token from probabilities
            next_token = torch.multinomial(probs, 1)
            # Append next token to context
            context = torch.cat([context, next_token], dim=1)
        return context
            
            
    
m = Model().to(device)

xb, yb = get_batch("train")
logits, loss = m(xb, yb)
print(logits.shape)
print(loss)

In [None]:
@torch.no_grad()
def estimate_loss():
    m.eval()
    
    num_batches = 10
    split_losses = dict()
    
    for split in ["train", "val"]:
        losses = torch.zeros(num_batches)
        for i in range(num_batches):
            # Sample batch
            xb, yb = get_batch(split)
            _, loss = m(xb, yb)
            losses[i] = loss
        split_losses[split] = losses.mean().item()
        
    m.train()
    
    return split_losses["train"], split_losses["val"]
    
estimate_loss()

In [None]:
# Train model
train_losses = []
val_losses = []
optimizer = torch.optim.Adam(m.parameters(), lr=learning_rate)

for iteration in (pbar := tqdm.tqdm(range(num_iterations))):
    if iteration % 200 == 0:
        train_loss, val_loss = estimate_loss()
        pbar.set_description(f"T: {train_loss:.3}, V: {val_loss:.3}")
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        
    # Sample batch of data
    xb, yb = get_batch("train")
    
    # Forward pass
    logits, loss = m(xb, yb)
    
    # Backwards pass
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    
plt.figure()
plt.plot(train_losses)
plt.plot(val_losses)

In [None]:
# Let the model generate something
raw_output = m.generate(torch.zeros(1, 1, dtype=torch.long), 1000)
print(decode(raw_output[0].tolist()))

## History

```
Bigram:      T: 2.48, V: 2.48
```