<a href="https://colab.research.google.com/github/kiankyars/Ultra-Scale-Playbook-Series/blob/main/1_scaling_basics.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Video 1: What is Scaling in LLM Training?
Welcome to the first notebook of the Ultra-Scale Playbook series!

Objective
- Estimate memory usage of a Transformer

Exercise 1: Estimate Transformer Memory Usage (4 bytes per param)

In [None]:
import torch
import torch.nn as nn

# Define a basic Transformer block
class MiniTransformer(nn.Module):
    def __init__(self, embed_dim=768, num_heads=12, ff_dim=3072):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.linear1 = nn.Linear(embed_dim, ff_dim)
        self.linear2 = nn.Linear(ff_dim, embed_dim)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)

    def forward(self, x):
        attn_out, _ = self.attn(x, x, x)
        x = self.norm1(x + attn_out)
        ff = self.linear2(torch.relu(self.linear1(x)))
        return self.norm2(x + ff)

model = MiniTransformer()

Show calculation below:

In [None]:
print("Parameter breakdown per module:")
for name, module in model.named_children():
    num_params = sum(p.numel() for p in module.parameters())
    print(f"- {name}: {num_params} parameters")

total_params_manual = sum(sum(p.numel() for p in module.parameters()) for name, module in model.named_children())
print(f"\nTotal parameters: {total_params_manual}")

Parameter breakdown per module:
- attn: 2362368 parameters
- linear1: 2362368 parameters
- linear2: 2360064 parameters
- norm1: 1536 parameters
- norm2: 1536 parameters

Total parameters: 7087872


Exercise 2: *GPT2* Parameters (Optional)

In [None]:
from transformers import GPT2Model
model = GPT2Model.from_pretrained('gpt2')
def count_params(model):
    params: int = sum(p.numel() for p in model.parameters())
    return f"{params / 1e6:.2f}M"

print(model)
print("Total # of params:", count_params(model))

GPT2Model(
  (wte): Embedding(50257, 768)
  (wpe): Embedding(1024, 768)
  (drop): Dropout(p=0.1, inplace=False)
  (h): ModuleList(
    (0-11): 12 x GPT2Block(
      (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): GPT2Attention(
        (c_attn): Conv1D(nf=2304, nx=768)
        (c_proj): Conv1D(nf=768, nx=768)
        (attn_dropout): Dropout(p=0.1, inplace=False)
        (resid_dropout): Dropout(p=0.1, inplace=False)
      )
      (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (mlp): GPT2MLP(
        (c_fc): Conv1D(nf=3072, nx=768)
        (c_proj): Conv1D(nf=768, nx=3072)
        (act): NewGELUActivation()
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
Total # of params: 124.44M
