**Tokenizer BPE normal**

In [39]:
from __future__ import annotations
import os, json
from pathlib import Path
from typing import List, Union

try:
    from tokenizers import ByteLevelBPETokenizer, Tokenizer
except Exception:
    ByteLevelBPETokenizer = None

class BPETokenizer:
    """Minimal BPE wrapper (HuggingFace tokenizers).
    Trains on a text file or a folder of .txt files. Saves merges/vocab to out_dir.
    """
    def __init__(self, vocab_size: int = 32000, special_tokens: List[str] | None = None):
        if ByteLevelBPETokenizer is None:
            raise ImportError("Please `pip install tokenizers` for BPETokenizer.")
        self.vocab_size = vocab_size
        self.special_tokens = special_tokens or ["<s>", "</s>", "<pad>", "<unk>", "<mask>"]
        self._tok = None

    def train(self, data_path: Union[str, Path]):
        files: List[str] = []
        p = Path(data_path)
        if p.is_dir():
            files = [str(fp) for fp in p.glob("**/*.txt")]
        else:
            files = [str(p)]
        tok = ByteLevelBPETokenizer()
        tok.train(files=files, vocab_size=self.vocab_size, min_frequency=2, special_tokens=self.special_tokens)
        self._tok = tok

    def save(self, out_dir: Union[str, Path]):
        out = Path(out_dir); out.mkdir(parents=True, exist_ok=True)
        assert self._tok is not None, "Train or load before save()."
        self._tok.save_model(str(out))
        self._tok.save(str(out / "tokenizer.json"))
        meta = {"vocab_size": self.vocab_size, "special_tokens": self.special_tokens}
        (out/"bpe_meta.json").write_text(json.dumps(meta))

    def load(self, dir_path: Union[str, Path]):
        dirp = Path(dir_path)
        # Prefer explicit filenames; fall back to glob if needed.
        vocab = dirp / "vocab.json"
        merges = dirp / "merges.txt"
        tokenizer = dirp / "tokenizer.json"
        if not vocab.exists() or not merges.exists():
            # Fallback for custom basenames
            vs = list(dirp.glob("*.json"))
            ms = list(dirp.glob("*.txt"))
            if not vs or not ms:
                raise FileNotFoundError(f"Could not find vocab.json/merges.txt in {dirp}")
            vocab = vs[0]
            merges = ms[0]
        # tok = ByteLevelBPETokenizer(str(vocab), str(merges))
        tok = Tokenizer.from_file(str(tokenizer))
        self._tok = tok
        meta_file = dirp / "bpe_meta.json"
        if meta_file.exists():
            meta = json.loads(meta_file.read_text())
            self.vocab_size = meta.get("vocab_size", self.vocab_size)
            self.special_tokens = meta.get("special_tokens", self.special_tokens)


    def encode(self, text: str):
        ids = self._tok.encode(text).ids
        return ids

    def decode(self, ids):
        return self._tok.decode(ids)

**LR_scheduler**

In [40]:
import math

class WarmupCosineLR:
    """Linear warmup → cosine decay (per-step API)."""
    def __init__(self, optimizer, warmup_steps: int, total_steps: int, base_lr: float):
        self.optimizer = optimizer
        self.warmup_steps = max(1, warmup_steps)
        self.total_steps = max(self.warmup_steps+1, total_steps)
        self.base_lr = base_lr
        self.step_num = 0
    def step(self):
        self.step_num += 1
        if self.step_num <= self.warmup_steps:
            lr = self.base_lr * self.step_num / self.warmup_steps
        else:
            progress = (self.step_num - self.warmup_steps) / (self.total_steps - self.warmup_steps)
            lr = 0.5 * self.base_lr * (1.0 + math.cos(math.pi * progress))
        for g in self.optimizer.param_groups:
            g['lr'] = lr
        return lr

**dataset_bpe**

In [41]:
from __future__ import annotations
import torch
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
from typing import Tuple
# from tokenizer_bpe import BPETokenizer

class TextBPEBuffer(Dataset):
    """Memory-mapped-ish single-file dataset: tokenize once → long tensor of ids.
    get(idx) returns a (block_size,) slice; we construct (x,y) with shift inside collate.
    """
    def __init__(self, path: str, tokenizer: BPETokenizer, block_size: int = 256):
        super().__init__()
        self.block_size = block_size
        text = Path(path).read_text(encoding='utf-8')
        self.ids = torch.tensor(tokenizer.encode(text), dtype=torch.long)
    def __len__(self):
        return max(0, self.ids.numel() - self.block_size - 1)
    def __getitem__(self, i: int):
        x = self.ids[i:i+self.block_size]
        y = self.ids[i+1:i+self.block_size+1]
        return x, y

def make_loader(path: str, tokenizer: BPETokenizer, block_size: int, batch_size: int, shuffle=True) -> DataLoader:
    ds = TextBPEBuffer(path, tokenizer, block_size)
    return DataLoader(ds, batch_size=batch_size, shuffle=shuffle, drop_last=True)

**amp_accum**

In [42]:
import torch

class AmpGrad:
    """AMP + gradient accumulation wrapper.
    Usage:
        amp = AmpGrad(optimizer, accum=4, amp=True)
        amp.backward(loss)
        if amp.should_step(): amp.step(); amp.zero_grad()
    """
    def __init__(self, optimizer, accum: int = 1, amp: bool = True):
        self.optim = optimizer
        self.accum = max(1, accum)
        self.amp = amp and torch.cuda.is_available()
        self.scaler = torch.cuda.amp.GradScaler(enabled=self.amp)
        self._n = 0
    def backward(self, loss: torch.Tensor):
        loss = loss / self.accum
        if self.amp:
            self.scaler.scale(loss).backward()
        else:
            loss.backward()
        self._n += 1
    def should_step(self):
        return (self._n % self.accum) == 0
    def step(self):
        if self.amp:
            self.scaler.step(self.optim)
            self.scaler.update()
        else:
            self.optim.step()
    def zero_grad(self):
        self.optim.zero_grad(set_to_none=True)

**checkpointing**

In [43]:
from __future__ import annotations
import os
from pathlib import Path
from typing import Any, Dict, Optional, Tuple
import sys
# sys.path.append(str(Path(__file__).resolve().parents[1]/'part_3'))
import time
import torch
import shutil
import torch.nn as nn

DEF_NAME = "model_last.pt"

# ----------------------------- TB-only helpers (safe no-ops otherwise) ----------------------------- #
def _is_tb(logger) -> bool:
    return getattr(logger, "w", None) is not None


