<a href="https://colab.research.google.com/github/digitaldaimyo/ASA/blob/main/building_blocks/ASATrainTest.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Addressed State Model

## Define

In [None]:
#@title Addressed State Attention

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Dict, Tuple

# -------------------------
# RoPE helper (rotate-half)
# -------------------------
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
    x1 = x[..., ::2]
    x2 = x[..., 1::2]
    return torch.stack((-x2, x1), dim=-1).flatten(-2)

class RotaryEmbedding(nn.Module):
    def __init__(self, dim: int, base: float = 10000.0):
        super().__init__()
        assert dim % 2 == 0, "RoPE requires even dim"
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        self._cos_cached = None
        self._sin_cached = None
        self._t_cached = None
        self._device_cached = None

    def get_cos_sin(self, T: int, device, dtype):
        if (
            self._t_cached == T
            and self._cos_cached is not None
            and self._device_cached == device
        ):
            return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype)

        t = torch.arange(T, device=device, dtype=self.inv_freq.dtype)
        freqs = torch.einsum("t,f->tf", t, self.inv_freq)  # [T, d/2]
        emb = torch.cat([freqs, freqs], dim=-1)            # [T, d]
        cos = emb.cos()[None, None, :, :]                  # [1,1,T,d]
        sin = emb.sin()[None, None, :, :]                  # [1,1,T,d]

        self._t_cached = T
        self._device_cached = device
        self._cos_cached = cos
        self._sin_cached = sin
        return cos.to(dtype=dtype), sin.to(dtype=dtype)

def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
    return (x * cos) + (_rotate_half(x) * sin)

# -------------------------
# ALiBi slopes helper
# -------------------------
def alibi_slopes(num_heads: int, device=None, dtype=torch.float32) -> torch.Tensor:
    def get_slopes(n):
        def power_of_2_slopes(n):
            start = 2.0 ** (-(2.0 ** -(math.log2(n) - 3)))
            ratio = start
            return [start * (ratio ** i) for i in range(n)]
        if math.log2(n).is_integer():
            return power_of_2_slopes(n)
        closest = 2 ** math.floor(math.log2(n))
        return power_of_2_slopes(closest) + get_slopes(2 * closest)[0::2][: n - closest]
    return torch.tensor(get_slopes(num_heads), device=device, dtype=dtype)  # [H]

# -------------------------
# softplus init helpers
# -------------------------
def _inv_softplus(y: torch.Tensor) -> torch.Tensor:
    return torch.log(torch.expm1(y))

# -------------------------
# Linear attention feature map (Performer-style)
# -------------------------
def phi(x: torch.Tensor) -> torch.Tensor:
    return F.elu(x) + 1.0


class AddressedStateAttention(nn.Module):
    """
    Addressed State Attention (ASA):
      - prefix-softmax WRITE into slots (O(T))
      - READ routing from tokens -> slots (softmax over slots)
      - optional content-conditioned READ term (gamma)
      - RoPE on write keys (geometry)
      - ALiBi bias on write logits (prefix-friendly)

    Optional slot-space refinement (formerly "k-space"):
      - causal linear attention in a low-dim slot-address coordinate space
      - produces per-token signed weights over slots
      - decoded through the same streaming slot-state basis
      - gated by learnable slotspace_gate (softplus)

    PERF (behavior-preserving):
      - Streaming prefix write states in chunks (no [B,H,K,T,d] materialization)
      - Slot-space prefix scan is chunked (exact)
    """

    def __init__(
        self,
        embed_dim: int,
        num_heads: int = 8,
        num_slots: int = 8,
        dropout: float = 0.1,

        # temperatures / numerics
        read_temperature: float = 1.0,
        write_temperature: float = 1.0,
        state_fp32: bool = True,
        slot_dropout: float = 0.0,
        normalize_k: bool = False,

        # positions (write geometry)
        use_rope_keys: bool = True,
        rope_base: float = 10000.0,

        # write bias (ALiBi)
        use_alibi_write: bool = True,
        alibi_strength_init: float = 0.1,
        learn_alibi_strength: bool = True,
        min_strength: float = 0.0,

        # content-conditioned read term
        use_content_read: bool = True,
        content_read_init: float = -4.0,
        content_read_max_gamma: float = 3.0,

        # optional slot-space refinement
        use_slotspace_refine: bool = True,
        slotspace_dim: int = 32,
        slotspace_gate_init: float = -4.0,
        slotspace_dropout: float = 0.05,
        slotspace_signed_weights: bool = True,

        # RoPE in slot-space matcher (Q/K only)
        use_rope_slotspace: bool = True,
        rope_base_slotspace: float = 100000.0,

        # perf knobs (no behavior change)
        write_chunk_size: int = 128,
        slotspace_chunk_size: int = 128,
    ):
        super().__init__()
        assert embed_dim % num_heads == 0
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.num_slots = num_slots
        self.head_dim = embed_dim // num_heads

        self.dropout = nn.Dropout(dropout)

        self.read_temperature = float(read_temperature)
        self.write_temperature = float(write_temperature)
        self.state_fp32 = bool(state_fp32)
        self.slot_dropout = float(slot_dropout)
        self.normalize_k = bool(normalize_k)
        self.routing_override = None

        self.use_rope_keys = bool(use_rope_keys)
        self.use_alibi_write = bool(use_alibi_write)
        self.learn_alibi_strength = bool(learn_alibi_strength)
        self.min_strength = float(min_strength)

        self.use_content_read = bool(use_content_read)
        self.content_read_max_gamma = float(content_read_max_gamma)

        self.use_slotspace_refine = bool(use_slotspace_refine)
        self.slotspace_dim = int(slotspace_dim)
        self.slotspace_dropout = nn.Dropout(float(slotspace_dropout))
        self.slotspace_signed_weights = bool(slotspace_signed_weights)

        self.write_chunk_size = int(write_chunk_size)
        self.slotspace_chunk_size = int(slotspace_chunk_size)

        # Learned slot keys per head: [H,K,d]
        self.slot_keys = nn.Parameter(
            torch.randn(num_heads, num_slots, self.head_dim) / math.sqrt(self.head_dim)
        )

        # Projections
        self.Wk_write = nn.Linear(embed_dim, embed_dim, bias=False)
        self.Wv_write = nn.Linear(embed_dim, embed_dim, bias=False)
        self.Wq_read  = nn.Linear(embed_dim, embed_dim, bias=False)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)

        # RoPE (write geometry)
        self.rope = RotaryEmbedding(self.head_dim, base=rope_base) if self.use_rope_keys else None

        # ALiBi slopes (buffer)
        if self.use_alibi_write:
            self.register_buffer("_alibi_slopes", alibi_slopes(num_heads), persistent=False)  # [H]
        else:
            self.register_buffer("_alibi_slopes", torch.zeros(num_heads), persistent=False)

        # Learnable ALiBi strength (positive via softplus)
        if self.use_alibi_write and self.learn_alibi_strength:
            init = torch.tensor(float(alibi_strength_init) - self.min_strength).clamp_min(1e-8)
            self._alibi_strength_param = nn.Parameter(_inv_softplus(init))
        else:
            self._alibi_strength_param = None
            self.alibi_strength = float(alibi_strength_init)

        # Content read gamma (>=0 via softplus)
        if self.use_content_read:
            self._content_read_gamma_raw = nn.Parameter(torch.tensor(float(content_read_init)))
        else:
            self._content_read_gamma_raw = None

        # -------------------------
        # Optional slot-space refinement
        # -------------------------
        self.use_rope_slotspace = bool(use_rope_slotspace) and bool(self.use_slotspace_refine)
        if self.use_slotspace_refine:
            self.slot_in  = nn.Linear(num_slots, self.slotspace_dim, bias=False)
            self.slot_q   = nn.Linear(self.slotspace_dim, self.slotspace_dim, bias=False)
            self.slot_k   = nn.Linear(self.slotspace_dim, self.slotspace_dim, bias=False)
            self.slot_v   = nn.Linear(self.slotspace_dim, self.slotspace_dim, bias=False)
            self.slot_out = nn.Linear(self.slotspace_dim, num_slots, bias=False)
            self._slotspace_gate_raw = nn.Parameter(torch.tensor(float(slotspace_gate_init)))

            if self.use_rope_slotspace:
                assert (self.slotspace_dim % 2) == 0, "use_rope_slotspace requires even slotspace_dim"
                self.rope_slotspace = RotaryEmbedding(self.slotspace_dim, base=float(rope_base_slotspace))
            else:
                self.rope_slotspace = None
        else:
            self.slot_in = None
            self.slot_q = self.slot_k = self.slot_v = None
            self.slot_out = None
            self._slotspace_gate_raw = None
            self.rope_slotspace = None

    def _alibi_strength(self, dtype, device) -> torch.Tensor:
        if not (self.use_alibi_write and self.learn_alibi_strength):
            return torch.tensor(self.alibi_strength, dtype=dtype, device=device)
        return (F.softplus(self._alibi_strength_param) + self.min_strength).to(dtype=dtype, device=device)

    def _content_read_gamma(self, dtype, device) -> torch.Tensor:
        if not self.use_content_read:
            return torch.tensor(0.0, dtype=dtype, device=device)
        g = F.softplus(self._content_read_gamma_raw)  # >= 0
        if self.content_read_max_gamma is not None and self.content_read_max_gamma > 0:
            g = g.clamp(max=self.content_read_max_gamma)
        return g.to(dtype=dtype, device=device)

    def _slotspace_gate(self, dtype, device) -> torch.Tensor:
        if not self.use_slotspace_refine:
            return torch.tensor(0.0, dtype=dtype, device=device)
        return F.softplus(self._slotspace_gate_raw).to(dtype=dtype, device=device)

    @staticmethod
    def _safe_exp_sub_max(s: torch.Tensor, m: torch.Tensor) -> torch.Tensor:
        diff = s - m
        diff = diff.masked_fill(~torch.isfinite(m), float("-inf"))
        return torch.exp(diff)

    @torch.compile
    def forward(
        self,
        x: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        return_info: bool = False,
        routing_mode: str = "softmax",           # "softmax" | "top1" | "topk" | "external"
        routing_topk: int = 2,                   # used if routing_mode=="topk"
        read_weights_override: Optional[torch.Tensor] = None,  # [B,H,T,K] or [B,H,L,K]
        routing_noise: Optional[str] = None,     # None | "gumbel" | "gaussian"
        routing_noise_scale: float = 1.0,

    ) -> Tuple[torch.Tensor, Optional[Dict[str, torch.Tensor]]]:
        B, T, C = x.shape
        H, K, d = self.num_heads, self.num_slots, self.head_dim

        # Project (write K/V, read Q)
        k_write = self.Wk_write(x).view(B, T, H, d).transpose(1, 2)  # [B,H,T,d]
        v_write = self.Wv_write(x).view(B, T, H, d).transpose(1, 2)  # [B,H,T,d]
        q_read  = self.Wq_read(x).view(B, T, H, d).transpose(1, 2)   # [B,H,T,d]

        if self.normalize_k:
            k_write = F.normalize(k_write, dim=-1, eps=1e-8)

        # RoPE on write keys
        if self.use_rope_keys:
            cos, sin = self.rope.get_cos_sin(T, device=x.device, dtype=k_write.dtype)
            k_write = apply_rope(k_write, cos, sin)

        # Slot dropout
        slot_keys = self.slot_keys
        if self.training and self.slot_dropout > 0.0:
            drop = (torch.rand((H, K), device=x.device) < self.slot_dropout)
            slot_keys = slot_keys * (~drop).to(slot_keys.dtype).unsqueeze(-1)

        # WRITE logits: [B,H,K,T]
        write_logits_raw = torch.einsum("hkd,bhtd->bhkt", slot_keys, k_write) / math.sqrt(d)

        # Stable dtype for prefix-softmax math
        state_dtype = torch.float32 if (self.state_fp32 and x.dtype != torch.float32) else x.dtype
        write_logits = write_logits_raw.to(state_dtype)

        # Write temperature
        wtemp = max(1e-6, self.write_temperature)
        write_logits = write_logits / wtemp

        # ALiBi distance bias (prefix-friendly)
        alibi_bias_applied = None
        if self.use_alibi_write:
            strength = self._alibi_strength(dtype=state_dtype, device=x.device)  # scalar
            slopes = self._alibi_slopes.to(device=x.device, dtype=state_dtype) * strength  # [H]
            pos_i = torch.arange(T, device=x.device, dtype=state_dtype)  # [T]
            alibi_bias = slopes.view(1, H, 1, 1) * pos_i.view(1, 1, 1, T) # [1,H,1,T]
            write_logits = write_logits + alibi_bias
            alibi_bias_applied = alibi_bias

        # Key padding mask (mask positions that are padding)
        if attention_mask is not None:
            valid = attention_mask.to(dtype=torch.bool)
            write_logits = write_logits.masked_fill(~valid.view(B, 1, 1, T), float("-inf"))
        else:
            valid = None

        # =====================================================
        # STREAMING WRITE + READ (no [B,H,K,T,d] slot states)
        # =====================================================
        content_read_gamma = self._content_read_gamma(dtype=q_read.dtype, device=x.device)
        rtemp = max(1e-6, self.read_temperature)

        out_h = torch.empty((B, H, T, d), device=x.device, dtype=state_dtype)
        read_weights = torch.empty((B, H, T, K), device=x.device, dtype=q_read.dtype)

        # Optional analytics: [B,H,T,K] (later permuted to [B,H,K,T])
        slot_state_norm_t = torch.empty((B, H, T, K), device=x.device, dtype=torch.float32) if return_info else None

        denom_state = torch.zeros((B, H, K), device=x.device, dtype=state_dtype)
        numer_state = torch.zeros((B, H, K, d), device=x.device, dtype=state_dtype)
        m_state = torch.full((B, H, K), float("-inf"), device=x.device, dtype=state_dtype)

        WRITE_CHUNK = self.write_chunk_size

        for t0 in range(0, T, WRITE_CHUNK):
            t1 = min(T, t0 + WRITE_CHUNK)
            L = t1 - t0

            wlog_c = write_logits[:, :, :, t0:t1]  # [B,H,K,L]

            # streaming cummax
            m_c, _ = torch.cummax(wlog_c, dim=-1)  # [B,H,K,L]
            m_new = torch.maximum(m_state.unsqueeze(-1), m_c)  # [B,H,K,L]

            # rescale old prefix state to new max reference
            scale = torch.exp(m_state.unsqueeze(-1) - m_new)  # [B,H,K,L] (exp(-inf)=0)

            denom_c = denom_state.unsqueeze(-1) * scale                  # [B,H,K,L]
            numer_c = numer_state.unsqueeze(-2) * scale.unsqueeze(-1)    # [B,H,K,L,d]

            # new weights
            w_new = self._safe_exp_sub_max(wlog_c, m_new)  # [B,H,K,L]

            # accumulate within chunk
            denom_c = denom_c + torch.cumsum(w_new, dim=-1)  # [B,H,K,L]
            v_c = v_write[:, :, t0:t1, :].to(state_dtype)    # [B,H,L,d]
            add = torch.cumsum(w_new.unsqueeze(-1) * v_c.unsqueeze(2), dim=-2)  # [B,H,K,L,d]
            numer_c = numer_c + add

            # per-token slot state for this chunk only
            slot_state_c = numer_c / denom_c.clamp_min(1e-8).unsqueeze(-1)  # [B,H,K,L,d]
            slot_state_t = slot_state_c.permute(0, 1, 3, 2, 4).contiguous() # [B,H,L,K,d]

            # READ routing logits
            q_read_c = q_read[:, :, t0:t1, :]  # [B,H,L,d]

            # base (key) term
            read_logits_key = torch.einsum("bhld,hkd->bhlk", q_read_c, slot_keys) / math.sqrt(d)

            # optional content term
            read_logits_content = None
            read_logits = read_logits_key
            if self.use_content_read:
                read_logits_content = torch.einsum(
                    "bhld,bhlkd->bhlk",
                    q_read_c,
                    slot_state_t.to(q_read_c.dtype)
                ) / math.sqrt(d)
                read_logits = read_logits + content_read_gamma.to(read_logits.dtype) * read_logits_content

            # Optional: noise on logits to probe routing stability (off by default)
            # You can plumb these as forward() kwargs; see signature snippet below.
            if routing_noise is not None:
                if routing_noise == "gumbel":
                    # gumbel(0,1) noise; scale by routing_noise_scale
                    u = torch.rand_like(read_logits)
                    g = -torch.log(-torch.log(u.clamp_min(1e-8)).clamp_min(1e-8))
                    read_logits = read_logits + routing_noise_scale * g
                elif routing_noise == "gaussian":
                    read_logits = read_logits + routing_noise_scale * torch.randn_like(read_logits)
                else:
                    raise ValueError(f"Unknown routing_noise={routing_noise}")


            if self.routing_override is not None:
                if callable(self.routing_override):
                    ctx = {
                        "t0": t0,
                        "t1": t1,
                        "B": B, "H": H, "T": T, "K": K, "d": d,
                        "rtemp": rtemp,
                        "state_dtype": state_dtype,
                        "q_read_c": q_read_c,          # [B,H,L,d]
                        "slot_keys": slot_keys,        # [H,K,d]
                        "slot_state_t": slot_state_t,  # [B,H,L,K,d] (current prefix slot states)
                        "valid": valid,                # [B,T] or None
                    }



                    # must return [B,H,L,K]
                    read_w_c = self.routing_override(
                        t0, t1, read_logits,    # [B,H,L,K] full (key + content + noise if applied),
                        read_logits_key,        # [B,H,L,K] key-only
                        read_logits_content,    # [B,H,L,K] or None
                        ctx,
                    )
                else:
                    # tensor override: [B,H,T,K]
                    read_w_c = self.routing_override[:, :, t0:t1, :].to(read_logits.dtype)

                # safety: ensure finite + normalize
                read_w_c = torch.nan_to_num(read_w_c, nan=0.0, posinf=0.0, neginf=0.0)
                read_w_c = read_w_c.clamp_min(0.0)
                read_w_c = read_w_c / read_w_c.sum(dim=-1, keepdim=True).clamp_min(1e-8)

            else:

                # Routing mode
                if routing_mode == "softmax":
                    read_w_c = torch.softmax(read_logits / rtemp, dim=-1)  # [B,H,L,K]

                elif routing_mode == "top1":
                    # hard one-hot on argmax
                    top = read_logits.argmax(dim=-1)  # [B,H,L]
                    read_w_c = F.one_hot(top, num_classes=K).to(read_logits.dtype)

                elif routing_mode == "topk":
                    kk = int(routing_topk)
                    kk = max(1, min(K, kk))
                    # mask out everything except top-k then renormalize with softmax-like
                    vals, idx = torch.topk(read_logits, k=kk, dim=-1)
                    masked = torch.full_like(read_logits, float("-inf"))
                    masked.scatter_(-1, idx, vals)
                    read_w_c = torch.softmax(masked / rtemp, dim=-1)

                elif routing_mode == "external":
                    if read_weights_override is None:
                        raise ValueError("routing_mode='external' requires read_weights_override")
                    # accept either full [B,H,T,K] or chunk [B,H,L,K]
                    if read_weights_override.shape[-2] == T:
                        read_w_c = read_weights_override[:, :, t0:t1, :]
                    else:
                        read_w_c = read_weights_override
                    # safety: renormalize
                    read_w_c = read_w_c / read_w_c.sum(dim=-1, keepdim=True).clamp_min(1e-8)

                else:
                    raise ValueError(f"Unknown routing_mode={routing_mode}")

            read_weights[:, :, t0:t1, :] = read_w_c

            # token output
            out_h[:, :, t0:t1, :] = torch.einsum(
                "bhlk,bhlkd->bhld",
                read_w_c.to(state_dtype),
                slot_state_t.to(state_dtype),
            )

            if return_info:
                slot_state_norm_t[:, :, t0:t1, :] = slot_state_t.to(torch.float32).norm(dim=-1)

            # update running states to end-of-chunk
            m_state = m_new[:, :, :, -1]
            denom_state = denom_c[:, :, :, -1]
            numer_state = numer_c[:, :, :, -1, :]

        # =====================================================
        # Optional: causal linear attention in slot-space (CHUNKED prefix scan)
        # =====================================================
        slotspace_delta_norm_mean = None
        if self.use_slotspace_refine:
            slotspace_dtype = state_dtype
            M = self.slotspace_dim

            # Encode read weights into slot-space coordinates
            u = self.slot_in(read_weights.to(slotspace_dtype))  # [B,H,T,M]
            q_s  = self.slot_q(u)
            k_s  = self.slot_k(u)
            v_s  = self.slot_v(u)

            # RoPE in slot-space matcher (Q/K only)
            if self.use_rope_slotspace:
                cos_s, sin_s = self.rope_slotspace.get_cos_sin(T, device=x.device, dtype=q_s.dtype)
                q_s = apply_rope(q_s, cos_s, sin_s)
                k_s = apply_rope(k_s, cos_s, sin_s)

            qf = phi(q_s)
            kf = phi(k_s)

            if valid is not None:
                mask = valid.view(B, 1, T, 1).to(slotspace_dtype)
                qf = qf * mask
                kf = kf * mask
                v_s = v_s * mask

            u2 = torch.empty((B, H, T, M), device=x.device, dtype=slotspace_dtype)

            S_state = torch.zeros((B, H, M, M), device=x.device, dtype=slotspace_dtype)
            Z_state = torch.zeros((B, H, M), device=x.device, dtype=slotspace_dtype)

            SS_CHUNK = self.slotspace_chunk_size

            for t0 in range(0, T, SS_CHUNK):
                t1 = min(T, t0 + SS_CHUNK)
                L = t1 - t0

                qf_c = qf[:, :, t0:t1, :]   # [B,H,L,M]
                kf_c = kf[:, :, t0:t1, :]   # [B,H,L,M]
                v_c  = v_s[:, :, t0:t1, :]  # [B,H,L,M]

                kv = torch.einsum("bhlm,bhln->bhlmn", kf_c, v_c)  # [B,H,L,M,M]
                S_c = torch.cumsum(kv, dim=2)
                Z_c = torch.cumsum(kf_c, dim=2)

                S_c = S_c + S_state.unsqueeze(2)
                Z_c = (Z_c + Z_state.unsqueeze(2)).clamp_min(1e-8)

                num = torch.einsum("bhlm,bhlmn->bhln", qf_c, S_c)
                den = torch.einsum("bhlm,bhlm->bhl", qf_c, Z_c).unsqueeze(-1).clamp_min(1e-8)
                u2[:, :, t0:t1, :] = num / den

                S_state = S_c[:, :, -1, :, :]
                Z_state = Z_c[:, :, -1, :]

            u2 = self.slotspace_dropout(u2)

            # Decode slot weights per token
            slot_w = self.slot_out(u2)  # [B,H,T,K]
            if self.slotspace_signed_weights:
                slot_w = torch.tanh(slot_w)
            else:
                slot_w = torch.softmax(slot_w, dim=-1)

            # Second streaming pass to decode slotspace contribution through slot states
            gate = self._slotspace_gate(dtype=state_dtype, device=x.device).to(state_dtype)

            denom_state = torch.zeros((B, H, K), device=x.device, dtype=state_dtype)
            numer_state = torch.zeros((B, H, K, d), device=x.device, dtype=state_dtype)
            m_state = torch.full((B, H, K), float("-inf"), device=x.device, dtype=state_dtype)

            delta_norm_sum = torch.zeros((), device=x.device, dtype=torch.float32)
            delta_norm_count = 0

            for t0 in range(0, T, WRITE_CHUNK):
                t1 = min(T, t0 + WRITE_CHUNK)
                L = t1 - t0

                wlog_c = write_logits[:, :, :, t0:t1]  # [B,H,K,L]
                m_c, _ = torch.cummax(wlog_c, dim=-1)
                m_new = torch.maximum(m_state.unsqueeze(-1), m_c)

                scale = torch.exp(m_state.unsqueeze(-1) - m_new)
                denom_c = denom_state.unsqueeze(-1) * scale
                numer_c = numer_state.unsqueeze(-2) * scale.unsqueeze(-1)

                w_new = self._safe_exp_sub_max(wlog_c, m_new)
                denom_c = denom_c + torch.cumsum(w_new, dim=-1)

                v_c = v_write[:, :, t0:t1, :].to(state_dtype)  # [B,H,L,d]
                add = torch.cumsum(w_new.unsqueeze(-1) * v_c.unsqueeze(2), dim=-2)
                numer_c = numer_c + add

                slot_state_c = numer_c / denom_c.clamp_min(1e-8).unsqueeze(-1)  # [B,H,K,L,d]
                slot_state_t = slot_state_c.permute(0, 1, 3, 2, 4).contiguous() # [B,H,L,K,d]

                slot_w_c = slot_w[:, :, t0:t1, :].to(state_dtype)  # [B,H,L,K]
                delta_c = torch.einsum("bhlk,bhlkd->bhld", slot_w_c, slot_state_t.to(state_dtype))  # [B,H,L,d]

                out_h[:, :, t0:t1, :] = out_h[:, :, t0:t1, :] + gate * delta_c

                delta_norm_sum = delta_norm_sum + delta_c.detach().to(torch.float32).norm(dim=-1).sum()
                delta_norm_count += (B * H * L)

                m_state = m_new[:, :, :, -1]
                denom_state = denom_c[:, :, :, -1]
                numer_state = numer_c[:, :, :, -1, :]

            slotspace_delta_norm_mean = (delta_norm_sum / max(1, delta_norm_count)).detach().cpu()

        # Finish
        out = out_h.transpose(1, 2).contiguous().view(B, T, C)
        out = self.out_proj(out)
        out = self.dropout(out)

        info = None
        if return_info:
            info = {
                "write_logits_raw": write_logits_raw.detach(),
                "write_logits": write_logits.detach().to(torch.float32),
                "read_weights": read_weights.detach(),
                # [B,H,K,T]
                "slot_state_norm": slot_state_norm_t.detach().permute(0, 1, 3, 2).contiguous() if slot_state_norm_t is not None else None,
                "content_read_gamma": content_read_gamma.detach().to(torch.float32).cpu(),
            }
            if alibi_bias_applied is not None:
                info["alibi_bias_applied"] = alibi_bias_applied.detach().to(torch.float32)
            if self.use_alibi_write and self.learn_alibi_strength:
                info["alibi_strength"] = self._alibi_strength(dtype=torch.float32, device=x.device).detach().cpu()
            if self.use_slotspace_refine:
                info["slotspace_gate"] = self._slotspace_gate(dtype=torch.float32, device=x.device).detach().cpu()
                info["use_rope_slotspace"] = torch.tensor(bool(self.use_rope_slotspace))
                if slotspace_delta_norm_mean is not None:
                    info["slotspace_delta_norm"] = slotspace_delta_norm_mean

            info["read_logits"] = read_logits.detach().to(torch.float32)
            info["read_logits_key"] = read_logits_key.detach().to(torch.float32)
            if read_logits_content is not None:
                info["read_logits_content"] = read_logits_content.detach().to(torch.float32)
            info["routing_mode"] = routing_mode

        return out, info

