# PART 2: Training a Model from Scratch: Leaving No Rock Unturned

Welcome back to the second installment of our series, "Transformers from Scratch: Leaving No Rock Unturned." In Part 1, we laid the groundwork by comprehensively exploring the inner workings of the Transformer architecture. If you haven't had the chance to dive into the foundations, you can access Part 1 [here](https://github.com/jcolano/transformer_step_by_step/tree/main/step1_basic_transformer).

Now we're about to take on one of the most important and complex aspects of working with Transformers: training a model from scratch. During its initial training phase, the Large Language Model (LLM) undergoes extensive training on extensive datasets containing both textual content and programming code. The primary objective during this phase is for the LLM to develop a deep understanding of the complex statistical associations between words and ideas, thus establishing its fundamental knowledge base.

In this installment, we'll guide you through every step of the training process, ensuring you have a firm grasp on the following key topics:

1. **Data Preprocessing**: We'll start by loading the dataset, preparing it for training, and instantiating a tokenizer that will help the model understand the text data. You'll see how to split the data into training, validation, and test sets, setting the stage for robust model training.

2. **Custom Dataset Function**: We'll introduce a custom dataset function, providing you with a tailored approach to handle the data. This function will empower you to seamlessly integrate your dataset into the training pipeline.

3. **Custom Collate Function**: The collate function plays a crucial role in batching and preprocessing your data. We'll show you how to create a custom collate function to efficiently handle variable-length sequences and optimize your training process.

4. **Data Loading**: With data preprocessing in place, we'll construct data loaders that efficiently feed batches of data to your model during training. This step is essential for managing memory and training effectively.

5. **Hyperparameters**: We'll see the hyperparameters governing the training process. From learning rates to batch sizes, we'll explain the significance of each parameter and guide you in making informed choices.

6. **Training Loop**: The heart of this notebook lies in the training loop. We'll break down the loop into its core components, including the forward pass, loss computation, and backward pass for gradient descent. You'll gain a deep understanding of how these components come together to optimize your model.

Throughout this notebook, we'll provide clear code walkthroughs, explanations, and practical tips to ensure you understand all the details involved in training a Transformer model. By the end of our journey, you'll have the knowledge to better understand and manage the different tools that are typically used to tackle NLP tasks, generate creative text, and explore the vast potential of Transformers.

So, if you're ready to unlock the art of training a Transformer model from scratch, join me on this exciting adventure. Together, we'll leave **no rock unturned** in our quest for learning!


In [29]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import random_split, DataLoader, Dataset
from torch.nn import TransformerEncoder, TransformerEncoderLayer

from torch.nn.utils import clip_grad_norm_
from torch.optim.lr_scheduler import LambdaLR
from torch.nn.utils.rnn import pad_sequence

#for learning rate decay:
from torch.optim.lr_scheduler import ReduceLROnPlateau

import os
import random
import numpy as np
import math

from datasets import load_dataset
from transformers import GPT2Tokenizer 

from torch.utils.data import Dataset, DataLoader

from tqdm import tqdm

In [30]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.empty_cache()

## The Transformer

This is the transformer that we built in our first part. It is an autoregressive model with 105MM parameters. You will find here the following standard components:

* Positional Encoding Class: Learn how transformers account for the order of data without inherently processing it sequentially.
* Scaled Dot Product Class: Dive into the self-attention mechanism and see how different parts of a sequence attend to each other.
* Attention Head Class: Understand the fundamental building block of the self-attention mechanism.
* Multi-head Attention Class: Discover how transformers harness multiple attention heads to capture various features from the data.
* Feed-forward Module Class: Delve into the feed-forward networks present within the transformer and their role.
* Transformer Block Class: See how the various components come together to form a transformer block.
* Transformer Model: Integrate the different classes to build the complete autoregressive transformer model.

