In [None]:
# =========================
# wandb + (optional) Google Drive for downloading saves of models to it
# =========================

# SET GoogleDrive DIR TO SAVE TO FOR THE EXPERIMENT!!
# ---------------------------------------------------
# ---------------------------------------------------
GDRIVE_BASE_DIR = "/content/drive/MyDrive/sort_with_duplicate_mixing_2"
# ---------------------------------------------------
# ---------------------------------------------------
# ---------------------------------------------------

!pip -q install wandb

import os
import wandb

# ---- W&B ----
wandb.login()

# ---- Google Drive (optional) ----
# Set this to False if you want to skip Drive entirely.
ENABLE_GDRIVE = True

# Globals used by Cell 2 + Cell 3
GDRIVE_MOUNTED = False
GDRIVE_MODEL_DIR = None  # destination where .pt files get copied as they are saved

if ENABLE_GDRIVE:
    try:
        from google.colab import drive
        drive.mount("/content/drive")
        GDRIVE_MOUNTED = True

        # Set the base folder ONCE here. Cell 3 can optionally set a per-grid subfolder.
        os.makedirs(GDRIVE_BASE_DIR, exist_ok=True)

        # Default: copy into base; Cell 3 may overwrite to a per-grid folder.
        GDRIVE_MODEL_DIR = GDRIVE_BASE_DIR

        print("‚úÖ Drive mounted.")
        print(f"   GDRIVE_BASE_DIR  = {GDRIVE_BASE_DIR}")
        print(f"   GDRIVE_MODEL_DIR = {GDRIVE_MODEL_DIR}")
    except Exception as e:
        print(f"‚ö†Ô∏è Drive mount skipped/failed: {e}")
        GDRIVE_MOUNTED = False
else:
    print("Drive mount disabled (ENABLE_GDRIVE=False).")


  | |_| | '_ \/ _` / _` |  _/ -_)
