Imports

In [1]:
# This will be much more dense in terms of imports
import torch
import torch.nn.functional
import numpy as np
import math
import random
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.trainers import BpeTrainer
from datasets import load_dataset
from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR, SequentialLR
import matplotlib.pyplot
import contextlib
from torch.utils.checkpoint import checkpoint
from torch.optim import AdamW

Setting the device

In [2]:
# For this runtime, a GPU is almost mandatory, so please connect to a GPU runtime of some kind if you haven't already
device = (torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu"))

Classes(Attention_Head, Multi_Headed_Attention, Transformer_Block)

In [3]:
class Attention_Head(torch.nn.Module):
  def __init__(self, embedding_size, head_embedding_size):
    super().__init__()
    self.head_embedding_size = head_embedding_size
    # For the following classes, we will use torch.nn.Linear instead of standard matrices. The reason for this is that nn.Linear is optimized for GPU training
    # nn.Linear without biases functions the exact same as a normal matrix tensor. nn.Linear with biases functions equivalently to one layer of an MLP pre-activation
    self.query = torch.nn.Linear(embedding_size, head_embedding_size, bias = False)
    self.key = torch.nn.Linear(embedding_size, head_embedding_size, bias = False)
    self.value = torch.nn.Linear(embedding_size, head_embedding_size, bias = False)
  def forward(self, embeddings, mask):
    # These multiply the query, key, and value matrices with the embeddings, producing embeddings that are scaled down to the head_embedding_size
    query_vectors = self.query(embeddings)
    key_vectors = self.key(embeddings)
    value_vectors = self.value(embeddings)
    # Now we dot the list of query vectors with the list of key vectors and scale them down
    key_query_product = torch.matmul(query_vectors,key_vectors.transpose(-2,-1)) / math.sqrt(self.head_embedding_size)
    # Here, we apply a causal mask to the key_query_product to prevent the future tokens from providing information to past ones
    key_query_product = key_query_product.masked_fill(mask, float("-inf"))
    # Now we apply softmax, then dot the key_query_product with the value vectors to create context vectors
    key_query_product = torch.nn.functional.softmax(key_query_product, dim = -1).to(value_vectors.dtype)
    return torch.matmul(key_query_product, value_vectors)
  # Note: We don't have an optimization method here, as we are going to use Pytorch's "Optimizer Adam"

class Multi_Headed_Attention(torch.nn.Module):
  def __init__(self, embedding_size, num_heads):
    super().__init__()
    # We instantiate with module list in order to access params with a Pytorch optimizer later on
    self.attentionHeads = torch.nn.ModuleList([
      Attention_Head(embedding_size, embedding_size // num_heads) for _ in range(num_heads)
    ])
    # torch.nn.Linear is identical to the Layers_Dense class from the other notebooks, but it has faster performance on GPUs due to optimization in PyTorch
    self.linear = torch.nn.Linear(embedding_size,embedding_size, bias = True)

  def forward(self, inputs, mask):
    # We define a list that will eventually contain all the outputs of the attention layers
    attn_outputs = []
    # We compute attention with all the heads, then concatenate them into one list
    for head in self.attentionHeads:
      attn_outputs.append(head.forward(inputs,mask))
    attn_outputs = torch.cat(attn_outputs, dim = -1)
    # We apply a linear operation
    return self.linear(attn_outputs)

class Transformer_Block(torch.nn.Module):
  def __init__(self, embedding_size, num_heads):
    super().__init__()
    # We define layer dropout here, so that we can prevent overfitting and have the model make better use of params
    self.dropout = torch.nn.Dropout(p = 0.1)
    # Two layer norms for two residual connection additions
    self.layer_norm_1 = torch.nn.LayerNorm(embedding_size, eps = 1e-5)
    self.layer_norm_2 = torch.nn.LayerNorm(embedding_size, eps = 1e-5)
    # Its conventional to have linear layers perform an operation that scales embeddings to a higher dimensionality, apply an activation function, and scale them back down
    # In theory, this should allow the embeddings to recieve higher information across more dimensions than the embeddings allow, before scaling them back down
    self.linear_i = torch.nn.Linear(embedding_size,4*embedding_size, bias = True)
    self.linear_f = torch.nn.Linear(4*embedding_size,embedding_size, bias = True)
    # We define our Multi_Headed_Attention block here
    self.attention = Multi_Headed_Attention(embedding_size,num_heads)

  def forward(self, inputs, mask):
    # In this implementation we make use of pre-norming before attention and passing through feedforward layers
    inputs = inputs + self.dropout(self.attention(self.layer_norm_1(inputs),mask))
    # We pass through an activation function that often yields great results for LLM training, called GeLu
    # GeLu itself is a function that scales negative values similarly to ReLu, except with a softer cutoff
    inputs = inputs + self.dropout(self.linear_f(torch.nn.functional.gelu(self.linear_i(self.layer_norm_2(inputs)), approximate='tanh')))
    # We return the updated embeddings with added context
    return inputs

Creating and pre-processing our dataset

In [None]:
# These are params you can experiment with. A smaller corpus is easier to fit on, but might not produce equivalent intelligibility to a larger one. Vocab is similar here
len_corpus = 200_000_000
vocab_size = 17500
current_chars = 0
target_chars = len_corpus
subset_text = []
# We load our dataset here. My goal when desigining this was coherent phrase completion, so I used the "bookcorpusopen", a corpus made of compiled book passages
ds = load_dataset("rojagtap/bookcorpus", split="train", streaming=True)
# Creating the data
for chunk in ds:
  text = chunk["text"]
  # We get chunks from our dataset and append them to our subset of the text
  if current_chars + len(text) > target_chars:
    subset_text.append(text[:target_chars - current_chars])
    break
  subset_text.append(text)
  current_chars += len(text)

# Cleaning
added = set()
cleaned = []

# We clean chunks in our subset and append them
for chunk in subset_text:
  chunk = chunk.strip()
  if len(chunk) >= 10:
    if chunk in added:
      continue
    else:
      cleaned.append(chunk)
      added.add(chunk)

# Here, we create a tokenizer and create our vocab
tokenizer = Tokenizer(BPE()); tokenizer.pre_tokenizer = Whitespace()
trainer = BpeTrainer(vocab_size=vocab_size, show_progress=True)
random.shuffle(cleaned)
# Trains the tokenizer
tokenizer.train_from_iterator(iter(cleaned), trainer=trainer)
assert len(tokenizer.get_vocab()) > 0
print(len(tokenizer.get_vocab()))

# We create our dataset here
ids = tokenizer.encode("\n\n".join(cleaned)).ids
print(f"Total number of tokens in dataset: {len(ids):,}")
data = torch.tensor(ids,dtype = torch.long).cpu()
n_train = int(0.9*data.numel())
train_data = data[:n_train]
val_data = data[n_train:]

Defining a batching method

In [6]:
# minGPT by Andrej Karpathy
# https://github.com/karpathy/minGPT
# MIT License
# This block of code is inspired by Andrej Karpathy's "Let's Build GPT from Scratch"
# The shape of a training or validation example will be batch size x sequence len x embedding dimensionality
def get_batch(split, batch_size, seq_len):
  data_split = train_data if split =='train' else val_data
  ix = torch.randint(0, len(data_split)-seq_len-1, (batch_size,), device='cpu')
  offsets = torch.arange(seq_len, device='cpu').unsqueeze(0)
  # We do this in the style of the target sequence being the input sequence shifted by one
  # This allows us to get much more training examples out of one training one, for example if the input sequence was
  # "The dog was happy that"
  # The target sequence would be:
  # "dog was happy that he"
  # We get predictions based on "The", "The dog", "The dog was", "The dog was happy", and "The dog was happy that"
  x = data_split[ix.unsqueeze(1) + offsets].pin_memory()
  y = data_split[ix.unsqueeze(1) + offsets + 1].pin_memory()
  return x, y

Transformer Class

In [5]:
class Transformer(torch.nn.Module):
  def __init__(self,embedding_size,num_heads,num_layers, lr):
    super().__init__()
    self.embedding_size = embedding_size
    self.num_heads = num_heads
    # Wraps everything in a ModuleList so that params can be accessed later by the optimizer
    self.layers = torch.nn.ModuleList([
        Transformer_Block(embedding_size,num_heads) for _ in range(num_layers)
    ])
    # Layer Norm for norming the embeddings
    self.layer_norm = torch.nn.LayerNorm(embedding_size)
    # Stores loss history for use with matplotlib
    self.training_loss_history = []
    self.val_loss_history = []
    # Creates the embedding table and biases for the final linear projection
    # You could get similar functionality out of simply creating a Tensor of the same shape and indexing but nn.Embedding is specialized for "sparse updates", when only a few
    # Vectors within an embedding table get updated. Because of this, using nn.embedding can save you memory and time
    # Ex: self.embedding_table = torch.nn.Parameter(torch.randn(vocab_size,embedding_size), requires_grad = True).to(device)
    self.embedding_table = torch.nn.Embedding(vocab_size, embedding_size).to(device)
    self.vocab_bias = torch.nn.Parameter(torch.zeros(vocab_size), requires_grad=True).to(device)
    # Instantiates an "Optimizer Adam" instance, making use of learning rate adjustments over time
    self.optimizer_lr = lr
    fused = (device.type == "cuda")
    self.optimizer = torch.optim.AdamW(self.parameters(),lr=lr,betas=(0.9, 0.95),eps=1e-8,weight_decay=0.0,fused=fused)

  def positional_encoding(self, seq_len, embedding_size):
     # If you want to learn about this specific kind of PE, read the original Transformer paper "Attention is All You Need"
    pos_encoding = torch.zeros(seq_len,embedding_size)
    pos = torch.arange(0,seq_len, dtype = torch.float).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, embedding_size, 2).float() * -(math.log(10000.0) / embedding_size))
    # Applies the formula from the paper to even and odd indices
    pos_encoding[:, 0::2] = torch.sin(pos * div_term)
    pos_encoding[:, 1::2] = torch.cos(pos * div_term)
    return pos_encoding

  def forward(self, inputs):
    # We save these shapes for later use
    batch_size, seq_len = inputs.shape
    # Accesses embeddings through token indices and adds positional encoding
    embeddings = self.embedding_table(inputs) + self.pos_encoding[:inputs.size(1)]
    # We compute a causal mask here and pass it to the transformer blocks
    mask = torch.triu(torch.ones(inputs.size(1), inputs.size(1), dtype=torch.bool, device=inputs.device), diagonal=1)
    mask = mask.unsqueeze(0)
    # We pass the embeddings through all of the model's layers
    for block in self.layers:
      embeddings = block.forward(embeddings,mask)
    # We norm the embeddings, then project them to vocab_size
    # We use weight tying here(using the same weights for multiple operations), as it saves memory greatly and leads to better performance
    logits = torch.matmul(embeddings, self.embedding_table.weight.T)+ self.vocab_bias
    # Each row of the final logits has shape [batch_size, seq_len, vocab_size]
    # Each position contains information about the corresponding token in the input sequence
    # These logits are raw scores predicting the likelihood of each token being the next one (before softmax)
    return logits

  def train_model(self, num_batches, batch_size, seq_len):
    # This puts the model in training mode, where layer dropout is used
    self.train()
    # We set a warmup time so that we can implement learning rate warmup and decay, useful features for optimizing training
    warmup_steps = 1000
    hold_until = int(0.35 * num_batches)
    # Defines the warmup, which allows the model to slowly increase its lr so that it uses its peak lr during the most usefl parts of training
    warmup = LambdaLR(self.optimizer, lr_lambda=lambda s: (s+1)/max(1,warmup_steps))
    # Defines how long the model keeps its lr
    flat = LambdaLR(self.optimizer, lr_lambda=lambda s: 1.0)
    # Defines when the learning rate decays, allowing for more pecise gradient adjustments that don't bounce around minimums
    min_lr = 1e-4
    cosine = CosineAnnealingLR(self.optimizer, T_max=max(1, num_batches - hold_until), eta_min=min_lr)
    # Composite scheduler
    scheduler = SequentialLR(self.optimizer, [warmup, flat, cosine], milestones=[warmup_steps, hold_until])
    # Now, we will define some features in order to conserve memory during training
    is_cuda = (device == "cuda") # If we have a GPU runtime enabled
    # We make use of autocast_ctx in order to save memory during training by casting float32s to float16s when safe
    autocast_ctx = (lambda: torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)) if is_cuda else contextlib.nullcontext
    # Creates a positional encoding tensor so we don't have to reuse it
    self.pos_encoding = self.positional_encoding(seq_len, self.embedding_size).to(device)
    for idx in range(num_batches):
      # Gets batches for training
      xb, yb = get_batch("train", batch_size, seq_len)
      xb = xb.to(device, non_blocking=True)
      yb = yb.to(device, non_blocking=True)
      # Sets loss to none outside of the loop so that we can use it as a reference later
      loss = None
      try:
        # Uses autocast_ctx during the forward process
        with autocast_ctx():
          logits = self.forward(xb)
          assert logits.shape[:-1] == yb.shape, f"Shape mismatch: {logits.shape} vs {yb.shape}"
          loss = torch.nn.functional.cross_entropy(logits.reshape(-1, logits.size(-1)), yb.view(-1))

      except Exception as e:
            print(f"Error at batch {idx}: {e}")
            continue
      if loss is not None:
        # Performs optimization
        self.optimizer.zero_grad(set_to_none = True)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.parameters(), 2.0)
        self.optimizer.step()
        # Steps with the scheduler, adjusting lr if needed
        scheduler.step()
        # Appends current loss for plotting later
        self.training_loss_history.append(loss.item())
        if idx%50 == 0:
          print(f"[train] {idx}: {loss.item():.4f}")
        # Validation every 200 steps
        if idx%200 == 0 :
          # Sets to eval mode so there is no dropout
          self.eval()
          with torch.no_grad(), autocast_ctx():
            # Gets a val batch
            vx, vy = get_batch("val", batch_size, seq_len)
            vx = vx.to(device, non_blocking=True)
            vy = vy.to(device, non_blocking=True)

            v_logits = self.forward(vx)
            v_loss = torch.nn.functional.cross_entropy(
                v_logits.view(-1, v_logits.size(-1)), vy.view(-1)
            )
            # Appends to the val_loss history10
            self.val_loss_history.append(v_loss.item())
            print(f"[val] {idx}: {v_loss.item():.4f}")
          # Sets back to training mode
          self.train()
  # Here we implement topk generation, where instead of simply choosing the most likely output, the model chooses between the top 50. This often leads to less repetitive sentences
  # and more coherent sentence completion
  def generate(self,prompt, tokenizer, max_new_tokens, k):
    # We encode the prompt and add a batch dimension
    prompt = torch.tensor(tokenizer.encode(prompt).ids, dtype = torch.long, device = device).unsqueeze(0)
    # We progressively add new tokens till we reach the max new, picking from the topk tokens
    with torch.no_grad():
      for i in range(max_new_tokens):
        self.pos_encoding = self.positional_encoding(prompt.shape[1], self.embedding_size).to(device)
        # We forward the prompt through the model
        logits = self.forward(prompt)
        # We slice the sequence to the last embedding
        next_token_probs = logits[:,-1,:]
        top_k, top_k_indices = torch.topk(next_token_probs,k)
        # We take all the tokens that aren't in the top_k values out of the equation
        next_token_probs[next_token_probs < top_k[:, -1, None]] = float("-inf")
        # We create a probability distribution using the rest
        probs = torch.nn.functional.softmax(next_token_probs, dim = -1)
        # We choose from the tokens and append them to the sequence
        next_token = torch.multinomial(probs, num_samples=1, replacement=False)
        prompt = torch.cat((prompt,next_token), dim = 1)

    return tokenizer.decode(prompt.squeeze(0).tolist())