For full details about each one of these components, review the first notebook [first notebook](https://github.com/jcolano/transformer_step_by_step/tree/main/step1_basic_transformer) of the series.

In [31]:
class PositionalEncoding(nn.Module):
    def __init__(self, max_tokens=1000, embedding_dimensions=768):
        super().__init__()
        pe = torch.zeros(max_tokens, embedding_dimensions)
        position = torch.arange(0, max_tokens, dtype=torch.float).unsqueeze(1)
        div_term = 1 / (10000 ** (torch.arange(0, embedding_dimensions, 2).float() / embedding_dimensions))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        self.pe = pe.unsqueeze(0) # Shape [1, max_tokens, embedding_dimensions]
        self.pe = self.pe.to(device)
        
    def forward(self, x):
        # x has shape [batch_size, max_tokens, embedding_dimensions]
        # self.pe has shape [1, max_tokens, embedding_dimensions]
        # Broadcasting will take care of the batch size dimension
        
        x = x.to(device) 
        
        return x + self.pe

positional_encoding = PositionalEncoding()

def scaled_dot_product_attention(query, key, value):
    dim_k = query.size(-1)
    scores = torch.bmm(query, key.transpose(1, 2)) / np.sqrt(dim_k) 
    weights = F.softmax(scores, dim=-1)
    return torch.bmm(weights, value)
    
class AttentionHead(nn.Module):
    def __init__(self, embed_dim, head_dim):
        super().__init__()
        self.q = nn.Linear(embed_dim, head_dim)
        self.k = nn.Linear(embed_dim, head_dim)
        self.v = nn.Linear(embed_dim, head_dim)
        
    def forward(self, hidden_state):
        attn_outputs = scaled_dot_product_attention(self.q(hidden_state), 
                                                    self.k(hidden_state), 
                                                    self.v(hidden_state)) 
        return attn_outputs

class MultiHeadAttention(nn.Module): 
    def __init__(self, embedding_dimensions=512, num_attention_heads=8):
        super().__init__()
        embed_dim = embedding_dimensions
        num_heads = num_attention_heads
        head_dim = embed_dim // num_heads
        self.heads = nn.ModuleList(
            [AttentionHead(embed_dim, head_dim) for _ in range(num_heads)]
        )
        self.output_linear = nn.Linear(embed_dim, embed_dim)
        
    def forward(self, hidden_state):
        x = torch.cat([h(hidden_state) for h in self.heads], dim=-1) 
        x = self.output_linear(x)
        return x
    
multihead_attn = MultiHeadAttention(embedding_dimensions=512, num_attention_heads=4)
multihead_attn = multihead_attn.to(device)

class MLP(nn.Module):
    def __init__(self, embed_dim, mlp_dim):
        super().__init__()
        self.c_fc = nn.Linear(embed_dim, mlp_dim)
        self.c_proj = nn.Linear(mlp_dim, embed_dim)
        self.act = nn.GELU()

    def forward(self, x):
        x = self.c_fc(x)
        x = self.act(x)
        x = self.c_proj(x)
        return x

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, mlp_dim, num_attention_heads):
        super().__init__()

        self.ln_1 = nn.LayerNorm(embed_dim)
        self.ln_2 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadAttention(embedding_dimensions=embed_dim, num_attention_heads=num_attention_heads)
        self.mlp = MLP(embed_dim, mlp_dim)
        
    def forward(self, x):
        # First, the attention + residual connection
        attn_output = self.attn(self.ln_1(x))
        x = x + attn_output
        
        # Then, the feed-forward + residual connection
        ff_output = self.mlp(self.ln_2(x))
        x = x + ff_output
        
        return x


class Transformer(nn.Module):
    def __init__(self, vocab_size, embedding_dimensions, max_tokens, num_blocks, mlp_dim, num_attention_heads):
        super().__init__()

        # Token embeddings
        self.wte = nn.Embedding(vocab_size, embedding_dimensions)
        
        # Positional encodings
        self.wpe = PositionalEncoding(max_tokens, embedding_dimensions)

        # Transformer blocks
        self.blocks = nn.ModuleList(
            [TransformerBlock(embedding_dimensions, mlp_dim, num_attention_heads) for _ in range(num_blocks)]
        )

        # Final layer normalization
        self.ln_f = nn.LayerNorm(embedding_dimensions)
        
        # Output head for causal language modeling (prediction of the next token)
        self.lm_head = nn.Linear(embedding_dimensions, vocab_size, bias=False)

    def forward(self, input_tokens):
        # input_tokens is of shape [batch_size, max_tokens]

        # Get embeddings
        x = self.wte(input_tokens)
        
        # Add positional encodings
        x = self.wpe(x)

        # Go through each block
        for block in self.blocks:
            x = block(x)

        # Final layer normalization
        x = self.ln_f(x)

        # Get token probabilities
        logits = self.lm_head(x)
        
        return logits
    
model = Transformer(vocab_size=50257, embedding_dimensions=768, max_tokens=1000, num_blocks=4, mlp_dim=3072, num_attention_heads=12)
model.to(device) 


