# Pre-training of GPT-2-124-million 
- pre-train on novel Anna Karenina
- adaptations
    - reduce the embedding size from 1024 to 256 to make sure the training loop will run locally
    - change it back to 1024 when loading / using tre-trained weights
- preprocess
- modules and models
    - define building-block modules
    - define model
- training loop
- evaluation
- inference

In [9]:
import numpy as np
import os
import sys 
from typing import Tuple, Dict, List

cwd = os.getcwd()

In [None]:
import torch
print(torch.__version__)
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn 
import torch.nn.functional as F

## preprocess

In [None]:
# load tokenizer
import importlib
import tiktoken
print("tiktoken version:", importlib.metadata.version("tiktoken"))
tokenizer = tiktoken.get_encoding("gpt2")

In [None]:
# read in raw text
pdata = f"{cwd[:-18]}traditional-NLP/data/"
sys.path.append(pdata)
with open(f"{pdata}anna.txt" , 'r', encoding='utf-8') as f:
    text_data = f.read()
print(f"The type of the raw text: {type(text_data)}")
print(f"The beginning of raw text: \n {text_data[:50]}")

In [None]:
# inspect raw text and tokens
total_characters = len(text_data)
print(f"total num of characters in Anna Karenina: {total_characters}")
total_tokens = len(tokenizer.encode(text_data))
print(f"total num of tokens in Anna Karenina with BPE tokenizer: {total_tokens}")
# total num of characters in Anna Karenina: 1985223
# total num of tokens in Anna Karenina with BPE tokenizer: 508206

### set parameters

In [None]:
CONFIG_GPT2_124M = {
    "vocab_size": 50257,   # Vocabulary size
    "context_length": 256, # Shortened context length (orig: 1024)
    "emb_dim": 768,        # Embedding dimension
    "n_heads": 12,         # Number of attention heads
    "n_layers": 12,        # Number of layers
    "drop_rate": 0.1,      # Dropout rate
    "qkv_bias": False      # Query-key-value bias
}

torch.manual_seed(123)

### torch dataset dataloader

In [None]:
# create dataset and dataloader

class my_text_dataset(Dataset):

    # initialize with n varg in
    def __init__(self, raw_text:str, tokenizer, max_length:int, stride:int=1):
        # create class attributes
        self.input_tokens_x = []
        self.target_tokens_y = []

        # tokenize the enitre text 
        tokens = tokenizer.encode(raw_text, allowed_special={"<|endoftext|>"})

        # set y as stride number of tokens trailing x 
        for i in range(0, (len(tokens)-max_length), stride):
            x_tmp = tokens[i : (i+max_length)]
            y_tmp = tokens[(i+1) : (i+max_length+1)]
            self.input_tokens_x.append(torch.tensor(x_tmp))
            self.target_tokens_y.append(torch.tensor(y_tmp))

    # overwrite the __len__() method to return number of rows in the dataset
    def __len__(self) -> int:
        "Returns the number of rows / pairs of x-y sequences in the dataset"
        return len(self.input_tokens_x)
    
    # overwrite the __getitem__() method (required for subclasses of torch.utils.data.Dataset)
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
        "Returns one sample of data, data and label (X, y)."
        return self.input_tokens_x[idx], self.target_tokens_y[idx]

def my_text_dataloader(raw_text:str, batch_size:int=4, max_length:int=256,
                       stride:int=128, shuffle=True, drop_last=True, num_workers=0):
    # initialize the tokenizer
    tokenizer = tiktoken.get_encoding("gpt2")

    # create dataset
    dataset = my_text_dataset(raw_text, tokenizer, max_length, stride)

    # create dataloader
    dataloader = DataLoader(
        dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=num_workers)

    return dataloader

#### split into T, V, H

In [None]:
total_characters = len(text_data)
print(f"total num of characters in Anna Karenina: {total_characters}")
prop_t, prop_v, prop_h = (0.8,0.1,0.1)
split_idx_t, split_idx_v = int(prop_t * total_characters), int((prop_t+prop_v) * total_characters)
print(f"Split at character index {split_idx_t} between train and valid sets, and at {split_idx_v} betwee valid and hold sets")

d_train = text_data[:split_idx_t]
d_valid = text_data[split_idx_t:split_idx_v]
d_hold  = text_data[split_idx_v:]
print(len(d_train), len(d_valid), len(d_hold))

