# Lab 4: Memory-Efficient Transformer Training Techniques

The goal of this assignment was to compare several modern memory-optimization techniques (BF16 mixed precision, FlashAttention, windowed attention and gradient checkpointing) used during Transformer training. The compared metrics include: GPU memory usage, maximum batch size that fits into memory, training speed (time per step and total time for 1 epoch) and final model performance (perplexity after 1 epoch).


### Datasets and Tokenizer

The datasets used were downloaded from Speakleash (only high-quality docs were used, resulting in ~25MB files each)
- `wolne_lektury_corpus` - for training  
- `1000_novels_corpus_CLARIN-PL` - for validation 

The SentencePiece tokenizer was used (trained on the training corpus). Vocabulary size: 12000.

Sequence Length was set to 256, and whole datasets were used, resulting in:
- Number of training sequences: 28049
- Number of validation sequences: 24965

The datasets and splits are kept identical across all techniques.

### Model architecture
The from-scratch decoder-only language model was used (from Lab 1).

Model initialisation and hyperparameters display.

In [9]:
import sys
from pathlib import Path

repo_root = Path("..").resolve()
sys.path.insert(0, str(repo_root))


import config
from model import TransformerDecoderOnly

model = TransformerDecoderOnly(
    vocab_size=config.VOCAB_SIZE,
    d_model=config.TX_D_MODEL,
    n_layer=config.TX_N_LAYER,
    n_head=config.TX_N_HEAD,
    d_ff=config.TX_D_FF,
    dropout=config.TX_DROPOUT,
    pad_id=3,
)

model

TransformerDecoderOnly(
  (embed): Embedding(12000, 128, padding_idx=3)
  (posenc): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (blocks): ModuleList(
    (0-3): 4 x DecoderBlock(
      (self_attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
      )
      (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (ff): Sequential(
        (0): Linear(in_features=128, out_features=1024, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=1024, out_features=128, bias=True)
      )
      (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (drop): Dropout(p=0.1, inplace=False)
    )
  )
  (ln_f): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (lm_head): Linear(in_features=128, out_features=12000, bias=False)
)

In [10]:
print(f"Total Parameters: {sum(p.numel() for p in model.parameters()):,}")

Total Parameters: 2,855,680


Training hyperparameters (the setup can be found in `config.py`):
- Learning rate: 1e-3
- Optimizer: Adam
- Number of epochs: just 1 for profiling different training regimes

### Hardware used

- Experiments were conducted on the Athena HPC cluster at AGH University of Science and Technology. 
- The SLURM scheduller was used and the script `training.sh` was used to run the experiment setup from `run_baseline.py`.
- All experiments were run on a single NVIDIA A100-SXM4 GPU with 40 GB HBM2 memory.
- The training stack used PyTorch 2.5.1 with CUDA 12.1 and FlashAttention 2.8.3.
- The GPU natively supports BF16 mixed-precision arithmetic.

```
=== Hardware summary ===

GPU: NVIDIA A100-SXM4-40GB
GPU memory (GB): 39.7
PyTorch: 2.5.1+cu121
CUDA (torch): 12.1
BF16 supported: True

Name: flash_attn
Version: 2.8.3
```


# Experiment setup
For each technique, two configurations were run: (1) using the same batch size as the FP32 baseline, and (2) using the maximum batch size that fits into GPU memory. Peak memory usage and average step time are measured over 20 training steps, while training time and perplexity are measured after one full epoch. All other variables are kept identical.

# Results