In [122]:
#copyright joshuah.rainstar@gmail.com
from __future__ import annotations
import math
import typing

import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Dict, Tuple





class PairwiseRotSpiral(nn.Module):
    def __init__(self, dim, radius=6.0, omega=1.0, k=1.0, step=0.1, cube_shell=False):
        super().__init__()
        self.dim = dim
        self.radius = float(radius)
        self.omega = float(omega)
        self.k = float(k)
        self.step = float(step)
        self.cube_shell = bool(cube_shell)
        self.eps = 1e-8

    def _cos_sin(self, x):
        theta = self.omega * self.step
        # Use Python math for scalar, then create tensors on correct device and dtype
        c = torch.tensor(math.cos(theta), device=x.device, dtype=x.dtype)
        s = torch.tensor(math.sin(theta), device=x.device, dtype=x.dtype)
        return c, s

    def forward(self, x):
        D = x.size(-1)
        # radial term
        r = torch.linalg.vector_norm(x, dim=-1, keepdim=True).clamp_min(self.eps)
        radial = (self.radius - r) * (x / r)

        # rotation on 2D pairs, vectorized
        if D >= 2:
            c, s = self._cos_sin(x)
            n2 = D // 2
            head = x[..., : n2 * 2].reshape(*x.shape[:-1], n2, 2)
            xi = head[..., 0]
            xj = head[..., 1]
            yi = c * xi - s * xj
            yj = s * xi + c * xj
            rot = torch.stack([yi, yj], dim=-1).reshape(*x.shape[:-1], n2 * 2)
            if D % 2 == 1:
                y = torch.cat([rot, x[..., -1:].contiguous()], dim=-1)
            else:
                y = rot
        else:
            y = x

        # one-step Euler update
        y = x + self.step * ((y - x) + self.k * radial)

        if self.cube_shell:
            y = self.radius * torch.tanh(y / self.radius)
        return y



# Example mixer that spirals EACH component around a chosen center (origin by default)
class SpiralMix(nn.Module):
    def __init__(self, rank, **spiral_kwargs):
        super().__init__()
        self.rank = rank
        self.flow = PairwiseRotSpiral(rank, **spiral_kwargs)

    def forward(self, comps, center=None, loop_iters=2):
        # Accept either a list/tuple of [...,] Tensors or a single Tensor [..., r]
        if isinstance(comps, (list, tuple)):
            # old DynMix API: list of [B,T] or [B] -> stack on last dim -> [B,T,r] (or [B,r])
            x = torch.stack(comps, dim=-1)
            return_list = True
        else:
            # new API: comps is already [B,T,r] (or any leading dims, last is r)
            x = comps
            return_list = False

        if center is None:
            center = 0.0  # broadcastable scalar OK
        y = x
        for _ in range(loop_iters):
            y = self.flow(y - center) + center  # pairwise rotations on last dim only

        if return_list:
            # match DynMix return type: list of [...,] tensors
            return [y[..., i] for i in range(y.size(-1))]
        return y




class LayerNorm(nn.Module):
    """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """

    def __init__(self, ndim, bias):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(ndim))
        self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None

    def forward(self, input):
        return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)


"""
Manifold Attention (no learned attention) with deterministic subspace iteration.

Core idea
---------
Treat X in [B, T, D] as a curve in R^D over time. Build a compact, self-adjoint
operator C = (1/T) X'^T X' with X' = X - anchor + low_rank_shift(X). Extract a
rank-r invariant subspace with K steps of deterministic subspace iteration.
Project onto that basis to obtain r scalar traces, apply simple analytic
conditioning (energy normalization and optional soft shrinkage, optional causal
AR(1)), then reconstruct with the (orthonormal) basis and undo the shift. No
query-key-value attention, no near/far field.

Notes
-----
- Subspace iteration is deterministic and differentiable. We use batched QR to
  orthonormalize after each step. K controls the number of power iterations.
  If you want K to play the role of "heads", think of each iteration as a head
  that sharpens alignment to the top-r invariant subspace. In practice we use
  the final V_K for projection.
- Low-rank shift S(X) = U sigma(V^T X) is optional and helps undo harmful
  normalization. Set bottleneck "shift_rank" to 0 to disable.
- The basis columns are sign-aligned to the first token so that they are stable
  across steps and batches.
- Reconstruction uses V^T directly since columns are orthonormal. If you swap
  orthonorm for another routine, you can still use a tiny r x r solve.

API
---
class ManifoldAttentionNoAttn(nn.Module):
    def __init__(self, d_model, rank, K=2, shift_rank=0, shrink_lambda=0.0,
                 causal=False, ar_rho=0.0, eps=1e-5, dropout=0.0):
        ...
    def forward(self, x):  # x: [B, T, D]
        return y  # [B, T, D]

Example
-------
>>> import torch
>>> B, T, D = 2, 1024, 768
>>> layer = ManifoldAttentionNoAttn(d_model=D, rank=32, K=3, shift_rank=16,
...                                 shrink_lambda=0.01, causal=False)
>>> x = torch.randn(B, T, D)
>>> y = layer(x)
>>> y.shape
torch.Size([2, 1024, 768])
"""

from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F


