In [None]:
import subprocess, sys, os

print("=" * 60)
print("STEP 1: Installing correct dependencies")
print("=" * 60)

subprocess.check_call([
    sys.executable, "-m", "pip", "install",
    "transformers==4.44.2", "--quiet", "--force-reinstall", "--no-deps"
])
subprocess.check_call([
    sys.executable, "-m", "pip", "install",
    "tokenizers>=0.19,<0.20", "huggingface-hub>=0.23,<0.25",
    "safetensors>=0.4", "regex", "tiktoken",
    "--quiet"
])
print("Dependencies installed.")

print("\n" + "=" * 60)
print("STEP 2: Imports & Verification")
print("=" * 60)

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint_util
import numpy as np
import psutil
import gc
import math
import re
import time

import transformers
print(f"Python:       {sys.version.split()[0]}")
print(f"PyTorch:      {torch.__version__}")
print(f"Transformers: {transformers.__version__}")
assert transformers.__version__.startswith("4.44")

from transformers import GPT2LMHeadModel, GPT2TokenizerFast
print("GPT2 imports: ✓")
print(f"CUDA avail:   {torch.cuda.is_available()} (expected: False)")

torch.set_num_threads(2)
torch.manual_seed(42)

try:
    with open('/sys/fs/cgroup/memory/memory.limit_in_bytes', 'r') as f:
        cgroup_limit = int(f.read().strip())
        print(f"Container memory limit: {cgroup_limit / 1e9:.2f} GB")
except FileNotFoundError:
    try:
        with open('/sys/fs/cgroup/memory.max', 'r') as f:
            val = f.read().strip()
            print(f"Container memory limit: {'unlimited' if val == 'max' else f'{int(val)/1e9:.2f} GB'}")
    except FileNotFoundError:
        print("Container memory limit: unknown")

print(f"\n" + "=" * 60)
print("STEP 3: FlashLM v3 Configuration (d_model=256)")
print("=" * 60)

CONFIG = {
    # Architecture — d_model reduced from 512 to 256 for CPU speed
    'd_model': 256,
    'vocab_size': 50257,
    'n_recursions': 2,
    'glu_expansion': 2.67,
    'deep_supervision_steps': [2],
    
    # Training schedule
    'total_steps': 10000,
    'phase_schedule': {
        '1a': {'start': 0,    'end': 1000,  'seq_len': 64,  'grad_accum': 8,  'freeze_embed': True},
        '1b': {'start': 1000, 'end': 4000,  'seq_len': 128, 'grad_accum': 16, 'freeze_embed': False},
        '1c': {'start': 4000, 'end': 7000,  'seq_len': 256, 'grad_accum': 32, 'freeze_embed': False},
        '1d': {'start': 7000, 'end': 10000, 'seq_len': 512, 'grad_accum': 32, 'freeze_embed': False},
    },
    'batch_size': 4,
    
    # Optimizers
    'normuon_lr': 0.02,
    'normuon_momentum': 0.95,
    'normuon_beta2': 0.999,
    'adamw_lr': 3e-4,
    'adamw_betas': (0.9, 0.95),
    'weight_decay': 0.1,
    'max_grad_norm': 1.0,
    
    # LR schedule (WSD)
    'warmup_steps': 200,
    'decay_start_step': 9000,
    
    # Regularization
    'label_smoothing': 0.1,
    'zloss_coeff': 1e-4,
    'dropout': 0.0,
    'ema_decay': 0.999,
    'ewa_checkpoints': [8000, 9000, 10000],
    
    # Rho-1
    'rho1_top_fraction': 0.4,
    'rho1_refresh_interval': 200,
    'rho1_start_step': 1000,
    
    # Dataset
    'dataset_name': 'HuggingFaceFW/fineweb-edu',
    'dataset_subset': 'sample-10BT',
    'n_docs': 30000,
    'train_split': 0.95,
    'reasoning_shift_step': 7000,
    
    # DBO
    'dbo_max_steps': 300,
    'dbo_blend': 0.7,
    'dbo_patience': 50,
    'dbo_eval_interval': 25,
    
    # Router
    'USE_ROUTER': True,
    'router_threshold': 0.5,
    
    # Hardware
    'device': 'cpu',
    'num_threads': 2,
    'seed': 42,
    'checkpoint_dir': './checkpoints',
    'log_interval': 50,
    'memory_log_interval': 100,
}
CONFIG['glu_inner_dim'] = int(CONFIG['d_model'] * CONFIG['glu_expansion'])

print(f"  d_model:        {CONFIG['d_model']}  (reduced from 512 for CPU speed)")
print(f"  vocab_size:     {CONFIG['vocab_size']}")
print(f"  n_recursions:   {CONFIG['n_recursions']}")
print(f"  glu_inner_dim:  {CONFIG['glu_inner_dim']}")
print(f"  total_steps:    {CONFIG['total_steps']}")
print(f"  deep_sup:       {CONFIG['deep_supervision_steps']}")
for pk, pv in CONFIG['phase_schedule'].items():
    print(f"  phase {pk}:      {pv}")

tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
enc = tokenizer.encode("Hello world")
dec = tokenizer.decode(enc)
assert dec == "Hello world"
print(f"\nTokenizer: ✓")

rss = psutil.Process().memory_info().rss / 1e6
print(f"Process RSS: {rss:.0f} MB")
print("=" * 60)
print("CELL 1 COMPLETE ✓")


