# 📈 Training

This notebook explores some aspects of training a model, and optimising for performance.

## Setup 

In [1]:
import autorootcwd

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import math
import numpy as np
import matplotlib.pyplot as plt

import time
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer
from datasets import Dataset

from src.utils import get_device, get_model, get_tokenizer, get_dataloader, get_micro_dataloader, tokenize, get_optimizer, get_scheduler
from src.model import Model
from src.train.baseline import train, sample

In [4]:
device = get_device()
print(f"Using device: {device}")

Using device: cuda


Let's load a small model for debugging and running some tests. We'll also create a dummy batch that is representative in size of the actual workflows during training. 

## Overfitting

In [5]:
# Create a dataset with a single example
sentence = "I am a large language model and I can memorize this sentence."
dataset = Dataset.from_dict({"text": [sentence]})

### Llama2 (9M)

In [6]:
# Load model
model_name = "mikasenghaas/llama2-9m-fresh"
model : Model = get_model(model_name)

# Load tokenizer
tokenizer : AutoTokenizer = get_tokenizer(model_name)
bos, eos = tokenizer.bos_token_id, tokenizer.eos_token_id
tokenizer.pad_token = tokenizer.eos_token

# Create dataloader
seq_length = 32
data = dataset.map(lambda examples: tokenize(examples["text"], tokenizer, max_length=seq_length+1, return_tensors=None))
dataloader = get_dataloader(data, batch_size=1, shuffle=False, cycle=True)

# Train
model.train()
model.to(device)
step, max_steps = 0, 100
optimizer = get_optimizer(model, lr=3e-3, weight_decay=0.0, betas=(0.9, 0.95))
scheduler = get_scheduler(optimizer, num_steps=None, warmup_steps=None, num_cycles=None, min_lr_factor=None, last_epoch=None, enable=False)
loss_fn = torch.nn.CrossEntropyLoss()
while step < max_steps:
    batch = next(dataloader)
    batchloader = get_micro_dataloader(batch, micro_batch_size=1)
    outputs = train(step, model, batchloader, loss_fn, optimizer, scheduler, device, "float32", 1.0)

    if (step+1) % 10 == 0 or step == 0:
        print(f"Step: {step+1}, Loss: {outputs.loss:.4f} ({sample(model, tokenizer, 'I am', 50, device=device)})")
    step += 1

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