In [None]:

#@title Addressed State Attention (training-focused efficient version) â€” ONLINE slotspace scan, single pass, chunkwise dropout
# Drop-in replacement for your current "training-focused efficient version".
# Key change:
#   - Slotspace refine is computed ONLINE inside the same write/read chunk loop.
#   - No [B,H,T,K] read_weights buffer, no [B,H,T,M] u2 buffer, no second streaming decode pass.
#   - Slotspace dropout is applied chunkwise to u2_c (as requested).

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Dict, Tuple

# -------------------------
# RoPE helper (rotate-half)
# -------------------------
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
    x1 = x[..., ::2]
    x2 = x[..., 1::2]
    return torch.stack((-x2, x1), dim=-1).flatten(-2)

class RotaryEmbedding(nn.Module):
    def __init__(self, dim: int, base: float = 10000.0):
        super().__init__()
        assert dim % 2 == 0, "RoPE requires even dim"
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        self._cos_cached = None
        self._sin_cached = None
        self._t_cached = None
        self._device_cached = None

    def get_cos_sin(self, T: int, device, dtype):
        if (
            self._t_cached == T
            and self._cos_cached is not None
            and self._device_cached == device
        ):
            return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype)

        t = torch.arange(T, device=device, dtype=self.inv_freq.dtype)
        freqs = torch.einsum("t,f->tf", t, self.inv_freq)  # [T, d/2]
        emb = torch.cat([freqs, freqs], dim=-1)            # [T, d]
        cos = emb.cos()[None, None, :, :]                  # [1,1,T,d]
        sin = emb.sin()[None, None, :, :]                  # [1,1,T,d]

        self._t_cached = T
        self._device_cached = device
        self._cos_cached = cos
        self._sin_cached = sin
        return cos.to(dtype=dtype), sin.to(dtype=dtype)

def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
    return (x * cos) + (_rotate_half(x) * sin)

# -------------------------
# ALiBi slopes helper
# -------------------------
def alibi_slopes(num_heads: int, device=None, dtype=torch.float32) -> torch.Tensor:
    def get_slopes(n):
        def power_of_2_slopes(n):
            start = 2.0 ** (-(2.0 ** -(math.log2(n) - 3)))
            ratio = start
            return [start * (ratio ** i) for i in range(n)]
        if math.log2(n).is_integer():
            return power_of_2_slopes(n)
        closest = 2 ** math.floor(math.log2(n))
        return power_of_2_slopes(closest) + get_slopes(2 * closest)[0::2][: n - closest]
    return torch.tensor(get_slopes(num_heads), device=device, dtype=dtype)  # [H]

def _inv_softplus(y: torch.Tensor) -> torch.Tensor:
    return torch.log(torch.expm1(y))

def phi(x: torch.Tensor) -> torch.Tensor:
    return F.elu(x) + 1.0


