#### The SGD Optimizer

In [1]:
from collections.abc import Callable, Iterable
from typing import Optional
import torch
import math
class SGD(torch.optim.Optimizer):
    def __init__(self, params, lr=1e-3):
        if lr < 0:
            raise ValueError(f"Invalid learning rate: {lr}")
        defaults = {"lr": lr}
        super().__init__(params, defaults)
    
    def step(self, closure: Optional[Callable] = None):
        loss = None if closure is None else closure()
        for group in self.param_groups:
            lr = group["lr"] # Get the learning rate.
            for p in group["params"]:
                if p.grad is None:
                    continue
                state = self.state[p] # Get state associated with p.
                t = state.get("t", 0) # Get iteration number from the state, or initial value.
                grad = p.grad.data # Get the gradient of loss with respect to p.
                p.data -= lr / math.sqrt(t + 1) * grad # Update weight tensor in-place.
                state["t"] = t + 1 # Increment iteration number.
        return loss

In [2]:
def train_with_lr(lr=1):
    weights = torch.nn.Parameter(5 * torch.randn((10, 10)))
    opt = SGD([weights], lr=lr)
    for t in range(100):
        opt.zero_grad() # Reset the gradients for all learnable parameters.
        loss = (weights**2).mean() # Compute a scalar loss value.
        print(loss.cpu().item())
        loss.backward() # Run backward pass, which computes gradients.
        opt.step() # Run optimizer step.
train_with_lr()

23.231822967529297
22.311845779418945
21.685230255126953
21.18732261657715
20.765697479248047
20.395889282226562
20.064184188842773
19.761991500854492
19.483503341674805
19.224586486816406
18.9821834564209
18.75394058227539
18.538013458251953
18.332921981811523
18.13745880126953
17.950620651245117
17.771562576293945
17.599571228027344
17.434032440185547
17.274415969848633
17.120250701904297
16.971139907836914
16.826719284057617
16.686668395996094
16.550697326660156
16.418556213378906
16.290010452270508
16.164852142333984
16.042888641357422
15.923945426940918
15.807865142822266
15.694502830505371
15.583721160888672
15.4753999710083
15.369423866271973
15.265680313110352
15.164079666137695
15.064526557922363
14.966931343078613
14.871219635009766
14.777314186096191
14.68514633178711
14.594646453857422
14.505756378173828
14.418414115905762
14.332568168640137
14.248164176940918
14.165155410766602
14.083488464355469
14.00312614440918
13.924025535583496
13.84614372253418
13.769445419311523
13.

In [3]:
train_with_lr(1e1)

23.613693237304688
15.112763404846191
11.140484809875488
8.716239929199219
7.060154914855957
5.853676795959473
4.936800003051758
4.2186360359191895
3.6431238651275635
3.173565626144409
2.784832239151001
2.4590954780578613
2.183340549468994
1.947838544845581
1.745171308517456
1.5695844888687134
1.4165500402450562
1.2824575901031494
1.1643961668014526
1.059995174407959
0.9673063158988953
0.8847154378890991
0.8108751773834229
0.7446537017822266
0.6850941181182861
0.6313827633857727
0.5828243494033813
0.5388219952583313
0.49886059761047363
0.46249425411224365
0.4293350875377655
0.3990447521209717
0.37132683396339417
0.3459210693836212
0.32259804010391235
0.30115506052970886
0.2814127206802368
0.26321133971214294
0.24640899896621704
0.23087894916534424
0.21650773286819458
0.2031938135623932
0.19084596633911133
0.17938199639320374
0.16872793436050415
0.15881691873073578
0.14958851039409637
0.14098793268203735
0.13296547532081604
0.12547598779201508
0.11847838759422302
0.11193519830703735
0.1

In [4]:
train_with_lr(1e2)

21.831697463989258
21.831697463989258
3.7457268238067627
0.0896436795592308
1.4007038418120563e-16
1.5611711606934078e-18
5.2570132587009346e-20
3.1316387522597062e-21
2.686522232062097e-22
2.9850247373517037e-23
4.032438827449519e-24
6.354760727088093e-25
1.1351687472012565e-25
2.2509475706860925e-26
4.8771132872294616e-27
1.1406159280483048e-27
2.851539820120762e-28
7.5609018406698e-29
2.112615901037936e-29
6.1870789105247994e-30
1.89060323583994e-30
6.004643315751757e-31
1.975619501189715e-31
6.714248583238634e-32
2.3511288793459106e-32
8.464065523175293e-33
3.126469671721731e-33
1.1828929857347599e-33
4.576947048689891e-34
1.808579218359819e-34
7.289237025617899e-35
2.993041649765045e-35
1.250771907022993e-35
5.3145495787203145e-36
2.294039393955804e-36
1.0051615674922361e-36
4.467385242649392e-37
2.0126116683740813e-37
9.185106738652679e-38
4.2439924274923344e-38
1.98425516080582e-38
9.38286610546929e-39
4.4852410921492664e-39
2.166501312843277e-39
1.0570064381325311e-39
5.2068607

In [5]:
train_with_lr(1e3)

27.975845336914062
10099.279296875
1744304.5
194035248.0
15716853760.0
991913443328.0
50921597829120.0
2190864830955520.0
8.075059925509734e+16
2.592991740801581e+18
7.351361131589403e+19
1.8601283687547887e+21
4.238549728860823e+22
8.763297269289533e+23
1.6545960435473895e+25
2.8688559006199424e+26
4.590169440991908e+27
6.8062972269119495e+28
9.388694720528329e+29
1.208888584699385e+31
1.4574030165905482e+32
1.6496205584273464e+33
1.7574709159504212e+34
1.766386139213741e+35
1.6783671763136253e+36
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf
inf