Step: 1, Loss: 10.3590 (I am I I I I I I I IHD model lying model lying model lyingmaz lyingmaz model lyingmaz model lyingmaz model lyingmaz model lyingmazmazmazmazmazmazmazmazmazmazmazmazmazmazsuppmazsuppmazsuppmazsupp)
Step: 10, Loss: 4.7743 (I am a large model and I large and I can memorize this sentence.)
Step: 20, Loss: 0.8941 (I am a large language model and I can memorize this sentence.)
Step: 30, Loss: 0.0897 (I am a large language model and I can memorize this sentence.)
Step: 40, Loss: 0.0101 (I am a large language model and I can memorize this sentence.)
Step: 50, Loss: 0.0027 (I am a large language model and I can memorize this sentence.)
Step: 60, Loss: 0.0013 (I am a large language model and I can memorize this sentence.)
Step: 70, Loss: 0.0009 (I am a large language model and I can memorize this sentence.)
Step: 80, Loss: 0.0007 (I am a large language model and I can memorize this sentence.)
Step: 90, Loss: 0.0006 (I am a large language model and I can memorize this sente

### GPT-2 (124M)

In [7]:
# Load model
model_name = "mikasenghaas/gpt2-124m-fresh"
model : Model = get_model(model_name)

# Load tokenizer
tokenizer : AutoTokenizer = get_tokenizer(model_name)
bos, eos = tokenizer.bos_token_id, tokenizer.eos_token_id
tokenizer.pad_token = tokenizer.eos_token

# Create dataloader
seq_length = 32
data = dataset.map(lambda examples: tokenize(examples["text"], tokenizer, max_length=seq_length+1, return_tensors=None))
dataloader = get_dataloader(data, batch_size=1, shuffle=False, cycle=True)

# Train
model.train()
model.to(device)
step, max_steps = 0, 100
optimizer = get_optimizer(model, lr=3e-4, weight_decay=0.0, betas=(0.9, 0.95))
scheduler = get_scheduler(optimizer, num_steps=None, warmup_steps=None, num_cycles=None, min_lr_factor=None, last_epoch=None, enable=False)
loss_fn = torch.nn.CrossEntropyLoss()
while step < max_steps:
    batch = next(dataloader)
    batchloader = get_micro_dataloader(batch, micro_batch_size=1)
    outputs = train(step, model, batchloader, loss_fn, optimizer, scheduler, device, "float32", 1.0)

    if (step+1) % 10 == 0 or step == 0:
        print(f"Step: {step+1}, Loss: {outputs.loss:.4f} ({sample(model, tokenizer, 'I am', 50, device=device)})")
    step += 1

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

Step: 1, Loss: 10.6368 (I am  a a a aizeizeizeizeizeizeizeizeizeizeizeizeizeizeizeizeizeizeizeizeizeizeizeizeizeizeizeizeizeizeizeizeizeizeizeizeizeizeizeizeizeizeizeizeizeize)
Step: 10, Loss: 0.1532 (I am  a large language model and I can memorize this sentence.)
Step: 20, Loss: 0.0044 (I am  a large language model and I can memorize this sentence.)
Step: 30, Loss: 0.0013 (I am  a large language model and I can memorize this sentence.)
Step: 40, Loss: 0.0006 (I am  a large language model and I can memorize this sentence.)
Step: 50, Loss: 0.0004 (I am  a large language model and I can memorize this sentence.)
Step: 60, Loss: 0.0003 (I am  a large language model and I can memorize this sentence.)
Step: 70, Loss: 0.0002 (I am  a large language model and I can memorize this sentence.)
Step: 80, Loss: 0.0002 (I am  a large language model and I can memorize this sentence.)
Step: 90, Loss: 0.0001 (I am  a large language model and I can memorize this sentence.)
Step: 100, Loss: 0.0001 (I am  

## Mixed Precision

We can speed up training by using lower precision for model weights, activations and gradients. PyTorch offers a one-liner to set the internal precision used in matrix multiplications (most used operation in Transformers), as well as a context manager to use lower precision for activations and gradients.

Let's investigate the theoretical speed-ups we can expect from both of these techniques on a GPT-2 (124M) model.

In [8]:
# Micro batch
B, L, V = 4, 1024, 50257
x = torch.randint(0, V, (B, L+1), dtype=torch.long).to("cuda")
batch = {
    "input_ids": x[:, :-1],
    "attention_mask": torch.ones_like(x[:, :-1]),
    "labels": x[:, 1:]
}
batch = {k: v.to(device) for k, v in batch.items()}

### Llama2 (9M)

In [9]:
# Not relevant

### GPT-2 (124M)

In [10]:
# Load GPT2 model
model_name = "mikasenghaas/gpt2-124m-fresh"
model : Model = get_model(model_name).to(device)

#### Highest Precision + Float32

In [11]:
torch.set_float32_matmul_precision("highest")

In [12]:
%%timeit -n 5 -r 5
logits = gpt2(**batch)
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), batch["labels"].reshape(-1))
loss.backward()

NameError: name 'gpt2' is not defined

#### Highest Precision + Float16

In [None]:
%%timeit -n 5 -r 5
with torch.amp.autocast(device_type="cuda", dtype=torch.float16):
    logits = gpt2(**batch)
    loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), batch["labels"].reshape(-1))
loss.backward()

#### Highest Precision + BFloat16

In [None]:
%%timeit -n 5 -r 5
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
    logits = gpt2(**batch)
    loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), batch["labels"].reshape(-1))
loss.backward()

#### High Precision + Float32

In [17]:
torch.set_float32_matmul_precision("high")

In [None]:
%%timeit -n 5 -r 5
logits = gpt2(**batch)
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), batch["labels"].reshape(-1))
loss.backward()

#### High Precision + Float16

In [None]:
%%timeit -n 5 -r 5
with torch.amp.autocast(device_type="cuda", dtype=torch.float16):
    logits = gpt2(**batch)
    loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), batch["labels"].reshape(-1))
loss.backward()

#### High Precision + BFloat16

In [None]:
%%timeit -n 5 -r 5
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
    logits = gpt2(**batch)
    loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), batch["labels"].reshape(-1))
loss.backward()

Slowest Run (`precision="highest"` and `dtype=torch.float32`): 256ms
Fastest Run (`precision="high"` and `dtype=torch.bfloat16`): 143ms

This is a ~1.8x speedup. We also see that the difference between `float16` and `bfloat16` is very small, so we prefer `bfloat16` because it has no precision loss. Further, we also get to 145ms with dtype `bfloat16` and `precision="highest"`, so it seems that in this case precision `'high'` is redundant (i.e. it sets less precision where it is already set).

## Compile Model

PyTorch offers `torch.compile` which can compile your model into a static graph, and can lead to significant performance. 

In [None]:
# Compile model