# Optimization!


In [None]:
!pip install torch tiktoken transformers

In [None]:
from transformers import GPT2LMHeadModel
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import tiktoken
import math
import inspect
import os
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [None]:
class DataLoaderLite:
  def __init__(self, B, T):
    self.B=B
    self.T=T

    # at init, load tokens from disk and store them in memory
    with open('input.txt', 'r') as f:
      text = f.read()
    enc = tiktoken.get_encoding('gpt2')
    tokens = enc.encode(text)
    self.tokens = torch.tensor(tokens)
    print(f"Loaded {len(self.tokens)} tokens")
    print(f"1 epoch = {len(self.tokens) // (B*T)} batches")
    self.current_size = 0 # state

  def next_batch(self):
    B, T = self.B, self.T
    buf = self.tokens[self.current_size:self.current_size + B*T + 1] # +1 coz we need it in 'y'
    # buf = buf.to(device) dont do this here to save space on gpu
    x = buf[:-1].view(B, T) # inputs
    y = buf[1:].view(B, T) # targets
    self.current_size += B*T # advance position in tensor
    # if loading next batch would be out of bounds, reset
    if self.current_size + B*T + 1 > len(self.tokens):
      self.current_size = 0
    return x, y

In [None]:

class CausalSelfAttention(nn.Module):
  def __init__(self, config):
    super().__init__()
    # make sure hidden dim is a multiple of no. of heads
    assert config.n_embed % config.n_head == 0

    # a single linear layer to compute Q, K, V simultaneously
    self.c_attn=nn.Linear(config.n_embed, 3 * config.n_embed)

    # output projection
    self.c_proj = nn.Linear(config.n_embed, config.n_embed)
    self.c_proj.NANOGPT_SCALE_INIT = 1 # flag for weight initialization of c_proj, use std = 0.02/sqroot(num layers)

    # regularization
    self.n_head = config.n_head
    self.n_embed = config.n_embed

    # not really a bias, more of a mask, but following OpenAI naming convention
    self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                         .view(1, 1,config.block_size, config.block_size ))



  def forward(self, x):
    B, T, C = x.size()  # Batch size, sequence length, n_embed
    qkv= self.c_attn(x)
    q,k,v = qkv.split(self.n_embed, dim=2)
    k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

    # explanation : C = n_head * head_size
    # k.shape = (B, T, n_head, head_size)
    # k = k.transpose(1, 2)
    # Before transpose: (B, T, n_head, head_size)
    # After transpose:  (B, n_head, T, head_size)

    # similar for q and v
    q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
    v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

    # Attention
    # att = (q @ k.transpose(-2,-1)) * (1.0 / math.sqrt(k.size(-1)))
    # att = att.masked_fill(self.bias[:, :, :T, :T] == 0.0, float('-inf'))
    # att = F.softmax(att, dim=-1)
    # y = att @ v # (B, nh, T, hs) x (B, nh, T, hs) -> (B, nh, T, hs), basically a weighted sum of values

    y = F.scaled_dot_product_attention(q, k, v, is_causal=True) # flash attention

    y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side

    # output projection
    y = self.c_proj(y)

    return y