class AddressedStateAttention(nn.Module):
    """
    Training-focused ASA (ONLINE slotspace scan, single pass):
      - streaming write + read (no [B,H,K,T,d])
      - optional slotspace refine computed ONLINE inside the same chunk loop:
          * causal linear-attn scan state (S_state, Z_state) carried across chunks
          * decode slot_w_c for the chunk only, apply via SAME slot_state_t (no second pass)
      - minimal allocations by default
      - optional lightweight stats for monitoring
    """

    def __init__(
        self,
        embed_dim: int,
        num_heads: int = 8,
        num_slots: int = 8,
        dropout: float = 0.1,

        # temperatures / numerics
        read_temperature: float = 1.0,
        write_temperature: float = 1.0,
        state_fp32: bool = True,
        slot_dropout: float = 0.0,
        normalize_k: bool = False,

        # positions (write geometry)
        use_rope_keys: bool = True,
        rope_base: float = 10000.0,

        # write bias (ALiBi)
        use_alibi_write: bool = True,
        alibi_strength_init: float = 0.1,
        learn_alibi_strength: bool = True,
        min_strength: float = 0.0,

        # content-conditioned read term
        use_content_read: bool = True,
        content_read_init: float = -4.0,
        content_read_max_gamma: float = 3.0,

        # optional slot-space refinement (2nd order primary path)
        use_slotspace_refine: bool = True,
        slotspace_dim: int = 32,
        slotspace_gate_init: float = -4.0,
        slotspace_dropout: float = 0.05,
        slotspace_signed_weights: bool = True,

        # RoPE in slot-space matcher (Q/K only)
        use_rope_slotspace: bool = True,
        rope_base_slotspace: float = 100000.0,

        # perf knobs
        write_chunk_size: int = 128,
        enable_compiled: bool = False,

        # training diag knobs
        return_light_stats_default: bool = False,
    ):
        super().__init__()
        assert embed_dim % num_heads == 0
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.num_slots = num_slots
        self.head_dim = embed_dim // num_heads

        self.dropout = nn.Dropout(dropout)

        self.read_temperature = float(read_temperature)
        self.write_temperature = float(write_temperature)
        self.state_fp32 = bool(state_fp32)
        self.slot_dropout = float(slot_dropout)
        self.normalize_k = bool(normalize_k)

        self.use_rope_keys = bool(use_rope_keys)
        self.use_alibi_write = bool(use_alibi_write)
        self.learn_alibi_strength = bool(learn_alibi_strength)
        self.min_strength = float(min_strength)

        self.use_content_read = bool(use_content_read)
        self.content_read_max_gamma = float(content_read_max_gamma)

        self.use_slotspace_refine = bool(use_slotspace_refine)
        self.slotspace_dim = int(slotspace_dim)
        self.slotspace_dropout = nn.Dropout(float(slotspace_dropout))
        self.slotspace_signed_weights = bool(slotspace_signed_weights)

        self.use_rope_slotspace = bool(use_rope_slotspace) and bool(self.use_slotspace_refine)

        self.write_chunk_size = int(write_chunk_size)
        self.return_light_stats_default = bool(return_light_stats_default)

        H, K, d = self.num_heads, self.num_slots, self.head_dim

        # Learned slot keys per head: [H,K,d]
        self.slot_keys = nn.Parameter(torch.randn(H, K, d) / math.sqrt(d))

        # Projections
        self.Wk_write = nn.Linear(embed_dim, embed_dim, bias=False)
        self.Wv_write = nn.Linear(embed_dim, embed_dim, bias=False)
        self.Wq_read  = nn.Linear(embed_dim, embed_dim, bias=False)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)

        # RoPE (write geometry)
        self.rope = RotaryEmbedding(d, base=rope_base) if self.use_rope_keys else None

        # ALiBi slopes
        if self.use_alibi_write:
            self.register_buffer("_alibi_slopes", alibi_slopes(H), persistent=False)  # [H]
        else:
            self.register_buffer("_alibi_slopes", torch.zeros(H), persistent=False)

        if self.use_alibi_write and self.learn_alibi_strength:
            init = torch.tensor(float(alibi_strength_init) - self.min_strength).clamp_min(1e-8)
            self._alibi_strength_param = nn.Parameter(_inv_softplus(init))
        else:
            self._alibi_strength_param = None
            self.alibi_strength = float(alibi_strength_init)

        # Content read gamma
        if self.use_content_read:
            self._content_read_gamma_raw = nn.Parameter(torch.tensor(float(content_read_init)))
        else:
            self._content_read_gamma_raw = None

        # Slotspace refine stack
        if self.use_slotspace_refine:
            self.slot_in  = nn.Linear(K, self.slotspace_dim, bias=False)
            self.slot_q   = nn.Linear(self.slotspace_dim, self.slotspace_dim, bias=False)
            self.slot_k   = nn.Linear(self.slotspace_dim, self.slotspace_dim, bias=False)
            self.slot_v   = nn.Linear(self.slotspace_dim, self.slotspace_dim, bias=False)
            self.slot_out = nn.Linear(self.slotspace_dim, K, bias=False)
            self._slotspace_gate_raw = nn.Parameter(torch.tensor(float(slotspace_gate_init)))

            if self.use_rope_slotspace:
                assert (self.slotspace_dim % 2) == 0
                self.rope_slotspace = RotaryEmbedding(self.slotspace_dim, base=float(rope_base_slotspace))
            else:
                self.rope_slotspace = None
        else:
            self.slot_in = None
            self.slot_q = self.slot_k = self.slot_v = None
            self.slot_out = None
            self._slotspace_gate_raw = None
            self.rope_slotspace = None

        # Compile only the inner chunk kernel if desired.
        self._compiled = None
        if enable_compiled:
            self.enable_compiled_kernel()

    def enable_compiled_kernel(self):
        if self._compiled is None:
            self._compiled = torch.compile(self._write_read_chunk, fullgraph=False)

    def _alibi_strength(self, dtype, device) -> torch.Tensor:
        if not (self.use_alibi_write and self.learn_alibi_strength):
            return torch.tensor(getattr(self, "alibi_strength", 0.0), dtype=dtype, device=device)
        return (F.softplus(self._alibi_strength_param) + self.min_strength).to(dtype=dtype, device=device)

    def _content_read_gamma(self, dtype, device) -> torch.Tensor:
        if not self.use_content_read:
            return torch.tensor(0.0, dtype=dtype, device=device)
        g = F.softplus(self._content_read_gamma_raw)
        if self.content_read_max_gamma is not None and self.content_read_max_gamma > 0:
            g = g.clamp(max=self.content_read_max_gamma)
        return g.to(dtype=dtype, device=device)

    def _slotspace_gate(self, dtype, device) -> torch.Tensor:
        if not self.use_slotspace_refine:
            return torch.tensor(0.0, dtype=dtype, device=device)
        return F.softplus(self._slotspace_gate_raw).to(dtype=dtype, device=device)

    @staticmethod
    def _safe_exp_sub_max(s: torch.Tensor, m: torch.Tensor) -> torch.Tensor:
        diff = s - m
        diff = diff.masked_fill(~torch.isfinite(m), float("-inf"))
        return torch.exp(diff)

    def _write_read_chunk(
        self,
        wlog_c: torch.Tensor,       # [B,H,K,L] state_dtype
        v_c: torch.Tensor,          # [B,H,L,d] state_dtype
        q_read_c: torch.Tensor,     # [B,H,L,d] q dtype
        slot_keys: torch.Tensor,    # [H,K,d]
        content_gamma: torch.Tensor,# scalar tensor (q dtype)
        rtemp: float,
        m_state: torch.Tensor,      # [B,H,K]
        denom_state: torch.Tensor,  # [B,H,K]
        numer_state: torch.Tensor,  # [B,H,K,d]
    ) -> Tuple[
        torch.Tensor,  # out_base_c [B,H,L,d] (state_dtype)
        torch.Tensor,  # read_w_c   [B,H,L,K] (q dtype)
        torch.Tensor,  # slot_state_t [B,H,L,K,d] (state_dtype)  <-- needed for online 2nd order
        torch.Tensor,  # m_state_new
        torch.Tensor,  # denom_state_new
        torch.Tensor,  # numer_state_new
    ]:
        B, H, K, L = wlog_c.shape
        d = numer_state.shape[-1]
        state_dtype = numer_state.dtype

        # prefix-softmax streaming within chunk
        m_c, _ = torch.cummax(wlog_c, dim=-1)              # [B,H,K,L]
        m_new = torch.maximum(m_state.unsqueeze(-1), m_c)  # [B,H,K,L]
        scale = torch.exp(m_state.unsqueeze(-1) - m_new)   # [B,H,K,L]

        denom_c = denom_state.unsqueeze(-1) * scale
        numer_c = numer_state.unsqueeze(-2) * scale.unsqueeze(-1)

        w_new = self._safe_exp_sub_max(wlog_c, m_new)
        denom_c = denom_c + torch.cumsum(w_new, dim=-1)

        add = torch.cumsum(w_new.unsqueeze(-1) * v_c.unsqueeze(2), dim=-2)
        numer_c = numer_c + add

        slot_state_c = numer_c / denom_c.clamp_min(1e-8).unsqueeze(-1)      # [B,H,K,L,d]
        slot_state_t = slot_state_c.permute(0, 1, 3, 2, 4).contiguous()     # [B,H,L,K,d]

        # routing
        read_logits_key = torch.einsum("bhld,hkd->bhlk", q_read_c, slot_keys) / math.sqrt(d)
        if self.use_content_read:
            read_logits_content = torch.einsum("bhld,bhlkd->bhlk", q_read_c, slot_state_t.to(q_read_c.dtype)) / math.sqrt(d)
            read_logits = read_logits_key + content_gamma.to(read_logits_key.dtype) * read_logits_content
        else:
            read_logits = read_logits_key

        read_w_c = torch.softmax(read_logits / rtemp, dim=-1)

        # base output
        out_base_c = torch.einsum("bhlk,bhlkd->bhld", read_w_c.to(state_dtype), slot_state_t)

        # update prefix state
        m_state_new = m_new[:, :, :, -1]
        denom_state_new = denom_c[:, :, :, -1]
        numer_state_new = numer_c[:, :, :, -1, :]

        return out_base_c, read_w_c, slot_state_t, m_state_new, denom_state_new, numer_state_new

    def forward(
        self,
        x: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        return_info: bool = False,
        return_light_stats: Optional[bool] = None,
    ) -> Tuple[torch.Tensor, Optional[Dict[str, torch.Tensor]]]:

        if return_light_stats is None:
            return_light_stats = self.return_light_stats_default

        B, T, C = x.shape
        H, K, d = self.num_heads, self.num_slots, self.head_dim

        # projections
        k_write = self.Wk_write(x).view(B, T, H, d).transpose(1, 2)  # [B,H,T,d]
        v_write = self.Wv_write(x).view(B, T, H, d).transpose(1, 2)  # [B,H,T,d]
        q_read  = self.Wq_read(x).view(B, T, H, d).transpose(1, 2)   # [B,H,T,d]

        if self.normalize_k:
            k_write = F.normalize(k_write, dim=-1, eps=1e-8)

        if self.use_rope_keys:
            cos, sin = self.rope.get_cos_sin(T, device=x.device, dtype=k_write.dtype)
            k_write = apply_rope(k_write, cos, sin)

        # slot dropout
        slot_keys = self.slot_keys
        if self.training and self.slot_dropout > 0.0:
            drop = (torch.rand((H, K), device=x.device) < self.slot_dropout)
            slot_keys = slot_keys * (~drop).to(slot_keys.dtype).unsqueeze(-1)

        # write logits [B,H,K,T]
        write_logits_raw = torch.einsum("hkd,bhtd->bhkt", slot_keys, k_write) / math.sqrt(d)

        state_dtype = torch.float32 if (self.state_fp32 and x.dtype != torch.float32) else x.dtype
        write_logits = write_logits_raw.to(state_dtype)

        wtemp = max(1e-6, self.write_temperature)
        write_logits = write_logits / wtemp

        if self.use_alibi_write:
            strength = self._alibi_strength(dtype=state_dtype, device=x.device)
            slopes = self._alibi_slopes.to(device=x.device, dtype=state_dtype) * strength
            pos = torch.arange(T, device=x.device, dtype=state_dtype)
            write_logits = write_logits + slopes.view(1, H, 1, 1) * pos.view(1, 1, 1, T)

        if attention_mask is not None:
            valid = attention_mask.to(dtype=torch.bool)
            write_logits = write_logits.masked_fill(~valid.view(B, 1, 1, T), float("-inf"))
        else:
            valid = None

        content_gamma = self._content_read_gamma(dtype=q_read.dtype, device=x.device)
        rtemp = max(1e-6, self.read_temperature)

        out_h = torch.empty((B, H, T, d), device=x.device, dtype=state_dtype)

        # write prefix state
        denom_state = torch.zeros((B, H, K), device=x.device, dtype=state_dtype)
        numer_state = torch.zeros((B, H, K, d), device=x.device, dtype=state_dtype)
        m_state = torch.full((B, H, K), float("-inf"), device=x.device, dtype=state_dtype)

        # slotspace prefix scan state (ONLINE)
        if self.use_slotspace_refine:
            M = self.slotspace_dim
            slotspace_dtype = state_dtype
            S_state = torch.zeros((B, H, M, M), device=x.device, dtype=slotspace_dtype)
            Z_state = torch.zeros((B, H, M), device=x.device, dtype=slotspace_dtype)
            gate = self._slotspace_gate(dtype=state_dtype, device=x.device).to(state_dtype)

            # precompute RoPE cos/sin for slotspace once, slice per chunk (cheap)
            if self.use_rope_slotspace:
                cos_s_full, sin_s_full = self.rope_slotspace.get_cos_sin(T, device=x.device, dtype=slotspace_dtype)
            else:
                cos_s_full = sin_s_full = None

            if return_info and return_light_stats:
                delta_norm_sum = torch.zeros((), device=x.device, dtype=torch.float32)
                delta_norm_count = 0

        if return_info and return_light_stats:
            entropy_sum = torch.zeros((), device=x.device, dtype=torch.float32)
            top1_sum = torch.zeros((), device=x.device, dtype=torch.float32)
            stat_count = 0

        WRITE_CHUNK = self.write_chunk_size
        kernel = self._compiled if self._compiled is not None else self._write_read_chunk

        for t0 in range(0, T, WRITE_CHUNK):
            t1 = min(T, t0 + WRITE_CHUNK)
            L = t1 - t0

            wlog_c = write_logits[:, :, :, t0:t1]                    # [B,H,K,L]
            v_c    = v_write[:, :, t0:t1, :].to(state_dtype)         # [B,H,L,d]
            q_c    = q_read[:, :, t0:t1, :]                          # [B,H,L,d]

            out_base_c, rw_c, slot_state_t, m_state, denom_state, numer_state = kernel(
                wlog_c, v_c, q_c, slot_keys, content_gamma, rtemp,
                m_state, denom_state, numer_state
            )

            out_c = out_base_c

            # -------------------------
            # ONLINE slotspace refine: scan update + decode + apply on THIS chunk's slot_state_t
            # -------------------------
            if self.use_slotspace_refine:
                # encode read weights into slotspace coords
                u_c = self.slot_in(rw_c.to(slotspace_dtype))      # [B,H,L,M]
                q_s = self.slot_q(u_c)
                k_s = self.slot_k(u_c)
                v_s = self.slot_v(u_c)

                # slotspace RoPE: slice [t0:t1] so positions remain absolute
                if self.use_rope_slotspace:
                    cos_s = cos_s_full[:, :, t0:t1, :]            # [1,1,L,M]
                    sin_s = sin_s_full[:, :, t0:t1, :]
                    q_s = apply_rope(q_s, cos_s, sin_s)
                    k_s = apply_rope(k_s, cos_s, sin_s)

                # optional mask
                if valid is not None:
                    mask_c = valid[:, t0:t1].view(B, 1, L, 1).to(slotspace_dtype)
                    q_s = q_s * mask_c
                    k_s = k_s * mask_c
                    v_s = v_s * mask_c

                qf_c = phi(q_s)
                kf_c = phi(k_s)

                # online prefix scan within chunk
                kv_c = torch.einsum("bhlm,bhln->bhlmn", kf_c, v_s.to(slotspace_dtype))   # [B,H,L,M,M]
                S_c = torch.cumsum(kv_c, dim=2) + S_state.unsqueeze(2)                  # [B,H,L,M,M]
                Z_c = torch.cumsum(kf_c, dim=2) + Z_state.unsqueeze(2)                  # [B,H,L,M]
                Z_c = Z_c.clamp_min(1e-8)

                num = torch.einsum("bhlm,bhlmn->bhln", qf_c, S_c)                        # [B,H,L,M]
                den = torch.einsum("bhlm,bhlm->bhl", qf_c, Z_c).unsqueeze(-1)            # [B,H,L,1]
                u2_c = num / den.clamp_min(1e-8)

                # carry prefix state forward
                S_state = S_c[:, :, -1, :, :]
                Z_state = Z_c[:, :, -1, :]

                # chunkwise dropout (requested)
                u2_c = self.slotspace_dropout(u2_c)

                # decode weights for this chunk only
                slot_w_c = self.slot_out(u2_c)                                          # [B,H,L,K]
                if self.slotspace_signed_weights:
                    slot_w_c = torch.tanh(slot_w_c)
                else:
                    slot_w_c = torch.softmax(slot_w_c, dim=-1)

                # second-order delta through *current* prefix slot state for this chunk
                delta_c = torch.einsum(
                    "bhlk,bhlkd->bhld",
                    slot_w_c.to(state_dtype),
                    slot_state_t,   # already state_dtype
                )

                out_c = out_c + gate * delta_c

                if return_info and return_light_stats:
                    delta_norm_sum += delta_c.detach().to(torch.float32).norm(dim=-1).sum()
                    delta_norm_count += (B * H * L)

            out_h[:, :, t0:t1, :] = out_c

            if return_info and return_light_stats:
                p = rw_c.clamp_min(1e-8)
                ent = -(p * p.log()).sum(dim=-1).mean()
                top = rw_c.argmax(dim=-1).reshape(-1)
                hist = torch.bincount(top, minlength=K).float()
                top1 = (hist.max() / hist.sum().clamp_min(1.0))
                entropy_sum += ent.detach().to(torch.float32)
                top1_sum += top1.detach().to(torch.float32)
                stat_count += 1

        # finish projection
        out = out_h.transpose(1, 2).contiguous().view(B, T, C)
        out = self.out_proj(out)
        out = self.dropout(out)

        info = None
        if return_info:
            info = {}
            if return_light_stats:
                info["alibi_strength_mean"] = self._alibi_strength(dtype=torch.float32, device=x.device).detach().cpu()
                info["content_read_gamma_mean"] = self._content_read_gamma(dtype=torch.float32, device=x.device).detach().cpu()
                info["entropy_mean"] = (entropy_sum / max(1, stat_count)).detach().cpu()
                info["top1freq_mean"] = (top1_sum / max(1, stat_count)).detach().cpu()
                if self.use_slotspace_refine:
                    info["slotspace_gate_mean"] = self._slotspace_gate(dtype=torch.float32, device=x.device).detach().cpu()
                    if "delta_norm_sum" in locals():
                        info["slotspace_delta_norm"] = (delta_norm_sum / max(1, delta_norm_count)).detach().cpu()
            # heavy tensors intentionally omitted

        return out, info

In [None]:

#@title LM and Config defs
# ============================================================================
# Addressed State Models (ASM): Config + Block + LM
# - Naming aligned with paper: slots, read/write, slot-space refinement
# - No compatibility layer (fresh public tooling)
# - Assumes AddressedStateAttention is defined elsewhere (the primitive module)
# ============================================================================