assert len(total_tokens * prop_t) > CONFIG_GPT2_124M["context_length"], "Not enough tokens for loader_t (training dataloader)"
assert len(total_tokens * prop_v) > CONFIG_GPT2_124M["context_length"], "Not enough tokens for loader_v (validation dataloader)"
assert len(total_tokens * prop_h) > CONFIG_GPT2_124M["context_length"], "Not enough tokens for loader_h (testing dataloader)"

In [None]:
loader_t = my_text_dataloader(
    raw_text=d_train,
    batch_size=2, # this is only for learning purpose; in practice, batch_size >= 1024 is common
    max_length=CONFIG_GPT2_124M["context_length"],
    stride=CONFIG_GPT2_124M["context_length"],
    drop_last=True,
    shuffle=True,
    num_workers=0
)

loader_v = my_text_dataloader(
    raw_text=d_valid,
    batch_size=2,
    max_length=CONFIG_GPT2_124M["context_length"],
    stride=CONFIG_GPT2_124M["context_length"],
    drop_last=False,
    shuffle=False,
    num_workers=0
)

# loader_h = my_text_dataloader(
#     raw_text=d_hold,
#     batch_size=2,
#     max_length=CONFIG_GPT2_124M["context_length"],
#     stride=CONFIG_GPT2_124M["context_length"],
#     drop_last=False,
#     shuffle=False,
#     num_workers=0
# )

### inspect loaded data

In [None]:
print("Train loader:")
for x, y in loader_t:
    print(x.shape, y.shape)

print("\nValidation loader:")
for x, y in loader_v:
    print(x.shape, y.shape)

In [None]:
train_tokens = 0
for input_batch, target_batch in loader_t:
    train_tokens += input_batch.numel()

val_tokens = 0
for input_batch, target_batch in loader_v:
    val_tokens += input_batch.numel()

print("Training tokens:", train_tokens)
print("Validation tokens:", val_tokens)
print("All tokens:", train_tokens + val_tokens)

## modules and model
- key components:
    - tokenization - done in my_text_dataloader
    - input embedding
    - positional encoding
    - dropout
    - tansformer block
        - layernorm
        - multiheadattention CONFIG_GPT2_124M["n_heads"] by CONFIG_GPT2_124M["n_layers"]
        - droppout+shortcut
        - layernorm
        - feedford
        - dropout+shortcut
    - layernorm
    - output linear layer


### define modules

In [None]:
class Multihead_Causal_Attention(nn.Module):
    def __init__(self, d_in, d_out, context_length, n_heads, dropout_rate, qkv_bias=False):
        # inherit from the nn.Module parent 
        super().__init__() 

        # make sure d_out is divisible by n_heads (modulous ope, remainder==0)
        assert (d_out % n_heads == 0), "d_out must be divisible by n_heads"

        self.d_out = d_out
        self.n_heads = n_heads
        # floor division
        self.d_head = d_out // n_heads
        self.w_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.w_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.w_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        # add the buffer to create mask and send it to device with the model 
        # but not update it
        self.register_buffer(
            'mask',
            torch.triu(
                torch.ones(context_length,context_length),
                diagonal=1)
        )
        # add the dropout - object from nn.Dropout with param dropout_rate
        self.dropout = nn.Dropout(dropout_rate)
        # add linear layer to combine heads out
        self.combine_heads = nn.Linear(d_out, d_out)
    

    def forward(self, x):
        # allowing batching: first is the batch dim of tensors
        batch, n_tokens, d_in = x.shape

        # initialize the w_query, w_key, w_value 
        # AND matmul with input embeddings x
        queries = self.w_query(x)
        keys = self.w_key(x)
        values = self.w_value(x)

        # ###### split weights for the heads ######
        # dims from (batch, n_tokens, d_out) 
        # to (batch, n_tokens, n_heads, d_head)
        queries = queries.view(batch, n_tokens, self.n_heads, self.d_head)
        keys = keys.view(batch, n_tokens, self.n_heads, self.d_head)
        values = values.view(batch, n_tokens, self.n_heads, self.d_head)
        # then to (batch, n_heads, n_tokens, d_head)
        queries = queries.transpose(1,2)
        keys = keys.transpose(1,2)
        values = values.transpose(1,2)
        # ###### split weights for the heads ######

        # attention score query @ key.T 
        # but remember the dims is (batch, n_tokens, n_heads, d_head) so transpose the last two
        # !!! this computes dot product for each head !!!
        attention_scores = torch.matmul(queries, keys.transpose(2, 3))

        # add causal attention masks 
        # computeation with trailing underscore are performed in-place
        attention_scores.masked_fill_(
            # change the mask to boolean (truncated to num of tokens)
            self.mask.bool()[:n_tokens, :n_tokens],
            # fill value when 1 in mask
            -torch.inf
        )

        # attention weights = normalized attention scores
        # scale the attention scores by the sqrt(embedding dimentsion) first 
        # to improve the training performance by avoiding small gradients.
        attention_weights = torch.softmax(
            attention_scores / (keys.shape[-1]**0.5),
            dim=-1
        )

        # apply dropout to attention weights 
        attention_weights = self.dropout(attention_weights)

        # calculate context vector attention weights @ values
        # ###### combine across all heads  ######
        # dims (batch, n_heads, n_tokens, d_head) to (batch, n_tokens, n_heads, d_head)
        context_vectors = torch.matmul(attention_weights, values).transpose(1, 2)
        context_vectors = context_vectors.contiguous().view(
            batch, n_tokens, self.d_out
        )
        # Combines heads, where self.d_out= self.n_heads * self.d_head
        context_vectors = self.combine_heads(context_vectors)
        # ###### combine across all heads  ######


        return context_vectors

