# Making train pipeline (No-CLI) Notebook
**NOTE**: in actual ```train.py```, i'm gonna be implementing CLI + some modifications to increase the calculation speed <br/>
but the basic ideas are the same.

## Steps
1. Imports & global setups
2. Config & runtime knobs


# OVERALL DIAGRAM
<img src="./train_diagram.png" width="800"/>

## **If you're testing full, turn this to true**

In [1]:
isfull = True

# 1. Imports & global setups

In [2]:
# imports
import os, math, time, json, random, csv, warnings
from pathlib import Path
import datetime as dt
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torch.amp import GradScaler

# global setups
warnings.filterwarnings("ignore", category=UserWarning, module="tqdm") # just to remove tqdm warnings

torch.backends.cuda.matmul.allow_tf32 = True
# this defaults to true. but i'm doing it to MAKE SURE it's true
# it allows matrix multiplication to run in TF32 mod (in my case, RTX30-series)

torch.backends.cudnn.benchmark = True
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass
# Values:
#   'highest' → strict IEEE FP32 (full precision, slower).
#   'high' → allows TF32 (fast, slightly less precise).
#   'medium' → allows lower precision modes when possible (even faster).

# i've done this just for debugging
SEED = 1234
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)

# setting up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device) # make sure you see "device: cuda"

device: cuda


# 2. Config, params, runtime knobs

In [3]:

# model config (later, we'll be grabing model.json instead)
cfg = {
    "vocab_size": 60000,
    "context_length": 128,      # real one uses 10024
    "emb_dim": 256,             # real one uses 768
    "n_heads": 4,               # real one uses 12
    "n_layers": 4,              # real one uses 12
    "drop_rate": 0.0,
    "qkv_bias": True,
    "activation": "gelu",
    "layer_norm_eps": 1e-5,
    "initializer_range": 0.02,
    "intermediate_size": 1024,  # real one uses 3072
    "attention_probs_dropout_prob": 0.0,
    "grad_ckpt": False,
}

import shutil
from pathlib import Path

# paths
RUN_DIR = Path("../pretrain_checkpoint_notebook/")

# DIR THINGS; I'VE HANDLED IT HERE EXPLITELY FOR NOTEBOOK (it's not gonna be in our train.py)
# remove RUN_DIR if it exists (clear contents)
if RUN_DIR.exists():
    shutil.rmtree(RUN_DIR)

# recreate empty RUN_DIR
RUN_DIR.mkdir(parents=True, exist_ok=True)
# END OF DIR THINGS

# scheduling knobs (small for quick test; this will be scaled up later before real pretrain)
TOTAL_UPDATES   = 200
PILOT_FRAC      = 0.05
CHUNKS          = 2
EVAL_EVERY      = 50
PERIODIC_EVERY  = 100
LOG_EVERY       = 10
ACCUM           = 2
MICRO_BSZ       = 8
BASE_LR         = 5e-4
WEIGHT_DECAY    = 0.10

# AMP mode: "fp16" | "bf16" | "none"
# during the real run, i tested both bf16 and fp16
# but the estimate completion time for bf16 was approximately 8 days while fp16 was 5 days
# so, i decided to use fp16
AMP_MODE = "fp16" # feel free to test out bf16 and none (it's gonna default to fp16)

# derived knobs
CTX = int(cfg["context_length"])
EFFECTIVE_TOKENS_PER_UPDATE = CTX * MICRO_BSZ * ACCUM
WARMUP_STEPS = max(1, int(0.10 * TOTAL_UPDATES))

print(f"Plan → TOTAL_UPDATES={TOTAL_UPDATES} (pilot={int(PILOT_FRAC*TOTAL_UPDATES)}, "
      f"{CHUNKS}×chunks). eff tokens/update ≈ {EFFECTIVE_TOKENS_PER_UPDATE:,}. warmup={WARMUP_STEPS}")


Plan → TOTAL_UPDATES=200 (pilot=10, 2×chunks). eff tokens/update ≈ 2,048. warmup=20


later on, some of those constants will be controlled by inline params

# 3. Helpers (functions)

In [4]:
# before saving/loading state_dict, we want to unwrap a model to its underlying nn.Module
def _unwrap_model_for_state_dict(model: nn.Module) -> nn.Module:
    if hasattr(model, "_orig_mod"): # unwrap torch.compile wrapper
        return model._orig_mod
    if hasattr(model, "module"): # unwrap DDP / DataParallel
        return model.module
    return model


In [5]:
# we also want to remove prefix from all keys in a state_dict
# removes a given prefix (e.g., "module." or "_orig_mod.") from 
# all parameter keys inside a model's state_dict (i.e. ckpt["model"]). 
# This is useful when loading checkpoints saved with wrappers like 
# DataParallel (which adds "module.") or torch.compile (which adds "_orig_mod.").
def _strip_prefix_in_state_dict(sd: dict, prefix: str) -> dict:
    if not any(k.startswith(prefix) for k in sd.keys()):
        return sd
    return {k[len(prefix):] if k.startswith(prefix) else k: v for k, v in sd.items()}

In [6]:
# ensures each worker gets a unique, deterministic random seed
# so, yeah. we call this multiple times to instantiate workers
def worker_init_fn(worker_id: int):
    seed = (torch.initial_seed() + worker_id) % (2**31 - 1)
    np.random.seed(seed)
    torch.manual_seed(seed)

#### Weight decay
<img src="./weight_decay.png" width="500"/>

that's the example of graph for weight decay. <br/>
concept: as the weight value goes down; learning rate would increase.