Defining the model

In [None]:
# 512 embedding dim, 8 attention heads per layer, 7 layers, learning rate of 5e-4
transformer = Transformer(512,8,7,5e-4)
# We put the transformer on cuda so it can make use of the gpu
transformer.to(device)
# We adjust some extra params so that we can optimize for speed
if device.type == "cuda":
    torch.set_float32_matmul_precision("high")
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True)

# This allows you to see all the params included in the model
for name, param in transformer.named_parameters():
    print(name, param.shape)

Training

In [None]:
#transformer.train_model(20000,128,256)

Loading a saved model

In [None]:
 # Models aren't loaded perfectly after saving, so a small amount of training might be necessary to get the model back to its original performance
"""
transformer = Transformer(512,8,7,5e-4) # Make sure the model is instantiated the same as your saved run
tokenizer = Tokenizer.from_file("my_tokenizer.json")
if device.type == "cuda":
  torch.set_float32_matmul_precision("high")
  torch.backends.cuda.matmul.allow_tf32 = True
  torch.backends.cudnn.allow_tf32 = True
  transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True)
transformer.load_state_dict(torch.load("transformer.pt", map_location=device, weights_only=False))
transformer.to(device)
"""


Saving a pre-trained model

In [25]:
torch.save(transformer.state_dict(), "transformer.pt", _use_new_zipfile_serialization=True)
tokenizer.save("my_tokenizer.json")

Prompting

In [None]:
# With the params I have saved, the model can do short sentence completion but may become
# more incoherent as sentence length increases.
prompt = "Hello, how are you doing? I'm doing fine, responded"
response = prompt[0] + transformer.generate(prompt, tokenizer, 20, 25)
print(response)