In [None]:
# nn.LayerNorm(emb_dim)
# if we code it out it does the following
# each mini-batch in the scenario is all the input embeddings of one context 
# mean and var came from calc across columns of the emsbeddings for each token
# then scale and shif provides a linear transformation
class LayerNorm(nn.Module):
    def __init__(self, emb_dim):
        super().__init__()
        self.eps = 1e-5
        # unlike buffers, Parameters will be updated
        self.scale = nn.Parameter(torch.ones(emb_dim))
        self.shift = nn.Parameter(torch.zeros(emb_dim))

    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True, unbiased=False)
        norm_x = (x - mean) / torch.sqrt(var + self.eps)
        return self.scale * norm_x + self.shift

In [None]:
# nn.GELU()
# when coding it out, it looks like the following
class GELU(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(
            torch.sqrt(torch.tensor(2.0 / torch.pi)) *
            (x + 0.044715 * torch.pow(x, 3))
        ))

In [None]:
# feed forward netword after multihead attention in each transformer block
# why does this particular architecture have a 4 x expansion and shrinkage?
class FeedForward(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(config["emb_dim"], 4 * config["emb_dim"]),
            GELU(),
            nn.Linear(4 * config["emb_dim"], config["emb_dim"]),
        )

    def forward(self, x):
        return self.layers(x)

In [None]:
class TransformerBlock(nn.Module):
    """
    Follows the architecture of GPT-2-124_million
    Decoder only transformer

    - tansformer block
        - layernorm
        - multiheadattention CONFIG_GPT2_124M["n_heads"] by CONFIG_GPT2_124M["n_layers"]
        - droppout
        - shortcut
        - layernorm
        - feedford
        - dropout
        - shortcut
    """

    def __init__(self, config):
        super().__init__()

        self.lnorm1 = LayerNorm(config["emb_dim"])
        self.mhca = Multihead_Causal_Attention(d_in=config["emb_dim"], 
                                               d_out=config["emb_dim"], 
                                               context_length=config["context_length"], 
                                               n_heads=config["n_heads"], 
                                               dropout_rate=config["drop_rate"])
        self.drop_out = nn.Dropout(p=config["drop_rate"])
        self.lnorm2 = LayerNorm(config["emb_dim"])
        self.ff = FeedForward(config)

    def forward(self, x):
        # define shortcut / residual connection for attenion block
        residual_conn = x
        # layer norm before attention
        x = self.lnorm1(x)
        # multihead causal attention
        x = self.mhca(x)
        # dropout 
        x = self.drop_out(x)
        # shortcut / residual connection
        x = x + residual_conn

        # define residual for FeedForward block
        residual_conn = x
        # layer norm
        x = self.lnorm2(x)
        # feedforward
        x = self.ff(x)
        # drop_out
        x = self.drop_out(x)
        # residual connection
        x = x + residual_conn
        return x

### define model