- without weight decay (blue)
    - moving strictly down -> params grow large -> causes overfitting and unstable learning
- with weight decay (orange)
    - moving down -> but in some point, it will move down slower (never hitting horizontal asymptote)
        -> param grows large to small -> prevents overfitting and unstable learning

In [7]:
# this is called before optimizer does the thing
def param_groups_weight_decay(module: nn.Module, weight_decay: float):
    decay = [] # params that should have weight decay applied
    no_decay = [] # params that should not have weight decay applied

    for n, p in module.named_parameters():
        if not p.requires_grad:
            continue
        is_bias = n.endswith("bias")
        is_norm = ("norm" in n.lower()) or ("ln" in n.lower())

        # biases, norms, and 1D tensors -> no decay
        (no_decay if (is_bias or is_norm or p.ndim <= 1) else decay).append(p)
    return [{"params": decay, "weight_decay": weight_decay},
            {"params": no_decay, "weight_decay": 0.0}]

#### Cosine warmup
let's focus on the warmup phase (vertical dotted lines) <br/>
##### Without Cosine warmup
<img src="./without_cosine_warmup.png" width="500"><br/>
- after warmup, the learning rate stays flat at the maximum value.
- we're gonna face overfitting problem
- this may cause unstable learning

<img src="./cosine_warmup.png" width="500"/><br/>
- that nice down-curve makes learning rate to drop smoothly
- updates get smaller over time, letting the model stable instead of bouncing
- this is unlikely to cause overfitting and unstable learning

In [8]:
def build_cosine_with_warmup(optimizer, warmup_steps, total_steps, min_lr_ratio=0.10):
    # this part is gonna return that nice curve, which is after warmup
    def lr_lambda(step):
        # linear part (from 0 to warmup)
        if step < warmup_steps:
            return (step + 1) / max(1, warmup_steps)
        
        # normalize step to [0, 1] after warmup
        t = (step - warmup_steps) / max(1, total_steps - warmup_steps)
        t = min(max(t, 0.0), 1.0)

        # apply cosine curve using that "t"
        return min_lr_ratio + 0.5 * (1 - min_lr_ratio) * (1 + math.cos(math.pi * t))
    
    # then we return that
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

#### Exponential Moving Average
<img src="./ema.png" width="500"/><br/>
so, what's doing here is:
1. take that fluctuating weight values that's moving downwards
2. smooth the noises 
3. we'll end up with some nice curve that's heading downwards

In [9]:
class EMAMeter:
    def __init__(self, beta=0.9): 
        self.beta = beta # smoothing factor (closer to 1 = smoother; but slower updates)
        self.val = 0.0 # current EMA value
        self.inited = False # flag to check if first value has been set
        
    def update(self, x):
        # get weight, loss, etc. as x
        x = float(x)

        # if first update, just take x as initial value
        if not self.inited: 
            self.val, 
            self.inited = x, 
            True
        # if not, new_value = smoothing_factor * old_value + (1 - smoothing_factor)
        else: 
            self.val = self.beta*self.val + (1-self.beta)*x

        return self.val

#### Top-K
idea: 
- track top-K best result (like in lowest loss, etc) using a heap
- later on, this will be used when we reload checkpoint, fine-tuning, etc.

In [10]:
class TopK:
    def __init__(self, k=5):
        self.k = k
        self.heap = []  # max-heap via negative loss

    # adding a new result into the heap
    # args:
    #   loss (float): validation loss
    #   update (int): training step number
    #   path (str): checkpoint file path
    def add(self, loss, update, path):
        import heapq
        item = (-float(loss), int(update), str(path))

        # if heap is not full, push new item
        if len(self.heap) < self.k:
            heapq.heappush(self.heap, item)
        
        # if heap is full, replace the worst
        else:
            if item > self.heap[0]: # better than current worst?
                heapq.heapreplace(self.heap, item) # yeah, replace the worst item with the current item
    
    # return top-K items sorted by ascending loss
    def best(self):
        return sorted([(-l, u, p) for (l, u, p) in self.heap], key=lambda x: x[0])

#### ByteTensor, Tensor, Byte, Sequences to tensor
some of the Q&A
- why cpu instead of GPU?
    - that byte tensor isn't for computation. it's just some number that stores metadata
- why uint8?
    - since RNG states, checkpoint manifests, etc. are often stored as uint8

In [11]:
def _to_byte_tensor(x):

    # just in case x is nothing (does NOT mean undefined)
    if x is None: 
        return None
    
    # if type of x is ByteTensor (8-bit unsigned)
    if isinstance(x, torch.ByteTensor): 
        return x.cpu() # just put that into cpu
    
    # if type of x is any other kind of Tensor
    if isinstance(x, torch.Tensor):     
        # detach from the graph -> cast to uint8 -> move to cpu
        return x.detach().to(dtype=torch.uint8, device="cpu")
    
    # if type of x is raw bytes or bytearray
    if isinstance(x, (bytes, bytearray)): 
        # convert to list of integers -> cast to uint8
        return torch.tensor(list(x), dtype=torch.uint8)
    
    # if type of x is list, tuple, ndarray
    if isinstance(x, (list, tuple, np.ndarray)): 
        # just cast to uint8
        return torch.tensor(x, dtype=torch.uint8)
    return None

# 4. Logs, visualization, directories, etc.

In [12]:
import shutil
from pathlib import Path

