# Freeze the best base
What's the point? <br/>
- fixing the starting weights to ensure any gains/bugs during SFT (supervised fine tuning) come from SFT, not drift in later-pretrain checkpoints.
- pretrain can keep improving train loss while worsening valid. freezing the best and fully valid checkpoint prevents from accidently starting SFT from a worse step.
- keep SFT separate from pretrain runs (so we don't overrite), and gives a stable baseline for later tasks (yeah, we'll do SFT/Lora and DPO/Lora later on)

so, basically, this notebook will: <br/>
1. pick a true base checkpoint out of multiple .pt in checkpoints (i just decided to work with best.pt since that's proven to be best looking over the log during pretraining. Else, when we iterate every .pt it's gonna take significant amount of time)
2. evaluate best.pt with fp32
3. export evaluation result to BEST_BASE_README.md (it's just for showing the result. it'll not going to be used later on)
4. make a copy of best.pt as best_base.pt (starting point for SFT/Lora)

# NOTE 
the whole thing with small pretrain took around an hour <br/>
and with full pretrain took nearly 6 hours <br/> 
if you want to skip, just run `# NEVER RUN THIS IF YOU'VE DONE ABOVE` cell that's in the very bottom. <br/>
it's gonna do the same thing except for detailed logs and others <br/>
meaning we're going to completely trust that `best.pt`

## **If you're testing full, turn this to true**

In [1]:
isfull = True

# Configure necessary stuffs

In [5]:
import os, json, math, shutil, warnings
from pathlib import Path
from datetime import datetime

import numpy as np
import torch
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

import sys
sys.path.append("../python_files")
from Dummy_Model import DummyModel
from npy_datasets import NpyTokensDataset

if isfull:
    RUN_DIR       = Path("../pretrain_checkpoint")
    VALID_BLOCKS  = Path("../materials/valid_blocks.npy")
else:
    RUN_DIR       = Path("../pretrain_checkpoint_small")
    VALID_BLOCKS  = Path("../materials_small/valid_blocks.npy")

CONFIG_PATH   = Path("../configs/model_config_124M.json")
BEST_CKPT     = RUN_DIR / "checkpoints" / "best.pt"

BEST_BASE_PT  = RUN_DIR / "best_base.pt"
README_PATH   = RUN_DIR / "BEST_BASE_README.md"

RUN_DIR.mkdir(parents=True, exist_ok=True)

SEED = 1337
torch.manual_seed(SEED)
np.random.seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
try:
    torch.set_float32_matmul_precision("high")  # faster matmuls on Ampere+, preserves FP32 numerics
except Exception:
    pass

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE, "| Torch:", torch.__version__)


Device: cuda | Torch: 2.6.0+cu124