In [None]:
# put it all together into a model 
class GPT2_124_model(nn.Module):
    """
    - input embedding
    - positional encoding
    - dropout
    - tansformer block
    - layernorm
    - output linear layer
    """
    def __init__(self, config):
        super().__init__()
        # input embedding
        self.input_emb = nn.Embedding(num_embeddings=config["vocab_size"],
                                      embedding_dim=config["emb_dim"])
        # absolute positional encoding
        self.pos_emb = nn.Embedding(num_embeddings=config["context_length"], 
                                    embedding_dim=config["emb_dim"])
        # TO TRY RoPE
        # from torchtune.modules import RotaryPositionalEmbeddings
        # rope_dim = config["emb_dim"]/config["n_head"]
        # self.pos_emb = RotaryPositionalEmbeddings(dim=rope_dim)

        # dropout
        self.drop_out = nn.Dropout(p=config["drop_rate"])

        # transformer block x n_layers times
        self.transformer_block = nn.Sequential(
            # unpack list comprehension to repeat transformer-block n_layers times
            *[TransformerBlock(config) for _ in range(config["n_layers"])]
        )    

        # layer norm
        self.lnorm = LayerNorm(config['emb_dim'])

        # final output layer
        # expand tokens into logits in vocab_size dimensions
        # do not add extra bias 
        self.out_layer = nn.Linear(in_features=config["emb_dim"],
                                   out_features=config["vocab_size"],
                                   bias=False)
    
    def forward(self, input_tokens):
        batch_size, seq_len = input_tokens.shape
        input_embeddings = self.input_emb(input_tokens)
        posit_embeddings = self.pos_emb(torch.arange(seq_len, device=input_tokens.device))
        # add positional encoding into input embedding 
        x = input_embeddings + posit_embeddings
        # dropout
        x = self.drop_out(x)
        # transformer block
        x = self.transformer_block(x)
        # layer norm
        x = self.lnorm(x)
        # final output layer -> logits
        logits = self.out_layer(x)
        return logits

## training loop
- A typical training loop
- loop through all training epochs
    - within each epoch, loop through all batches (n_batches = train_size / batch_size)
        - reset (from previous batch iter) the loss gradient to zero 
        - calculate loss on the current batch
        - backpropagate loss gradient 
        - step to update weight and biases for next loop of training
        - claculate training and validation losses
        - visualize losses and sample texts 

## eval
- loss calculation under the hood:
    - get the logits from the transformer output layer
    - convert logits to probablities with softmax
    - get target probabilities 
    - take (-1) * log (probabilties) [ log(prob)<0 so maximize to 0 or (-1)* to minimize to 0 ]
    - take the cross entropy loss between predicted and target 


In [None]:
with torch.no_grad():
    logits = model(inputs)

logits_flat = logits.flatten(0, 1)
targets_flat = targets.flatten()
logits_flat = logits.flatten(0, 1)
targets_flat = targets.flatten()

loss = torch.nn.functional.cross_entropy(logits_flat, targets_flat)
print(loss)

# The perplexity is often considered more interpretable 
# because it can be understood as the effective vocabulary size 
# that the model is uncertain about at each step 
# (in the example below, that'd be perplexity number of tokens)
perplexity = torch.exp(loss)
print(perplexity)

In [None]:
# calculate loss per batch
def calc_loss_batch(input_batch, target_batch, model, device):
    input_batch, target_batch = input_batch.to(device), target_batch.to(device)
    logits = model(input_batch)
    loss = torch.nn.functional.cross_entropy(logits.flatten(0, 1), target_batch.flatten())
    return loss

# function to compute loss per loader of data during training and validation
# as sum of loss across all batches in the dataloader / number of batches
def calc_loss_loader(data_loader, model, device, num_batches=None):
    total_loss = 0.
    if len(data_loader) == 0:
        return float("nan")
    elif num_batches is None:
        # iterate over all batches in the dataloader unless num_batches is otherwise given
        num_batches = len(data_loader)
    else:
        # Reduce the number of batches to match the total number of batches in the data loader
        # if num_batches exceeds the number of batches in the data loader
        num_batches = min(num_batches, len(data_loader))
    for i, (input_batch, target_batch) in enumerate(data_loader):
        if i < num_batches:
            loss = calc_loss_batch(input_batch, target_batch, model, device)
            # sum loss from each batch 
            total_loss += loss.item() 
        else:
            break
    return total_loss / num_batches

In [None]:
model.to(device)
with troch.no_grad():
    loss_train = calc_loss_loader(data_loader=loader_t, model=model, device=device)
    loss_valid = calc_loss_loader(data_loader=loader_v, model=model, device=device)

print(f"Training loss: {loss_train}")
print(f"Validation loss: {loss_valid}")

## inference