def make_run_dirs(base):
    # we're gonna later "../pretrain_results_test") for smoke testing
    base = Path(base)

    # inside base path,
    ckpt = base / "checkpoints" # create checkpoints folder
    plots = base / "plots" # create plots folder
    logs = base / "logs" # create logs folder

    # just in case if we're rerunning the whole thing 
    # (with different config)
    # clear only the three subfolders
    for d in (ckpt, plots, logs):
        if d.exists() and d.is_dir():
            shutil.rmtree(d)

    # recreate fresh
    for d in (ckpt, plots, logs):
        d.mkdir(parents=True, exist_ok=True)

    return base, ckpt, plots, logs

In [13]:
# about manifest
# it's going to create manifest in base
# also, it will contain latest update
def manifest_load_or_init(pretrain_dir: Path):
    # define path
    path = pretrain_dir / "manifest.json"

    # if there's no manifest, create one
    if not path.exists():
        manifest = {
            "created_utc": dt.datetime.now(dt.UTC).isoformat() + "Z",
            "chunks_completed": 0,
            "last_update": 0,
            "best_val": None,
            "cfg": cfg,
        }
        path.write_text(json.dumps(manifest, indent=2))
    
    # if we have one, load that since we're gonna be updating this (chunks_completed, last_update, best_val)
    else:
        manifest = json.loads(path.read_text())
    return manifest, path

In [14]:
# it's going to generate csv for each train (NOT CHUNKS)
def history_init(csv_path: Path):
    if not csv_path.exists():
        with open(csv_path, "w", newline="") as f:
            csv.writer(f).writerow(
                ["update","split","loss","ppl","lr","tps","gnorm","scale","tokens_seen","utc"]
            )

#### Below 2 will be used for plotting training loss & validation loss in graph

In [15]:
# it's going to update the history.csv for each train
def log_train_row(csv_path, update, loss, lr, tps, gnorm, scale, tokens_seen):
    with open(csv_path, "a", newline="") as f:
        csv.writer(f).writerow([
            int(update),"train",float(loss),"",float(lr),float(tps),
            float(gnorm),float(scale),int(tokens_seen),dt.datetime.now(dt.UTC).isoformat()+"Z"
        ])

In [16]:
# it's going to update the validation results to the history.csv
def log_val_row(csv_path, update, loss, ppl, tokens_seen):
    with open(csv_path, "a", newline="") as f:
        csv.writer(f).writerow([
            int(update),"val",float(loss),float(ppl),"","","","",
            int(tokens_seen),dt.datetime.now(dt.UTC).isoformat()+"Z"
        ])

#### Generating graph for training loss & validation loss

In [17]:
# this is optional
# i'm using this to visualize results later on
def save_plot(history_csv: Path, plots_dir: Path, eff_tokens_per_update: int, save_tag="latest"):
    try:
        import pandas as pd, matplotlib.pyplot as plt
        if not history_csv.exists(): return
        df = pd.read_csv(history_csv)
        if df.empty: return

        df["tokens_seen"] = df["update"] * eff_tokens_per_update
        tr = df[df["split"]=="train"].copy()
        vl = df[df["split"]=="val"].copy()

        fig, ax = plt.subplots(figsize=(7,3.2), dpi=120)
        ax.plot(tr["update"], tr["loss"], label="Training loss")
        if not vl.empty:
            ax.plot(vl["update"], vl["loss"], linestyle="--", label="Validation loss")
        ax.set_xlabel("Updates"); ax.set_ylabel("Loss"); ax.legend(loc="best")

        ax2 = ax.twiny()
        ax2.set_xlim(ax.get_xlim())
        xticks = ax.get_xticks()
        ax2.set_xticks(xticks)
        ax2.set_xticklabels([f"{int(x*eff_tokens_per_update/1000):,}k" for x in xticks])
        ax2.set_xlabel("Tokens seen")

        out = plots_dir / f"loss_curve_{save_tag}.png"
        fig.tight_layout()
        fig.savefig(out)
        plt.close(fig)
        print(f"🖼️ saved plot → {out}")

    # if it fails...? well, we don't really need graph so, skip that
    except Exception as e:
        print(f"[warn] plot failed: {e}")


# 5. Checkpoint
here's the thing<br/>
the actual run will take around 6 days WITHOUT PAUSING<br/>
i don't really want to burst my computer<br/>
i decided to make checkpoint<br/>
so that in the future, when we stop and resume, it will start from the latest checkpoint

In [18]:
def save_checkpoint(ckpt_dir: Path, manifest: dict, manifest_path: Path,
                    model, optimizer, scheduler, scaler, update, val_loss=None, tag="latest"):
    
    base_model = _unwrap_model_for_state_dict(model)

    if tag == "latest":
        path = ckpt_dir / "latest.pt"
    elif tag == "best":
        path = ckpt_dir / "best.pt"
    else:
        path = ckpt_dir / f"{tag}.pt"

    payload = {
        "update": int(update),
        "val_loss": (None if val_loss is None else float(val_loss)),
        "optimizer": optimizer.state_dict(),
        "scheduler": scheduler.state_dict(),
        "scaler": scaler.state_dict() if (scaler is not None and getattr(scaler, "is_enabled", lambda: False)()) else None,
        "cfg": cfg,
        "rng": {
            "torch_cpu": torch.random.get_rng_state(),
            "torch_cuda": torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None,
        },
        "timestamp": time.time(),
        "model": base_model.state_dict()
    }

    # you'll see multiples of uXXXXXX.pt in checkpoint
    # this is what generates thsoe .pt filess
    torch.save(payload, path)

    manifest["last_update"] = int(update)
    if val_loss is not None:
        m_best = manifest.get("best_val")
        if (m_best is None) or (val_loss < m_best):
            manifest["best_val"] = float(val_loss)
    manifest_path.write_text(json.dumps(manifest, indent=2))
    return str(path)