In [None]:
class MLP(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.c_fc = nn.Linear(config.n_embed, 4 * config.n_embed)         # ffn. increasing hidden dim size increases capacity of model to learn, 4*embed dim is just design choice
    self.gelu = nn.GELU(approximate='tanh')                            # activation
    self.c_proj = nn.Linear( 4 * config.n_embed, config.n_embed) # projection

  def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        return x

In [None]:
class Block(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.ln_1 = nn.LayerNorm(config.n_embed)  # layer norm 1
    self.attn = CausalSelfAttention(config) # causal attention
    self.ln_2 = nn.LayerNorm(config.n_embed) # layer norm 2
    self.mlp = MLP(config) # fnn

  def forward(self, x):
    x = x + self.attn(self.ln_1(x))
    x = x + self.mlp(self.ln_2(x))
    return x

@dataclass
class GPTConfig:
  block_size : int = 1024    # max sequence length
  vocab_size : int = 50257   # number of tokens: 50,000 BPE merges + 256 bytes tokens + 1 <|endoftext|> token
  n_layer : int = 12
  n_head : int = 12
  n_embed : int = 768

class GPT(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.config=config

    self.transformer=nn.ModuleDict(dict(
        wte = nn.Embedding(config.vocab_size, config.n_embed),  # weights for token embeddings
        wpe = nn.Embedding(config.block_size, config.n_embed),  # weights for positional embeddings
        h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]), # block for each layer
        ln_f = nn.LayerNorm(config.n_embed),  # final layer normalisation
        ))
    self.lm_head = nn.Linear(config.n_embed, config.vocab_size,bias=False) # last second, linear layer

    # weight-sharing scheme
    self.transformer.wte.weight = self.lm_head.weight

    # initialize parameters
    self.apply(self._init_weights)

  def _init_weights(self, module):
    if isinstance(module, nn.Linear):
      std = 0.02
      if hasattr(module, 'NANOGPT_SCALE_INIT'): # will be true only for output projection, `c_proj` layer
        std *= (2 * self.config.n_layer) ** -0.05 # scale std by 1/sqrt(no_of_layers) acc to GPT paper
        # we are doing 2 * no of layers bcoz every layer has 2 blocks that add to residual stream - attention and then mlp
        torch.nn.init.normal_(module.weight, mean=0.0, std = std) # inititalise weights according to gpt2 official code, i.e., mean 0, std 0.02 for weights
      # sqroot n thing is done to control the growth of activations in residual stream in forward pass as each residual stream adds its data so we scale down every contribution to residual stream
        torch.nn.init.zeros_(module.bias) # and normal initialisation for bias
    elif isinstance(module, nn.Embedding):
      torch.nn.init.normal_(module.weight, mean=0.0, std = 0.02)



  def forward(self, idx, targets=None):
    # idx (B, T) Batch size, B sequences, each of length T stacked up, T<=block_size
    B, T = idx.size()
    assert T <= self.config.block_size, f"Cannot forward sequence of length {T}, block size is only {self.config.block_size}"
    # forward the token and posisition embeddings
    pos = torch.arange(0, T, dtype=torch.long, device=idx.device) # shape (T), arange iterate from 0 to T
    pos_emb = self.transformer.wpe(pos) # shape (T, n_embd) # identical for every single row (batch)
    tok_emb = self.transformer.wte(idx) # shape (B, T, n_embd)
    x = tok_emb + pos_emb # internal broadcasting
    # forward the blocks of transformer
    for block in self.transformer.h:
      x = block(x)
    # forward the final layernorm
    x = self.transformer.ln_f(x)
    # forward the final classifier
    logits=self.lm_head(x) # (B, T, vocab_size)
    loss=None
    if targets is not None:
      loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) # logits - (B*T, vocab_size)
    return logits, loss


  @classmethod
  def from_pretrained(cls, model_type):
      # Loads pretrained GPT-2 model weights from huggingface
      assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
      print("loading weights from pretrained gpt: %s" % model_type)

      # n_layer, n_head and n_embed are determined from model_type
      config_args = {
          'gpt2':         dict(n_layer=12, n_head=12, n_embed=768),  # 124M params
          'gpt2-medium':  dict(n_layer=24, n_head=16, n_embed=1024), # 350M params
          'gpt2-large':   dict(n_layer=36, n_head=20, n_embed=1280), # 774M params
          'gpt2-xl':      dict(n_layer=48, n_head=25, n_embed=1600), # 1558M params
      }[model_type]
      config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints
      config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints
      # create a from-scratch initialized minGPT model
      config = GPTConfig(**config_args)
      model = GPT(config)
      sd = model.state_dict()
      sd_keys = sd.keys()
      sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param

      # init a huggingface/transformers model
      model_hf = GPT2LMHeadModel.from_pretrained(model_type)
      sd_hf = model_hf.state_dict()

      # copy while ensuring all of the parameters are aligned and match in names and shapes
      sd_keys_hf = sd_hf.keys()
      sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer
      sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer)
      transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
      # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear
      # this means that we have to transpose these weights when we import them
      assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
      for k in sd_keys_hf:
          if any(k.endswith(w) for w in transposed):
              # special treatment for the Conv1D weights we need to transpose
              assert sd_hf[k].shape[::-1] == sd[k].shape
              with torch.no_grad():
                  sd[k].copy_(sd_hf[k].t())
          else:
              # vanilla copy over the other parameters
              assert sd_hf[k].shape == sd[k].shape
              with torch.no_grad():
                  sd[k].copy_(sd_hf[k])

      return model

  def configure_optimizers(self, weight_decay, learning_rate, device_type):
      # start with all of the candidate parameters (that require grad)
      param_dict = {pn: p for pn, p in self.named_parameters()}  # (name, tesnor)
      param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
      # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
      # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
      # as parameters with dim = 1 include biases, layer norm gamma, beta
      decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
      nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
      optim_groups = [
          {'params': decay_params, 'weight_decay': weight_decay},
          {'params': nodecay_params, 'weight_decay': 0.0}
      ]
      num_decay_params = sum(p.numel() for p in decay_params)
      num_nodecay_params = sum(p.numel() for p in nodecay_params)
      # if master_process:
      print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
      print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
      # Create AdamW optimizer and use the fused version if it is available
      fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
      use_fused = fused_available and device_type == "cuda"
      # if master_process:
      print(f"using fused AdamW: {use_fused}")
      optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=(0.9, 0.95), eps=1e-8, fused=use_fused)
      return optimizer
      # fused means to fuse all weight updates into single kernel


By default, all numerical values—such as weights, biases, and logits—are stored in `float32` (32-bit floating point). However, deep learning workloads can often tolerate much lower numerical precision.