STEP 1: Installing correct dependencies

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.0[0m[39;49m -> [0m[32;49m26.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.0[0m[39;49m -> [0m[32;49m26.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
  from .autonotebook import tqdm as notebook_tqdm
Dependencies installed.

STEP 2: Imports & Verification
Python:       3.11.10
PyTorch:      2.1.2+cu121
Transformers: 4.44.2
GPT2 imports: ✓
CUDA avail:   False (expected: False)
Container memory limit: 5.37 GB

STEP 3: FlashLM v3 Configuration (d_model=256)
  d_model:        256  (reduced from 512 for CPU speed)
  vocab_size:     50257
  n_recursions:   2
  glu_inner_dim:  683
  total_steps:    10000
  d

In [None]:
import torch
import psutil
import gc
import os
from transformers import GPT2LMHeadModel, GPT2TokenizerFast

print("Loading GPT-2 model for embedding extraction...")
gpt2 = GPT2LMHeadModel.from_pretrained("gpt2")
W_full = gpt2.transformer.wte.weight.data.clone()  # (50257, 768)
print(f"Original embedding shape: {W_full.shape}")

del gpt2
gc.collect()
print("GPT-2 model deleted.")

# --- SVD Truncation: 768 -> 256 ---
print("\nComputing SVD...")
U, S, Vt = torch.linalg.svd(W_full, full_matrices=False)

d_target = CONFIG['d_model']  # 256
W_projected = U[:, :d_target] * S[:d_target].unsqueeze(0)  # (50257, 256)

total_variance = (S ** 2).sum()
retained_variance = (S[:d_target] ** 2).sum()
variance_ratio = (retained_variance / total_variance).item()

print(f"\nSVD Projection Results:")
print(f"  Projected shape:    {W_projected.shape}")
print(f"  Variance retained:  {variance_ratio:.4f} ({variance_ratio*100:.2f}%)")
print(f"  SV 256 vs 257:      {S[255].item():.4f} vs {S[256].item():.4f}")

assert W_projected.shape == (50257, d_target)
assert variance_ratio > 0.70, f"Variance too low: {variance_ratio:.4f}"
print(f"  ✓ Shape correct")
print(f"  ✓ Variance > 70%")

# --- Cosine similarity preservation ---
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
cos = torch.nn.CosineSimilarity(dim=0)
print("\nCosine similarity test:")
for w1, w2 in [("king", "queen"), ("cat", "dog"), ("python", "code")]:
    id1, id2 = tokenizer.encode(w1)[0], tokenizer.encode(w2)[0]
    sim_o = cos(W_full[id1], W_full[id2]).item()
    sim_p = cos(W_projected[id1], W_projected[id2]).item()
    print(f"  '{w1}' vs '{w2}': orig={sim_o:.4f}, proj={sim_p:.4f}, delta={abs(sim_o-sim_p):.4f}")

# --- Save ---
embed_path = f"gpt2_embed_{d_target}.pt"
torch.save(W_projected, embed_path)
print(f"\nSaved to {embed_path} ({os.path.getsize(embed_path)/1e6:.1f} MB)")

del W_full, U, S, Vt
gc.collect()

rss = psutil.Process().memory_info().rss / 1e6
print(f"Process RSS: {rss:.0f} MB")
print("=" * 60)
print("CELL 2 COMPLETE ✓")


Loading GPT-2 model for embedding extraction...
Original embedding shape: torch.Size([50257, 768])
GPT-2 model deleted.

Computing SVD...

SVD Projection Results:
  Projected shape:    torch.Size([50257, 256])
  Variance retained:  0.7114 (71.14%)
  SV 256 vs 257:      27.4301 vs 27.4088
  ✓ Shape correct
  ✓ Variance > 70%

Cosine similarity test:
  'king' vs 'queen': orig=0.2666, proj=0.4846, delta=0.2180
  'cat' vs 'dog': orig=0.3815, proj=0.5415, delta=0.1600
  'python' vs 'code': orig=0.3337, proj=0.4915, delta=0.1578

Saved to gpt2_embed_256.pt (51.5 MB)
Process RSS: 2223 MB
CELL 2 COMPLETE ✓


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import psutil
import time

D = CONFIG['d_model']  # 256

# ----- 3a: RMSNorm -----
class RMSNorm(nn.Module):
    def __init__(self, d_model: int, eps: float = 1e-8):
        super().__init__()
        self.eps = eps
        self.scale = nn.Parameter(torch.ones(d_model))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        rms = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).rsqrt()
        return x * rms * self.scale

_x = torch.randn(2, 16, D)
assert RMSNorm(D)(_x).shape == (2, 16, D)
print(f"RMSNorm:          params={D:,}")


# ----- 3b: BitLinear (Ternary Weights) -----
class BitLinear(nn.Module):
    def __init__(self, in_features: int, out_features: int, bias: bool = False):
        super().__init__()
        self.weight = nn.Parameter(torch.empty(out_features, in_features))
        self.bias = nn.Parameter(torch.zeros(out_features)) if bias else None
        nn.init.kaiming_normal_(self.weight, nonlinearity='linear')

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        w = self.weight
        threshold = w.abs().mean()
        w_ternary = torch.zeros_like(w)
        w_ternary = torch.where(w > threshold, torch.ones_like(w), w_ternary)
        w_ternary = torch.where(w < -threshold, -torch.ones_like(w), w_ternary)
        w_quantized = w + (w_ternary - w).detach()
        return F.linear(x, w_quantized, self.bias)

_bl = BitLinear(D, D)
assert _bl(_x).shape == (2, 16, D)
print(f"BitLinear:        params={sum(p.numel() for p in _bl.parameters()):,}")


# ----- 3c: CausalConv -----
class CausalConv(nn.Module):
    def __init__(self, d_model: int, kernel_size: int = 4, dilation: int = 1, groups: int = 16):
        super().__init__()
        self.causal_pad = (kernel_size - 1) * dilation
        self.conv = nn.Conv1d(
            d_model, d_model,
            kernel_size=kernel_size,
            dilation=dilation,
            groups=groups,
            bias=False
        )
        nn.init.kaiming_normal_(self.conv.weight, nonlinearity='linear')

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_t = x.transpose(1, 2)
        x_t = F.pad(x_t, (self.causal_pad, 0))
        return self.conv(x_t).transpose(1, 2)

# Causality test
_cc = CausalConv(D, kernel_size=4, dilation=1, groups=16)
_t1 = torch.randn(1, 8, D)
_o1 = _cc(_t1)
_t2 = _t1.clone(); _t2[:, 5:, :] = torch.randn(1, 3, D)
_o2 = _cc(_t2)
assert torch.allclose(_o1[:, :5, :], _o2[:, :5, :], atol=1e-6), "CAUSALITY VIOLATION!"
print(f"CausalConv:       params={sum(p.numel() for p in _cc.parameters()):,}, causal=True")


# ----- 3d: ConvMixer -----
class ConvMixer(nn.Module):
    """3-layer causal dilated conv mixer with gating.
    Dilations [1, 4, 64] -> receptive field ~208 tokens."""
    def __init__(self, d_model: int, kernel_size: int = 4, groups: int = 16):
        super().__init__()
        dilations = [1, 4, 64]
        self.layers = nn.ModuleList()
        self.norms = nn.ModuleList()
        for d in dilations:
            self.layers.append(nn.ModuleDict({
                'conv_gate': CausalConv(d_model, kernel_size, dilation=d, groups=groups),
                'conv_value': CausalConv(d_model, kernel_size, dilation=d, groups=groups),
            }))
            self.norms.append(RMSNorm(d_model))
        self.out_proj = BitLinear(d_model, d_model)
        nn.init.normal_(self.out_proj.weight, mean=0.0, std=0.01)
        self.receptive_field = sum((kernel_size - 1) * d for d in dilations) + 1

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        for layer, norm in zip(self.layers, self.norms):
            residual = x
            x = norm(x)
            gate = torch.sigmoid(layer['conv_gate'](x))
            value = layer['conv_value'](x)
            x = residual + gate * value
        return self.out_proj(x)

_mixer = ConvMixer(D)
assert _mixer(torch.randn(2, 16, D)).shape == (2, 16, D)
# Causality
_t1 = torch.randn(1, 32, D); _o1 = _mixer(_t1)
_t2 = _t1.clone(); _t2[:, 20:, :] = torch.randn(1, 12, D)
assert torch.allclose(_o1[:, :20, :], _mixer(_t2)[:, :20, :], atol=1e-6), "MIXER CAUSALITY VIOLATION!"
_mixer_params = sum(p.numel() for p in _mixer.parameters())
print(f"ConvMixer:        params={_mixer_params:,}, receptive_field={_mixer.receptive_field}, causal=True")


# ----- 3e: TernaryGLU -----
class TernaryGLU(nn.Module):
    def __init__(self, d_model: int, expansion: float = 2.67):
        super().__init__()
        inner_dim = int(d_model * expansion)
        self.gate_proj = BitLinear(d_model, inner_dim)
        self.up_proj = BitLinear(d_model, inner_dim)
        self.down_proj = BitLinear(inner_dim, d_model)
        nn.init.normal_(self.down_proj.weight, mean=0.0, std=0.01)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        gate = F.relu(self.gate_proj(x)).square()
        return self.down_proj(gate * self.up_proj(x))

_glu = TernaryGLU(D)
assert _glu(torch.randn(2, 16, D)).shape == (2, 16, D)
_glu_params = sum(p.numel() for p in _glu.parameters())
print(f"TernaryGLU:       params={_glu_params:,}, inner_dim={int(D*2.67)}")


# ----- 3f: Router -----
class Router(nn.Module):
    def __init__(self, d_model: int):
        super().__init__()
        self.linear = nn.Linear(d_model, 1)
        nn.init.zeros_(self.linear.weight)
        nn.init.constant_(self.linear.bias, 1.0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return torch.sigmoid(self.linear(x))

_router = Router(D)
assert _router(torch.randn(2, 16, D)).shape == (2, 16, 1)
_router_params = sum(p.numel() for p in _router.parameters())
print(f"Router:           params={_router_params:,}")


# ----- Speed benchmark -----
print("\n" + "=" * 60)
print("SPEED BENCHMARK (d_model=256)")
print("=" * 60)

_test = torch.randn(4, 64, D)

t0 = time.time()
for _ in range(50):
    _ = _mixer(_test)
mixer_t = (time.time() - t0) / 50

t0 = time.time()
for _ in range(50):
    _ = _glu(_test)
glu_t = (time.time() - t0) / 50

combined = mixer_t + glu_t
print(f"  ConvMixer forward:  {mixer_t*1000:.1f}ms")
print(f"  TernaryGLU forward: {glu_t*1000:.1f}ms")
print(f"  Combined (1 rec):   {combined*1000:.1f}ms")
print(f"  Est. 2 recursions:  {combined*2*1000:.1f}ms")
print(f"  Est. fwd+bwd:       {combined*2*3*1000:.1f}ms")

# ----- Summary -----
print("\n" + "=" * 60)
total_block = _mixer_params + _glu_params + _router_params + D * 2
print(f"Block total: {total_block:,} params")
print(f"  + embedding: {50257 * D:,}")
print(f"  = ~{total_block + 50257 * D:,} unique params")

rss = psutil.Process().memory_info().rss / 1e6
print(f"\nProcess RSS: {rss:.0f} MB")
print("=" * 60)
print("CELL 3 COMPLETE ✓")


RMSNorm:          params=256
BitLinear:        params=65,536
CausalConv:       params=16,384, causal=True
ConvMixer:        params=164,608, receptive_field=208, causal=True
TernaryGLU:       params=524,544, inner_dim=683
Router:           params=257

SPEED BENCHMARK (d_model=256)
  ConvMixer forward:  3.2ms
  TernaryGLU forward: 4.6ms
  Combined (1 rec):   7.8ms
  Est. 2 recursions:  15.6ms
  Est. fwd+bwd:       46.8ms

Block total: 689,921 params
  + embedding: 12,865,792
  = ~13,555,713 unique params

Process RSS: 2715 MB
CELL 3 COMPLETE ✓


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint_util
import psutil
import gc
import os
import time


class RecursiveBlock(nn.Module):
    def __init__(self, d_model: int, glu_expansion: float, use_router: bool):
        super().__init__()
        self.norm1 = RMSNorm(d_model)
        self.token_mixer = ConvMixer(d_model, kernel_size=4, groups=16)
        self.norm2 = RMSNorm(d_model)
        self.glu = TernaryGLU(d_model, expansion=glu_expansion)
        self.router = Router(d_model) if use_router else None

    def forward(self, x: torch.Tensor) -> tuple:
        x = x + self.token_mixer(self.norm1(x))
        x = x + self.glu(self.norm2(x))
        r = self.router(x) if self.router is not None else None
        return x, r


class FlashLM(nn.Module):
    """FlashLM v3 with position subsampling.
    
    During training: only compute LM head + CE loss on a random subset
    of sequence positions (default 25%). This gives ~3-4x speedup on the
    dominant bottleneck (LM head matmul + cross entropy).
    
    During eval: compute full logits for all positions.
    """
    def __init__(self, config: dict):
        super().__init__()
        self.config = config
        d = config['d_model']
        
        embed_path = f"gpt2_embed_{d}.pt"
        if os.path.exists(embed_path):
            W_proj = torch.load(embed_path, map_location='cpu')
            self.embedding = nn.Embedding(config['vocab_size'], d)
            self.embedding.weight = nn.Parameter(W_proj)
            print(f"  Loaded pretrained embeddings from {embed_path}")
        else:
            self.embedding = nn.Embedding(config['vocab_size'], d)
            print(f"  WARNING: {embed_path} not found, random init")
        
        self.block = RecursiveBlock(d, config['glu_expansion'], config['USE_ROUTER'])
        self.final_norm = RMSNorm(d)
        self.lm_head = nn.Linear(d, config['vocab_size'], bias=False)
        self.lm_head.weight = self.embedding.weight
        
        self.deep_sup_norms = nn.ModuleDict()
        
        self.ema_decay = config['ema_decay']
        self._ema_shadow = {}
        self._ema_backup = {}
        self.n_recursions = config['n_recursions']
        self.label_smoothing = config['label_smoothing']
        self.zloss_coeff = config['zloss_coeff']
        
        # Position subsampling ratio (only used during training)
        self.train_pos_fraction = 0.25  # compute loss on 25% of positions

    def _init_ema(self):
        self._ema_shadow = {}
        for name, param in self.named_parameters():
            if param.requires_grad:
                self._ema_shadow[name] = param.data.clone()
    
    def _update_ema(self):
        decay = self.ema_decay
        for name, param in self.named_parameters():
            if param.requires_grad and name in self._ema_shadow:
                self._ema_shadow[name].mul_(decay).add_(param.data, alpha=1.0 - decay)
    
    def _apply_ema(self):
        self._ema_backup = {}
        for name, param in self.named_parameters():
            if name in self._ema_shadow:
                self._ema_backup[name] = param.data.clone()
                param.data.copy_(self._ema_shadow[name])
    
    def _restore_from_ema(self):
        for name, param in self.named_parameters():
            if name in self._ema_backup:
                param.data.copy_(self._ema_backup[name])
        self._ema_backup = {}

    def forward(self, input_ids, targets=None, use_checkpointing=True):
        B, T = input_ids.shape
        x = self.embedding(input_ids)
        
        router_scores = []
        
        for step in range(1, self.n_recursions + 1):
            if use_checkpointing and self.training:
                def run_block(x_in):
                    return self.block(x_in)
                x, r = checkpoint_util.checkpoint(run_block, x, use_reentrant=False)
            else:
                x, r = self.block(x)
            if r is not None:
                router_scores.append(r)
        
        x = self.final_norm(x)  # (B, T, D)
        
        result = {
            'logits': None,
            'loss': None,
            'aux_losses': [],
            'router_scores': router_scores,
        }
        
        if targets is not None and self.training:
            # === POSITION SUBSAMPLING (training only) ===
            # Only compute lm_head on a fraction of positions
            n_positions = B * T
            n_sample = max(1, int(n_positions * self.train_pos_fraction))
            
            # Flatten to (B*T, D) and (B*T,)
            x_flat = x.reshape(n_positions, -1)         # (B*T, D)
            targets_flat = targets.reshape(n_positions)  # (B*T,)
            
            # Random position indices
            indices = torch.randperm(n_positions, device=x.device)[:n_sample]
            
            # Subsample
            x_sub = x_flat[indices]           # (n_sample, D)
            targets_sub = targets_flat[indices]  # (n_sample,)
            
            # Compute logits ONLY for sampled positions
            logits_sub = self.lm_head(x_sub)  # (n_sample, vocab_size)
            
            main_loss = F.cross_entropy(
                logits_sub, targets_sub,
                label_smoothing=self.label_smoothing,
                reduction='mean'
            )
            z_loss = self.zloss_coeff * logits_sub.float().logsumexp(dim=-1).pow(2).mean()
            
            result['loss'] = main_loss + z_loss
            result['main_loss'] = main_loss.item()
            result['z_loss'] = z_loss.item()
            # Store subsampled logits shape for logging
            result['n_sampled_positions'] = n_sample
            
        elif targets is not None:
            # === FULL COMPUTATION (eval) ===
            logits = self.lm_head(x)
            result['logits'] = logits
            main_loss = F.cross_entropy(
                logits.reshape(-1, logits.size(-1)),
                targets.reshape(-1),
                label_smoothing=self.label_smoothing,
                reduction='mean'
            )
            z_loss = self.zloss_coeff * logits.float().logsumexp(dim=-1).pow(2).mean()
            result['loss'] = main_loss + z_loss
            result['main_loss'] = main_loss.item()
            result['z_loss'] = z_loss.item()
        else:
            # === GENERATION MODE (no targets) ===
            logits = self.lm_head(x)
            result['logits'] = logits
        
        return result

    def count_parameters(self) -> dict:
        embed_params = self.embedding.weight.numel()
        block_params = sum(p.numel() for p in self.block.parameters())
        norm_params = sum(p.numel() for p in self.final_norm.parameters())
        total = sum(p.numel() for p in self.parameters())
        return {
            'embedding (tied)': embed_params,
            'block (shared x2)': block_params,
            'final_norm': norm_params,
            'lm_head': '(tied)',
            'total_unique': total,
        }


# ============================================================
# Build and validate
# ============================================================
print("Building FlashLM v3 (d=256, position subsampling)...")
model = FlashLM(CONFIG)

pc = model.count_parameters()
print("\nParameters:")
for k, v in pc.items():
    print(f"  {k}: {v:,}" if isinstance(v, int) else f"  {k}: {v}")
total_params = sum(p.numel() for p in model.parameters())
print(f"  Total: {total_params:,} ({total_params*4/1e6:.1f} MB float32)")
print(f"  Train position fraction: {model.train_pos_fraction} (25%)")

# Forward test (train mode - subsampled)
print("\nTrain mode forward...")
model.train()
dummy_in = torch.randint(0, CONFIG['vocab_size'], (4, 64))
dummy_tgt = torch.randint(0, CONFIG['vocab_size'], (4, 64))
res = model(dummy_in, targets=dummy_tgt, use_checkpointing=False)
print(f"  Loss: {res['loss'].item():.4f}, CE: {res['main_loss']:.4f}")
print(f"  Positions sampled: {res['n_sampled_positions']} / {4*64} = {res['n_sampled_positions']/(4*64)*100:.0f}%")
assert res['loss'].dim() == 0
print("  ✓ Train forward OK")

# Forward test (eval mode - full)
model.eval()
res_eval = model(dummy_in, targets=dummy_tgt, use_checkpointing=False)
print(f"\nEval mode forward...")
print(f"  Logits shape: {res_eval['logits'].shape}")
print(f"  Loss: {res_eval['loss'].item():.4f}")
assert res_eval['logits'].shape == (4, 64, CONFIG['vocab_size'])
print("  ✓ Eval forward OK")

# Backward test
model.train()
res = model(dummy_in, targets=dummy_tgt, use_checkpointing=False)
res['loss'].backward()
grads_ok = all(
    dict(model.named_parameters())[n].grad is not None and 
    dict(model.named_parameters())[n].grad.norm().item() > 0
    for n in ['embedding.weight', 'block.glu.gate_proj.weight', 'final_norm.scale']
)
assert grads_ok
print("  ✓ Backward OK")

# ============================================================
# SPEED TEST
# ============================================================
print("\n" + "=" * 60)
print("SPEED TEST (with position subsampling)")
print("=" * 60)

model.train()
results_table = {}
for seq_len in [64, 128, 256, 512]:
    _inp = torch.randint(0, CONFIG['vocab_size'], (CONFIG['batch_size'], seq_len))
    _tgt = torch.randint(0, CONFIG['vocab_size'], (CONFIG['batch_size'], seq_len))
    times = []
    for _ in range(10):
        model.zero_grad(set_to_none=True)
        t0 = time.time()
        r = model(_inp, targets=_tgt, use_checkpointing=False)
        r['loss'].backward()
        times.append(time.time() - t0)
    avg = sum(sorted(times)[1:-1]) / (len(times) - 2)
    results_table[seq_len] = avg
    n_pos = CONFIG['batch_size'] * seq_len
    n_sub = int(n_pos * model.train_pos_fraction)
    print(f"  seq={seq_len:>3d}: {avg:.3f}s ({avg*1000:.0f}ms) "
          f"[{n_sub}/{n_pos} positions]")

weighted = (
    1000 * results_table[64] +
    3000 * results_table[128] +
    3000 * results_table[256] +
    3000 * results_table[512]
) / 10000
print(f"\n  Weighted avg:  {weighted:.3f}s/step")
print(f"  10k steps:     ~{10000 * weighted / 3600:.1f}h")
print(f"  +10% overhead: ~{10000 * weighted * 1.1 / 3600:.1f}h")

# EMA
model.zero_grad()
model._init_ema()
for p in model.parameters():
    if p.requires_grad:
        p.data.add_(torch.randn_like(p) * 0.01)
model._update_ema()
assert sum((p.data - model._ema_shadow[n]).abs().sum().item()
           for n, p in model.named_parameters() if n in model._ema_shadow) > 0
print(f"\nEMA: ✓")

model.zero_grad()
gc.collect()
rss = psutil.Process().memory_info().rss / 1e6
print(f"Process RSS: {rss:.0f} MB")
print("=" * 60)
print("CELL 4 COMPLETE ✓")


Building FlashLM v3 (d=256, position subsampling)...
  Loaded pretrained embeddings from gpt2_embed_256.pt

Parameters:
  embedding (tied): 12,865,792
  block (shared x2): 689,921
  final_norm: 256
  lm_head: (tied)
  total_unique: 13,555,969
  Total: 13,555,969 (54.2 MB float32)
  Train position fraction: 0.25 (25%)

Train mode forward...
  Loss: 14.5717, CE: 14.5509
  Positions sampled: 64 / 256 = 25%
  ✓ Train forward OK

Eval mode forward...
  Logits shape: torch.Size([4, 64, 50257])
  Loss: 14.6872
  ✓ Eval forward OK
  ✓ Backward OK

SPEED TEST (with position subsampling)
  seq= 64: 0.229s (229ms) [64/256 positions]
  seq=128: 0.294s (294ms) [128/512 positions]
  seq=256: 0.530s (530ms) [256/1024 positions]
  seq=512: 0.999s (999ms) [512/2048 positions]

  Weighted avg:  0.570s/step
  10k steps:     ~1.6h
  +10% overhead: ~1.7h

EMA: ✓
Process RSS: 2549 MB
CELL 4 COMPLETE ✓


In [None]:
import torch
import psutil
import gc
import time

print("=" * 60)
print("SYNTHETIC OVERFIT TEST (with position subsampling)")
print("=" * 60)

# Save state
saved_state = {k: v.clone() for k, v in model.state_dict().items()}

torch.manual_seed(CONFIG['seed'])
SEQ_LEN = 32
N_SAMPLES = 5
syn_data = torch.randint(0, CONFIG['vocab_size'], (N_SAMPLES, SEQ_LEN + 1))
syn_in = syn_data[:, :-1]
syn_tgt = syn_data[:, 1:]
del syn_data

# Temporarily set position fraction higher for overfit test
# (with only 5*32=160 positions, 25% = 40 positions — too few to overfit)
original_frac = model.train_pos_fraction
model.train_pos_fraction = 1.0  # use all positions for overfit test
model.train()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.0)

N_STEPS = 100
losses = []
print(f"Training {N_STEPS} steps ({N_SAMPLES} samples, seq={SEQ_LEN})...")
print(f"{'Step':>6s} | {'Loss':>10s} | {'CE':>10s} | {'t/step':>8s}")
print("-" * 45)

start = time.time()
for step in range(1, N_STEPS + 1):
    optimizer.zero_grad(set_to_none=True)
    result = model(syn_in, targets=syn_tgt, use_checkpointing=False)
    result['loss'].backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    optimizer.step()
    losses.append(result['loss'].item())
    if step % 10 == 0 or step == 1:
        print(f"{step:>6d} | {result['loss'].item():>10.4f} | {result['main_loss']:>10.4f} | "
              f"{(time.time()-start)/step:>6.2f}s")
    if step % 20 == 0:
        gc.collect()

total_t = time.time() - start
final, initial = losses[-1], losses[0]
print("-" * 45)
print(f"  Initial: {initial:.4f}, Final: {final:.4f}, Min: {min(losses):.4f}")
print(f"  Reduction: {(1-final/initial)*100:.1f}%")
print(f"  Time: {total_t:.1f}s ({total_t/N_STEPS:.2f}s/step)")

if final < initial * 0.7:
    print(f"  ✓ PASS: Model is learning ({(1-final/initial)*100:.0f}% reduction)")
else:
    print(f"  ✗ FAIL: Insufficient learning")

# Restore
model.train_pos_fraction = original_frac
model.load_state_dict(saved_state)
del saved_state, optimizer, syn_in, syn_tgt
gc.collect()

rss = psutil.Process().memory_info().rss / 1e6
print(f"\nProcess RSS: {rss:.0f} MB")
print("=" * 60)
print("CELL 5 COMPLETE ✓")


SYNTHETIC OVERFIT TEST (with position subsampling)
Training 100 steps (5 samples, seq=32)...
  Step |       Loss |         CE |   t/step
---------------------------------------------
     1 |    14.4979 |    14.4752 |   0.41s
    10 |     1.8551 |     1.8299 |   0.40s
    20 |     1.7439 |     1.7249 |   0.39s
    30 |     1.6578 |     1.6442 |   0.39s
    40 |     1.5943 |     1.5859 |   0.39s
    50 |     1.5420 |     1.5365 |   0.39s
    60 |     1.5102 |     1.5069 |   0.39s
    70 |     1.4939 |     1.4919 |   0.40s
    80 |     1.4847 |     1.4826 |   0.40s
    90 |     1.4682 |     1.4670 |   0.40s
   100 |     1.4629 |     1.4619 |   0.40s
---------------------------------------------
  Initial: 14.4979, Final: 1.4629, Min: 1.4629
  Reduction: 89.9%
  Time: 39.7s (0.40s/step)
  ✓ PASS: Model is learning (90% reduction)

Process RSS: 2665 MB
CELL 5 COMPLETE ✓


In [None]:
import torch
import numpy as np
import psutil
import gc
import os
import time
import subprocess, sys

print("=" * 60)
print("DATA LOADING")
print("=" * 60)

MEMMAP_PATH = "fineweb_tokens.npy"
META_PATH = "fineweb_meta.npz"

if os.path.exists(MEMMAP_PATH) and os.path.exists(META_PATH):
    print(f"Found existing memmap: {MEMMAP_PATH}")
    meta = np.load(META_PATH)
    total_tokens = int(meta['total_tokens'])
    n_docs = int(meta['n_docs'])
    print(f"  Documents: {n_docs:,}")
    print(f"  Total tokens: {total_tokens:,}")
    tokens_mmap = np.memmap(MEMMAP_PATH, dtype=np.uint16, mode='r')
    print(f"  Memmap shape: {tokens_mmap.shape}")
else:
    # Install compatible datasets + fix huggingface-hub version
    print("Installing datasets with compatible dependencies...")
    subprocess.check_call([
        sys.executable, "-m", "pip", "install",
        "datasets>=2.19,<3.0",
        "huggingface-hub>=0.23,<1.0",
        "fsspec>=2024.2.0,<2025.0.0",
        "--quiet"
    ])
    
    # Verify fix
    import importlib
    for mod_name in list(sys.modules.keys()):
        if 'huggingface_hub' in mod_name or 'datasets' in mod_name or 'fsspec' in mod_name:
            del sys.modules[mod_name]
    
    import huggingface_hub
    print(f"  huggingface_hub: {huggingface_hub.__version__}")
    
    import datasets
    print(f"  datasets: {datasets.__version__}")
    
    from datasets import load_dataset
    from transformers import GPT2TokenizerFast
    tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
    
    print(f"\n  Streaming {CONFIG['n_docs']:,} documents from {CONFIG['dataset_name']}...")
    t0 = time.time()
    
    ds = load_dataset(
        CONFIG['dataset_name'],
        CONFIG['dataset_subset'],
        split="train",
        streaming=True
    )
    
    all_tokens = []
    n_docs = 0
    total_chars = 0
    
    for doc in ds:
        if n_docs >= CONFIG['n_docs']:
            break
        
        text = doc.get('text', '')
        if len(text) < 50:
            continue
        
        tokens = tokenizer.encode(text)
        all_tokens.extend(tokens)
        n_docs += 1
        total_chars += len(text)
        
        if n_docs % 5000 == 0:
            elapsed = time.time() - t0
            rss = psutil.Process().memory_info().rss / 1e6
            print(f"    {n_docs:,} docs | {len(all_tokens):,} tokens | "
                  f"{elapsed:.0f}s | RSS: {rss:.0f} MB")
    
    elapsed = time.time() - t0
    total_tokens = len(all_tokens)
    print(f"\n  Downloaded {n_docs:,} docs in {elapsed:.0f}s")
    print(f"  Total tokens: {total_tokens:,}")
    print(f"  Avg tokens/doc: {total_tokens/n_docs:.0f}")
    
    # Save as uint16 memmap
    print(f"\n  Saving memmap to {MEMMAP_PATH}...")
    tokens_array = np.array(all_tokens, dtype=np.uint16)
    tokens_mmap = np.memmap(MEMMAP_PATH, dtype=np.uint16, mode='w+', shape=tokens_array.shape)
    tokens_mmap[:] = tokens_array[:]
    tokens_mmap.flush()
    np.savez(META_PATH, total_tokens=total_tokens, n_docs=n_docs)
    
    file_size_mb = os.path.getsize(MEMMAP_PATH) / 1e6
    print(f"  Memmap size: {file_size_mb:.1f} MB")
    
    del all_tokens, tokens_array
    gc.collect()

# --- Verify ---
from transformers import GPT2TokenizerFast
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")

total_tokens = len(tokens_mmap)
assert tokens_mmap.dtype == np.uint16
assert total_tokens > 100000, f"Too few tokens: {total_tokens}"
assert tokens_mmap.max() < CONFIG['vocab_size']

sample_text = tokenizer.decode(tokens_mmap[:50].tolist())
print(f"\n  First 50 tokens: '{sample_text[:100]}...'")

# --- Train/Val split ---
train_size = int(total_tokens * CONFIG['train_split'])
val_size = total_tokens - train_size
print(f"  Train: {train_size:,} tokens ({CONFIG['train_split']*100:.0f}%)")
print(f"  Val:   {val_size:,} tokens ({(1-CONFIG['train_split'])*100:.0f}%)")

# --- Batch sampler ---
class TokenDataset:
    """Random-access dataset returning (input, target) from memmap."""
    def __init__(self, tokens_mmap, start_idx: int, end_idx: int, seq_len: int):
        self.tokens = tokens_mmap
        self.start = start_idx
        self.end = end_idx
        self.seq_len = seq_len
    
    def get_batch(self, batch_size: int) -> tuple:
        max_start = self.end - self.seq_len - 1
        indices = torch.randint(self.start, max_start, (batch_size,))
        input_ids = torch.stack([
            torch.from_numpy(self.tokens[idx:idx+self.seq_len].astype(np.int64))
            for idx in indices
        ])
        targets = torch.stack([
            torch.from_numpy(self.tokens[idx+1:idx+self.seq_len+1].astype(np.int64))
            for idx in indices
        ])
        return input_ids, targets

train_dataset = TokenDataset(tokens_mmap, 0, train_size, seq_len=64)
val_dataset = TokenDataset(tokens_mmap, train_size, total_tokens, seq_len=64)

# Batch test
_inp, _tgt = train_dataset.get_batch(4)
assert _inp.shape == (4, 64)
assert _tgt.shape == (4, 64)
assert (_inp[:, 1:] == _tgt[:, :-1]).all(), "Shift incorrect!"
print(f"\n  Batch test: ✓ (shape={_inp.shape}, shift verified)")

rss = psutil.Process().memory_info().rss / 1e6
print(f"Process RSS: {rss:.0f} MB")
print("=" * 60)
print("CELL 6 COMPLETE ✓")


DATA LOADING
Installing datasets with compatible dependencies...
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
deepnote-toolkit 2.1.2 requires pyarrow<=17.0.0,>=13.0.0; python_version == "3.11" and sys_platform != "darwin", but you have pyarrow 23.0.1 which is incompatible.[0m[31m
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.0[0m[39;49m -> [0m[32;49m26.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
  huggingface_hub: 0.36.2
  datasets: 2.21.0

  Streaming 30,000 documents from HuggingFaceFW/fineweb-edu...
Downloading readme: 100%|██████████| 26.4k/26.4k [00:00<00:00, 220kB/s]
Token indices sequence length is longer than the specified maximum sequence length for this model (1055 > 1024). Running this sequence through the

In [None]:
import torch
import torch.nn as nn
import math
import psutil

print("=" * 60)
print("OPTIMIZERS & SCHEDULER")
print("=" * 60)

# --- Debug: print all CONFIG keys so we can find the right names ---
print("\nCONFIG keys:")
for k, v in sorted(CONFIG.items()):
    if not isinstance(v, (list, dict)):
        print(f"  {k} = {v}")
    else:
        print(f"  {k} = {type(v).__name__}({len(v)} items)")

# --- Safely resolve beta2 (check multiple possible key names) ---
BETA2 = None
for key in ['beta2', 'adamw_beta2', 'optimizer_beta2', 'betas']:
    if key in CONFIG:
        val = CONFIG[key]
        if isinstance(val, (list, tuple)):
            BETA2 = val[1]
        else:
            BETA2 = val
        print(f"\n  Found beta2 via CONFIG['{key}'] = {BETA2}")
        break
if BETA2 is None:
    BETA2 = 0.999
    print(f"\n  beta2 not found in CONFIG — using default {BETA2}")

# ------------------------------------------------------------------
# NorMuon Optimizer
# ------------------------------------------------------------------
class NorMuon(torch.optim.Optimizer):
    def __init__(self, params, lr=0.02, momentum=0.95, beta2=0.999,
                 ns_steps=5, eps=1e-8):
        defaults = dict(lr=lr, momentum=momentum, beta2=beta2,
                        ns_steps=ns_steps, eps=eps)
        super().__init__(params, defaults)

    @staticmethod
    def newton_schulz5(G, steps=5):
        assert G.ndim == 2
        a, b, c = (3.4445, -4.7750, 2.0315)
        X = G.float()
        X = X / (X.norm() + 1e-7)
        transposed = False
        if X.size(0) > X.size(1):
            X = X.T
            transposed = True
        for _ in range(steps):
            A = X @ X.T
            B = b * A + c * A @ A
            X = a * X + B @ X
        if transposed:
            X = X.T
        return X

    @torch.no_grad()
    def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()
        for group in self.param_groups:
            lr = group['lr']
            beta1 = group['momentum']
            beta2 = group['beta2']
            ns_steps = group['ns_steps']
            eps = group['eps']
            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad
                state = self.state[p]
                if len(state) == 0:
                    state['step'] = 0
                    state['momentum_buffer'] = torch.zeros_like(grad)
                    state['v'] = torch.zeros(grad.size(0), device=grad.device,
                                             dtype=grad.dtype)
                state['step'] += 1
                buf = state['momentum_buffer']
                v = state['v']
                buf.mul_(beta1).add_(grad, alpha=1.0 - beta1)
                O = self.newton_schulz5(buf, steps=ns_steps)
                row_sq_mean = (O * O).mean(dim=1)
                v.mul_(beta2).add_(row_sq_mean, alpha=1.0 - beta2)
                V_expanded = v.unsqueeze(1).expand_as(O)
                O_hat = O / (V_expanded.sqrt() + eps)
                m, n = O_hat.shape
                scale = 0.2 * math.sqrt(m * n) / (O_hat.norm() + eps)
                p.add_(O_hat, alpha=-lr * scale)
        return loss

# ------------------------------------------------------------------
# Build optimizers
# ------------------------------------------------------------------
normuon_params = []
adamw_decay_params = []
adamw_nodecay_params = []

for name, p in model.named_parameters():
    if not p.requires_grad:
        continue
    if p.ndim == 2 and 'embedding' not in name and 'lm_head' not in name:
        normuon_params.append(p)
    elif p.ndim == 1 or 'bias' in name:
        adamw_nodecay_params.append(p)
    else:
        adamw_decay_params.append(p)

print(f"\nBuilding optimizers...")
print(f"  NorMuon params: {sum(p.numel() for p in normuon_params):,} "
      f"({len(normuon_params)} tensors)")
print(f"  AdamW params (decay): {sum(p.numel() for p in adamw_decay_params):,} "
      f"({len(adamw_decay_params)} tensors)")
print(f"  AdamW params (no decay): {sum(p.numel() for p in adamw_nodecay_params):,} "
      f"({len(adamw_nodecay_params)} tensors)")

normuon_opt = NorMuon(
    normuon_params,
    lr=CONFIG['normuon_lr'],
    momentum=CONFIG['normuon_momentum'],
    beta2=BETA2,
)

adamw_opt = torch.optim.AdamW([
    {'params': adamw_decay_params, 'lr': CONFIG['adamw_lr'],
     'weight_decay': CONFIG['weight_decay']},
    {'params': adamw_nodecay_params, 'lr': CONFIG['adamw_lr'],
     'weight_decay': 0.0},
], betas=(0.9, BETA2))

# ------------------------------------------------------------------
# WSD Schedule
# ------------------------------------------------------------------
def get_lr_scale(step):
    warmup = CONFIG['warmup_steps']
    decay_start = CONFIG['decay_start_step']
    total = CONFIG['total_steps']
    if step < warmup:
        return step / warmup
    elif step < decay_start:
        return 1.0
    else:
        progress = (step - decay_start) / max(1, total - decay_start)
        return max(0.0, 1.0 - progress)

print(f"\nWSD Schedule (total={CONFIG['total_steps']}, "
      f"warmup={CONFIG['warmup_steps']}, decay_start={CONFIG['decay_start_step']}):")
for s in [0, 100, 200, 1000, 5000, 8999, 9000, 9500, 9999]:
    sc = get_lr_scale(s)
    print(f"  step {s:5d}: normuon_lr={CONFIG['normuon_lr']*sc:.6f}, "
          f"adamw_lr={CONFIG['adamw_lr']*sc:.6f}")

# ------------------------------------------------------------------
# Quick test
# ------------------------------------------------------------------
print(f"\nTest optimizer step...")
x_test = torch.randint(0, CONFIG['vocab_size'], (2, 32))
y_test = torch.randint(0, CONFIG['vocab_size'], (2, 32))
model.train()
out1 = model(x_test, targets=y_test)
loss1 = out1['loss']
print(f"  Loss before step: {loss1.item():.4f}")
loss1.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), CONFIG['max_grad_norm'])
normuon_opt.step()
adamw_opt.step()
normuon_opt.zero_grad(set_to_none=True)
adamw_opt.zero_grad(set_to_none=True)
out2 = model(x_test, targets=y_test)
print(f"  Loss after step:  {out2['loss'].item():.4f}")
print(f"  ✓ Optimizer step works")

print(f"\nProcess RSS: {psutil.Process().memory_info().rss / 1e6:.0f} MB")
print("=" * 60)
print("CELL 7 COMPLETE ✓")


OPTIMIZERS & SCHEDULER

CONFIG keys:
  USE_ROUTER = True
  adamw_betas = (0.9, 0.95)
  adamw_lr = 0.0003
  batch_size = 4
  checkpoint_dir = ./checkpoints
  d_model = 256
  dataset_name = HuggingFaceFW/fineweb-edu
  dataset_subset = sample-10BT
  dbo_blend = 0.7
  dbo_eval_interval = 25
  dbo_max_steps = 300
  dbo_patience = 50
  decay_start_step = 9000
  deep_supervision_steps = list(1 items)
  device = cpu
  dropout = 0.0
  ema_decay = 0.999
  ewa_checkpoints = list(3 items)
  glu_expansion = 2.67
  glu_inner_dim = 683
  label_smoothing = 0.1
  log_interval = 50
  max_grad_norm = 1.0
  memory_log_interval = 100
  n_docs = 30000
  n_recursions = 2
  normuon_beta2 = 0.999
  normuon_lr = 0.02
  normuon_momentum = 0.95
  num_threads = 2
  phase_schedule = dict(4 items)
  reasoning_shift_step = 7000
  rho1_refresh_interval = 200
  rho1_start_step = 1000
  rho1_top_fraction = 0.4
  router_threshold = 0.5
  seed = 42
  total_steps = 10000
  train_split = 0.95
  vocab_size = 50257
  warmup_s

In [None]:
import torch, torch.nn as nn, torch.nn.functional as F
import time, gc, psutil, os, copy, math, inspect
import numpy as np
import sys

print("=" * 60)
print("STAGE 1: PRE-TRAINING")
print("=" * 60)

# ------------------------------------------------------------------
# CRITICAL FIX: Kill ALL gradient checkpointing BEFORE model creation
# ------------------------------------------------------------------
import torch.utils.checkpoint

def _passthrough(fn, *args, use_reentrant=None, **kwargs):
    return fn(*args, **kwargs)

torch.utils.checkpoint.checkpoint = _passthrough
torch.utils.checkpoint.checkpoint_sequential = lambda functions, segments, input, **kw: \
    (lambda x: [x := f(x) for f in functions][-1] if functions else x)(input)
if 'torch.utils.checkpoint' in sys.modules:
    sys.modules['torch.utils.checkpoint'].checkpoint = _passthrough
print(f"  checkpoint function is now: {torch.utils.checkpoint.checkpoint.__name__}")

# ------------------------------------------------------------------
# SPEED FIX: Reduce grad_accum to fit in 2.5h
# ------------------------------------------------------------------
CONFIG['phase_schedule']['1a']['grad_accum'] = 4
CONFIG['phase_schedule']['1b']['grad_accum'] = 4
CONFIG['phase_schedule']['1c']['grad_accum'] = 8
CONFIG['phase_schedule']['1d']['grad_accum'] = 8

# ------------------------------------------------------------------
# 0. Rebuild model fresh (AFTER patching)
# ------------------------------------------------------------------
model = FlashLM(CONFIG)

for attr in ['use_checkpointing', '_use_checkpointing', 'gradient_checkpointing',
             'checkpointing', 'use_checkpoint']:
    if hasattr(model, attr):
        setattr(model, attr, False)
        print(f"  Disabled model.{attr}")
    for name, mod in model.named_modules():
        if hasattr(mod, attr):
            setattr(mod, attr, False)

model.train()
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Process RSS: {psutil.Process().memory_info().rss / 1e6:.0f} MB")

_fwd_params = inspect.signature(model.forward).parameters
print(f"FlashLM.forward() accepts: {list(_fwd_params.keys())}")

# ------------------------------------------------------------------
# Quick forward+backward test
# ------------------------------------------------------------------
def to_scalar(x):
    return x.item() if hasattr(x, 'item') else float(x)

print("\nQuick forward+backward test...")
_test_x = torch.randint(0, CONFIG['vocab_size'], (2, 32))
_test_y = torch.randint(0, CONFIG['vocab_size'], (2, 32))
_test_out = model(_test_x, targets=_test_y, use_checkpointing=False)
_test_out['loss'].backward()
model.zero_grad(set_to_none=True)
print(f"  PASSED — loss={to_scalar(_test_out['loss']):.4f}")
del _test_x, _test_y, _test_out
gc.collect()

# ------------------------------------------------------------------
# 1. Build optimizers
# ------------------------------------------------------------------
normuon_params = []
adamw_decay_params = []
adamw_nodecay_params = []

for name, p in model.named_parameters():
    if not p.requires_grad:
        continue
    if p.ndim == 2 and 'embedding' not in name and 'lm_head' not in name:
        normuon_params.append(p)
    elif p.ndim == 1 or 'bias' in name:
        adamw_nodecay_params.append(p)
    else:
        adamw_decay_params.append(p)

normuon_opt = NorMuon(
    normuon_params,
    lr=CONFIG['normuon_lr'],
    momentum=CONFIG['normuon_momentum'],
    beta2=CONFIG.get('normuon_beta2', 0.999),
)

adamw_opt = torch.optim.AdamW([
    {'params': adamw_decay_params, 'lr': CONFIG['adamw_lr'],
     'weight_decay': CONFIG['weight_decay']},
    {'params': adamw_nodecay_params, 'lr': CONFIG['adamw_lr'],
     'weight_decay': 0.0},
], betas=CONFIG['adamw_betas'])

print(f"  NorMuon params: {sum(p.numel() for p in normuon_params):,}")
print(f"  AdamW decay params: {sum(p.numel() for p in adamw_decay_params):,}")
print(f"  AdamW no-decay params: {sum(p.numel() for p in adamw_nodecay_params):,}")

# ------------------------------------------------------------------
# 2. Learning-rate schedule (WSD)
# ------------------------------------------------------------------
def get_lr_scale(step):
    warmup = CONFIG['warmup_steps']
    decay_start = CONFIG['decay_start_step']
    total = CONFIG['total_steps']
    if step < warmup:
        return step / warmup
    elif step < decay_start:
        return 1.0
    else:
        progress = (step - decay_start) / max(1, total - decay_start)
        return max(0.0, 1.0 - progress)

def set_lr(step):
    scale = get_lr_scale(step)
    for pg in normuon_opt.param_groups:
        pg['lr'] = CONFIG['normuon_lr'] * scale
    for pg in adamw_opt.param_groups:
        pg['lr'] = CONFIG['adamw_lr'] * scale

# ------------------------------------------------------------------
# 3. Phase schedule (speed-optimized grad_accum)
# ------------------------------------------------------------------
_phases_sorted = sorted(CONFIG['phase_schedule'].values(), key=lambda p: p['start'])

def get_phase(step):
    result = _phases_sorted[0]
    for phase in _phases_sorted:
        if step >= phase['start']:
            result = phase
        else:
            break
    return result

print("\nPhase schedule (speed-optimized):")
for p in _phases_sorted:
    print(f"  steps {p['start']}-{p['end']}: seq={p['seq_len']}, "
          f"grad_accum={p['grad_accum']}, freeze={p.get('freeze_embed', False)}")

# ------------------------------------------------------------------
# 4. EMA
# ------------------------------------------------------------------
ema_state = {k: v.clone() for k, v in model.state_dict().items()}

def update_ema():
    global ema_state
    decay = CONFIG['ema_decay']
    with torch.no_grad():
        for k, v in model.state_dict().items():
            ema_state[k].mul_(decay).add_(v, alpha=1 - decay)

def apply_ema():
    backup = {k: v.clone() for k, v in model.state_dict().items()}
    model.load_state_dict(ema_state)
    return backup

def restore_from_backup(backup):
    model.load_state_dict(backup)

# ------------------------------------------------------------------
# 5. Dataset access
# ------------------------------------------------------------------
tokens_mmap = np.memmap('fineweb_tokens.npy', dtype=np.uint16, mode='r')
total_tokens = len(tokens_mmap)
train_end = int(total_tokens * CONFIG['train_split'])

def get_batch(split, seq_len, batch_size):
    if split == 'train':
        start, end = 0, train_end
    else:
        start, end = train_end, total_tokens
    max_start = end - seq_len - 1
    idxs = np.random.randint(start, max_start, size=batch_size)
    x = np.stack([tokens_mmap[i:i+seq_len] for i in idxs])
    y = np.stack([tokens_mmap[i+1:i+seq_len+1] for i in idxs])
    return torch.from_numpy(x.astype(np.int64)), torch.from_numpy(y.astype(np.int64))

# ------------------------------------------------------------------
# 6. Forward kwargs builder
# ------------------------------------------------------------------
_fwd_sig = inspect.signature(model.forward).parameters

def _build_fwd_kwargs(targets):
    kwargs = {}
    if 'targets' in _fwd_sig:
        kwargs['targets'] = targets
    if 'use_checkpointing' in _fwd_sig:
        kwargs['use_checkpointing'] = False
    return kwargs

# ------------------------------------------------------------------
# 7. Evaluation
# ------------------------------------------------------------------
@torch.no_grad()
def evaluate(seq_len=256, n_batches=5, batch_size=4):
    model.eval()
    old_frac = getattr(model, 'train_pos_fraction', None)
    if old_frac is not None:
        model.train_pos_fraction = 1.0
    losses = []
    for _ in range(n_batches):
        x, y = get_batch('val', seq_len, batch_size)
        out = model(x, **_build_fwd_kwargs(y))
        losses.append(to_scalar(out['main_loss']))
    model.train()
    if old_frac is not None:
        model.train_pos_fraction = old_frac
    return sum(losses) / len(losses)

# ------------------------------------------------------------------
# 8. Checkpoint saving
# ------------------------------------------------------------------
os.makedirs(CONFIG['checkpoint_dir'], exist_ok=True)

def save_checkpoint(step, train_loss, val_loss):
    path = os.path.join(CONFIG['checkpoint_dir'], f'flashlm_step{step}.pt')
    torch.save({
        'step': step,
        'model_state_dict': model.state_dict(),
        'normuon_state_dict': normuon_opt.state_dict(),
        'adamw_state_dict': adamw_opt.state_dict(),
        'ema_state': ema_state,
        'train_loss': train_loss,
        'val_loss': val_loss,
        'config': CONFIG,
    }, path)
    size_mb = os.path.getsize(path) / 1e6
    print(f"  Checkpoint saved: {path} ({size_mb:.1f} MB)")

# ------------------------------------------------------------------
# 9. Training loop
# ------------------------------------------------------------------
best_val_loss = float('inf')
train_losses = []
val_losses = []
step_times = []
total_start = time.time()

print(f"\nStarting training: {CONFIG['total_steps']} steps")
print(f"  Phases: {len(CONFIG['phase_schedule'])}")
print(f"  Position subsampling: {getattr(model, 'train_pos_fraction', 1.0)}")
print(f"  Gradient checkpointing: DISABLED (global patch)")
print()

for step in range(1, CONFIG['total_steps'] + 1):
    step_start = time.time()

    phase = get_phase(step)
    seq_len = phase['seq_len']
    grad_accum = phase['grad_accum']
    freeze_embed = phase.get('freeze_embed', False)

    if hasattr(model, 'embedding'):
        model.embedding.weight.requires_grad = not freeze_embed

    set_lr(step)

    normuon_opt.zero_grad(set_to_none=True)
    adamw_opt.zero_grad(set_to_none=True)
    accum_loss = 0.0
    accum_main = 0.0

    for micro in range(grad_accum):
        x, y = get_batch('train', seq_len, CONFIG['batch_size'])
        out = model(x, **_build_fwd_kwargs(y))
        loss = out['loss'] / grad_accum
        loss.backward()
        accum_loss += to_scalar(out['loss']) / grad_accum
        accum_main += to_scalar(out['main_loss']) / grad_accum

    torch.nn.utils.clip_grad_norm_(model.parameters(), CONFIG['max_grad_norm'])
    normuon_opt.step()
    adamw_opt.step()
    update_ema()

    step_time = time.time() - step_start
    train_losses.append(accum_loss)
    step_times.append(step_time)

    if step % CONFIG['log_interval'] == 0 or step == 1:
        lr_scale = get_lr_scale(step)
        elapsed = time.time() - total_start
        eta = (elapsed / step) * (CONFIG['total_steps'] - step)
        print(f"step {step:5d} | loss {accum_loss:.4f} | CE {accum_main:.4f} | "
              f"lr_s {lr_scale:.4f} | seq {seq_len:3d} | "
              f"ga {grad_accum:2d} | {step_time:.2f}s/step | "
              f"elapsed {elapsed/3600:.2f}h | ETA {eta/3600:.2f}h")

    if step % 500 == 0 or step == 1:
        val_loss = evaluate(seq_len=min(seq_len, 256), n_batches=5)
        val_losses.append((step, val_loss))
        is_best = val_loss < best_val_loss
        if is_best:
            best_val_loss = val_loss
        print(f"  >>> VAL loss: {val_loss:.4f} {'(BEST)' if is_best else ''}")

    if step in CONFIG.get('ewa_checkpoints', []):
        val_loss = evaluate(seq_len=256, n_batches=10)
        save_checkpoint(step, accum_loss, val_loss)

    if step % 200 == 0:
        gc.collect()

    if step % CONFIG.get('memory_log_interval', 100) == 0:
        rss = psutil.Process().memory_info().rss / 1e6
        if rss > 4500:
            print(f"  ⚠ RSS high: {rss:.0f} MB")

# ------------------------------------------------------------------
# 10. Final evaluation & save
# ------------------------------------------------------------------
total_time = time.time() - total_start
print("\n" + "=" * 60)
print("TRAINING COMPLETE")
print("=" * 60)
print(f"Total time: {total_time/3600:.2f} h ({total_time:.0f} s)")
print(f"Average step time: {sum(step_times)/len(step_times):.3f} s")
print(f"Final train loss: {train_losses[-1]:.4f}")
print(f"Best val loss: {best_val_loss:.4f}")

final_val = evaluate(seq_len=256, n_batches=20)
print(f"Final val loss (20 batches): {final_val:.4f}")

backup = apply_ema()
ema_val = evaluate(seq_len=256, n_batches=20)
print(f"EMA val loss: {ema_val:.4f}")
if ema_val < final_val:
    print("  EMA is better — keeping EMA weights")
else:
    restore_from_backup(backup)
    print("  Base weights are better — restoring")

final_path = os.path.join(CONFIG['checkpoint_dir'], 'flashlm_final.pt')
torch.save({
    'model_state_dict': model.state_dict(),
    'ema_state': ema_state,
    'config': CONFIG,
    'val_loss': min(final_val, ema_val),
    'total_steps': CONFIG['total_steps'],
    'total_time_h': total_time / 3600,
}, final_path)
final_size = os.path.getsize(final_path) / 1e6
print(f"\nFinal model saved: {final_path} ({final_size:.1f} MB)")
print(f"Process RSS: {psutil.Process().memory_info().rss / 1e6:.0f} MB")
print("=" * 60)
print("CELL 8 COMPLETE ✓")


STAGE 1: PRE-TRAINING
  checkpoint function is now: _passthrough
  Loaded pretrained embeddings from gpt2_embed_256.pt
Model parameters: 13,555,969
Process RSS: 3280 MB
FlashLM.forward() accepts: ['input_ids', 'targets', 'use_checkpointing']

Quick forward+backward test...
  PASSED — loss=14.4360
  NorMuon params: 590,336
  AdamW decay params: 12,964,096
  AdamW no-decay params: 1,537

Phase schedule (speed-optimized):
  steps 0-1000: seq=64, grad_accum=4, freeze=True
  steps 1000-4000: seq=128, grad_accum=4, freeze=False
  steps 4000-7000: seq=256, grad_accum=8, freeze=False
  steps 7000-10000: seq=512, grad_accum=8, freeze=False

Starting training: 10000 steps
  Phases: 4
  Position subsampling: 0.25
  Gradient checkpointing: DISABLED (global patch)

step     1 | loss 13.5547 | CE 13.5350 | lr_s 0.0050 | seq  64 | ga  4 | 0.39s/step | elapsed 0.00h | ETA 1.09h
  >>> VAL loss: 13.5829 (BEST)
step    50 | loss 8.5744 | CE 8.5672 | lr_s 0.2500 | seq  64 | ga  4 | 0.39s/step | elapsed 0.

KeyboardInterrupt: 

In [None]:
import torch, os
os.makedirs('/checkpoints', exist_ok=True)
torch.save({
    'model_state_dict': model.state_dict(),
    'ema_state': ema_state,
    'config': CONFIG,
    'val_loss': 6.8019,
    'last_step': 4050,
    'note': 'stopped at phase 1c - val_loss 6.80 best',
}, '/checkpoints/flashlm_final.pt')
size = os.path.getsize('/checkpoints/flashlm_final.pt') / 1e6
print(f"Model saved! /checkpoints/flashlm_final.pt ({size:.1f} MB)")

Model saved! /checkpoints/flashlm_final.pt (159.9 MB)


<a style='text-decoration:none;line-height:16px;display:flex;color:#5B5B62;padding:10px;justify-content:end;' href='https://deepnote.com?utm_source=created-in-deepnote-cell&projectId=7a3227a4-aad7-4fa2-bd88-175f07a980b0' target="_blank">

Created in <span style='font-weight:600;margin-left:4px;'>Deepnote</span></a>