In [19]:
def load_checkpoint(path: Path, model, optimizer, scheduler, scaler, map_location):
    # if we don't see any checkpoint.pt?
    # ⭐ fresh start ⭐ 
    # yeah... without checkpoint, if you got crash or OOM during run, you'll have to run all over again...
    if not path.exists():
        print("no checkpoint found, starting fresh")
        return 1, None
    
    # load that checkpoint.pt
    ckpt = torch.load(path, map_location=map_location, weights_only=False)

    # and re-configure everything
    raw_sd = ckpt["model"]
    raw_sd = _strip_prefix_in_state_dict(raw_sd, "_orig_mod.")
    raw_sd = _strip_prefix_in_state_dict(raw_sd, "module.")
    model.load_state_dict(raw_sd)
    
    optimizer.load_state_dict(ckpt["optimizer"])
    scheduler.load_state_dict(ckpt["scheduler"])
    if scaler is not None and ckpt.get("scaler") is not None:
        try: scaler.load_state_dict(ckpt["scaler"])
        except Exception as e: print(f"[warn] AMP scaler restore failed: {e}")

    rng = ckpt.get("rng", {})
    try:
        cpu_state = _to_byte_tensor(rng.get("torch_cpu"))
        if cpu_state is not None: torch.random.set_rng_state(cpu_state)
    except Exception as e:
        print(f"[warn] CPU RNG restore skipped: {e}")

    try:
        cuda_states = rng.get("torch_cuda")
        if torch.cuda.is_available() and cuda_states is not None:
            if isinstance(cuda_states, (list, tuple)):
                cuda_states = [_to_byte_tensor(s) for s in cuda_states]
                torch.cuda.set_rng_state_all(cuda_states)
            else:
                torch.cuda.set_rng_state(_to_byte_tensor(cuda_states))
    except Exception as e:
        print(f"[warn] CUDA RNG restore skipped: {e}")

    start = int(ckpt.get("update", 0)) + 1
    best  = ckpt.get("val_loss", None)
    print(f"🔄 resumed from {path} @ update {start-1} (best_val={best})")
    return start, best


# 6. Evaluation

In [20]:
import torch
import torch.nn.functional as F

# masked value
IGNORE_INDEX = -100

# THE GREEDIEST METHOD
# always return loss (that's not none)
# I never did this in train.py (i got lazy on fixing that bug when testing pretraining cell)
def forward_with_guaranteed_loss(model, xb, yb):
    try:
        out = model(xb, labels=yb)
    except TypeError:
        out = model(xb, yb)

    if isinstance(out, (tuple, list)):
        logits, maybe_loss = out[0], (out[1] if len(out) > 1 else None)
    elif hasattr(out, "logits") or hasattr(out, "loss"):
        logits = getattr(out, "logits", out)
        maybe_loss = getattr(out, "loss", None)
    else:
        logits, maybe_loss = out, None

    if maybe_loss is None:
        V = logits.shape[-1]
        loss = F.cross_entropy(
            logits.reshape(-1, V),
            yb.reshape(-1),
            ignore_index=IGNORE_INDEX,
            reduction="mean"
        )
    else:
        loss = maybe_loss

    return logits, loss


In [21]:
from contextlib import nullcontext

@torch.no_grad() # disables gradient tracking
def evaluate(model, loader, device, amp_dtype=None, max_batches=None):
    model_was_training = model.training
    model.eval()

    # amp/autocast context
    if amp_dtype in ("bf16", "bfloat16", torch.bfloat16):
        autocast_ctx = torch.autocast(device_type=device.type, dtype=torch.bfloat16)
    elif amp_dtype in ("fp16", "float16", torch.float16):
        autocast_ctx = torch.autocast(device_type=device.type, dtype=torch.float16)
    else:
        autocast_ctx = nullcontext()

    total_loss_weighted = 0.0  # sum of (loss * valid_tokens)
    total_valid_tokens  = 0

    with torch.no_grad():
        with autocast_ctx:
            for b_idx, batch in enumerate(loader):
                if isinstance(batch, (tuple, list)):
                    xb, yb = batch[0], batch[1]
                elif isinstance(batch, dict):
                    xb, yb = batch["input_ids"], batch.get("labels", batch.get("targets"))
                else:
                    raise TypeError(f"Unexpected batch type: {type(batch)}")

                xb = xb.to(device, non_blocking=True)
                yb = yb.to(device, non_blocking=True).long()

                _, loss = forward_with_guaranteed_loss(model, xb, yb)

                # count only valid tokens
                valid = (yb != IGNORE_INDEX).sum().item()
                if valid == 0:
                    # skip completely masked batches
                    continue

                total_loss_weighted += loss.item() * valid
                total_valid_tokens  += valid

                if max_batches is not None and (b_idx + 1) >= max_batches:
                    break

    # Avoid div by zero
    if total_valid_tokens == 0:
        mean_loss = float("nan")
        ppl = float("inf")
    else:
        mean_loss = total_loss_weighted / total_valid_tokens
        ppl = float("inf") if not (mean_loss == mean_loss) else float(torch.exp(torch.tensor(mean_loss)).item())

    if model_was_training:
        model.train()

    return mean_loss, ppl


# 7. Model & dataset wiring

