# Part I: Training a GPT-2 model from scratch

The objective of this notebook is to train a LLM from scratch on a short text. The model will be trained on the *next token generation* or *causal generation* task. The main topics we will be working on:
1. Data preparation
2. Building the LLM architecture
3. Training an LLM

This lab projects is heavily inspired by [S. Raschka's youtube lecture](https://www.youtube.com/watch?v=quh7z1q7-uc&t=16s) on the topic.

## 0. Library imports

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import tiktoken


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

## 1. LLM input data
The obective of this section is to prepare the data in order to train a LLM. We will be using a short poem book called *Lou Catounet Gascoun*, written in the gascon dialact of the occitan language, by [Guilhèm Adèr](https://en.wikipedia.org/wiki/Guilh%C3%A8m_Ad%C3%A8r). Preparing the data for using with the LLM will be done in two steps:
1. Tokenizing the data.
2. Preparing the input-output data batches for training.

### 1.1. Using a pre-trained tokenizer
A tokenizer has two main roles:
- Breaking text into smaller chunks of characters called tokens.
- Mapping the text expressed in these chunks into a sequence of integers.
Each token is assigned a unique integer, so that the token-integer mapping is bijective, and one can both encode a text into a sequence of integers and decode a sequence of integers into a text.

The more common option is to use a pre-trained tokenizer. In this notebook we will use a tokenizer based on the Byte Pair Encoding (BPE) tokenization method. The BPE tokenization algorithm builds a set of tokens iteratively based on a corpus of texts:
- It starts with an initial vocabulary of all characters present in the corpus.
- It creates new tokens by merging existing ones: at each step, the most frequent pair of existing tokens is merged.
- The algorithm stops when the desired vocabulary size is reached.

Below we use a tokenizer that has been pre-trained for the GPT-2 model on a large corpus of texts. We use the `tiktoken` library for its particularly efficient implementation.

In [None]:
tokenizer = tiktoken.get_encoding("gpt2")

**Exercise.** Write a few sentences in the language of your choice and use the tokenizer to encode and decode it.

In [None]:
# TODO: Write text in the strings below and encode it using the tokenizer
text = (
    "..."
    "..."
)

integers = ...  # TODO: encode the text using the tokenizer
# TODO: print the sequence of integers

In [None]:
# %load solutions/training/encode.py

In [None]:
# TODO: decode the sequence of integers using the tokenizer and print the resulting string

In [None]:
# %load solutions/training/decode.py

### 1.2. Loading the data
The objective of this section is to load the data from the text we will be using into batches appropriate for training. Since we are training our LLM for the *next token prediction* task, we will train our LLM on input-output pairs where the output contains the input shifted to the left by one token (like we deed for RNNs).

In order to do so, we will split the text into chunks of length `max_length`. These chanks can potentially overlap, the starting point of each chunk is obtained by taking the starting point of the previous chunk and adding the `stride` parameter.

**Exercise.** Complete the function `GPTDataset` below by filling in the `TODO` tags.

In [None]:
class GPTDataset(Dataset):
    def __init__(self, txt, tokenizer, max_length, stride):
        self.input_ids = []
        self.target_ids = []

        # TODO: Tokenize the entire text
        token_ids = ...

        # Use a sliding window to chunk the book into overlapping sequences of max_length by moving the window by stride
        for i in range(0, len(token_ids) - max_length, stride):
            input_chunk = ...  # TODO: choose the input tokens
            target_chunk = ... # TODO: choose the target tokens
            self.input_ids.append(torch.tensor(input_chunk))
            self.target_ids.append(torch.tensor(target_chunk))

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return self.input_ids[idx], self.target_ids[idx]

In [None]:
# %load solutions/training/dataset.py

The above function creates a dataset from the given text `txt` by splitting it into chunks and creating input-output pairs. Next, we create a second helper function `create_dataloader` that uses the function `GPTDataset` to create the dataset and arrange it into batches for training.

**Exercise.** Complete the function `create_dataloader` below by filling in the `TODO` tags.

In [None]:
def create_dataloader(txt, batch_size=4, max_length=256, 
                         stride=128, shuffle=True, drop_last=True,
                         num_workers=0):

    # Initialize the tokenizer
    tokenizer = ... # TODO: initialize the GPT2 tokenizer

    # Create dataset
    dataset = ... # TODO: initialize the GPTDataset

    # Create dataloader
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        drop_last=drop_last,
        num_workers=num_workers
    )

    return dataloader

In [None]:
# %load solutions/training/dataloader.py

In [None]:
with open("catounet.txt", "r", encoding="utf-8") as f:
    text = f.read()

print(f"Total number of characters in the text:\n {len(text)}\n")

print(f"An excerpt from the book:\n {text[1_000:1_500]}\n")

**Exercise.** Load the `text` using the `create_dataloader` function and the following 

In [None]:
# Creating batched input-output pairs
dataloader = create_dataloader(raw_text, batch_size=8, max_length=4, stride=4, shuffle=False)

data_iter = iter(dataloader)
inputs, targets = next(data_iter)
print("Inputs:\n", inputs)
print("\nTargets:\n", targets)

## 2. Building the LLM architecture
The LLM uses the transformer architecture, the main layer types that are present in such an architecture are:
- Multi-head attention
- Layer Norm
- GeLU activation
- Feed-forward
These different layer types will be conbined in a *transformer block*, and several transformer blocks will be stacked on top of each other to compose the whole GPT-2 architecture.

### 2.1. Multi-head Attention
The *multi-head* attention layer takes the input and processes it in chunks of equal length through eahc of its different heads, by using the attention mechanism, i.e. computing the queries, keys and values, and the attention scores and weights from them.

**Exercise.** Implement the `MultiHeadAttention` by filling in the `TODO` flags below.


**Exercise.** Complete the `TODO` flags below in order to define these four kind of layers.

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        # TODO: use the 'assert' statement to check that d_out is divisible by num_heads

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = ...  # TODO: compute the per-head dimension in terms of d_out and num_heads

        self.W_query = # TODO: initilize the linear layer for the query with the appropriate dimensions and the optional bias term
        self.W_key = # TODO: initilize the linear layer for the key with the appropriate dimensions and the optional bias term
        self.W_value = # TODO: initilize the linear layer for the value with the appropriate dimensions and the optional bias term
        self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x):
        b, num_tokens, d_in = x.shape

        keys = # TODO: apply the key matrix to the input  # Shape: (b, num_tokens, d_out)
        queries = # TODO: apply the query matrix to the input
        values = # TODO: apply the value matrix to the input

        # We implicitly split the matrices by adding a `num_heads` dimension
        # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        # Compute scaled dot-product attention (aka self-attention) with a causal mask
        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head

        # Original mask truncated to the number of tokens and converted to boolean
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        # Use the mask to fill attention scores
        attn_scores.masked_fill_(mask_bool, -torch.inf)

        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Shape: (b, num_tokens, num_heads, head_dim)
        context_vec = (attn_weights @ values).transpose(1, 2)

        # Combine heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec)  # optional projection

        return context_vec