import math
from dataclasses import dataclass
from typing import Tuple, Optional

import torch
import torch.nn as nn


# ============================================================================
# Config
# ============================================================================
@dataclass
class ASMTrainConfig:
    # Data
    dataset_name: str = "wikitext"
    dataset_config: str = "wikitext-103-raw-v1"
    tokenizer_name: str = "gpt2"

    max_seq_len: int = 256
    stride_frac_val: float = 0.50
    seed: int = 1337

    # Sample budgets
    train_samples_target: int = 100_000_000
    val_samples_target: int = 25_000

    # Training
    batch_size: int = 64
    learning_rate: float = 3e-4
    weight_decay: float = 0.01
    betas: Tuple[float, float] = (0.9, 0.95)
    grad_clip: float = 1.0
    warmup_steps: int = 1_000
    total_steps: int = 75_000
    eval_interval: int = 1_000
    log_interval: int = 100

    # Model
    vocab_size: int = 50257
    embed_dim: int = 384
    num_layers: int = 23
    num_heads: int = 8
    num_slots: int = 32
    mlp_ratio: float = 4.0
    dropout: float = 0.1
    tie_weights: bool = True

    # Addressed State Attention (ASA) / numerics
    read_temperature: float = 1.0
    write_temperature: float = 1.0
    slot_dropout: float = 0.05
    state_fp32: bool = True
    normalize_k: bool = False

    # Positions
    use_abs_pos: bool = False
    use_rope_keys: bool = True
    rope_base: float = 10000.0
    use_alibi_write: bool = True
    alibi_strength_init: float = 0.1
    learn_alibi_strength: bool = True
    min_strength: float = 0.0

    # Content-conditioned read term (gamma)
    use_content_read: bool = True
    content_read_init: float = -4.0
    content_read_max_gamma: float = 3.0

    # Optional slot-space refinement (formerly "k-space")
    use_slotspace_refine: bool = True
    slotspace_dim: int = 64
    slotspace_gate_init: float = -4.0
    slotspace_dropout: float = 0.05
    slotspace_signed_weights: bool = True

    # RoPE inside slot-space matcher (Q/K only)
    use_rope_slotspace: bool = True
    rope_base_slotspace: float = 100000.0

    # Perf knobs (behavior-identical)
    write_chunk_size: int = 128
    slotspace_chunk_size: int = 128
    enable_compiled: bool = False

    # Analytics
    eval_max_batches: int = 150
    analytics_last_k: int = 4

    # IO / caches
    output_dir: str = "./drive/MyDrive/asm_outputs"
    tag: str = "asm_wikitext"
    cache_dir: str = "./drive/MyDrive/asm_caches"
    val_windows_cache: str = "./drive/MyDrive/asm_val_cache_windows_1024.pkl"


# ============================================================================
# Block
# ============================================================================
class ASMBlock(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        num_slots: int,
        mlp_ratio: float = 4.0,
        dropout: float = 0.1,

        # temperatures / numerics
        read_temperature: float = 1.0,
        write_temperature: float = 1.0,
        state_fp32: bool = True,
        slot_dropout: float = 0.0,
        normalize_k: bool = False,

        # positions
        use_rope_keys: bool = True,
        rope_base: float = 10000.0,
        use_alibi_write: bool = True,

        # ALiBi params
        alibi_strength_init: float = 0.1,
        learn_alibi_strength: bool = True,
        min_strength: float = 0.0,

        # content-conditioned read (gamma)
        use_content_read: bool = True,
        content_read_init: float = -4.0,
        content_read_max_gamma: float = 3.0,

        # optional slot-space refinement
        use_slotspace_refine: bool = True,
        slotspace_dim: int = 32,
        slotspace_gate_init: float = -10.0,
        slotspace_dropout: float = 0.0,
        slotspace_signed_weights: bool = True,

        # RoPE inside slot-space matcher
        use_rope_slotspace: bool = True,
        rope_base_slotspace: float = 100000.0,

        # chunk sizes
        write_chunk_size: int = 128,
        slotspace_chunk_size: int = 128,
        enable_compiled: bool = False,
    ):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)

        self.asa = AddressedStateAttention(
            embed_dim=embed_dim,
            num_heads=num_heads,
            num_slots=num_slots,
            dropout=dropout,

            read_temperature=read_temperature,
            write_temperature=write_temperature,
            state_fp32=state_fp32,
            slot_dropout=slot_dropout,
            normalize_k=normalize_k,

            use_rope_keys=use_rope_keys,
            rope_base=rope_base,
            use_alibi_write=use_alibi_write,
            alibi_strength_init=alibi_strength_init,
            learn_alibi_strength=learn_alibi_strength,
            min_strength=min_strength,

            use_content_read=use_content_read,
            content_read_init=content_read_init,
            content_read_max_gamma=content_read_max_gamma,

            use_slotspace_refine=use_slotspace_refine,
            slotspace_dim=slotspace_dim,
            slotspace_gate_init=slotspace_gate_init,
            slotspace_dropout=slotspace_dropout,
            slotspace_signed_weights=slotspace_signed_weights,

            use_rope_slotspace=use_rope_slotspace,
            rope_base_slotspace=rope_base_slotspace,

            write_chunk_size=write_chunk_size,
            slotspace_chunk_size=slotspace_chunk_size,
            enable_compiled=enable_compiled,

        )

        self.norm2 = nn.LayerNorm(embed_dim)
        hidden = int(embed_dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, hidden, bias=False),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden, embed_dim, bias=False),
            nn.Dropout(dropout),
        )

    def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, return_info: bool = False):
        a, info = self.asa(self.norm1(x), attention_mask=attention_mask, return_info=return_info)
        x = x + a
        x = x + self.mlp(self.norm2(x))
        return x, info


# ============================================================================
# LM
# ============================================================================
class ASMLanguageModel(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        embed_dim: int = 384,
        num_layers: int = 6,
        num_heads: int = 8,
        num_slots: int = 8,
        max_seq_len: int = 1024,
        mlp_ratio: float = 4.0,
        dropout: float = 0.1,

        # temperatures / numerics
        read_temperature: float = 1.0,
        write_temperature: float = 1.0,
        state_fp32: bool = True,
        slot_dropout: float = 0.05,
        normalize_k: bool = False,

        tie_weights: bool = True,

        # LM-level abs pos
        use_abs_pos: bool = False,

        # positions
        use_rope_keys: bool = True,
        rope_base: float = 10000.0,
        use_alibi_write: bool = True,

        # ALiBi
        alibi_strength_init: float = 0.1,
        learn_alibi_strength: bool = True,
        min_strength: float = 0.0,

        # content-conditioned read (gamma)
        use_content_read: bool = True,
        content_read_init: float = -4.0,
        content_read_max_gamma: float = 3.0,

        # optional slot-space refinement
        use_slotspace_refine: bool = True,
        slotspace_dim: int = 32,
        slotspace_gate_init: float = -10.0,
        slotspace_dropout: float = 0.0,
        slotspace_signed_weights: bool = True,

        # RoPE inside slot-space matcher
        use_rope_slotspace: bool = True,
        rope_base_slotspace: float = 100000.0,

        # chunk sizes
        write_chunk_size: int = 128,
        slotspace_chunk_size: int = 128,
        enable_compile: bool = False,
    ):
        super().__init__()
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.max_seq_len = max_seq_len
        self.use_abs_pos = bool(use_abs_pos)

        self.tok = nn.Embedding(vocab_size, embed_dim)
        self.pos = nn.Embedding(max_seq_len, embed_dim) if self.use_abs_pos else None
        self.drop = nn.Dropout(dropout)

        self.blocks = nn.ModuleList([
            ASMBlock(
                embed_dim=embed_dim,
                num_heads=num_heads,
                num_slots=num_slots,
                mlp_ratio=mlp_ratio,
                dropout=dropout,

                read_temperature=read_temperature,
                write_temperature=write_temperature,
                state_fp32=state_fp32,
                slot_dropout=slot_dropout,
                normalize_k=normalize_k,

                use_rope_keys=use_rope_keys,
                rope_base=rope_base,
                use_alibi_write=use_alibi_write,

                alibi_strength_init=alibi_strength_init,
                learn_alibi_strength=learn_alibi_strength,
                min_strength=min_strength,

                use_content_read=use_content_read,
                content_read_init=content_read_init,
                content_read_max_gamma=content_read_max_gamma,

                use_slotspace_refine=use_slotspace_refine,
                slotspace_dim=slotspace_dim,            slotspace_gate_init=slotspace_gate_init,
                slotspace_dropout=slotspace_dropout,
                slotspace_signed_weights=slotspace_signed_weights,
                use_rope_slotspace=use_rope_slotspace,
                rope_base_slotspace=rope_base_slotspace,

                write_chunk_size=write_chunk_size,
                slotspace_chunk_size=slotspace_chunk_size,
                enable_compile=enable_compile,
            )
            for _ in range(num_layers)
        ])

        self.norm = nn.LayerNorm(embed_dim)
        self.lm_head = nn.Linear(embed_dim, vocab_size, bias=False)
        if tie_weights:
            self.lm_head.weight = self.tok.weight

        self.apply(self._init)

    def _init(self, m):
        if isinstance(m, nn.Linear):
            nn.init.normal_(m.weight, std=0.02)
        elif isinstance(m, nn.Embedding):
            nn.init.normal_(m.weight, std=0.02)
        elif isinstance(m, nn.LayerNorm):
            nn.init.ones_(m.weight)
            nn.init.zeros_(m.bias)

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        return_info: bool = False,
    ):
        B, T = input_ids.shape
        assert T <= self.max_seq_len, f"T={T} exceeds max_seq_len={self.max_seq_len}"

        x = self.tok(input_ids)
        if self.use_abs_pos:
            pos = torch.arange(T, device=input_ids.device).unsqueeze(0).expand(B, -1)
            x = x + self.pos(pos)

        x = self.drop(x)

        infos = []
        for blk in self.blocks:
            x, info = blk(x, attention_mask=attention_mask, return_info=return_info)
            if return_info:
                infos.append(info)

        x = self.norm(x)
        logits = self.lm_head(x)
        return (logits, infos) if return_info else logits


# ============================================================================
# Convenience: build model from config
# ============================================================================
def build_model_from_cfg(cfg: ASMTrainConfig) -> ASMLanguageModel:
    return ASMLanguageModel(
        vocab_size=cfg.vocab_size,
        embed_dim=cfg.embed_dim,
        num_layers=cfg.num_layers,
        num_heads=cfg.num_heads,
        num_slots=cfg.num_slots,
        max_seq_len=cfg.max_seq_len,
        mlp_ratio=cfg.mlp_ratio,
        dropout=cfg.dropout,

        read_temperature=cfg.read_temperature,
        write_temperature=cfg.write_temperature,
        state_fp32=cfg.state_fp32,
        slot_dropout=cfg.slot_dropout,
        normalize_k=cfg.normalize_k,

        tie_weights=cfg.tie_weights,

        use_abs_pos=cfg.use_abs_pos,
        use_rope_keys=cfg.use_rope_keys,
        rope_base=cfg.rope_base,
        use_alibi_write=cfg.use_alibi_write,

        alibi_strength_init=cfg.alibi_strength_init,
        learn_alibi_strength=cfg.learn_alibi_strength,
        min_strength=cfg.min_strength,

        use_content_read=cfg.use_content_read,
        content_read_init=cfg.content_read_init,
        content_read_max_gamma=cfg.content_read_max_gamma,

        use_slotspace_refine=cfg.use_slotspace_refine,
        slotspace_dim=cfg.slotspace_dim,
        slotspace_gate_init=cfg.slotspace_gate_init,
        slotspace_dropout=cfg.slotspace_dropout,
        slotspace_signed_weights=cfg.slotspace_signed_weights,
        use_rope_slotspace=cfg.use_rope_slotspace,
        rope_base_slotspace=cfg.rope_base_slotspace,

        write_chunk_size=cfg.write_chunk_size,
        slotspace_chunk_size=cfg.slotspace_chunk_size,
        enable_compile=cfg.enable_compile,
    )

In [None]:


#@title LM and Config defs efficient
# ============================================================================
# Addressed State Models (ASM): Config + Block + LM
# - Naming aligned with paper: slots, read/write, slot-space refinement
# - No compatibility layer (fresh public tooling)
# - Assumes AddressedStateAttention is defined elsewhere (the primitive module)
# ============================================================================

import math
from dataclasses import dataclass
from typing import Tuple, Optional

import torch
import torch.nn as nn


# ============================================================================
# Config
# ============================================================================
@dataclass
class ASMTrainConfig:
    # Data
    dataset_name: str = "wikitext"
    dataset_config: str = "wikitext-103-raw-v1"
    tokenizer_name: str = "gpt2"

    max_seq_len: int = 256
    stride_frac_val: float = 0.50
    seed: int = 1337

    # Sample budgets
    train_samples_target: int = 100_000_000
    val_samples_target: int = 25_000

    # Training
    batch_size: int = 64
    learning_rate: float = 3e-4
    weight_decay: float = 0.01
    betas: Tuple[float, float] = (0.9, 0.95)
    grad_clip: float = 1.0
    warmup_steps: int = 1_000
    total_steps: int = 75_000
    eval_interval: int = 1_000
    log_interval: int = 100

    # Model
    vocab_size: int = 50257
    embed_dim: int = 384
    num_layers: int = 23
    num_heads: int = 8
    num_slots: int = 32
    mlp_ratio: float = 4.0
    dropout: float = 0.1
    tie_weights: bool = True

    # Addressed State Attention (ASA) / numerics
    read_temperature: float = 1.0
    write_temperature: float = 1.0
    slot_dropout: float = 0.05
    state_fp32: bool = True
    normalize_k: bool = False

    # Positions
    use_abs_pos: bool = False
    use_rope_keys: bool = True
    rope_base: float = 10000.0
    use_alibi_write: bool = True
    alibi_strength_init: float = 0.1
    learn_alibi_strength: bool = True
    min_strength: float = 0.0

    # Content-conditioned read term (gamma)
    use_content_read: bool = True
    content_read_init: float = -4.0
    content_read_max_gamma: float = 3.0

    # Optional slot-space refinement (formerly "k-space")
    use_slotspace_refine: bool = True
    slotspace_dim: int = 64
    slotspace_gate_init: float = -4.0
    slotspace_dropout: float = 0.05
    slotspace_signed_weights: bool = True

    # RoPE inside slot-space matcher (Q/K only)
    use_rope_slotspace: bool = True
    rope_base_slotspace: float = 100000.0

    # Perf knobs (behavior-identical)
    write_chunk_size: int = 128
    enable_compiled: bool = True

    # Analytics
    eval_max_batches: int = 150
    analytics_last_k: int = 4

    # IO / caches
    output_dir: str = "./drive/MyDrive/asm_outputs"
    tag: str = "asm_wikitext"
    cache_dir: str = "./drive/MyDrive/asm_caches"
    val_windows_cache: str = "./drive/MyDrive/asm_val_cache_windows_1024.pkl"


# ============================================================================
# Block
# ============================================================================
class ASMBlock(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        num_slots: int,
        mlp_ratio: float = 4.0,
        dropout: float = 0.1,

        # temperatures / numerics
        read_temperature: float = 1.0,
        write_temperature: float = 1.0,
        state_fp32: bool = True,
        slot_dropout: float = 0.0,
        normalize_k: bool = False,

        # positions
        use_rope_keys: bool = True,
        rope_base: float = 10000.0,
        use_alibi_write: bool = True,

        # ALiBi params
        alibi_strength_init: float = 0.1,
        learn_alibi_strength: bool = True,
        min_strength: float = 0.0,

        # content-conditioned read (gamma)
        use_content_read: bool = True,
        content_read_init: float = -4.0,
        content_read_max_gamma: float = 3.0,

        # optional slot-space refinement
        use_slotspace_refine: bool = True,
        slotspace_dim: int = 32,
        slotspace_gate_init: float = -10.0,
        slotspace_dropout: float = 0.0,
        slotspace_signed_weights: bool = True,

        # RoPE inside slot-space matcher
        use_rope_slotspace: bool = True,
        rope_base_slotspace: float = 100000.0,

        # chunk sizes
        write_chunk_size: int = 128,
        enable_compiled: bool = False,
    ):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)

        self.asa = AddressedStateAttention(
            embed_dim=embed_dim,
            num_heads=num_heads,
            num_slots=num_slots,
            dropout=dropout,

            read_temperature=read_temperature,
            write_temperature=write_temperature,
            state_fp32=state_fp32,
            slot_dropout=slot_dropout,
            normalize_k=normalize_k,

            use_rope_keys=use_rope_keys,
            rope_base=rope_base,
            use_alibi_write=use_alibi_write,
            alibi_strength_init=alibi_strength_init,
            learn_alibi_strength=learn_alibi_strength,
            min_strength=min_strength,

            use_content_read=use_content_read,
            content_read_init=content_read_init,
            content_read_max_gamma=content_read_max_gamma,

            use_slotspace_refine=use_slotspace_refine,
            slotspace_dim=slotspace_dim,
            slotspace_gate_init=slotspace_gate_init,
            slotspace_dropout=slotspace_dropout,
            slotspace_signed_weights=slotspace_signed_weights,

            use_rope_slotspace=use_rope_slotspace,
            rope_base_slotspace=rope_base_slotspace,

            write_chunk_size=write_chunk_size,
            enable_compiled=enable_compiled,

        )

        self.norm2 = nn.LayerNorm(embed_dim)
        hidden = int(embed_dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, hidden, bias=False),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden, embed_dim, bias=False),
            nn.Dropout(dropout),
        )

    def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, return_info: bool = False, return_light_stats: Optional[bool] = None):
        a, info = self.asa(self.norm1(x), attention_mask=attention_mask, return_info=return_info, return_light_stats=return_light_stats)
        x = x + a
        x = x + self.mlp(self.norm2(x))
        return x, info