# checkpointing._log_hparams_tb
def _log_hparams_tb(logger, args, total_steps):
    if not _is_tb(logger): return
    try:
        h = dict(
            vocab_size=args.vocab_size, block_size=args.block_size, n_layer=args.n_layer,
            n_head=args.n_head, n_embd=args.n_embd, dropout=args.dropout, lr=args.lr,
            warmup_steps=args.warmup_steps, batch_size=args.batch_size, grad_accum=args.grad_accum_steps,
            mixed_precision=args.mixed_precision, steps=args.steps, epochs=args.epochs,
        )
        logger.hparams(h, {"meta/total_steps": float(total_steps)})
    except Exception:
        pass

def _maybe_log_graph_tb(logger, model, xb, yb):
    if not hasattr(logger, "graph"):
        return
    try:
        class _TensorOnly(nn.Module):
            def __init__(self, m):
                super().__init__(); self.m = m.eval()
            def forward(self, x, y=None):
                out = self.m(x, y) if y is not None else self.m(x)
                if isinstance(out, (list, tuple)):
                    for o in out:
                        if torch.is_tensor(o):
                            return o
                    return out[0]
                return out
        wrapped = _TensorOnly(model).to(xb.device)
        logger.graph(wrapped, (xb, yb))
    except Exception:
        pass

def _log_model_stats(logger, model, step: int, do_hists: bool = False):
    if not _is_tb(logger): return
    try:
        params = [p for p in model.parameters() if p.requires_grad]
        total_param_norm = torch.norm(torch.stack([p.detach().norm(2) for p in params]), 2).item()
        grads = [p.grad for p in params if p.grad is not None]
        total_grad_norm = float('nan')
        if grads:
            total_grad_norm = torch.norm(torch.stack([g.detach().norm(2) for g in grads]), 2).item()
        logger.log(step=step, **{
            "train/param_global_l2": total_param_norm,
            "train/grad_global_l2": total_grad_norm,
        })
        if do_hists:
            for name, p in model.named_parameters():
                logger.hist(f"params/{name}", p, step)
                if p.grad is not None:
                    logger.hist(f"grads/{name}", p.grad, step)
    except Exception:
        pass

def _maybe_log_attention(logger, model, xb, step: int, every: int = 100):
    """
    Logs Q/K/V histograms for each Transformer block using the current minibatch xb.
    No model edits. No hooks. Runs a light no-grad recomputation of the pre-attn path.
    - Takes first batch and first head only to keep logs tiny.
    - Uses pre-RoPE values (simpler & stable for histograms).
    """
    if not _is_tb(logger) or step == 0 or (step % every):
        return
    try:
        import torch
        with torch.no_grad(), torch.cuda.amp.autocast(enabled=False):
            # Recreate inputs seen by blocks
            x = model.tok_emb(xb)           # (B,T,C)
            x = model.drop(x)

            B, T, _ = x.shape
            for li, blk in enumerate(getattr(model, "blocks", [])):
                h = blk.ln1(x)              # pre-attn normalized hidden

                attn = blk.attn
                # Project to Q/K/V exactly like the module (pre-RoPE for simplicity)
                q = attn.wq(h).view(B, T, attn.n_head,   attn.d_head).transpose(1, 2)      # (B,H,T,D)
                k = attn.wk(h).view(B, T, attn.n_kv_head, attn.d_head).transpose(1, 2)     # (B,Hk,T,D)
                v = attn.wv(h).view(B, T, attn.n_kv_head, attn.d_head).transpose(1, 2)     # (B,Hk,T,D)

                # Take a tiny slice to keep logs light
                q1 = q[:1, :1].contiguous().view(-1).float().cpu()
                k1 = k[:1, :1].contiguous().view(-1).float().cpu()
                v1 = v[:1, :1].contiguous().view(-1).float().cpu()

                # Drop non-finite (defensive)
                q1 = q1[torch.isfinite(q1)]
                k1 = k1[torch.isfinite(k1)]
                v1 = v1[torch.isfinite(v1)]

                if q1.numel() > 0: logger.hist(f"qkv/block{li}/q_hist", q1, step)
                if k1.numel() > 0: logger.hist(f"qkv/block{li}/k_hist", k1, step)
                if v1.numel() > 0: logger.hist(f"qkv/block{li}/v_hist", v1, step)

                # Optional small scalars (norms) that show up on Time Series
                if q1.numel(): logger.log(step=step, **{f"qkv/block{li}/q_l2_mean": float(q1.square().mean().sqrt())})
                if k1.numel(): logger.log(step=step, **{f"qkv/block{li}/k_l2_mean": float(k1.square().mean().sqrt())})
                if v1.numel(): logger.log(step=step, **{f"qkv/block{li}/v_l2_mean": float(v1.square().mean().sqrt())})

                # Advance x to next block with a CHEAP approximation to avoid doubling full compute:
                # use the model's own FFN path only; skip re-running attention (we're only logging pre-attn stats).
                x = x + blk.ffn(blk.ln2(x))

    except Exception as e:
        print(f"[qkv] logging failed: {e}")


def _log_runtime(logger, step: int, it_t0: float, xb, device):
    try:
        dt = time.time() - it_t0
        toks = int(xb.numel())
        toks_per_s = toks / max(dt, 1e-6)
        mem = torch.cuda.memory_allocated()/(1024**2) if torch.cuda.is_available() else 0.0
        logger.log(step=step, **{
            "sys/throughput_tokens_per_s": toks_per_s,
            "sys/step_time_s": dt,
            "sys/gpu_mem_alloc_mb": mem
        })
    except Exception:
        pass

def _log_samples_tb(logger, model, tok, xb, device, step: int, max_new_tokens: int = 64):
    if not _is_tb(logger): return
    if tok is None: return
    try:
        model.eval()
        with torch.no_grad():
            out = model.generate(xb[:1].to(device), max_new_tokens=max_new_tokens, temperature=1.0, top_k=50)
        model.train()
        text = tok.decode(out[0].tolist())
        logger.text("samples/generation", text, step)
    except Exception:
        pass
# ---------------------------------------------------------------------- #

def _extract_config_from_model(model) -> dict:
    """
    Best-effort extraction of GPTModern-like config including GQA fields.
    """
    cfg = {}
    try:
        tok_emb = getattr(model, "tok_emb", None)
        blocks = getattr(model, "blocks", None)
        if tok_emb is None or not blocks:
            return cfg

        try:
            from swiglu import SwiGLU  # optional
        except Exception:
            class SwiGLU: pass

        cfg["vocab_size"] = int(tok_emb.num_embeddings)
        cfg["block_size"]  = int(getattr(model, "block_size", 0) or 0)
        cfg["n_layer"]     = int(len(blocks))

        first_blk = blocks[0]
        attn = getattr(first_blk, "attn", None)
        if attn is None:
            return cfg

        # Heads & dims
        cfg["n_head"]   = int(getattr(attn, "n_head"))
        d_head          = int(getattr(attn, "d_head"))
        cfg["n_embd"]   = int(cfg["n_head"] * d_head)
        cfg["n_kv_head"]= int(getattr(attn, "n_kv_head", cfg["n_head"]))  # default to MHA

        # Dropout (if present)
        drop = getattr(attn, "dropout", None)
        cfg["dropout"] = float(getattr(drop, "p", 0.0)) if drop is not None else 0.0

        # Norm/FFN style
        cfg["use_rmsnorm"] = isinstance(getattr(model, "ln_f", None), nn.Identity)
        cfg["use_swiglu"]  = isinstance(getattr(first_blk, "ffn", None), SwiGLU)

        # Positional / attention tricks
        for k in ("rope", "max_pos", "sliding_window", "attention_sink"):
            if hasattr(attn, k):
                val = getattr(attn, k)
                cfg[k] = int(val) if isinstance(val, bool) else val
    except Exception:
        return {}
    return cfg