In [None]:
train_loader = DataLoaderLite(B=8, T=1024)
model = GPT(GPTConfig())
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
for i in range(50):
  t0 = time.time()
  x, y = train_loader.next_batch()
  x, y = x.to(device), y.to(device)
  optimizer.zero_grad()
  logits, loss = model(x, y)
  loss.backward()
  optimizer.step()
  torch.cuda.synchronize()
  t1 = time.time()
  dt = (t1-t0)*1000
  tokens_per_sec = (train_loader.B * train_loader.T) / (t1 - t0)
  print(f"step : {i}, loss : {loss.item()}, dt : {dt:.2f}ms, tok/sec : {tokens_per_sec}")

  # first iteration is usually slower since the model performs various one-time initializations, memory allocations etc.

Loaded 338025 tokens
1 epoch = 41 batches
step : 0, loss : 10.942951202392578, dt : 858.21ms, tok/sec : 9545.42972068293
step : 1, loss : 9.644050598144531, dt : 455.96ms, tok/sec : 17966.477501992224
step : 2, loss : 9.0220947265625, dt : 455.22ms, tok/sec : 17995.591344033703
step : 3, loss : 8.714418411254883, dt : 455.65ms, tok/sec : 17978.7269478367
step : 4, loss : 8.611956596374512, dt : 455.68ms, tok/sec : 17977.598134836153
step : 5, loss : 8.491798400878906, dt : 454.40ms, tok/sec : 18028.109777123434
step : 6, loss : 8.387723922729492, dt : 453.86ms, tok/sec : 18049.560166271545
step : 7, loss : 8.011007308959961, dt : 454.82ms, tok/sec : 18011.458209534194
step : 8, loss : 7.754319190979004, dt : 455.32ms, tok/sec : 17991.624284269827
step : 9, loss : 7.754043102264404, dt : 454.47ms, tok/sec : 18025.253576749554
step : 10, loss : 7.741122722625732, dt : 453.89ms, tok/sec : 18048.460363209644
step : 11, loss : 7.565569877624512, dt : 454.51ms, tok/sec : 18023.891998506042
s

## TF32 (TensorFloat-32)
FP32 - sign : 1, exponent : 8, mantissa : 23

TF32 - sign : 1, exponent : 8, mantissa : 10

13 mantissa bits are truncated which makes matrix multiplication faster.
Speedup is achieved at the cost of precision.
Inputs are fp32, outputs are fp32, but internally, numbers get truncated to perform operations faster

In [None]:
train_loader = DataLoaderLite(B=8, T=1024)
torch.set_float32_matmul_precision('high') # utilises tf-32 precision for all matmul in linear layers
model = GPT(GPTConfig())
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
for i in range(50):
  t0 = time.time()
  x, y = train_loader.next_batch()
  x, y = x.to(device), y.to(device)
  optimizer.zero_grad()
  logits, loss = model(x, y)
  loss.backward()
  optimizer.step()
  torch.cuda.synchronize()
  t1 = time.time()
  dt = (t1-t0)*1000
  tokens_per_sec = (train_loader.B * train_loader.T) / (t1 - t0)
  print(f"step : {i}, loss : {loss.item()}, dt : {dt:.2f}ms, tok/sec : {tokens_per_sec}")



Loaded 338025 tokens
1 epoch = 41 batches


  _C._set_float32_matmul_precision(precision)


step : 0, loss : 11.05107307434082, dt : 385.26ms, tok/sec : 21263.516990211654
step : 1, loss : 9.797891616821289, dt : 363.87ms, tok/sec : 22513.835937361786
step : 2, loss : 9.275678634643555, dt : 363.43ms, tok/sec : 22540.805248190194
step : 3, loss : 8.857524871826172, dt : 363.49ms, tok/sec : 22537.049898694007
step : 4, loss : 8.780006408691406, dt : 363.57ms, tok/sec : 22531.80337638628
step : 5, loss : 8.659692764282227, dt : 363.63ms, tok/sec : 22528.582778747437
step : 6, loss : 8.526681900024414, dt : 363.31ms, tok/sec : 22548.12734145317
step : 7, loss : 8.185690879821777, dt : 363.52ms, tok/sec : 22535.21703297604
step : 8, loss : 7.856198310852051, dt : 363.48ms, tok/sec : 22537.81860853492
step : 9, loss : 7.79470682144165, dt : 363.48ms, tok/sec : 22537.774258464848
step : 10, loss : 7.773467063903809, dt : 363.66ms, tok/sec : 22526.780827701594
step : 11, loss : 7.627633571624756, dt : 363.79ms, tok/sec : 22518.778922367295
step : 12, loss : 7.5837082862854, dt : 364

## BF16

**FP16**, which is commonly used after FP32, is more aggressive in reducing numerical precision. <br>
sign : 1, exponent : 5, mantissa : 7 <br>
The exponent determines the numeric range the format can represent, while the mantissa controls how precisely numbers inside that range can be expressed. Because FP16 shortens the exponent compared to FP32, its numeric range becomes much smaller, even though the internal precision is also reduced.