# ============================================================================
# LM
# ============================================================================
class ASMLanguageModel(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        embed_dim: int = 384,
        num_layers: int = 6,
        num_heads: int = 8,
        num_slots: int = 8,
        max_seq_len: int = 1024,
        mlp_ratio: float = 4.0,
        dropout: float = 0.1,

        # temperatures / numerics
        read_temperature: float = 1.0,
        write_temperature: float = 1.0,
        state_fp32: bool = True,
        slot_dropout: float = 0.05,
        normalize_k: bool = False,

        tie_weights: bool = True,

        # LM-level abs pos
        use_abs_pos: bool = False,

        # positions
        use_rope_keys: bool = True,
        rope_base: float = 10000.0,
        use_alibi_write: bool = True,

        # ALiBi
        alibi_strength_init: float = 0.1,
        learn_alibi_strength: bool = True,
        min_strength: float = 0.0,

        # content-conditioned read (gamma)
        use_content_read: bool = True,
        content_read_init: float = -4.0,
        content_read_max_gamma: float = 3.0,

        # optional slot-space refinement
        use_slotspace_refine: bool = True,
        slotspace_dim: int = 32,
        slotspace_gate_init: float = -10.0,
        slotspace_dropout: float = 0.0,
        slotspace_signed_weights: bool = True,

        # RoPE inside slot-space matcher
        use_rope_slotspace: bool = True,
        rope_base_slotspace: float = 100000.0,

        # chunk sizes
        write_chunk_size: int = 128,
        enable_compiled: bool = False,
    ):
        super().__init__()
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.max_seq_len = max_seq_len
        self.use_abs_pos = bool(use_abs_pos)

        self.tok = nn.Embedding(vocab_size, embed_dim)
        self.pos = nn.Embedding(max_seq_len, embed_dim) if self.use_abs_pos else None
        self.drop = nn.Dropout(dropout)

        self.blocks = nn.ModuleList([
            ASMBlock(
                embed_dim=embed_dim,
                num_heads=num_heads,
                num_slots=num_slots,
                mlp_ratio=mlp_ratio,
                dropout=dropout,

                read_temperature=read_temperature,
                write_temperature=write_temperature,
                state_fp32=state_fp32,
                slot_dropout=slot_dropout,
                normalize_k=normalize_k,

                use_rope_keys=use_rope_keys,
                rope_base=rope_base,
                use_alibi_write=use_alibi_write,

                alibi_strength_init=alibi_strength_init,
                learn_alibi_strength=learn_alibi_strength,
                min_strength=min_strength,

                use_content_read=use_content_read,
                content_read_init=content_read_init,
                content_read_max_gamma=content_read_max_gamma,

                use_slotspace_refine=use_slotspace_refine,
                slotspace_dim=slotspace_dim,            slotspace_gate_init=slotspace_gate_init,
                slotspace_dropout=slotspace_dropout,
                slotspace_signed_weights=slotspace_signed_weights,
                use_rope_slotspace=use_rope_slotspace,
                rope_base_slotspace=rope_base_slotspace,

                write_chunk_size=write_chunk_size,
                enable_compiled=enable_compiled,
            )
            for _ in range(num_layers)
        ])

        self.norm = nn.LayerNorm(embed_dim)
        self.lm_head = nn.Linear(embed_dim, vocab_size, bias=False)
        if tie_weights:
            self.lm_head.weight = self.tok.weight

        self.apply(self._init)

    def _init(self, m):
        if isinstance(m, nn.Linear):
            nn.init.normal_(m.weight, std=0.02)
        elif isinstance(m, nn.Embedding):
            nn.init.normal_(m.weight, std=0.02)
        elif isinstance(m, nn.LayerNorm):
            nn.init.ones_(m.weight)
            nn.init.zeros_(m.bias)

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        return_info: bool = False,
        return_light_stats: Optional[bool] = None,
    ):
        B, T = input_ids.shape
        assert T <= self.max_seq_len, f"T={T} exceeds max_seq_len={self.max_seq_len}"

        x = self.tok(input_ids)
        if self.use_abs_pos:
            pos = torch.arange(T, device=input_ids.device).unsqueeze(0).expand(B, -1)
            x = x + self.pos(pos)

        x = self.drop(x)

        infos = []
        for blk in self.blocks:
            x, info = blk(x, attention_mask=attention_mask, return_info=return_info, return_light_stats=return_light_stats)
            if return_info:
                infos.append(info)

        x = self.norm(x)
        logits = self.lm_head(x)
        return (logits, infos) if return_info else logits


# ============================================================================
# Convenience: build model from config
# ============================================================================
def build_model_from_cfg(cfg: ASMTrainConfig) -> ASMLanguageModel:
    return ASMLanguageModel(
        vocab_size=cfg.vocab_size,
        embed_dim=cfg.embed_dim,
        num_layers=cfg.num_layers,
        num_heads=cfg.num_heads,
        num_slots=cfg.num_slots,
        max_seq_len=cfg.max_seq_len,
        mlp_ratio=cfg.mlp_ratio,
        dropout=cfg.dropout,

        read_temperature=cfg.read_temperature,
        write_temperature=cfg.write_temperature,
        state_fp32=cfg.state_fp32,
        slot_dropout=cfg.slot_dropout,
        normalize_k=cfg.normalize_k,

        tie_weights=cfg.tie_weights,

        use_abs_pos=cfg.use_abs_pos,
        use_rope_keys=cfg.use_rope_keys,
        rope_base=cfg.rope_base,
        use_alibi_write=cfg.use_alibi_write,

        alibi_strength_init=cfg.alibi_strength_init,
        learn_alibi_strength=cfg.learn_alibi_strength,
        min_strength=cfg.min_strength,

        use_content_read=cfg.use_content_read,
        content_read_init=cfg.content_read_init,
        content_read_max_gamma=cfg.content_read_max_gamma,

        use_slotspace_refine=cfg.use_slotspace_refine,
        slotspace_dim=cfg.slotspace_dim,
        slotspace_gate_init=cfg.slotspace_gate_init,
        slotspace_dropout=cfg.slotspace_dropout,
        slotspace_signed_weights=cfg.slotspace_signed_weights,
        use_rope_slotspace=cfg.use_rope_slotspace,
        rope_base_slotspace=cfg.rope_base_slotspace,

        write_chunk_size=cfg.write_chunk_size,
        enable_compiled=cfg.enable_compiled,
    )

In [None]:
#@title Train, Eval Functions and Utilities

import os, math, random, pickle, time
from dataclasses import asdict
from typing import Tuple, Dict, List, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, IterableDataset, DataLoader
from tqdm.auto import tqdm

from datasets import load_dataset
from transformers import AutoTokenizer


# =========================================================
# Data: cached token streams + val windows + random-window train
# =========================================================
class StableValidationDataset(Dataset):
    def __init__(self, samples: List[Tuple[torch.Tensor, torch.Tensor]]):
        self.samples = samples
    def __len__(self): return len(self.samples)
    def __getitem__(self, idx): return self.samples[idx]

def _ensure_dir(path: str):
    d = os.path.dirname(path)
    if d:
        os.makedirs(d, exist_ok=True)

def build_or_load_token_stream(
    *,
    cache_path: str,
    dataset_name: str,
    dataset_config: str,
    split: str,
    tokenizer_name: str,
    min_chars: int = 1,
    add_eos_between_rows: bool = True,
    max_rows: Optional[int] = None,
) -> List[int]:
    if os.path.exists(cache_path):
        print(f"Loading cached token stream: {cache_path}")
        with open(cache_path, "rb") as f:
            stream = pickle.load(f)
        print(f"Loaded token stream tokens={len(stream):,}")
        return stream

    print(f"Building token stream for {dataset_name}/{dataset_config} split={split} ...")
    _ensure_dir(cache_path)
    tok = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=True)
    eos = tok.eos_token_id
    assert eos is not None, "Tokenizer must have eos_token_id"

    ds = load_dataset(dataset_name, dataset_config, split=split)

    stream: List[int] = []
    used = 0
    for row in tqdm(ds, desc=f"Tokenizing {split}"):
        if max_rows is not None and used >= max_rows:
            break
        text = (row.get("text") or "").strip()
        if len(text) < min_chars:
            continue
        ids = tok.encode(text, add_special_tokens=False)
        if not ids:
            continue
        stream.extend(ids)
        if add_eos_between_rows:
            stream.append(eos)
        used += 1

    print(f"Built stream: rows_used={used:,} tokens={len(stream):,}")
    with open(cache_path, "wb") as f:
        pickle.dump(stream, f)
    print(f"Cached token stream to {cache_path}")
    return stream

def build_or_load_validation_windows(
    *,
    cache_path: str,
    token_stream: List[int],
    max_seq_len: int,
    stride_frac: float,
    val_samples_target: int,
) -> StableValidationDataset:
    if os.path.exists(cache_path):
        print(f"Loading cached validation windows: {cache_path}")
        with open(cache_path, "rb") as f:
            samples = pickle.load(f)
        print(f"Loaded val windows: {len(samples)}")
        return StableValidationDataset(samples)

    print("Building validation windows (cached)...")
    _ensure_dir(cache_path)

    T = int(max_seq_len)
    stride = max(1, int(T * float(stride_frac)))
    need = int(val_samples_target)

    max_start = len(token_stream) - (T + 1)
    if max_start <= 0:
        raise ValueError("Validation token stream too small for max_seq_len+1")

    samples: List[Tuple[torch.Tensor, torch.Tensor]] = []
    for start in tqdm(range(0, max_start + 1, stride), desc="Chunking val stream"):
        chunk = token_stream[start:start + T + 1]
        if len(chunk) < T + 1:
            break
        x = torch.tensor(chunk[:-1], dtype=torch.long)
        y = torch.tensor(chunk[1:], dtype=torch.long)
        samples.append((x, y))
        if len(samples) >= need:
            break

    print(f"Built val windows={len(samples)} (stride={stride}, stream_tokens={len(token_stream):,})")
    with open(cache_path, "wb") as f:
        pickle.dump(samples, f)
    print(f"Cached validation windows to {cache_path}")
    return StableValidationDataset(samples)

class WikiTextRandomWindowStream(IterableDataset):
    def __init__(self, token_stream: List[int], max_seq_len: int, train_samples_target: int, seed: int):
        super().__init__()
        self.stream = token_stream
        self.T = int(max_seq_len)
        self.target = int(train_samples_target)
        self.seed = int(seed)
        self.max_start = len(self.stream) - (self.T + 1)
        if self.max_start <= 0:
            raise ValueError("Train token stream too small for max_seq_len+1")

    def __iter__(self):
        info = torch.utils.data.get_worker_info()
        wid = info.id if info is not None else 0
        rng = random.Random(self.seed + 17 * wid)

        yielded = 0
        while yielded < self.target:
            start = rng.randint(0, self.max_start)
            chunk = self.stream[start:start + self.T + 1]
            x = torch.tensor(chunk[:-1], dtype=torch.long)
            y = torch.tensor(chunk[1:], dtype=torch.long)
            yield x, y
            yielded += 1


# =========================================================
# LR schedule
# =========================================================
class WarmupCosine:
    def __init__(self, opt, warmup_steps, total_steps, base_lr):
        self.opt = opt
        self.warmup = int(warmup_steps)
        self.total = int(total_steps)
        self.base = float(base_lr)
        self.step_num = 0

    def step(self):
        self.step_num += 1
        if self.step_num <= self.warmup:
            lr = self.base * self.step_num / max(1, self.warmup)
        else:
            prog = (self.step_num - self.warmup) / max(1, (self.total - self.warmup))
            lr = self.base * 0.5 * (1 + math.cos(math.pi * min(1.0, prog)))
        for g in self.opt.param_groups:
            g["lr"] = lr
        return lr


# =========================================================
# Light metrics helpers (fast + tolerant)
# =========================================================
@torch.no_grad()
def _entropy_mean(read_w: torch.Tensor) -> float:
    # read_w: [B,H,T,K]
    eps = 1e-8
    p = read_w.clamp_min(eps)
    ent = -(p * p.log()).sum(dim=-1)  # [B,H,T]
    return float(ent.mean().detach().cpu().item())

@torch.no_grad()
def _top1freq_mean(read_w: torch.Tensor) -> float:
    # fraction of tokens assigned to most common top1 slot (averaged over batch+heads)
    top1 = read_w.argmax(dim=-1)  # [B,H,T]
    flat = top1.reshape(-1).detach().cpu()
    K = read_w.shape[-1]
    hist = torch.bincount(flat, minlength=K).float()
    denom = hist.sum().clamp_min(1.0)
    return float((hist.max() / denom).item())

@torch.no_grad()
def _write_stats(write_logits: torch.Tensor, last_k: int) -> Tuple[float, float]:
    # write_logits: [B,H,K,T] (already biased/masked)
    w = torch.softmax(write_logits, dim=-1)
    T = w.shape[-1]
    pos = torch.arange(T, device=w.device, dtype=w.dtype).view(1, 1, 1, T)
    com = (w * pos).sum(dim=-1) / max(1.0, float(T - 1))   # [B,H,K] normalized 0..1
    lastk = min(max(1, int(last_k)), T)
    lastk_mass = w[..., -lastk:].sum(dim=-1)               # [B,H,K]
    return float(com.mean().detach().cpu().item()), float(lastk_mass.mean().detach().cpu().item())

@torch.no_grad()
def _get_scalar(info: Dict, key: str) -> Optional[float]:
    v = info.get(key, None)
    if v is None:
        return None
    if isinstance(v, torch.Tensor):
        if v.numel() == 1:
            return float(v.detach().cpu().item())
        return float(v.detach().float().mean().cpu().item())
    if isinstance(v, (int, float)):
        return float(v)
    return None

@torch.no_grad()
def _layer_param_summaries(model) -> Dict[str, float]:
    """
    Truthy parameter summaries:
      - content_read_gamma uses softplus(raw) then clamped to content_read_max_gamma if present
      - slotspace_gate uses softplus(raw)
      - alibi_strength uses module's _alibi_strength when available
    """
    gammas = []
    gates = []
    alibis = []

    for blk in getattr(model, "blocks", []):
        attn = getattr(blk, "asa", None)
        if attn is None:
            continue

        # content read gamma (best-effort)
        if hasattr(attn, "_content_read_gamma_raw"):
            g = F.softplus(attn._content_read_gamma_raw.detach().float())
            mx = getattr(attn, "content_read_max_gamma", None)
            if mx is not None and float(mx) > 0:
                g = g.clamp(max=float(mx))
            gammas.append(float(g.cpu().item()))

        # slot-space gate
        if hasattr(attn, "_slotspace_gate_raw"):
            kg = F.softplus(attn._slotspace_gate_raw.detach().float())
            gates.append(float(kg.cpu().item()))

        # alibi strength (actual used)
        if hasattr(attn, "_alibi_strength"):
            try:
                a = attn._alibi_strength(dtype=torch.float32, device=next(attn.parameters()).device).detach().cpu().item()
                alibis.append(float(a))
            except Exception:
                pass

    out: Dict[str, float] = {}
    if gammas:
        t = torch.tensor(gammas)
        out["content_read_gamma_mean"] = float(t.mean().item())
        out["content_read_gamma_min"]  = float(t.min().item())
        out["content_read_gamma_max"]  = float(t.max().item())
    if gates:
        t = torch.tensor(gates)
        out["slotspace_gate_mean"] = float(t.mean().item())
        out["slotspace_gate_min"]  = float(t.min().item())
        out["slotspace_gate_max"]  = float(t.max().item())
    if alibis:
        t = torch.tensor(alibis)
        out["alibi_strength_mean"] = float(t.mean().item())
        out["alibi_strength_min"]  = float(t.min().item())
        out["alibi_strength_max"]  = float(t.max().item())
    return out