In [22]:
# grabbing model and dataset
# model and sets are in different directory
# so, we'll gonna change the path
import sys
sys.path.append("../python_files")
from npy_datasets import NpyTokensDataset
from Dummy_Model import DummyModel

if isfull:
    TRAIN_BLOCKS = "../materials/train_blocks.npy"
    VALID_BLOCKS = "../materials/valid_blocks.npy"
else:
    TRAIN_BLOCKS = "../materials_small/train_blocks.npy"
    VALID_BLOCKS = "../materials_small/valid_blocks.npy"

# datasets
train_ds = NpyTokensDataset(TRAIN_BLOCKS, cfg["context_length"])
valid_ds = NpyTokensDataset(VALID_BLOCKS, cfg["context_length"])

base_kwargs = dict(
    batch_size=MICRO_BSZ,
    num_workers=0,
    drop_last=True,
    pin_memory=True,
    shuffle=True,
)

valid_kwargs = dict(base_kwargs)
valid_kwargs["shuffle"] = False

# pin to GPU if available
try:
    base_kwargs["pin_memory_device"] = "cuda"
    valid_kwargs["pin_memory_device"] = "cuda"
except TypeError:
    pass

# loader
train_loader = DataLoader(train_ds, **base_kwargs)
valid_loader = DataLoader(valid_ds, **valid_kwargs)

# model
# if you ran all the way down to here from start, you should have cfg and device variables defined
model = DummyModel(cfg).to(device)
print("params:", sum(p.numel() for p in model.parameters()))

# check batch sanity
xb, yb = next(iter(train_loader))
assert xb.dtype == torch.int64 and yb.dtype == torch.int64
assert xb.shape[1] == cfg["context_length"] and yb.shape == xb.shape
print("batch OK:", xb.shape, yb.shape)


using sdpa_kernel
params: 18552320
batch OK: torch.Size([8, 128]) torch.Size([8, 128])


# 8. Optimizer (AdamW)
list of torch.optim:
- torch.optim.AdamW
- torch.optim.Adadelta
- torch.optim.Adafactor
- torch.optim.Adagrad
- torch.optim.Adam
- torch.optim.Adamax
- torch.optim.ASGD
- torch.optim.Adadelta
- torch.optim.LBFGS
- torch.optim.NAdam
- torch.optim.Optimizer
- torch.optim.RAdam
- torch.optim.SGD
- torch.optim.RMSprop
- torch.optim.Rprop
- torch.optim.SparseAdam

