In [1]:
import torch
import math
import torch.nn as nn
from torch import Tensor
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

# **Task**

We train a baseline language model on the Tiny Shakespeare corpus. In particular, we will use character-level tokenization, and predict the next character given the previous $k$ characters (conditional probability over characteers). To achieve this we will train the model to maximize the log-likelihood of the data. 

$$
\max_{\theta} \sum_{i=1}^{N} \hat{p}(\mathbf{x}_i) = \max_{\theta} \sum_{i=1}^{N} \prod_{j=1}^{k} \hat{p}(x_j | x_{j-1},..., x_{j-k}; \theta) = \max_{\theta} \sum_{i=1}^{N} \sum_{j=1}^{k} \log \hat{p}(x_j | x_{j-1},..., x_{j-k}; \theta)
$$

where $\mathbf{x}_i$ is the $i$-th sequence in the dataset created from the corpus, $N$ is the number of sequences in the dataset and $\theta$ are the parameters of the model. 

### Multiple context lengths $k$

For large language models (LLMs) such as GPT-2, we in fact train over various context lengths $k$, e.g. we may set a maximum context length of $K$ and train the model on all sequences of length $K$ or less. The above formulation is therefore slightly simplified, but the idea is the same. This modification will also allow our model to learn how to generate text based on only a single character which is useful for generating text from scratch (beginning only with a "space" character).

In the context of creating a dataset, we can think of passing the model a sequence of length $K$ but instead of only predicting the next character based on the entire sequence we will predict the next character for each subsequence of length $k \leq K$. Therefore,  our target will consist of the next character for each subsequence of length $k$:

$$
[x_{1}, x_{2}, ..., x_{K}] \mapsto [x_{2}, x_{3}, ..., x_{K+1}] 
$$ 

In particular, the following subsequence predictions are made:

$$
x_{1} \mapsto x_{2} \\
x_{1}, x_{2} \mapsto x_{3} \\ 
x_{1}, x_{2}, x_{3} \mapsto x_{4} \\
... \\
x_{1}, x_{2}, ..., x_{K} \mapsto x_{K+1}
$$

When using a transformer (decoder) architecture, we can use the same input sequence for all subsequence predictions, but mask out the positions that are not relevant for the prediction.


# **Dataset**

In [45]:
# load data
corpus = open('tiny_shakespeare.txt', 'r').read()
print(corpus[:100], '...', sep="")

# map each character to an integer and vice versa
chars = sorted(set(corpus))
idx2char = dict(enumerate(chars))
char2idx = {v: k for (k, v) in idx2char.items()}
print(f'\nNo. characteres: {len(corpus)}. Unique set: {chars}')

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You...

No. characteres: 1115393. Unique set: ['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']


We use a `get_batch(split)` method as opposed to a sliding window approach to precompile all possible batches. This is due to the heavy overlap between sequences and the amount of memory it would require ($K$ times more). Instead, we will use a random sampling approach to generate batches on the fly.

In [41]:
# Form dataset from corpus
block_size = 20 # context size for next character prediction
batch_size = 32 # batch size for training

# encode corpus (char to int)
enc_corpus = [char2idx[i] for i in corpus[:1000]]

# split into train and validation sets
split_idx = int(len(enc_corpus)*0.9)
enc_train_corpus = enc_corpus[:split_idx]
enc_val_corpus = enc_corpus[split_idx:]

def get_batch(split):
    enc_corpus = enc_train_corpus if split == 'train' else enc_val_corpus
    idxs = torch.randint(len(enc_corpus) - block_size, (batch_size,))
    X_batch = torch.stack([torch.tensor(enc_corpus[i:i+block_size]) for i in idxs]) # stack along batch dimension (0)
    Y_batch = torch.stack([torch.tensor(enc_corpus[i+1:i+1+block_size]) for i in idxs])
    return X_batch, Y_batch

# X, Y = get_batch('train')
# print(f"X.shape: {X.shape} | X.dtype: {X.dtype}")
# print(f"Y.shape: {Y.shape} | Y.dtype: {Y.dtype}")
# for i in range(3):
#     print(f"{X[i]} ---> {Y[i]}")

# **Transformer model**

In this notebook we present a solution using the transformer architecture. As we are only interested in generating text without any context or prompt, we will only use the decoder part of the transformer. We begin by implementing the decoder stack without any modularity in order simplify the code and provide a clear overview of the model. We will then refactor the code to make it more modular and reusable.