# =========================================================
# Eval
# =========================================================
@torch.no_grad()
def evaluate(model, val_loader, max_batches=50, last_k=8):
    model.eval()

    losses = []
    infos = None
    ent_acc = 0.0
    top1_acc = 0.0
    com_acc = 0.0
    lastk_acc = 0.0

    delta_acc = 0.0
    delta_n = 0
    n_batches = 0

    for i, (xb, yb) in enumerate(val_loader):
        if i >= max_batches:
            break
        xb, yb = xb.to(device, non_blocking=True), yb.to(device, non_blocking=True)

        with torch.autocast(device_type=device.type, dtype=amp_dtype, enabled=use_amp):
            logits = model(xb, return_info=False)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), yb.view(-1))
        losses.append(float(loss.item()))

        # Light metrics: layer-avg per batch
        if infos and infos[0] is not None:
            eL, tL, cL, lL = [], [], [], []
            dL = []
            for info in infos:
                read_w = info.get("read_weights", None)
                wl     = info.get("write_logits", None)

                if read_w is not None:
                    eL.append(_entropy_mean(read_w))
                    tL.append(_top1freq_mean(read_w))
                if wl is not None:
                    com, lastm = _write_stats(wl.to(torch.float32), last_k=last_k)
                    cL.append(com)
                    lL.append(lastm)

                dd = _get_scalar(info, "slotspace_delta_norm")
                if dd is not None:
                    dL.append(dd)

            if eL:
                ent_acc += sum(eL) / len(eL)
                top1_acc += sum(tL) / len(tL)
            if cL:
                com_acc += sum(cL) / len(cL)
                lastk_acc += sum(lL) / len(lL)
            if dL:
                delta_acc += sum(dL) / len(dL)
                delta_n += 1

            n_batches += 1

    mean = sum(losses) / max(1, len(losses))
    ppl = float(math.exp(min(20.0, mean)))

    stats: Dict[str, float] = {}
    if n_batches > 0:
        stats["entropy_mean"] = ent_acc / n_batches
        stats["top1freq_mean"] = top1_acc / n_batches
        stats["write_com_mean"] = com_acc / n_batches
        stats["write_lastk_mass_mean"] = lastk_acc / n_batches
    if delta_n > 0:
        stats["slotspace_delta_norm"] = delta_acc / delta_n

    # Parameter summaries (truthy, cheap)
    try:
        stats.update(_layer_param_summaries(model))
    except Exception:
        pass

    model.train()
    return mean, ppl, stats


# =========================================================
# Checkpointing
# =========================================================
def save_ckpt(path, cfg, model, opt, step, best_val):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    torch.save(
        {"cfg": asdict(cfg), "model": model.state_dict(), "opt": opt.state_dict(),
         "step": step, "best_val": best_val},
        path,
    )
    print(f"âœ“ Saved {path}")


# =========================================================
# Pretty stats formatting (not slow)
# =========================================================
def _fmt_stats(stats: Dict[str, float], last_k: int) -> str:
    keys = [
        "alibi_strength_mean",
        "entropy_mean",
        "top1freq_mean",
        "write_com_mean",
        "write_lastk_mass_mean",
        "content_read_gamma_mean",
        "content_read_gamma_max",
        "slotspace_gate_mean",
        "slotspace_gate_max",
        "slotspace_delta_norm",
    ]
    parts = []
    for k in keys:
        if k in stats:
            if k == "write_lastk_mass_mean":
                parts.append(f"{k}(last_k={last_k})={stats[k]:.4f}")
            else:
                parts.append(f"{k}={stats[k]:.4f}")
    return " | ".join(parts)


# =========================================================
# Train
# =========================================================
def train_asm(cfg: ASMTrainConfig):
    random.seed(cfg.seed)
    torch.manual_seed(cfg.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(cfg.seed)

    # ---------- Data prep (cached streams) ----------
    os.makedirs(cfg.cache_dir, exist_ok=True)
    train_stream_cache = os.path.join(cfg.cache_dir, f"{cfg.dataset_config}_train_stream.pkl")
    val_stream_cache   = os.path.join(cfg.cache_dir, f"{cfg.dataset_config}_val_stream.pkl")

    train_stream = build_or_load_token_stream(
        cache_path=train_stream_cache,
        dataset_name=cfg.dataset_name,
        dataset_config=cfg.dataset_config,
        split="train",
        tokenizer_name=cfg.tokenizer_name,
        min_chars=1,
        add_eos_between_rows=True,
    )
    val_stream = build_or_load_token_stream(
        cache_path=val_stream_cache,
        dataset_name=cfg.dataset_name,
        dataset_config=cfg.dataset_config,
        split="validation",
        tokenizer_name=cfg.tokenizer_name,
        min_chars=1,
        add_eos_between_rows=True,
    )

    val_dataset = build_or_load_validation_windows(
        cache_path=cfg.val_windows_cache,
        token_stream=val_stream,
        max_seq_len=cfg.max_seq_len,
        stride_frac=cfg.stride_frac_val,
        val_samples_target=cfg.val_samples_target,
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=cfg.batch_size,
        shuffle=False,
        num_workers=0,
        pin_memory=torch.cuda.is_available(),
    )

    train_ds = WikiTextRandomWindowStream(
        token_stream=train_stream,
        max_seq_len=cfg.max_seq_len,
        train_samples_target=cfg.train_samples_target,
        seed=cfg.seed,
    )
    train_loader = DataLoader(
        train_ds,
        batch_size=cfg.batch_size,
        num_workers=3,
        pin_memory=torch.cuda.is_available(),
    )

    # ---------- Model ----------
    model = ASMLanguageModel(
        vocab_size=cfg.vocab_size,
        embed_dim=cfg.embed_dim,
        num_layers=cfg.num_layers,
        num_heads=cfg.num_heads,
        num_slots=cfg.num_slots,
        max_seq_len=cfg.max_seq_len,
        mlp_ratio=cfg.mlp_ratio,
        dropout=cfg.dropout,

        read_temperature=cfg.read_temperature,
        write_temperature=cfg.write_temperature,
        state_fp32=cfg.state_fp32,
        slot_dropout=cfg.slot_dropout,
        normalize_k=cfg.normalize_k,

        tie_weights=cfg.tie_weights,

        use_abs_pos=cfg.use_abs_pos,

        use_rope_keys=cfg.use_rope_keys,
        rope_base=cfg.rope_base,
        use_alibi_write=cfg.use_alibi_write,
        alibi_strength_init=cfg.alibi_strength_init,
        learn_alibi_strength=cfg.learn_alibi_strength,
        min_strength=cfg.min_strength,

        use_content_read=cfg.use_content_read,
        content_read_init=cfg.content_read_init,
        content_read_max_gamma=cfg.content_read_max_gamma,

        use_slotspace_refine=cfg.use_slotspace_refine,
        slotspace_dim=cfg.slotspace_dim,
        slotspace_gate_init=cfg.slotspace_gate_init,
        slotspace_dropout=cfg.slotspace_dropout,
        slotspace_signed_weights=cfg.slotspace_signed_weights,

        use_rope_slotspace=cfg.use_rope_slotspace,
        rope_base_slotspace=cfg.rope_base_slotspace,

        write_chunk_size=cfg.write_chunk_size,
        enable_compiled=cfg.enable_compiled,
        #slotspace_chunk_size=cfg.slotspace_chunk_size,
    ).to(device)

    out_dir = os.path.join(cfg.output_dir, cfg.tag)
    os.makedirs(out_dir, exist_ok=True)

    n_params = sum(p.numel() for p in model.parameters())
    print("=" * 108)
    print(f"Training [{cfg.tag}] on {cfg.dataset_name}/{cfg.dataset_config}")
    print(f"Params: {n_params:,}")
    print(f"Train tokens: {len(train_stream):,} | Val tokens: {len(val_stream):,} | Val windows: {len(val_dataset):,}")
    print(f"T={cfg.max_seq_len} | val_stride_frac={cfg.stride_frac_val} | last_k={cfg.analytics_last_k}")
    #print(f"Chunks: write={cfg.write_chunk_size} | slotspace={cfg.slotspace_chunk_size} | amp={use_amp}({amp_dtype}) | state_fp32={cfg.state_fp32}")
    print(f"Chunks: write={cfg.write_chunk_size} | amp={use_amp}({amp_dtype}) | state_fp32={cfg.state_fp32}")

    print(f"RoPE: keys={cfg.use_rope_keys}(base={cfg.rope_base:g}) | slotspace={cfg.use_rope_slotspace}(base={cfg.rope_base_slotspace:g})")
    print("=" * 108)

    opt = torch.optim.AdamW(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay, betas=cfg.betas)
    sched = WarmupCosine(opt, cfg.warmup_steps, cfg.total_steps, cfg.learning_rate)

    # ---------- Initial eval ----------
    best_val = float("inf")
    vloss, vppl, vstats = evaluate(model, val_loader, max_batches=cfg.eval_max_batches, last_k=cfg.analytics_last_k)
    best_val = vloss
    save_ckpt(os.path.join(out_dir, "best.pt"), cfg, model, opt, 0, best_val)

    print(f"[VAL step 0] loss={vloss:.3f} ppl={vppl:.2f}")
    if vstats:
        print("  " + _fmt_stats(vstats, last_k=cfg.analytics_last_k))

    # ---------- Training loop ----------
    running = 0.0
    step = 0
    t_last = time.time()

    pbar = tqdm(total=cfg.total_steps, desc=f"[{cfg.tag}]")
    for xb, yb in train_loader:
        xb = xb.to(device, non_blocking=True)
        yb = yb.to(device, non_blocking=True)

        with torch.autocast(device_type=device.type, dtype=amp_dtype, enabled=use_amp):
            logits = model(xb)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), yb.view(-1))

        opt.zero_grad(set_to_none=True)
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
        opt.step()
        lr = sched.step()

        step += 1
        running += float(loss.item())
        pbar.update(1)

        if step % cfg.log_interval == 0:
            avg = running / cfg.log_interval
            running = 0.0

            ps = _layer_param_summaries(model)

            it_s = cfg.log_interval / max(1e-9, (time.time() - t_last))
            t_last = time.time()

            postfix = {
                "loss": f"{avg:.3f}",
                "ppl": f"{math.exp(min(20.0, avg)):.2f}",
                "lr": f"{lr:.2e}",
                "it/s": f"{it_s:.2f}",
            }
            if "content_read_gamma_mean" in ps: postfix["Î³Î¼"] = f"{ps['content_read_gamma_mean']:.3f}"
            if "slotspace_gate_mean" in ps: postfix["sgÎ¼"] = f"{ps['slotspace_gate_mean']:.3f}"
            pbar.set_postfix(postfix)

            msg = f"[step {step}] train_loss={avg:.3f} ppl={math.exp(min(20.0, avg)):.2f} lr={lr:.2e} it/s={it_s:.2f}"
            if "content_read_gamma_mean" in ps: msg += f" | content_read_gamma_mean={ps['content_read_gamma_mean']:.4f}"
            if "slotspace_gate_mean" in ps: msg += f" | slotspace_gate_mean={ps['slotspace_gate_mean']:.4f}"
            if "alibi_strength_mean" in ps: msg += f" | alibi_strength_mean={ps['alibi_strength_mean']:.4f}"
            print(msg)

        if step % cfg.eval_interval == 0:
            vloss, vppl, vstats = evaluate(model, val_loader, max_batches=cfg.eval_max_batches, last_k=cfg.analytics_last_k)
            print(f"\n[VAL step {step}] loss={vloss:.3f} ppl={vppl:.2f}")
            if vstats:
                print("  " + _fmt_stats(vstats, last_k=cfg.analytics_last_k))

            if vloss < best_val:
                best_val = vloss
                save_ckpt(os.path.join(out_dir, "best.pt"), cfg, model, opt, step, best_val)

        if step >= cfg.total_steps:
            break

    save_ckpt(os.path.join(out_dir, "final.pt"), cfg, model, opt, step, best_val)
    print(f"[{cfg.tag}] Done. Best val loss: {best_val:.4f}")
    return model

In [None]:

#@title Train, Eval Functions and Utilities â€” Gradient Accumulation (no checkpointing)

import os, math, random, pickle, time
from dataclasses import asdict
from typing import Tuple, Dict, List, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, IterableDataset, DataLoader
from tqdm.auto import tqdm

from datasets import load_dataset
from transformers import AutoTokenizer


# =========================================================
# Data: cached token streams + val windows + random-window train
# =========================================================
class StableValidationDataset(Dataset):
    def __init__(self, samples: List[Tuple[torch.Tensor, torch.Tensor]]):
        self.samples = samples
    def __len__(self): return len(self.samples)
    def __getitem__(self, idx): return self.samples[idx]

def _ensure_dir(path: str):
    d = os.path.dirname(path)
    if d:
        os.makedirs(d, exist_ok=True)

def build_or_load_token_stream(
    *,
    cache_path: str,
    dataset_name: str,
    dataset_config: str,
    split: str,
    tokenizer_name: str,
    min_chars: int = 1,
    add_eos_between_rows: bool = True,
    max_rows: Optional[int] = None,
) -> List[int]:
    if os.path.exists(cache_path):
        print(f"Loading cached token stream: {cache_path}")
        with open(cache_path, "rb") as f:
            stream = pickle.load(f)
        print(f"Loaded token stream tokens={len(stream):,}")
        return stream

    print(f"Building token stream for {dataset_name}/{dataset_config} split={split} ...")
    _ensure_dir(cache_path)
    tok = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=True)
    eos = tok.eos_token_id
    assert eos is not None, "Tokenizer must have eos_token_id"

    ds = load_dataset(dataset_name, dataset_config, split=split)

    stream: List[int] = []
    used = 0
    for row in tqdm(ds, desc=f"Tokenizing {split}"):
        if max_rows is not None and used >= max_rows:
            break
        text = (row.get("text") or "").strip()
        if len(text) < min_chars:
            continue
        ids = tok.encode(text, add_special_tokens=False)
        if not ids:
            continue
        stream.extend(ids)
        if add_eos_between_rows:
            stream.append(eos)
        used += 1

    print(f"Built stream: rows_used={used:,} tokens={len(stream):,}")
    with open(cache_path, "wb") as f:
        pickle.dump(stream, f)
    print(f"Cached token stream to {cache_path}")
    return stream

def build_or_load_validation_windows(
    *,
    cache_path: str,
    token_stream: List[int],
    max_seq_len: int,
    stride_frac: float,
    val_samples_target: int,
) -> StableValidationDataset:
    if os.path.exists(cache_path):
        print(f"Loading cached validation windows: {cache_path}")
        with open(cache_path, "rb") as f:
            samples = pickle.load(f)
        print(f"Loaded val windows: {len(samples)}")
        return StableValidationDataset(samples)

    print("Building validation windows (cached)...")
    _ensure_dir(cache_path)

    T = int(max_seq_len)
    stride = max(1, int(T * float(stride_frac)))
    need = int(val_samples_target)

    max_start = len(token_stream) - (T + 1)
    if max_start <= 0:
        raise ValueError("Validation token stream too small for max_seq_len+1")

    samples: List[Tuple[torch.Tensor, torch.Tensor]] = []
    for start in tqdm(range(0, max_start + 1, stride), desc="Chunking val stream"):
        chunk = token_stream[start:start + T + 1]
        if len(chunk) < T + 1:
            break
        x = torch.tensor(chunk[:-1], dtype=torch.long)
        y = torch.tensor(chunk[1:], dtype=torch.long)
        samples.append((x, y))
        if len(samples) >= need:
            break

    print(f"Built val windows={len(samples)} (stride={stride}, stream_tokens={len(token_stream):,})")
    with open(cache_path, "wb") as f:
        pickle.dump(samples, f)
    print(f"Cached validation windows to {cache_path}")
    return StableValidationDataset(samples)