[34m[1mwandb[0m: (1) Create a W&B account
[34m[1mwandb[0m: (2) Use an existing W&B account
[34m[1mwandb[0m: (3) Don't visualize my results
[34m[1mwandb[0m: Enter your choice:

 2


[34m[1mwandb[0m: You chose 'Use an existing W&B account'
[34m[1mwandb[0m: Logging into https://api.wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: Find your API key here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter:

 ¬∑¬∑¬∑¬∑¬∑¬∑¬∑¬∑¬∑¬∑


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mnathan-henry[0m ([33mnathan-henry-uc-berkeley-electrical-engineering-computer[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


‚ö†Ô∏è Drive mount skipped/failed: Error: credential propagation was unsuccessful


In [2]:
# =========================
# setup (model, batching, training, grid runner)
# =========================
import os, time, copy, math, shutil
from dataclasses import dataclass, asdict
from typing import Optional, Tuple, Iterable, Dict, Any
from fractions import Fraction
from contextlib import nullcontext

import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import trange
import wandb

# Optional (for pretty summary grid)
try:
    import pandas as pd
except Exception:
    pd = None

# -------------------------
# Speed knobs
# -------------------------
def enable_tf32():
    if torch.cuda.is_available():
        if hasattr(torch.backends.cuda.matmul, "fp32_precision"):
            torch.backends.cuda.matmul.fp32_precision = "tf32"
        else:
            torch.backends.cuda.matmul.allow_tf32 = True

        if (
            hasattr(torch.backends, "cudnn")
            and hasattr(torch.backends.cudnn, "conv")
            and hasattr(torch.backends.cudnn.conv, "fp32_precision")
        ):
            torch.backends.cudnn.conv.fp32_precision = "tf32"
        else:
            torch.backends.cudnn.allow_tf32 = True

    try:
        torch.set_float32_matmul_precision("high")
    except Exception:
        pass

enable_tf32()

# -------------------------
# Console clear (Jupyter + terminal-safe)
# -------------------------
def clear_console():
    try:
        from IPython.display import clear_output
        clear_output(wait=True)
        return
    except Exception:
        pass
    print("\033[2J\033[H", end="")

# -------------------------
# Robust context helpers (NO device_type kwarg)
# -------------------------
def get_sdpa_context():
    try:
        from torch.nn.attention import sdpa_kernel, SDPBackend
        return sdpa_kernel([SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH])
    except Exception:
        return nullcontext()

def get_autocast_context(device: torch.device, dtype: Optional[torch.dtype]):
    if device.type != "cuda" or dtype is None:
        return nullcontext()
    try:
        return torch.amp.autocast("cuda", dtype=dtype)
    except Exception:
        return torch.cuda.amp.autocast(dtype=dtype)

def make_grad_scaler(enabled: bool):
    if not enabled:
        class _NoScaler:
            def is_enabled(self): return False
            def scale(self, x): return x
            def step(self, opt): opt.step()
            def update(self): pass
        return _NoScaler()
    try:
        return torch.amp.GradScaler()
    except Exception:
        return torch.cuda.amp.GradScaler()

# -------------------------
# Model components
# -------------------------
class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.fc_1 = nn.Linear(config.n_embd, 3 * config.n_embd)
        self.gelu = nn.GELU(approximate="tanh")
        self.fc_2 = nn.Linear(3 * config.n_embd, config.n_embd)

    def forward(self, x):
        return self.fc_2(self.gelu(self.fc_1(x)))

class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_heads == 0
        self.n_embd = config.n_embd
        self.n_heads = config.n_heads
        self.head_dim = config.n_embd // config.n_heads

        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        self.c_proj.NANOGPT_SCALE_INIT = True

    def forward(self, x):
        B, T, C = x.size()
        qkv = self.c_attn(x)
        q, k, v = qkv.split(self.n_embd, dim=2)

        q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)

        y = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=True)
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.c_proj(y)
        return y

class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.use_mlp = bool(getattr(config, "use_mlp", True))

        self.attn = CausalSelfAttention(config)
        self.ln_1 = nn.LayerNorm(config.n_embd)

        if self.use_mlp:
            self.mlp = MLP(config)
            self.ln_2 = nn.LayerNorm(config.n_embd)
        else:
            self.mlp = None
            self.ln_2 = None

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        if self.mlp is not None:
            x = x + self.mlp(self.ln_2(x))
        return x

class GPTConfig:
    def __init__(
        self,
        block_size: int,
        vocab_size: int,
        n_layers: int = 2,
        n_heads: int = 1,
        n_embd: int = 64,
        without_pos: bool = False,
        use_mlp: bool = True,
        max_seq_len: Optional[int] = None,
    ):
        self.block_size = int(block_size)
        self.vocab_size = int(vocab_size)
        self.n_layers = int(n_layers)
        self.n_heads = int(n_heads)
        self.n_embd = int(n_embd)
        self.without_pos = bool(without_pos)
        self.use_mlp = bool(use_mlp)
        self.max_seq_len = int(max_seq_len if max_seq_len is not None else (2 * self.block_size + 1))

class GPT(nn.Module):
    def __init__(self, config: GPTConfig):
        super().__init__()
        self.config = config
        self.n_layers = config.n_layers

        self.transformer = nn.ModuleDict(
            dict(
                wte=nn.Embedding(config.vocab_size, config.n_embd),
                wpe=nn.Embedding(config.max_seq_len, config.n_embd),
                h=nn.ModuleList([Block(config) for _ in range(config.n_layers)]),
                ln_f=nn.LayerNorm(config.n_embd),
            )
        )
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        self.lm_head.weight = self.transformer.wte.weight  # weight tying
        self.apply(self._init_weights)

        self.register_buffer("pos_idx", torch.arange(config.max_seq_len), persistent=False)

        # Positional encodings fixed to 0 (and frozen) if without_pos=True
        if self.config.without_pos:
            with torch.no_grad():
                self.transformer.wpe.weight.zero_()
            self.transformer.wpe.weight.requires_grad_(False)

    def _init_weights(self, module):
        std = 0.02
        if isinstance(module, nn.Linear):
            if hasattr(module, "NANOGPT_SCALE_INIT"):
                std *= (2 * self.n_layers) ** -0.5
            torch.nn.init.normal_(module.weight, mean=0.0, std=std)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=std)

    def forward(self, idx, return_full_logits: bool = False, block_size: Optional[int] = None):
        B, T = idx.size()
        if block_size is None:
            block_size = self.config.block_size
        block_size = int(block_size)

        expected_T = 2 * block_size + 1
        assert T == expected_T, f"Expected T={expected_T} for block_size={block_size}, got T={T}"
        assert T <= self.config.max_seq_len, f"T={T} exceeds max_seq_len={self.config.max_seq_len}"

        pos = self.transformer.wpe(self.pos_idx[:T])
        x = self.transformer.wte(idx) if self.config.without_pos else (self.transformer.wte(idx) + pos)

        for block in self.transformer.h:
            x = block(x)

        x = self.transformer.ln_f(x)

        targets = idx[:, block_size + 1 :]  # (B, K)

        if return_full_logits:
            logits = self.lm_head(x)                       # (B, T, V)
            logits_for_loss = logits[:, block_size:T-1, :] # (B, K, V)
        else:
            x_for_loss = x[:, block_size:T-1, :]           # (B, K, C)
            logits_for_loss = self.lm_head(x_for_loss)     # (B, K, V)
            logits = logits_for_loss

        loss = F.cross_entropy(
            logits_for_loss.reshape(-1, logits_for_loss.size(-1)),
            targets.reshape(-1),
        )
        return logits, loss

# -------------------------
# Ratio + embedding dim helpers
# -------------------------
def compute_effective_n_embd(vocab_n: int, n_heads: int, vocab_over_embd: Optional[float], fallback_n_embd: int) -> int:
    """
    Implements: (vocab_size / embd_dim) = r  => embd_dim = vocab_size / r
    vocab_size includes SEP => total_vocab = vocab_n + 1
    Then rounds up to be divisible by n_heads.
    """
    if vocab_over_embd is None:
        n_embd = int(fallback_n_embd)
    else:
        r = float(vocab_over_embd)
        if r <= 0:
            raise ValueError(f"vocab_over_embd must be > 0, got {vocab_over_embd}")
        total_vocab_size = int(vocab_n) + 1
        n_embd = int(round(total_vocab_size / r))
        n_embd = max(1, n_embd)

    n_heads = int(n_heads)
    if n_embd % n_heads != 0:
        n_embd = ((n_embd + n_heads - 1) // n_heads) * n_heads
    return int(n_embd)

# -------------------------
# Data batching (p-way duplicate mixing, controlled by with_mixing boolean)
# -------------------------
SEP_TOKEN = "SEP"

def _sample_no_duplicates(batch_size: int, vocab_n: int, block_size: int, device: torch.device) -> torch.Tensor:
    if block_size > vocab_n:
        raise ValueError(
            f"Cannot sample {block_size} unique tokens from vocab_n={vocab_n}. "
            f"Need block_size <= vocab_n for 'no-duplicates' sampling."
        )
    scores = torch.rand(batch_size, vocab_n, device=device)
    return scores.topk(block_size, dim=1).indices.to(torch.long)

def _sample_with_duplicates(batch_size: int, vocab_n: int, block_size: int, device: torch.device) -> torch.Tensor:
    return torch.randint(0, vocab_n, (batch_size, block_size), device=device, dtype=torch.long)

def _sample_from_random_subset_with_replacement(
    batch_size: int,
    vocab_n: int,
    block_size: int,
    device: torch.device,
    effective_vocab_n: int,
) -> torch.Tensor:
    effective_vocab_n = int(effective_vocab_n)
    if not (1 <= effective_vocab_n <= vocab_n):
        raise ValueError(f"effective_vocab_n must be in [1, {vocab_n}], got {effective_vocab_n}")

    if effective_vocab_n == vocab_n:
        return _sample_with_duplicates(batch_size, vocab_n, block_size, device)

    scores = torch.rand(batch_size, vocab_n, device=device)
    allowed = scores.topk(effective_vocab_n, dim=1).indices  # (B, effective_vocab_n)

    pick = torch.randint(0, effective_vocab_n, (batch_size, block_size), device=device)
    return allowed.gather(1, pick).to(torch.long)

# rotation offset so remainder distribution is balanced over time
_DUP_MIX_OFFSET = 0

def _sample_numbers_mixed_duplicate_curriculum(
    batch_size: int,
    vocab_n: int,
    block_size: int,
    device: torch.device,
    p: int,
) -> torch.Tensor:
    """
    Even split across t=0..p-1 (as evenly as possible):
      t=0: guaranteed 0 duplicates (no replacement, full vocab)
      t>0: remove floor(t/p * vocab_n) items (per-sample), then sample with replacement
    """
    global _DUP_MIX_OFFSET
    p = int(p)
    if p <= 0:
        raise ValueError(f"p must be >= 1, got {p}")
    if p == 1:
        return _sample_with_duplicates(batch_size, vocab_n, block_size, device)

    base = batch_size // p
    rem = batch_size % p

    # distribute remainder in a rotating way across calls
    extra = [0] * p
    start = int(_DUP_MIX_OFFSET % p)
    for i in range(rem):
        extra[(start + i) % p] += 1
    _DUP_MIX_OFFSET += rem

    groups = []
    for t in range(p):
        n_t = base + extra[t]
        if n_t <= 0:
            continue

        if t == 0:
            x_t = _sample_no_duplicates(n_t, vocab_n, block_size, device)
        else:
            n_removed = (vocab_n * t) // p
            eff_vocab = vocab_n - n_removed  # >= 1 since t < p
            x_t = _sample_from_random_subset_with_replacement(n_t, vocab_n, block_size, device, eff_vocab)

        groups.append(x_t)

    x = torch.cat(groups, dim=0)
    perm = torch.randperm(x.size(0), device=device)
    return x[perm]

def _sample_numbers(
    batch_size: int,
    vocab_n: int,
    block_size: int,
    device: torch.device,
    allow_duplicates: bool,
    *,
    dup_mixture_p: Optional[int] = None,
    use_dup_mixture: bool = False,
) -> torch.Tensor:
    if not allow_duplicates:
        return _sample_no_duplicates(batch_size, vocab_n, block_size, device)

    # allow_duplicates=True
    if use_dup_mixture and (dup_mixture_p is not None) and int(dup_mixture_p) > 1:
        return _sample_numbers_mixed_duplicate_curriculum(batch_size, vocab_n, block_size, device, int(dup_mixture_p))

    return _sample_with_duplicates(batch_size, vocab_n, block_size, device)

def _build_batch_from_x(x: torch.Tensor, vocab_n: int) -> torch.Tensor:
    vals = x.sort(dim=1).values
    sep_id = vocab_n
    sep = torch.full((x.size(0), 1), sep_id, device=x.device, dtype=torch.long)
    return torch.cat([x, sep, vals], dim=1)

def get_batch(
    batch_size: int,
    device: torch.device,
    vocab_n: int,
    block_size: int,
    allow_duplicates: bool,
    *,
    dup_mixture_p: Optional[int] = None,
    use_dup_mixture: bool = False,
) -> torch.Tensor:
    x = _sample_numbers(
        batch_size=batch_size,
        vocab_n=vocab_n,
        block_size=block_size,
        device=device,
        allow_duplicates=allow_duplicates,
        dup_mixture_p=dup_mixture_p,
        use_dup_mixture=use_dup_mixture,
    )
    return _build_batch_from_x(x, vocab_n)

# -------------------------
# LR + plateau helpers
# -------------------------
def create_optimizer(model, weight_decay: float, lr: float):
    params = [p for p in model.parameters() if p.requires_grad]
    decay_params = [p for p in params if p.dim() > 1]
    nondecay_params = [p for p in params if p.dim() <= 1]
    optim_groups = [
        {"params": decay_params, "weight_decay": weight_decay},
        {"params": nondecay_params, "weight_decay": 0.0},
    ]
    try:
        return torch.optim.AdamW(optim_groups, lr=lr, betas=(0.9, 0.95), eps=1e-8, fused=True)
    except TypeError:
        return torch.optim.AdamW(optim_groups, lr=lr, betas=(0.9, 0.95), eps=1e-8)

def get_lr(itr: int, cfg) -> float:
    if itr < cfg.warmup_iters:
        return cfg.learning_rate * (itr + 1) / (cfg.warmup_iters + 1)
    if itr > cfg.max_iters:
        return cfg.min_lr
    ratio = (itr - cfg.warmup_iters) / (cfg.max_iters - cfg.warmup_iters)
    ratio = 0.5 * (1.0 + math.cos(math.pi * ratio))
    return cfg.min_lr + ratio * (cfg.learning_rate - cfg.min_lr)

def _safe_log10(x: float, eps: float = 1e-30) -> float:
    return math.log10(max(float(x), eps))

def _linreg_slope(xs, ys) -> float:
    n = len(xs)
    if n < 2:
        return 0.0
    x_mean = sum(xs) / n
    y_mean = sum(ys) / n
    cov = 0.0
    var = 0.0
    for x, y in zip(xs, ys):
        dx = x - x_mean
        dy = y - y_mean
        cov += dx * dy
        var += dx * dx
    return cov / var if var > 0 else 0.0

# -------------------------
# Accuracy helpers
# -------------------------
def acc_from_logits(logits: torch.Tensor, idx: torch.Tensor, block_size: int):
    targets = idx[:, block_size + 1 :]
    preds = logits.argmax(dim=-1)
    token_acc = (preds == targets).float().mean()
    sample_acc = (preds == targets).all(dim=1).float().mean()
    return token_acc, sample_acc

@torch.no_grad()
def eval_batch_metrics(model: nn.Module, device: torch.device, cfg, block_size: int, use_amp: bool, amp_dtype: Optional[torch.dtype], batch_size: Optional[int] = None):
    model.eval()
    bs = int(batch_size) if batch_size is not None else int(cfg.micro_batch_size)

    batch = get_batch(
        batch_size=bs,
        device=device,
        vocab_n=cfg.vocab_n,
        block_size=int(block_size),
        allow_duplicates=cfg.allow_duplicates,
        # Eval stays "original" distribution (no mixing) by default
    )

    with (get_autocast_context(device, amp_dtype) if use_amp else nullcontext()):
        logits, loss = model(batch, return_full_logits=False, block_size=int(block_size))

    token_acc_t, sample_acc_t = acc_from_logits(logits, batch, int(block_size))
    model.train()
    return float(loss.item()), float(token_acc_t.item()), float(sample_acc_t.item())

# -------------------------
# NEW: evaluation by exact number of unique tokens in the input list
# -------------------------
def _sample_exact_unique_count(
    batch_size: int,
    vocab_n: int,
    block_size: int,
    num_unique: int,
    device: torch.device,
) -> torch.Tensor:
    num_unique = int(num_unique)
    if not (1 <= num_unique <= min(block_size, vocab_n)):
        raise ValueError(f"num_unique must be in [1, min(K,vocab_n)] got {num_unique} (K={block_size}, vocab_n={vocab_n})")

    scores = torch.rand(batch_size, vocab_n, device=device)
    uniq = scores.topk(num_unique, dim=1).indices.to(torch.long)  # (B, U)

    if block_size == num_unique:
        x = uniq
    else:
        extra_idx = torch.randint(0, num_unique, (batch_size, block_size - num_unique), device=device)
        extra = uniq.gather(1, extra_idx)
        x = torch.cat([uniq, extra], dim=1)
        perm = torch.rand(batch_size, block_size, device=device).argsort(dim=1)
        x = x.gather(1, perm)

    return x

@torch.no_grad()
def eval_sample_acc_by_unique_count(
    model: nn.Module,
    device: torch.device,
    cfg,
    block_size: int,
    use_amp: bool,
    amp_dtype: Optional[torch.dtype],
    batch_size: Optional[int] = None,
):
    was_training = model.training
    model.eval()

    k = int(block_size)
    vmax = min(k, int(cfg.vocab_n))
    bs = int(batch_size) if batch_size is not None else int(cfg.micro_batch_size)

    base = bs // vmax
    rem = bs % vmax

    xs = []
    v_labels = []
    for i, V in enumerate(range(1, vmax + 1)):
        n_v = base + (1 if i < rem else 0)
        if n_v <= 0:
            continue
        x_v = _sample_exact_unique_count(n_v, cfg.vocab_n, k, V, device)
        xs.append(x_v)
        v_labels.append(torch.full((n_v,), V, device=device, dtype=torch.long))

    x_all = torch.cat(xs, dim=0)
    v_all = torch.cat(v_labels, dim=0)

    perm = torch.randperm(x_all.size(0), device=device)
    x_all = x_all[perm]
    v_all = v_all[perm]

    batch = _build_batch_from_x(x_all, cfg.vocab_n)

    with (get_autocast_context(device, amp_dtype) if use_amp else nullcontext()):
        logits, _ = model(batch, return_full_logits=False, block_size=k)

    targets = batch[:, k + 1 :]
    preds = logits.argmax(dim=-1)
    sample_correct = (preds == targets).all(dim=1)

    out = {}
    for V in range(1, vmax + 1):
        mask = (v_all == V)
        out[V] = float(sample_correct[mask].float().mean().item()) if mask.any() else float("nan")

    if was_training:
        model.train()
    return out

# -------------------------
# Config (NOTE: with_mixing boolean replaces "p" hyperparam)
# with_mixing=False => effective p=1
# with_mixing=True  => effective p=mixing_bins (default 8)
# -------------------------
@dataclass
class TrainConfig:
    vocab_n: int = 1024
    block_size: int = 32
    allow_duplicates: bool = True
    sep_token: str = SEP_TOKEN

    # NEW: boolean toggle for mixing
    with_mixing: bool = False
    mixing_bins: int = 8  # when with_mixing=True, effective p = mixing_bins; when False, effective p=1

    # NEW: log sample accuracy by exact number of unique tokens V in the input
    log_unique_count_metrics: bool = True
    unique_count_eval_batch_size: Optional[int] = None  # None => use micro_batch_size

    test_block_sizes: Optional[Tuple[int, ...]] = None

    n_layers: int = 2
    n_heads: int = 1
    n_embd: int = 64
    vocab_over_embd: Optional[float] = None

    without_pos: bool = False
    use_mlp: bool = True

    warmup_iters: int = 200
    max_iters: int = 120000
    learning_rate: float = 1e-4
    min_lr: float = 1e-6
    weight_decay: float = 0.0

    micro_batch_size: int = 1024
    effective_batch_size: int = 4096

    log_interval: int = 250
    ckpt_interval: int = 20000
    save_dir: str = "./saved_models"

    plateau_window_logs: int = 40
    plateau_slope_threshold: float = 0.02
    plateau_log10_loss_delta: float = 0.02
    plateau_patience_logs: int = 2
    plateau_extra_logs: int = 1
    plateau_min_iters: int = 20000

    seed: int = 1337
    use_compile: bool = False

    wandb_project: str = "sortgpt"
    wandb_entity: Optional[str] = None
    wandb_group: Optional[str] = None
    wandb_mode: Optional[str] = None

def _mix_p_effective(cfg: TrainConfig) -> int:
    return int(cfg.mixing_bins) if bool(cfg.with_mixing) else 1

# -------------------------
# Naming / saving + Drive copy hook (NEW: interpretable names)
# -------------------------
def make_wandb_run_name(cfg: TrainConfig) -> str:
    eff_n_embd = compute_effective_n_embd(cfg.vocab_n, cfg.n_heads, cfg.vocab_over_embd, cfg.n_embd)
    return (
        f"vocab{int(cfg.vocab_n)}"
        f"_blockSize{int(cfg.block_size)}"
        f"_layers{int(cfg.n_layers)}"
        f"_pos{int(cfg.without_pos)}"
        f"_mlp{int(cfg.use_mlp)}"
        f"_embd{int(eff_n_embd)}"
        f"_dup{int(cfg.allow_duplicates)}"
        f"_mix{int(cfg.with_mixing)}"
    )

def make_save_filename(prefix: str, cfg: TrainConfig, iters_done: int) -> str:
    eff_n_embd = compute_effective_n_embd(cfg.vocab_n, cfg.n_heads, cfg.vocab_over_embd, cfg.n_embd)
    return (
        f"{prefix}"
        f"_vocab{int(cfg.vocab_n)}"
        f"_blockSize{int(cfg.block_size)}"
        f"_layers{int(cfg.n_layers)}"
        f"_pos{int(cfg.without_pos)}"
        f"_mlp{int(cfg.use_mlp)}"
        f"_embd{int(eff_n_embd)}"
        f"_dup{int(cfg.allow_duplicates)}"
        f"_mix{int(cfg.with_mixing)}"
        f"_iters{int(iters_done)}.pt"
    )

def maybe_copy_to_drive(local_path: str):
    """
    Copies local_path -> GDRIVE_MODEL_DIR if Drive was mounted in Cell 1.
    If Drive setup was skipped, this becomes a no-op.
    """
    try:
        mounted = bool(globals().get("GDRIVE_MOUNTED", False))
        dst_dir = globals().get("GDRIVE_MODEL_DIR", None)
        if (not mounted) or (not dst_dir):
            return
        os.makedirs(dst_dir, exist_ok=True)
        dst_path = os.path.join(dst_dir, os.path.basename(local_path))
        shutil.copy2(local_path, dst_path)
        print(f"‚úì Copied to Drive: {dst_path}")
    except Exception as e:
        print(f"‚úó Drive copy failed for {local_path}: {e}")

# -------------------------
# Training
# -------------------------
def train_sorting_gpt(cfg: TrainConfig) -> Dict[str, Any]:
    t0 = time.time()
    torch.manual_seed(cfg.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(cfg.seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"using device: {device}")

    test_block_sizes = tuple(int(k) for k in (cfg.test_block_sizes if cfg.test_block_sizes else (cfg.block_size,)))
    primary_test_k = int(max(test_block_sizes))

    # Validate strict no-duplicates mode
    if not cfg.allow_duplicates:
        for k in (int(cfg.block_size),) + tuple(test_block_sizes):
            if k > cfg.vocab_n:
                raise ValueError(f"allow_duplicates=False requires block_size <= vocab_n (got block_size={k}, vocab_n={cfg.vocab_n})")

    total_vocab_size = cfg.vocab_n + 1
    sep_id = cfg.vocab_n

    grad_accum_steps = cfg.effective_batch_size // cfg.micro_batch_size
    assert cfg.effective_batch_size % cfg.micro_batch_size == 0

    # AMP
    use_amp = (device.type == "cuda")
    if use_amp:
        bf16_ok = getattr(torch.cuda, "is_bf16_supported", lambda: False)()
        amp_dtype = torch.bfloat16 if bf16_ok else torch.float16
    else:
        amp_dtype = None
    scaler = make_grad_scaler(enabled=(use_amp and amp_dtype == torch.float16))

    # Effective embd + max_seq_len for largest tested K
    eff_n_embd = compute_effective_n_embd(cfg.vocab_n, cfg.n_heads, cfg.vocab_over_embd, cfg.n_embd)
    max_k_for_model = max([int(cfg.block_size)] + [int(k) for k in test_block_sizes])
    max_seq_len = 2 * max_k_for_model + 1

    model_cfg = GPTConfig(
        block_size=cfg.block_size,
        vocab_size=total_vocab_size,
        n_layers=cfg.n_layers,
        n_heads=cfg.n_heads,
        n_embd=eff_n_embd,
        without_pos=cfg.without_pos,
        use_mlp=cfg.use_mlp,
        max_seq_len=max_seq_len,
    )
    model = GPT(model_cfg).to(device)

    if cfg.use_compile and hasattr(torch, "compile"):
        try:
            model = torch.compile(model, mode="max-autotune")
            print("torch.compile enabled")
        except Exception as e:
            print(f"torch.compile failed, continuing uncompiled: {e}")

    optimizer = create_optimizer(model, weight_decay=cfg.weight_decay, lr=cfg.learning_rate)
    os.makedirs(cfg.save_dir, exist_ok=True)

    # W&B init
    if wandb.run is not None:
        wandb.finish()

    mix_p_eff = _mix_p_effective(cfg)

    wandb_cfg = asdict(cfg)
    wandb_cfg.update(
        dict(
            total_vocab_size=total_vocab_size,
            sep_id=sep_id,
            max_seq_len=max_seq_len,
            max_k_for_model=max_k_for_model,
            grad_accum_steps=grad_accum_steps,
            amp_dtype=str(amp_dtype) if amp_dtype is not None else "none",
            device=str(device),
            test_block_sizes=list(test_block_sizes),
            primary_test_k=int(primary_test_k),
            n_embd_effective=int(eff_n_embd),
            mix_bins_effective=int(mix_p_eff),  # 1 if mix disabled, else 8 by default
        )
    )

    run = wandb.init(
        project=cfg.wandb_project,
        entity=cfg.wandb_entity,
        group=cfg.wandb_group,
        name=make_wandb_run_name(cfg),
        config=wandb_cfg,
        mode=cfg.wandb_mode,
    )

    run.define_metric("iter")
    run.define_metric("train/*", step_metric="iter")
    run.define_metric("test/*", step_metric="iter")
    run.define_metric("lr", step_metric="iter")
    run.define_metric("plateau/*", step_metric="iter")
    run.define_metric("ll/iter")
    run.define_metric("ll/*", step_metric="ll/iter")

    # Plateau state
    ll_iters_hist, ll_loss_hist = [], []
    plateau_hits = 0
    plateau_reached = False
    plateau_reached_iter = None
    stop_after_log_iter = None

    thresholds = [0.9, 0.99, 0.999, 0.9999, 1.0]
    first_iter_at = {thr: None for thr in thresholds}

    last_test_metrics = {}
    iters_done = cfg.max_iters
    last_log_t = time.time()

    with get_sdpa_context():
        for itr in trange(cfg.max_iters, desc="training"):
            optimizer.zero_grad(set_to_none=True)
            loss_accum = torch.zeros((), device=device)

            do_log = (itr % cfg.log_interval == 0)
            if do_log:
                token_correct = torch.zeros((), device=device)
                sample_correct = torch.zeros((), device=device)
                token_total = 0
                sample_total = 0

            for _ in range(grad_accum_steps):
                # TRAINING BATCH uses mixing if cfg.with_mixing=True
                mix_p_eff = _mix_p_effective(cfg)
                batch = get_batch(
                    batch_size=cfg.micro_batch_size,
                    device=device,
                    vocab_n=cfg.vocab_n,
                    block_size=cfg.block_size,
                    allow_duplicates=cfg.allow_duplicates,
                    dup_mixture_p=mix_p_eff,
                    use_dup_mixture=True,
                )

                if use_amp:
                    with get_autocast_context(device, amp_dtype):
                        logits, loss = model(batch, return_full_logits=False, block_size=cfg.block_size)
                else:
                    logits, loss = model(batch, return_full_logits=False, block_size=cfg.block_size)

                if do_log:
                    with torch.no_grad():
                        targets = batch[:, cfg.block_size + 1 :]
                        preds = logits.detach().argmax(dim=-1)
                        token_correct += (preds == targets).sum()
                        sample_correct += (preds == targets).all(dim=1).sum()
                        token_total += targets.numel()
                        sample_total += targets.size(0)

                loss_to_back = loss / grad_accum_steps
                if scaler.is_enabled():
                    scaler.scale(loss_to_back).backward()
                else:
                    loss_to_back.backward()

                loss_accum += loss.detach()

            lr = get_lr(itr, cfg)
            for pg in optimizer.param_groups:
                pg["lr"] = lr

            if scaler.is_enabled():
                scaler.step(optimizer)
                scaler.update()
            else:
                optimizer.step()

            # ---- log + plateau check ----
            if do_log:
                train_loss = float((loss_accum / grad_accum_steps).item())
                train_token_acc = float((token_correct / max(token_total, 1)).item())
                train_sample_acc = float((sample_correct / max(sample_total, 1)).item())

                steps_done = itr + 1
                ll_iter = _safe_log10(steps_done)
                ll_loss = _safe_log10(train_loss)
                ll_lr = _safe_log10(lr)

                ll_iters_hist.append(ll_iter)
                ll_loss_hist.append(ll_loss)

                slope = None
                delta_ll_loss = None
                plateau_now = 0

                if (not plateau_reached) and (len(ll_iters_hist) >= cfg.plateau_window_logs) and (steps_done >= cfg.plateau_min_iters):
                    xs = ll_iters_hist[-cfg.plateau_window_logs:]
                    ys = ll_loss_hist[-cfg.plateau_window_logs:]
                    slope = _linreg_slope(xs, ys)
                    delta_ll_loss = ys[0] - ys[-1]
                    is_improving = (delta_ll_loss >= 0.0)
                    plateau_now = int(
                        is_improving and (delta_ll_loss < cfg.plateau_log10_loss_delta) and (slope > -cfg.plateau_slope_threshold)
                    )
                    plateau_hits = (plateau_hits + 1) if plateau_now else 0
                    if plateau_hits >= cfg.plateau_patience_logs:
                        plateau_reached = True
                        plateau_reached_iter = itr
                        stop_after_log_iter = itr + cfg.log_interval * max(int(cfg.plateau_extra_logs), 0)
                        print(
                            f"üü® plateau detected @ itr={itr} (slope={slope:.4f}, Œîlog10(loss)={delta_ll_loss:.4f}); "
                            f"will stop after logging itr={stop_after_log_iter}"
                        )

                now = time.time()
                dt = now - last_log_t
                last_log_t = now
                print(
                    f"itr: {itr} lr: {lr:.3e} train loss: {train_loss:.6f} "
                    f"train token_acc: {train_token_acc:.4f} train sample_acc: {train_sample_acc:.4f} (dt={dt:.2f}s)"
                )

                log_dict = {
                    "iter": steps_done,
                    "train/loss": train_loss,
                    "train/token_acc": train_token_acc,
                    "train/sample_acc": train_sample_acc,
                    "lr": lr,
                    "ll/iter": ll_iter,
                    "ll/train_loss": ll_loss,
                    "ll/lr": ll_lr,
                    "plateau/now": plateau_now,
                    "plateau/hits": plateau_hits,
                    "plateau/reached": int(plateau_reached),
                    "mix/bins_effective": int(_mix_p_effective(cfg)),
                    "mix/enabled": int(cfg.with_mixing),
                }
                if slope is not None:
                    log_dict["plateau/loglog_slope"] = float(slope)
                if delta_ll_loss is not None:
                    log_dict["plateau/delta_log10_loss_window"] = float(delta_ll_loss)

                # Test metrics (original eval distribution)
                for k in test_block_sizes:
                    test_loss, test_token_acc, test_sample_acc = eval_batch_metrics(
                        model=model,
                        device=device,
                        cfg=cfg,
                        block_size=int(k),
                        use_amp=use_amp,
                        amp_dtype=amp_dtype,
                        batch_size=cfg.micro_batch_size,
                    )
                    last_test_metrics[int(k)] = {"loss": test_loss, "token_acc": test_token_acc, "sample_acc": test_sample_acc}
                    log_dict[f"test/K{int(k)}/loss"] = test_loss
                    log_dict[f"test/K{int(k)}/token_acc"] = test_token_acc
                    log_dict[f"test/K{int(k)}/sample_acc"] = test_sample_acc

                # NEW: sample accuracy by exact unique-count V on PRIMARY test K
                if getattr(cfg, "log_unique_count_metrics", True):
                    ubs = cfg.unique_count_eval_batch_size if cfg.unique_count_eval_batch_size is not None else cfg.micro_batch_size
                    uniq_acc = eval_sample_acc_by_unique_count(
                        model=model,
                        device=device,
                        cfg=cfg,
                        block_size=int(primary_test_k),
                        use_amp=use_amp,
                        amp_dtype=amp_dtype,
                        batch_size=int(ubs),
                    )
                    for V, sa in uniq_acc.items():
                        log_dict[f"test/K{int(primary_test_k)}/uniqueV{int(V)}/sample_acc"] = float(sa)

                # Threshold first-iter tracking using PRIMARY test K
                pk_metrics = last_test_metrics.get(primary_test_k, None)
                if pk_metrics is not None:
                    pk_sa = float(pk_metrics["sample_acc"])
                    for thr in thresholds:
                        if first_iter_at[thr] is None and pk_sa >= thr:
                            first_iter_at[thr] = int(steps_done)

                run.log(log_dict, step=steps_done)

                if plateau_reached and (stop_after_log_iter is not None) and (itr >= stop_after_log_iter):
                    iters_done = steps_done
                    break

            # ---- checkpointing ----
            if cfg.ckpt_interval and (itr > 0) and (itr % cfg.ckpt_interval == 0):
                steps_done = itr + 1
                ckpt_name = make_save_filename("Checkpoint", cfg, steps_done)
                ckpt_path = os.path.join(cfg.save_dir, ckpt_name)
                torch.save(
                    {
                        "itr": itr,
                        "iters_done": steps_done,
                        "plateau_reached": plateau_reached,
                        "plateau_reached_iter": plateau_reached_iter,
                        "train_config": wandb_cfg,
                        "model_config": vars(model_cfg),
                        "model": model.state_dict(),
                        "optimizer": optimizer.state_dict(),
                    },
                    ckpt_path,
                )
                print(f"saved checkpoint: {ckpt_path}")
                maybe_copy_to_drive(ckpt_path)

    # Final eval snapshot
    final_train_loss, final_train_token_acc, final_train_sample_acc = eval_batch_metrics(
        model=model,
        device=device,
        cfg=cfg,
        block_size=int(cfg.block_size),
        use_amp=use_amp,
        amp_dtype=amp_dtype,
        batch_size=cfg.micro_batch_size,
    )

    final_test_snapshot = {}
    for k in test_block_sizes:
        tl, tta, tsa = eval_batch_metrics(
            model=model,
            device=device,
            cfg=cfg,
            block_size=int(k),
            use_amp=use_amp,
            amp_dtype=amp_dtype,
            batch_size=cfg.micro_batch_size,
        )
        final_test_snapshot[int(k)] = {"loss": tl, "token_acc": tta, "sample_acc": tsa}

    # Final save
    final_name = make_save_filename("Final", cfg, iters_done)
    final_path = os.path.join(cfg.save_dir, final_name)
    torch.save(
        {
            "iters_done": iters_done,
            "plateau_reached": plateau_reached,
            "plateau_reached_iter": plateau_reached_iter,
            "train_config": wandb_cfg,
            "model_config": vars(model_cfg),
            "model": model.state_dict(),
        },
        final_path,
    )
    print(f"saved final model: {final_path}")
    maybe_copy_to_drive(final_path)

    run.summary["iters_done"] = iters_done
    run.summary["plateau_reached"] = bool(plateau_reached)
    run.summary["plateau_reached_iter"] = int(plateau_reached_iter) if plateau_reached_iter is not None else None
    run.finish()

    # Build result dict for grid summary
    res: Dict[str, Any] = dict(
        vocab_n=int(cfg.vocab_n),
        block_size=int(cfg.block_size),
        n_layers=int(cfg.n_layers),
        n_heads=int(cfg.n_heads),
        without_pos=bool(cfg.without_pos),
        use_mlp=bool(cfg.use_mlp),
        allow_duplicates=bool(cfg.allow_duplicates),
        with_mixing=bool(cfg.with_mixing),
        mix_bins_effective=int(_mix_p_effective(cfg)),
        vocab_over_embd=float(cfg.vocab_over_embd) if cfg.vocab_over_embd is not None else None,
        n_embd_effective=int(eff_n_embd),
        test_block_sizes=",".join(str(int(k)) for k in test_block_sizes),
        primary_test_k=int(primary_test_k),
        iters_done=int(iters_done),
        plateau_reached=bool(plateau_reached),
        plateau_reached_iter=int(plateau_reached_iter) if plateau_reached_iter is not None else None,
        final_model_path=str(final_path),
        wall_time_sec=float(time.time() - t0),

        final_train_loss=float(final_train_loss),
        final_train_token_acc=float(final_train_token_acc),
        final_train_sample_acc=float(final_train_sample_acc),
    )

    for k, m in final_test_snapshot.items():
        res[f"final_test_loss_K{k}"] = float(m["loss"])
        res[f"final_test_token_acc_K{k}"] = float(m["token_acc"])
        res[f"final_test_sample_acc_K{k}"] = float(m["sample_acc"])

    for thr in thresholds:
        res[f"iter_to_sample_acc_{thr}"] = first_iter_at[thr]

    return res

# -------------------------
# Grid runner (NEW: with_mixing is part of the grid)
# -------------------------
def run_grid(
    base_cfg: TrainConfig,
    vocab_sizes: Iterable[int],
    layer_counts: Iterable[int],
    block_sizes: Iterable[int],
    without_pos_flags: Iterable[bool],
    vocab_over_embd_list: Iterable[float],
    allow_duplicates_flags: Optional[Iterable[bool]] = None,
    use_mlp_flags: Optional[Iterable[bool]] = None,
    with_mixing_flags: Optional[Iterable[bool]] = None,
    print_summary: bool = True,
):
    vocab_sizes = sorted(list(vocab_sizes), reverse=True)
    block_sizes = sorted(list(block_sizes), reverse=True)
    layer_counts = sorted(list(layer_counts), reverse=True)
    vocab_over_embd_list = sorted(list(vocab_over_embd_list))

    if allow_duplicates_flags is None:
        allow_duplicates_flags = [bool(base_cfg.allow_duplicates)]
    else:
        allow_duplicates_flags = list(allow_duplicates_flags)

    if use_mlp_flags is None:
        use_mlp_flags = [bool(base_cfg.use_mlp)]
    else:
        use_mlp_flags = list(use_mlp_flags)

    if with_mixing_flags is None:
        with_mixing_flags = [bool(base_cfg.with_mixing)]
    else:
        with_mixing_flags = list(with_mixing_flags)

    # Default test block sizes = largest trained K in this grid
    if base_cfg.test_block_sizes is None or len(base_cfg.test_block_sizes) == 0:
        default_test_block_sizes = (max(int(k) for k in block_sizes),)
    else:
        default_test_block_sizes = tuple(int(k) for k in base_cfg.test_block_sizes)

    results: list[Dict[str, Any]] = []

    for N in vocab_sizes:
        for K in block_sizes:
            for L in layer_counts:
                for npos in without_pos_flags:
                    for mlp_on in use_mlp_flags:
                        for dup in allow_duplicates_flags:
                            for mix_on in with_mixing_flags:
                                for r in vocab_over_embd_list:
                                    cfg = copy.deepcopy(base_cfg)
                                    cfg.vocab_n = int(N)
                                    cfg.block_size = int(K)
                                    cfg.n_layers = int(L)
                                    cfg.without_pos = bool(npos)
                                    cfg.use_mlp = bool(mlp_on)
                                    cfg.allow_duplicates = bool(dup)
                                    cfg.with_mixing = bool(mix_on)
                                    cfg.vocab_over_embd = float(r)
                                    cfg.test_block_sizes = default_test_block_sizes

                                    try:
                                        res = train_sorting_gpt(cfg)
                                        res["error"] = None
                                    except Exception as e:
                                        eff_n_embd = compute_effective_n_embd(cfg.vocab_n, cfg.n_heads, cfg.vocab_over_embd, cfg.n_embd)
                                        res = {
                                            "vocab_n": int(cfg.vocab_n),
                                            "block_size": int(cfg.block_size),
                                            "n_layers": int(cfg.n_layers),
                                            "n_heads": int(cfg.n_heads),
                                            "without_pos": bool(cfg.without_pos),
                                            "use_mlp": bool(cfg.use_mlp),
                                            "allow_duplicates": bool(cfg.allow_duplicates),
                                            "with_mixing": bool(cfg.with_mixing),
                                            "mix_bins_effective": int(_mix_p_effective(cfg)),
                                            "vocab_over_embd": float(cfg.vocab_over_embd) if cfg.vocab_over_embd is not None else None,
                                            "n_embd_effective": int(eff_n_embd),
                                            "test_block_sizes": ",".join(str(int(k)) for k in default_test_block_sizes),
                                            "primary_test_k": int(max(default_test_block_sizes)),
                                            "iters_done": None,
                                            "plateau_reached": None,
                                            "plateau_reached_iter": None,
                                            "final_model_path": None,
                                            "final_train_loss": None,
                                            "final_train_token_acc": None,
                                            "final_train_sample_acc": None,
                                            "error": repr(e),
                                        }
                                        for thr in [0.9, 0.99, 0.999, 0.9999, 1.0]:
                                            res[f"iter_to_sample_acc_{thr}"] = None

                                    results.append(res)

    if print_summary:
        clear_console()
        print("===== GRID SUMMARY =====")
        if pd is not None:
            try:
                from IPython.display import display
            except Exception:
                display = print

            df = pd.DataFrame(results)

            null_cols = [c for c in df.columns if c.startswith("iter_to_sample_acc_")] + ["iters_done", "plateau_reached_iter"]
            for c in null_cols:
                if c in df.columns:
                    df[c] = df[c].where(~df[c].isna(), "NULL")

            pd.set_option("display.max_rows", None)
            pd.set_option("display.max_columns", None)
            pd.set_option("display.width", 250)

            sort_cols = [c for c in ["vocab_n", "block_size", "n_layers", "without_pos", "use_mlp", "allow_duplicates", "with_mixing", "n_embd_effective"] if c in df.columns]
            if sort_cols:
                df = df.sort_values(sort_cols).reset_index(drop=True)

            display(df)
            return df
        else:
            for row in results:
                print(row)
            return results

    return results


  _C._set_float32_matmul_precision(precision)


In [3]:
# =========================
# choose config + run grid
# =========================
import os, time

STAMP = time.strftime("%Y%m%d_%H%M%S")
PROJECT = globals().get("PROJECT", "sortgpt")
GROUP = f"grid_{STAMP}"

GRID_ROOT = f"./grid_outputs_{GROUP}"
WANDB_DIR = os.path.join(GRID_ROOT, "wandb")
SAVE_DIR  = os.path.join(GRID_ROOT, "saved_models")

os.makedirs(WANDB_DIR, exist_ok=True)
os.makedirs(SAVE_DIR, exist_ok=True)

# Put all wandb local artifacts for the whole grid under one directory
os.environ["WANDB_DIR"] = WANDB_DIR
os.environ["WANDB_CACHE_DIR"] = os.path.join(GRID_ROOT, "wandb_cache")
os.environ["WANDB_CONFIG_DIR"] = os.path.join(GRID_ROOT, "wandb_config")

print("GRID_ROOT =", GRID_ROOT)
print("WANDB_DIR =", WANDB_DIR)
print("SAVE_DIR  =", SAVE_DIR)
print("PROJECT   =", PROJECT)
print("GROUP     =", GROUP)

# OPTIONAL: if Drive is mounted, copy into a per-grid folder
if globals().get("GDRIVE_MOUNTED", False) and globals().get("GDRIVE_BASE_DIR", None):
    GDRIVE_MODEL_DIR = os.path.join(GDRIVE_BASE_DIR, GROUP)
    os.makedirs(GDRIVE_MODEL_DIR, exist_ok=True)
    print("‚úÖ Drive destination set for this grid:")
    print("   GDRIVE_MODEL_DIR =", GDRIVE_MODEL_DIR)
else:
    print("Drive not mounted (or skipped). No Drive copying will occur.")

base_cfg = TrainConfig(
    vocab_n=256,              # overridden by grid
    block_size=16,            # overridden by grid

    allow_duplicates=True,    # MUST BE TRUE if using mixing flags
    with_mixing=False,        # overridden by grid axis below
    mixing_bins=8,            # when with_mixing=True -> effective p=8

    n_layers=2,               # overridden by grid
    n_heads=1,

    n_embd=64,                # used only if vocab_over_embd=None (grid sets vocab_over_embd)
    vocab_over_embd=None,

    without_pos=False,        # overridden by grid
    use_mlp=True,             # overridden by grid

    test_block_sizes=None,    # default => largest K in the grid

    max_iters=60000,
    warmup_iters=200,
    learning_rate=1e-4,
    min_lr=1e-6,
    weight_decay=0.0,

    micro_batch_size=4096,
    effective_batch_size=4096,

    log_interval=250,
    ckpt_interval=1000000,    # huge => usually no checkpoints; only finals
    save_dir=SAVE_DIR,

    plateau_window_logs=40,
    plateau_min_iters=20000,
    plateau_slope_threshold=0.02,
    plateau_log10_loss_delta=0.02,
    plateau_patience_logs=2,
    plateau_extra_logs=1,

    # Unique-count metrics
    log_unique_count_metrics=True,
    unique_count_eval_batch_size=None,   # set smaller (e.g. 2048) if you want faster logs

    wandb_project=PROJECT,
    wandb_group=GROUP,
)

df_summary = run_grid(
    base_cfg,
    vocab_sizes=[128],
    layer_counts=[1, 2, 3],
    block_sizes=[16],
    without_pos_flags=[False],
    vocab_over_embd_list=[2, 4, 8],

    use_mlp_flags=[True, False],

    # mix0 => with_mixing=False => effective p=1
    # mix1 => with_mixing=True  => effective p=8
    with_mixing_flags=[False, True], # ONLY SET TO TRUE IF YOU ARE ALLOWING DUPLICATES !!!!!!!!!!!!!!!!!!

    print_summary=True,
)

GRID_ROOT = ./grid_outputs_grid_20260107_223822
WANDB_DIR = ./grid_outputs_grid_20260107_223822/wandb
SAVE_DIR  = ./grid_outputs_grid_20260107_223822/saved_models
PROJECT   = sortgpt
GROUP     = grid_20260107_223822
Drive not mounted (or skipped). No Drive copying will occur.
using device: cuda


training:   0%|          | 0/60000 [00:00<?, ?it/s]

itr: 0 lr: 4.975e-07 train loss: 4.848498 train token_acc: 0.0437 train sample_acc: 0.0000 (dt=1.22s)


training:   0%|          | 264/60000 [00:05<14:51, 67.01it/s]

itr: 250 lr: 1.000e-04 train loss: 3.954088 train token_acc: 0.2191 train sample_acc: 0.0000 (dt=3.86s)


training:   1%|          | 512/60000 [00:08<14:27, 68.56it/s]

itr: 500 lr: 9.999e-05 train loss: 2.686578 train token_acc: 0.8075 train sample_acc: 0.0320 (dt=3.65s)


training:   1%|‚ñè         | 763/60000 [00:12<14:28, 68.22it/s]

itr: 750 lr: 9.998e-05 train loss: 1.670355 train token_acc: 0.9392 train sample_acc: 0.4102 (dt=3.65s)


training:   2%|‚ñè         | 1013/60000 [00:16<14:24, 68.25it/s]

itr: 1000 lr: 9.996e-05 train loss: 0.940036 train token_acc: 0.9726 train sample_acc: 0.6702 (dt=3.65s)


training:   2%|‚ñè         | 1263/60000 [00:19<14:20, 68.23it/s]

itr: 1250 lr: 9.992e-05 train loss: 0.500542 train token_acc: 0.9841 train sample_acc: 0.7917 (dt=3.65s)


training:   3%|‚ñé         | 1513/60000 [00:23<14:16, 68.26it/s]

itr: 1500 lr: 9.988e-05 train loss: 0.288355 train token_acc: 0.9884 train sample_acc: 0.8469 (dt=3.65s)


training:   3%|‚ñé         | 1741/60000 [00:26<14:56, 65.00it/s]


KeyboardInterrupt: 