In [None]:
# %load solutions/training/multihead.py

### 2.2. Layer Norm

You might have heard of BatchNorm, where the inputs of a layer are normalized across the batch: the mean and standard deviation of the inputs are computed accross the batch dimension, then the inputs are normalized by substracting the mean and dividing by the standard deviation. 

The *LayerNorm* normalization technique is similar, only the normalization happens accross the feature dimension rather than the batch dimension.

Once the feature mean $\mu$ and standard deviation $\sigma$ are computed, the inputs are normalized as follows:
$$\hat{x}_i = \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}},$$
where $\epsilon$ is a small constant typically taken to be equal to $10^{-5}$ to avoid division by numbers close to zero.

The output of the `LayerNorm` is not $\hat{x}_i$ though, but
$$y_i = \gamma \hat{x}_i + \beta$$
where $\gamma$ (scale parameter) and $\beta$ (shift parameter) are learnable parameters.

**Exercise.** Complete the `LayerNorm` class below by computing the feature mean and variance and performing the appropriate normalization.

In [None]:
class LayerNorm(nn.Module):
    def __init__(self, emb_dim):
        super().__init__()
        self.eps = 1e-5 # small value to avoid division by zero(
        self.scale = nn.Parameter(torch.ones(emb_dim)) # scale parameter (learnable)
        self.shift = nn.Parameter(torch.zeros(emb_dim)) # shift parameter (learnable)

    def forward(self, x):
        mean = ... # TODO: compute the mean of the input tensor over the last dimension
        var = ... # TODO: compute the variance of the input tensor over the last dimension
        norm_x = ... # TODO: normalize the input tensor
        y = ... # TODO: apply the learned scale and shift parameters
        return y