class WikiTextRandomWindowStream(IterableDataset):
    def __init__(self, token_stream: List[int], max_seq_len: int, train_samples_target: int, seed: int):
        super().__init__()
        self.stream = token_stream
        self.T = int(max_seq_len)
        self.target = int(train_samples_target)
        self.seed = int(seed)
        self.max_start = len(self.stream) - (self.T + 1)
        if self.max_start <= 0:
            raise ValueError("Train token stream too small for max_seq_len+1")

    def __iter__(self):
        info = torch.utils.data.get_worker_info()
        wid = info.id if info is not None else 0
        rng = random.Random(self.seed + 17 * wid)

        yielded = 0
        while yielded < self.target:
            start = rng.randint(0, self.max_start)
            chunk = self.stream[start:start + self.T + 1]
            x = torch.tensor(chunk[:-1], dtype=torch.long)
            y = torch.tensor(chunk[1:], dtype=torch.long)
            yield x, y
            yielded += 1


# =========================================================
# LR schedule
# =========================================================
class WarmupCosine:
    def __init__(self, opt, warmup_steps, total_steps, base_lr):
        self.opt = opt
        self.warmup = int(warmup_steps)
        self.total = int(total_steps)
        self.base = float(base_lr)
        self.step_num = 0

    def step(self):
        self.step_num += 1
        if self.step_num <= self.warmup:
            lr = self.base * self.step_num / max(1, self.warmup)
        else:
            prog = (self.step_num - self.warmup) / max(1, (self.total - self.warmup))
            lr = self.base * 0.5 * (1 + math.cos(math.pi * min(1.0, prog)))
        for g in self.opt.param_groups:
            g["lr"] = lr
        return lr


# =========================================================
# Light metrics helpers (fast + tolerant)
# =========================================================
@torch.no_grad()
def _entropy_mean(read_w: torch.Tensor) -> float:
    eps = 1e-8
    p = read_w.clamp_min(eps)
    ent = -(p * p.log()).sum(dim=-1)
    return float(ent.mean().detach().cpu().item())

@torch.no_grad()
def _top1freq_mean(read_w: torch.Tensor) -> float:
    top1 = read_w.argmax(dim=-1)
    flat = top1.reshape(-1).detach().cpu()
    K = read_w.shape[-1]
    hist = torch.bincount(flat, minlength=K).float()
    denom = hist.sum().clamp_min(1.0)
    return float((hist.max() / denom).item())

@torch.no_grad()
def _write_stats(write_logits: torch.Tensor, last_k: int) -> Tuple[float, float]:
    w = torch.softmax(write_logits, dim=-1)
    T = w.shape[-1]
    pos = torch.arange(T, device=w.device, dtype=w.dtype).view(1, 1, 1, T)
    com = (w * pos).sum(dim=-1) / max(1.0, float(T - 1))
    lastk = min(max(1, int(last_k)), T)
    lastk_mass = w[..., -lastk:].sum(dim=-1)
    return float(com.mean().detach().cpu().item()), float(lastk_mass.mean().detach().cpu().item())

@torch.no_grad()
def _get_scalar(info: Dict, key: str) -> Optional[float]:
    v = info.get(key, None)
    if v is None:
        return None
    if isinstance(v, torch.Tensor):
        if v.numel() == 1:
            return float(v.detach().cpu().item())
        return float(v.detach().float().mean().cpu().item())
    if isinstance(v, (int, float)):
        return float(v)
    return None

@torch.no_grad()
def _layer_param_summaries(model) -> Dict[str, float]:
    gammas = []
    gates = []
    alibis = []

    for blk in getattr(model, "blocks", []):
        attn = getattr(blk, "asa", None)
        if attn is None:
            continue

        if hasattr(attn, "_content_read_gamma_raw"):
            g = F.softplus(attn._content_read_gamma_raw.detach().float())
            mx = getattr(attn, "content_read_max_gamma", None)
            if mx is not None and float(mx) > 0:
                g = g.clamp(max=float(mx))
            gammas.append(float(g.cpu().item()))

        if hasattr(attn, "_slotspace_gate_raw"):
            kg = F.softplus(attn._slotspace_gate_raw.detach().float())
            gates.append(float(kg.cpu().item()))

        if hasattr(attn, "_alibi_strength"):
            try:
                a = attn._alibi_strength(dtype=torch.float32, device=next(attn.parameters()).device).detach().cpu().item()
                alibis.append(float(a))
            except Exception:
                pass

    out: Dict[str, float] = {}
    if gammas:
        t = torch.tensor(gammas)
        out["content_read_gamma_mean"] = float(t.mean().item())
        out["content_read_gamma_min"]  = float(t.min().item())
        out["content_read_gamma_max"]  = float(t.max().item())
    if gates:
        t = torch.tensor(gates)
        out["slotspace_gate_mean"] = float(t.mean().item())
        out["slotspace_gate_min"]  = float(t.min().item())
        out["slotspace_gate_max"]  = float(t.max().item())
    if alibis:
        t = torch.tensor(alibis)
        out["alibi_strength_mean"] = float(t.mean().item())
        out["alibi_strength_min"]  = float(t.min().item())
        out["alibi_strength_max"]  = float(t.max().item())
    return out


# =========================================================
# Eval
# =========================================================
@torch.no_grad()
def evaluate(model, val_loader, max_batches=50, last_k=8):
    model.eval()

    losses = []
    infos = None
    ent_acc = 0.0
    top1_acc = 0.0
    com_acc = 0.0
    lastk_acc = 0.0

    delta_acc = 0.0
    delta_n = 0
    n_batches = 0

    for i, (xb, yb) in enumerate(val_loader):
        if i >= max_batches:
            break
        xb, yb = xb.to(device, non_blocking=True), yb.to(device, non_blocking=True)

        with torch.autocast(device_type=device.type, dtype=amp_dtype, enabled=use_amp):
            logits = model(xb, return_info=False)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), yb.view(-1))
        losses.append(float(loss.item()))

        # kept as-is: infos is None unless you enable return_info in eval
        if infos and infos[0] is not None:
            eL, tL, cL, lL = [], [], [], []
            dL = []
            for info in infos:
                read_w = info.get("read_weights", None)
                wl     = info.get("write_logits", None)

                if read_w is not None:
                    eL.append(_entropy_mean(read_w))
                    tL.append(_top1freq_mean(read_w))
                if wl is not None:
                    com, lastm = _write_stats(wl.to(torch.float32), last_k=last_k)
                    cL.append(com)
                    lL.append(lastm)

                dd = _get_scalar(info, "slotspace_delta_norm")
                if dd is not None:
                    dL.append(dd)

            if eL:
                ent_acc += sum(eL) / len(eL)
                top1_acc += sum(tL) / len(tL)
            if cL:
                com_acc += sum(cL) / len(cL)
                lastk_acc += sum(lL) / len(lL)
            if dL:
                delta_acc += sum(dL) / len(dL)
                delta_n += 1

            n_batches += 1

    mean = sum(losses) / max(1, len(losses))
    ppl = float(math.exp(min(20.0, mean)))

    stats: Dict[str, float] = {}
    if n_batches > 0:
        stats["entropy_mean"] = ent_acc / n_batches
        stats["top1freq_mean"] = top1_acc / n_batches
        stats["write_com_mean"] = com_acc / n_batches
        stats["write_lastk_mass_mean"] = lastk_acc / n_batches
    if delta_n > 0:
        stats["slotspace_delta_norm"] = delta_acc / delta_n

    try:
        stats.update(_layer_param_summaries(model))
    except Exception:
        pass

    model.train()
    return mean, ppl, stats


# =========================================================
# Checkpointing (model checkpoints, not activation checkpointing)
# =========================================================
def save_ckpt(path, cfg, model, opt, step, best_val):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    torch.save(
        {"cfg": asdict(cfg), "model": model.state_dict(), "opt": opt.state_dict(),
         "step": step, "best_val": best_val},
        path,
    )
    print(f"âœ“ Saved {path}")


# =========================================================
# Pretty stats formatting (not slow)
# =========================================================
def _fmt_stats(stats: Dict[str, float], last_k: int) -> str:
    keys = [
        "alibi_strength_mean",
        "entropy_mean",
        "top1freq_mean",
        "write_com_mean",
        "write_lastk_mass_mean",
        "content_read_gamma_mean",
        "content_read_gamma_max",
        "slotspace_gate_mean",
        "slotspace_gate_max",
        "slotspace_delta_norm",
    ]
    parts = []
    for k in keys:
        if k in stats:
            if k == "write_lastk_mass_mean":
                parts.append(f"{k}(last_k={last_k})={stats[k]:.4f}")
            else:
                parts.append(f"{k}={stats[k]:.4f}")
    return " | ".join(parts)


# =========================================================
# Train (with gradient accumulation)
# =========================================================
def train_asm(cfg: ASMTrainConfig):
    """
    Drop-in replacement.
    Adds gradient accumulation controlled by:
      - cfg.micro_batch_size (DataLoader batch_size)
      - cfg.grad_accum_steps (optimizer steps every N microbatches)

    If these fields are missing on cfg, defaults are used:
      micro_batch_size := cfg.batch_size
      grad_accum_steps := 1
    """
    random.seed(cfg.seed)
    torch.manual_seed(cfg.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(cfg.seed)

    # ---- accumulation defaults (safe even if cfg lacks the fields) ----
    micro_bs = int(getattr(cfg, "micro_batch_size", cfg.batch_size))
    accum_steps = int(getattr(cfg, "grad_accum_steps", 1))
    assert micro_bs >= 1
    assert accum_steps >= 1

    eff_bs = micro_bs * accum_steps

    # ---------- Data prep (cached streams) ----------
    os.makedirs(cfg.cache_dir, exist_ok=True)
    train_stream_cache = os.path.join(cfg.cache_dir, f"{cfg.dataset_config}_train_stream.pkl")
    val_stream_cache   = os.path.join(cfg.cache_dir, f"{cfg.dataset_config}_val_stream.pkl")

    train_stream = build_or_load_token_stream(
        cache_path=train_stream_cache,
        dataset_name=cfg.dataset_name,
        dataset_config=cfg.dataset_config,
        split="train",
        tokenizer_name=cfg.tokenizer_name,
        min_chars=1,
        add_eos_between_rows=True,
    )
    val_stream = build_or_load_token_stream(
        cache_path=val_stream_cache,
        dataset_name=cfg.dataset_name,
        dataset_config=cfg.dataset_config,
        split="validation",
        tokenizer_name=cfg.tokenizer_name,
        min_chars=1,
        add_eos_between_rows=True,
    )

    val_dataset = build_or_load_validation_windows(
        cache_path=cfg.val_windows_cache,
        token_stream=val_stream,
        max_seq_len=cfg.max_seq_len,
        stride_frac=cfg.stride_frac_val,
        val_samples_target=cfg.val_samples_target,
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=micro_bs,
        shuffle=False,
        num_workers=0,
        pin_memory=torch.cuda.is_available(),
    )

    train_ds = WikiTextRandomWindowStream(
        token_stream=train_stream,
        max_seq_len=cfg.max_seq_len,
        train_samples_target=cfg.train_samples_target,
        seed=cfg.seed,
    )
    train_loader = DataLoader(
        train_ds,
        batch_size=micro_bs,
        num_workers=3,
        pin_memory=torch.cuda.is_available(),
    )

    # ---------- Model ----------
    model = ASMLanguageModel(
        vocab_size=cfg.vocab_size,
        embed_dim=cfg.embed_dim,
        num_layers=cfg.num_layers,
        num_heads=cfg.num_heads,
        num_slots=cfg.num_slots,
        max_seq_len=cfg.max_seq_len,
        mlp_ratio=cfg.mlp_ratio,
        dropout=cfg.dropout,

        read_temperature=cfg.read_temperature,
        write_temperature=cfg.write_temperature,
        state_fp32=cfg.state_fp32,
        slot_dropout=cfg.slot_dropout,
        normalize_k=cfg.normalize_k,

        tie_weights=cfg.tie_weights,

        use_abs_pos=cfg.use_abs_pos,

        use_rope_keys=cfg.use_rope_keys,
        rope_base=cfg.rope_base,
        use_alibi_write=cfg.use_alibi_write,
        alibi_strength_init=cfg.alibi_strength_init,
        learn_alibi_strength=cfg.learn_alibi_strength,
        min_strength=cfg.min_strength,

        use_content_read=cfg.use_content_read,
        content_read_init=cfg.content_read_init,
        content_read_max_gamma=cfg.content_read_max_gamma,

        use_slotspace_refine=cfg.use_slotspace_refine,
        slotspace_dim=cfg.slotspace_dim,
        slotspace_gate_init=cfg.slotspace_gate_init,
        slotspace_dropout=cfg.slotspace_dropout,
        slotspace_signed_weights=cfg.slotspace_signed_weights,

        use_rope_slotspace=cfg.use_rope_slotspace,
        rope_base_slotspace=cfg.rope_base_slotspace,

        write_chunk_size=cfg.write_chunk_size,
        enable_compiled=cfg.enable_compiled,
    ).to(device)

    out_dir = os.path.join(cfg.output_dir, cfg.tag)
    os.makedirs(out_dir, exist_ok=True)

    n_params = sum(p.numel() for p in model.parameters())
    print("=" * 108)
    print(f"Training [{cfg.tag}] on {cfg.dataset_name}/{cfg.dataset_config}")
    print(f"Params: {n_params:,}")
    print(f"Train tokens: {len(train_stream):,} | Val tokens: {len(val_stream):,} | Val windows: {len(val_dataset):,}")
    print(f"T={cfg.max_seq_len} | val_stride_frac={cfg.stride_frac_val} | last_k={cfg.analytics_last_k}")
    print(f"Batching: micro_bs={micro_bs} | accum_steps={accum_steps} | effective_bs={eff_bs}")
    print(f"Chunks: write={cfg.write_chunk_size} | amp={use_amp}({amp_dtype}) | state_fp32={cfg.state_fp32}")
    print(f"RoPE: keys={cfg.use_rope_keys}(base={cfg.rope_base:g}) | slotspace={cfg.use_rope_slotspace}(base={cfg.rope_base_slotspace:g})")
    print("=" * 108)

    opt = torch.optim.AdamW(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay, betas=cfg.betas)
    sched = WarmupCosine(opt, cfg.warmup_steps, cfg.total_steps, cfg.learning_rate)

    # ---------- Initial eval ----------
    best_val = float("inf")
    vloss, vppl, vstats = evaluate(model, val_loader, max_batches=cfg.eval_max_batches, last_k=cfg.analytics_last_k)
    best_val = vloss
    save_ckpt(os.path.join(out_dir, "best.pt"), cfg, model, opt, 0, best_val)

    print(f"[VAL step 0] loss={vloss:.3f} ppl={vppl:.2f}")
    if vstats:
        print("  " + _fmt_stats(vstats, last_k=cfg.analytics_last_k))

    # ---------- Training loop (grad accumulation) ----------
    model.train()
    opt.zero_grad(set_to_none=True)

    running = 0.0          # for logging (unscaled loss)
    step = 0               # optimizer steps
    micro_step = 0         # microbatches
    t_last = time.time()

    # tqdm counts optimizer steps (matches cfg.total_steps, sched, eval_interval)
    pbar = tqdm(total=cfg.total_steps, desc=f"[{cfg.tag}]")

    for xb, yb in train_loader:
        xb = xb.to(device, non_blocking=True)
        yb = yb.to(device, non_blocking=True)
        micro_step += 1

        with torch.autocast(device_type=device.type, dtype=amp_dtype, enabled=use_amp):
            logits = model(xb)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), yb.view(-1))

        # log unscaled loss for readability
        running += float(loss.item())

        # scale loss so accumulated grads match a single big batch
        (loss / accum_steps).backward()

        # take an optimizer step every accum_steps microbatches
        if (micro_step % accum_steps) == 0:
            nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
            opt.step()
            opt.zero_grad(set_to_none=True)
            lr = sched.step()

            step += 1
            pbar.update(1)

            if step % cfg.log_interval == 0:
                avg = running / (cfg.log_interval * accum_steps)  # avg loss per microbatch (equiv per-sample batch loss)
                running = 0.0

                ps = _layer_param_summaries(model)
                it_s = cfg.log_interval / max(1e-9, (time.time() - t_last))
                t_last = time.time()

                postfix = {
                    "loss": f"{avg:.3f}",
                    "ppl": f"{math.exp(min(20.0, avg)):.2f}",
                    "lr": f"{lr:.2e}",
                    "it/s": f"{it_s:.2f}",
                }
                if "content_read_gamma_mean" in ps: postfix["Î³Î¼"] = f"{ps['content_read_gamma_mean']:.3f}"
                if "slotspace_gate_mean" in ps: postfix["sgÎ¼"] = f"{ps['slotspace_gate_mean']:.3f}"
                pbar.set_postfix(postfix)

                msg = f"[step {step}] train_loss={avg:.3f} ppl={math.exp(min(20.0, avg)):.2f} lr={lr:.2e} it/s={it_s:.2f}"
                if "content_read_gamma_mean" in ps: msg += f" | content_read_gamma_mean={ps['content_read_gamma_mean']:.4f}"
                if "slotspace_gate_mean" in ps: msg += f" | slotspace_gate_mean={ps['slotspace_gate_mean']:.4f}"
                if "alibi_strength_mean" in ps: msg += f" | alibi_strength_mean={ps['alibi_strength_mean']:.4f}"
                print(msg)

            if step % cfg.eval_interval == 0:
                vloss, vppl, vstats = evaluate(model, val_loader, max_batches=cfg.eval_max_batches, last_k=cfg.analytics_last_k)
                print(f"\n[VAL step {step}] loss={vloss:.3f} ppl={vppl:.2f}")
                if vstats:
                    print("  " + _fmt_stats(vstats, last_k=cfg.analytics_last_k))

                if vloss < best_val:
                    best_val = vloss
                    save_ckpt(os.path.join(out_dir, "best.pt"), cfg, model, opt, step, best_val)

            if step >= cfg.total_steps:
                break

    # If we exit mid-accumulation (micro_step not divisible by accum_steps), you can optionally flush grads.
    # For strict reproducibility, we *do not* flush by default.

    save_ckpt(os.path.join(out_dir, "final.pt"), cfg, model, opt, step, best_val)
    print(f"[{cfg.tag}] Done. Best val loss: {best_val:.4f}")
    return model