FP32 : exponent = 8 bits -> range = -126 to +127  <br>
max value = (1.11111...binary)* 2^127 ~ 3.4 * 10^38 <br>
min value approx = 2^-126 ~ 1.18 * 10^-38 <br>

FP16 : exponent = 5 bits -> range = -14 to +15 <br>
max value = (1.111111.... binary) * 2^15  ~ 6.55 * 10^4 <br>
min value = 2^-14 = 6.1*10^-5 <br>

Gradients are often extremely small. During backpropagation, it is common to see values like 1e−6, 1e−7, or 1e−8. FP16 simply cannot represent many numbers smaller than about 1e−5, so these tiny gradients underflow to zero.


So why GRADIENT SCALING? <br>
The idea is to multiply the loss (and therefore all gradients) by a large constant, such as 2^10: <br>

scaled_loss = loss * 2^10

This makes tiny gradients large enough to fit into FP16 without being rounded to zero.

Then after backpropagation:
scaled_gradients / 2^10

HOW BF16 AVOIDS THE PROBLEM ?
BF16 keeps FP32’s exponent size (8 bits), so it has the same numeric range, just lower precision. <br>

**BF16** - sign : 1, exponent : 8, mantissa : 7 <br>


When using `torch.autocast` with BF16, activations are converted to the BF16 dtype, while model weights remain in FP32. This is known as **Mixed-Precision training**. Only certain operations, such as matrix multiplications, run in BF16 for speed, while operations that are more numerically sensitive, such as layer normalization, continue using FP32 internally.

In [None]:
train_loader = DataLoaderLite(B=8, T=1024)
torch.set_float32_matmul_precision('high') # utilises tf-32 precision for all matmul in linear layers
model = GPT(GPTConfig())
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
for i in range(50):
  t0 = time.time()
  x, y = train_loader.next_batch()
  x, y = x.to(device), y.to(device)
  optimizer.zero_grad()
  with torch.autocast(device_type=device, dtype=torch.bfloat16):
    logits, loss = model(x, y)
  loss.backward()
  optimizer.step()
  torch.cuda.synchronize()
  t1 = time.time()
  dt = (t1-t0)*1000
  tokens_per_sec = (train_loader.B * train_loader.T) / (t1 - t0)
  print(f"step : {i}, loss : {loss.item()}, dt : {dt:.2f}ms, tok/sec : {tokens_per_sec}")


Loaded 338025 tokens
1 epoch = 41 batches
step : 0, loss : 10.88330078125, dt : 367.87ms, tok/sec : 22268.58632156576
step : 1, loss : 9.643840789794922, dt : 271.62ms, tok/sec : 30160.2811423074
step : 2, loss : 8.933135986328125, dt : 271.85ms, tok/sec : 30134.25312614561
step : 3, loss : 8.750102996826172, dt : 271.79ms, tok/sec : 30141.099639198816
step : 4, loss : 8.590248107910156, dt : 271.76ms, tok/sec : 30143.87613895766
step : 5, loss : 8.452657699584961, dt : 271.74ms, tok/sec : 30145.91256390482
step : 6, loss : 8.377922058105469, dt : 271.71ms, tok/sec : 30150.11852000049
step : 7, loss : 8.070959091186523, dt : 271.73ms, tok/sec : 30148.002168988034
step : 8, loss : 7.802304267883301, dt : 271.73ms, tok/sec : 30147.790550227688
step : 9, loss : 7.761746406555176, dt : 271.78ms, tok/sec : 30142.183735378385
step : 10, loss : 7.7245330810546875, dt : 271.68ms, tok/sec : 30153.187756799878
step : 11, loss : 7.534664154052734, dt : 271.70ms, tok/sec : 30151.38847323649
step :

## torch.compile() - TC

It costs compilation time but execution will be much faster. SThe speedup comes from reducing Python overhead and minimizing unnecessary GPU read–write operations. TC will know what kind of operations will come and will try to optimize them. It takes out python interpreter from forward pass entirely and compiles entire neural net as a single object without the interpreter.

Read/writes : between GPU and HBM.
GPU has HBM (High Bandwidth Memory), CPU has RAM
Without `torch.compile()`, Python often cannot detect when a variable is reused repeatedly, so tensors may perform multiple round trips between the GPU and HBM.With TC, the compiler understands the entire operation graph ahead of time and avoids these unnecessary transfers.

If certain operations cause frequent switches between GPU and HBM, TC keeps the relevant tensors on the GPU until all dependent computations are complete, reducing memory traffic and improving performance.

In [None]:
train_loader = DataLoaderLite(B=8, T=1024)
model = GPT(GPTConfig())
model.to(device)
model = torch.compile(model)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
for i in range(50):
  t0 = time.time()
  x, y = train_loader.next_batch()
  x, y = x.to(device), y.to(device)
  optimizer.zero_grad()
  with torch.autocast(device_type=device, dtype=torch.bfloat16):
    logits, loss = model(x, y)
  loss.backward()
  optimizer.step()
  torch.cuda.synchronize()
  t1 = time.time()
  dt = (t1-t0)*1000
  tokens_per_sec = (train_loader.B * train_loader.T) / (t1 - t0)
  print(f"step : {i}, loss : {loss.item()}, dt : {dt:.2f}ms, tok/sec : {tokens_per_sec}")


