## GPT Development

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# Unique characters in the text
chars = sorted(list(set(text)))
vocab_size = len(chars)

# Mapping from characters to integers and vice versa
char_to_int = {c: i for i, c in enumerate(chars)}
int_to_char = {i: c for i, c in enumerate(chars)}

encode = lambda x: [char_to_int[c] for c in x] # x: str -> list[int]
decode = lambda x: ''.join([int_to_char[i] for i in x]) # x: list[int] -> str

print(encode('Hello World!'))
print(decode(encode('Hello World!')))


In [None]:
# Train and validation splits
data = torch.tensor(encode(text), dtype=torch.long)
split = int(len(data) * 0.9) # 90% train, 10% val
train_data, val_data = data[:split], data[split:]

In [None]:
# There are a total of block_size training examples in each block
block_size = 8
x = train_data[:block_size]
y = train_data[1:block_size+1]
for t in range(block_size):
    print(f'context: {x[:t+1].tolist()} -> target: {y[t]}')

Training with these different lengths of contexts from a size of 1 to block_size is important to ensure the transformer learns to deal with different context lengths. This is useful during inference because the model can generate text from as little as one character of context.

In [None]:
batch_size = 4 # Number of sequences to process in parallel
block_size = 8 # Maximum context length for predictions

def get_batch(split: str) -> tuple[torch.Tensor, torch.Tensor]:
    """Generate a random batch of context and target sequences."""
    data = train_data if split == 'train' else val_data
    # Randomly sample batch_size number of starting indices
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

# Get a batch of context and target sequences
xb, yb = get_batch('train')
print(f'xb: {xb}\nyb: {yb}')

# xb and yb are both tensors of shape (batch_size, block_size)

In [None]:
# Display the context and target sequences for each batch element
for b in range(batch_size):
    for t in range(block_size):
        print(f'context: {xb[b, :t+1].tolist()} -> target: {yb[b, t]}')

### Bigram Model

In [None]:
# B - batch size, T - block size (time step), C - embedding dimension (vocab size)

class BigramLanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embed_table = nn.Embedding(vocab_size, vocab_size) # (B,T) -> (B,T,C)

    def forward(self, x: torch.Tensor, y: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
        logits = self.token_embed_table(x)

        if y is None:
            loss = None
        else:
            B, T, C = logits.shape
            # Flatten batch and sequence dimensions to use F.cross_entropy
            logits = logits.view(B*T, C)
            y = y.view(B*T)
            loss = F.cross_entropy(logits, y)
        return logits, loss

    def generate(self, x: torch.Tensor, max_tokens: int) -> torch.Tensor:
        for _ in range(max_tokens):
            # Get the previous predictions
            logits, _ = self(x)
            # Keep only the last prediction
            logits = logits[:, -1, :] # (B,C)
            # Apply softmax to convert logits into probabilities
            probs = F.softmax(logits, dim=-1) # (B,C)
            # Sample from the probability distribution
            x_next = torch.multinomial(probs, num_samples=1) # (B,1)
            # Concatenate the new prediction to the previous context
            x = torch.cat([x, x_next], dim=1) # (B,T+1)
        return x
    
model = BigramLanguageModel()

# Generate
context = torch.zeros((1, 1), dtype=torch.long)
print(decode(model.generate(context, max_tokens=32)[0].tolist()))

The integer associated with each character is used as an index to look up the corresponding row in the embedding table. This row is a trainable vector (of size `n_embed`) representation of the character.

In [None]:
# Hyperparameters
batch_size = 32 # Sequences to process in parallel
max_iters = 2500 # Iterations to train the model
lr = 1e-2 # Learning rate

# Training the model
optimiser = torch.optim.AdamW(model.parameters(), lr=lr)

loss = torch.tensor(torch.inf)

for i in range(max_iters):

    if i % (max_iters // 10) == 0 or i == max_iters - 1:
        print(f'iteration {i}, loss: {loss.item()}')

    # Get a batch of context and target sequences
    xb, yb = get_batch('train')

    # Compute the gradients and update the weights
    _, loss = model(xb, yb) # Forward pass
    optimiser.zero_grad(set_to_none=True)
    loss.backward()
    optimiser.step()

In [None]:
# Generate
context = torch.zeros((1, 1), dtype=torch.long)
print(decode(model.generate(context, max_tokens=32)[0].tolist()))

### Self-Attention

In [None]:
B, T, C = 4, 8, 32
x = torch.randn(B, T ,C)

# Bag of words. Calculate x[b,t] = mean_{t'<=t} x[b,t']
xbow_1 = torch.zeros((B, T, C))
for b in range(B):
    for t in range(T):
        xbow_1[b, t] = torch.mean(x[b, :t+1], 0)

# Version 2. Parallelised. W is a lower triangular matrix which can be used for weighted aggregation
W = torch.tril(torch.ones(T, T))
W = W / W.sum(1, keepdim=True)
xbow_2 = W @ x

# Version 3. Parallelised. Uses softmax. W represents the same lower triangular matrix as before
tril = torch.tril(torch.ones(T, T))
W = torch.zeros((T, T))
W = W.masked_fill(tril == 0, float('-inf'))
W = F.softmax(W, dim=-1)
xbow_3 = W @ x

# Check that the three methods are equivalent
torch.allclose(xbow_1, xbow_2) and torch.allclose(xbow_1, xbow_3)


In [None]:
# B - batch size, T - block size (time step), C - embedding dimension, H - head size

# Single head self-attention
head_size = 16
n_embed = 32
key = nn.Linear(n_embed, head_size, bias=False) # (B,T,C) -> (B,T,H)
query = nn.Linear(n_embed, head_size, bias=False) # (B,T,C) -> (B,T,H)
value = nn.Linear(n_embed, head_size, bias=False) # (B,T,C) -> (B,T,H)
k = key(x)
q = query(x)

# Compute the scaled dot-product attention
W = q @ k.transpose(-2, -1) # (B,T,H) @ (B,H,T) -> (B,T,T)
tril = torch.tril(torch.ones(T, T))
W = W.masked_fill(tril == 0, float('-inf'))
W = F.softmax(W, dim=-1)
v = value(x)
out = W @ v # (B,T,T) @ (B,T,H) -> (B,T,H)

**Notes:**
- Attention is a communication mechanism. It can be viewed as nodes in a directed graph looking at each other and aggregating information with a weighted sum from all nodes that point to them, with data-dependent weights.
- In the attention layer of a Transformer, every token is attending to a finite list of tokens previously in the sequence. This is called causal self-attention.
- There is no notion of space. Attention simply acts over a set of vectors. This is why tokens need to be positionally encoded.
- Each example across batch dimensions are treated independently and never interact with each other.
- In an encoder attention block just delete the single line that performs masking with `tril`, allowing all tokens to communicate with each other and not just the previous ones. The block implemented above is called a decoder attention block because it has triangular masking and is used in autoregressive settings like language modelling.
- 'Self-attention' just means that the keys and the values are produced from the same source as the queries (`x` in this case). In 'cross attention', the queries still get produced from `x`, but the keys and values come from a different source (such as an encoder module).
- 'Scaled' attention additionally divides `W` by $1/\sqrt{H}$. This ensures that when the input `Q` and `K` are of unit variance, `W` has unit variance as well and softmax will stay diffuse and not saturate (see below).

In [None]:
k = torch.randn(B, T, head_size)
q = torch.randn(B, T, head_size)
W = q @ k.transpose(-2, -1) * head_size**-0.5

k.var(), q.var(), W.var()

In [None]:
print(torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5]), dim=-1))
# With larger values the probabilities become more concentrated, converges to a one-hot vector
print(torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5]) * 10, dim=-1))