# Scratchpad 3 - Pre-train LLM

## 1. Calculate text generation loss: cross-entropy and perplexity

### 1.1 Cross-entropy loss (average negative log probability value)

The model generates logits tensor of (batch_size, seq_length, vocab_size), the last dimension can be interpreted as the mode's predicted probability of each token in the whole vacabulary of being the next token. Applying `softmax` makes them probability values from 0 to 1, sumning up to 1.

For the target tokens (the input tokens minus the first token and plus the real next token in the training data), we want to maximize their logits in the output tokens. Which means we want to
-  **maximize the target tokens' softmax-ed values in the ouput tensor to 1**. 

It's easier to do this if we first calculate their log values, because to max the probabilities to 1, then we just need to maximize their log values (which are negative numbers) to 0. 

In ML we need to define a loss function and minimize one loss value calculated from the function. So instead of maximizing muliple log values of all the batches' ouput token sequences, we define the loss to be the 
- **averaged negative log probability value** of all target tokens in the output tensor, from all batches in the output concatenated. 

This is our **cross-entropy** loss. PyTorch provides a `cross_entropy` function which combines two operations:
- apply softmax 
- calculate log of softmax
- calculate negative log-likelihood loss
- average loss across batches

To use it, we need to flatten the logits tensor and the target token IDs batch tensor (aka concat all batches):

```
loss = torch.nn.functional.cross_entropy(logits_flat, target_idx_batch_flat)
```

Let's try it out.

In [20]:
from gpt.gpt_model import GPTModel
import torch

GPT_CONFIG_124M = {
    "vocab_size": 50257,  # Vocabulary size
    "context_length": 1024,  # Context length
    "embedding_dim": 768,  # Embedding dimension
    "n_heads": 12,  # Number of attention heads
    "n_layers": 12,  # Number of layers
    "dropout_rate": 0.1,  # Dropout rate
    "qkv_bias": False,  # Query-Key-Value bias
}

model = GPTModel(GPT_CONFIG_124M)
model.eval();

In [9]:
full_text = "The quick brown fox jumps over the lazy dog"
full_text_tokens = tokenizer.encode(full_text)
print(f"Full text tokens: {full_text_tokens}")

Full text tokens: [464, 2068, 7586, 21831, 18045, 625, 262, 16931, 3290]


In [14]:
input_idx = full_text_tokens[:4]
input_text = tokenizer.decode(input_idx)
print(f"Input text: {input_text}")
input_idx_batch = torch.tensor(input_idx).unsqueeze(0)
print(f"Input idx batch tensor: {input_idx_batch}")

Input text: The quick brown fox
Input idx batch tensor: tensor([[  464,  2068,  7586, 21831]])


In [17]:
with torch.no_grad():
    logits = model(input_idx_batch)
print(f"Output logits shape: {logits.shape}")

Output logits shape: torch.Size([1, 4, 50257])


In [19]:
target_idx = full_text_tokens[1:5]
target_text = tokenizer.decode(target_idx)
print(f"Target text: {target_text}")
target_idx_batch = torch.tensor(target_idx).unsqueeze(0)
print(f"Target idx batch tensor: {target_idx_batch}")

Target text:  quick brown fox jumps
Target idx batch tensor: tensor([[ 2068,  7586, 21831, 18045]])


In [24]:
# logits is the model's output of shape (batch_size, seq_len, vocab_size) -> (batch_size * seq_len, vocab_size)
logits_flat = logits.flatten(0, 1)
# targets is the target sequence's token IDs shape (batch_size, seq_len) -> (batch_size * seq_len)
targets_idx_flat = target_idx_batch.flatten(0, 1)
# Compute the cross entropy loss
loss = torch.nn.functional.cross_entropy(logits_flat, targets_idx_flat)
print(loss)

tensor(10.7479)


With our randomly initialized GPT2 model, the cross-entropy loss is 10.7479.

### 1.2 Perplexity

**Perplexity** is also a metric that measures how well the model predicts a sequence of tokens. It is the **exponential of the average negaltive log-likelihood** of the predicted probabilities for each token in a sequence, aka
- the **exponential of cross-entropy**
  
It tells us how "surprised" the model is by the test data. It can be interpreted as how many choices the model thinks it has when predicting the next token.

$$\text{Perplexity} = \exp\left(-\frac{1}{N} \sum_{i=1}^{N} \log P(x_i)\right)$$

```
cros_entropy_loss = torch.nn.functional.cross_entropy(logits_flat, target_idx_batch_flat)
perplexity_loss = torch.exp(cros_entropy_loss)
```

- A perplexity of 1 means the model is extremely sure (assigns probability of 1 to the correct next token ID).
- A high perplexity means the model is uncertain about the next token.


In [25]:
perplexity = torch.exp(loss)
print(perplexity)

tensor(46531.8984)


Our untrained model has perplexity of 46532, with the vocabulary size 50257, it is basically random guessing.