(Adam is short for Adaptive Moment Estimate. Not a person's name)

##### TOP 2 Optimizers often used (without AdamW):
- Adam (this is later upgraded to AdamW, which we're going to use it)
    - How it works: adaptive learning rates per parameter (tracks 1st and 2nd moments of gradients)
    - Pros:
        - it's FAST
        - minimal hyper parameter tuning needed
    - Cons:
        - may cause overfit
        - memory hungry (since it keeps 2 extra states per parameter)
    - Example: GPT-1
- LAMB
    - How it works: builds an AdamW but scales the learning rate per layer by the ratio of parameter norm to update norm
    - Pros:
        - Enables very large batch training without divergence
        - designed for scaling BERT/GPT pretraining across hundreds of GPUs
    - Cons:
        - complex and slower per-step than AdamW
        - gains are most relevant only at extreme batch scales
    - Example: BERT

##### AdamW (we're going to use this guy)
<img src="./AdamW.png" width="500"/><br/>
- How it works: in big picture same thing with Adam, but it decouples weight decay from gradient updates <br/>
    (there's more but i'm trimming in here)
- Pros:
    - better generalization than Adam (yup. that decoupling weight decay)
    - stable convergence with large model
- Cons:
    - memory cost since it still keeps 1st and 2nd states (like Adam)
- Example: GPT-4, DeepMind (Google), LLaMA (Meta)

In [23]:
# remember we've made all those weight dacay & cosine warmup?
# yup. we're going to use those here

# weight decay
opt_groups = param_groups_weight_decay(model, WEIGHT_DECAY)

# if torch supports fused, use that
try:
    # fused=True makes cumputation faster
    optimizer = torch.optim.AdamW(opt_groups, lr=BASE_LR, betas=(0.9,0.95), eps=1e-8, fused=True)

# if not, just use the defaulted one
# computation is slower. but results are identical
except TypeError:
    optimizer = torch.optim.AdamW(opt_groups, lr=BASE_LR, betas=(0.9,0.95), eps=1e-8)



##### (TL;DR)
<img src="./AdamW_actual.png" width="300"/><br/>
the actual graph would look something like this

# 9. Scheduler and AMP

In [24]:
# let scheduler use that cosine stuff
scheduler = build_cosine_with_warmup(
    optimizer, warmup_steps=WARMUP_STEPS, total_steps=TOTAL_UPDATES, min_lr_ratio=0.10
)

# define AMP_MODE
if AMP_MODE == "fp16":
    amp_dtype = torch.float16
    scaler = GradScaler(enabled=True)
elif AMP_MODE == "bf16":
    amp_dtype = torch.bfloat16
    scaler = GradScaler(enabled=False)
else:
    amp_dtype = None
    scaler = GradScaler(enabled=False)

# 10. Define constants for checkpoints
note that if you run this, it will remove everything in pretrain_results_test folder

In [25]:

# --- Run dirs & manifest & history ---
PRETRAIN_DIR, CKPT_DIR, PLOTS_DIR, LOGS_DIR = make_run_dirs(RUN_DIR)
manifest, MANIFEST_PATH = manifest_load_or_init(PRETRAIN_DIR)
HISTORY_CSV = LOGS_DIR / "history.csv"
history_init(HISTORY_CSV)

LATEST = CKPT_DIR / "latest.pt"
BEST   = CKPT_DIR / "best.pt"


# 11. Loading checkpoints
after all the way down to pretesting, we'll be running this again<br/>
to verify that the resuming works

In [26]:
start_update, best_val = load_checkpoint(LATEST, model, optimizer, scheduler, scaler, device)
if best_val is None: best_val = float("inf")

if start_update > 1:
    print(f"🔄 RESUMING TRAINING from update {start_update - 1} (best={best_val}) in {PRETRAIN_DIR}")
else:
    print("🚀 Starting fresh training run.")

no checkpoint found, starting fresh
🚀 Starting fresh training run.


##### **IF YOUR TESTING RESUME, RUN CODE ALL THE WAY DOWN AGAIN**

# 12. Train helpers & alerts
<img src="./periodic_register_val.png" width="800"/><br/>
it's basically this part (+other blocks connected to) in the TRAIN block of diagram

In [27]:
topk = TopK(k=5)

def save_periodic(update, val_loss, every=PERIODIC_EVERY):
    if update % every == 0:
        tag = f"u{update:07d}" # for example, u0000100
        save_checkpoint(CKPT_DIR, manifest, MANIFEST_PATH, model, optimizer, scheduler, scaler, update, val_loss, tag)
        print(f"💾 saved periodic checkpoint: {tag}")

def register_val_result(update, val_loss):
    # always save latest
    save_checkpoint(CKPT_DIR, manifest, MANIFEST_PATH, model, optimizer, scheduler, scaler, update, val_loss, "latest")

    # track and save best so far
    current_best = getattr(register_val_result, "_best", float("inf"))
    if val_loss < current_best:
        save_checkpoint(CKPT_DIR, manifest, MANIFEST_PATH, model, optimizer, scheduler, scaler, update, val_loss, "best")
        register_val_result._best = val_loss
        print(f"⭐ new best: {val_loss:.4f} @ upd {update}")
    tag = f"u{update:07d}"
    ckpt_path = CKPT_DIR / f"{tag}.pt"
    topk.add(val_loss, update, str(ckpt_path))

# alerts
# later, we'll be using these constants in train loop
BASELINE_WINDOW       = 50
TOKENS_SEC_ALERT_DROP = 0.30
LOSS_SPIKE_X          = 2.0
CONSEC_FOR_ALERT      = 8


# 13. Train loop

In [28]:

from tqdm.auto import tqdm # for progress bar

def train_for_updates(num_updates, tag):
    global start_update, best_val
    end_update = start_update + num_updates - 1

    # smoothing meters (using Exponential Moving Average)
    tps_ema  = EMAMeter(0.8) # (closer to 1 = smoother; but slower updates)
    loss_ema = EMAMeter(0.9)

    # baselines used for alerting
    base_tps = None
    base_loss= None
    tps_bad  = 0 # consecutive bad throughput updates
    loss_bad = 0 # consecutive bad loss spikes

    # set model to training mode
    model.train()
    loader_iter = iter(train_loader)

    # progress bar
    pbar = tqdm(range(start_update, end_update + 1), desc=f"[{tag}]", unit="upd")
    
    for update in pbar:
        # forward & backward pass
        # reset gradients each step
        optimizer.zero_grad(set_to_none=True)

        # start timer
        upd_t0 = time.time()

        # accumulated loss (for gradient)
        loss_accum = 0.0 # track accumulated loss across ACCUM steeps

        # gradient accumulation
        for _ in range(ACCUM):
            try:
                xb, yb = next(loader_iter) # get batch
            except StopIteration:
                # restart dataloader when ephch exhausted
                loader_iter = iter(train_loader); xb, yb = next(loader_iter)

            # move to device (cuda)
            xb = xb.to(device, non_blocking=True)
            yb = yb.to(device, non_blocking=True).long()

            # forward pass
            # if we're using fp16 or df16
            if amp_dtype is not None:
                with torch.amp.autocast(device_type="cuda", dtype=amp_dtype):
                    logits, loss = forward_with_guaranteed_loss(model, xb, yb)
            # if "none"
            else:
                logits, loss = forward_with_guaranteed_loss(model, xb, yb)

            # normalize loss for accumulation
            loss = loss / ACCUM

            # backward pass
            if scaler and scaler.is_enabled():
                scaler.scale(loss).backward()
            else:
                loss.backward()

            # accumulate for logging (not for backprop)
            loss_accum += float(loss)

        # optimizer step
        # unscale gradients if AMP is on
        try:
            if scaler and scaler.is_enabled():
                scaler.unscale_(optimizer)
        except RuntimeError:
            pass # if grads already unscaled or there's an issue, just pass
        
        # gradient clipping (prevent exploding gradients)
        gnorm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)


        try:
            if scaler and scaler.is_enabled():
                scaler.step(optimizer) # AMP step
            else:
                optimizer.step() # regular step
        except AssertionError:
            # if AMP scale gets NaN, just skip update
            if scaler and scaler.is_enabled(): scaler.update()
            scheduler.step()
            continue
        
        # update scale if AMP
        if scaler and scaler.is_enabled():
            scaler.update()

        # update learning rate scheduler
        scheduler.step()

        # logging & metrics
        dt_upd  = max(time.time() - upd_t0, 1e-6) # elapsed time
        tps     = EFFECTIVE_TOKENS_PER_UPDATE / dt_upd # throughput
        tps_sm  = tps_ema.update(tps) # smoothed tps
        loss_sm = loss_ema.update(loss_accum * ACCUM) # smoothed loss
        gpu_gb  = (torch.cuda.max_memory_allocated() / 1e9) if torch.cuda.is_available() else 0.0
        amp_sc  = float(getattr(scaler, "get_scale", lambda: 1.0)()) # AMP scale factor

        # track tokens processed
        tokens_seen = update * EFFECTIVE_TOKENS_PER_UPDATE

        # log training row to CSV
        log_train_row(HISTORY_CSV, update, loss_sm, optimizer.param_groups[0]['lr'], tps_sm,
                      float(gnorm), float(amp_sc), tokens_seen)

        # establish baseline after warmup window
        seen = update - start_update + 1
        if seen == BASELINE_WINDOW:
            base_tps  = tps_sm
            base_loss = loss_sm

        # detect throughput drop
        if base_tps is not None and base_tps > 0:
            tps_bad = tps_bad + 1 if tps_sm < (1.0 - TOKENS_SEC_ALERT_DROP) * base_tps else 0

        # detect loss spike
        if base_loss is not None and base_loss > 0:
            loss_bad = loss_bad + 1 if loss_sm > LOSS_SPIKE_X * base_loss else 0

        # trigger alert if bad condition sustained
        if (tps_bad >= CONSEC_FOR_ALERT) or (loss_bad >= CONSEC_FOR_ALERT):
            print(f"\n⚠️  ALERT @ upd {update}: "
                  f"{'tps drop' if tps_bad>=CONSEC_FOR_ALERT else ''}"
                  f"{' and ' if tps_bad>=CONSEC_FOR_ALERT and loss_bad>=CONSEC_FOR_ALERT else ''}"
                  f"{'loss spike' if loss_bad>=CONSEC_FOR_ALERT else ''}. "
                  f"tps≈{tps_sm:,.0f} (base≈{base_tps:,.0f}), loss≈{loss_sm:.4f} (base≈{base_loss:.4f})")
            
            # run validation eval
            val_loss, val_ppl = evaluate(model, valid_loader, device, amp_dtype, max_batches=15)
            print(f"   immediate eval: val_loss={val_loss:.4f}  ppl={val_ppl:.2f}")

            # save emergency checkpoint & plot
            save_checkpoint(CKPT_DIR, manifest, MANIFEST_PATH, model, optimizer, scheduler, scaler, update, val_loss,
                            tag=f"alert_u{update:07d}")
            save_plot(HISTORY_CSV, PLOTS_DIR, EFFECTIVE_TOKENS_PER_UPDATE, save_tag=f"alert_u{update:07d}")
            tps_bad = loss_bad = 0 # reset counters

        # periodic logging
        if update % LOG_EVERY == 0:
            pbar.set_postfix(
                loss=f"{loss_sm:.4f}",
                lr=f"{optimizer.param_groups[0]['lr']:.2e}",
                tps=f"{tps_sm:,.0f}",
                gpu=f"{gpu_gb:.2f}GB",
                gnorm=f"{float(gnorm):.2f}",
                scale=f"{amp_sc:.0f}",
            )
            if torch.cuda.is_available():
                torch.cuda.reset_peak_memory_stats(device)
        
        # periodic evaluation
        if update % EVAL_EVERY == 0:
            val_loss, val_ppl = evaluate(model, valid_loader, device, amp_dtype, max_batches=15)
            print(f"\n[upd {update}] val_loss={val_loss:.4f}  ppl={val_ppl:.2f}")

            # log validation row to CSV
            log_val_row(HISTORY_CSV, update, val_loss, val_ppl, tokens_seen)

            # periodic checkpoint (not every step, but spaced)
            save_periodic(update, val_loss, every=PERIODIC_EVERY)
            register_val_result(update, val_loss)

            # update training/validation plot
            save_plot(HISTORY_CSV, PLOTS_DIR, EFFECTIVE_TOKENS_PER_UPDATE, save_tag=f"u{update:07d}")

        # increment global counter
        start_update = update + 1

    # final save (boundary)
    save_checkpoint(CKPT_DIR, manifest, MANIFEST_PATH, model, optimizer, scheduler, scaler,
                    start_update - 1, None, tag=f"u{start_update-1:07d}")
    
    # mark chunk completion in manifest
    if tag.startswith("chunk"):
        manifest["chunks_completed"] = int(manifest.get("chunks_completed", 0)) + 1
        MANIFEST_PATH.write_text(json.dumps(manifest, indent=2))

    print(f"✅ Finished {tag}. Reached update {start_update-1}.")

##### **RUN THIS IF YOU'RE TESTING RESUME**

In [29]:
TOTAL_UPDATES   = 300 # before it was 200

#### **ONLY FOR RESUME TEST ↑**

In [30]:

# plan driver (demo)
# how many updates for pilot phase
pilot_updates = max(1, int(PILOT_FRAC * TOTAL_UPDATES))

# how many updates per chunk (remainingupdates split into CHUNKS)
chunk_updates = max(1, (TOTAL_UPDATES - pilot_updates) // CHUNKS)

# pilot phase
if start_update <= pilot_updates:
    # if we haven't finished pilot yet, calculate remaining work
    todo = pilot_updates - (start_update - 1)
    train_for_updates(todo, tag="pilot")
else:
    # skip if we're already done
    print("Pilot already completed (resume detected).")

# chunked training phase
for i in range(1, CHUNKS + 1):
    # target update at the end of this chunk
    target_end = pilot_updates + i * chunk_updates

    
    if start_update <= target_end:
        # still work left in this chunk
        todo = target_end - (start_update - 1)
        train_for_updates(todo, tag=f"chunk{i}")
    else:
        # skip this chunk if already completed
        print(f"chunk{i} already completed (resume detected).")

# final logs
print("🏁 Full run plan completed (demo schedule).")

print("Artifacts:")
print(" - History CSV:", HISTORY_CSV)
print(" - Plots dir:  ", PLOTS_DIR)
print(" - Ckpts dir:  ", CKPT_DIR)


[pilot]:   0%|          | 0/15 [00:00<?, ?upd/s]

✅ Finished pilot. Reached update 15.


[chunk1]:   0%|          | 0/142 [00:00<?, ?upd/s]


[upd 50] val_loss=8.0667  ppl=3186.71
⭐ new best: 8.0667 @ upd 50
🖼️ saved plot → ..\pretrain_checkpoint_notebook\plots\loss_curve_u0000050.png

[upd 100] val_loss=7.8978  ppl=2691.40
💾 saved periodic checkpoint: u0000100
⭐ new best: 7.8978 @ upd 100
🖼️ saved plot → ..\pretrain_checkpoint_notebook\plots\loss_curve_u0000100.png

[upd 150] val_loss=7.6175  ppl=2033.39
⭐ new best: 7.6175 @ upd 150
🖼️ saved plot → ..\pretrain_checkpoint_notebook\plots\loss_curve_u0000150.png
✅ Finished chunk1. Reached update 157.


[chunk2]:   0%|          | 0/142 [00:00<?, ?upd/s]


[upd 200] val_loss=7.6171  ppl=2032.59
💾 saved periodic checkpoint: u0000200
⭐ new best: 7.6171 @ upd 200
🖼️ saved plot → ..\pretrain_checkpoint_notebook\plots\loss_curve_u0000200.png

[upd 250] val_loss=7.5308  ppl=1864.62
⭐ new best: 7.5308 @ upd 250
🖼️ saved plot → ..\pretrain_checkpoint_notebook\plots\loss_curve_u0000250.png
✅ Finished chunk2. Reached update 299.
🏁 Full run plan completed (demo schedule).
Artifacts:
 - History CSV: ..\pretrain_checkpoint_notebook\logs\history.csv
 - Plots dir:   ..\pretrain_checkpoint_notebook\plots
 - Ckpts dir:   ..\pretrain_checkpoint_notebook\checkpoints


##### IF YOU'RE TESTING RESUME, IT SHOULD SEE
```
Pilot already completed (resume detected).
chunk1 already completed (resume detected).
```
something like that <br/>
and your checkpoint should now have ```u0000299.pt``` in ```pretrain_results_test/checkpoints``` folder

#### **GO BACK TO "11. Loading checkpoints" TO TEST RESUME**




# (TL;DR)
what i've done for checkpoints: <br/>
suppose chunk = 4, total = 1050 (pilot= 50) <br/>
first run: <br/>
pilot: 50 <br/>
chunk 1: 250 <br/>
chunk 2: 250 <br/>
chunk 3: 250 <br/>
chunk 4: 250 <br/>
 <br/>
increase total = 2050 (suppose pilot= 50) <br/>
assuming we've done up to 1000 from first run, now <br/>
pilot: 50 <br/>
chunk 1: 500 <br/>
chunk 2: 500 <- we resume from somewhere in chunk 2 <br/>
chunk 3: 500 <br/>
chunk 4: 500 <br/>
 <br/>
so, yeah. if you decided to rerun and the chunk is different, THAT'S EXPECTED <br/>
since we're reallocating items for each chunk. <br/>


under ```python_files``` folder, you'll see ```train.py``` <br/>
that's basically 
- **everything in this notebook**
- **some modification** for ```args``` <br/>
- grabbing ```model.json``` from directory to use those configuration
- now, ctx is configurable (for pretraining + another pretraining for better result)
since we'll be training in terminal

(P.S.) Hey, i've done some experiment with ```train.py``` and now, the new one takes less then a minutes! (used to take 10 minutes)

# UPDATES
- **Fast attention kernels**: sdpa_ctx_fast() forces Flash/Mem-Efficient SDPA when available.
- **Allocator & math speedups**: PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True, TF32 enabled, cuDNN benchmark, optional torch.compile.
- **AMP modes**: runtime switch --amp {fp16,bf16,none} with proper GradScaler handling.
- **Better DataLoader**: worker seeding, pin_memory_device='cuda', prefetch_factor.
- **Optimizer & decay**: AdamW (fused if possible), 2-group weight-decay (no decay on norms/bias/1D).
- **Scheduler**: cosine with warmup (10%) + min_lr_ratio floor.
- **Resume safety**: strip module prefixes, fuse old q/k/v weights → qkv on load, resize pos_emb if ctx changes, rebuild LR tail if scheduler state missing.
- **Run integrity**: per run-dir config immutability (raises on mismatch).
- **Checkpointing**: light vs full payloads + manifest.json; Top-K best tracking.
- **Monitoring**: tokens/sec & loss EMAs, dataloader latency in tqdm, alerts on TPS drop or loss spike → immediate eval + save + plot.
- **Plotting & logs**: dual-axis loss plots (updates & tokens), CSV history.csv with train/val rows.
- **Plan execution**: pilot + chunks + tail schedule; per-N-update grad clipping; quick CUDA timing probe.

# DONE!
if you went through every cell (including resume) and don't see any error <br/>
you're good to go!