def _verify_model_matches(model, cfg: Dict[str, Any]) -> Tuple[bool, str]:
    """Return (ok, message)."""
    expected = {
        "block_size": cfg.get("block_size"),
        "n_layer":    cfg.get("n_layer"),
        "n_head":     cfg.get("n_head"),
        "n_embd":     cfg.get("n_embd"),
        "vocab_size": cfg.get("vocab_size"),
        "n_kv_head":  cfg.get("n_kv_head", cfg.get("n_head")),
    }
    got = {
        "block_size": int(getattr(model, "block_size", -1)),
        "n_layer":    int(len(model.blocks)),
        "vocab_size": int(model.tok_emb.num_embeddings),
    }
    first_blk = model.blocks[0]
    got.update({
        "n_head":     int(first_blk.attn.n_head),
        "n_embd":     int(first_blk.attn.n_head * first_blk.attn.d_head),
        "n_kv_head":  int(getattr(first_blk.attn, "n_kv_head", first_blk.attn.n_head)),
    })
    diffs = [f"{k}: ckpt={expected[k]} vs model={got[k]}" for k in expected if expected[k] != got[k]]
    if diffs:
        return False, "Architecture mismatch:\n  " + "\n  ".join(diffs)
    return True, "ok"


def save_checkpoint(model, optimizer, scheduler, amp, step: int, out_dir: str,
                    tokenizer_dir: str | None = None, config: dict | None = None):
    out = Path(out_dir); out.mkdir(parents=True, exist_ok=True)

    # Prefer the model’s own config if available (e.g., a dict or dataclass with __dict__/asdict)
    if hasattr(model, "config"):
        cfg_obj = model.config
        cfg = dict(cfg_obj) if isinstance(cfg_obj, dict) else getattr(cfg_obj, "__dict__", None) or _extract_config_from_model(model)
    else:
        cfg = config if config is not None else _extract_config_from_model(model)

    torch.save({
        "model": model.state_dict(),
        "optimizer": optimizer.state_dict() if optimizer is not None else None,
        "scheduler": scheduler.state_dict() if hasattr(scheduler, "state_dict") else None,
        "amp_scaler": amp.scaler.state_dict() if amp and getattr(amp, "scaler", None) else None,
        "step": int(step),
        "config": cfg,   # ← always write config
        "version": "part4-v2",
    }, out / DEF_NAME)

    if tokenizer_dir is not None:
        (out / "tokenizer_dir.txt").write_text(tokenizer_dir)



def load_checkpoint(model, path: str, optimizer=None, scheduler=None, amp=None, strict: bool = True):
    ckpt = torch.load(path, map_location="cpu")

    cfg = ckpt.get("config")
    if cfg:
        ok, msg = _verify_model_matches(model, cfg)
        if not ok:
            raise RuntimeError(msg + "\nRebuild the model with this config, or load with strict=False.")
    else:
        # Legacy checkpoint without config: strongly encourage a rebuild step elsewhere
        print("[compat] Warning: checkpoint has no config; cannot verify architecture.")

    missing, unexpected = model.load_state_dict(ckpt["model"], strict=strict)
    if strict and (missing or unexpected):
        raise RuntimeError(f"State dict mismatch:\n  missing: {missing}\n  unexpected: {unexpected}")

    if optimizer is not None and ckpt.get("optimizer") is not None:
        optimizer.load_state_dict(ckpt["optimizer"])
    if scheduler is not None and ckpt.get("scheduler") is not None and hasattr(scheduler, "load_state_dict"):
        scheduler.load_state_dict(ckpt["scheduler"])
    if amp is not None and ckpt.get("amp_scaler") is not None and getattr(amp, "scaler", None):
        amp.scaler.load_state_dict(ckpt["amp_scaler"])

    return ckpt.get("step", 0)


# ----------------------------- checkpoint/save utils ----------------------------- #
def checkpoint_paths(out_dir: Path, step: int):
    return out_dir / f"model_step{step:07d}.pt", out_dir / "model_last.pt"

def atomic_save_all(model, optim, sched, amp, step: int, out_dir: Path,
                    tok_dir: str | None, keep_last_k: int, config: dict):
    """Write model_last.pt (with config) + a rolling per-step copy."""
    save_checkpoint(model, optim, sched, amp, step, str(out_dir), tok_dir, config=config)  # writes model_last.pt
    per_step, last = checkpoint_paths(out_dir, step)
    try:
        shutil.copy2(last, per_step)
    except Exception:
        pass
    # GC old per-step checkpoints
    try:
        ckpts = sorted(out_dir.glob("model_step*.pt"))
        for old in ckpts[:-keep_last_k]:
            old.unlink(missing_ok=True)
    except Exception:
        pass

**logger**

In [44]:
from __future__ import annotations
import time
from pathlib import Path

class NoopLogger:
    def log(self, **kwargs):
        pass
    def close(self):
        pass

