In [102]:
#copyright joshuah.rainstar@gmail.com
from __future__ import annotations
import math
import typing

import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Dict, Tuple

# ===========================================================
# Utilities
# ===========================================================

def _norm(v, eps: float = 1e-12):
    return torch.linalg.vector_norm(v, dim=-1, keepdim=True).clamp_min(eps)


def _unit(v, eps: float = 1e-12):
    return v / _norm(v, eps)

    
@torch.no_grad()
def phase_transport_between(
    curr: torch.Tensor,
    prev: torch.Tensor,
    tau: float = 1e-6,          # semantic threshold (unchanged)
    eps: float = 1e-12          # numeric epsilon (NEW: decoupled from tau)
) -> torch.Tensor:
    assert curr.shape == prev.shape and curr.dim() == 3
    B, T, C = curr.shape

    # Units (reuse norms) — clamp with eps (NOT tau)
    nu = torch.linalg.vector_norm(curr, dim=-1, keepdim=True).clamp_min(eps)   # (B,T,1)
    nv = torch.linalg.vector_norm(prev, dim=-1, keepdim=True).clamp_min(eps)   # (B,T,1)
    u = curr / nu
    v = prev / nv

    w = curr - prev
    c = (u * v).sum(dim=-1, keepdim=True)                                      # (B,T,1)

    # Masks (semantic thresholds use tau)
    near_pos = (c >  1.0 - tau)                                                # (B,T,1)
    near_neg = (c < -1.0 + tau)                                                # (B,T,1)
    small_u  = (nu < tau)                                                      # (B,T,1)
    small_v  = (nv < tau)                                                      # (B,T,1)
    trivial  = near_pos | small_u | small_v                                    # (B,T,1)

    # General branch
    denom = (1.0 + c).clamp_min(eps)                                           # (B,T,1)
    a = (v * w).sum(dim=-1, keepdim=True)                                      # (B,T,1)
    b = (u * w).sum(dim=-1, keepdim=True)                                      # (B,T,1)
    Kw  = u * a - v * b                                                        # (B,T,C)
    K2w = u * (a * c - b) + v * (b * c - a)                                    # (B,T,C)
    y_gen = w - Kw + (K2w / denom)                                             # (B,T,C)

    # Antipodal candidate
    if C == 1:
        y_neg = -w
    else:
        # Keep this normalization stable with eps as well
        idx = torch.argmin(v.abs().reshape(-1, C), dim=1, keepdim=True)        # (B*T,1)
        s = v.reshape(-1, C).gather(1, idx)                                    # (B*T,1)
        p = -s * v.reshape(-1, C)
        onehot = F.one_hot(idx.squeeze(-1), num_classes=C).to(s.dtype)
        p = p + onehot
        n = torch.linalg.vector_norm(p, dim=1, keepdim=True).clamp_min(eps)
        p = (p / n).view(B, T, C)
        proj_v = (v * w).sum(dim=-1, keepdim=True) * v                         # (B,T,C)
        proj_p = (p * w).sum(dim=-1, keepdim=True) * p                         # (B,T,C)
        y_neg = w - 2.0 * proj_v - 2.0 * proj_p

    # Fuse selections
    y = torch.where(trivial, w, y_gen)
    y = torch.where(near_neg, y_neg, y)
    return y

