# Pretraining from unlabeled data (chapter 5)

This notebook explores pretraining process of LLMs based on Sebastian Raschka's book (Chapter 5). In particular, it discusses the following:

1. Computing the **training** and **validation set losses** to assess the quality of LLM-generated text during training
2. Implementing a **training function** and pretraining the LLM
3. **Saving and loading model weights** to continue training an LLM
4. **Loading pretrained weights** from OpenAI

## Acknowledgment

All concepts, architectures, and implementation approaches are credited to Sebastian Raschka's work.  This repository serves as my personal implementation and notes while working through the book's content.

## Resources

- [Sebastian Raschka's GitHub](https://github.com/rasbt)
- [Book Information](https://www.manning.com/books/build-a-large-language-model-from-scratch)
    - [Chapter 5](https://livebook.manning.com/book/build-a-large-language-model-from-scratch/chapter-5)
- [Pytorch Lightning - great tutorial collection](https://lightning.ai/docs/pytorch/stable/levels/core_skills.html#)

![Topic overview](https://drek4537l1klr.cloudfront.net/raschka/Figures/5-1.png)

In [5]:
# This installs the ipynb package which enables importing functions defined in other notebooks.
%pip install ipynb

Note: you may need to restart the kernel to use updated packages.


In [21]:
from typing import Optional

import torch
import torch.nn as nn

import tiktoken

# Import previous chapter dependencies.
# See https://stackoverflow.com/questions/44116194/import-a-function-from-another-ipynb-file
# NOTE: Importing these functions seems to run the entire cell the symbol is defined in, which would
#       suggest that symbols should be defined in separate cells from the test code.
from ipynb.fs.full.chapter_04_gpt_from_scratch import (
    GPTConfig,
    GPTModel,
    generate_text_simple,
)
from ipynb.fs.full.chapter_02_dataset_creation import create_dataloader_v1

tiktoken version: 0.7.0
Total number of character: 20479
Total number of tokens: 5145
x: [290, 4920, 2241, 287]
y:      [4920, 2241, 287, 257]
 and ---->  established
 and established ---->  himself
 and established himself ---->  in
 and established himself in ---->  a
[tensor([[  40,  367, 2885, 1464]]), tensor([[ 367, 2885, 1464, 1807]])]
Parameter containing:
tensor([[ 0.3374, -0.1778, -0.1690],
        [ 0.9178,  1.5810,  1.3010],
        [ 1.2753, -0.2010, -0.1606],
        [-0.4015,  0.9666, -1.1481],
        [-1.1589,  0.3255, -0.6315],
        [-2.8400, -0.7849, -1.4096]], requires_grad=True)
torch.Size([6, 3])
tensor([[-0.4015,  0.9666, -1.1481]], grad_fn=<EmbeddingBackward0>)
tensor([[ 1.2753, -0.2010, -0.1606],
        [-0.4015,  0.9666, -1.1481],
        [-2.8400, -0.7849, -1.4096],
        [ 0.9178,  1.5810,  1.3010]], grad_fn=<EmbeddingBackward0>)
Token IDs: tensor([[   40,   367,  2885,  1464],
        [ 1807,  3619,   402,   271],
        [10899,  2138,   257,  7026],


In [7]:
# Instantiate the GPT-2 configuration with shortened context length.
GPT_CONFIG_124M = GPTConfig(
    vocab_size=50257,  # as used by the BPE tokenizer for GPT-2.
    context_length=256,
    emb_dim=768,
    n_heads=12,
    n_layers=12,
    dropout_rate=0.1,
    qkv_bias=False,
)

In [8]:
# Create two training examples in a batch.
tokenizer = tiktoken.get_encoding("gpt2")

batch = []
txt1 = "Every effort moves you"
txt2 = "Every day holds a"

batch.append(torch.tensor(tokenizer.encode(txt1)))
batch.append(torch.tensor(tokenizer.encode(txt2)))
batch = torch.stack(batch, dim=0)

In [9]:
# Test the GPT model.
torch.manual_seed(123)

# Run the model on the batch.
model = GPTModel(GPT_CONFIG_124M)
model.eval()
out = model(batch)

print(f"Input batch: {batch}")
print(f"Output shape: {out.shape}")

Input batch: tensor([[6109, 3626, 6100,  345],
        [6109, 1110, 6622,  257]])
Output shape: torch.Size([2, 4, 50257])


# Text encoder and decoder utilities

![Topic overview](https://drek4537l1klr.cloudfront.net/raschka/Figures/5-3.png)

## Text to token conversion

In [10]:
def text_to_token_ids(
    text: str, tokenizer: Optional[tiktoken.Encoding] = None
) -> torch.Tensor:
    """Convert a text string to a tensor of token IDs.

    Args:
        text: The text to convert to token IDs.
        tokenizer: The tokenizer to use.

    Returns:
        torch.Tensor: A tensor of token IDs.
    """
    # Instantiate a default tokenizer (if non was provided).
    # Tokenize the input text.
    encoded = tokenizer.encode(text, allowed_special={"<|endoftext|>"})

    # Convert the tokenized text to a tensor.
    # NOTE: .unsqueeze(0) adds the batch dimension.
    encoded_tensor = torch.tensor(encoded).unsqueeze(0)
    return encoded_tensor

## Token to text conversion

In [11]:
def token_ids_to_text(
    token_ids: torch.Tensor, tokenizer: Optional[tiktoken.Encoding] = None
) -> str:
    """Convert a tensor of token IDs to a text string.

    Args:
        token_ids: The tensor of token IDs to convert to text.
        tokenizer: The tokenizer to use.

    Returns:
        str: The text string.
    """
    # Instantiate a default tokenizer (if non was provided).
    # NOTE: .squeeze(0) removes the batch dimension.
    flat = token_ids.squeeze(0)
    return tokenizer.decode(flat.tolist())

In [12]:
# Test the text to token conversion.
start_context = "Every effort moves you"
tokenizer = tiktoken.get_encoding("gpt2")

token_ids = generate_text_simple(
    model=model,
    idx=text_to_token_ids(start_context, tokenizer),
    max_new_tokens=10,
    context_size=GPT_CONFIG_124M.context_length,
)
print("Output text:\n", token_ids_to_text(token_ids, tokenizer))

Output text:
 Every effort moves you rentingetic wasnم refres RexMeCHicular stren


# Loss function

Computing the loss involves 5 steps as shown in the following figure. The example below uses a seven word vocabulary for illustration purposes.

For each of the three input tokens, shown on the left, we compute a vector containing probability scores corresponding to each token in the vocabulary. The index position of the highest probability score in each vector represents the most likely next token ID. These token IDs associated with the highest probability scores are selected and mapped back into a text that represents the text generated by the model.

![Text generation loss](https://drek4537l1klr.cloudfront.net/raschka/Figures/5-4.png)

## Example - step by step

In [13]:
# Develop the loss function using a batch of two simple examples.
inputs = torch.tensor(
    [[16833, 3626, 6100], [40, 1107, 588]],  # ["every effort moves", "I really like"]
)

# Define the targets, which are the next tokens in the sequences.
targets = torch.tensor(
    [
        [3626, 6100, 345],
        [1107, 588, 11311],
    ]  # [" effort moves you", " really like chocolate"]
)

# Compute the logits for the inputs.
# NOTE: We disable gradient computation since gradients are only used for training.
with torch.no_grad():
    logits = model(inputs)

# Compute the probabilities of each token in the vocabulary.
# NOTE: The shape of probas is [B, T, V] where
#
# B is the batch size
# T is the sequence length
# V is the vocabulary size.
probas = torch.softmax(logits, dim=-1)
print(f"Probas shape: {probas.shape}")

# Step 3 and 4: Convert the probabilities to token IDs via a greedy decoding strategy.
token_ids = torch.argmax(probas, dim=-1, keepdim=True)
# Print both batches of token IDs.
print("Token IDs:\n", token_ids)

# Step 5: Convert the token IDs back to text.
print(f"Targets batch 1: {token_ids_to_text(targets[0], tokenizer)}")
print(f"Outputs batch 1:" f" {token_ids_to_text(token_ids[0].flatten(), tokenizer)}")

Probas shape: torch.Size([2, 3, 50257])
Token IDs:
 tensor([[[16657],
         [  339],
         [42826]],

        [[49906],
         [29669],
         [41751]]])
Targets batch 1:  effort moves you
Outputs batch 1:  Armed heNetflix


In [14]:
# For each of the two input texts, we can print the initial softmax probability scores
# corresponding to the target tokens using the following code:

batch_idx = 0
# TODO: Why can't we just use probas[batch_idx, :, targets[batch_idx]] since T = 3?
target_probas_1 = probas[batch_idx, [0, 1, 2], targets[batch_idx]]
print(f"probas.shape: {probas.shape}")
print("Text 1:", target_probas_1)

batch_idx = 1
target_probas_2 = probas[batch_idx, [0, 1, 2], targets[batch_idx]]
print("Text 2:", target_probas_2)

probas.shape: torch.Size([2, 3, 50257])
Text 1: tensor([    0.0001,     0.0000,     0.0000])
Text 2: tensor([    0.0000,     0.0001,     0.0000])


## Computing the loss step by step

![Loss computation](https://drek4537l1klr.cloudfront.net/raschka/Figures/5-7.png)

In [15]:
# Compute the log probabilities of the target tokens.
# NOTE: Working with logarithms of probability scores is more manageable in mathematical
#       optimization than handling the scores directly.
log_probas = torch.log(torch.cat((target_probas_1, target_probas_2)))
print(f"log_probas: {log_probas}")

# Compute the average log probability of the target tokens.
avg_log_probas = torch.mean(log_probas)
print(f"avg_log_probas: {avg_log_probas}")

# The goal is to get the average log probability as close to 0 as possible by updating the model’s
# weights as part of the training process. However, in deep learning, the common practice isn’t to
# push the average log probability up to 0 but rather to bring the negative average log probability
# down to 0. The negative average log probability is simply the average log probability multiplied
# by –1.
neg_avg_log_probas = avg_log_probas * -1
print(f"neg_avg_log_probas: {neg_avg_log_probas}")

log_probas: tensor([ -9.5042, -10.3796, -11.3677, -11.4798,  -9.7764, -12.2561])
avg_log_probas: -10.793964385986328
neg_avg_log_probas: 10.793964385986328


In [16]:
# As we can see, the logits tensor has three dimensions: batch size, number of tokens, and
# vocabulary size. The targets tensor has two dimensions: batch size and number of tokens.
# For the cross_entropy loss function in PyTorch, we want to flatten these tensors by combining
# them over the batch dimension:
print("Logits shape:", logits.shape)
print("Targets shape:", targets.shape)

logits_flat = logits.flatten(0, 1)
targets_flat = targets.flatten()
print("Flattened logits:", logits_flat.shape)
print("Flattened targets:", targets_flat.shape)

loss = torch.nn.functional.cross_entropy(logits_flat, targets_flat)
print(loss)

Logits shape: torch.Size([2, 3, 50257])
Targets shape: torch.Size([2, 3])
Flattened logits: torch.Size([6, 50257])
Flattened targets: torch.Size([6])
tensor(10.7940)


## The difference between cross-entropy, perplexity, and KL-divergence

### Cross-entropy

Cross-entropy measures how well a predicted probability distribution $q$ matches a true distribution $p$. It’s defined as:

$$
H(p, q) = -\sum_{x} p(x) \log q(x)
$$

where $x$ runs over all possible events. Intuitively, it’s the average number of bits needed to encode samples from $p$, if they’re encoded according to $q$. The lower the cross-entropy, the closer $q$ is to $p$. 

[According to Wikipedia](https://en.wikipedia.org/wiki/Cross-entropy), in information theory, the cross-entropy between two probability distributions ${\displaystyle p}$ and ${\displaystyle q}$, over the same underlying set of events, measures the average number of bits needed to identify an event drawn from the set when the coding scheme used for the set is optimized for an estimated probability distribution ${\displaystyle q}$, rather than the true distribution 
${\displaystyle p}$.

This statement reflects a fundamental idea from information theory: cross-entropy measures the cost of encoding data from one distribution $p$ under the assumptions of another distribution $q$. The unit “bits” arises because we’re working in the context of binary information encoding. Intuitively, each bit represents a yes/no choice, and the cross-entropy tells us, on average, how many such choices we’d need to make to encode the true outcomes from $p$, given that our model assigns probabilities according to $q$.

- If $q$ perfectly matches $p$, the encoding is as efficient as possible—this is essentially the entropy $H(p)$ of the true distribution.  
- If $q$ differs from $p$, the encoder based on $q$ will make less informed decisions, leading to longer or more error-prone codes on average.  
- The “lower” cross-entropy means we’re closer to the ideal scenario where $q \approx p$, which indicates our model (represented by $q$) is doing a better job of approximating the true distribution $p$.  
- Conversely, a higher cross-entropy indicates that $q$ diverges significantly from $p$, causing inefficiencies and increasing the average number of bits needed.

So, the cross-entropy not only quantifies the difference between two distributions, but also translates that difference into the practical costs of encoding data.

**Example**:  
- True distribution: $p = [0.7, 0.2, 0.1]$
- Predicted distribution 1: $q_1 = [0.6, 0.3, 0.1]$
- Predicted distribution 2: $q_2 = [0.9, 0.05, 0.05]$
  - $H(p, q_1)$ will be lower than $H(p, q_2)$, because $q_1$ is closer to $p$ than $q_2$.

### Perplexity

Perplexity is often used in language modeling and other probabilistic models to measure how well a model predicts a sample. It’s defined as the exponentiated average negative log-probability:

$$
\text{Perplexity}(p, q) = 2^{H(p, q)}
$$

This represents the effective number of choices the model assigns to each outcome. A lower perplexity means the model is more confident in its predictions. Perplexity is often viewed as a normalized measure of cross-entropy, expressed in terms of the equivalent branching factor. For instance, if a language model’s perplexity is 10, it implies the model is, on average, as uncertain as making a single choice out of 10 equally likely outcomes.

[According to Wikipedia](https://en.wikipedia.org/wiki/Perplexity), in information theory, perplexity is a measure of uncertainty in the value of a sample from a discrete probability distribution. The larger the perplexity, the less likely it is that an observer can guess the value which will be drawn from the distribution.

[From Sebastian Raschka's book:](https://www.manning.com/books/build-a-large-language-model-from-scratch)

Perplexity is a measure often used alongside cross entropy loss to evaluate the performance of models in tasks like language modeling. It can provide a more interpretable way to understand the uncertainty of a model in predicting the next token in a sequence.

Perplexity measures how well the probability distribution predicted by the model matches the actual distribution of the words in the dataset. Similar to the loss, a lower perplexity indicates that the model predictions are closer to the actual distribution. Perplexity can be calculated as ```perplexity = torch.exp(loss)```, which returns ```tensor(48725.8203)``` when applied to the previously calculated loss.

Perplexity is often considered more interpretable than the raw loss value because it signifies the effective vocabulary size about which the model is uncertain at each step. In the given example, this would translate to the model being unsure about which among 48,725 tokens in the vocabulary to generate as the next token.

[ChatGPT](https://chatgpt.com/c/67f355f7-0c14-800f-84ad-1fa039a6025d) provides a similar intuitive explanation. If we consider a language model predicting the next word in a sentence, perplexity provides a numerical summary of how uncertain or "perplexed" the model is, on average, when choosing among possible outcomes. A perplexity value of 10, for example, indicates that the model’s uncertainty is equivalent to having 10 equally likely choices for each word it predicts. In other words, lower perplexity means the model is more confident in its predictions, as it can narrow down the possible outcomes to a smaller, more focused set. Higher perplexity indicates greater uncertainty or poorer model performance, since the model must spread its probability mass across more outcomes, essentially "considering" a larger range of possibilities before making a prediction.

This interpretation of perplexity as a kind of "average branching factor" makes it particularly useful in evaluating the quality of language models. Instead of dealing with abstract bits or logarithms (as in cross-entropy), perplexity translates the model’s predictive efficiency into a form that’s more intuitive.

**Example**:  
- Suppose a language model predicts a sentence like “The cat sat on the ____” with probabilities for possible words:  
  - $p(\text{mat})$ = 0.8, $p(\text{floor})$ = 0.15, $p(\text{roof})$ = 0.05
  - If the true word is “mat” and the model’s probabilities closely match this, the perplexity will be low.  
  - If the model assigns much lower probability to “mat” and higher to other options, the perplexity will increase, indicating worse predictions.


### KL Divergence (Kullback-Leibler Divergence)  

KL divergence measures how one probability distribution \( q \) diverges from a reference distribution \( p \). It’s given by:

$$
D_{KL}(p \parallel q) = \sum_{x} p(x) \log\frac{p(x)}{q(x)}
$$

KL divergence is always non-negative and equals zero only when \( p = q \). Unlike cross-entropy, it explicitly quantifies the “distance” (in an information-theoretic sense) between the two distributions. While cross-entropy tells us how many bits are needed to encode \( p \) using \( q \), KL divergence tells us how many extra bits are needed compared to using the true distribution \( p \) itself.

**Example**:  
- True distribution: p = [0.5, 0.5]
- Predicted distribution 1: $q_1$ = [0.6, 0.4]
- Predicted distribution 2: $q_2$ = [0.9, 0.1]
  - $D_{KL}(p \parallel q_1)$ is smaller than $D_{KL}(p \parallel q_2)$, because $q_1$ is closer to $p$.  
  - If $q_1$ becomes equal to $p$, the KL divergence will be zero.


### Comparing the Concepts

1. **Cross-Entropy vs. KL Divergence**:  
   - Cross-entropy combines the entropy of $p$, which is fixed for a given $p$, and the KL divergence from $p$ to $q$:  
     $$
     H(p, q) = H(p) + D_{KL}(p \parallel q)
     $$
   - While cross-entropy measures the total coding cost under $q$, KL divergence isolates the inefficiency due to $q$’s divergence from $p$.

2. **Perplexity and Cross-Entropy**:  
   - Perplexity is derived directly from cross-entropy, converting the measure into an interpretable “average number of choices.” It essentially provides a more human-readable version of the model’s performance.  
   - Both low perplexity and low cross-entropy indicate a better model fit, but perplexity is the exponential form and gives a more intuitive sense of the model’s uncertainty.

3. **Perplexity and KL Divergence**:  
   - While perplexity is connected to cross-entropy, KL divergence is a more nuanced measure that focuses on how much $q$ deviates from $p$ rather than the raw efficiency of encoding.  
   - Perplexity doesn’t directly measure divergence; instead, it measures how well the model predicts, which can be related to divergence indirectly through the cross-entropy.

In summary, cross-entropy and perplexity are practical metrics for evaluating how well a predictive model matches a true distribution, with perplexity offering a more intuitive interpretation. KL divergence, on the other hand, is a more fundamental information-theoretic measure that quantifies how much one distribution differs from another, forming a building block for understanding the inefficiencies captured by cross-entropy.

# Training and validation set losses

When preparing the data loaders, we split the input text into training and validation set portions. Then we tokenize the text (only shown for the training set portion for simplicity) and divide the tokenized text into chunks of a user-specified length (here, 6). Finally, we shuffle the rows and organize the chunked text into batches (here, batch size 2), which we can use for model training.

![Data splits](https://drek4537l1klr.cloudfront.net/raschka/Figures/5-9.png)

In [22]:
# Load example dataset.
file_path = "data/the_verdict.txt"
with open(file_path, "r", encoding="utf-8") as file:
    text_data = file.read()

# Pritn statistics.
total_characters = len(text_data)
total_tokens = len(tokenizer.encode(text_data))
print("Characters:", total_characters)
print("Tokens:", total_tokens)

# Divide the dataset into training and validation sets.
# NOTE: This is a simple and naive approach to splitting the dataset and should be replaced with
#       tooling from pytorch (e.g. https://pytorch.org/docs/stable/data.html#torch.utils.data.random_split)
train_ratio = 0.90
split_idx = int(train_ratio * len(text_data))
train_data = text_data[:split_idx]
val_data = text_data[split_idx:]
print("Train data (chars):", len(train_data))
print("Validation data (chars):", len(val_data))

# Create the dataloaders.
torch.manual_seed(123)
train_loader = create_dataloader_v1(
    train_data,
    batch_size=2,
    max_length=GPT_CONFIG_124M.context_length,
    stride=GPT_CONFIG_124M.context_length,
    drop_last=True,
    shuffle=True,
    num_workers=0,
)
val_loader = create_dataloader_v1(
    val_data,
    batch_size=2,
    max_length=GPT_CONFIG_124M.context_length,
    stride=GPT_CONFIG_124M.context_length,
    drop_last=False,
    shuffle=False,
    num_workers=0,
)

Characters: 20479
Tokens: 5145
Train data (chars): 18431
Validation data (chars): 2048
