# 📈 Training

This notebook explores some aspects of training a model, and optimising for performance.

## Setup 

In [1]:
import autorootcwd

In [2]:
%load_ext autoreload
%autoreload 2

In [14]:
import math
import numpy as np
import matplotlib.pyplot as plt

import time
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer
from datasets import Dataset

from src.train.baseline import train, sample
from src.utils import get_device, get_model, get_tokenizer, get_dataset, get_dataloader, get_micro_dataloader, tokenize, get_optimizer, get_scheduler
from src.config import ModelConfig, TrainConfig, DataConfig, SampleConfig

In [None]:
device = get_device()
print(f"Using device: {device}")

Let's load a small model for debugging and running some tests. We'll also create a dummy batch that is representative in size of the actual workflows during training. 

## Overfitting

In [None]:
# Configuration
model_config = ModelConfig(n_layer=12, n_head=12, n_embd=768) # GPT-2 Small (124M)
train_config = TrainConfig(max_steps=100, max_epochs=-1, micro_batch_size=1, batch_size=1)
data_config = DataConfig(path="mikasenghaas/memorize", seq_length=32)
sample_config = SampleConfig(num_samples=1)

# Load model and tokenizer
model = get_model(model_config)

# Load tokenizer
tokenizer = get_tokenizer()
tokenizer.pad_token = tokenizer.eos_token

# Create dataloader
dataset = get_dataset(data_config, split="train")
data = dataset.map(lambda examples: tokenize(examples["text"], tokenizer, max_length=data_config.seq_length+1, return_tensors=None))
dataloader = get_dataloader(data, batch_size=1, shuffle=False, cycle=True)

# Train
model.train()
model.to(device)
optimizer = get_optimizer(model, train_config.optimizer)
scheduler = get_scheduler(optimizer, train_config.scheduler)
loss_fn = torch.nn.CrossEntropyLoss()
step = 0
while step < train_config.max_steps:
    batch = next(dataloader)
    batchloader = get_micro_dataloader(batch, micro_batch_size=1)
    outputs = train(step, model, batchloader, loss_fn, optimizer, scheduler, train_config, device)

    if (step+1) % 10 == 0 or step == 0:
        print(f"Step: {step+1}, Loss: {outputs.loss:.4f} ({sample(model, tokenizer, sample_config, device=device)[0]})")
    step += 1

## Mixed Precision

We can speed up training by using lower precision for model weights, activations and gradients. PyTorch offers a one-liner to set the internal precision used in matrix multiplications (most used operation in Transformers), as well as a context manager to use lower precision for activations and gradients.

Let's investigate the theoretical speed-ups we can expect from both of these techniques on a GPT-2 (124M) model.

In [21]:
# Micro batch
B, L, V = 4, 1024, 50257
x = torch.randint(0, V, (B, L+1), dtype=torch.long).to("cuda")
batch = {
    "input_ids": x[:, :-1],
    "attention_mask": torch.ones_like(x[:, :-1]),
    "labels": x[:, 1:]
}
batch = {k: v.to(device) for k, v in batch.items()}

In [22]:
# Load GPT2 model
model_config = ModelConfig(n_layer=12, n_head=12, n_embd=768)
model = get_model(model_config).to(device)

#### Highest Precision + Float32

In [23]:
torch.set_float32_matmul_precision("highest")

In [None]:
%%timeit -n 5 -r 5
logits = model.forward(input_ids=batch["input_ids"])
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), batch["labels"].reshape(-1))
loss.backward()

#### Highest Precision + Float16

In [None]:
%%timeit -n 5 -r 5
with torch.amp.autocast(device_type="cuda", dtype=torch.float16):
    logits = model.forward(input_ids=batch["input_ids"])
    loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), batch["labels"].reshape(-1))
loss.backward()

#### Highest Precision + BFloat16

In [None]:
%%timeit -n 5 -r 5
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
    logits = model.forward(input_ids=batch["input_ids"])
    loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), batch["labels"].reshape(-1))
loss.backward()

#### High Precision + Float32

In [29]:
torch.set_float32_matmul_precision("high")

In [None]:
%%timeit -n 5 -r 5
logits = model.forward(input_ids=batch["input_ids"])
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), batch["labels"].reshape(-1))
loss.backward()

#### High Precision + Float16

In [None]:
%%timeit -n 5 -r 5
with torch.amp.autocast(device_type="cuda", dtype=torch.float16):
    logits = model.forward(input_ids=batch["input_ids"])
    loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), batch["labels"].reshape(-1))
loss.backward()

#### High Precision + BFloat16

In [None]:
%%timeit -n 5 -r 5
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
    logits = model.forward(input_ids=batch["input_ids"])
    loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), batch["labels"].reshape(-1))
loss.backward()

This is a ~2.8x speedup. We also see that the difference between `float16` and `bfloat16` is very small, so we prefer `bfloat16` because it has no precision loss. Further, we also get to 145ms with dtype `bfloat16` and `precision="highest"`, so it seems that in this case precision `'high'` is redundant (i.e. it sets less precision where it is already set).

- Slowest Run (`precision="highest"` and `dtype=torch.float32`): 207ms
- Fastest Run (`precision="high"` and `dtype=torch.bfloat16`): 74ms


## Compile Model

PyTorch offers `torch.compile` which can compile your model into a static graph, and can lead to significant performance. 

In [None]:
# Compile model