Transformer(
  (wte): Embedding(50257, 768)
  (wpe): PositionalEncoding()
  (blocks): ModuleList(
    (0-3): 4 x TransformerBlock(
      (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): MultiHeadAttention(
        (heads): ModuleList(
          (0-11): 12 x AttentionHead(
            (q): Linear(in_features=768, out_features=64, bias=True)
            (k): Linear(in_features=768, out_features=64, bias=True)
            (v): Linear(in_features=768, out_features=64, bias=True)
          )
        )
        (output_linear): Linear(in_features=768, out_features=768, bias=True)
      )
      (mlp): MLP(
        (c_fc): Linear(in_features=768, out_features=3072, bias=True)
        (c_proj): Linear(in_features=3072, out_features=768, bias=True)
        (act): GELU(approximate='none')
      )
    )
  )
  (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (lm_head): Linear(in_features=768

In [32]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"The model has {count_parameters(model):,} parameters.")

The model has 105,547,776 parameters.


## The Dataset

Before we dive into the details of the training loop a Transformer model from scratch, it's essential to understand the data that will feed our training process. The quality and structure of the dataset is pivotal in the model's performance, making this phase a critical starting point.

In this section, we'll explore the following key aspects:

1. **Tokenizer**: To make sense of the text data, we'll introduce a tokenizer, a tool that breaks down text into manageable units. You'll learn how to instantiate and use a tokenizer to process your dataset effectively.

2. **Dataset Loading**: We'll begin by loading the dataset that we'll use for training. Whether you're working with a pre-existing dataset or collecting and preparing your own, this step is where you ensure your data is ready for training.

3. **Data Splitting**: Properly dividing your dataset into training, validation, and test sets is crucial for evaluating your model's performance and preventing overfitting. We'll discuss strategies for achieving an appropriate split.

4. **Custom Dataset Function**: In some cases, your dataset may require custom handling to align with your training objectives. We'll demonstrate how to create a custom dataset function tailored to your specific data format and requirements.

5. **Custom Collate Function**: The collate function determines how data is batched and processed during training. We'll walk you through the creation of a custom collate function to ensure efficient and effective batching of variable-length sequences.

By the end of this section, you'll have an understanding of how to prepare and structure your data for a Transformer model training. 


### Tokenizer

In the realm of natural language processing, text data is transformed into a format that machine learning models can understand through a process called tokenization. Tokenization breaks text down into individual units, such as words or subwords, and assigns a unique numerical identifier (token) to each unit. This process is essential for training and using language models like the one we are building.

Two dimensions of the tokenizer to understand and how they influence the capabilities and efficiency of the language model: Vocabulary Size and Embedding Size.

### Vocabulary Size:

The vocabulary size, often referred to as "vocab size," dictates the number of unique tokens that the tokenizer can handle. It represents the breadth of language that the model can comprehend. Specifically:

- **Importance**: A larger vocabulary size allows the model to recognize and generate a more extensive range of words and concepts. This can be especially advantageous for tasks requiring nuanced language understanding and generation.

- **Trade-off**: A larger vocabulary size demands more memory and computational resources. Models with larger vocabularies are slower to train and use, making them less suitable for resource-constrained environments.

- **Possible Options**: Tokenizers like GPT-2 come in different variants, each with its own vocabulary size. For instance, "gpt2-small" has a smaller vocabulary compared to "gpt2-medium" or "gpt2-large." The choice of vocabulary size should align with your specific task requirements and available resources.

### Embedding Size (Hidden Dimension or Model Dimension):

The embedding size determines the dimensionality of the continuous vector representations (embeddings) for each token. These embeddings serve as the foundation for the model's understanding of text. Key considerations include:

- **Importance**: The embedding size defines the complexity of the patterns and nuances that the model can capture in the data. A larger embedding size can potentially enhance the model's ability to represent intricate language features.

- **Trade-off**: Similar to vocabulary size, a larger embedding size necessitates more memory and computational power. Training and using models with larger embeddings can be resource-intensive and time-consuming.

- **Possible Options**: Tokenizers like GPT-2 offer various models with different embedding sizes. "gpt2-small" typically has smaller embeddings compared to "gpt2-medium" or "gpt2-large." Your choice of embedding size should align with the complexity of your task and the available computational resources.

Understanding these aspects of the tokenizer is crucial when selecting the right model for your NLP project. Balancing vocabulary size and embedding size ensures that your model is both capable and efficient. In the next sections, we will delve further into dataset preprocessing and data splitting, setting the stage for robust model training.

Lets instantiate our tokenizer:

1. **Initialization of Tokenizer**:
   - The first line initializes the tokenizer, which is essential for processing text data for the Transformer model. We will be using the gpt2-medium tokenizer from Huggingface. This tokenizer has over 50K vocab size and an embedding dimension of 1024.
  

2. **Setting Pad Token**:
   - This line sets the `pad_token` to be the same as the `eos_token` (end-of-sequence token), commonly used for sequence padding during training.


3. **Configuring Maximum Sequence Length**:
   - The last line configures the maximum sequence length (`model_max_length`) that the tokenizer will handle. Sequences longer than this length will be truncated or split to fit within this limit.  This is a number that you will define and basically determines, in the case of text, how long your text can be.


In [33]:
# Initialize tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2-medium")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.model_max_length = 1000

### Data Preprocessing: Loading and Cleaning

We begin the data preprocessing phase by loading a dataset and performing some initial cleaning to ensure data quality.

1. **Loading the Dataset**:
   - We use the `load_dataset` function from the `datasets` library to load a dataset named 'roneneldan/TinyStories.' This dataset will serve as the basis for our training data.

2. **Setting a Seed and Shuffling**:
   - To ensure randomization, we set a seed value (e.g., `seed = 42`) that can be changed whenever we need a new random sample. We then shuffle the dataset using this seed to introduce randomness into our sample selection.

3. **Selecting the First 'n_sample' Records**:
   - We select the first 'n_sample' records from the shuffled dataset. This step helps us limit the dataset size to a manageable quantity while maintaining a diverse sample.

4. **Identifying Empty Samples**:
   - To ensure data quality, we iterate through the dataset, checking each sample for empty text entries. We identify empty samples based on the assumption that the key for text data in the dataset is 'text.' We check if the text is empty after stripping whitespace.

5. **Removing Empty Samples**:
   - If any empty samples are identified, we compute the indices of the samples to keep by excluding the indices of the empty ones. We then select and keep only the samples without empty text entries.

6. **Final Sample Count**:
   - We calculate the final number of samples remaining in the dataset after removing empty ones and print the result. This step ensures that we are aware of the dataset size for further processing and training.

The provided code performs these data preprocessing steps to load and clean the dataset, ensuring that we work with high-quality data in subsequent stages of our project.

In [34]:
from datasets import load_dataset

# Load the dataset
n_samples = 100

dataset = load_dataset('roneneldan/TinyStories', split='train')

# Set a seed different from the one you used before
seed = 42  # You can change this value every time you need a new random sample
shuffled_dataset = dataset.shuffle(seed=seed)

# Select the first n_samples records
dataset = shuffled_dataset.select(range(n_samples))

# Identify empty samples
empty_samples_indices = []
for idx, sample in enumerate(dataset):
    # Assuming the key in the dataset for the text is 'text' (adjust if it's different)
    if not sample['text'].strip():  # Check if the text is empty after stripping whitespace
        empty_samples_indices.append(idx)

# Remove empty samples
if empty_samples_indices:
    # Compute indices of the samples to keep (i.e., all indices minus the empty ones)
    indices_to_keep = [idx for idx in range(len(dataset)) if idx not in empty_samples_indices]
    
    # Select only the samples to keep
    dataset = dataset.select(indices_to_keep)
    
n_samples = len(dataset)

print(f"Remaining samples after dropping empty ones: {len(dataset)}")

# Now, `dataset` contains a random sample of n_samples records from the dataset

Found cached dataset parquet (C:/Users/JCO/.cache/huggingface/datasets/roneneldan___parquet/roneneldan--TinyStories-6ac769f186d7da53/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
Loading cached shuffled indices for dataset at C:\Users\JCO\.cache\huggingface\datasets\roneneldan___parquet\roneneldan--TinyStories-6ac769f186d7da53\0.0.0\2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec\cache-1e2c1ec15d48fb3c.arrow


Remaining samples after dropping empty ones: 100


### Dataset Splitting: Train, Validation, and Test Sets

In this section, we focus on splitting the dataset into three subsets: the training set, the validation set, and the test set. Proper dataset splitting is essential for model training, evaluation, and tuning.

1. **Splitting Percentages**:
   - We define the percentages for splitting the dataset into the training, validation, and test sets. Here, `train_percent` is set to 0.8, and `val_percent` is set to 0.2. The test set's percentage is implicitly calculated as 1.0 minus the sum of the training and validation percentages, ensuring that the three sets add up to 100%.

2. **Calculating Set Sizes**:
   - We calculate the sizes of the training and validation sets based on the specified percentages. The `train_size` is computed as `train_percent * n_samples`, and the `val_size` is computed as `val_percent * n_samples`. The remaining samples not assigned to the training or validation sets will be allocated to the test set.

3. **Creating Split Datasets**:
   - We use the `select` method to create three separate datasets: `train_dataset`, `val_dataset`, and `test_dataset`. Each dataset is generated by selecting a range of samples based on the previously calculated sizes. The training set includes the first `train_size` samples, the validation set includes the next `val_size` samples, and the test set comprises the remaining samples.

This dataset splitting process ensures that we have distinct subsets for training, validation, and testing, facilitating the model development and evaluation phases of our project.


In [35]:
# Split the dataset into train, val, and test
train_percent = 0.8
val_percent = 0.1

# test_percent is implicitly 0.1 since train + val + test = 1.0

train_size = int(train_percent * n_samples)
val_size = int(val_percent * n_samples)
# Remaining samples are for testing

train_dataset = dataset.select(list(range(train_size)))
val_dataset = dataset.select(list(range(train_size, train_size + val_size)))
test_dataset = dataset.select(list(range(train_size + val_size, n_samples)))

### Custom Dataset Class for Text Data

In this section, we define a custom dataset class named `CustomDataset` designed for handling text data. This class is crucial for preparing the data in a format that the Transformer model can consume during training.

1. **Class Definition**:
   - We define a new Python class named `CustomDataset` that inherits from the `Dataset` class, which is typically used in PyTorch for managing and loading datasets.

2. **Initialization**:
   - The `__init__` method initializes the custom dataset. It takes two arguments: `texts` and `tokenizer`.
   - `texts`: This argument represents a list of text samples, such as sentences or documents, that you want to use for training.
   - `tokenizer`: This argument is an instance of a tokenizer (e.g., the GPT-2 tokenizer) that will be used to process the text data.

3. **Length of the Dataset**:
   - The `__len__` method is implemented to return the total number of text samples in the dataset. This value is used to determine the dataset's length.

4. **Getting an Item**:
   - The `__getitem__` method defines how a specific item from the dataset is retrieved. It takes an index (`idx`) as an argument, indicating which text sample to fetch.
   - Inside this method, the text associated with the given index is extracted from the `texts` list.
   - The `tokenizer` is then used to tokenize the text. The tokenizer returns tokenized and encoded representations of the text.
   - The method returns a dictionary containing two key-value pairs:
     - `'input_ids'`: This key holds the tokenized input IDs of the text, which are the indices of tokens in the model's vocabulary.
     - `'attention_mask'`: This key holds the attention mask, which indicates which tokens should be attended to (1) and which should be ignored (0).

This custom dataset class allows you to easily convert a list of text samples into a format suitable for training Transformer models. You can use instances of this class with PyTorch's data loaders to efficiently feed data into your model during training.


In [36]:
class CustomDataset(Dataset):
    def __init__(self, texts, tokenizer):
        self.texts = texts
        self.tokenizer = tokenizer
        
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = self.texts[idx]
        tokens = self.tokenizer(text, return_tensors='pt', padding=False, truncation=True)
        return {
            'input_ids': tokens['input_ids'].squeeze(),
            'attention_mask': tokens['attention_mask'].squeeze()
        }       

### Custom Collate Function for Batch Processing

In this section, we define a custom collate function named `custom_collate_fn`. This function is essential for processing batches of data efficiently during training, especially when working with variable-length sequences.

1. **Function Definition**:
   - The `custom_collate_fn` function is defined to process a batch of data, typically containing multiple text samples.

2. **Input Parameters**:
   - The function takes two parameters:
     - `batch`: A batch of data, where each element in the batch is a dictionary containing 'input_ids' and 'attention_mask' keys.
     - `max_length` (optional, default value is set to 1000): An integer that specifies the maximum sequence length to which sequences in the batch should be padded.

3. **Data Extraction**:
   - Inside the function, 'input_ids' and 'attention_mask' values are extracted from each item in the batch, creating separate lists of input IDs and attention masks.

4. **Padding Sequences**:
   - The extracted sequences in `input_ids` and `attention_masks` are padded to ensure they have the same length within the batch. Padding is performed using the `pad_sequence` function from PyTorch with `batch_first=True`. The padding value is set to 0.
   - This padding ensures that all sequences in the batch have the same length for efficient batch processing.

5. **Further Padding to Max Length**:
   - After initial padding, the function checks if the sequences are shorter than the specified `max_length`. If they are, further padding is applied to extend the sequences to the specified maximum length.
   - Padding is added to the right side of the sequences using `F.pad` from PyTorch, ensuring that all sequences reach the same length.

6. **Batch Output**:
   - Finally, the function returns a dictionary containing two keys:
     - `'input_ids'`: This key holds the batch of padded input IDs.
     - `'attention_mask'`: This key holds the batch of padded attention masks.

This custom collate function is designed to handle variable-length sequences efficiently, ensuring that all sequences in a batch have the same length. It plays a crucial role in preparing data for model training and is typically used in conjunction with data loaders in PyTorch.


In [37]:
def custom_collate_fn(batch, max_length=1000):
    input_ids = [item['input_ids'] for item in batch]
    attention_masks = [item['attention_mask'] for item in batch]

    # Pad sequences in the batch
    input_ids_padded = pad_sequence(input_ids, batch_first=True, padding_value=0)
    attention_masks_padded = pad_sequence(attention_masks, batch_first=True, padding_value=0)

    # Further pad all sequences to a fixed length of max_length
    if input_ids_padded.size(1) < max_length:
        padding_needed = max_length - input_ids_padded.size(1)
        input_ids_padded = F.pad(input_ids_padded, (0, padding_needed), 'constant', 0)
        attention_masks_padded = F.pad(attention_masks_padded, (0, padding_needed), 'constant', 0)

    return {
        'input_ids': input_ids_padded,
        'attention_mask': attention_masks_padded
    }

### Data Loaders: Preparing Batches for Training and Testing

In this section, we set up data loaders to prepare batches of data for training and validation. Data loaders are crucial for efficient batch processing during model training. We also perform a quick test to ensure that the data loaders are functioning as expected.

1. **Batch Size**:
   - We define the batch size as `batch_size = 32`. This parameter determines how many samples are processed together in each iteration during training and validation.

2. **Dataset Splitting**:
   - We split the training and validation datasets into separate lists of text samples: `texts_train` for the training set and `texts_val` for the validation set.

3. **Converting to CustomDataset**:
   - We create two instances of the `CustomDataset` class (`train_custom_dataset` and `val_custom_dataset`) to convert the text data into a format suitable for training. Each instance takes a list of text samples and a tokenizer as input.

4. **Data Loaders Setup**:
   - We set up two data loaders (`train_loader` and `val_loader`) using PyTorch's `DataLoader` class. These data loaders enable efficient batch processing for training and validation.
   - For the training loader (`train_loader`):
     - We specify that the data should be shuffled (`shuffle=True`) to introduce randomness into the order of samples in each batch.
     - We set the batch size to `batch_size` for grouping text samples into batches.
     - We use the `custom_collate_fn` function as the collate function to preprocess and pad the data within each batch.
   - For the validation loader (`val_loader`):
     - We set `shuffle=False` to keep the validation data in its original order.
     - The batch size and collate function settings are the same as those for the training loader.

5. **Testing Data Loaders**:
   - To verify that the data loaders are working as expected, we iterate through the training loader (`train_loader`) in a for loop and print the shapes of the input IDs and attention masks for the first batch.
   - This testing step helps ensure that the data loading process is correctly configured and that batches are generated with the specified batch size and padding.

These data loaders are ready for use in training and validation, providing batches of text data that can be fed into the Transformer model for learning and evaluation.


In [38]:
batch_size = 32
        
# Split your dataset
texts_train = [ex['text'] for ex in train_dataset]
texts_val = [ex['text'] for ex in val_dataset]

# Convert HuggingFace datasets to CustomDataset
train_custom_dataset = CustomDataset(texts_train, tokenizer)
val_custom_dataset = CustomDataset(texts_val, tokenizer)


train_loader = DataLoader(train_custom_dataset, shuffle=True, batch_size=batch_size, collate_fn=custom_collate_fn)
val_loader = DataLoader(val_custom_dataset, shuffle=False, batch_size=batch_size, collate_fn=custom_collate_fn)


# Testing
for data in train_loader:
    print("input_ids shape:", data['input_ids'].shape)
    print("attention_mask shape:", data['attention_mask'].shape)
    break


input_ids shape: torch.Size([32, 1000])
attention_mask shape: torch.Size([32, 1000])


## The training loop

### Hyperparameters and Training Monitoring

In this section, we set various hyperparameters for model training and define parameters for monitoring the training progress. Properly tuning these hyperparameters and monitoring the training process are essential for achieving good model performance.

1. **Number of Epochs (`num_epochs`)**:
   - `num_epochs` is set to `1`, indicating that we plan to train the model for one epoch, which means going through the entire training dataset once.

2. **Learning Rate (`learning_rate`)**:
   - `learning_rate` is set to `5e-4`, specifying the rate at which the model updates its parameters during training. It's a crucial hyperparameter that affects the speed and quality of training.

3. **Gradient Clipping (`grad_clip`)**:
   - `grad_clip` is set to `1.0`. Gradient clipping is a technique used to prevent gradient values from becoming too large during training, which can lead to unstable training. This parameter sets an upper bound on the gradient values.

4. **Warmup Steps (`warmup_steps`)**:
   - We calculate `warmup_steps` based on the assumption of warming up the learning rate over one epoch. It's a common practice to gradually increase the learning rate at the beginning of training to help the model converge faster.

5. **Total Steps (`total_steps`)**:
   - `total_steps` represents the total number of training steps that will be taken over the entire training process. It's calculated as the product of the number of training batches per epoch (`len(train_loader)`) and the number of epochs (`num_epochs`).

6. **Reporting Steps (`report_steps`)**:
   - `report_steps` is set to `5`. This parameter determines how often training progress is reported. In this case, progress will be reported every 5 steps.

7. **Checkpoint Steps (`checkpoint_steps`)**:
   - `checkpoint_steps` is set to `50`. This parameter determines how often model checkpoints (model snapshots) are saved during training. Setting it to `0` means no checkpointing based on the number of steps.

8. **Step Counter (`step`)**:
   - `step` is initialized to `0` and serves as a counter to keep track of the number of training steps completed.

9. **Patience for Early Stopping (`patience`)**:
   - `patience` is set to `5`, which is used for early stopping during training. If validation loss does not improve for a certain number of epochs (in this case, 5), training will be stopped early.

10. **Epochs with No Improvement (`epochs_no_improve`)**:
    - `epochs_no_improve` is initialized to `0` and is used to track the number of consecutive epochs where the validation loss does not improve.

11. **Best Validation Loss (`best_val_loss`)**:
    - `best_val_loss` is initialized to positive infinity (`float('inf')`). It is used to store the best validation loss observed during training, which is crucial for early stopping.

These hyperparameters and monitoring parameters are essential for configuring and managing the training process effectively.


In [39]:
# Hyperparameters
num_epochs = 1
learning_rate = 5e-4
grad_clip = 1.0
warmup_steps = len(train_loader) * 1  # Assuming warmup over 1 epoch
total_steps = len(train_loader) * num_epochs

# Define reporting steps
report_steps = 5
checkpoint_steps = 50 # Save a checkpoint every x steps. Leave 0 for not checkpoint on steps. 
step = 0  # Counter for the number of steps

patience = 5
epochs_no_improve = 0
best_val_loss = float('inf')

### Optimizer and Loss Function Definition

In this section, we define the optimizer and the loss function, both of which are crucial components of the training process.

1. **Optimizer Selection (`optimizer`)**:
   - We choose the AdamW optimizer as our optimization algorithm. AdamW is a variant of the Adam optimizer that incorporates weight decay (L2 regularization) to prevent overfitting.
   - The optimizer operates on the model's parameters, which are accessed using `model.parameters()`.
   - We specify the learning rate (`lr`) for the optimizer, which determines the step size for updating model parameters during training.

2. **Loss Function**:
   - The loss function is a critical component in training neural networks. It quantifies the difference between the model's predictions and the ground truth.
   - The choice of the loss function depends on the specific task (e.g., classification, regression) and is typically defined elsewhere in the code. Common choices include cross-entropy loss for classification tasks and mean squared error (MSE) for regression tasks.

The optimizer and loss function together form the core of the training process. The optimizer determines how the model's parameters are updated during backpropagation, while the loss function guides the training process by quantifying the model's performance.

In [40]:
# Define optimizer and loss function
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
loss_fn = torch.nn.CrossEntropyLoss()

### Learning Rate Schedule

In this section, we define a learning rate schedule using a custom function and associate it with the optimizer. A learning rate schedule is a technique used to adjust the learning rate during training to control the convergence of the model.

1. **Learning Rate Schedule Function (`lr_schedule`)**:
   - We define a custom function named `lr_schedule` that takes a single argument `step`. This function calculates the learning rate for a given training step.
   - The schedule is designed to include a warm-up phase followed by a linear decay phase. The learning rate is gradually increased during the warm-up phase and then linearly decreased during the decay phase.
   - Specifically, if the `step` is less than `warmup_steps`, the learning rate is increased linearly from 0 to 1 over the warm-up period.
   - After the warm-up phase, the learning rate decreases linearly from 1 to 0 over the remaining steps.

2. **Scheduler Definition (`scheduler`)**:
   - We create a scheduler using the `LambdaLR` class, which allows us to define a custom learning rate schedule function (`lr_schedule`) for the optimizer.
   - The `optimizer` parameter specifies the optimizer to which the learning rate schedule will be applied.
   - The `lr_lambda` parameter is set to `lr_schedule`, indicating that the custom function defines the learning rate schedule.
   
This learning rate schedule helps control the pace of learning during training, typically starting with a small learning rate to stabilize training and gradually decreasing it to fine-tune the model's parameters as training progresses. It's an important technique for achieving better convergence and model performance.


In [41]:
# Linear schedule function
def lr_schedule(step):
    if step < warmup_steps:
        return float(step) / float(max(1, warmup_steps))
    return max(0.0, float(total_steps - step) / float(max(1, total_steps - warmup_steps)))

scheduler = LambdaLR(optimizer, lr_lambda=lr_schedule)

### Checkpoint Naming

In this section, we define and format checkpoint names for saving and loading model checkpoints during training and evaluation. Checkpoints are snapshots of the model's parameters that can be saved and restored to continue training or perform inference.

1. **Checkpoint Name (`checkpoint_name`)**:
   - `checkpoint_name` is a descriptive name for the checkpoint, such as "autoregressive-105m." It typically includes information about the model architecture, dataset size, or any other relevant details.
   
2. **Checkpoint Epoch (`checkpoint_epoch`)**:
   - `checkpoint_epoch` is a string constructed by appending "_checkpoint_" to the `checkpoint_name`. This format is commonly used to denote checkpoints saved at specific training epochs. For example, "autoregressive-105m_checkpoint_3" would represent a checkpoint saved after the third training epoch.
   
3. **Final Checkpoint (`final_checkpoint`)**:
   - `final_checkpoint` is a string constructed by appending "_final.pth" to the `checkpoint_name`. This format is often used to denote the final model checkpoint saved after training is complete. It typically represents the best or most recent state of the model.

These checkpoint names help keep track of different model snapshots during training and are valuable for resuming training, performing model evaluation, or deploying trained models in production.


In [42]:
checkpoint_name = "autoregressive-105m"
checkpoint_epoch = f"{checkpoint_name}_checkpoint_"
final_checkpoint = f"{checkpoint_name}_final.pth"

### Training Loop

In this section, we define the training loop for training a Transformer model. The training loop consists of several steps, including forward and backward passes, loss calculation, and checkpoint saving.

1. **Epoch Loop**:
   - We iterate over the specified number of epochs (`num_epochs`) to train the model.

2. **Model Training Mode**:
   - Before starting the epoch, we set the model to training mode using `model.train()`. This mode ensures that model parameters are updated during training.

3. **Training Progress Bar**:
   - We create a progress bar (`train_progress_bar`) to monitor the progress of training within the current epoch. This progress bar provides updates on the training loss and displays progress information.

4. **Batch Iteration**:
   - We iterate over batches of data from the training loader (`train_loader`). Each batch contains input IDs and attention masks for a batch of text samples.

5. **Forward Pass**:
   - For each batch, we perform a forward pass through the model to obtain logits (raw predictions) for the next tokens in the sequence.

6. **Loss Calculation**:
   - We calculate the loss by comparing the predicted logits with the shifted labels (ground truth). The shift is applied to predict the next token in the sequence.

7. **Backward Pass and Optimization**:
   - We perform a backward pass to compute gradients and then update model parameters using the optimizer (`optimizer`).
   - Gradient clipping is applied to prevent excessively large gradients that can disrupt training stability.

8. **Learning Rate Scheduler**:
   - The learning rate scheduler (`scheduler`) is used to adjust the learning rate during training, following a specified schedule.

9. **Checkpoint Saving**:
   - Optionally, we save model checkpoints at regular intervals (determined by `checkpoint_steps`). These checkpoints include the model's state, optimizer state, loss, and current step.

10. **Progress Reporting**:
    - We update the progress bar with the current loss and also print training progress every `report_steps` steps.

11. **Epoch-wise Metrics**:
    - After completing an epoch, we calculate and report average training loss and perplexity (a measure of prediction uncertainty).

12. **Validation Loop**:
    - After training for an epoch, we enter the validation phase.
    - We set the model to evaluation mode (`model.eval()`) to disable gradient calculations.
    - We iterate over batches from the validation loader (`val_loader`) and calculate the validation loss.
    - Validation loss and perplexity are reported.

13. **Model Checkpointing (End of Epoch)**:
    - After the validation loop, we save a checkpoint for the current epoch, even if it's not a scheduled checkpoint. This represents the full state of the model at the end of the epoch.

14. **Early Stopping**:
    - We track the validation loss and apply early stopping if the validation loss does not improve for a certain number of epochs (`patience`). This helps prevent overfitting.

15. **Break on Early Stopping**:
    - If early stopping criteria are met, we print a message indicating early stopping and break out of the training loop.

This training loop is a fundamental part of training Transformer models and includes key steps for training, monitoring, and saving model checkpoints.


In [43]:
# Training loop
for epoch in range(num_epochs):

    model.train()  # Set the model to training mode
    total_loss = 0

    train_progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", position=0, leave=True)
       
    for batch in train_progress_bar:
        # Move batch tensors to the same device as the model
        input_ids = batch['input_ids'].to(device)

        if isinstance(input_ids, list) and all(isinstance(item, list) for item in input_ids): # HACK!!! Check if original_prompt_tensors is a list of lists     
            lengths = [len(seq) for seq in input_ids] # Verify if sequences have fixed or variable length
            unique_lengths = set(lengths)
            
            if len(unique_lengths) > 1: # If sequences have variable lengths, pad them
                max_length = max(unique_lengths)
                original_prompt_tensors = [seq + [0] * (max_length - len(seq)) for seq in input_ids]  # padding with zeros
                
            input_ids = [torch.tensor(seq).to(device) for seq in input_ids] # Convert original_prompt_tensors to individual tensors

        attention_mask = batch['attention_mask'].to(device)

        # Forward pass
        input_ids = input_ids.to(device)
        logits = model(input_ids)

        # The language model task expects the logits to be shifted for prediction
        # The idea is that, for each token, we want to predict the next one
        # Therefore, we shift the input tokens to the right
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = input_ids[..., 1:].contiguous()

        # Calculate loss
        loss = loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        total_loss += loss.item()

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        clip_grad_norm_(model.parameters(), grad_clip)  # Gradient clipping
        optimizer.step()
        scheduler.step()  

        # Increment step
        step += 1
        
        # Save checkpoint every checkpoint_steps steps
        if (checkpoint_steps > 0) and (step % checkpoint_steps == 0):
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss,
                'step': step  # saving the step can also be useful
            }, f"{checkpoint_epoch}{epoch}.pt")

        # Update tqdm description with current loss
        train_progress_bar.set_description(f"Epoch {epoch+1}/{num_epochs} Loss: {loss.item():.4f}")
        
        # Report every 'report_steps'
        if step % report_steps == 0:
            print(f"Epoch {epoch+1}/{num_epochs}, Step {step}, Training Loss: {loss.item():.4f}")
      
    avg_train_loss = total_loss / len(train_loader)
    train_perplexity = math.exp(avg_train_loss)

    # Evaluation loop on validation set
    model.eval()  # Set model to evaluation mode
    total_eval_loss = 0
    
    with torch.no_grad():
            val_progress_bar = tqdm(val_loader, desc=f"Validating Epoch {epoch+1}/{num_epochs}", position=0, leave=True)
            
            for batch in val_progress_bar:
                try:
                    input_ids = batch['input_ids'].to(device)
                    attention_mask = batch['attention_mask'].to(device)

                    logits = model(input_ids)
                    shift_logits = logits[..., :-1, :].contiguous()
                    shift_labels = input_ids[..., 1:].contiguous()

                    loss = loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
                    total_eval_loss += loss.item()

                    val_progress_bar.set_description(f"Validation Loss: {loss.item():.4f}") 
                except RuntimeError as e:
                    print(f"Skipped batch due to error: {e}") 
            

    avg_val_loss = total_eval_loss / len(val_loader)
    val_perplexity = math.exp(avg_val_loss)
            
    print(f"Epoch {epoch+1}/{num_epochs} - Training Loss: {avg_train_loss:.4f}, Validation Loss: {avg_val_loss:.4f}")
    print(f"Training Perplexity: {train_perplexity:.4f}, Validation Perplexity: {val_perplexity:.4f}")

    # After each epoch
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
        'step': 0 # This means that the checkpoint is for the full epoch
    }, f"{checkpoint_epoch}{epoch}.pt")

    
    # After validation loop
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        epochs_no_improve = 0
        
    else:
        epochs_no_improve += 1
        
    if epochs_no_improve == patience:
        print(f"Early stopping at epoch {epoch}. Best validation loss: {best_val_loss}")
        break