# Evaluation
What's happening here is:
1. W have rows of token numbers (a grid)
2. Pick a window length T (in our case, 512 since that's context size)
3. Cut each row into non-overlapping chunks of size T
4. for each chunk, make: x = tokens[s: s+T], y = tokens[s+1: s+T+1]
    - (Shift by 1 so the model learns the next token)
5. number of samples: rows x floor((L-1)/T), where L is row length
6. Order is fixed (no shuffle) -> reproducibe result
7. Output type is int64 so it fits the model's embedding layer
8. DataLoader just batches them (like 12 at a time) for faster eval

for example: if T=4 <br/>
Row:`[t0 t1 t2 t3 t4 t5 t6 t7 t8 t9]` <br/>
if x: `[t0 t1 t2 t3]`, y: `[t1 t2 t3 t4]` <br/>
=> x: `[t4 t5 t6 t7]`, y: `[t5 t6 t7 t8]` (the next tokens)

In [6]:
class EvalTokensDataset(NpyTokensDataset):
    def __len__(self):
        if self._ndim == 1:
            return max(0, (self._shape[0] - 1) // self.seq_len)
        rows, L = self._shape
        return max(0, rows * ((L - 1) // self.seq_len))

    def __getitem__(self, idx):
        arr = self._arr
        T = self.seq_len
        if self._ndim == 1:
            L = self._shape[0]
            s = min(idx * T, L - T - 1)
            x = arr[s:s+T].astype(np.int64, copy=False)
            y = arr[s+1:s+T+1].astype(np.int64, copy=False)
        else:
            rows, L = self._shape
            per_row = (L - 1) // T
            r = idx % rows
            k = (idx // rows) % per_row
            s = min(k * T, L - T - 1)
            row = arr[r]
            x = row[s:s+T].astype(np.int64, copy=False)
            y = row[s+1:s+T+1].astype(np.int64, copy=False)
        return torch.from_numpy(x), torch.from_numpy(y)

# build dataloader
with open(CONFIG_PATH, "r") as f:
    cfg = json.load(f)
SEQ_LEN = int(cfg.get("context_length", 512))

# deterministic, full-valid dataset
eval_ds = EvalTokensDataset(str(VALID_BLOCKS), SEQ_LEN)

# since i'm working on windows, workers=0 to avoid multiprocessing spawn issues
NUM_WORKERS = 0
BATCH_SIZE  = 12

full_dl = DataLoader(
    eval_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    persistent_workers=(NUM_WORKERS > 0)
)

print(f"Valid windows: {len(eval_ds)} | Seq len: {SEQ_LEN} | Batch size: {BATCH_SIZE}")


Valid windows: 360053 | Seq len: 512 | Batch size: 12


# Evaluation 2
here's the plan. for evaluation:
- first pass: fp16 + AMP with fp32 weights
- second pass: fp32 with tf32 allowed

so yeah. matmuls/convs will execute in tf32 for speed

In [9]:
# ---------------- Safe checkpoint loader (robust) ----------------
import warnings, torch
from pathlib import Path

def safe_load_state_dict(ckpt_path: Path):
    """
    Loads a model state_dict from common checkpoint layouts:
    - plain state_dict
    - {"model": state_dict}
    - {"state_dict": state_dict}
    - {"module": state_dict}  (DDP-style)
    - {"ema": state_dict}     (if you saved EMA-only)
    Tries weights_only=True first (PyTorch ≥2.4), falls back cleanly.
    """
    def _torch_load(weights_only=True):
        try:
            return torch.load(ckpt_path, map_location="cpu", weights_only=weights_only)
        except TypeError:
            # Old PyTorch without weights_only kwarg
            return torch.load(ckpt_path, map_location="cpu")

    obj = _torch_load(weights_only=True)

    # If it's already a raw state_dict (tensor values), return as-is
    if isinstance(obj, dict) and all(isinstance(k, str) for k in obj.keys()):
        # Heuristics: dive into common containers
        for key in ("model", "state_dict", "module", "ema", "model_ema"):
            inner = obj.get(key, None)
            if isinstance(inner, dict) and len(inner) > 0 and all(isinstance(k, str) for k in inner.keys()):
                return inner
        return obj  # looks like a proper state_dict already

    # As a last resort, try loading without weights_only (older pickles)
    obj = _torch_load(weights_only=False)
    if isinstance(obj, dict):
        for key in ("model", "state_dict", "module", "ema", "model_ema"):
            inner = obj.get(key, None)
            if isinstance(inner, dict):
                return inner
        return obj
    raise RuntimeError(f"Unrecognized checkpoint format at {ckpt_path}")


In [10]:
# --------------- Best-practice full-valid scorer ----------------
import math, torch, torch.nn.functional as F
from contextlib import nullcontext

IGNORE_INDEX = -100  # keep consistent with your project

def _extract_logits_and_loss(out):
    """Accepts (logits, loss), logits-only, or obj/dict with .logits/.loss."""
    logits, loss = None, None
    if isinstance(out, (tuple, list)):
        logits = out[0]
        loss   = out[1] if len(out) > 1 else None
    elif isinstance(out, dict):
        logits = out.get("logits", None)
        loss   = out.get("loss", None)
    else:
        logits = getattr(out, "logits", out)
        loss   = getattr(out, "loss", None)
    return logits, loss

@torch.inference_mode()
def forward_with_guaranteed_loss(model, xb, yb):
    """
    Always returns (logits, loss). Prefers model(..., labels=yb).
    Falls back to manual CE if loss is None. Respects IGNORE_INDEX.
    """
    out = model(xb, labels=yb)  # explicit kwarg avoids positional API pitfalls
    logits, loss = _extract_logits_and_loss(out)
    if loss is None:
        V = logits.size(-1)
        loss = F.cross_entropy(
            logits.reshape(-1, V),
            yb.reshape(-1),
            ignore_index=IGNORE_INDEX,
            reduction="mean",
        )
    return logits, loss

@torch.inference_mode()
def score_model_full_valid(model: torch.nn.Module,
                           loader,
                           device: torch.device,
                           precision: str = "fp32",     # {"fp32","amp_fp16"}
                           strict_fp32: bool = False,
                           show_pbar: bool = True):
    """
    Evaluates masked average loss (ignores targets == IGNORE_INDEX) and perplexity.
    Restores TF32 flags and training mode afterwards.
    """
    # preserve train/eval and TF32 flags
    was_training = model.training
    prev_matmul_tf32 = torch.backends.cuda.matmul.allow_tf32 if device.type == "cuda" else None
    prev_cudnn_tf32  = torch.backends.cudnn.allow_tf32        if device.type == "cuda" else None

    model.eval().to(device)

    try:
        # TF32 policy
        if device.type == "cuda":
            if precision == "fp32" and strict_fp32:
                torch.backends.cuda.matmul.allow_tf32 = False
                torch.backends.cudnn.allow_tf32 = False
            else:
                torch.backends.cuda.matmul.allow_tf32 = True
                torch.backends.cudnn.allow_tf32 = True

        use_amp = (precision == "amp_fp16" and device.type == "cuda")
        amp_ctx = torch.autocast(device_type="cuda", dtype=torch.float16) if use_amp else nullcontext()

        total_loss = 0.0
        total_kept = 0

        it = loader
        if show_pbar:
            try:
                from tqdm.auto import tqdm
                it = tqdm(loader, total=len(loader),
                          desc=f"full-valid ({precision})",
                          unit="batch", dynamic_ncols=True, leave=False)
            except Exception:
                pass

        for xb, yb in it:
            xb = xb.to(device, non_blocking=True).long()
            yb = yb.to(device, non_blocking=True).long()

            with amp_ctx:
                _, loss = forward_with_guaranteed_loss(model, xb, yb)

            kept = (yb != IGNORE_INDEX).sum().item()
            if kept == 0:  # e.g., pure pretrain eval with no mask
                kept = yb.numel()

            total_loss += float(loss) * kept
            total_kept += kept

            if show_pbar and hasattr(it, "set_postfix"):
                avg = total_loss / max(1, total_kept)
                it.set_postfix(avg_loss=f"{avg:.4f}", ppl=f"{math.exp(avg):.2f}")

        avg_loss = total_loss / max(1, total_kept)
        ppl = math.exp(min(20.0, avg_loss))  # clamp to avoid inf on spikes
        return avg_loss, ppl

    finally:
        # restore flags and mode
        model.train(was_training)
        if device.type == "cuda":
            torch.backends.cuda.matmul.allow_tf32 = prev_matmul_tf32
            torch.backends.cudnn.allow_tf32 = prev_cudnn_tf32


# Full validation & rescore

In [11]:
# making sure that we do have best.pt 
assert BEST_CKPT.is_file(), f"Missing BEST_CKPT: {BEST_CKPT}"

# build model and move to eval mode
model = DummyModel(cfg).to(DEVICE).eval()

# load serialized weights
sd = safe_load_state_dict(BEST_CKPT)

# enforce an exact key match to catch config/arch mismatches 
# load_state_dict returns (missing_keys, unexpected_keys) here
# you should see nothing in the both brackets
missing, unexpected = model.load_state_dict(sd, strict=True)
assert not missing and not unexpected, f"State mismatch — missing: {missing}, unexpected: {unexpected}"

# container for results across percisions
results = {}

# fast pass (AMP fp16 if cuda, otherwise fall back to fp32)
prec = "amp_fp16" if DEVICE.type == "cuda" else "fp32"
val_loss_fp16, val_ppl_fp16 = score_model_full_valid(model, full_dl, DEVICE, precision=prec, show_pbar=True)

# on CPU, the key will be fp32 here and will be overwritten by gold pass below
# but that's fine since both runs are fp32 in that case
results["amp_fp16" if DEVICE.type == "cuda" else "fp32"] = (float(val_loss_fp16), float(val_ppl_fp16))

# gold pass (fp32)
# set scrict_fp32=True if you want to forbid tf32 on Ampere+
# but since i'm running in RTX3060 or RTX3080, i'm gonna set to false (cuz it's fast)
val_loss_fp32, val_ppl_fp32 = score_model_full_valid(model, full_dl, DEVICE, precision="fp32", strict_fp32=False, show_pbar=True)
results["fp32"] = (float(val_loss_fp32), float(val_ppl_fp32))

print("\n[best.pt] full-valid results:")
for k, (L, P) in results.items():
    print(f"  {k:<8} — loss={L:.6f}  ppl={P:.2f}")


full-valid (amp_fp16):   0%|          | 0/30005 [00:00<?, ?batch/s]

full-valid (fp32):   0%|          | 0/30005 [00:00<?, ?batch/s]


[best.pt] full-valid results:
  amp_fp16 — loss=4.896934  ppl=133.88
  fp32     — loss=4.896926  ppl=133.88


# Save results

In [12]:
# copy original .pt as canonical base (keeps training metadata layout)
shutil.copy2(BEST_CKPT, BEST_BASE_PT)

# write a tiny README (provenance + both metrics)
loss_fp16, ppl_fp16 = results.get("amp_fp16", results["fp32"])
loss_fp32, ppl_fp32 = results["fp32"]

readme = f"""# Best Base (Frozen)
- Date: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
- Source checkpoint: {BEST_CKPT.name}
- Full-valid (amp_fp16): loss={loss_fp16:.6f}, ppl={ppl_fp16:.2f}
- Full-valid (fp32)    : loss={loss_fp32:.6f}, ppl={ppl_fp32:.2f}
- Seed: {SEED}
- Context length: {SEQ_LEN}
- Batch size (eval): {BATCH_SIZE}
- Notes: Deterministic sliding-window evaluation; amp_fp16 for speed + fp32 for reproducibility.
- Artifacts: {'best_base.pt'}
"""
with open(README_PATH, "w", encoding="utf-8") as f:
    f.write(readme)

print(f"✅ Frozen best base at: {BEST_BASE_PT}")
print(f"📝 Wrote: {README_PATH}")


✅ Frozen best base at: ..\pretrain_checkpoint\best_base.pt
📝 Wrote: ..\pretrain_checkpoint\BEST_BASE_README.md


# NEVER RUN THIS IF YOU'VE DONE ABOVE

In [None]:
# superfast: trust best.pt and freeze now (no evaluation)
from pathlib import Path
import shutil
from datetime import datetime

assert BEST_CKPT.is_file(), f"Missing BEST_CKPT: {BEST_CKPT}"
shutil.copy2(BEST_CKPT, BEST_BASE_PT)

readme = f"""# Best Base (Frozen) — Quick Freeze (no eval)
- Date: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
- Source checkpoint: {BEST_CKPT.name}
- Notes: Quick dev freeze (no deterministic validation run).
- Artifacts: {'best_base.pt'}
"""
with open(README_PATH, "w", encoding="utf-8") as f:
    f.write(readme)

print(f"✅ Frozen best base at: {BEST_BASE_PT}")
print(f"📝 Wrote: {README_PATH}")