In [None]:
#@title train asm func with resume
def train_asm(cfg: ASMTrainConfig):
    """
    Drop-in replacement for your current train_asm(cfg).

    Adds:
      - resume from cfg.resume_path if present and exists
      - correct tqdm initial position
      - scheduler step_num alignment
      - skips "initial eval + best save" when resuming (optional: can still eval if you want)

    Required globals (same as your current code expects):
      device, use_amp, amp_dtype
    Required symbols:
      ASMLanguageModel, ASMTrainConfig, WarmupCosine
      build_or_load_token_stream, build_or_load_validation_windows, WikiTextRandomWindowStream
      evaluate, save_ckpt, _layer_param_summaries, _fmt_stats
    """
    import os, math, random, time
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torch.utils.data import DataLoader
    from tqdm.auto import tqdm

    random.seed(cfg.seed)
    torch.manual_seed(cfg.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(cfg.seed)

    # ---- accumulation defaults (safe even if cfg lacks the fields) ----
    micro_bs = int(getattr(cfg, "micro_batch_size", cfg.batch_size))
    accum_steps = int(getattr(cfg, "grad_accum_steps", 1))
    assert micro_bs >= 1
    assert accum_steps >= 1
    eff_bs = micro_bs * accum_steps

    # ---------- Data prep (cached streams) ----------
    os.makedirs(cfg.cache_dir, exist_ok=True)
    train_stream_cache = os.path.join(cfg.cache_dir, f"{cfg.dataset_config}_train_stream.pkl")
    val_stream_cache   = os.path.join(cfg.cache_dir, f"{cfg.dataset_config}_val_stream.pkl")

    train_stream = build_or_load_token_stream(
        cache_path=train_stream_cache,
        dataset_name=cfg.dataset_name,
        dataset_config=cfg.dataset_config,
        split="train",
        tokenizer_name=cfg.tokenizer_name,
        min_chars=1,
        add_eos_between_rows=True,
    )
    val_stream = build_or_load_token_stream(
        cache_path=val_stream_cache,
        dataset_name=cfg.dataset_name,
        dataset_config=cfg.dataset_config,
        split="validation",
        tokenizer_name=cfg.tokenizer_name,
        min_chars=1,
        add_eos_between_rows=True,
    )

    val_dataset = build_or_load_validation_windows(
        cache_path=cfg.val_windows_cache,
        token_stream=val_stream,
        max_seq_len=cfg.max_seq_len,
        stride_frac=cfg.stride_frac_val,
        val_samples_target=cfg.val_samples_target,
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=micro_bs,
        shuffle=False,
        num_workers=0,
        pin_memory=torch.cuda.is_available(),
    )

    train_ds = WikiTextRandomWindowStream(
        token_stream=train_stream,
        max_seq_len=cfg.max_seq_len,
        train_samples_target=cfg.train_samples_target,
        seed=cfg.seed,
    )
    train_loader = DataLoader(
        train_ds,
        batch_size=micro_bs,
        num_workers=3,
        pin_memory=torch.cuda.is_available(),
    )

    # ---------- Model ----------
    model = ASMLanguageModel(
        vocab_size=cfg.vocab_size,
        embed_dim=cfg.embed_dim,
        num_layers=cfg.num_layers,
        num_heads=cfg.num_heads,
        num_slots=cfg.num_slots,
        max_seq_len=cfg.max_seq_len,
        mlp_ratio=cfg.mlp_ratio,
        dropout=cfg.dropout,

        read_temperature=cfg.read_temperature,
        write_temperature=cfg.write_temperature,
        state_fp32=cfg.state_fp32,
        slot_dropout=cfg.slot_dropout,
        normalize_k=cfg.normalize_k,

        tie_weights=cfg.tie_weights,

        use_abs_pos=cfg.use_abs_pos,

        use_rope_keys=cfg.use_rope_keys,
        rope_base=cfg.rope_base,
        use_alibi_write=cfg.use_alibi_write,
        alibi_strength_init=cfg.alibi_strength_init,
        learn_alibi_strength=cfg.learn_alibi_strength,
        min_strength=cfg.min_strength,

        use_content_read=cfg.use_content_read,
        content_read_init=cfg.content_read_init,
        content_read_max_gamma=cfg.content_read_max_gamma,

        use_slotspace_refine=cfg.use_slotspace_refine,
        slotspace_dim=cfg.slotspace_dim,
        slotspace_gate_init=cfg.slotspace_gate_init,
        slotspace_dropout=cfg.slotspace_dropout,
        slotspace_signed_weights=cfg.slotspace_signed_weights,

        use_rope_slotspace=cfg.use_rope_slotspace,
        rope_base_slotspace=cfg.rope_base_slotspace,

        write_chunk_size=cfg.write_chunk_size,
        enable_compiled=cfg.enable_compiled,
    ).to(device)

    out_dir = os.path.join(cfg.output_dir, cfg.tag)
    os.makedirs(out_dir, exist_ok=True)

    n_params = sum(p.numel() for p in model.parameters())
    print("=" * 108)
    print(f"Training [{cfg.tag}] on {cfg.dataset_name}/{cfg.dataset_config}")
    print(f"Params: {n_params:,}")
    print(f"Train tokens: {len(train_stream):,} | Val tokens: {len(val_stream):,} | Val windows: {len(val_dataset):,}")
    print(f"T={cfg.max_seq_len} | val_stride_frac={cfg.stride_frac_val} | last_k={cfg.analytics_last_k}")
    print(f"Batching: micro_bs={micro_bs} | accum_steps={accum_steps} | effective_bs={eff_bs}")
    print(f"Chunks: write={cfg.write_chunk_size} | amp={use_amp}({amp_dtype}) | state_fp32={cfg.state_fp32}")
    print(f"RoPE: keys={cfg.use_rope_keys}(base={cfg.rope_base:g}) | slotspace={cfg.use_rope_slotspace}(base={cfg.rope_base_slotspace:g})")
    print("=" * 108)

    # ---------- Optimizer + Scheduler ----------
    opt = torch.optim.AdamW(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay, betas=cfg.betas)
    sched = WarmupCosine(opt, cfg.warmup_steps, cfg.total_steps, cfg.learning_rate)

    # ---------- Resume (if requested) ----------
    resume_path = getattr(cfg, "resume_path", None)
    start_step = 0
    best_val = float("inf")

    if resume_path is not None and os.path.exists(resume_path):
        print(f"Resuming from checkpoint: {resume_path}")
        ckpt = torch.load(resume_path, map_location=device)
        # Strict load by default to catch config/shape mismatches early
        model.load_state_dict(ckpt["model"], strict=True)
        if "opt" in ckpt and ckpt["opt"] is not None:
            opt.load_state_dict(ckpt["opt"])
        start_step = int(ckpt.get("step", 0))
        best_val = float(ckpt.get("best_val", float("inf")))

        # Align scheduler to resumed optimizer-step count
        sched.step_num = start_step

        # Optional sanity eval on resume (off by default for speed)
        if bool(getattr(cfg, "eval_on_resume", False)):
            vloss, vppl, vstats = evaluate(model, val_loader, max_batches=cfg.eval_max_batches, last_k=cfg.analytics_last_k)
            print(f"[VAL resume@step {start_step}] loss={vloss:.3f} ppl={vppl:.2f}")
            if vstats:
                print("  " + _fmt_stats(vstats, last_k=cfg.analytics_last_k))

    else:
        # ---------- Initial eval (fresh run) ----------
        vloss, vppl, vstats = evaluate(model, val_loader, max_batches=cfg.eval_max_batches, last_k=cfg.analytics_last_k)
        best_val = vloss
        save_ckpt(os.path.join(out_dir, "best.pt"), cfg, model, opt, 0, best_val)

        print(f"[VAL step 0] loss={vloss:.3f} ppl={vppl:.2f}")
        if vstats:
            print("  " + _fmt_stats(vstats, last_k=cfg.analytics_last_k))

    # ---------- Training loop (grad accumulation) ----------
    model.train()
    opt.zero_grad(set_to_none=True)

    running = 0.0      # sum of *unscaled* microbatch loss for logging
    step = start_step  # optimizer steps completed
    micro_step = 0     # microbatches since (re)start
    t_last = time.time()

    # tqdm counts optimizer steps (matches cfg.total_steps, sched, eval_interval)
    pbar = tqdm(total=cfg.total_steps, initial=step, desc=f"[{cfg.tag}]")

    for xb, yb in train_loader:
        if step >= cfg.total_steps:
            break

        xb = xb.to(device, non_blocking=True)
        yb = yb.to(device, non_blocking=True)
        micro_step += 1

        with torch.autocast(device_type=device.type, dtype=amp_dtype, enabled=use_amp):
            logits = model(xb)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), yb.view(-1))

        running += float(loss.item())
        (loss / accum_steps).backward()

        if (micro_step % accum_steps) == 0:
            nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
            opt.step()
            opt.zero_grad(set_to_none=True)
            lr = sched.step()

            step += 1
            pbar.update(1)

            if step % cfg.log_interval == 0:
                # average microbatch loss across the last log window
                avg = running / (cfg.log_interval * accum_steps)
                running = 0.0

                ps = _layer_param_summaries(model)
                it_s = cfg.log_interval / max(1e-9, (time.time() - t_last))
                t_last = time.time()

                postfix = {
                    "loss": f"{avg:.3f}",
                    "ppl": f"{math.exp(min(20.0, avg)):.2f}",
                    "lr": f"{lr:.2e}",
                    "it/s": f"{it_s:.2f}",
                }
                if "content_read_gamma_mean" in ps: postfix["Î³Î¼"] = f"{ps['content_read_gamma_mean']:.3f}"
                if "slotspace_gate_mean" in ps: postfix["sgÎ¼"] = f"{ps['slotspace_gate_mean']:.3f}"
                pbar.set_postfix(postfix)

                msg = f"[step {step}] train_loss={avg:.3f} ppl={math.exp(min(20.0, avg)):.2f} lr={lr:.2e} it/s={it_s:.2f}"
                if "content_read_gamma_mean" in ps: msg += f" | content_read_gamma_mean={ps['content_read_gamma_mean']:.4f}"
                if "slotspace_gate_mean" in ps: msg += f" | slotspace_gate_mean={ps['slotspace_gate_mean']:.4f}"
                if "alibi_strength_mean" in ps: msg += f" | alibi_strength_mean={ps['alibi_strength_mean']:.4f}"
                print(msg)

            if step % cfg.eval_interval == 0:
                vloss, vppl, vstats = evaluate(model, val_loader, max_batches=cfg.eval_max_batches, last_k=cfg.analytics_last_k)
                print(f"\n[VAL step {step}] loss={vloss:.3f} ppl={vppl:.2f}")
                if vstats:
                    print("  " + _fmt_stats(vstats, last_k=cfg.analytics_last_k))

                if vloss < best_val:
                    best_val = vloss
                    save_ckpt(os.path.join(out_dir, "best.pt"), cfg, model, opt, step, best_val)

            # Optional: periodic "last" checkpoint so you can resume even if best isn't improving
            save_last_every = int(getattr(cfg, "save_last_every", 0))
            if save_last_every > 0 and (step % save_last_every) == 0:
                save_ckpt(os.path.join(out_dir, "last.pt"), cfg, model, opt, step, best_val)

    # NOTE: If we exit mid-accumulation (micro_step not divisible by accum_steps), we do NOT flush grads
    # by default, to preserve a clean "optimizer step = accum_steps microbatches" invariant.

    save_ckpt(os.path.join(out_dir, "final.pt"), cfg, model, opt, step, best_val)
    print(f"[{cfg.tag}] Done. Best val loss: {best_val:.4f} | last step: {step}")
    return model

## Train

In [None]:
#@title Launch Run

# =========================================================
# Run
# =========================================================
if __name__ == "__main__":

    # =========================================================
    # Device + AMP
    # =========================================================
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    use_amp = torch.cuda.is_available()
    amp_dtype = torch.bfloat16  # A100-friendly

    print("Using device:", device)
    if torch.cuda.is_available():
        print("GPU:", torch.cuda.get_device_name(0))
        print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory/1e9:.1f} GB")

    cfg = ASMTrainConfig(
        embed_dim=1024,
        max_seq_len=1024,
        num_heads=16,
        num_layers=12,
        num_slots=32,
        slotspace_dim=128,
        batch_size=4,
        total_steps=200_000,
        warmup_steps=3_000,

        state_fp32=False,

        use_content_read=True,
        content_read_init=-2.0,
        content_read_max_gamma=3.0,

        write_chunk_size=128,
        enable_compiled=False,
        #slotspace_chunk_size=256,
        use_slotspace_refine=True,
        slotspace_gate_init=-3.0,
        tag="asm_wikitext_1024t_1024d_32h_64sd_64s_128cs_12l",
        val_windows_cache="./drive/MyDrive/asm_nlp/val_cache_wikitext_windows_1024.pkl",
    )
    print("CONFIG:")
    print(cfg)
    train_asm(cfg)





#

In [None]:



#@title Launch Run

# =========================================================
# Run
# =========================================================
if __name__ == "__main__":

    # =========================================================
    # Device + AMP
    # =========================================================
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    use_amp = torch.cuda.is_available()
    amp_dtype = torch.bfloat16  # A100-friendly

    print("Using device:", device)
    if torch.cuda.is_available():
        print("GPU:", torch.cuda.get_device_name(0))
        print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory/1e9:.1f} GB")

    cfg = ASMTrainConfig(
        embed_dim=1024,
        max_seq_len=1024,
        num_heads=16,
        num_layers=15,
        num_slots=32,
        slotspace_dim=128,
        batch_size=2,
        total_steps=50_000,
        warmup_steps=300,

        state_fp32=False,

        use_content_read=True,
        content_read_init=-2.0,
        content_read_max_gamma=3.0,

        write_chunk_size=128,
        enable_compiled=True,
        #slotspace_chunk_size=256,
        use_slotspace_refine=True,
        slotspace_gate_init=-3.0,
        tag="asm_wikitext_1024t_1024d_32h_64sd_64s_128cs_12l",
        val_windows_cache="./drive/MyDrive/asm_nlp/val_cache_wikitext_windows_1024.pkl",
    )



    # Preferred (explicit) if you added these attrs to cfg (train_asm uses getattr so it's safe):
    cfg.micro_batch_size = 2
    cfg.grad_accum_steps = 16

    # If you want the *same token budget* as before, reduce train_samples_target by accum_steps
    # because each optimizer step consumes accum_steps microbatches.
    # Only do this if your train_samples_target is "number of batches" (it is, in your iterable).
    cfg.train_samples_target = cfg.total_steps * cfg.grad_accum_steps

    cfg.resume_path = "./drive/MyDrive/asm_outputs/asm_wikitext_1024t_1024d_32h_64sd_64s_128cs_12l/best.pt"
    # Optional but recommended so you can resume even if best doesn't change:
    cfg.save_last_every = 500  # saves last.pt every 500 optimizer steps
    # Optional sanity check:
    cfg.eval_on_resume = True

    print("CONFIG:")
    print(cfg)

    model = train_asm(cfg)

#