#### Adam

In [None]:
import torch
from torchinfo import summary
from transformer import TransformerLM
from perf import *

vocab_size = 50257
context_length = 1024
num_layers = 48
d_model = 1600
num_heads = 25

d_ff = 4 * d_model

print('params\t\tgradients\tadamw\t\tactivation\tsum')
for batch_size in range(0, 6):
    memory = CalcMemory(batch_size, vocab_size, context_length, num_layers, d_model, num_heads)

    print('\t\t'.join(f'{x:.3f}' for x in memory))

a100 = 19.5 * 1000**4 * 0.5
flops = CalcFlops(1024, vocab_size, context_length, num_layers, d_model, num_heads) * 400000
print(f'days: {flops/a100/3600/24}')



params		gradients	adamw		activation	sum
7.624		7.624		15.249		0.000		30.497
7.624		7.624		15.249		18.075		48.572
7.624		7.624		15.249		36.150		66.647
7.624		7.624		15.249		54.225		84.722
7.624		7.624		15.249		72.300		102.797
7.624		7.624		15.249		90.375		120.872
days: 6583.556412243875


In [7]:
model = TransformerLM(vocab_size, context_length, num_layers, d_model, num_heads, d_ff, 10000, dtype=torch.float32)
batch_size = 4
summary(
    model, 
    input_size=(batch_size, context_length),
    dtypes=[torch.long], 
    col_names=["input_size", "output_size", "num_params", "mult_adds"],
    row_settings=["var_names"],
    depth=2
)

Layer (type (var_name))                                 Input Shape               Output Shape              Param #                   Mult-Adds
TransformerLM (TransformerLM)                           [4, 1024]                 [4, 1024, 50257]          --                        --
├─Embedding (token_embeddings)                          [4, 1024]                 [4, 1024, 1600]           80,411,200                321,644,800
├─ModuleList (layers)                                   --                        --                        --                        --
│    └─TransformerBlock (0)                             [4, 1024, 1600]           [4, 1024, 1600]           40,963,200                163,852,800
│    └─TransformerBlock (1)                             [4, 1024, 1600]           [4, 1024, 1600]           40,963,200                163,852,800
│    └─TransformerBlock (2)                             [4, 1024, 1600]           [4, 1024, 1600]           40,963,200                163,852,80

In [None]:
import torch
import gc
from perf import *

# Hyperparameters
vocab_size = 10000
context_length = 128
num_layers = 48
d_model = 160
num_heads = 4
d_ff = 4 * d_model
batch_size = 4

# Display theoretical estimation
print('params\t\tgradients\tadamw\t\tactivation\tsum')
memory = CalcMemory(batch_size, vocab_size, context_length, num_layers, d_model, num_heads)
print('\t\t'.join(f'{x:.3f}' for x in memory))

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

if device.type == 'cpu':
    print("GPU not detected. Memory measurement is only supported on CUDA devices.")
else:
    # --- Step 1: Initialize Model and Measure Static Weights ---
    # Clear cache and reset stats before starting
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

    # Instantiate model and move to GPU
    model = TransformerLM(
        vocab_size, context_length, num_layers, d_model, num_heads, d_ff, 
        theta=10000, device=device, dtype=torch.float32
    )
    
    weight_memory = torch.cuda.memory_allocated() / 1024**3
    print(f"[1] Model weights memory: {weight_memory:.4f} GB")

    # --- Step 2: Measure Forward Pass (Activations) ---
    # Prepare dummy data (Long type for Embedding layer)
    input_ids = torch.randint(0, vocab_size, (batch_size, context_length), device=device)

    # Reset peak stats to capture the peak from this specific computation
    torch.cuda.reset_peak_memory_stats()

    # Forward pass: This allocates memory for activations (intermediate tensors)
    outputs = model(input_ids)

    forward_memory = torch.cuda.memory_allocated() / 1024**3
    print(f"[2] Post-forward memory usage (Weights + Activations): {forward_memory:.4f} GB")

    # --- Step 3: Measure Backward Pass (Gradients & Peak) ---
    # Backward pass: This calculates gradients, usually reaching the absolute peak usage
    loss = outputs.mean()  # Create a dummy loss scalar
    loss.backward()

    # --- Step 4: Final Statistics ---
    # max_memory_allocated() records the highest point seen since the last reset
    peak_memory = torch.cuda.max_memory_allocated() / 1024**3
    current_memory = torch.cuda.memory_allocated() / 1024**3

    print(f"--------------------------------------------------")
    print(f"Actual Training Peak Memory: {peak_memory:.4f} GB")
    print(f"--------------------------------------------------")
    print(f"Detailed Breakdown:")
    print(f"  - Model weights: {weight_memory:.4f} GB")
    print(f"  - Activations (est.): {forward_memory - weight_memory:.4f} GB (intermediate variables from forward pass)")
    print(f"  - Gradients & temp buffers: {peak_memory - forward_memory:.4f} GB (additional overhead from backward pass)")
    print(f"--------------------------------------------------")

params		gradients	adamw		activation	sum
0.079		0.079		0.159		0.351		0.668
Using device: cuda
[1] Model weights memory: 0.0862 GB
[2] Post-forward memory usage (Weights + Activations): 0.7241 GB
--------------------------------------------------
Actual Training Peak Memory: 0.7578 GB
--------------------------------------------------
Detailed Breakdown:
  - Model weights: 0.0862 GB
  - Activations (est.): 0.6379 GB (intermediate variables from forward pass)
  - Gradients & temp buffers: 0.0337 GB (additional overhead from backward pass)
--------------------------------------------------