class TBLogger(NoopLogger):
    """
    Backward compatible:
      - logger.log(step=..., loss=..., lr=...)
    Extras you can optionally use:
      - logger.hist("params/wte.weight", tensor, step)
      - logger.text("samples/generation", text, step)
      - logger.image("attn/heatmap", HWC_or_CHW_tensor_or_np, step)
      - logger.graph(model, example_batch)
      - logger.hparams(dict_of_config, dict_of_metrics_once)
      - logger.flush()
    Auto-behavior:
      - If a value in .log(...) is a tensor/ndarray with >1 element, it logs a histogram.
      - If key starts with "text/", logs as text.
    """
    # logger.py
    def __init__(self, out_dir: str, flush_secs: int = 10, run_name: str | None = None):
        self.w = None
        self.hparams_logged = False
        run_name = run_name or time.strftime("%Y%m%d-%H%M%S")
        run_dir = Path(out_dir) / run_name
        run_dir.mkdir(parents=True, exist_ok=True)
        try:
            from torch.utils.tensorboard import SummaryWriter
            self.w = SummaryWriter(log_dir=str(run_dir), flush_secs=flush_secs)
        except Exception as e:
            print(f"[TBLogger] TensorBoard not available: {e}. Logging disabled.")
        self._auto_hist_max_elems = 2048
        self.run_dir = str(run_dir)  # handy for prints/debug



    # ---------- backwards-compatible ----------
    def log(self, step: Optional[int] = None, **kv: Any):
        if not self.w: return
        for k, v in kv.items():
            # text channel (opt-in via key prefix)
            if isinstance(k, str) and k.startswith("text/"):
                try:
                    self.w.add_text(k[5:], str(v), global_step=step)
                except Exception:
                    pass
                continue

            # scalar vs histogram auto-route
            try:
                import torch, numpy as np  # lazy
                is_torch = isinstance(v, torch.Tensor)
                is_np = isinstance(v, np.ndarray)
                if is_torch or is_np:
                    # scalar?
                    numel = int(v.numel() if is_torch else v.size)
                    if numel == 1:
                        val = (v.item() if is_torch else float(v))
                        self.w.add_scalar(k, float(val), global_step=step)
                    else:
                        # small-ish tensors => histogram
                        if numel <= self._auto_hist_max_elems:
                            self.w.add_histogram(k, v.detach().cpu() if is_torch else v, global_step=step)
                        else:
                            # fall back to scalar summary stats
                            arr = v.detach().cpu().flatten().numpy() if is_torch else v.flatten()
                            self.w.add_scalar(k + "/mean", float(arr.mean()), global_step=step)
                            self.w.add_scalar(k + "/std", float(arr.std()), global_step=step)
                    continue
            except Exception:
                pass

            # number-like
            try:
                self.w.add_scalar(k, float(v), global_step=step)
            except Exception:
                # swallow non-numeric junk silently (same behavior as before)
                pass

    # ---------- nice-to-have helpers ----------
    def hist(self, tag: str, values: Any, step: Optional[int] = None, bins: str = "tensorflow"):
        if not self.w: return
        try:
            import torch
            if isinstance(values, torch.Tensor):
                values = values.detach().cpu()
            self.w.add_histogram(tag, values, global_step=step, bins=bins)
        except Exception:
            pass

    def text(self, tag: str, text: str, step: Optional[int] = None):
        if not self.w: return
        try:
            self.w.add_text(tag, text, global_step=step)
        except Exception:
            pass

    def image(self, tag: str, img, step: Optional[int] = None):
        """
        img: torch.Tensor [C,H,W] or [H,W,C] or numpy array
        """
        if not self.w: return
        try:
            self.w.add_image(tag, img, global_step=step, dataformats="CHW" if getattr(img, "ndim", 0) == 3 and img.shape[0] in (1,3) else "HWC")
        except Exception:
            pass

    def graph(self, model, example_input):
        if not self.w: return
        try:
            # example_input: a Tensor batch or a tuple
            if not isinstance(example_input, tuple):
                example_input = (example_input,)
            self.w.add_graph(model, example_input)
        except Exception:
            pass  # graph tracing can fail depending on model control flow; don't crash

    def hparams(self, hparams: Dict[str, Any], metrics_once: Optional[Dict[str, float]] = None):
        if not self.w or self.hparams_logged:
            return
        try:
            # Single, stable sub-run so it doesn’t spam the left pane
            self.w.add_hparams(hparams, metrics_once or {}, run_name="_hparams")
            self.hparams_logged = True
        except Exception:
            pass

    def flush(self):
        if self.w:
            try: self.w.flush()
            except Exception: pass

    def close(self):
        if self.w:
            try: self.w.close()
            except Exception: pass

class WBLogger(NoopLogger):
    def __init__(self, project: str, run_name: str | None = None):
        try:
            import wandb
            wandb.init(project=project, name=run_name)
            self.wb = wandb
        except Exception:
            self.wb = None
    def log(self, **kv):
        if self.wb: self.wb.log(kv)


def init_logger(which: str, out_dir: str = "runs/part4"):
    if which == 'tensorboard':
        tb = TBLogger(out_dir)
        return tb if tb.w is not None else NoopLogger()
    if which == 'wandb':
        return WBLogger(project='llm-part4')
    return NoopLogger()

In [45]:
from __future__ import annotations
import torch

def top_k_top_p_filtering(logits: torch.Tensor, top_k: int | None = None, top_p: float | None = None):
    """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering.
    - logits: (B, vocab)
    Returns filtered logits with -inf for masked entries.
    """
    B, V = logits.shape
    filtered = logits.clone()

    if top_k is not None and top_k < V:
        topk_vals, _ = torch.topk(filtered, top_k, dim=-1)
        kth = topk_vals[:, -1].unsqueeze(-1)
        filtered[filtered < kth] = float('-inf')

    if top_p is not None and 0 < top_p < 1.0:
        sorted_logits, sorted_idx = torch.sort(filtered, descending=True, dim=-1)
        probs = torch.softmax(sorted_logits, dim=-1)
        cumsum = torch.cumsum(probs, dim=-1)
        mask = cumsum > top_p
        # keep at least 1 token
        mask[..., 0] = False
        sorted_logits[mask] = float('-inf')
        # Scatter back
        filtered = torch.full_like(filtered, float('-inf'))
        filtered.scatter_(1, sorted_idx, sorted_logits)

    return filtered

In [46]:
from __future__ import annotations
import torch
import math

class RoPECache:
    """Precompute cos/sin for positions up to max_pos for even head_dim."""
    def __init__(self, head_dim: int, max_pos: int, base: float = 10000.0, device: torch.device | None = None):
        assert head_dim % 2 == 0, "RoPE head_dim must be even"
        self.head_dim = head_dim
        self.base = base
        self.device = device
        self._build(max_pos)
    def get(self, positions: torch.Tensor):
        # positions: (T,) or (1,T)
        if positions.dim() == 2:
            positions = positions[0]
        need = int(positions.max().item()) + 1 if positions.numel() > 0 else 1
        if need > self.max_pos:
            # grow tables
            self._build(max(need, int(self.max_pos * 2)))
        cos = self.cos[positions]  # (T, D/2)
        sin = self.sin[positions]
        return cos, sin

    def _build(self, max_pos: int):
        """(Re)build cos/sin tables for a new max_pos."""
        self.max_pos = max_pos
        inv_freq = 1.0 / (10000.0 ** (torch.arange(0, self.head_dim, 2, device=self.device).float() / self.head_dim))
        t = torch.arange(max_pos, device=self.device).float()
        freqs = torch.outer(t, inv_freq)  # (max_pos, head_dim/2)
        self.cos = torch.cos(freqs)
        self.sin = torch.sin(freqs)

