<a href="https://colab.research.google.com/github/falseywinchnet/PyITD/blob/main/ConvexTransformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
"""
Prepare the Shakespeare dataset for character-level language modeling.
So instead of encoding with GPT-2 BPE tokens, we just map characters to ints.
Will save train.bin, val.bin containing the ids, and meta.pkl containing the
encoder and decoder and some other related info.
"""
import os
import pickle
import requests
import numpy as np
import os
from pathlib import Path

try:
    base_dir = Path(__file__).parent
except NameError:
    base_dir = Path(os.getcwd())  # fallback if __file__ is not defined (e.g. in REPL)
# download the tiny shakespeare dataset
input_file_path = os.path.join(os.path.dirname(base_dir), 'input.txt')
if not os.path.exists(input_file_path):
    data_url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
    with open(input_file_path, 'w') as f:
        f.write(requests.get(data_url).text)

with open(input_file_path, 'r') as f:
    data = f.read()
print(f"length of dataset in characters: {len(data):,}")

# get all the unique characters that occur in this text
chars = sorted(list(set(data)))
vocab_size = len(chars)
print("all the unique characters:", ''.join(chars))
print(f"vocab size: {vocab_size:,}")

# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
def encode(s):
    return [stoi[c] for c in s] # encoder: take a string, output a list of integers
def decode(l):
    return ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

# create the train and test splits
n = len(data)
train_data = data[:int(n*0.9)]
val_data = data[int(n*0.9):]

# encode both to integers
train_ids = encode(train_data)
val_ids = encode(val_data)
print(f"train has {len(train_ids):,} tokens")
print(f"val has {len(val_ids):,} tokens")

# export to bin files
train_ids = np.array(train_ids, dtype=np.uint16)
val_ids = np.array(val_ids, dtype=np.uint16)
train_ids.tofile(os.path.join(os.path.dirname(base_dir), 'train.bin'))
val_ids.tofile(os.path.join(os.path.dirname(base_dir), 'val.bin'))

# save the meta information as well, to help us encode/decode later
meta = {
    'vocab_size': vocab_size,
    'itos': itos,
    'stoi': stoi,
}
with open(os.path.join(os.path.dirname(base_dir), 'meta.pkl'), 'wb') as f:
    pickle.dump(meta, f)

length of dataset in characters: 1,115,394
all the unique characters: 
 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
vocab size: 65
train has 1,003,854 tokens
val has 111,540 tokens


#if you use my ideas, please credit me, dont just steal
joshuah.rainstar@gmail.com


In [19]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from pathlib import Path
from typing import List


import torch, torch.nn as nn, torch.nn.functional as F, math
from typing import Literal

# --------‑‑‑‑ helper ---------------------------------------------------
@torch.jit.script   # optional: proves it scripts
def next_pow_two(x: int) -> int:
    """
    TorchScript‑safe power‑of‑two ceiling using tensor math.
    """
    t = torch.tensor(float(x))            # scalar tensor
    k = torch.ceil(torch.log2(t)).int()   # round‑up exponent
    return int((2 ** k).item())           # convert back to Python int

