# 📈 Training

This notebook explores some aspects of training a model, and optimising for performance.

## Setup 

In [1]:
import autorootcwd

In [2]:
import math
import numpy as np
import matplotlib.pyplot as plt

import time
import torch
from transformers import AutoModelForCausalLM

Let's load a reasonably sized model for our tiny benchmark experiments and create a dummy batch that is representative in size of the actual workflows during training. 

In [None]:
gpt2 : AutoModelForCausalLM = AutoModelForCausalLM.from_pretrained("mikasenghaas/gpt2-124m-fresh").to("cuda")

print(f"Loaded model with {gpt2.num_parameters()/1e6:.2f}M parameters")

In [None]:
# Micro batch
B, L, V = 8, 1024, 50257
x = torch.randint(0, V, (B, L+1), dtype=torch.long).to("cuda")
batch = {"input_ids": x[:, :-1], "attention_mask": torch.ones_like(x[:, :-1]), "labels": x[:, 1:]}

## Gradient Accumulation

Modern LLM training formulas often require large batches, e.g. 512 batches with a sequence length of 1024. This is a total of ~0.5M tokens per step which is out of memory, even for powerful GPUs. Therefore, we use gradient accumulation to simulate a larger batch size. We simply set a true `micro_batch_size` and accumulate the gradients over `batch_size // micro_batch_size` steps.

In [None]:
# Gradient accumulation

## Compile Model

PyTorch offers `torch.compile` which can compile your model into a static graph, and can lead to significant performance. 

In [None]:
# Compile model

## Mixed Precision

We can speed up training by using lower precision for model weights, activations and gradients. PyTorch offers a one-liner to set the internal precision used in matrix multiplications (most used operation in Transformers), as well as a context manager to use lower precision for activations and gradients.

Let's investigate the theoretical speed-ups we can expect from both of these techniques on a GPT-2 (124M) model.

### Highest Precision

In [6]:
torch.set_float32_matmul_precision("highest")

In [None]:
%%timeit -n 5 -r 5
outputs = gpt2(**batch)
outputs.loss.backward()

In [None]:
%%timeit -n 5 -r 5
with torch.amp.autocast(device_type="cuda", dtype=torch.float16):
    outputs = gpt2(**batch)
outputs.loss.backward()

In [None]:
%%timeit -n 5 -r 5
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
    outputs = gpt2(**batch)
outputs.loss.backward()

### High Precision

In [10]:
torch.set_float32_matmul_precision("high")

In [None]:
%%timeit -n 5 -r 5
outputs = gpt2(**batch)
outputs.loss.backward()

In [None]:
%%timeit -n 5 -r 5
with torch.amp.autocast(device_type="cuda", dtype=torch.float16):
    outputs = gpt2(**batch)
outputs.loss.backward()

In [None]:
%%timeit -n 5 -r 5
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
    outputs = gpt2(**batch)
outputs.loss.backward()

Slowest Run (`precision="highest"` and `dtype=torch.float32`): 256ms
Fastest Run (`precision="high"` and `dtype=torch.bfloat16`): 143ms

This is a ~1.8x speedup. We also see that the difference between `float16` and `bfloat16` is very small, so we prefer `bfloat16` because it has no precision loss. Further, we also get to 145ms with dtype `bfloat16` and `precision="highest"`, so it seems that in this case precision `'high'` is redundant (i.e. it sets less precision where it is already set).