def apply_rope_single(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
    """Rotate pairs along last dim for RoPE.
    x: (B,H,T,D) with D even; cos/sin: (T,D/2)
    """
    assert x.size(-1) % 2 == 0
    cos = cos.unsqueeze(0).unsqueeze(0)  # (1,1,T,D/2)
    sin = sin.unsqueeze(0).unsqueeze(0)
    x1 = x[..., ::2]
    x2 = x[..., 1::2]
    xr1 = x1 * cos - x2 * sin
    xr2 = x1 * sin + x2 * cos
    out = torch.empty_like(x)
    out[..., ::2] = xr1
    out[..., 1::2] = xr2
    return out

In [47]:
import torch
import torch.nn as nn

class RMSNorm(nn.Module):
    """Root Mean Square Layer Normalization.
    y = x * g / rms(x),   rms(x) = sqrt(mean(x^2) + eps)
    """
    def __init__(self, dim: int, eps: float = 1e-8):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        rms = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).sqrt()
        return (x / rms) * self.weight

In [48]:
import torch.nn as nn

class SwiGLU(nn.Module):
    """SwiGLU FFN: (xW1) ⊗ swish(xW2) W3  with expansion factor `mult`.
    """
    def __init__(self, dim: int, mult: int = 4, dropout: float = 0.0):
        super().__init__()
        inner = mult * dim
        self.w1 = nn.Linear(dim, inner, bias=False)
        self.w2 = nn.Linear(dim, inner, bias=False)
        self.w3 = nn.Linear(inner, dim, bias=False)
        self.act = nn.SiLU()
        self.drop = nn.Dropout(dropout)
    def forward(self, x):
        a = self.w1(x)
        b = self.act(self.w2(x))
        return self.drop(self.w3(a * b))

In [49]:
from __future__ import annotations
import torch
from dataclasses import dataclass

@dataclass
class KVCache:
    k: torch.Tensor  # (B,H,T,D)
    v: torch.Tensor  # (B,H,T,D)

    @property
    def T(self):
        return self.k.size(2)

class RollingKV:
    """Rolling buffer with optional attention sink.
    Keeps first `sink` tokens + last `window` tokens.
    """
    def __init__(self, window: int, sink: int = 0):
        self.window = window
        self.sink = sink
        self.k = None
        self.v = None
    def step(self, k_new: torch.Tensor, v_new: torch.Tensor):
        if self.k is None:
            self.k, self.v = k_new, v_new
        else:
            self.k = torch.cat([self.k, k_new], dim=2)
            self.v = torch.cat([self.v, v_new], dim=2)
        # crop
        if self.k.size(2) > self.window + self.sink:
            sink_part = self.k[:, :, :self.sink, :]
            sink_val  = self.v[:, :, :self.sink, :]
            tail_k = self.k[:, :, -self.window:, :]
            tail_v = self.v[:, :, -self.window:, :]
            self.k = torch.cat([sink_part, tail_k], dim=2)
            self.v = torch.cat([sink_val, tail_v], dim=2)
        return self.k, self.v

In [50]:
from __future__ import annotations
import math, torch
import torch.nn as nn
import torch.nn.functional as F


class CausalSelfAttentionModern(nn.Module):
    def __init__(self, n_embd: int, n_head: int, dropout: float = 0.0,
                 rope: bool = True, max_pos: int = 4096,
                 sliding_window: int | None = None, attention_sink: int = 0,
                 n_kv_head: int | None = None):  # ← NEW
        super().__init__()
        assert n_embd % n_head == 0, "n_embd must be divisible by n_head"
        self.n_head = n_head
        self.n_kv_head = n_kv_head or n_head      # ← NEW (GQA defaults to MHA)
        assert self.n_head % self.n_kv_head == 0, "n_head must be multiple of n_kv_head (GQA grouping)"
        self.group_size = self.n_head // self.n_kv_head
        self.d_head = n_embd // n_head

        # Separate projections for Q vs K/V (sizes differ under GQA)  ← CHANGED
        self.wq  = nn.Linear(n_embd, self.n_head   * self.d_head, bias=False)
        self.wk  = nn.Linear(n_embd, self.n_kv_head * self.d_head, bias=False)
        self.wv  = nn.Linear(n_embd, self.n_kv_head * self.d_head, bias=False)
        self.proj = nn.Linear(n_embd, n_embd, bias=False)
        self.dropout = nn.Dropout(dropout)

        self.use_rope = rope
        self.rope_cache: RoPECache | None = None
        self.max_pos = max_pos
        self.sliding_window = sliding_window
        self.attention_sink = attention_sink

    def _maybe_init_rope(self, device):
        if self.use_rope and self.rope_cache is None:
            self.rope_cache = RoPECache(self.d_head, self.max_pos, device=device)

    def forward(self, x: torch.Tensor, kv_cache: KVCache | None = None, start_pos: int = 0):
        """x: (B,T,C). If kv_cache given, we assume generation (T small, often 1)."""
        B, T, C = x.shape
        self._maybe_init_rope(x.device)

        # Projections
        q = self.wq(x).view(B, T, self.n_head,   self.d_head).transpose(1, 2)    # (B,H, T,D)
        k = self.wk(x).view(B, T, self.n_kv_head, self.d_head).transpose(1, 2)   # (B,Hk,T,D)
        v = self.wv(x).view(B, T, self.n_kv_head, self.d_head).transpose(1, 2)   # (B,Hk,T,D)

        # RoPE on *current* tokens (cached keys are already rotated)
        if self.use_rope:
            pos = torch.arange(start_pos, start_pos + T, device=x.device)
            cos, sin = self.rope_cache.get(pos)
            q = apply_rope_single(q, cos, sin)   # (B,H, T,D)
            k = apply_rope_single(k, cos, sin)   # (B,Hk,T,D)

        # Concatenate past cache (cache is stored in Hk heads)
        if kv_cache is not None:
            k_all = torch.cat([kv_cache.k, k], dim=2)  # (B,Hk, Tpast+T, D)
            v_all = torch.cat([kv_cache.v, v], dim=2)
        else:
            k_all, v_all = k, v

        # Sliding-window + attention-sink (crop along seq length)
        if self.sliding_window is not None and k_all.size(2) > (self.sliding_window + self.attention_sink):
            s = self.attention_sink
            k_all = torch.cat([k_all[:, :, :s, :], k_all[:, :, -self.sliding_window:, :]], dim=2)
            v_all = torch.cat([v_all[:, :, :s, :], v_all[:, :, -self.sliding_window:, :]], dim=2)

        # --- GQA expand: repeat K/V heads to match Q heads before attention ---
        if self.n_kv_head != self.n_head:
            k_attn = k_all.repeat_interleave(self.group_size, dim=1)  # (B,H,Tk,D)
            v_attn = v_all.repeat_interleave(self.group_size, dim=1)  # (B,H,Tk,D)
        else:
            k_attn, v_attn = k_all, v_all

        # Scaled dot-product attention (PyTorch scales internally)
        is_causal = kv_cache is None
        y = F.scaled_dot_product_attention(q, k_attn, v_attn,
                                           attn_mask=None,
                                           dropout_p=self.dropout.p if self.training else 0.0,
                                           is_causal=is_causal)          # (B,H,T,D)

        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.proj(y)

        # Update KV cache (store compact Hk heads, not expanded)
        if kv_cache is not None:
            k_new = torch.cat([kv_cache.k, k], dim=2)  # (B,Hk,*,D)
            v_new = torch.cat([kv_cache.v, v], dim=2)
        else:
            k_new, v_new = k, v
        new_cache = KVCache(k_new, v_new)
        return y, new_cache