# ===========================================================
# Multi-scale features (vectorized pyramid)
# ===========================================================
class CausalCentroidPyramid(nn.Module):
    """Identical outputs to CausalCentroidPyramid, but faster.

    Key changes:
    - Builds all dyadic centroids directly via cumsum (no sequential dependency).
    - Computes all cluster deltas in a single batched call to phase_transport_between.
    """
    def __init__(self, num_scales: int, tau: float = 1e-6):
        super().__init__()
        assert num_scales >= 1
        self.K = num_scales
        self.tau = float(tau)

    @torch.no_grad()
    def forward(self, x: torch.Tensor, mask_early: bool = True) -> torch.Tensor:
        B, T, C = x.shape
        device = x.device
        dtype = x.dtype
    
        # token-level PT (scale-1)
        prev_tok = torch.zeros_like(x)
        if T > 1:
            prev_tok[:, 1:, :] = x[:, :-1, :].contiguous()
        d1 = phase_transport_between(x, prev_tok, tau=self.tau)  # (B,T,C)
        if mask_early:
            d1[:, :1, :].zero_()
        if self.K == 1:
            return d1.unsqueeze(2)
    
        # constants (avoid .item() / data-dependent Python ints)
        K1 = self.K - 1
        W_vec = (2 ** torch.arange(1, self.K, device=device, dtype=torch.long))  # (K1,)
        Wmax = (1 << (self.K - 1)) if self.K > 1 else 1  # Python int
    
        # dyadic centroids via windowed means (vectorized)
        csum = torch.cumsum(x, dim=1)  # (B,T,C)
        csum_pad = torch.cat([torch.zeros(B, 1, C, device=device, dtype=dtype), csum], dim=1)  # (B,T+1,C)
    
        t_end = torch.arange(1, T + 1, device=device, dtype=torch.long)                         # (T,)
        idx_start_jt = (t_end.unsqueeze(0) - W_vec.unsqueeze(1)).clamp_min(0)                  # (K1,T)
        idx_start_tk = idx_start_jt.transpose(0, 1).contiguous()                                # (T,K1)
        idx_end_tk = t_end.unsqueeze(1).expand(T, K1).contiguous()                              # (T,K1)
    
        csum_ext = csum_pad.unsqueeze(2).expand(B, T + 1, K1, C)                                # (B,T+1,K1,C)
    
        gather_shape = (B, T, K1, C)
        idx_start = idx_start_tk.unsqueeze(0).unsqueeze(-1).expand(gather_shape)                # (B,T,K1,C)
        idx_end = idx_end_tk.unsqueeze(0).unsqueeze(-1).expand(gather_shape)                    # (B,T,K1,C)
    
        start_vals = torch.gather(csum_ext, dim=1, index=idx_start)
        end_vals = torch.gather(csum_ext, dim=1, index=idx_end)
        window_sums = end_vals - start_vals                                                     # (B,T,K1,C)
        mu_all = window_sums / W_vec.to(dtype).view(1, 1, -1, 1)                                # (B,T,K1,C)
    
        if mask_early:
            t_idx = torch.arange(T, device=device).unsqueeze(1)                                 # (T,1)
            valid_mu = (t_idx >= (W_vec - 1).view(1, -1))                                       # (T,K1)
            mu_all = mu_all * valid_mu.view(1, T, -1, 1)
    
        # previous centroids (shift by W per scale), vectorized with padding
        mu_pad = torch.cat([torch.zeros(B, Wmax, K1, C, device=device, dtype=dtype), mu_all], dim=1)  # (B,Wmax+T,K1,C)
        idx_prev_tk = torch.arange(T, device=device).unsqueeze(1) - W_vec.view(1, -1) + Wmax          # (T,K1)
        idx_prev = idx_prev_tk.unsqueeze(0).unsqueeze(-1).expand(gather_shape)                        # (B,T,K1,C)
        prev_mu_all = torch.gather(mu_pad, dim=1, index=idx_prev)                                     # (B,T,K1,C)
    
        # all cluster deltas in one batched PT call
        mu_flat = mu_all.reshape(B * K1, T, C).contiguous()
        prev_flat = prev_mu_all.reshape(B * K1, T, C).contiguous()
        d_flat = phase_transport_between(mu_flat, prev_flat, tau=self.tau)                            # (B*K1,T,C)
        d_clusters = d_flat.view(B, T, K1, C)
    
        if mask_early:
            valid_d = (torch.arange(T, device=device).unsqueeze(1) >= W_vec.view(1, -1))              # (T,K1)
            d_clusters = d_clusters * valid_d.view(1, T, -1, 1)
    
        return torch.cat([d1.unsqueeze(2), d_clusters], dim=2)  # (B,T,K,C)

        # ----- STREAMING STATE FOR INFERENCE -----
class CausalPyramidState:
    """
    O(K) step-time updates, no recompute.
    For level ℓ we keep a ring buffer of length 2^ℓ storing μ_ℓ (with μ_0=x).
    That suffices both to:
      - build μ_{ℓ+1}(t) from μ_ℓ(t) and μ_ℓ(t-2^ℓ)
      - compute deltas at scale s=ℓ via μ_s(t-2^s)
    """
    def __init__(self, num_scales: int, C: int, device, batch_size: int = 1, tau: float = 1e-6):
        self.K = num_scales
        self.C = C
        self.B = batch_size
        self.device = device
        self.tau = float(tau)
        self.t = 0  # number of tokens processed so far

        # ring buffers: list over levels ℓ = 0..K-1, each [B, L=2^ℓ, C]
        self.buffers = []
        self.ptrs = []
        for l in range(self.K):
            L = 1 << l
            self.buffers.append(torch.zeros(self.B, L, C, device=device))
            self.ptrs.append(0)

    def _read_lookback(self, level: int, r: int):
        """return μ_level(t - r); zeros if not enough history yet"""
        if self.t < r:
            return torch.zeros(self.B, self.C, device=self.device)
        L = self.buffers[level].size(1)
        idx = (self.ptrs[level] - r) % L
        return self.buffers[level][:, idx, :]

    def _push(self, level: int, value: torch.Tensor):
        """write current μ_level(t) and advance ptr"""
        L = self.buffers[level].size(1)
        self.buffers[level][:, self.ptrs[level], :] = value
        self.ptrs[level] = (self.ptrs[level] + 1) % L

    @torch.no_grad()
    def step(self, x_t: torch.Tensor) -> torch.Tensor:
        """
        x_t: (B, C)
        returns d(t): (B, K, C)  [token PT + (K-1) cluster PTs]
        """
        B, C = x_t.shape
        feats = []

        # ------- token PT (read BEFORE any push) -------
        prev_x = self._read_lookback(level=0, r=1)  # μ0(t-1)
        d1 = phase_transport_between(x_t[:, None, :], prev_x[:, None, :], tau=self.tau).squeeze(1)
        if self.t == 0:
            d1.zero_()
        feats.append(d1)

        # ------- (A) compute all μ_s(t) with pre-push lookbacks -------
        mu_curr = [None] * self.K
        mu_curr[0] = x_t                      # μ0(t)
        mu_prev = x_t
        for s in range(1, self.K):
            W1 = 1 << (s - 1)
            W  = 1 << s
            mu_back = self._read_lookback(level=s-1, r=W1)   # μ_{s-1}(t - 2^{s-1})  (pre-push!)
            mu_s_t  = 0.5 * (mu_prev + mu_back)              # μ_s(t)
            if self.t < (W - 1):                             # early mask (global t)
                mu_s_t.zero_()
            mu_curr[s] = mu_s_t
            mu_prev = mu_s_t

        # ------- (B) compute all deltas d_s using μ_s(t−W) (pre-push) -------
        for s in range(1, self.K):
            W = 1 << s
            mu_prevW = self._read_lookback(level=s, r=W)     # μ_s(t - 2^s)  (pre-push!)
            d_s = phase_transport_between(mu_curr[s][:, None, :], mu_prevW[:, None, :], tau=self.tau).squeeze(1)
            if self.t + 1 <= W:
                d_s.zero_()
            feats.append(d_s)

        # ------- (C) push μ_ℓ(t) for all levels, exactly once -------
        self._push(level=0, value=mu_curr[0])
        for s in range(1, self.K):
            self._push(level=s, value=mu_curr[s])

        self.t += 1
        return torch.stack(feats, dim=1)  # (B, K, C)