Loaded 338025 tokens
1 epoch = 41 batches
step : 0, loss : 10.927841186523438, dt : 35776.88ms, tok/sec : 228.97470031541704
step : 1, loss : 9.605266571044922, dt : 166.52ms, tok/sec : 49195.32722157395
step : 2, loss : 8.896303176879883, dt : 165.55ms, tok/sec : 49484.82444419161
step : 3, loss : 8.641592979431152, dt : 165.81ms, tok/sec : 49404.91977822256
step : 4, loss : 8.505861282348633, dt : 165.40ms, tok/sec : 49528.83587321905
step : 5, loss : 8.390480995178223, dt : 165.72ms, tok/sec : 49431.288932112024
step : 6, loss : 8.291648864746094, dt : 165.38ms, tok/sec : 49533.119929880406
step : 7, loss : 7.981864929199219, dt : 165.75ms, tok/sec : 49424.818421385746
step : 8, loss : 7.707769870758057, dt : 165.38ms, tok/sec : 49533.97683014734
step : 9, loss : 7.687943458557129, dt : 165.70ms, tok/sec : 49437.6900049064
step : 10, loss : 7.670999526977539, dt : 165.43ms, tok/sec : 49518.55716006894
step : 11, loss : 7.490860939025879, dt : 165.72ms, tok/sec : 49434.13364870134
st

## Flash Attention
Flash Attention replaces the attention implementation inside the `CausalSelfAttention` class. The four lines of PyTorch attention code are fused into a single highly optimized kernel. This fusion goes far beyond what `torch.compile()` can automatically detect, because it requires an algorithmic re-write of how attention is computed. Flash Attention is very mindful about how pytorch does that computation such that fewer read/writes to HBM are ensured.

The key idea is that Flash Attention **never materializes the full attention matrix in HBM**. Instead, it uses the “online softmax” trick introduced in the paper *Online Normalizer Calculation for Softmax* (2018). This allows the softmax to be computed correctly without storing all logits at once.

The trick is to maintain two running values for each query row:

- `m`: the running maximum of logits seen so far  
- `s`: the running sum of `exp(logits − m)`

Whenever the running maximum changes from `m` to `m'`, the running sum can be updated using:
s = s * exp(m − m') + Σ exp(logits − m')   [sum over all logits]

`(logits - m)` keeps the exponentials numerically stable and prevents overflow.  
By the end of the streaming pass over the logits, Flash Attention has:
- the correct maximum
- the correct sum
So we can compute the correct softmax, even though we never saw the whole row at once.


In [None]:
train_loader = DataLoaderLite(B=8, T=1024)
model = GPT(GPTConfig())
model.to(device)
model=torch.compile(model)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
for i in range(50):
  t0 = time.time()
  x, y = train_loader.next_batch()
  x, y = x.to(device), y.to(device)
  optimizer.zero_grad()
  with torch.autocast(device_type=device, dtype=torch.bfloat16):
    logits, loss = model(x, y)
  loss.backward()
  optimizer.step()
  torch.cuda.synchronize()
  t1 = time.time()
  dt = (t1-t0)*1000
  tokens_per_sec = (train_loader.B * train_loader.T) / (t1 - t0)
  print(f"step : {i}, loss : {loss.item()}, dt : {dt:.2f}ms, tok/sec : {tokens_per_sec}")


Loaded 338025 tokens
1 epoch = 41 batches
step : 0, loss : 10.986957550048828, dt : 1649.81ms, tok/sec : 4965.431513590949
step : 1, loss : 9.674814224243164, dt : 126.05ms, tok/sec : 64991.31494546816
step : 2, loss : 9.05832576751709, dt : 125.59ms, tok/sec : 65226.46887481586
step : 3, loss : 8.783764839172363, dt : 125.67ms, tok/sec : 65185.13852516552
step : 4, loss : 8.659647941589355, dt : 125.44ms, tok/sec : 65308.045662760065
step : 5, loss : 8.559056282043457, dt : 125.65ms, tok/sec : 65195.15694202786
step : 6, loss : 8.460760116577148, dt : 125.52ms, tok/sec : 65265.868061649744
step : 7, loss : 8.097240447998047, dt : 125.69ms, tok/sec : 65176.23585023768
step : 8, loss : 7.821047782897949, dt : 125.62ms, tok/sec : 65214.58412669491
step : 9, loss : 7.800475120544434, dt : 125.70ms, tok/sec : 65169.06601927016
step : 10, loss : 7.793295860290527, dt : 125.67ms, tok/sec : 65184.272880593875
step : 11, loss : 7.582406997680664, dt : 125.79ms, tok/sec : 65126.327021566154
ste




## Numbers