Epoch 1/1:   0%|          | 0/3 [00:00<?, ?it/s]

### Saving the Final Model Checkpoint

In this section, we save the final model checkpoint after completing the training process. A model checkpoint is a snapshot of the model's parameters that can be used for further training, evaluation, or deployment.

1. **Checkpoint Save Command (`torch.save()`)**:
   - We use the `torch.save()` function to save the model's state dictionary (`model.state_dict()`) to a file named `final_checkpoint`. This command serializes the model parameters and saves them to disk.
   - The `model.state_dict()` contains all the learnable parameters of the model, such as weights and biases.

2. **File Naming (`final_checkpoint`)**:
   - The filename for the final checkpoint is set to `final_checkpoint`. This naming convention is commonly used to indicate that it represents the final state of the model after training.

By saving the final model checkpoint, we can load this checkpoint later to perform inference, fine-tuning, or further evaluation without the need to retrain the model from scratch.

You can copy and paste this markdown directly into your Jupyter notebook to document the final model checkpoint saving step.


In [None]:
torch.save(model.state_dict(), final_checkpoint)

## Conclusion

In this notebook, we embarked on a comprehensive journey to explore the inner workings of autoregressive Transformers from scratch. 

In part 1 we went deep into the fundamentals of the Transformer architecture and took a detailed look at its key components, including positional encoding, self-attention mechanisms, multi-head attention, feed-forward networks, and the Transformer block. By the end of Part 1, we successfully assembled these components to construct a complete autoregressive Transformer model.