In [51]:
import torch.nn as nn
# from rmsnorm import RMSNorm
# from swiglu import SwiGLU
# from attn_modern import CausalSelfAttentionModern

class TransformerBlockModern(nn.Module):
    def __init__(self, n_embd: int, n_head: int, dropout: float = 0.0,
                 use_rmsnorm: bool = True, use_swiglu: bool = True,
                 rope: bool = True, max_pos: int = 4096,
                 sliding_window: int | None = None, attention_sink: int = 0, n_kv_head: int | None = None):
        super().__init__()
        Norm = RMSNorm if use_rmsnorm else nn.LayerNorm
        self.ln1 = Norm(n_embd)
        self.attn = CausalSelfAttentionModern(n_embd, n_head, dropout, rope, max_pos, sliding_window, attention_sink, n_kv_head)
        self.ln2 = Norm(n_embd)
        self.ffn = SwiGLU(n_embd, mult=4, dropout=dropout) if use_swiglu else nn.Sequential(
            nn.Linear(n_embd, 4*n_embd), nn.GELU(), nn.Linear(4*n_embd, n_embd), nn.Dropout(dropout)
        )
    def forward(self, x, kv_cache=None, start_pos: int = 0):
        a, kv_cache = self.attn(self.ln1(x), kv_cache=kv_cache, start_pos=start_pos)
        x = x + a
        x = x + self.ffn(self.ln2(x))
        return x, kv_cache

In [52]:
from __future__ import annotations
import torch
import torch.nn as nn
# from block_modern import TransformerBlockModern
# from tokenizer import ByteTokenizer

# Get the absolute path to the folder that contains part_2 and part_3
import os, sys
# parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
# sys.path.insert(0, parent_dir)

class GPTModern(nn.Module):
    def __init__(self, vocab_size: int = 256, block_size: int = 256,
                 n_layer: int=4, n_head: int=4, n_embd: int=256, dropout: float=0.0,
                 use_rmsnorm: bool = True, use_swiglu: bool = True, rope: bool = True,
                 max_pos: int = 4096, sliding_window: int | None = None, attention_sink: int = 0, n_kv_head: int | None = None):
        super().__init__()
        self.block_size = block_size
        self.tok_emb = nn.Embedding(vocab_size, n_embd)
        # self.pos_emb = nn.Embedding(block_size, n_embd)
        self.drop = nn.Dropout(dropout)
        self.blocks = nn.ModuleList([
            TransformerBlockModern(n_embd, n_head, dropout, use_rmsnorm, use_swiglu, rope, max_pos, sliding_window, attention_sink, n_kv_head)
            for _ in range(n_layer)
        ])
        self.ln_f = nn.Identity() if use_rmsnorm else nn.LayerNorm(n_embd)
        self.head = nn.Linear(n_embd, vocab_size, bias=False)

    def forward(self, idx: torch.Tensor, targets: torch.Tensor | None = None, kv_cache_list=None, start_pos: int = 0):
        B, T = idx.shape
        assert T <= self.block_size
        pos = torch.arange(0, T, device=idx.device).unsqueeze(0)
        x = self.tok_emb(idx)
        # + self.pos_emb(pos)
        x = self.drop(x)

        new_caches = []
        for i, blk in enumerate(self.blocks):
            cache = None if kv_cache_list is None else kv_cache_list[i]
            x, cache = blk(x, kv_cache=cache, start_pos=start_pos)
            new_caches.append(cache)
        x = self.ln_f(x)
        logits = self.head(x)

        loss = None
        if targets is not None:
            import torch.nn.functional as F
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        return logits, loss, new_caches

    @torch.no_grad()
    def generate(self,
                 prompt: torch.Tensor,
                 max_new_tokens=200,
                 temperature=1.0,
                 top_k=50,
                 top_p=None,
                 eos_id=1, # addition from part 6 for early stopping
                 sliding_window: int | None = None,
                 attention_sink: int = 0):
        # try:
        #     from utils import top_k_top_p_filtering as _tk'
        # except Exception:
        #     _tk = lambda x, **_: x

        self.eval()
        idx = prompt
        kvs = [None] * len(self.blocks)

        for _ in range(max_new_tokens):
            # feed full prompt once; then only the last token
            idx_cond = idx[:, -self.block_size:] if kvs[0] is None else idx[:, -1:]

            # absolute start position from cache length (0 on first step)
            start_pos = 0 if kvs[0] is None else kvs[0].k.size(2)

            logits, _, kvs = self(idx_cond, kv_cache_list=kvs, start_pos=start_pos)

            next_logits = logits[:, -1, :] / max(temperature, 1e-6)
            next_logits = top_k_top_p_filtering(next_logits, top_k=top_k, top_p=top_p)
            probs = torch.softmax(next_logits, dim=-1)
            next_id = torch.argmax(probs, dim=-1, keepdim=True) if temperature == 0.0 else torch.multinomial(probs, 1)
            idx = torch.cat([idx, next_id], dim=1)

            # addition from part 6 for early stopping
            if eos_id is not None:
                if (next_id == eos_id).all():
                    break

        return idx


    @torch.no_grad()
    def generate_nocache(self, prompt: torch.Tensor, max_new_tokens=200, temperature=1.0, top_k=50, top_p=None,
                sliding_window: int | None = None, attention_sink: int = 0):
        # try:
        #     print('from utils import top_k_top_p_filtering as _tk')
        # except Exception:
        #     _tk = lambda x, **_: x

        self.eval()
        idx = prompt

        for _ in range(max_new_tokens):
            # always run a full forward over the cropped window, with NO cache
            idx_cond = idx[:, -self.block_size:]
            # absolute position of first token in the window (matches cached path)
            start_pos = idx.size(1) - idx_cond.size(1)

            logits, _, _ = self(idx_cond, kv_cache_list=None, start_pos=start_pos)

            next_logits = logits[:, -1, :] / max(temperature, 1e-6)
            next_logits = top_k_top_p_filtering(next_logits, top_k=top_k, top_p=top_p)
            probs = torch.softmax(next_logits, dim=-1)
            topv, topi = torch.topk(probs, 10)
            print("top ids:", topi.tolist())
            print("top vs:", topv.tolist())
            next_id = torch.argmax(probs, dim=-1, keepdim=True) if temperature == 0.0 else torch.multinomial(probs, 1)
            idx = torch.cat([idx, next_id], dim=1)

        return idx