CUDA kernels and GPU hardware are often optimized around powers of two, and many low-level GPU blocks are designed with power-of-two dimensions in mind. Because of this, changing certain model parameters (such as vocabulary size) to the nearest power of two can lead to measurable performance improvements.

For example, increasing the vocabulary size to a power of two can speed up indexing and matrix operations. However, GPT-2 has a fixed vocabulary size of 50,257. If we expand this to the next power of two, such as 50,304, the tokens from index 50,258 onward will never actually be produced because the tokenizer will never generate them. This means we are wasting a small amount of embedding space.

The model must also implicitly learn to assign extremely low probabilities (effectively negative infinity logits) to these unused token IDs, since they should never be predicted during inference.


In [None]:
train_loader = DataLoaderLite(B=8, T=1024)
model = GPT(GPTConfig(vocab_size=50304)) # changed to nearest power of 2
model.to(device)
model=torch.compile(model)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
for i in range(50):
  t0 = time.time()
  x, y = train_loader.next_batch()
  x, y = x.to(device), y.to(device)
  optimizer.zero_grad()
  with torch.autocast(device_type=device, dtype=torch.bfloat16):
    logits, loss = model(x, y)
  loss.backward()
  optimizer.step()
  torch.cuda.synchronize()
  t1 = time.time()
  dt = (t1-t0)*1000
  tokens_per_sec = (train_loader.B * train_loader.T) / (t1 - t0)
  print(f"step : {i}, loss : {loss.item()}, dt : {dt:.2f}ms, tok/sec : {tokens_per_sec}")


Loaded 338025 tokens
1 epoch = 41 batches
step : 0, loss : 10.941413879394531, dt : 1396.59ms, tok/sec : 5865.723813466514
step : 1, loss : 9.538830757141113, dt : 130.61ms, tok/sec : 62723.373155798305
step : 2, loss : 8.87597942352295, dt : 123.07ms, tok/sec : 66564.51706067762
step : 3, loss : 8.73884105682373, dt : 123.35ms, tok/sec : 66410.51591854406
step : 4, loss : 8.543574333190918, dt : 123.03ms, tok/sec : 66586.83362854496
step : 5, loss : 8.380556106567383, dt : 123.33ms, tok/sec : 66421.42816436915
step : 6, loss : 8.306436538696289, dt : 123.03ms, tok/sec : 66587.86597055073
step : 7, loss : 8.031896591186523, dt : 123.30ms, tok/sec : 66437.35363851694
step : 8, loss : 7.7480974197387695, dt : 123.02ms, tok/sec : 66592.77003225006
step : 9, loss : 7.698884010314941, dt : 123.32ms, tok/sec : 66428.4909404459
step : 10, loss : 7.647953510284424, dt : 123.25ms, tok/sec : 66466.01303022332
step : 11, loss : 7.520112037658691, dt : 123.32ms, tok/sec : 66431.44491039614
step : 

# Hyperparameters, AdamW, Gradient Clipping

gpt-2 - gpt3 change : gpt3 context window 2048 from 1024 and was trained for longer time and on larger dataset

AdamW beta1 =0.9 , beta2=0.95, change in code Gradient Clipping

Norms are high at beginning and then stablise during training bcoz at start model is completely random

In [None]:
train_loader = DataLoaderLite(B=8, T=1024)
model = GPT(GPTConfig(vocab_size=50304)) # changed to nearest power of 2
model.to(device)
model=torch.compile(model)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, betas=(0.9, 0.95), eps=1e-8)
for i in range(50):
  t0 = time.time()
  x, y = train_loader.next_batch()
  x, y = x.to(device), y.to(device)
  optimizer.zero_grad()
  with torch.autocast(device_type=device, dtype=torch.bfloat16):
    logits, loss = model(x, y)
  loss.backward()
  norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
  optimizer.step()
  torch.cuda.synchronize()
  t1 = time.time()
  dt = (t1-t0)*1000
  tokens_per_sec = (train_loader.B * train_loader.T) / (t1 - t0)
  print(f"step : {i} | loss : {loss.item()} | dt : {dt:.2f}ms | tok/sec : {tokens_per_sec} | norm : {norm:.4f}")


Loaded 338025 tokens
1 epoch = 41 batches
step : 0 | loss : 10.999530792236328 | dt : 197.48ms | tok/sec : 41483.290495110945 | norm : 14.0023
step : 1 | loss : 9.577607154846191 | dt : 125.98ms | tok/sec : 65024.8923996374 | norm : 4.3634
step : 2 | loss : 8.831929206848145 | dt : 125.31ms | tok/sec : 65374.52384875462 | norm : 2.9763
step : 3 | loss : 8.84640884399414 | dt : 125.59ms | tok/sec : 65229.19307304738 | norm : 4.2062
step : 4 | loss : 8.569517135620117 | dt : 124.85ms | tok/sec : 65615.09534503568 | norm : 3.7140
step : 5 | loss : 8.39919376373291 | dt : 125.07ms | tok/sec : 65497.77231153997 | norm : 2.7486
step : 6 | loss : 8.354333877563477 | dt : 124.81ms | tok/sec : 65636.40358672438 | norm : 2.1312
step : 7 | loss : 8.035881042480469 | dt : 125.17ms | tok/sec : 65448.24267082674 | norm : 2.6961
step : 8 | loss : 7.711540222167969 | dt : 125.41ms | tok/sec : 65323.81422495822 | norm : 1.7410
step : 9 | loss : 7.650970458984375 | dt : 125.64ms | tok/sec : 65202.579977