In this Part 2, we reviewed in detail the practical aspects of training our model. We covered various essential topics, including data preprocessing, tokenization, dataset creation, and data loading. We implemented a custom dataset class and collate function, setting the stage for efficient data handling during training.

Then we explored critical hyperparameters to drive the training of our autoregressive Transformer model. Hyperparameters such as the number of epochs, learning rate, gradient clipping, and learning rate scheduling were defined to optimize the training process. We also included a condition for early stopping to prevent overfitting during training.

Our training loop was detailed and included key steps such as forward and backward passes, loss computation, gradient clipping, and checkpoint saving. We monitored training progress using progress bars, reported losses, and calculated perplexity for both training and validation datasets.

### What's Next

As we conclude this second installment, we've laid the groundwork for training an autoregressive Transformer model. However, there's more to discover in the world of Transformers. In the next and final installment of our "Transformers from Scratch" series, we will explore the actual generation of text.

In Part 3, we'll learn about the generation loop used with autoregressive Transformers. This step involves using the trained model to generate coherent and contextually relevant text, making it an essential skill for various natural language processing tasks, including text completion, dialogue generation, and more.  It is important to note that training a model like this to produce coherent text would take a huge dataset (in the order of billions of tokens) and big computing power.  When we get to this Part 3 you will learn how the text is generated in models like ChatGPT, but please don't expect high quality results as this model is being trained on very little data.

Stay tuned for Part 3, where we'll uncover the secrets of autoregressive text generation. Together, we'll leave no stone unturned in our quest to master Transformers from the ground up.

Happy coding and exploring!


Developed by [Juan Olano](https://www.linkedin.com/in/juan-olano-b9a330112/) Sept.2023