In [None]:
from __future__ import annotations
import argparse, time, signal
from pathlib import Path
import sys

import torch
import torch.nn as nn

# so we can import Part 3 model
from pathlib import Path as _P
# sys.path.append(str(_P(__file__).resolve().parents[1] / 'part_3'))
# # from model_modern import GPTModern

# from tokenizer_bpe import BPETokenizer
# from dataset_bpe import make_loader
# from lr_scheduler import WarmupCosineLR
# from amp_accum import AmpGrad
# from checkpointing import (
#     load_checkpoint,
#     _log_hparams_tb,
#     _maybe_log_graph_tb,
#     _is_tb,
#     _log_model_stats,
#     _maybe_log_attention,
#     _log_samples_tb,
#     _log_runtime,
#     atomic_save_all,
# )
# from logger import init_logger


def run_cfg_from_args(args, vocab_size: int) -> dict:
    return dict(
        vocab_size=vocab_size,
        block_size=args.block_size,
        n_layer=args.n_layer,
        n_head=args.n_head,
        n_embd=args.n_embd,
        dropout=args.dropout,
        use_rmsnorm=True,
        use_swiglu=True,
        rope=True,
        max_pos=4096,
        sliding_window=None,
        attention_sink=0,
    )


def main(argv:None):
    if argv is None:
        argv = sys.argv[1:]
    p = argparse.ArgumentParser()
    p.add_argument('--data', type=str, required=True)
    p.add_argument('--out', type=str, default='runs/part4')

    # tokenizer / model dims
    p.add_argument('--bpe', action='store_true', help='train and use a BPE tokenizer (recommended)')
    p.add_argument('--vocab_size', type=int, default=32000)
    p.add_argument('--block_size', type=int, default=256)
    p.add_argument('--n_layer', type=int, default=6)
    p.add_argument('--n_head', type=int, default=8)
    p.add_argument('--n_embd', type=int, default=512)
    p.add_argument('--dropout', type=float, default=0.0)

    # train
    p.add_argument('--batch_size', type=int, default=32)
    p.add_argument('--epochs', type=int, default=1)
    p.add_argument('--steps', type=int, default=300, help='max optimizer steps for this run')
    p.add_argument('--lr', type=float, default=3e-4)
    p.add_argument('--warmup_steps', type=int, default=20)
    p.add_argument('--mixed_precision', action='store_true')
    p.add_argument('--grad_accum_steps', type=int, default=4)

    # misc
    p.add_argument('--log', choices=['wandb', 'tensorboard', 'none'], default='tensorboard')
    p.add_argument('--save_every', type=int, default=50, help='save checkpoint every N optimizer steps')
    p.add_argument('--keep_last_k', type=int, default=2, help='keep last K step checkpoints (plus model_last.pt)')
    args = p.parse_args(argv)

    # device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # output dir and (possible) checkpoint
    out_dir = Path(args.out); out_dir.mkdir(parents=True, exist_ok=True)
    ckpt_path = out_dir / "model_last.pt"
    have_ckpt = ckpt_path.exists()

    # ---- load checkpoint meta if present ----
    ckpt = None
    saved_tok_dir = None
    if have_ckpt:
        ckpt = torch.load(str(ckpt_path), map_location=device)
        if "config" not in ckpt:
            raise RuntimeError(
                "Checkpoint is missing 'config'."
                "Please re-save a checkpoint that includes the model config."
            )
        tok_file = ckpt_path.with_name("tokenizer_dir.txt")
        saved_tok_dir = tok_file.read_text().strip() if tok_file.exists() else None

    # ---- tokenizer ----
    tok = None
    tok_dir = None
    if have_ckpt:
        if not saved_tok_dir:
            raise RuntimeError(
                "Checkpoint was found but tokenizer_dir.txt is missing. "
                "Resume requires the original tokenizer."
            )
        tok = BPETokenizer(); tok.load(saved_tok_dir)
        tok_dir = saved_tok_dir
        vocab_size = tok.vocab_size
        print(f"[resume] Loaded tokenizer from {tok_dir} (vocab={vocab_size})")
    else:
        if args.bpe:
            tok = BPETokenizer(vocab_size=args.vocab_size)
            tok.train(args.data)
            tok_dir = str(out_dir / 'tokenizer'); Path(tok_dir).mkdir(parents=True, exist_ok=True)
            tok.save(tok_dir)
            vocab_size = tok.vocab_size
            print(f"[init] Trained tokenizer to {tok_dir} (vocab={vocab_size})")
        else:
            tok = None
            vocab_size = 256  # byte-level fallback (not recommended for Part 4)

    # ---- dataset ----
    train_loader = make_loader(args.data, tok, args.block_size, args.batch_size, shuffle=True)
    try:
        print(f"Number of batches in train_loader: {len(train_loader)}")
        if hasattr(train_loader, "dataset"):
            print(f"Number of samples in dataset: {len(train_loader.dataset)}")
    except TypeError:
        print("train_loader has no __len__ (likely an iterable or streaming dataset).")


    # ---- build model config ----
    if have_ckpt:
        cfg_build = ckpt["config"]
        if cfg_build.get("vocab_size") != vocab_size:
            raise RuntimeError(
                f"Tokenizer vocab ({vocab_size}) != checkpoint config vocab ({cfg_build.get('vocab_size')}). "
                "This deterministic script forbids vocab changes on resume."
            )
    else:
        cfg_build = run_cfg_from_args(args, vocab_size)

    # ---- init model/opt/sched/amp ----
    model = GPTModern(**cfg_build).to(device)
    optim = torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=0.1)

    total_steps = min(args.steps, args.epochs * len(train_loader))
    warmup = min(args.warmup_steps, max(total_steps // 10, 1))
    sched = WarmupCosineLR(optim, warmup_steps=warmup, total_steps=total_steps, base_lr=args.lr)

    amp = AmpGrad(optim, accum=args.grad_accum_steps, amp=args.mixed_precision)

    # ---- strict resume ----
    step = 0
    if have_ckpt:
        step = load_checkpoint(model, str(ckpt_path), optimizer=optim, scheduler=sched, amp=amp, strict=True)
        print(f"[resume] Loaded checkpoint at step {step}")

    # ---- logging ----
    logger = init_logger(args.log, out_dir=str(out_dir))
    _log_hparams_tb(logger, args, total_steps)
    if _is_tb(logger):
        try:
            ex_x, ex_y = next(iter(train_loader))
            _maybe_log_graph_tb(logger, model, ex_x.to(device), ex_y.to(device))
        except Exception:
            pass

    # ---- graceful save on SIGINT/SIGTERM ----
    save_requested = {"flag": False}
    def _on_term(sig, frame): save_requested["flag"] = True
    signal.signal(signal.SIGTERM, _on_term)
    signal.signal(signal.SIGINT,  _on_term)

    # ---- train loop ----
    model.train()
    while step < args.steps:
        for xb, yb in train_loader:
            if step >= args.steps: break
            if save_requested["flag"]:
                atomic_save_all(model, optim, sched, amp, step, out_dir, tok_dir, args.keep_last_k, cfg_build)
                print(f"[signal] Saved checkpoint at step {step} to {out_dir}. Exiting.")
                return

            it_t0 = time.time()
            xb, yb = xb.to(device), yb.to(device)
            with torch.cuda.amp.autocast(enabled=amp.amp):
                logits, loss, _ = model(xb, yb)
            amp.backward(loss)

            if amp.should_step():
                amp.step(); amp.zero_grad()
                lr = sched.step()
                step += 1

                # periodic checkpoint
                if step % args.save_every == 0:
                    atomic_save_all(model, optim, sched, amp, step, out_dir, tok_dir, args.keep_last_k, cfg_build)
                    if _is_tb(logger):
                        logger.text("meta/checkpoint", f"Saved at step {step}", step)

                # logging
                if step % 50 == 0:
                    train_loss = float(loss.item())
                    print(f"[step {step}] train_loss={train_loss:.4f}, lr={lr:.6f}")

                 
                    logger.log(step=step, loss=float(loss.item()), lr=float(lr))
                    _log_runtime(logger, step, it_t0, xb, device)
                    _log_model_stats(logger, model, step, do_hists=False)
                    _maybe_log_attention(logger, model, xb, step, every=100)
                    _log_samples_tb(logger, model, tok, xb, device, step, max_new_tokens=64)

    # ---- final save ----
    atomic_save_all(model, optim, sched, amp, step, out_dir, tok_dir, args.keep_last_k, cfg_build)
    print(f"Saved checkpoint to {out_dir}/model_last.pt")


if __name__ == '__main__':
    main([
        '--data', '/content/tiny_hi.txt',
        '--out', 'runs/part4-demo',
        '--bpe',
        '--vocab_size', '1000',
        '--epochs', '1',
        '--steps', '300',
        '--batch_size', '16',
        '--block_size', '128',
        '--n_layer', '2',
        '--n_head', '2',
        '--n_embd', '128',
        '--mixed_precision',
        '--grad_accum_steps', '2',
        '--log', 'tensorboard'
    ])


[init] Trained tokenizer to runs/part4-demo/tokenizer (vocab=1000)


  self.scaler = torch.cuda.amp.GradScaler(enabled=self.amp)
  assert T <= self.block_size
  need = int(positions.max().item()) + 1 if positions.numel() > 0 else 1
  need = int(positions.max().item()) + 1 if positions.numel() > 0 else 1
  assert x.size(-1) % 2 == 0
  with torch.cuda.amp.autocast(enabled=amp.amp):
  with torch.no_grad(), torch.cuda.amp.autocast(enabled=False):


Saved checkpoint to runs/part4-demo/model_last.pt


In [57]:
# !pip install tensorboard




In [58]:
# %load_ext tensorboard


In [65]:
# %tensorboard --logdir runs/part4-demo


In [64]:
from __future__ import annotations
import argparse, torch
from pathlib import Path

# load Part 3 model
import sys
from pathlib import Path as _P
# sys.path.append(str(_P(__file__).resolve().parents[1]/'part_3'))
# from model_modern import GPTModern  # noqa: E402

# from tokenizer_bpe import BPETokenizer


def main(argv=None):
    if argv is None:
      argv = sys.argv[1:]
    p = argparse.ArgumentParser()
    p.add_argument('--ckpt', type=str, required=True)
    p.add_argument('--prompt', type=str, default='')
    p.add_argument('--tokens', type=int, default=100)
    p.add_argument('--cpu', action='store_true')
    args = p.parse_args(argv)

    device = torch.device('cuda' if torch.cuda.is_available() and not args.cpu else 'cpu')

    ckpt = torch.load(args.ckpt, map_location='cpu')  # load on CPU first; move model later
    sd = ckpt['model']
    cfg = ckpt.get('config') or {}

    # tokenizer (if present)
    tok = None
    tok_dir_file = Path(args.ckpt).with_name('tokenizer_dir.txt')
    if tok_dir_file.exists():
        tok_dir = tok_dir_file.read_text().strip()  # file contains the dir path
        tok = BPETokenizer()
        tok.load(tok_dir)                            # <-- instance method, pass the directory
        vocab_from_tok = tok.vocab_size
    else:
        vocab_from_tok = None


    # ---- build config (prefer saved config; otherwise infer) ----
    if not cfg:
        # If a tokenizer is present and vocab differs, override with tokenizer vocab
        # if vocab_from_tok is not None and cfg.get('vocab_size') != vocab_from_tok:
        #     cfg = {**cfg, 'vocab_size': vocab_from_tok}
    # else:
        # Old checkpoints without config: infer essentials from weights
        # tok_emb.weight: [V, C] where C == n_embd
        V, C = sd['tok_emb.weight'].shape
        # pos_emb.weight: [block_size, C] if present
        block_size = sd['pos_emb.weight'].shape[0] if 'pos_emb.weight' in sd else 256
        # count transformer blocks present
        import re
        layer_ids = {int(m.group(1)) for k in sd.keys() if (m := re.match(r"blocks\.(\d+)\.", k))}
        n_layer = max(layer_ids) + 1 if layer_ids else 1
        # pick an n_head that divides C (head count doesn't affect weight shapes)
        n_head = 8 if C % 8 == 0 else 4 if C % 4 == 0 else 2 if C % 2 == 0 else 1
        cfg = dict(
            vocab_size=vocab_from_tok or V,
            block_size=block_size,
            n_layer=n_layer,
            n_head=n_head,
            n_embd=C,
            dropout=0.0,
            use_rmsnorm=True,
            use_swiglu=True,
            rope=True,
            max_pos=4096,
            sliding_window=None,
            attention_sink=0,
        )

    # ---- build & load model ----
    model = GPTModern(**cfg).to(device).eval()
    model.load_state_dict(ckpt['model'])
    model.to(device).eval()

    # prompt ids
    if tok:
        ids = tok.encode(args.prompt)
        if len(ids) == 0: ids = [10]
    else:
        ids = [10] if args.prompt == '' else list(args.prompt.encode('utf-8'))
    idx = torch.tensor([ids], dtype=torch.long, device=device)

    with torch.no_grad():
        out = model.generate(idx, max_new_tokens=args.tokens)
    out_ids = out[0].tolist()
    if tok:
        print(tok.decode(out_ids))
    else:
        print(bytes(out_ids).decode('utf-8', errors='ignore'))

if __name__ == '__main__':
    main([
        '--ckpt', 'runs/part4-demo/model_last.pt',
        '--tokens', '10',
        '--prompt', 'जो सुमिरत'
    ])

जो सुमिरत्रगट तीनान बि न