# Learning Rate :  Cosine decay LR schedule with Warm-up

In [None]:
max_lr = 6e-4  # as per GPT-paper
min_lr = max_lr * 0.1
warmup_steps = 10
max_steps = 50
def get_lr(it):
  #  linear warmup for warm-iter steps
  if it < warmup_steps:
    return max_lr * (it+1)/ warmup_steps

  # if it > lr_decay iters, return min lr
  if it > max_steps:
    return min_lr

  # in between, use cosine decay down to min lr
  decay_ratio = (it-warmup_steps) / (max_steps-warmup_steps)
  assert 0<=decay_ratio<=1
  coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
  return min_lr + coeff * (max_lr-min_lr)


In [None]:
train_loader = DataLoaderLite(B=8, T=1024)
model = GPT(GPTConfig(vocab_size=50304)) # changed to nearest power of 2
model.to(device)
model=torch.compile(model)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, betas=(0.9, 0.95), eps=1e-8)
for step in range(max_steps):
  t0 = time.time()
  x, y = train_loader.next_batch()
  x, y = x.to(device), y.to(device)
  optimizer.zero_grad()
  with torch.autocast(device_type=device, dtype=torch.bfloat16):
    logits, loss = model(x, y)
  loss.backward()
  norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
  lr = get_lr(step)
  for param_group in optimizer.param_groups:
    param_group['lr'] = lr
  optimizer.step()
  torch.cuda.synchronize()
  t1 = time.time()
  dt = (t1-t0)*1000
  tokens_per_sec = (train_loader.B * train_loader.T) / (t1 - t0)
  print(f"step : {i} | loss : {loss.item()} | dt : {dt:.2f}ms | tok/sec : {tokens_per_sec} | lr : {lr:.6e}")


Loaded 338025 tokens
1 epoch = 41 batches
step : 49 | loss : 10.987662315368652 | dt : 135.39ms | tok/sec : 60506.558510076284 | lr : 6.000000e-05
step : 49 | loss : 9.78203010559082 | dt : 125.44ms | tok/sec : 65304.94251182663 | lr : 1.200000e-04
step : 49 | loss : 9.17325496673584 | dt : 124.80ms | tok/sec : 65643.30039298399 | lr : 1.800000e-04
step : 49 | loss : 9.601394653320312 | dt : 125.23ms | tok/sec : 65416.84204834715 | lr : 2.400000e-04
step : 49 | loss : 9.002758026123047 | dt : 124.93ms | tok/sec : 65573.64661337018 | lr : 3.000000e-04
step : 49 | loss : 8.708261489868164 | dt : 125.20ms | tok/sec : 65431.79096706867 | lr : 3.600000e-04
step : 49 | loss : 8.647960662841797 | dt : 125.23ms | tok/sec : 65414.35122595728 | lr : 4.200000e-04
step : 49 | loss : 8.216434478759766 | dt : 125.61ms | tok/sec : 65219.16419372289 | lr : 4.800000e-04
step : 49 | loss : 7.806668758392334 | dt : 125.28ms | tok/sec : 65390.32453335769 | lr : 5.400000e-04
step : 49 | loss : 7.6600837707

All models use weight decay of 0.1 to provide a small amount of regularization

In [None]:
optimizer = model.configure_optimizers(weight_decay=0.1, learning_rate=6e-4, device_type=device)
for step in range(max_steps):
  t0 = time.time()
  x, y = train_loader.next_batch()
  x, y = x.to(device), y.to(device)
  optimizer.zero_grad()
  with torch.autocast(device_type=device, dtype=torch.bfloat16):
    logits, loss = model(x, y)
  loss.backward()
  norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
  lr = get_lr(step)
  for param_group in optimizer.param_groups:
    param_group['lr'] = lr
  optimizer.step()
  torch.cuda.synchronize()
  t1 = time.time()
  dt = (t1-t0)*1000
  tokens_per_sec = (train_loader.B * train_loader.T) / (t1 - t0)
  print(f"step : {i} | loss : {loss.item()} | dt : {dt:.2f}ms | tok/sec : {tokens_per_sec} | lr : {lr:.4f}")