class SemanticClusterFeaturesCausal(nn.Module):
    """
    Unified wrapper:
      - forward(x): vectorized for training
      - step(x_t, state): single-step for inference with cache
    """
    def __init__(self, num_scales: int, tau: float = 1e-6):
        super().__init__()
        self.pyramid = CausalCentroidPyramid(num_scales=num_scales, tau=tau)
        self.K = num_scales
        self.tau = float(tau)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.pyramid(x)  # (B,T,K,C)

    @torch.no_grad()
    def step(self, x_t: torch.Tensor, state: CausalPyramidState) -> torch.Tensor:
        return state.step(x_t)  # (B,K,C)


class GroupedChannelMLP(nn.Module):
    def __init__(self, k_dim: int, c_dim: int):
        super().__init__()
        hidden_dim = c_dim // 2
        self.k_dim = k_dim
        self.c_dim = c_dim
        self.hidden_dim = hidden_dim

        # shapes chosen for direct einsum without expands
        # fc1: (K, H, C)   fc2: (K, C, H)   b2: (K, C)
        self.fc1_weight = nn.Parameter(torch.empty(k_dim, hidden_dim, c_dim))
        self.fc2_weight = nn.Parameter(torch.empty(k_dim, c_dim, hidden_dim))
        self.fc2_bias   = nn.Parameter(torch.empty(k_dim, c_dim))

        nn.init.kaiming_uniform_(self.fc1_weight, a=5**0.5)
        nn.init.kaiming_uniform_(self.fc2_weight, a=5**0.5)
        nn.init.zeros_(self.fc2_bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: (B, T, K, C) or (B, K, C)
        returns: same leading dims, last two dims (K,C)
        """
        squeeze_time = False
        if x.dim() == 3:  # (B,K,C)
            x = x.unsqueeze(1)  # -> (B,1,K,C)
            squeeze_time = True
        elif x.dim() != 4:
            raise ValueError("Input must be (B,K,C) or (B,T,K,C)")

        # (B,T,K,C) x (K,H,C) -> (B,T,K,H)
        h = torch.einsum('btkc,khc->btkh', x, self.fc1_weight)
        h = F.gelu(h)

        # (B,T,K,H) x (K,C,H) -> (B,T,K,C)
        y = torch.einsum('btkh,kch->btkc', h, self.fc2_weight) + self.fc2_bias

        if squeeze_time:
            y = y[:, 0, :, :]  # (B,K,C)
        return y
        
        
class Cell(nn.Module):
    def __init__(self, dim_in: int, hidden: int):
        super().__init__()
        self.fc1 = nn.Linear(dim_in, hidden, bias=False) #dont change, false intentional
        self.fc2 = nn.Linear(hidden, dim_in, bias=True)
        self.act = nn.GELU()
    def forward(self, x):
      
        return self.fc2(self.act(self.fc1(x))) 

class GPTSemanticBlock(nn.Module):
    def __init__(self, config: GPTConfig):
        super().__init__()
        C = config.n_embd
        self.C = C
        self.K = config.n_scales
        # L = number of feature groups concatenated: token (1) + K scales
        self.L = 1 + self.K
        self.features = SemanticClusterFeaturesCausal(num_scales=self.K, tau=1e-6)
        self.drop = nn.Dropout(config.dropout)
        self.ln = nn.LayerNorm(self.C)
        self.mlp = Cell(self.C,self.C*2)

        # Each bottleneck maps C -> small_hidden -> C
        self.bottleneck = GroupedChannelMLP(self.K, self.C)

    # vectorized
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B, T, C)
        B, T, C = x.shape
        feats = self.features(x)               # (B, T, K, C)
        feats = self.bottleneck(feats) # (B, T, K, C)#bottlenecked
        feats= feats.sum(dim=2)
        # concat token embedding with processed features
        x_in = x + feats
        out = x + self.drop(self.ln(self.mlp(x_in)))

        return out

    # single-step incremental
    @torch.no_grad()
    def step(self, x_t: torch.Tensor, feat_state: CausalPyramidState) -> torch.Tensor:
        # x_t: (B, C)
        B, C = x_t.shape
        feats_t = self.features.step(x_t, feat_state)  # (B, K, C)
        feats_t = self.bottleneck(feats_t)
        feats_t= feats_t.sum(dim=1)

        x_in = x_t+feats_t     # (B, (1+K)*C)
        out = x_t + self.drop(self.ln(self.mlp(x_in)))
        return out



import torch
import torch.nn as nn

import math
import torch
import torch.nn as nn

def _is_prime(n: int) -> bool:
    if n < 2: return False
    if n % 2 == 0: return n == 2
    r = int(n**0.5)
    for f in range(3, r+1, 2):
        if n % f == 0: return False
    return True

def _factorize(n: int):
    f, cnt = [], {}
    d = 2
    while d * d <= n:
        while n % d == 0:
            cnt[d] = cnt.get(d, 0) + 1
            n //= d
        d += 1 if d == 2 else 2
    if n > 1: cnt[n] = cnt.get(n, 0) + 1
    return list(cnt.keys())

def _primitive_root(p: int) -> int:
    # p must be prime
    phi = p - 1
    factors = _factorize(phi)
    for g in range(2, p):
        ok = True
        for q in factors:
            if pow(g, phi // q, p) == 1:
                ok = False
                break
        if ok:
            return g
    raise RuntimeError("no primitive root found")

def _welch_costas_perm(V: int, device=None):
    """
    Welch Costas permutation σ on {0..V-1}, where V = p-1 for prime p.
    σ[i] = g^(i+1) mod p, mapped to 0..V-1 by subtracting 1.
    """
    p = V + 1
    if not _is_prime(p):
        return None
    g = _primitive_root(p)
    sigma = torch.empty(V, dtype=torch.long, device=device)
    for i in range(V):
        sigma[i] = pow(g, i + 1, p) - 1
    return sigma  # permutation of 0..V-1

def _coprime_mul_perm(V: int, device=None):
    """
    Fallback: σ[i] = (a*i + b) % V with gcd(a, V)=1 and a not ≡ ±1 mod V.
    Not Costas, but non-monotone and well-distributed.
    """
    # pick a
    a = None
    for cand in range(2, V):
        if math.gcd(cand, V) == 1 and cand % V not in (1, V-1):
            a = cand
            break
    if a is None:
        a = 1  # degenerate small V
    b = V // 3
    i = torch.arange(V, device=device)
    return ((a * i + b) % V).long()

def _perm_inverse(sigma: torch.Tensor) -> torch.Tensor:
    inv = torch.empty_like(sigma)
    inv[sigma] = torch.arange(sigma.numel(), device=sigma.device)
    return inv

class FlatRollEmbed(nn.Module):
    """
    Replacement for nn.Embedding that maps token id i -> cyclic roll^i of a base
    length-V vector whose non-DC spectrum is flat (DC=0). Requires V == n_embd.
    Weights are frozen by default.
    The 'eye' is mixed at 0.5 and then rows are permuted by a Costas-like order
    to maximize uniqueness while keeping even collapse.
    
    """
    def __init__(self, config, scale: str = "box", seed: int = 0,
                 freeze: bool = True, dtype=None, device=None):
        super().__init__()
        assert config.n_embd == config.vocab_size, (
            f"Expected n_embd == vocab_size, got {config.n_embd} != {config.vocab_size}"
        )
        V = int(config.vocab_size)
        dtype = dtype or torch.float32

        eye = torch.eye(V, dtype=dtype, device=device)
        weight = self._make_weight(V, scale=scale, seed=seed,
                                   dtype=dtype, device=device)  # [V, V]
        mixed = 0.5 * weight + 0.5 * eye  # add identity towers

        # --- compute a strong-scramble row order (Costas if possible) ---
        sigma = _welch_costas_perm(V, device=device)
        if sigma is None:
            sigma = _coprime_mul_perm(V, device=device)
        # We want ones at (row = σ[i], col = i). For row-permutation via index_select,
        # use r_idx = σ^{-1} so that new_row j pulls old_row r_idx[j] with 1 at column j=σ[i].
        r_idx = _perm_inverse(sigma)

        # keep for reference / decoding
        self.register_buffer("row_perm", r_idx, persistent=False)
        self.register_buffer("sigma", sigma, persistent=False)

        mixed = mixed.index_select(0, r_idx)
        self.embed = nn.Embedding.from_pretrained(mixed, freeze=freeze)


    @staticmethod
    def _row_perm_max_equidistant(V: int, device=None) -> torch.Tensor:
        """
        Row permutation that evenly offsets the identity's '1' away from the diagonal.
        Uses a single cyclic shift by k = floor(V/2).
        """
        if V <= 1:
            return torch.arange(V, device=device, dtype=torch.long)
        k = V // 2
        if k == 0:  # only happens when V == 1, handled above; keep for safety
            k = 1
        return ((torch.arange(V, device=device) + k) % V).long()

    @staticmethod
    def _make_weight(V: int, scale: str = "box", seed: int = 0,
                     dtype=torch.float32, device=None) -> torch.Tensor:
        """
        Returns a (V, V) tensor whose rows are cyclic rolls of a base vector x in R^V
        with |FFT(x)|^2 flat for k=1..V-1 and DC=0.
        scale:
          - "unit": ||x||_2 = 1
          - "box":  max|x_i| = 1
        """
        # build on CPU, move at end
        complex_dtype = torch.complex64 if dtype == torch.float32 else torch.complex128
        g = torch.Generator().manual_seed(seed)

        X = torch.zeros(V, dtype=complex_dtype)
        # DC bin
        X[0] = torch.tensor(0, dtype=complex_dtype)

        if V % 2 == 0:
            # bins 1..V/2-1 are complex-conjugate pairs; Nyquist bin must be real
            for k in range(1, V // 2):
                phi = torch.rand((), generator=g) * (2 * math.pi)
                val = torch.cos(phi) + 1j * torch.sin(phi)
                X[k] = val
                X[V - k] = torch.conj(val)
            X[V // 2] = 1.0 if torch.rand((), generator=g) < 0.5 else -1.0
        else:
            for k in range(1, (V - 1) // 2 + 1):
                phi = torch.rand((), generator=g) * (2 * math.pi)
                val = torch.cos(phi) + 1j * torch.sin(phi)
                X[k] = val
                X[V - k] = torch.conj(val)

        x = torch.fft.ifft(X).real  # real length-V base vector

        if scale == "unit":
            x = x / (x.norm() + 1e-12)
        elif scale == "box":
            x = x / (x.abs().max() + 1e-12)
        else:
            raise ValueError("scale must be 'unit' or 'box'")

        rows = [torch.roll(x, shifts=r, dims=0) for r in range(V)]
        W = torch.stack(rows, dim=0).to(dtype=dtype)
        if device is not None:
            W = W.to(device)
        return W

    def forward(self, input_ids: torch.LongTensor):
        # (batch, seq_len, V)
        return self.embed(input_ids)


        
@dataclass
class GPTConfig:
    block_size: int = 1024
    vocab_size: int = 66 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
    n_layer: int = 6
    n_head:int = 6
    n_embd: int = 128
    n_scales:int = 9
    dropout: float = 0.1


class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.vocab_size is not None
        assert config.block_size is not None
        self.config = config
        self.n_embd = config.n_embd
        self.drop = nn.Dropout(0.6)

        self.transformer = nn.ModuleDict(dict(
            wte = FlatRollEmbed(config),
            h = nn.ModuleList([GPTSemanticBlock(config) for _ in range(config.n_layer)]),

        ))

        self.lm_head = nn.Linear(self.config.n_embd, self.config.vocab_size, bias=False)
        self.lm_head.weight = self.transformer.wte.embed.weight
        #TODO- ADD A LEARNED, tied gate that regulate both LMhead and flatrollembed,
        #such that the mixture between eye and convolution-even generation is 
        #conditionally between 100% eye- strong individuality, risk of neighbors
        #100% even-convolution- strong generalization, risk of chaos
        #model may benefit from initially learning with one, then moving to the other
        #thus more rapidly acquiring meaningful structure


    # ---------- forward ----------
    def forward(self, idx, targets=None, eprint=False):
        device = idx.device
        b, t = idx.size()
        x = self.transformer.wte(idx) 
        x = x.detach()                 # sever any stale history just in case
        x.requires_grad_(True)         # make x a grad leaf for τ at layer 0

        for block in self.transformer.h:
                x= block(x)


        if targets is not None:
            logits = self.lm_head(x)
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                targets.view(-1),
                ignore_index=-1
            )
        else:
            logits = self.lm_head(x[:, [-1], :])
            loss = None
        return logits, loss


    @torch.no_grad()
    def generate_greedy(model: nn.Module, idx: torch.LongTensor, max_new_tokens: int, block_size: int):
        """
        model: your GPT with:
           - transformer.wte (embedding)
           - transformer.h : list[GPTSemanticBlock]
           - lm_head
        idx: (B, T0) prompt token ids
        """
        device = next(model.parameters()).device
        B = idx.size(0)
        # per-block feature caches
        feat_states = [CausalPyramidState(model.config.n_scales, model.config.n_embd, device, batch_size=B)
                       for _ in model.transformer.h]
    
        # 1) prime caches with the prompt (causal, one step at a time)
        x_all = model.transformer.wte(idx)  # (B,T0,C); fixed embeddings in your code
        for t in range(idx.size(1)):
            x_t = x_all[:, t, :]
            for blk, st in zip(model.transformer.h, feat_states):
                x_t = blk.step(x_t, st)      # per-block step
            # we discard logits during priming
    
        # 2) roll out new tokens
        out = [idx]
        cur = idx
        for _ in range(max_new_tokens):
            # last token embedding
            last_idx = cur[:, -1]                      # (B,)
            x_t = model.transformer.wte(last_idx)      # (B,C)
            for blk, st in zip(model.transformer.h, feat_states):
                x_t = blk.step(x_t, st)                # (B,C)
            logits = model.lm_head(x_t)                # (B,V)
            next_idx = torch.argmax(logits, dim=-1, keepdim=True)  # greedy; swap to sampling if you like
            out.append(next_idx)
            cur = torch.cat([cur, next_idx], dim=1)
            # keep only last block_size tokens in cur (typical AR convenience)
            if cur.size(1) > block_size:
                cur = cur[:, -block_size:]
        return torch.cat(out, dim=1)


In [74]:
import requests, os

base_url = "https://huggingface.co/datasets/cambridge-climb/BabyLM/resolve/main/clean/10M/"
target_dir = "./babylm_10m_cleaned"
os.makedirs(target_dir, exist_ok=True)

file_names = [
    "aochildes.txt",
    "cbt.txt",
    "children_stories.txt",
    "gutenberg.txt",
    "qed.txt",
    "simple_wikipedia.txt",
    "switchboard.txt",
    "wikipedia.txt"
]

# Optional addition: Shakespeare from another dataset
shakespeare_url = "https://raw.githubusercontent.com/karpathy/char-rnn/refs/heads/master/data/tinyshakespeare/input.txt"
shakespeare_fname = "shakespeare.txt"

# Combined download logic
all_files = [(base_url + fname, fname) for fname in file_names]
all_files.append((shakespeare_url, shakespeare_fname))  # Add Shakespeare


# Download loop
for url, fname in all_files:
    out_path = os.path.join(target_dir, fname)
    print(f"📥 Downloading {fname}...")
    resp = requests.get(url)
    if resp.status_code == 200:
        with open(out_path, "w", encoding="utf-8") as f:
            f.write(resp.text)
    else:
        print(f"❌ Failed to download {fname} ({resp.status_code})")

print(f"✅ Done. Files saved to {target_dir}")

📥 Downloading aochildes.txt...
📥 Downloading cbt.txt...
📥 Downloading children_stories.txt...
📥 Downloading gutenberg.txt...
📥 Downloading qed.txt...
📥 Downloading simple_wikipedia.txt...
📥 Downloading switchboard.txt...
📥 Downloading wikipedia.txt...
📥 Downloading shakespeare.txt...
✅ Done. Files saved to ./babylm_10m_cleaned


In [97]:
import os
import pickle
import numpy as np

# === Paths ===
source_dir = "./babylm_10m_cleaned"
out_dir    = "./babylm_char_tokenized"
os.makedirs(out_dir, exist_ok=True)

file_names = [
    "shakespeare.txt"#,#"aochildes.txt", "cbt.txt", "children_stories.txt", "gutenberg.txt",
    #"qed.txt", "simple_wikipedia.txt", "switchboard.txt", "wikipedia.txt"
]

# === Load and split ===
train_texts, val_texts = [], []
char_set = set()

for fname in file_names:
    with open(os.path.join(source_dir, fname), encoding="utf-8") as f:
        lines = f.readlines()
        n = len(lines)
        split = int(0.9 * n)
        train_part = "".join(lines[:split])
        val_part   = "".join(lines[split:])
        train_texts.append(train_part)
        val_texts.append(val_part)
        char_set.update(train_part)
        char_set.update(val_part)

full_train = "\n".join(train_texts)
full_val   = "\n".join(val_texts)

# === Final vocab ===
char_set = sorted(set(char_set))
vocab_chars = ["<unk>"] + [c for c in char_set if c != "<unk>"]

stoi = {ch: i for i, ch in enumerate(vocab_chars)}
itos = {i: ch for ch, i in stoi.items()}

# === Encode function ===
def encode(text):
    return [stoi.get(c, 0) for c in text]

train_ids = np.array(encode(full_train), dtype=np.uint16)
val_ids   = np.array(encode(full_val),   dtype=np.uint16)

# === Save ===
train_ids.tofile(os.path.join(out_dir, "train.bin"))
val_ids.tofile(os.path.join(out_dir, "val.bin"))



with open(os.path.join(out_dir, "meta.pkl"), "wb") as f:
    pickle.dump({
        "vocab_size": len(stoi),
        "stoi": stoi,
        "itos": itos
    }, f)

print(f"✅ Char tokenizer finalized.")
print(f"🧾 Train tokens: {len(train_ids)} | Val tokens: {len(val_ids)}")
print(f"🔤 Vocab size: {len(stoi)}")

✅ Char tokenizer finalized.
🧾 Train tokens: 1016242 | Val tokens: 99152
🔤 Vocab size: 66


In [103]:
# import os
import pickle
import numpy as np
from torch.utils.data import DataLoader, Dataset
import torch
from torch import nn
import torch.nn.functional as F

device = "cuda" if torch.cuda.is_available() else "cpu"

# === Config ===
data_dir = "./babylm_char_tokenized"  # <- char-tokenized data
block_size = 2048
batch_size = 8

# === Load tokenizer metadata ===
with open(os.path.join(data_dir, 'meta.pkl'), 'rb') as f:
    meta = pickle.load(f)
vocab_size = meta['vocab_size']

# === Load mmap edata (char-level tokens, uint16) ===
train_ids = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
val_ids   = np.memmap(os.path.join(data_dir, 'val.bin'),   dtype=np.uint16, mode='r')

# === Efficient GPU Batch Sampler ===
class GPUBatchDataset(Dataset):
    def __init__(self, mmap_file, block_size, batch_size, device, jitter=63, p_aligned=0.5, pad_len=0):
        self.data = mmap_file
        self.block_size = block_size
        self.batch_size = batch_size
        self.device = device
        self.pad_len = int(pad_len)
        self.sample_len = self.block_size + self.pad_len  # X length
        self.total = len(self.data) - self.sample_len - 1
        self.n_blocks = self.total // self.sample_len
        self.jitter = int(jitter)          # small random offset added to aligned start
        self.p_aligned = float(p_aligned)  # mix aligned and jittered

    def __len__(self):
        return self.total // self.batch_size

    def __getitem__(self, idx):
        X = np.empty((self.batch_size, self.sample_len), dtype=np.int64)
        Y = np.empty((self.batch_size, self.block_size), dtype=np.int64)

        for i in range(self.batch_size):
            # choose a base aligned block
            base_block = np.random.randint(0, self.n_blocks)
            start = base_block * self.sample_len

            # with probability, add a small jitter (keeps cache-friendly contiguous reads)
            if np.random.rand() > self.p_aligned:
                j = np.random.randint(0, self.jitter + 1)
                start = min(start + j, self.total)  # stay in range

            X[i] = self.data[start : start + self.sample_len]
            # targets correspond to the final block_size visible steps
            Y[i] = self.data[start + 1 + self.pad_len : start + 1 + self.pad_len + self.block_size]


        return (
            torch.from_numpy(X).to(self.device, non_blocking=True),
            torch.from_numpy(Y).to(self.device, non_blocking=True)
        )


config = GPTConfig(
    vocab_size=len(stoi),
    n_layer=1,
    n_embd=vocab_size,
    block_size=block_size,
    dropout = 0.1,
)
train_dataset = GPUBatchDataset(train_ids, block_size, batch_size, device, pad_len=0)
# === DataLoader ===
train_loader  = DataLoader(train_dataset, batch_size=1, shuffle=False, num_workers=0)

model = GPT(config)
model = torch.compile(model,mode="max-autotune")
model = model.to(device)

In [104]:
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-3)
losses = []
def train_epoch():
    model.train()
    total_loss = 0
    it = 0
    for xb, yb in train_loader:
          xb, yb = xb[0], yb[0]  # unwrap batch dimension
          optimizer.zero_grad()
          it = it + 1
          logits, loss = model(xb, yb)
          loss = loss
          loss.backward()
          torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
          optimizer.step()
          total_loss += loss.item()
          losses.append(loss.item())
          if it%100==0: print(loss.item())
    return total_loss / len(train_loader)

# === Run Training ===
num_epochs = 10
for epoch in range(1, num_epochs + 1):
    train_loss = train_epoch()
    print(f"Epoch {epoch:2d} | Train loss: {train_loss:.4f}")

2.4248459339141846
2.267223834991455
2.215237855911255
2.226175308227539
2.156221389770508
2.1747803688049316
2.101888418197632
2.0438454151153564
2.010363817214966
2.011732578277588
2.026224136352539
2.0078017711639404
1.97121262550354
1.9109714031219482
1.9306507110595703
1.8956797122955322
1.898237943649292
1.9327778816223145
1.9002705812454224
1.9405431747436523
1.9158666133880615
1.8525692224502563
1.8828603029251099
1.874499797821045
1.7933162450790405
1.8483089208602905
1.8752532005310059
1.850698709487915
1.8211157321929932
1.812084674835205
1.8197087049484253
1.8460789918899536
1.7804791927337646
1.8103175163269043
1.8493680953979492
1.8176980018615723
1.735107183456421
1.8106229305267334
1.8291242122650146
1.762618064880371
1.7982959747314453
1.7748730182647705
1.808021903038025
1.8309447765350342
1.8393099308013916
1.7504090070724487
1.7045764923095703
1.7763054370880127
1.8377323150634766
1.7397518157958984
1.8211332559585571
1.7997101545333862
1.7955690622329712
1.79905796

KeyboardInterrupt: 

In [87]:
num_epochs = 10
for epoch in range(1, num_epochs + 1):
    train_loss = train_epoch()
    print(f"Epoch {epoch:2d} | Train loss: {train_loss:.4f}")

1.7179797887802124
1.840389370918274
1.7476829290390015
1.789048194885254
1.762560248374939
1.7769278287887573
1.7655577659606934
1.7587710618972778
1.7428169250488281
1.7512731552124023
1.6718738079071045
1.8342667818069458
1.7842442989349365
1.73983633518219
1.7651057243347168
1.7372163534164429


KeyboardInterrupt: 

In [105]:
import pickle
def decode_chars(token_ids, itos):
    """
    Decodes a list of character token IDs into a string.
    """
    return ''.join([itos[i] for i in token_ids])

def encode_chars(text, stoi):
    """
    Encodes a string into a list of token IDs, one per character.
    """
    return [stoi.get(c, 0) for c in text]


from collections import deque


@torch.no_grad()
def decode_sequence_char_rolling(
    model, stoi, itos, prompt,
    max_new_tokens=100,
    block_size=1024,
    temperature=1.0,
    space_fallback=' ',
    strict_window=False,          # if True, periodically re-prime caches on the last block
    reprime_every=None            # if strict_window, how often to re-prime (int). Default: block_size
):
    """
    Rolling-block generator that:
      - keeps the ENTIRE generated text (no trimming of output),
      - maintains a rolling block window internally,
      - optionally re-primes feature caches on the last `block_size` tokens to strictly
        mimic block-window semantics seen during training.

    If strict_window=False (default): fastest path; caches stream forever.
    If strict_window=True: we periodically reinitialize the per-layer states using the
      most recent `block_size` tokens. This ensures exact 'sliding window' behavior.
    """
    device = next(model.parameters()).device
    model.eval()
    B = 1

    # ---- encode prompt (fallback to space if empty) ----
    space_id = stoi.get(space_fallback, 0)
    prompt_ids = encode_chars(prompt, stoi)
    if len(prompt_ids) == 0:
        prompt_ids = [space_id]

    # ---- left-pad ONCE to match your training forward's left-pad-to-block ----
    pad_len = max(0, block_size - len(prompt_ids))
    pad_ids = [space_id] * pad_len
    priming_ids = pad_ids + prompt_ids  # padding only used for priming; not returned

    # ---- per-block feature caches (one state per block) ----
    feat_states = [
        CausalPyramidState(
            num_scales=model.config.n_scales,
            C=model.config.n_embd,
            device=device,
            batch_size=B,
            tau=1e-6
        ) for _ in model.transformer.h
    ]

    # helper: (re-)prime caches with a sequence of token ids (left-pad to block if shorter)
    def _reprime_with_ids(tok_ids):
        # optionally left-pad the window up to block_size (only needed if strict semantics desired)
        if len(tok_ids) < block_size:
            tok_ids = [space_id] * (block_size - len(tok_ids)) + tok_ids
        ids_t = torch.tensor([tok_ids], dtype=torch.long, device=device)  # (1, T)
        x_last = None
        # fresh states
        new_states = [
            CausalPyramidState(
                num_scales=model.config.n_scales,
                C=model.config.n_embd,
                device=device,
                batch_size=B,
                tau=1e-6
            ) for _ in model.transformer.h
        ]
        for t in range(ids_t.size(1)):
            x_last = model.transformer.wte(ids_t[:, t])  # (1,C)
            for blk, st in zip(model.transformer.h, new_states):
                x_last = blk.step(x_last, st)
        return new_states, x_last

    # ---- initial priming with left-padded prompt ----
    ids = torch.tensor([priming_ids], dtype=torch.long, device=device)
    x_t = None
    for t in range(ids.size(1)):
        x_t = model.transformer.wte(ids[:, t])  # (1,C)
        for blk, st in zip(model.transformer.h, feat_states):
            x_t = blk.step(x_t, st)

    # ---- FULL output accumulator (never trimmed) ----
    out_full = list(prompt_ids)  # store ints

    # ---- rolling window buffer of last block_size tokens (prompt + generated) ----
    window = deque(prompt_ids, maxlen=block_size)

    # strict-window settings
    if reprime_every is None:
        reprime_every = block_size
    steps_since_reprime = 0

    # ---- incremental rollout ----
    for _ in range(max_new_tokens):
        logits = model.lm_head(x_t)  # (1,V)
        if temperature != 1.0:
            logits = logits / float(temperature)
        probs = torch.softmax(logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)  # (1,1)
        next_id = int(next_token.item())

        # record full output
        out_full.append(next_id)

        # advance rolling window
        window.append(next_id)

        # step one token
        x_t = model.transformer.wte(next_token.squeeze(-1))  # (1,C)
        for blk, st in zip(model.transformer.h, feat_states):
            x_t = blk.step(x_t, st)

        # optionally re-prime to strict sliding-window semantics
        if strict_window:
            steps_since_reprime += 1
            if steps_since_reprime >= reprime_every and len(window) == block_size:
                feat_states, x_t = _reprime_with_ids(list(window))
                steps_since_reprime = 0

    # decode full continuation (prompt + all generated)
    return decode_chars(out_full, itos)
    
with open("./babylm_char_tokenized/meta.pkl", "rb") as f:
    meta = pickle.load(f)
stoi = meta["stoi"]
itos = meta["itos"]
import time
then = time.time()
prompt = "ROMEO: Juliet, do you love me?  JULIET:"
generated = decode_sequence_char_rolling(
    model=model,
    stoi=stoi,
    itos=itos,
    prompt=prompt,
    max_new_tokens=4096,
    block_size=2048,
    temperature=0.8
)

print(generated)
print(time.time()-then)

ROMEO: Juliet, do you love me?  JULIET:
Yet God world whantss
Wher's chour both thBxecurry my shall ing
hones the gelanct our enot our Poress tore
Heavy the fairly Apelinatlessnaarly;
for your dommeast th3 
Thereast by my my orrorey basely,
By thy my trotlyss, lootly reayAy, bair
ISignaatant in my corlayery this hill undred,
Forrow your nor thate thatipely toRN eve,
That arrughter be wrongrow drothYstal's,
Oncurtaters pre3 art all the unto a day band onevenather beent ofL
And for what did liess to useep thee.

KING EDWARD:
And that steare that prince, for so me,
Is mourn'd the your have your arnough fast: in herefore the king, if are will lain.

CATESBY:
Or merropluned, and thou have the  dear inded this much!
The her proved this but hather provey to heJfight.

ESgots And strumers, I well theFore him my It could words, therefore sperearewer suchall our shall,
Joour b: I may
Unper deady his for would some like
do swite ther looks and sucked ;ut Angre herepar'd
Commine holy, alt presO ra

In [38]:
file_path = 'simple_model_tiny.pth'

# 3. Save the model's state_dict
torch.save(model.state_dict(), file_path)