def _batch_eye(n: int, batch: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
    """Batched identity [B, n, n]."""
    I = torch.eye(n, device=device, dtype=dtype)
    return I.unsqueeze(0).expand(batch, n, n)


def orthonorm_columns(V: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
    """Orthonormalize columns of V with batched QR.

    V: [B, D, r]  -> returns Q: [B, D, r] with Q^T Q = I_r
    """
    # torch.linalg.qr supports batched input
    Q, R = torch.linalg.qr(V, mode="reduced")
    # Ensure a consistent sign by forcing diag(R) positive where possible
    diag = torch.diagonal(R, dim1=-2, dim2=-1)
    sgn = torch.sign(diag + eps).unsqueeze(-2)  # [B, 1, r]
    Q = Q * sgn
    return Q


def subspace_iteration(C: torch.Tensor, r: int, K: int, V0: Optional[torch.Tensor] = None,
                       eps: float = 1e-6) -> torch.Tensor:
    """
    Batched subspace iteration with a Student-t-like spectral filter.
    Same signature and return as before. C: [B, D, D], V_K: [B, D, r].
    """
    B, D, _ = C.shape
    device, dtype = C.device, C.dtype

    # Deterministic init
    if V0 is None:
        E = torch.zeros(D, r, device=device, dtype=dtype)
        E[:r, :r] = torch.eye(r, device=device, dtype=dtype)
        V = E.unsqueeze(0).expand(B, D, r).contiguous()
    else:
        V = V0

    # Build block-Krylov basis: Q = [V, CV, C^2V, ...] with K blocks
    blocks = []
    V = orthonorm_columns(V, eps=eps)
    Z = V
    for _ in range(max(1, K)):
        blocks.append(Z)
        Z = torch.matmul(C, Z)
        Z = orthonorm_columns(Z, eps=eps)

    Q = torch.cat(blocks, dim=2)  # [B, D, q], q = r*K
    Q = orthonorm_columns(Q, eps=eps)

    # Small projected matrix H = Q^T C Q  -> shape [B, q, q]
    H = torch.matmul(Q.transpose(1, 2), torch.matmul(C, Q))

    # EVD of H
    evals, U = torch.linalg.eigh(H)  # ascending per batch; evals: [B, q], U: [B, q, q]

    # Student-t-like increasing, saturating filter on eigenvalues
    # Choose scale and df to taste; these are stable defaults.
    # κ: scale, set from a high quantile of evals per batch. ν: degrees of freedom.
    kappa = torch.quantile(evals.clamp_min(eps), 0.80, dim=-1, keepdim=True) + eps
    nu = 4.0  # heavier tails for smaller ν; tune as needed

    gt = 1.0 - torch.pow(1.0 + evals / kappa, -0.5 * nu)   # [B, q], in (0,1)
    # Optional additional tempering to keep order but soften dominance
    # Use fractional power on λ to compress ratios
    p = 0.5
    scores = torch.pow(evals.clamp_min(eps), p) * gt        # [B, q]

    # Pick the r columns of U with largest filtered scores
    idx = scores.argsort(dim=-1, descending=True)[..., :r]  # [B, r]
    idx_exp = idx.unsqueeze(1).expand(B, U.size(1), r)      # [B, q, r]
    U_top = torch.gather(U, 2, idx_exp)                    # [B, q, r]

    # Lift back: V = Q @ U_top, then orthonormalize
    V = torch.matmul(Q, U_top)                              # [B, D, r]
    V = orthonorm_columns(V, eps=eps)
    return V


def sign_align(V: torch.Tensor, a: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
    """Align signs of columns of V so that v_i^T a >= 0 for each i.

    V: [B, D, r]
    a: [B, D]  (anchor token x_1)
    returns V with column-wise signs adjusted deterministically.
    """
    # Compute dot products per column: [B, r]
    dots = (V * a.unsqueeze(-1)).sum(dim=1)
    sgn = torch.sign(dots + eps)  # +eps to avoid 0 sign
    return V * sgn.unsqueeze(1)


def energy_normalize(traces: torch.Tensor, eps: float = 1e-8) -> Tuple[torch.Tensor, torch.Tensor]:
    """Per-component energy normalization over time.

    traces: [B, T, r]
    returns (normed_traces, scales) where scales: [B, 1, r]
    """
    # Energy per component across time
    scales = torch.sqrt(torch.clamp((traces**2).sum(dim=1, keepdim=True), min=0.0) + eps)
    traces_n = traces / scales
    return traces_n, scales


def soft_shrink(x: torch.Tensor, lam: float) -> torch.Tensor:
    if lam <= 0.0:
        return x
    # Elementwise soft threshold
    return torch.sign(x) * F.gelu(torch.abs(x) - lam)


def ar1_filter(x: torch.Tensor, rho: float) -> torch.Tensor:
    """Causal AR(1) smoothing along time dimension for each component independently.

    x: [B, T, r], rho in [0,1)
    returns y of same shape
    """
    if rho <= 0.0:
        return x
    B, T, r = x.shape
    y = torch.zeros_like(x)
    y[:, 0, :] = x[:, 0, :]
    for t in range(1, T):
        y[:, t, :] = rho * y[:, t - 1, :] + (1.0 - rho) * x[:, t, :]
    return y


class LowRankShift(nn.Module):
    """Low-rank residual shift S(X) = U sigma(V^T X) applied per time step.

    If shift_rank == 0, the caller should bypass this module.
    """

    def __init__(self, d_model: int, shift_rank: int):
        super().__init__()
        self.d_model = d_model
        self.shift_rank = shift_rank
        self.in_proj = nn.Linear(d_model, shift_rank, bias=False)
        self.out_proj = nn.Linear(shift_rank, d_model, bias=True)
        self.act = nn.GELU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: [B, T, D]
        z = self.in_proj(x)
        z = self.act(z)
        s = self.out_proj(z)
        return s



def subspace_iteration_linop(matvec, d, rank, K, V0, eps: float = 1e-6):
    """
    Batched subspace iteration using a linear-operator matvec.
    - matvec: function(V) -> M @ V with V [B, d, r] and returns [B, d, r]
    - d: ambient dimension (D)
    - rank: r
    - K: iterations
    - V0: required init [B, d, r] - use the same identity init as covariance path
    """
    V = orthonorm_columns(V0, eps=eps)
    for _ in range(max(1, K)):
        Z = matvec(V)                  # [B, d, r]
        V = orthonorm_columns(Z, eps)  # match covariance path behavior
    return V


class ManifoldAttentionNoAttn(nn.Module):
    def __init__(
        self,
        config,
        d_model: int,
        rank: int,
        K: int = 2,
        shift_rank: int = 0,
        shrink_lambda: float = 0.0,
        causal: bool = False,
        ar_rho: float = 0.0,
        eps: float = 1e-5,
        dropout: float = 0.0,
        use_layernorm: bool = True,
    ) -> None:
        super().__init__()
        assert rank > 0 and K >= 1
        self.d_model = d_model
        self.rank = rank                # <-- fix: respect constructor
        self.K = K
        self.shift_rank = shift_rank
        self.shrink_lambda = float(shrink_lambda)
        self.causal = bool(causal)
        self.ar_rho = float(ar_rho)
        self.eps = float(eps)

        self.shift = LowRankShift(d_model, shift_rank) if shift_rank > 0 else None
        self.out = nn.Linear(d_model, d_model, bias=False)
        self.dropout = nn.Dropout(dropout)
        self.ln = nn.LayerNorm(d_model) if use_layernorm else nn.Identity()
        self.dynmix = SpiralMix(1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """x: [B, T, D] -> y: [B, T, D]"""
        B, T, D = x.shape
        assert D == self.d_model

        # Anchor vector (no large allocs)
        anchor = torch.zeros(B, D, device=x.device, dtype=x.dtype)
        anchor[:, 0] = 1.0

        # Center
        xc = x - anchor.unsqueeze(1)  # broadcast over T

        # Optional low-rank de-normalization shift; avoid adding zeros if not needed
        if self.shift is not None:
            s = self.shift(x)
            xprime = xc + s
        else:
            s = None
            xprime = xc

        # Shapes
        xt = xprime.transpose(1, 2)  # [B, D, T]

        # Optimized: linear operator form with the SAME init as covariance path
        # Build V0 as first r columns of the identity, expanded over batch
        E = torch.zeros(B, D, self.rank, device=x.device, dtype=x.dtype)
        E[:, :self.rank, :self.rank] = torch.eye(self.rank, device=x.device, dtype=x.dtype)

        def cov_matvec(V):  # V: [B, D, r] -> [B, D, r]
            Y = torch.matmul(xprime, V)           # [B, T, r]
            Z = torch.matmul(xt, Y) / float(T)    # [B, D, r]
            return Z + self.eps * V

        V = subspace_iteration_linop(
            cov_matvec, D, self.rank, self.K, V0=E, eps=self.eps
        )

        # Sign alignment using anchor token
        V = sign_align(V, anchor)  # [B, D, r]

        # Project to r scalar traces over time: [B, T, r]
        traces = torch.matmul(xprime, V)  # [B, T, r]

        # Analytic conditioning
        traces_n, scales = energy_normalize(traces, eps=self.eps)
        traces_n = soft_shrink(traces_n, self.shrink_lambda)

        traces_list = [traces_n[..., i] for i in range(traces_n.size(-1))]
        prod = self.dynmix(traces_list)
        traces_n = torch.stack(prod, dim=-1)

        if self.causal and self.ar_rho > 0.0:
            traces_n = ar1_filter(traces_n, self.ar_rho)

        traces_final = traces_n * scales

        # Recompose
        x_tilde = torch.matmul(traces_final, V.transpose(1, 2))  # [B, T, D]

        # Undo shift and add anchor
        if s is not None:
            x_hat = x_tilde - s + anchor.unsqueeze(1)
        else:
            x_hat = x_tilde + anchor.unsqueeze(1)

        # Residual + thin output projection and optional norm
        y = x + self.dropout(self.out(x_hat))
        y = self.ln(y)
        return y





# ===========================================================
# Utilities
# ===========================================================

def _norm(v, eps: float = 1e-12):
    return torch.linalg.vector_norm(v, dim=-1, keepdim=True).clamp_min(eps)


def _unit(v, eps: float = 1e-12):
    return v / _norm(v, eps)

    
@torch.no_grad()
def phase_transport_between(
    curr: torch.Tensor,
    prev: torch.Tensor,
    tau: float = 1e-6,          # semantic threshold (unchanged)
    eps: float = 1e-12          # numeric epsilon (NEW: decoupled from tau)
) -> torch.Tensor:
    assert curr.shape == prev.shape and curr.dim() == 3
    B, T, C = curr.shape

    # Units (reuse norms) — clamp with eps (NOT tau)
    nu = torch.linalg.vector_norm(curr, dim=-1, keepdim=True).clamp_min(eps)   # (B,T,1)
    nv = torch.linalg.vector_norm(prev, dim=-1, keepdim=True).clamp_min(eps)   # (B,T,1)
    u = curr / nu
    v = prev / nv

    w = curr - prev
    c = (u * v).sum(dim=-1, keepdim=True)                                      # (B,T,1)

    # Masks (semantic thresholds use tau)
    near_pos = (c >  1.0 - tau)                                                # (B,T,1)
    near_neg = (c < -1.0 + tau)                                                # (B,T,1)
    small_u  = (nu < tau)                                                      # (B,T,1)
    small_v  = (nv < tau)                                                      # (B,T,1)
    trivial  = near_pos | small_u | small_v                                    # (B,T,1)

    # General branch
    denom = (1.0 + c).clamp_min(eps)                                           # (B,T,1)
    a = (v * w).sum(dim=-1, keepdim=True)                                      # (B,T,1)
    b = (u * w).sum(dim=-1, keepdim=True)                                      # (B,T,1)
    Kw  = u * a - v * b                                                        # (B,T,C)
    K2w = u * (a * c - b) + v * (b * c - a)                                    # (B,T,C)
    y_gen = w - Kw + (K2w / denom)                                             # (B,T,C)

    # Antipodal candidate
    if C == 1:
        y_neg = -w
    else:
        # Keep this normalization stable with eps as well
        idx = torch.argmin(v.abs().reshape(-1, C), dim=1, keepdim=True)        # (B*T,1)
        s = v.reshape(-1, C).gather(1, idx)                                    # (B*T,1)
        p = -s * v.reshape(-1, C)
        onehot = F.one_hot(idx.squeeze(-1), num_classes=C).to(s.dtype)
        p = p + onehot
        n = torch.linalg.vector_norm(p, dim=1, keepdim=True).clamp_min(eps)
        p = (p / n).view(B, T, C)
        proj_v = (v * w).sum(dim=-1, keepdim=True) * v                         # (B,T,C)
        proj_p = (p * w).sum(dim=-1, keepdim=True) * p                         # (B,T,C)
        y_neg = w - 2.0 * proj_v - 2.0 * proj_p

    # Fuse selections
    y = torch.where(trivial, w, y_gen)
    y = torch.where(near_neg, y_neg, y)
    return y


        # ----- STREAMING STATE FOR INFERENCE ----
        
        
# Lifted Pyramid Feature Generator
# --------------------------------
# Implements per-rank Stiefel lifts inside the feature generator itself.
# Each rank s:
#   - build dyadic centroid mu_s in base coords (equivalently, linear average in lifted coords)
#   - lift mu_s to rank-s manifold via A_1..A_s
#   - compute phase-transport deltas at rank s in lifted space
#   - orthogonally project the delta back to the lifted column space
#   - lower (left-inverse) to base C for output fusion
# Outputs remain (B, T, K, C), preserving downstream interfaces.

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List

# Expect phase_transport_between(curr, prev, tau) to be available in scope
# from the user's existing code.

# -------------------------------
# Dimension rule (~1/100th rounded)
#   <100 -> 1; 200-300 -> 2; 400-599 -> 3; etc.
# -------------------------------

def add_dims(embed_dim: int) -> int:
    return (embed_dim // 200) + 1


class SharedStiefelTower(nn.Module):
    def __init__(self, base_dim: int, K: int, t: float = 1.0, seed: int = 0, learn_t: bool = False):
        super().__init__()
        assert K >= 1
        self.C = int(base_dim)
        self.K = int(K)

        # compute cumulative lifted dims per rank
        dims = [self.C]
        d_prev = self.C
        for _ in range(1, K):
            d_add = (d_prev // 200) + 1  # your 1/100th rule
            d_prev = d_prev + d_add
            dims.append(d_prev)
        self.rank_dims = dims  # length K; dims[0]=C

        D_final = dims[-1]
        M = torch.randn(D_final, D_final)
        S = M - M.T
        S = S / (S.norm(p='fro') + 1e-12)
        self.register_buffer("S", S, persistent=False)

        t_val = torch.tensor(float(t))
        self.t = nn.Parameter(t_val) if learn_t else t_val
        self._q_cache = {}  # keyed by (r, device, dtype)
    
    @torch.no_grad()
    def _Qr(self, r: int, device, dtype):
        key = (r, str(device), str(dtype), bool(isinstance(self.t, nn.Parameter)))
        if key in self._q_cache:
            return self._q_cache[key]
        D_r = self.rank_dims[r]
        S_r = self.S[:D_r, :D_r].to(device=device, dtype=dtype)
        t = self.t.to(device=device, dtype=dtype) if isinstance(self.t, nn.Parameter) else self.t
        Q_r = torch.matrix_exp(t * S_r)
        self._q_cache[key] = Q_r
        return Q_r

    @torch.no_grad()
    def _Ar(self, r: int, device, dtype):
        # A_r ∈ R^{D_r x C}, columns orthonormal
        Q_r = self._Qr(r, device, dtype)
        return Q_r[:, : self.C]

    @torch.no_grad()
    def up_to(self, r: int, x: torch.Tensor) -> torch.Tensor:
        # x: (..., C) -> (..., D_r)
        A = self._Ar(r, x.device, x.dtype)
        return F.linear(x, A)  # x @ A^T
    
    @torch.no_grad()
    def down_from(self, r: int, y: torch.Tensor) -> torch.Tensor:
        # y: (..., D_r) -> (..., C)
        A = self._Ar(r, y.device, y.dtype)
        return F.linear(y, A.T)  # y @ A
    
    @torch.no_grad()
    def project_to_rank(self, r: int, y: torch.Tensor) -> torch.Tensor:
        # orthogonal projector onto col(A_r)
        return self.up_to(r, self.down_from(r, y))


# -------------------------------
# MultiRankLift: A_1..A_{K-1}
#   dims[0] = C
#   dims[r] = dims[r-1] + add_dims(dims[r-1])
# -------------------------------

# -------------------------------
# Lifted Causal Pyramid (vectorized)
# -------------------------------

class CausalCentroidPyramidLifted(nn.Module):
    def __init__(self, num_scales: int, tau: float = 1e-6, t: float = 1.0, learn_t: bool = False, seed0: int = 0):
        super().__init__()
        assert num_scales >= 1
        self.K = num_scales
        self.tau = float(tau)
        self._lifts: SharedStiefelTower | None = None
        self._t = t
        self._learn_t = learn_t
        self._seed0 = seed0

    def _ensure_lifts(self, C: int):
        if self._lifts is None:
            self._lifts = SharedStiefelTower(C, self.K, t=self._t, learn_t=self._learn_t)

    @torch.no_grad()
    def forward(self, x: torch.Tensor, mask_early: bool = True) -> torch.Tensor:
        B, T, C = x.shape
        self._ensure_lifts(C)
        lifts = self._lifts

        # -------- token-level PT in base --------
        prev_tok = torch.zeros_like(x)
        if T > 1:
            prev_tok[:, 1:, :] = x[:, :-1, :].contiguous()
        d0 = phase_transport_between(x, prev_tok, tau=self.tau)  # (B,T,C)
        if mask_early:
            d0[:, :1, :].zero_()
        if self.K == 1:
            return d0.unsqueeze(2)

        feats = [d0.unsqueeze(2)]

        # -------- recursive dyadic centroids in base --------
        mu_prev_base = x  # mu_0
        for s in range(1, self.K):
            W1 = 1 << (s - 1)
            W = 1 << s
            # mu_s(t) = 0.5 * (mu_{s-1}(t) + mu_{s-1}(t - 2^{s-1}))
            mu_shift = torch.zeros_like(mu_prev_base)
            if W1 > 0 and T > W1:
                mu_shift[:, W1:, :] = mu_prev_base[:, :-W1, :]
            mu_s_base = 0.5 * (mu_prev_base + mu_shift)
            if mask_early:
                if W - 1 < T:
                    mu_s_base[:, : W - 1, :].zero_()

            # prev for PT at this rank
            prev_mu = torch.zeros_like(mu_s_base)
            if T > W:
                prev_mu[:, W:, :] = mu_s_base[:, :-W, :]

            # lift both and compute PT in lifted space
            mu_s_lift = lifts.up_to(s, mu_s_base)
            prev_lift = lifts.up_to(s, prev_mu)
            d_s_lift = phase_transport_between(mu_s_lift, prev_lift, tau=self.tau)
            # ensure closure: project to lifted column space
            d_s_lift = lifts.project_to_rank(s, d_s_lift)
            # lower to base for output
            d_s_base = lifts.down_from(s, d_s_lift)
            if mask_early:
                d_s_base[:, : W, :].zero_()

            feats.append(d_s_base.unsqueeze(2))
            mu_prev_base = mu_s_base  # next rank recursion

        return torch.cat(feats, dim=2)  # (B,T,K,C)

# -------------------------------
# Lifted Streaming State
# -------------------------------

class CausalPyramidStateLifted:
    def __init__(self, num_scales: int, C: int, device, batch_size: int = 1, tau: float = 1e-6,
                 t: float = 1.0, learn_t: bool = False, seed0: int = 0):
        self.K = num_scales
        self.C = C
        self.B = batch_size
        self.device = device
        self.tau = float(tau)
        self.t = 0
        self.lifts = SharedStiefelTower(C, self.K, t=t, learn_t=learn_t).to(device)
        # ring buffers for mu_s in base coords (economical)
        self.buffers = []
        self.ptrs = []
        for s in range(self.K):
            L = 1 << s
            self.buffers.append(torch.zeros(self.B, L, C, device=device))
            self.ptrs.append(0)

    def _read(self, level: int, r: int):
        if self.t < r:
            return torch.zeros(self.B, self.C, device=self.device)
        L = self.buffers[level].size(1)
        idx = (self.ptrs[level] - r) % L
        return self.buffers[level][:, idx, :]

    def _push(self, level: int, value: torch.Tensor):
        L = self.buffers[level].size(1)
        self.buffers[level][:, self.ptrs[level], :] = value
        self.ptrs[level] = (self.ptrs[level] + 1) % L

    @torch.no_grad()
    def step(self, x_t: torch.Tensor) -> torch.Tensor:
        B, C = x_t.shape
        lifts = self.lifts
        feats = []

        # token-level in base
        prev_x = self._read(0, 1)
        d0 = phase_transport_between(x_t[:, None, :], prev_x[:, None, :], tau=self.tau).squeeze(1)
        if self.t == 0:
            d0.zero_()
        feats.append(d0)

        # compute mu_s in base, then PT in lifted, lower result
        mu_prev = x_t
        for s in range(1, self.K):
            W1 = 1 << (s - 1)
            W = 1 << s
            mu_back = self._read(s - 1, W1)
            mu_s = 0.5 * (mu_prev + mu_back)
            if self.t < (W - 1):
                mu_s.zero_()

            mu_prevW = self._read(s, W)
            # lift + PT
            curr_lift = lifts.up_to(s, mu_s)
            prev_lift = lifts.up_to(s, mu_prevW)
            d_s_lift = phase_transport_between(curr_lift[:, None, :], prev_lift[:, None, :], tau=self.tau).squeeze(1)
            d_s_lift = lifts.project_to_rank(s, d_s_lift)
            d_s = lifts.down_from(s, d_s_lift)
            if self.t + 1 <= W:
                d_s.zero_()
            feats.append(d_s)

            # push mu_s (base) and continue
            self._push(s, mu_s)
            mu_prev = mu_s

        # push level-0
        self._push(0, x_t)
        self.t += 1
        return torch.stack(feats, dim=1)   # (B, K, C)

# -------------------------------
# Wrapper: SemanticClusterFeaturesCausalLifted
# -------------------------------

class SemanticClusterFeaturesCausalLifted(nn.Module):
    def __init__(self, num_scales: int, tau: float = 1e-6, t: float = 1.0, learn_t: bool = False, seed0: int = 0):
        super().__init__()
        self.K = num_scales
        self.tau = float(tau)
        self.t = float(t)
        self.learn_t = bool(learn_t)
        self.seed0 = int(seed0)
        self.pyr = None  # lazy init with C known

    def _ensure(self, C: int):
        if self.pyr is None:
            self.pyr = CausalCentroidPyramidLifted(self.K, tau=self.tau, t=self.t, learn_t=self.learn_t, seed0=self.seed0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        self._ensure(x.size(-1))
        return self.pyr(x)

    @torch.no_grad()
    def step(self, x_t: torch.Tensor, state: 'CausalPyramidStateLifted') -> torch.Tensor:
        return state.step(x_t)



class GroupedChannelMLP(nn.Module):
    def __init__(self, k_dim: int, c_dim: int):
        super().__init__()
        hidden_dim = c_dim // 2
        self.k_dim = k_dim
        self.c_dim = c_dim
        self.hidden_dim = hidden_dim

        # shapes chosen for direct einsum without expands
        # fc1: (K, H, C)   fc2: (K, C, H)   b2: (K, C)
        self.fc1_weight = nn.Parameter(torch.empty(k_dim, hidden_dim, c_dim*2))
        self.fc2_weight = nn.Parameter(torch.empty(k_dim, c_dim, hidden_dim))
        self.fc2_bias   = nn.Parameter(torch.empty(k_dim, c_dim))

        nn.init.kaiming_uniform_(self.fc1_weight, a=5**0.5)
        nn.init.kaiming_uniform_(self.fc2_weight, a=5**0.5)
        nn.init.zeros_(self.fc2_bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: (B, T, K, C) or (B, K, C)
        returns: same leading dims, last two dims (K,C)
        """
        squeeze_time = False
        if x.dim() == 3:  # (B,K,C)
            x = x.unsqueeze(1)  # -> (B,1,K,C)
            squeeze_time = True
        elif x.dim() != 4:
            raise ValueError("Input must be (B,K,C) or (B,T,K,C)")

        # (B,T,K,C) x (K,H,C) -> (B,T,K,H)
        h = torch.einsum('btkc,khc->btkh', x, self.fc1_weight)
        h = F.gelu(h)

        # (B,T,K,H) x (K,C,H) -> (B,T,K,C)
        y = torch.einsum('btkh,kch->btkc', h, self.fc2_weight) + self.fc2_bias

        if squeeze_time:
            y = y[:, 0, :, :]  # (B,K,C)
        return y
        
        
class Cell(nn.Module):
    def __init__(self, dim_in: int, hidden: int,dim_out:int):
        super().__init__()
        self.fc1 = nn.Linear(dim_in, hidden, bias=False) #dont change, false intentional
        self.fc2 = nn.Linear(hidden, dim_out, bias=True)
        self.act = nn.GELU()
    def forward(self, x):
      
        return self.fc2(self.act(self.fc1(x))) 

class GPTSemanticBlock(nn.Module):
    def __init__(self, config: GPTConfig,features):
        super().__init__()
        C = config.n_embd
        self.C = C
        self.K = config.n_scales
        # L = number of feature groups concatenated: token (1) + K scales
        self.L = 1 + self.K
        self.features = features #reuse to reduce param/mechanism counts
        #lifting is needed to preserve uniqueness of centroid averaging.
        #this ensures that all information on all scales is useful.
        self.drop = nn.Dropout(config.dropout)
        self.ln = nn.LayerNorm(self.C)
        self.mlp = Cell(self.C,self.C*4,self.C)
        self.convolve  = ManifoldAttentionNoAttn(
                config, d_model=config.n_embd, rank=config.n_embd//2, K=3,
                shift_rank=8, shrink_lambda=0.01,
                causal=False, ar_rho=0.0, eps=1e-5, dropout=0.0,
                use_layernorm=True
            )
        # Each bottleneck maps C -> small_hidden -> C
        self.bottleneck = GroupedChannelMLP(self.K, self.C)
        #bottleneck to drop out meaningful features

    # vectorized
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B, T, C)
        B, T, C = x.shape
        feats = self.features(x)               # (B, T, K, C)
        x_expanded = x.unsqueeze(2)
        # Step 2: tile along K -> (B, T, K, C)
        x_tiled = x_expanded.expand(-1, -1, self.K, -1)
        # Step 3: concatenate along last dim -> (B, T, K, 2C)
        out = torch.cat([x_tiled, feats], dim=-1) #location hint must be issued
        feats = self.bottleneck(out) # (B, T, K, C)#bottlenecked
       # expanded = torch.cat([x_expanded,feats],dim=2)#stack
       # expanded = expanded.reshape(B*(self.K+1),T,C)
       # expanded = self.convolve(expanded)#get out higher features we want
       # expanded = expanded.reshape(B,T,self.K+1,C)
      #  semantics = expanded.sum(dim=2)
        feats= feats.sum(dim=2)
        # concat token embedding with processed features
        x_in = x+feats#+semantics
        out = x + self.drop(self.ln(self.mlp(x_in)))

        return out

    # single-step incremental
    @torch.no_grad()
    def step(self, x_t: torch.Tensor, feat_state: CausalPyramidStateLifted) -> torch.Tensor:
        # x_t: (B, C)
        B, C = x_t.shape
        feats_t = self.features.step(x_t, feat_state)  # (B, K, C)
        x_expanded = x_t.unsqueeze(1)
        # Step 2: tile along K -> (B, T, K, C)
        x_tiled = x_expanded.expand(-1, self.K, -1)
        # Step 3: concatenate along last dim -> (B, T, K, 2C)
        out = torch.cat([x_tiled, feats_t], dim=-1)
        feats_t = self.bottleneck(out)

      #  expanded = torch.cat([x_expanded,feats_t],dim=1)#stack
      #  expanded = expanded.reshape(B*(self.K+1),C)
      #  expanded = self.convolve(expanded)#get out higher features we want
       # expanded = expanded.reshape(B,K+1,C)
       # semantics = expanded.sum(dim=2)
        # concat token embedding with processed features        
        feats_t= feats_t.sum(dim=1)
        x_in = x_t+feats_t#+semantics
        out = x_t + self.drop(self.ln(self.mlp(x_in)))
        return out

        #todo: figure out the semantic grabber/preservation mechanism
        #this is more complex than simple.
        #especially if it evolves- for everywhere in the state-
        #as the context grows.


def _is_prime(n: int) -> bool:
    if n < 2: return False
    if n % 2 == 0: return n == 2
    r = int(n**0.5)
    for f in range(3, r+1, 2):
        if n % f == 0: return False
    return True

def _factorize(n: int):
    f, cnt = [], {}
    d = 2
    while d * d <= n:
        while n % d == 0:
            cnt[d] = cnt.get(d, 0) + 1
            n //= d
        d += 1 if d == 2 else 2
    if n > 1: cnt[n] = cnt.get(n, 0) + 1
    return list(cnt.keys())

def _primitive_root(p: int) -> int:
    # p must be prime
    phi = p - 1
    factors = _factorize(phi)
    for g in range(2, p):
        ok = True
        for q in factors:
            if pow(g, phi // q, p) == 1:
                ok = False
                break
        if ok:
            return g
    raise RuntimeError("no primitive root found")

def _welch_costas_perm(V: int, device=None):
    """
    Welch Costas permutation σ on {0..V-1}, where V = p-1 for prime p.
    σ[i] = g^(i+1) mod p, mapped to 0..V-1 by subtracting 1.
    """
    p = V + 1
    if not _is_prime(p):
        return None
    g = _primitive_root(p)
    sigma = torch.empty(V, dtype=torch.long, device=device)
    for i in range(V):
        sigma[i] = pow(g, i + 1, p) - 1
    return sigma  # permutation of 0..V-1

def _coprime_mul_perm(V: int, device=None):
    """
    Fallback: σ[i] = (a*i + b) % V with gcd(a, V)=1 and a not ≡ ±1 mod V.
    Not Costas, but non-monotone and well-distributed.
    """
    # pick a
    a = None
    for cand in range(2, V):
        if math.gcd(cand, V) == 1 and cand % V not in (1, V-1):
            a = cand
            break
    if a is None:
        a = 1  # degenerate small V
    b = V // 3
    i = torch.arange(V, device=device)
    return ((a * i + b) % V).long()

def _perm_inverse(sigma: torch.Tensor) -> torch.Tensor:
    inv = torch.empty_like(sigma)
    inv[sigma] = torch.arange(sigma.numel(), device=sigma.device)
    return inv

class FlatRollEmbed(nn.Module):
    """
    Replacement for nn.Embedding that maps token id i -> cyclic roll^i of a base
    length-V vector whose non-DC spectrum is flat (DC=0). Requires V == n_embd.
    Weights are frozen by default.
    this yields an optimal embedding that is considered perfect.
    The 'eye' is mixed at 0.5 and then rows are permuted by a Costas-like order
    to maximize uniqueness while keeping even collapse.
    but wait, you're asking, my embeds/vocab is not orthagonal!
    the solution is simple, clever, efficient- 
    use  Smooth full-space rotation matrix via Lie algebra exponential map.
        A = exp(t·G), where G ∈ so(D) is skew-symmetric and full-rank.
    partition vocab idx space by modulo over chosen block size, use different rotation
    range from 0 to pi(evenly divided) for all partitions, use ONE embed matrix,
    embed->shift. Minimizes necessary parameter count. up-project to desired embed dim.
    for decoder, you're operating over a larger dimensional space as-is. that's fine.
    if you like, you can try down-project and repeat-decode invert on all blocks,
    and use stiefel inverting by transpose but use two sets of slices of rotation ranges
    so that you have blue noise coverage with a partition going from original bound to bound
    but also overlap going from mid to mid, try decode on all, hard route to one,
    take logits from that one- > bam, no learned decode either.
    down-projection tied to up-projection and you have a learned high efficiency mapping.
    
    """
    def __init__(self, config, scale: str = "box", seed: int = 0,
                 freeze: bool = True, dtype=None, device=None):
        super().__init__()
        assert config.n_embd == config.vocab_size, (
            f"Expected n_embd == vocab_size, got {config.n_embd} != {config.vocab_size}"
        )
        V = int(config.vocab_size)
        dtype = dtype or torch.float32

        eye = torch.eye(V, dtype=dtype, device=device)
        weight = self._make_weight(V, scale=scale, seed=seed,
                                   dtype=dtype, device=device)  # [V, V]
        M = int(torch.argmax(weight[0]))        # index of max in base x (row 0)
        pm = weight[0, M]                       # scalar
        N = 1.0 / pm
        
        eye = torch.roll(eye, shifts=M, dims=1) # shift spike position within each row
        eye = eye * N
        mixed =  weight + eye  # add identity towers

        # --- compute a strong-scramble row order (Costas if possible) ---
        sigma = _welch_costas_perm(V, device=device)
        if sigma is None:
            sigma = _coprime_mul_perm(V, device=device)
        # We want ones at (row = σ[i], col = i). For row-permutation via index_select,
        # use r_idx = σ^{-1} so that new_row j pulls old_row r_idx[j] with 1 at column j=σ[i].
        r_idx = _perm_inverse(sigma)

        # keep for reference / decoding
        self.register_buffer("row_perm", r_idx, persistent=False)
        self.register_buffer("sigma", sigma, persistent=False)

        mixed = mixed.index_select(0, r_idx)
        self.embed = nn.Embedding.from_pretrained(mixed, freeze=freeze)


    @staticmethod
    def _row_perm_max_equidistant(V: int, device=None) -> torch.Tensor:
        """
        Row permutation that evenly offsets the identity's '1' away from the diagonal.
        Uses a single cyclic shift by k = floor(V/2).
        """
        if V <= 1:
            return torch.arange(V, device=device, dtype=torch.long)
        k = V // 2
        if k == 0:  # only happens when V == 1, handled above; keep for safety
            k = 1
        return ((torch.arange(V, device=device) + k) % V).long()

    @staticmethod
    def _make_weight(V: int, scale: str = "box", seed: int = 0,
                     dtype=torch.float32, device=None) -> torch.Tensor:
        """
        Returns a (V, V) tensor whose rows are cyclic rolls of a base vector x in R^V
        with |FFT(x)|^2 flat for k=1..V-1 and DC=0.
        scale:
          - "unit": ||x||_2 = 1
          - "box":  max|x_i| = 1
        """
        # build on CPU, move at end
        complex_dtype = torch.complex64 if dtype == torch.float32 else torch.complex128
        g = torch.Generator().manual_seed(seed)

        X = torch.zeros(V, dtype=complex_dtype)
        # DC bin
        X[0] = torch.tensor(0, dtype=complex_dtype)

        if V % 2 == 0:
            # bins 1..V/2-1 are complex-conjugate pairs; Nyquist bin must be real
            for k in range(1, V // 2):
                phi = torch.rand((), generator=g) * (2 * math.pi)
                val = torch.cos(phi) + 1j * torch.sin(phi)
                X[k] = val
                X[V - k] = torch.conj(val)
            X[V // 2] = 1.0 if torch.rand((), generator=g) < 0.5 else -1.0
        else:
            for k in range(1, (V - 1) // 2 + 1):
                phi = torch.rand((), generator=g) * (2 * math.pi)
                val = torch.cos(phi) + 1j * torch.sin(phi)
                X[k] = val
                X[V - k] = torch.conj(val)

        x = torch.fft.ifft(X).real  # real length-V base vector

        if scale == "unit":
            x = x / (x.norm() + 1e-12)
        elif scale == "box":
            x = x / (x.abs().max() + 1e-12)
        else:
            raise ValueError("scale must be 'unit' or 'box'")

        rows = [torch.roll(x, shifts=r, dims=0) for r in range(V)]
        W = torch.stack(rows, dim=0).to(dtype=dtype)
        if device is not None:
            W = W.to(device)
        return W

    def forward(self, input_ids: torch.LongTensor):
        # (batch, seq_len, V)
        return self.embed(input_ids)


        
@dataclass
class GPTConfig:
    block_size: int = 2048
    vocab_size: int = 66 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
    n_layer: int = 6
    n_head:int = 6
    n_embd: int = 66 #tied to vocab
    n_scales:int = 3
    #1 is 2, 2=4,3=8,4=16,5=32,5=64,6=128,7=256,8=512, 9 = 1024
    #we keep signals/patterns relevant to 1024 token frames(train on 2x).
    #however, *in PRACTICE*, the higher levels are irrelevant.
    #we at most have a use for :
    #1 is bigrams, 2 is syllablic-level token equivalent,
    #3 is maybe words. We can maybe learn three.
    #beyond that, there are no realistic concepts, rather, we need capacity to keep abstract information.
    #furthermore we need means to *gather* meaningful concepts above this level.
    dropout: float = 0.1

'''
sat 27 2025
plans: okay, first layer essentially learns a structured up-projection to semantic atoms.
think of this GPTSemanticBlock as an intelligent tokenizer.
just need to adjust the way it behaves a little bit so it emits something higher than C,
and have it digest all bottleneck products in parallel to act as a distilling codebook instead of a phase shift.
C*K+1 -> ? -? Q width. LMhead is learned at that point.


problem- even with this phaseblock innovation,
model picks up coarse patterns juliet -> nurse -> romeo but by mid-loss DESTROYS them
is unable to preserve meaningful structure. this- which we observed- remains true.

'''


class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.vocab_size is not None
        assert config.block_size is not None
        self.config = config
        self.n_embd = config.n_embd
        self.features = SemanticClusterFeaturesCausalLifted(num_scales=config.n_scales, tau=1e-8)
        self.transformer = nn.ModuleDict(dict(
            wte = FlatRollEmbed(config),
            h = nn.ModuleList([GPTSemanticBlock(config,self.features) for _ in range(config.n_layer)]),

        ))

        self.lm_head = nn.Linear(self.config.n_embd, self.config.vocab_size, bias=False)
        #dont tie, allow model to adjust probabilities and adjust bifurcation of latent space

    # ---------- forward ----------
    def forward(self, idx, targets=None, eprint=False):
        device = idx.device
        b, t = idx.size()
        x = self.transformer.wte(idx) 


        for block in self.transformer.h:
                x= block(x)


        if targets is not None:
            logits = self.lm_head(x)
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                targets.view(-1),
                ignore_index=-1
            )
        else:
            logits = self.lm_head(x[:, [-1], :])
            loss = None
        return logits, loss


    @torch.no_grad()
    def generate_greedy(model: nn.Module, idx: torch.LongTensor, max_new_tokens: int, block_size: int):
        """
        model: your GPT with:
           - transformer.wte (embedding)
           - transformer.h : list[GPTSemanticBlock]
           - lm_head
        idx: (B, T0) prompt token ids
        """
        device = next(model.parameters()).device
        B = idx.size(0)
        # per-block feature caches
        feat_states = [CausalPyramidStateLifted(model.config.n_scales, model.config.n_embd, device, batch_size=B)
                       for _ in model.transformer.h]
    
        # 1) prime caches with the prompt (causal, one step at a time)
        x_all = model.transformer.wte(idx)  # (B,T0,C); fixed embeddings in your code
        for t in range(idx.size(1)):
            x_t = x_all[:, t, :]
            for blk, st in zip(model.transformer.h, feat_states):
                x_t = blk.step(x_t, st)      # per-block step
            # we discard logits during priming
    
        # 2) roll out new tokens
        out = [idx]
        cur = idx
        for _ in range(max_new_tokens):
            # last token embedding
            last_idx = cur[:, -1]                      # (B,)
            x_t = model.transformer.wte(last_idx)      # (B,C)
            for blk, st in zip(model.transformer.h, feat_states):
                x_t = blk.step(x_t, st)                # (B,C)
            logits = model.lm_head(x_t)                # (B,V)
            next_idx = torch.argmax(logits, dim=-1, keepdim=True)  # greedy; swap to sampling if you like
            out.append(next_idx)
            cur = torch.cat([cur, next_idx], dim=1)
            # keep only last block_size tokens in cur (typical AR convenience)
            if cur.size(1) > block_size:
                cur = cur[:, -block_size:]
        return torch.cat(out, dim=1)


In [118]:
import requests, os

base_url = "https://huggingface.co/datasets/cambridge-climb/BabyLM/resolve/main/clean/10M/"
target_dir = "./babylm_10m_cleaned"
os.makedirs(target_dir, exist_ok=True)

file_names = [
    "aochildes.txt",
    "cbt.txt",
    "children_stories.txt",
    "gutenberg.txt",
    "qed.txt",
    "simple_wikipedia.txt",
    "switchboard.txt",
    "wikipedia.txt"
]

# Optional addition: Shakespeare from another dataset
shakespeare_url = "https://raw.githubusercontent.com/karpathy/char-rnn/refs/heads/master/data/tinyshakespeare/input.txt"
shakespeare_fname = "shakespeare.txt"

# Combined download logic
all_files = [(base_url + fname, fname) for fname in file_names]
all_files.append((shakespeare_url, shakespeare_fname))  # Add Shakespeare


# Download loop
for url, fname in all_files:
    out_path = os.path.join(target_dir, fname)
    print(f"📥 Downloading {fname}...")
    resp = requests.get(url)
    if resp.status_code == 200:
        with open(out_path, "w", encoding="utf-8") as f:
            f.write(resp.text)
    else:
        print(f"❌ Failed to download {fname} ({resp.status_code})")

print(f"✅ Done. Files saved to {target_dir}")

📥 Downloading aochildes.txt...
📥 Downloading cbt.txt...
📥 Downloading children_stories.txt...
📥 Downloading gutenberg.txt...
📥 Downloading qed.txt...
📥 Downloading simple_wikipedia.txt...
📥 Downloading switchboard.txt...
📥 Downloading wikipedia.txt...
📥 Downloading shakespeare.txt...
✅ Done. Files saved to ./babylm_10m_cleaned


In [117]:
import os
import pickle
import numpy as np

# === Paths ===
source_dir = "./babylm_10m_cleaned"
out_dir    = "./babylm_char_tokenized"
os.makedirs(out_dir, exist_ok=True)

file_names = [
    "shakespeare.txt"#,#"aochildes.txt", "cbt.txt", "children_stories.txt", "gutenberg.txt",
    #"qed.txt", "simple_wikipedia.txt", "switchboard.txt", "wikipedia.txt"
]

# === Load and split ===
train_texts, val_texts = [], []
char_set = set()

for fname in file_names:
    with open(os.path.join(source_dir, fname), encoding="utf-8") as f:
        lines = f.readlines()
        n = len(lines)
        split = int(0.9 * n)
        train_part = "".join(lines[:split])
        val_part   = "".join(lines[split:])
        train_texts.append(train_part)
        val_texts.append(val_part)
        char_set.update(train_part)
        char_set.update(val_part)

full_train = "\n".join(train_texts)
full_val   = "\n".join(val_texts)

# === Final vocab ===
char_set = sorted(set(char_set))
vocab_chars = ["<unk>"] + [c for c in char_set if c != "<unk>"]

stoi = {ch: i for i, ch in enumerate(vocab_chars)}
itos = {i: ch for ch, i in stoi.items()}

# === Encode function ===
def encode(text):
    return [stoi.get(c, 0) for c in text]

train_ids = np.array(encode(full_train), dtype=np.uint16)
val_ids   = np.array(encode(full_val),   dtype=np.uint16)

# === Save ===
train_ids.tofile(os.path.join(out_dir, "train.bin"))
val_ids.tofile(os.path.join(out_dir, "val.bin"))



with open(os.path.join(out_dir, "meta.pkl"), "wb") as f:
    pickle.dump({
        "vocab_size": len(stoi),
        "stoi": stoi,
        "itos": itos
    }, f)

print(f"✅ Char tokenizer finalized.")
print(f"🧾 Train tokens: {len(train_ids)} | Val tokens: {len(val_ids)}")
print(f"🔤 Vocab size: {len(stoi)}")

✅ Char tokenizer finalized.
🧾 Train tokens: 1016242 | Val tokens: 99152
🔤 Vocab size: 66


In [124]:
# import os
import pickle
import numpy as np
from torch.utils.data import DataLoader, Dataset
import torch
from torch import nn
import torch.nn.functional as F

device = "cuda" if torch.cuda.is_available() else "cpu"

# === Config ===
data_dir = "./babylm_char_tokenized"  # <- char-tokenized data
block_size = 2048
batch_size = 8

# === Load tokenizer metadata ===
with open(os.path.join(data_dir, 'meta.pkl'), 'rb') as f:
    meta = pickle.load(f)
vocab_size = meta['vocab_size']

# === Load mmap edata (char-level tokens, uint16) ===
train_ids = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
val_ids   = np.memmap(os.path.join(data_dir, 'val.bin'),   dtype=np.uint16, mode='r')

# === Efficient GPU Batch Sampler ===
class GPUBatchDataset(Dataset):
    def __init__(self, mmap_file, block_size, batch_size, device, jitter=63, p_aligned=0.5, pad_len=0):
        self.data = mmap_file
        self.block_size = block_size
        self.batch_size = batch_size
        self.device = device
        self.pad_len = int(pad_len)
        self.sample_len = self.block_size + self.pad_len  # X length
        self.total = len(self.data) - self.sample_len - 1
        self.n_blocks = self.total // self.sample_len
        self.jitter = int(jitter)          # small random offset added to aligned start
        self.p_aligned = float(p_aligned)  # mix aligned and jittered

    def __len__(self):
        return self.total // self.batch_size

    def __getitem__(self, idx):
        X = np.empty((self.batch_size, self.sample_len), dtype=np.int64)
        Y = np.empty((self.batch_size, self.block_size), dtype=np.int64)

        for i in range(self.batch_size):
            # choose a base aligned block
            base_block = np.random.randint(0, self.n_blocks)
            start = base_block * self.sample_len

            # with probability, add a small jitter (keeps cache-friendly contiguous reads)
            if np.random.rand() > self.p_aligned:
                j = np.random.randint(0, self.jitter + 1)
                start = min(start + j, self.total)  # stay in range

            X[i] = self.data[start : start + self.sample_len]
            # targets correspond to the final block_size visible steps
            Y[i] = self.data[start + 1 + self.pad_len : start + 1 + self.pad_len + self.block_size]


        return (
            torch.from_numpy(X).to(self.device, non_blocking=True),
            torch.from_numpy(Y).to(self.device, non_blocking=True)
        )


config = GPTConfig(
    vocab_size=len(stoi),
    n_layer=1,
    n_embd=vocab_size,
    block_size=block_size,
    dropout = 0.1,
)
train_dataset = GPUBatchDataset(train_ids, block_size, batch_size, device, pad_len=0)
# === DataLoader ===
train_loader  = DataLoader(train_dataset, batch_size=1, shuffle=False, num_workers=0)

model = GPT(config)
model = torch.compile(model)
model = model.to(device)

In [125]:
    state_dict = torch.load(file_path)

    # Load the state dictionary into the model
    model.load_state_dict(state_dict)

<All keys matched successfully>

In [120]:
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-3)
losses = []
def train_epoch():
    model.train()
    total_loss = 0
    it = 0
    for xb, yb in train_loader:
          xb, yb = xb[0], yb[0]  # unwrap batch dimension
          optimizer.zero_grad()
          it = it + 1
          logits, loss = model(xb, yb)
          loss = loss
          loss.backward()
          torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
          optimizer.step()
          total_loss += loss.item()
          losses.append(loss.item())
          if it%100==0: print(loss.item()) 
    return total_loss / len(train_loader)

# === Run Training ===
num_epochs = 10
for epoch in range(1, num_epochs + 1):
    train_loss = train_epoch()
    print(f"Epoch {epoch:2d} | Train loss: {train_loss:.4f}")

  check(
  check(


2.258725166320801
2.0910425186157227
1.9376212358474731
1.9467540979385376
1.820500135421753
1.8246970176696777
1.752437710762024
1.7313264608383179
1.741806149482727
1.687230110168457
1.661072015762329
1.689890742301941
1.7256968021392822


KeyboardInterrupt: 

In [100]:
num_epochs = 10
for epoch in range(1, num_epochs + 1):
    train_loss = train_epoch()
    print(f"Epoch {epoch:2d} | Train loss: {train_loss:.4f}")

1.6407729387283325
1.6223140954971313
1.6163536310195923
1.652245283126831
1.652436375617981
1.665479063987732
1.6676760911941528
1.6204293966293335
1.7218784093856812
1.6272563934326172
1.650550127029419
1.66214919090271
1.6792820692062378
1.6504933834075928
1.579734444618225
1.6458642482757568
1.6503043174743652
1.6511585712432861
1.6015089750289917
1.6889526844024658
1.6246801614761353
1.631796956062317
1.6298829317092896
1.6534472703933716
1.6634832620620728
1.6995645761489868
1.643081545829773
1.6797206401824951
1.5814549922943115
1.6392793655395508
1.6286720037460327
1.664374589920044
1.6667989492416382
1.5867314338684082
1.5816336870193481
1.6698397397994995
1.597549557685852
1.5694206953048706
1.711871862411499
1.6672877073287964
1.6869760751724243
1.6014810800552368
1.6567672491073608
1.6341407299041748
1.56685471534729
1.6138548851013184
1.6154268980026245
1.6824930906295776
1.6624200344085693
1.6102018356323242
1.6616630554199219
1.6840475797653198
1.684403419494629
1.674499

KeyboardInterrupt: 

In [126]:
import pickle
def decode_chars(token_ids, itos):
    """
    Decodes a list of character token IDs into a string.
    """
    return ''.join([itos[i] for i in token_ids])

def encode_chars(text, stoi):
    """
    Encodes a string into a list of token IDs, one per character.
    """
    return [stoi.get(c, 0) for c in text]


from collections import deque


@torch.no_grad()
def decode_sequence_char_rolling(
    model, stoi, itos, prompt,
    max_new_tokens=100,
    block_size=1024,
    temperature=1.0,
    space_fallback=' ',
    strict_window=False,          # if True, periodically re-prime caches on the last block
    reprime_every=None            # if strict_window, how often to re-prime (int). Default: block_size
):
    """
    Rolling-block generator that:
      - keeps the ENTIRE generated text (no trimming of output),
      - maintains a rolling block window internally,
      - optionally re-primes feature caches on the last `block_size` tokens to strictly
        mimic block-window semantics seen during training.

    If strict_window=False (default): fastest path; caches stream forever.
    If strict_window=True: we periodically reinitialize the per-layer states using the
      most recent `block_size` tokens. This ensures exact 'sliding window' behavior.
    """
    device = next(model.parameters()).device
    model.eval()
    B = 1

    # ---- encode prompt (fallback to space if empty) ----
    space_id = stoi.get(space_fallback, 0)
    prompt_ids = encode_chars(prompt, stoi)
    if len(prompt_ids) == 0:
        prompt_ids = [space_id]

    # ---- left-pad ONCE to match your training forward's left-pad-to-block ----
    pad_len = max(0, block_size - len(prompt_ids))
    pad_ids = [space_id] * pad_len
    priming_ids = pad_ids + prompt_ids  # padding only used for priming; not returned

    # ---- per-block feature caches (one state per block) ----
    feat_states = [
        CausalPyramidStateLifted(
            num_scales=model.config.n_scales,
            C=model.config.n_embd,
            device=device,
            batch_size=B,
            tau=1e-6
        ) for _ in model.transformer.h
    ]

    # helper: (re-)prime caches with a sequence of token ids (left-pad to block if shorter)
    def _reprime_with_ids(tok_ids):
        # optionally left-pad the window up to block_size (only needed if strict semantics desired)
        if len(tok_ids) < block_size:
            tok_ids = [space_id] * (block_size - len(tok_ids)) + tok_ids
        ids_t = torch.tensor([tok_ids], dtype=torch.long, device=device)  # (1, T)
        x_last = None
        # fresh states
        new_states = [
            CausalPyramidStateLifted(
                num_scales=model.config.n_scales,
                C=model.config.n_embd,
                device=device,
                batch_size=B,
                tau=1e-6
            ) for _ in model.transformer.h
        ]
        for t in range(ids_t.size(1)):
            x_last = model.transformer.wte(ids_t[:, t])  # (1,C)
            for blk, st in zip(model.transformer.h, new_states):
                x_last = blk.step(x_last, st)
        return new_states, x_last

    # ---- initial priming with left-padded prompt ----
    ids = torch.tensor([priming_ids], dtype=torch.long, device=device)
    x_t = None
    for t in range(ids.size(1)):
        x_t = model.transformer.wte(ids[:, t])  # (1,C)
        for blk, st in zip(model.transformer.h, feat_states):
            x_t = blk.step(x_t, st)

    # ---- FULL output accumulator (never trimmed) ----
    out_full = list(prompt_ids)  # store ints

    # ---- rolling window buffer of last block_size tokens (prompt + generated) ----
    window = deque(prompt_ids, maxlen=block_size)

    # strict-window settings
    if reprime_every is None:
        reprime_every = block_size
    steps_since_reprime = 0

    # ---- incremental rollout ----
    for _ in range(max_new_tokens):
        logits = model.lm_head(x_t)  # (1,V)
        if temperature != 1.0:
            logits = logits / float(temperature)
        probs = torch.softmax(logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)  # (1,1)
        next_id = int(next_token.item())

        # record full output
        out_full.append(next_id)

        # advance rolling window
        window.append(next_id)

        # step one token
        x_t = model.transformer.wte(next_token.squeeze(-1))  # (1,C)
        for blk, st in zip(model.transformer.h, feat_states):
            x_t = blk.step(x_t, st)

        # optionally re-prime to strict sliding-window semantics
        if strict_window:
            steps_since_reprime += 1
            if steps_since_reprime >= reprime_every and len(window) == block_size:
                feat_states, x_t = _reprime_with_ids(list(window))
                steps_since_reprime = 0

    # decode full continuation (prompt + all generated)
    return decode_chars(out_full, itos)
    
with open("./babylm_char_tokenized/meta.pkl", "rb") as f:
    meta = pickle.load(f)
stoi = meta["stoi"]
itos = meta["itos"]
import time
then = time.time()
prompt = "ROMEO: ROMEO: ROMEO: Juliet! My Juliet! come forth to me! "
generated = decode_sequence_char_rolling(
    model=model,
    stoi=stoi,
    itos=itos,
    prompt=prompt,
    max_new_tokens=4096,
    block_size=2048,
    temperature=1.0
)

print(generated)
print(time.time()-then)

ValueError: not enough values to unpack (expected 3, got 2)

In [None]:
import numpy as np
a = np.arange(128)



In [123]:
file_path = 'simple_model_tiny.pth'

# 3. Save the model's state_dict
torch.save(model.state_dict(), file_path)