When devising such a model, it is important to increase complexity gradually in order to ensure that the model is able to learn and benefitting from the added complexity. Consequently, one would generally begin without any attention layers and a reduced context length. Then one may increase the context length and use a simplified communication scheme between the tokens (e.g. mean pooling). Finally, one may add attention layers and increase the number of layers. In the below code we skip this process and jump straight to the final model for brevity.

In [92]:
batch_s = 3     # batch size
seq_l = 5       # sequence length
vocab_s = 6     # vocab size
emb_d = 8       # embedding dimension
h = 2           # number of heads

# basic outline of transformer: 
x_batch = torch.randint(0, vocab_s, (batch_s, seq_l)) # (batch_size, seq_len)

# embedding layer
W_emb = torch.randn(vocab_s, emb_d) # (vocab_size, embedding_dim)
x_emb = W_emb[x_batch] # (batch_size, seq_len, embedding_dim)
# batch_size, seq_len and embedding_dim are often referred to as B, T and C in the literature,
# which stands for batch dimension, temporal dimension and channel dimension, respectively.

# positional encoding
W_pos = torch.randn(seq_l, emb_d) # (seq_len, embedding_dim)
pos_idxs = torch.arange(seq_l) # (seq_len,)
x_pos = W_pos[pos_idxs] # (seq_len, embedding_dim)
x_emb = x_emb + x_pos # (batch_size, seq_len, embedding_dim)

# masked multi-head self-attention 
W_kqv = torch.randn(emb_d, 3*emb_d) # (embedding_dim, 3*embedding_dim)
k, q, v = (x_emb @ W_kqv).split(emb_d, dim=-1) # (batch_size, seq_len, embedding_dim) (3x)
# We split the embedding dimension into h equal parts (one for each attention head).
# Subsequently, we transpose the 1st and 2nd dimension to have batch_s * h batches of 
# seq_l x (emb_d // 2) matrices over which attention is performed.
# When devising such efficient implementations, it is important to test the correctness on small examples.
k = k.view(batch_s, seq_l, h, emb_d // h).transpose(1, 2)
q = q.view(batch_s, seq_l, h, emb_d // h).transpose(1, 2)
v = v.view(batch_s, seq_l, h, emb_d // h).transpose(1, 2)

# We now have a set of `seq_l` queries, keys and values. Each query (projected to from a token/character, embedding) is used to 
# generate a new embedding for said token/character based on its (≤seq_l) predecessors. To ensure that it only consists of 
# its predecessors, we mask out the affinities/weights to all tokens that come after it. This is done prior to applying softmax, 
# by setting those weights to `-inf`. After softmax, these normalized weights become `exp(-inf)/sum(...)` which is 0. This masking 
# process to ensure that the weights are only based on the tokens that come before the current token is called causal masking and 
# is what differentiates a decoder transformer from an encoder transformer. 
tril_mask = torch.ones((seq_l, seq_l)).tril()
wei = q @ k.transpose(-2, -1) / ((emb_d // 2) ** 1/2) # dot product amplifies variance ~> scale back to unit variance 
wei = wei.masked_fill(tril_mask == 0, -torch.inf)
wei = F.softmax(q @ k.transpose(-2, -1), -1) # (batch_size, h, seq_len, seq_len)
x_emb_att = wei @ v # (batch_size, h, seq_len, embedding_dim // h)
x_emb_att = x_emb_att.transpose(1, 2).contiguous().view(batch_s, seq_l, emb_d) # (batch_size, seq_len, embedding_dim)

# add & norm
x_emb = x_emb + x_emb_att
x_emb = F.layer_norm(x_emb, (emb_d,)) # (batch_size, seq_len, embedding_dim)

# cross-attention module
# As we do not require any cross-attention over a prompt or context, we skip this step.

# feed-forward module
d_ff = 4 * emb_d
W_ff1 = torch.randn((emb_d, d_ff))
W_ff2 = torch.randn((d_ff, emb_d))
x_ff = F.relu(x_emb @ W_ff1) @ W_ff2 # (batch_size, seq_len, embedding_dim)

# add & norm
x_emb = x_emb + x_ff
x_emb = F.layer_norm(x_emb, (emb_d,)) # (batch_size, seq_len, embedding_dim)
# In contrast to batch normalization, layer normalization normalizes over the embedding or channel dimension.
# This means that there is also no distinction between training and evaluation time, as the normalization
# is not based on the batch statistics.

# output layer
W_out = torch.randn((emb_d, vocab_s)) 
x_out = x_emb @ W_out # (batch_size, seq_len, vocab_size)

# softmax
logits = F.softmax(x_emb, dim=-1) # (batch_size, seq_len, vocab_size)