num decayed parameter tensors: 50, with 124,354,560 parameters
num non-decayed parameter tensors: 98, with 121,344 parameters
using fused AdamW: True
step : 49 | loss : 6.1445417404174805 | dt : 136.97ms | tok/sec : 59810.67647504243 | lr : 0.0001
step : 49 | loss : 6.166140556335449 | dt : 124.93ms | tok/sec : 65571.64437608182 | lr : 0.0001
step : 49 | loss : 6.219522953033447 | dt : 122.36ms | tok/sec : 66950.77126083136 | lr : 0.0002
step : 49 | loss : 6.533832550048828 | dt : 121.79ms | tok/sec : 67263.09617756339 | lr : 0.0002
step : 49 | loss : 6.195084571838379 | dt : 121.84ms | tok/sec : 67233.3507509999 | lr : 0.0003
step : 49 | loss : 6.20240592956543 | dt : 121.53ms | tok/sec : 67407.59021523118 | lr : 0.0004
step : 49 | loss : 6.032421588897705 | dt : 121.75ms | tok/sec : 67284.43432080586 | lr : 0.0004
step : 49 | loss : 6.130808353424072 | dt : 121.64ms | tok/sec : 67345.75796502933 | lr : 0.0005
step : 49 | loss : 6.22491455078125 | dt : 121.91ms | tok/sec : 67199.16287

In original GPT-3, they use batch size = 0.5M (in terms of tokens which is rougly 500 rows [0.5M/1024] ). But we can't do that else our small GPUs would explode

Gradient Accumulation to Rescue!



In [None]:
total_batch_size = 524288 # 2^19, ~0.5M in number of tokens
B = 16 # micro batch size
T = 1024 # sequence length
assert total_batch_size % (B * T) == 0
grad_accum_steps = total_batch_size // (B * T)
print(f"total desired batch size : {total_batch_size}")
print(f"calculated gradient accumulation steps : {grad_accum_steps}")

total desired batch size : 524288
calculated gradient accumulation steps : 32


In [None]:
train_loader = DataLoaderLite(B=B, T=T)
model = GPT(GPTConfig(vocab_size=50304)) # changed to nearest power of 2
model.to(device)
model = torch.compile(model)

optimizer = model.configure_optimizers(weight_decay=0.1, learning_rate=6e-4, device_type=device)
for step in range(max_steps):
  loss_accum = 0.0
  t0 = time.time()
  optimizer.zero_grad()
  for micro_step in range(grad_accum_steps):
      x, y = train_loader.next_batch()
      x, y = x.to(device), y.to(device)

      with torch.autocast(device_type=device, dtype=torch.bfloat16):
        logits, loss = model(x, y)
      # loss is scaled below to account for gradient accumulation as the gradients just add on each successful backward()
      # addition of gradients correspons to SUM in the objective, but instead of SUM, we want MEAN.
      # So loss is scaled here
      loss = loss / grad_accum_steps
      loss_accum += loss.detach() # detach the tensor from computational graph
      loss.backward()
  norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

  # determine and set lr for this iteration
  lr = get_lr(step)
  for param_group in optimizer.param_groups:
    param_group['lr'] = lr
  optimizer.step()
  torch.cuda.synchronize() # wait for gpu to finish work
  t1 = time.time()
  dt = (t1-t0)*1000
  tokens_processed = train_loader.B * train_loader.T * grad_accum_steps
  tokens_per_sec = tokens_processed / (t1 - t0)
  print(f"step : {step} | loss : {loss_accum.item():.6f} | tok/sec : {tokens_per_sec} | lr : {lr:.4f} | dt : {dt:.2f}ms")


Loaded 338025 tokens
1 epoch = 20 batches
num decayed parameter tensors: 50, with 124,354,560 parameters
num non-decayed parameter tensors: 98, with 121,344 parameters
using fused AdamW: True
step : 0 | loss : 10.971720 | tok/sec : 18023.857577590145 | lr : 0.0001 | dt : 29088.56ms
step : 1 | loss : 9.763161 | tok/sec : 77665.71249227236 | lr : 0.0001 | dt : 6750.57ms
step : 2 | loss : 9.382488 | tok/sec : 77690.94549170464 | lr : 0.0002 | dt : 6748.38ms
step : 3 | loss : 9.651525 | tok/sec : 77468.55231203312 | lr : 0.0002 | dt : 6767.75ms
step : 4 | loss : 9.089612 | tok/sec : 77395.60917608687 | lr : 0.0003 | dt : 6774.13ms
step : 5 | loss : 8.640585 | tok/sec : 77361.57720653008 | lr : 0.0004 | dt : 6777.11ms
step : 6 | loss : 8.344876 | tok/sec : 77302.65912190381 | lr : 0.0004 | dt : 6782.28ms
step : 7 | loss : 8.052789 | tok/sec : 77265.79840663375 | lr : 0.0005 | dt : 6785.51ms
step : 8 | loss : 7.710036 | tok/sec : 76974.8246172897 | lr : 0.0005 | dt : 6811.16ms
step : 9 | los

# Distributed Data Parallel

DDP - we have 4 GPUs so we launch 4 processes and each process will be assigned a GPU. each GPU processes slightly different part of data and then we do average of gradients of all 4

There will be 4 python interpreters running the same script in parallel, only difference is that each has unique ddp rank