In [None]:
# %load solutions/training/layernorm.py

### 2.3. GELU
The *GELU* or Gaussian Error Linear Unit is an activation function defined as
$$\text{GELU}(x) = x\Phi(x)$$
where $\Phi(x)$ is the CDF of the standard normal distribution.
In order to avoid computationally expensive calculations, the GELU is often approximated using
$$\text{GELU}(x)\simeq 0.5 x \Bigg(
    1 + \text{tanh}\bigg(
        \sqrt{\frac{2}{\pi}}\big(x + 0.044715 x^3\big)
    \bigg)
    \Bigg).$$
We will use this approximation in our definition.

**Exercise.** Code the GELU activation function using the above approximation. Use the `torch` implementation of *tanh*, square root, $\pi$ and the *power* function.

In [None]:
class GELU(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        gelu = ... # TODO: implement the GELU activation function
        return gelu

In [None]:
# %load solutions/training/gelu.py

### 2.4. The Feed Forward blocks
The *feed forward* layers in our *transformer blocks* will consist of:
1. A linear layer
2. A GELU activation
3. A linear layer

The feed forward layer has the same input and output dimensions, and the hidden dimension is equal to 4 times the input dimension.

**Exercise.** Complete the `FeedForward` class by filling in the `TODO` flags.

In [None]:
class FeedForward(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        # TODO: add the two linear layers with the intermediate GELU activation function to the following sequential model
        self.layers = nn.Sequential(
            ... # TODO: add the layers
        )

    def forward(self, x):
        return self.layers(x)

In [None]:
# %load solutions/training/feedforward.py

### 2.5. The Transformer block
The next step is to define our *Transformer Block*, which is composed of:
1. A *LayerNorm* normalization
2. A *Multi-head attention* layer
3. A *Dropout* layer
4. A *LayerNorm* normalization
5. A *Feed Forward* layer
6. A *Dropout* layer

Moreover, two skip-connections are present in the *Transformer Block*:
- *(A)* A skip connection that adds the original input to the output of the operation *3* above.
- *(B)* A skip connection that adds the output of *(A)* to the output of the operation *6* above.

In order to initialize all the necessary layers, a configuration dictionary `cfg` will be passed to the `TransformerBlock` at initialization. This configuration dictionary will contain the information about the different parameters necessary to define the transformer architecture, under the following keywords:
- `vocab_size`: number of tokens in the vocabulary
- `emb_dim`: embedding dimension
- `context_length`: context length
- `n_heads`: number of heads in Multy head attention
- `drop_rate`: dropout rate
- `qkv_bias`: boolean determining weather to add a bias to the key, query and value matrices
- `n_layers`: number of transformer blocks to stack

**Exercise.** Implement the `forward` method of the `TransformerBlock`.

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.att = MultiHeadAttention(
            d_in=cfg["emb_dim"],
            d_out=cfg["emb_dim"],
            context_length=cfg["context_length"],
            num_heads=cfg["n_heads"], 
            dropout=cfg["drop_rate"],
            qkv_bias=cfg["qkv_bias"])
        self.ff = FeedForward(cfg["emb_dim"])
        self.norm1 = LayerNorm(cfg["emb_dim"])
        self.norm2 = LayerNorm(cfg["emb_dim"])
        self.drop_shortcut = nn.Dropout(cfg["drop_rate"])

    def forward(self, x):
        # TODO: implement the forward pass of the TransformerBlock by performing the necessary operations on the input x
        return x

In [None]:
# %load solutions/training/transformerblock.py

### 2.6. The GPT-2 architecture
We next implement the GPT architecture by specifying the necessary model parameters in the `GPT_CONFIG_124M` dictionary, and by stacking transformer blocks along with other necessary layers in the `GPTModel` class.

In [None]:
GPT_CONFIG_124M = {
    "vocab_size": 50257,    # Number of tokens in the GPT-2 vocabulary
    "context_length": 128,   # Number of tokens in the context window
    "emb_dim": 768,   # Dimension of token embeddings
    "n_heads": 12,           # Number of attention heads in the multi-head attention layers
    "n_layers": 12,         # Number of transformer layers
    "drop_rate": 0.0,       # Dropout rate
    "qkv_bias": False,      # Whether to include bias in the Q, K, V linear layers
}

In [None]:
class GPTModel(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
        self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
        self.drop_emb = nn.Dropout(cfg["drop_rate"])
        
        self.trf_blocks = nn.Sequential(
            *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
        
        self.final_norm = LayerNorm(cfg["emb_dim"])
        self.out_head = nn.Linear(
            cfg["emb_dim"], cfg["vocab_size"], bias=False
        )

    def forward(self, in_idx):
        batch_size, seq_len = in_idx.shape
        tok_embeds = self.tok_emb(in_idx)
        pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
        x = tok_embeds + pos_embeds  # Shape [batch_size, num_tokens, emb_size]
        x = self.drop_emb(x)
        x = self.trf_blocks(x)
        x = self.final_norm(x)
        logits = self.out_head(x)
        return logits

### 2.7. Checking the GPT-2 Model

In [None]:
batch = []

txt1 = "Le chat noir est sur la"
txt2 = "Le soleil brille dans"

batch.append(torch.tensor(tokenizer.encode(txt1)))
batch.append(torch.tensor(tokenizer.encode(txt2)))
batch = torch.stack(batch, dim=0)
print(batch)

In [None]:
torch.manual_seed(44)
model = GPTModel(GPT_CONFIG_124M)

out = model(batch)
print("Input batch:\n", batch)
print("\nOutput shape:", out.shape)
print(out)

### 2.8. Generating text with the GPT-2 model
In this section we will write a `generate_text` function that uses a GPT-2 model to generate text given an input sequence of tokens.

**Exercise.** Complete the `generate_text` function below by filling in the `TODO` flags.

In [None]:
def generate_text(model, idx, max_new_tokens, context_size):
    # idx is (batch, n_tokens) array of indices in the current context
    for _ in range(max_new_tokens):
        
        # TODO: Crop current context if it exceeds the supported context size
        # E.g., if LLM supports only 5 tokens, and the context size is 10
        # then only the last 5 tokens are used as context
        idx_cond = ... # TODO: crop the current context
        
        # Get the predictions
        with torch.no_grad():
            logits = # TODO: get the logits from the model
        
        # Focus only on the last time step
        # (batch, n_tokens, vocab_size) becomes (batch, vocab_size)
        logits = logits[:, -1, :]  

        # TODO: Apply softmax to get probabilities
        probas = ...  # (batch, vocab_size)

        # TODO: Get the idx of the vocab entry with the highest probability value
        idx_next = ...  # (batch, 1)

        # TODO: Append sampled index to the running sequence
        idx = ... # (batch, n_tokens+1)

    return idx

In [None]:
# %load solutions/training/generate.py

In [None]:
start_context = "Bonjour, je suis"

encoded = tokenizer.encode(start_context)
print("encoded:", encoded)

encoded_tensor = torch.tensor(encoded).unsqueeze(0)
print("encoded_tensor.shape:", encoded_tensor.shape)

In [None]:
out = generate_text(
    model=model,
    idx=encoded_tensor, 
    max_new_tokens=6, 
    context_size=GPT_CONFIG_124M["context_length"]
)

print("Output:", out)
print("Output length:", len(out[0]))

In [None]:
decoded_text = tokenizer.decode(out.squeeze(0).tolist())
print(decoded_text)

## 3. Training the LLM
The objective of this section is to train the LLM on the text `catounet.txt`. In order to check that everything is working properly, we will reload the data.

In [None]:
# Helper functions to tokenize and detokenize text

def text_to_token_ids(text, tokenizer):
    encoded = tokenizer.encode(text)
    encoded_tensor = torch.tensor(encoded).unsqueeze(0) # add batch dimension
    return encoded_tensor

def token_ids_to_text(token_ids, tokenizer):
    flat = token_ids.squeeze(0) # remove batch dimension
    return tokenizer.decode(flat.tolist())

 ### 3.1. Checking the data

In [None]:
# Load the book data
with open("catounet.txt", "r") as f:
    book = f.read()

print(f"An excerpt from the book:\n {book[1_000:1_500]}\n")

**Exercise.** Print the total number of characters and the total number of tokens in the book.

In [None]:
total_characters = # TODO: compute the total number of characters in the book
total_tokens = # TODO: compute the total number of tokens in the book

print(f"Total number of characters in the book: {total_characters}")
print(f"Total number of tokens in the book: {total_tokens}")

In [None]:
# %load solutions/training/nb_tokens.py

### 3.2. Train and validation loaders
Next we will split the text into a training and a validation set, by taking the first 90% of the characters for training and the last 10% for validation. We will also load these train/validation sets into their respective data loaders.

**Exercise.** Create the training and validation data loaders by completing the `TODO` flags below.

In [None]:
# TODO: Split the book into training and validation data
train_ratio = # TODO: define the proper ratio
split_idx = # TODO: find the index of the character where the split should occur
train_data = # TODO: choose the training characters
val_data = # TODO: choose the validation characters
print(len(train_data), len(val_data))

# We set a seed for reproducibility
torch.manual_seed(44)

train_loader = create_dataloader(
    train_data,
    batch_size=2,
    max_length=GPT_CONFIG_124M["context_length"],
    stride=GPT_CONFIG_124M["context_length"],
    drop_last=True,
    shuffle=True,
    num_workers=0
)

val_loader = create_dataloader(
    val_data,
    batch_size=2,
    max_length=GPT_CONFIG_124M["context_length"],
    stride=GPT_CONFIG_124M["context_length"],
    drop_last=False,
    shuffle=False,
    num_workers=0
)

In [None]:
# %load solutions/training/train_val.py

In [None]:
# An optional check that the data loaders are working as expected
print("Train loader:")
for x, y in train_loader:
    print(x.shape, y.shape)

print("\nValidation loader:")
for x, y in val_loader:
    print(x.shape, y.shape)

In [None]:
train_tokens = 0
for input_batch, target_batch in train_loader:
    train_tokens += input_batch.numel()

val_tokens = 0
for input_batch, target_batch in val_loader:
    val_tokens += input_batch.numel()

print("Training tokens:", train_tokens)
print("Validation tokens:", val_tokens)
print("All tokens:", train_tokens + val_tokens)

### 3.3. Computing the loss
Since we are training the GPT-2 model on the *next token prediction* task, the output of the model consists of the logits for the different tokens in the whole token vocabulary, i.e. at each step, we are solving a *classification* problem. Therefore we can train our model with the usual *cross-entropy loss*.

In [None]:
def calc_loss_batch(input_batch, target_batch, model, device):
    input_batch, target_batch = input_batch.to(device), target_batch.to(device)
    logits = model(input_batch)
    loss = torch.nn.functional.cross_entropy(logits.flatten(0, 1), target_batch.flatten())
    return loss

def calc_loss_loader(data_loader, model, device, num_batches=None):
    total_loss = 0.
    if len(data_loader) == 0:
        return float("nan")
    elif num_batches is None:
        num_batches = len(data_loader)
    else:
        # Reduce the number of batches to match the total number of batches in the data loader
        # if num_batches exceeds the number of batches in the data loader
        num_batches = min(num_batches, len(data_loader))
    for i, (input_batch, target_batch) in enumerate(data_loader):
        if i < num_batches:
            loss = calc_loss_batch(input_batch, target_batch, model, device)
            total_loss += loss.item()
        else:
            break
    return total_loss / num_batches

We next compute the initial training and validation losses for the model.

In [None]:
model.to(device) # no assignment model = model.to(device) necessary for nn.Module classes

with torch.no_grad(): # Disable gradient tracking for efficiency because we are not training, yet
    train_loss = calc_loss_loader(train_loader, model, device)
    val_loss = calc_loss_loader(val_loader, model, device)

print("Training loss:", train_loss)
print("Validation loss:", val_loss)

### 3.4. Training and monitoring functions
Before we start training the model, we define three helper functions that will help us train the model and monitor the evolution of the model during training.

**Exercice.** Complete the `evaluate_model` function below by filling in the `TODO` tags.

In [None]:
def evaluate_model(model, train_loader, val_loader, device, eval_iter):
    model.eval()
    with torch.no_grad():
        train_loss = # TODO: calculate the training loss
        val_loss = # TODO: calculate the validation loss
    model.train()
    return train_loss, val_loss

In [None]:
# %load solutions/training/eval.py

The `generate_and_print_sample` function will allow us to generate and print sample text during training to monitor the evolution of the model capabilities.

In [None]:
# Generate and print a sample text
def generate_and_print_sample(model, tokenizer, device, start_context):
    model.eval()
    context_size = model.pos_emb.weight.shape[0]
    encoded = text_to_token_ids(start_context, tokenizer).to(device)
    with torch.no_grad():
        token_ids = generate_text(
            model=model, idx=encoded,
            max_new_tokens=50, context_size=context_size
        )
        decoded_text = token_ids_to_text(token_ids, tokenizer)
        print(decoded_text.replace("\n", " "))  # Compact print format
    model.train()

**Exercise.** Complete the `train_model` function by filling in the `TODO` flags.

In [None]:
def train_model(model, train_loader, val_loader, optimizer, device, num_epochs,
                       eval_freq, eval_iter, start_context, tokenizer):
    # TODO: Initialize lists to track losses and tokens seen
    train_losses, val_losses, track_tokens_seen = ... # TODO: initialize to three empty lists
    tokens_seen, global_step = 0, -1

    # Main training loop
    for epoch in range(num_epochs):
        # TODO: Set model to training mode
        
        for input_batch, target_batch in train_loader:
            # TODO:  Reset loss gradients from previous batch iteration
            loss = ...  # TODO: calculate the loss for the current batch
            # TODO: Calculate loss gradients using backpropagation
            # TODO: Update model weights using loss gradients
            tokens_seen += input_batch.numel()
            global_step += 1

            # Optional evaluation step
            if global_step % eval_freq == 0:
                train_loss, val_loss = ... # TODO: evaluate the model
                train_losses.append(train_loss)
                val_losses.append(val_loss)
                track_tokens_seen.append(tokens_seen)
                print(f"Ep {epoch+1} (Step {global_step:06d}): "
                      f"Train loss {train_loss:.3f}, Val loss {val_loss:.3f}")

        # TODO: Print a sample text after each epoch

    return train_losses, val_losses, track_tokens_seen

In [None]:
# %load solutions/training/train.py

### 3.5. Training the model

In [None]:
torch.manual_seed(44)
model = GPTModel(GPT_CONFIG_124M)
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0004, weight_decay=0.1)

num_epochs = 10
train_losses, val_losses, tokens_seen = train_model(
    model, train_loader, val_loader, optimizer, device,
    num_epochs=num_epochs, eval_freq=5, eval_iter=5,
    start_context=val_data[:20], tokenizer=tokenizer
)

### 3.6. Plotting the loss histories

In [None]:
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator

def plot_losses(epochs_seen, tokens_seen, train_losses, val_losses):
    fig, ax1 = plt.subplots(figsize=(5, 3))

    # Plot training and validation loss against epochs
    ax1.plot(epochs_seen, train_losses, label="Training loss")
    ax1.plot(epochs_seen, val_losses, linestyle="-.", label="Validation loss")
    ax1.set_xlabel("Epochs")
    ax1.set_ylabel("Loss")
    ax1.legend(loc="upper right")
    ax1.xaxis.set_major_locator(MaxNLocator(integer=True))  # only show integer labels on x-axis

    # Create a second x-axis for tokens seen
    ax2 = ax1.twiny()  # Create a second x-axis that shares the same y-axis
    ax2.plot(tokens_seen, train_losses, alpha=0)  # Invisible plot for aligning ticks
    ax2.set_xlabel("Tokens seen")

    fig.tight_layout()  # Adjust layout to make room
    plt.savefig("loss-plot.pdf")
    plt.show()

In [None]:
epochs_tensor = torch.linspace(0, num_epochs, len(train_losses))
plot_losses(epochs_tensor, tokens_seen, train_losses, val_losses)

## 4. Questions
Think about the following questions.
1. Did the model learn?
2. What does the loss plot tell you? What is happening?