# ---------------------------------------------------------------------
#  S4DFFT  (augmented)
# ---------------------------------------------------------------------
class S4DFFT(nn.Module):
    """
    Diagonal State‑Space (S4D) layer with length‑agnostic FFT or recurrent scan.

      x : (B,T,D)  ➜  y : (B,T,D)
    """

    def __init__(
        self,
        d_model: int,
        N: int          = 64,          # # diagonal modes
        init: str       = "hippoD",    # 'hippoD' | 'inverse' | 'linear'
        short_thresh: int = 512,       # switch to recurrent if T ≤ this
        tau_min: float  = 1e-4,        # clamp on exp(log_tau)
    ):
        super().__init__()
        assert N % 2 == 0, "N must be even (conjugate pairs)."

        self.d_model, self.N = d_model, N
        self.tau_min = tau_min

        # unconstrained parameters for N/2 distinct modes
        self.log_tau = nn.Parameter(torch.randn(N // 2))
        self.freq    = nn.Parameter(torch.randn(N // 2))
        self.B       = nn.Parameter(torch.randn(N // 2))
        self.C       = nn.Parameter(torch.randn(N // 2))

        # input/output projections
        self.in_proj  = nn.Linear(d_model, N // 2, bias=False)
        self.out_proj = nn.Linear(N // 2, d_model, bias=False)

        # learnable global time‑scale Δt  (log‑domain)
        self.log_dt = nn.Parameter(torch.zeros(()))

        self._init_modes(init)

    # ----- initialisers --------------------------------------------------
    def _init_modes(self, kind: Literal["hippoD", "inverse", "linear"]):
        n = torch.arange(self.N // 2)
        with torch.no_grad():
            self.log_tau.fill_(math.log(0.5))
            if kind == "hippoD":
                self.freq.copy_(math.pi * (2*n + 1) / 2)
            elif kind == "inverse":
                self.freq.copy_((self.N / math.pi) / (2*n + 1))
            elif kind == "linear":
                self.freq.copy_(math.pi * n)
            else:
                raise ValueError(kind)
            nn.init.normal_(self.B,  mean=1.0, std=0.2)
            nn.init.normal_(self.C,  std=1.0 / math.sqrt(self.N/2))

    # ---------------------------------------------------------------------------
    # Real‑only kernel builder
    # ---------------------------------------------------------------------------
    def _kernel_fft(self, T: int):
        """
        Return RFFT(K) where K is the real convolution kernel of length T.
          output: (N, L/2+1) complex
        Everything up to the final rfft is real‑typed.
        """
        L   = next_pow_two(2 * T)

        dt   = torch.exp(self.log_dt)                      # scalar
        tau  = torch.exp(self.log_tau).clamp(min=self.tau_min)   # (N/2,)
        angle = self.freq * dt                                   # (N/2,)

        # |lam|  = exp(-tau*dt)            (real)
        # arg(lam)= angle                  (real)
        lam_mag = torch.exp(-tau * dt)                         # (N/2,)
        log_gain = (self.B.abs() + 1e-9).log() + \
                  (self.C.abs() + 1e-9).log()                 # (N/2,)

        i = torch.arange(T, device=tau.device)                 # (T,)

        # amplitude term  (N/2,T)   — still real
        amp = torch.exp(log_gain[:, None] + i[None] * torch.log(lam_mag)[:, None])

        # phase term
        phase = i[None] * angle[:, None]                       # (N/2,T)

        K_half = amp * torch.cos(phase)                        # (N/2,T) real

        # build full length‑N kernel (conjugate pair ⇒ symmetry in mode index)
        K_full = torch.cat([K_half, K_half.flip(0)], dim=0)     # (N,T) real

        return torch.fft.rfft(K_full, n=L, dim=-1)             # (N,L/2+1) complex

    # ----- forward (FFT or scan) ----------------------------------------
    def forward(self, x: torch.Tensor):
        B, T, _ = x.shape
        x_proj  = self.in_proj(x)                               # (B,T,N/2)
        x_modes = torch.cat([x_proj, x_proj.flip(-1)], dim=-1)  # (B,T,N)  real

        L  = next_pow_two(2 * T)
        Uf = torch.fft.rfft(x_modes, n=L, dim=1).transpose(1, 2)   # (B,N,L/2+1)

        Kf = self._kernel_fft(T)                                   # (N,L/2+1)
        Yf = Uf * Kf[None]                                         # broadcast

        y_modes = torch.fft.irfft(Yf, n=L, dim=2)[..., :T]          # (B,N,T)
        y_modes = y_modes.transpose(1, 2)                          # (B,T,N)
        y       = y_modes[..., : self.N // 2]                       # (B,T,N/2)
        return self.out_proj(y)


class S4PreMix(nn.Module):
    def __init__(self, embed_dim, inner_dim, heads, N_modes=64):
        super().__init__()
        assert inner_dim % heads == 0
        self.heads = heads
        self.d_k   = inner_dim // heads

        # ---- NEW: linear up‑proj to inner_dim -------------------------
        self.up   = nn.Linear(embed_dim, inner_dim, bias=False)

        # ---- S4D operates at reduced width ---------------------------
        self.s4d  = S4DFFT(d_model=inner_dim, N=N_modes)
        self.qkv  = nn.Linear(inner_dim, inner_dim * 3, bias=False)

    def forward(self, x):                       # x: (B,S,E)
        z  = self.s4d(self.up(x))               # (B,S,inner_dim)
        q, k, v = self.qkv(z).chunk(3, dim=-1)  # each (B,S,inner_dim)

        B, S, _ = x.shape
        new_shape = (B, S, self.heads, self.d_k)

        # safe reshape regardless of contiguity
        q = q.reshape(new_shape).transpose(1, 2)   # (B,H,S,d_k)
        k = k.reshape(new_shape).transpose(1, 2)
        v = v.reshape(new_shape).transpose(1, 2)
        return q, k, v
# ----------  Positive weight layer with Hoedt–Klambauer init ----------
# ---------------------------------------------------------------------
# PositiveLinear – strictly‑positive weights with HL init + safe interface
# ---------------------------------------------------------------------
class PositiveLinear(nn.Module):
    def __init__(self, d_in, d_out, bias=True):
        super().__init__()
        self.raw  = nn.Parameter(torch.empty(d_out, d_in))
        self.bias = nn.Parameter(torch.empty(d_out)) if bias else None
        with torch.no_grad():
            nn.init.normal_(self.raw, mean=math.log(math.sqrt(2/d_in)), std=0.2)
            if self.bias is not None: self.bias.zero_()

    @property
    def weight(self):                        # strictly positive
        return F.softplus(self.raw)

    def forward(self, x):
        return F.linear(x, self.weight, self.bias)

# ---------------- ICNN petal -----------------------------------------
class ICNN(nn.Module):
    def __init__(self, dim, hidden_dims):
        super().__init__()
        layers = [PositiveLinear(dim if i==0 else h, h) for i, h in enumerate(hidden_dims)]
        layers.append(PositiveLinear(hidden_dims[-1], dim))        # keep dimension
        self.layers = nn.ModuleList(layers)

    def forward(self, x):                                          # (..., D)
        z = x
        for layer in self.layers:
            z = F.softplus(layer(z))
        return z

#locally convex gate
class ConvexGate(nn.Module):
    def __init__(self, in_dim):
        super().__init__()
        self.lin = nn.Linear(in_dim, 1, bias=True)

    def forward(self, x):                       # (...,D)
        u = F.softplus(self.lin(x))             # convex, ≥0
        return 1.0 - torch.exp(-u)              # convex, ∈(0,1)


class ConvexGate(nn.Module):
    """
    Convex & bounded gate: g(x) = 1 - exp(-softplus(Wx + b)) ∈ (0,1)
    """
    def __init__(self, in_dim: int):
        super().__init__()
        self.lin = nn.Linear(in_dim, 1, bias=True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        u = F.softplus(self.lin(x))      # convex, ≥ 0
        return 1.0 - torch.exp(-u)       # convex, ∈ (0,1)

class ScalarHull(nn.Module):
    """
    Scalar-valued convex hull with a bounded convex gate.
      x : (..., D)  ->  y : (...,)     (scalar output)
    """
    def __init__(self, in_dim: int, hidden: List[int], petals: int):
        super().__init__()
        # convex ICNN petals
        self.petals = nn.ModuleList(ICNN(in_dim, hidden) for _ in range(petals))
        # convex & bounded gate
        self.gate   = ConvexGate(in_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        g      = self.gate(x)                                          # (...,1)
        xg     = x * g                                                 # (...,D)
        scores = [p(xg).mean(-1, keepdim=True) for p in self.petals]    # list of (...,1)
        return torch.logsumexp(torch.cat(scores, dim=-1), dim=-1)      # (...,)

class VectorHull(nn.Module):
    """
    Vector-valued convex hull with a bounded convex gate.
      x : (..., D)  ->  y : (..., D)
    """
    def __init__(self, dim: int, hidden: List[int], petals: int):
        super().__init__()
        # convex ICNN petals
        self.petals = nn.ModuleList(ICNN(dim, hidden) for _ in range(petals))
        # convex & bounded gate
        self.gate   = ConvexGate(dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        g    = self.gate(x)                                           # (...,1)
        xg   = x * g                                                 # (...,D)
        outs = torch.stack([p(xg) for p in self.petals], dim=-1)      # (...,D,P)
        return torch.logsumexp(outs, dim=-1)                         # (...,D)

class ValueICNN(nn.Module):
    """
    Produce a vector-valued value embedding via a convex ICNN hull.
      x : (..., embed_dim) -> v : (..., embed_dim)
    """
    def __init__(self, embed_dim: int, hidden_dims: List[int], petals: int):
        super().__init__()
        # a VectorHull already returns a vector of same dim
        self.vh = VectorHull(embed_dim, hidden_dims, petals)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B, S, E)
        B, S, E = x.shape
        # flatten tokens, run through hull, then restore shape
        v = self.vh(x.reshape(-1, E))    # (B*S, E)
        return v.view(B, S, E)
# --------------- Pairwise Hull Attention -----------------------------
# ---------------------------------------------------------------------
# Pair‑wise Hull attention • with Rotary Positional Embedding (RoPE) and S4D
# ---------------------------------------------------------------------

class LinearPreMix(nn.Module):
    """
    Pure‑attention pre‑projection:
        x  : (B,S,E)
        qkv: (B,S,3·inner_dim)   with a single Linear
    """
    def __init__(self, embed_dim, inner_dim, heads,N_modes):
        super().__init__()
        assert inner_dim % heads == 0
        self.heads = heads
        self.d_k   = inner_dim // heads

        self.up   = nn.Linear(embed_dim, inner_dim,  bias=False)   # optional width lift
        self.qkv  = nn.Linear(inner_dim, inner_dim * 3, bias=False)

    def forward(self, x):               # x:(B,S,E)
        z        = self.up(x)           # (B,S,inner_dim)
        q, k, v  = self.qkv(z).chunk(3, dim=-1)

        B, S, _  = x.shape
        new_shape= (B, S, self.heads, self.d_k)

        q = q.reshape(new_shape).transpose(1,2)  # (B,H,S,d_k)
        k = k.reshape(new_shape).transpose(1,2)
        v = v.reshape(new_shape).transpose(1,2)
        return q, k, v

class ConvexHullMixer(nn.Module):
    """
    Standard multi-head attention mixer:
      y[b,h,i,d] = sum_j softmax( (q·k)/√d + bias )[i,j] · val(v)[b,h,j,d]
    """
    def __init__(self, d_k: int, petals=None):
        super().__init__()
        # we don’t use “petals” any more; just keep a learnable linear for values
        self.val = nn.Linear(d_k, d_k, bias=False)

    def forward(
        self,
        q: torch.Tensor,            # (B, H, S, d_k)
        k: torch.Tensor,            # (B, H, S, d_k)
        v: torch.Tensor,            # (B, H, S, d_k)
        extra_score: torch.Tensor = None  # (H, S, S) or broadcastable
    ) -> torch.Tensor:
        B, H, S, D = q.shape

        # 1) dot-products
        #    (B,H,S,S) = sum over last dim of q * k
        scores = torch.einsum('bhid,bhjd->bhij', q, k) / math.sqrt(D)

        # 2) add any positional / bias term
        if extra_score is not None:
            # extra_score: (H,S,S) → (1,H,S,S)
            scores = scores + extra_score.unsqueeze(0)

        # 3) softmax
        attn = F.softmax(scores, dim=-1)  # over j

        # 4) linear-project values
        v_lin = self.val(v)               # (B,H,S,D)

        # 5) weighted sum
        y = torch.einsum('bhij,bhjd->bhid', attn, v_lin)  # (B,H,S,D)

        return y                                 # (B,H,S,D)


class ConvexPositionalBias(nn.Module):
    """
    Bias(i,j) = - w * |i-j|    with   w ≥ 0  (learned per head)
    Convex in positional indices; monotone non‑increasing with distance.
    """
    def __init__(self, heads):
        super().__init__()
        self.w_raw = nn.Parameter(torch.zeros(heads))   # raw parameter
    def forward(self, S: int):
        device = self.w_raw.device
        w = F.softplus(self.w_raw)                      # (H,)
        pos  = torch.arange(S, device=device, dtype=torch.float32)
        dist = (pos.unsqueeze(0) - pos.unsqueeze(1)).abs()  # (S,S)
        bias = - w[:, None, None] * dist                # (H,S,S)
        return bias

# --- REPLACED PairwiseHullAttention -----------------------------------
class PairwiseHullAttention(nn.Module):
    def __init__(self, embed_dim, heads, petals,
                 inner_dim=128, use_s4d=False):
        super().__init__()
        assert inner_dim % heads == 0
        self.heads = heads
        self.d_k   = inner_dim // heads

        # q‑k‑v projections (linear or S4D pre‑mix)
        self.pre = (S4PreMix if use_s4d else LinearPreMix)(
            embed_dim, inner_dim, heads, N_modes=64
        )

        # convex mixer
        self.mixer = ConvexHullMixer(self.d_k, petals)
        self.posbias = ConvexPositionalBias(heads)
        self.W_O   = nn.Linear(inner_dim, embed_dim)


    def forward(self, x):                      # x: (B,S,E)
        B, S, _ = x.shape
        device  = x.device

        # 1. q, k, v
        q, k, v = self.pre(x)


        # 3. convex rope
        p_bias = self.posbias(S)    # (H,S,S)

        # ConvexHullMixer expects q,k,v; include bias by
        # adding it to the scalar hull score inside the mixer:
        y = self.mixer(q, k, v, extra_score=p_bias)  # see below

        # 4. reshape & output proj
        out = y.transpose(1, 2).reshape(B, S, -1)   # (B,S,inner_dim)
        return self.W_O(out)

# --------------- Convex Residual Block -------------------------------
class OmniHullBlock(nn.Module):
    """
    Convex GPT block:               y = (1‑α₁)·x + α₁·Attn(LN₁(x))
                                    z = (1‑α₂)·y + α₂·HFF (LN₂(y))
    Both α₁ and α₂ are learned, shared across positions.
    """
    def __init__(self, dim, heads, petals,use_s4d):
        super().__init__()
        self.attn = PairwiseHullAttention(dim, heads, petals,use_s4d=use_s4d)
        self.hff  = VectorHull(dim, [dim*2], petals)

        self.ln1  = nn.LayerNorm(dim)
        self.ln2  = nn.LayerNorm(dim)

        # unconstrained scalars → α ∈ (0,1)
        self.alpha_raw1 = nn.Parameter(torch.zeros(()))
        self.alpha_raw2 = nn.Parameter(torch.zeros(()))

    @staticmethod
    def _mix(x, gx, alpha_raw):
        alpha = F.softplus(alpha_raw) / (1 + F.softplus(alpha_raw))  # sigmoid-softplus
        return (1 - alpha) * x + alpha * gx

    def forward(self, x):
        y = self._mix(x, self.attn(self.ln1(x)), self.alpha_raw1)
        z = self._mix(y, self.hff (self.ln2(y)), self.alpha_raw2)
        return z


# --------------- OmniHull GPT ----------------------------------------
class BigFatMommaGPT(nn.Module):
    def __init__(self, vocab_size, embed_dim, depth, heads, petals):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, embed_dim)
        self.pos_emb = nn.Parameter(torch.randn(1, 1024, embed_dim))  # fixed max length
        self.blocks = nn.ModuleList([
            OmniHullBlock(embed_dim, heads, petals, use_s4d=(i < depth - 1)) for i in range(depth)
        ])
        self.ln_f = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, vocab_size, bias=False)

    def forward(self, idx):
        B, S = idx.shape
        x = self.token_emb(idx) + self.pos_emb[:, :S, :]
        for blk in self.blocks:
            x = blk(x)
        logits = self.head(self.ln_f(x))
        return logits
#big, correct, and convex everywhere. She's your ideal woman.


In [15]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F



class ICNN(nn.Module):
    """
    Input-Convex Neural Network module.
    """
    def __init__(self, input_dim, hidden_dims):
        super().__init__()
        self.V = nn.Linear(input_dim, hidden_dims[0], bias=True)
        self.U = nn.Linear(hidden_dims[0], hidden_dims[0], bias=True)
        self.hidden = nn.ModuleList(
            nn.Linear(hidden_dims[i], hidden_dims[i+1], bias=True)
            for i in range(len(hidden_dims)-1)
        )
        nn.init.uniform_(self.U.weight, a=0.0, b=0.01)

    def forward(self, x):
        z = F.softplus(self.V(x))
        z = F.softplus(self.U(z) + z)
        for layer in self.hidden:
            z = F.softplus(layer(z))
        return z.mean(dim=-1, keepdim=True)

class HullModule(nn.Module):
    """
    Constructs a convex‐hull via log‐sum‐exp over gated ICNN petals.
    """
    def __init__(self, input_dim, hidden_dims, petals, beta=1.0, tanh_kwargs=None):
        super().__init__()
        # petals: a list of convex ICNN modules
        self.petals = nn.ModuleList([ICNN(input_dim, hidden_dims) for _ in range(petals)])
        # gate: produces a positive scalar per input, using TighteningTanh
        self.gate = nn.Sequential(
            nn.Linear(input_dim, 1, bias=True),
            nn.Softplus(),              # clamp ≥ 0
        )

    def forward(self, x):
        # x: (..., input_dim)
        # 1) compute a positive gate per example
        g = self.gate(x)          # (..., 1), ≥0
        # 2) scale inputs into each petal
        x_scaled = x * g          # broadcasting over last dim
        # 3) compute each petal’s score
        scores = torch.stack([p(x_scaled) for p in self.petals], dim=-1)  # (..., petals)
        # 4) log‐sum‐exp hull
        return  torch.logsumexp(scores, dim=-1, keepdim=False)

class DotSelfAttentionWithRoPE(nn.Module):
    """
    Standard multi-head dot-product self-attention with RoPE.
    """
    def __init__(self, embed_dim: int, heads: int):
        super().__init__()
        assert embed_dim % heads == 0, "embed_dim must be divisible by heads"
        self.h = heads
        self.d_k = embed_dim // heads

        self.W_Q = nn.Linear(embed_dim, embed_dim, bias=True)
        self.W_K = nn.Linear(embed_dim, embed_dim, bias=True)
        self.W_V = nn.Linear(embed_dim, embed_dim, bias=True)
        self.W_O = nn.Linear(embed_dim, embed_dim, bias=True)

    @staticmethod
    def _rope(x, sin, cos):
        # x: (..., D) where D = 2 * (D/2)
        x1, x2 = x[..., ::2], x[..., 1::2]
        y = torch.stack([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
        return y.flatten(-2)

    @staticmethod
    def _get_sin_cos(seq_len:int, d_half:int, device:torch.device):
        pos  = torch.arange(seq_len, device=device).unsqueeze(1)           # (S,1)
        freq = 1. / (10000 ** (torch.arange(0, d_half, 2, device=device) / d_half))  # (D/2,)
        ang  = pos * freq                                                  # (S,D/2)
        sin, cos = torch.sin(ang), torch.cos(ang)
        # shape (1, S, 1, D/2) so they broadcast over batch & heads
        return sin[None, :, None, :], cos[None, :, None, :]

    def forward(self, x):                         # x: (B, S, E)
        B, S, _ = x.shape

        # 1) project and split heads → (B, S, H, D)
        q = self.W_Q(x).view(B, S, self.h, self.d_k)
        k = self.W_K(x).view(B, S, self.h, self.d_k)
        v = self.W_V(x).view(B, S, self.h, self.d_k)

        # 2) apply RoPE
        sin, cos = self._get_sin_cos(S, self.d_k, x.device)
        q = self._rope(q, sin, cos)
        k = self._rope(k, sin, cos)

        # 3) reshape for matmul: (B, H, S, D)
        q = q.permute(0, 2, 1, 3)
        k = k.permute(0, 2, 1, 3)
        v = v.permute(0, 2, 1, 3)

        # 4) scaled dot-product: (B, H, S, S)
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)

        # 5) softmax and attention output: (B, H, S, D)
        attn = F.softmax(scores, dim=-1)
        out  = torch.matmul(attn, v)

        # 6) merge heads back → (B, S, E) and final projection
        out = out.permute(0, 2, 1, 3).reshape(B, S, self.h * self.d_k)
        return self.W_O(out)

class OmniHullBlock(nn.Module):
    def __init__(self, embed_dim, heads, petals):
        super().__init__()
        self.attn = DotSelfAttentionWithRoPE(embed_dim, heads)
        self.hull_mlp = HullModule(embed_dim, [embed_dim * 2, embed_dim], petals)
        self.ln1 = nn.LayerNorm(embed_dim)
        self.ln2 = nn.LayerNorm(embed_dim)

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.hull_mlp(self.ln2(x))
        return x

class SkinnyLittleGPT(nn.Module):
    """
    NanoGPT-style model with recursive convex-hull ICNN modules and RoPE.
    """
    def __init__(self, vocab_size, embed_dim, depth, heads, petals):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, embed_dim)
        self.blocks = nn.ModuleList([
            OmniHullBlock(embed_dim, heads, petals)
            for _ in range(depth)
        ])
        self.ln_f = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, vocab_size, bias=False)

    def forward(self, idx):

        x = self.token_emb(idx)
        for block in self.blocks:
            x = block(x)
        logits = self.head(self.ln_f(x))
        return logits

# shes got a flat butt but shes faster

In [None]:
#currently, skinnylittlegpt is in mode collapse.
#maybe her butt just isnt big enough. maybe she isnt done being baked yet.
#shes not really convex, but, she has some convexity.
#bigfatmommy is very very convex. in many ways. rounded. Rounded lumps.
#perfectly so? no. But there are tradeoffs involved.
#mommy also collapses almost immediately, outputting nothing but spaces.
#am i pushing her too hard?

In [None]:
import os
import pickle
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
from torch.optim.optimizer import Optimizer
device    = 'cuda' if torch.cuda.is_available() else 'cpu'

@torch.jit.script
def wolf_update(p: torch.Tensor,
                g: torch.Tensor,
                state_p: torch.Tensor,
                lr: float):
    # define your constants here instead of capturing them
    etcerta: float = 0.367879441
    et:      float = 1.0 - etcerta

    # same logic as before
    update    = state_p * et + g * etcerta
    new_state = state_p * et + update * etcerta
    sign_agree = torch.sign(update) * torch.sign(g)
    update    = update + (torch.rand_like(update)*2 - 1) * etcerta * update
    p_new     = torch.where(sign_agree > 0, p - lr * update, p)
    return p_new, new_state

class Wolf(Optimizer):
    def __init__(self, params, lr=1e-3):
        defaults = dict(lr=lr)
        super().__init__(params, defaults)
        for group in self.param_groups:
            for p in group['params']:
                self.state[p]['p'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure=None):
        loss = closure() if closure is not None else None
        for group in self.param_groups:
            lr = group['lr']
            for p in group['params']:
                if p.grad is None:
                    continue
                state_p = self.state[p]['p']
                p_new, new_state = wolf_update(p.data, p.grad, state_p, lr)
                p.data.copy_(p_new)
                state_p.copy_(new_state)
        return loss

# 1) Load data and meta as before
data_dir  = os.path.dirname(base_dir)
train_ids = np.fromfile(os.path.join(data_dir, 'train.bin'), dtype=np.uint16)
val_ids   = np.fromfile(os.path.join(data_dir, 'val.bin'),   dtype=np.uint16)
with open(os.path.join(data_dir, 'meta.pkl'), 'rb') as f:
    meta = pickle.load(f)
vocab_size = meta['vocab_size']

# 2) Compute data‐marginal q[v]
counts = np.bincount(train_ids, minlength=vocab_size).astype(float)
q = torch.tensor(counts / counts.sum(), dtype=torch.float32, device=device)  # [V]

# 3) Dataset + DataLoader
class CharDataset(Dataset):
    def __init__(self, data, block_size):
        self.data = torch.from_numpy(data).long()
        self.block_size = block_size
    def __len__(self):
        return len(self.data) - self.block_size
    def __getitem__(self, idx):
        x = self.data[idx : idx + self.block_size]
        y = self.data[idx + 1 : idx + self.block_size + 1]
        return x, y

block_size = 128
train_loader = DataLoader(CharDataset(train_ids, block_size),
                          batch_size=8, shuffle=True, drop_last=True)
val_loader   = DataLoader(CharDataset(val_ids,   block_size),
                          batch_size=8, shuffle=False, drop_last=True)

# 4) Model, optimizer, loss
virgin = BigFatMommaGPT(
        vocab_size = vocab_size,
        embed_dim  = 384,
        depth      = 4,
        heads      = 4,
        petals     = 2
)

# (Optional) Re‑initialise *only* the PositiveLinear layers:
print("Number of parameters: ", sum(p.numel() for p in virgin.parameters()))
model = torch.jit.script(virgin)
model.to(device)
optimizer = Wolf(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

λ_ent         = 0.5    # encourage weight diversification
λ_kl          = 0.5     # discourage mode collapse

# 6) Train / eval functions
def train_epoch():
    model.train()
    total_loss = 0
    for xb, yb in train_loader:
        xb, yb = xb.to(device), yb.to(device)

        # Forward
        logits = model(xb)                 # (B, T, V)
        B, T, V = logits.shape
        p = F.softmax(logits, dim=-1)      # (B, T, V)

        # 1) Standard CE
        ce_loss = criterion(logits.view(B*T, V),
                            yb.view(B*T))

        # 2) Entropy penalty
        ent = -(p * (p + 1e-12).log()).sum(dim=-1)  # (B, T)
        ent_loss = ent.mean()

        p_m = p.mean(dim=(0,1))            # [V]
        kl_loss = (p_m * (p_m + 1e-12).log() - p_m * q.log()).sum()

        # 4) Combined loss
        loss = ce_loss + λ_ent * ent_loss + λ_kl * kl_loss

        # Backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print(loss.item())
        total_loss += loss.item()
    return total_loss / len(train_loader)

@torch.no_grad()
def eval_epoch():
    model.eval()
    total_loss = 0
    for xb, yb in val_loader:
        xb, yb = xb.to(device), yb.to(device)
        logits = model(xb)
        B, T, V = logits.shape
        total_loss += criterion(logits.view(B*T,V),
                                yb.view(B*T)).item()
    return total_loss / len(val_loader)

# 7) Run training
num_epochs = 10
for epoch in range(1, num_epochs+1):
    train_loss = train_epoch()
    val_loss   = eval_epoch()
    print(f"Epoch {epoch:2d} | train: {train_loss:.4f} | val: {val_loss:.4f}")


Number of parameters:  5799839
7.089759349822998
7.059626579284668
7.024545192718506
7.020734786987305
6.986845016479492
6.977483749389648
6.955672740936279
6.915209770202637
6.830282211303711
6.808751106262207
6.788784027099609
6.764035701751709
6.756332874298096
6.671619415283203
6.662999629974365
6.598953723907471
6.588779926300049
6.520735263824463
6.512465953826904
6.502041339874268


In [27]:
optimizer = Wolf(model.parameters(), lr=0.5)#adam would explode at this training rate
criterion = nn.CrossEntropyLoss()

λ_ent         = 0.5    # encourage weight diversification
λ_kl          = 0.5     # discourage mode collapse

# 6) Train / eval functions
def train_epoch():
    model.train()
    total_loss = 0
    for xb, yb in train_loader:
        xb, yb = xb.to(device), yb.to(device)

        # Forward
        logits = model(xb)                 # (B, T, V)
        B, T, V = logits.shape
        p = F.softmax(logits, dim=-1)      # (B, T, V)

        # 1) Standard CE
        ce_loss = criterion(logits.view(B*T, V),
                            yb.view(B*T))

        # 2) Entropy penalty
        ent = -(p * (p + 1e-12).log()).sum(dim=-1)  # (B, T)
        ent_loss = ent.mean()

        p_m = p.mean(dim=(0,1))            # [V]
        kl_loss = (p_m * (p_m + 1e-12).log() - p_m * q.log()).sum()

        # 4) Combined loss
        loss = ce_loss + λ_ent * ent_loss + λ_kl * kl_loss

        # Backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print(loss.item())
        total_loss += loss.item()
    return total_loss / len(train_loader)

@torch.no_grad()
def eval_epoch():
    model.eval()
    total_loss = 0
    for xb, yb in val_loader:
        xb, yb = xb.to(device), yb.to(device)
        logits = model(xb)
        B, T, V = logits.shape
        total_loss += criterion(logits.view(B*T,V),
                                yb.view(B*T)).item()
    return total_loss / len(val_loader)

# 7) Run training
num_epochs = 10
for epoch in range(1, num_epochs+1):
    train_loss = train_epoch()
    val_loss   = eval_epoch()
    print(f"Epoch {epoch:2d} | train: {train_loss:.4f} | val: {val_loss:.4f}")

0.14970003068447113
0.10967803746461868
0.116947703063488
0.1101749986410141
0.10924039781093597
0.1043446809053421
0.11175236850976944
0.09217088669538498
0.09896231442689896
0.09130249172449112
0.09938418865203857
0.09552960842847824
0.09382826834917068
0.10704723000526428
0.12667648494243622
0.08107704669237137
0.08900757879018784
0.10800237953662872
0.08867326378822327
0.10503402352333069
0.09006978571414948
0.09783803671598434
0.10546399652957916
0.1254078894853592
0.0985318124294281
0.09559228271245956
0.08799614757299423
0.07349701970815659


KeyboardInterrupt: 

In [28]:
# --- helpers ---------------------------------------------------------
def fenchel_decode(logits, tau=1.0, iters=3):
    """Fenchel‑dual KL‑regularised projection of -logits (energy)."""
    energy = -logits                        # (B,V)
    p = torch.full_like(energy, 1.0 / energy.size(-1))  # uniform start
    for _ in range(iters):
        p = torch.softmax((-energy / tau) + p.log(), dim=-1)
    return p

# --- generation ------------------------------------------------------
use_fenchel   = False          # flip to compare
tau           = 1.0           # λ  (temperature analogue)
max_new_tokens = 200
top_k          = 25
block_size     = 128
temperature    = 1.0

bcontext_str = "To be, or not to be,"
context_ids = torch.tensor([[ stoi[c] for c in bcontext_str ]],
                           dtype=torch.long)
context_ids = context_ids.to(device)

generated = context_ids.clone()  # (1,T0)

for _ in range(max_new_tokens):
    input_ids = generated[:, -block_size:]        # casual block
    logits = model(input_ids)                     # (1,cur_T,V)
    logits = logits[:, -1, :] / temperature       # (1,V)

    # top‑k mask
    if top_k is not None:
        v, _ = torch.topk(logits, top_k)
        logits[logits < v[:, [-1]]] = -1e10

    if use_fenchel:
        probs = fenchel_decode(logits, tau=tau, iters=3)
    else:
        probs = torch.softmax(logits, dim=-1)

    next_id = torch.multinomial(probs, num_samples=1)   # (1,1)
    generated = torch.cat([generated, next_id], dim=1)

print('> ', ''.join(itos[i] for i in generated[0].tolist()))

>  To be, or not to be,We,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,, d,,                                          
