In [1]:
"""
Prepare the Shakespeare dataset for character-level language modeling.
So instead of encoding with GPT-2 BPE tokens, we just map characters to ints.
Will save train.bin, val.bin containing the ids, and meta.pkl containing the
encoder and decoder and some other related info.
"""
import os
import pickle
import requests
import numpy as np
import os
from pathlib import Path

try:
    base_dir = Path(__file__).parent
except NameError:
    base_dir = Path(os.getcwd())  # fallback if __file__ is not defined (e.g. in REPL)
# download the tiny shakespeare dataset
input_file_path = os.path.join(os.path.dirname(base_dir), 'input.txt')
if not os.path.exists(input_file_path):
    data_url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
    with open(input_file_path, 'w') as f:
        f.write(requests.get(data_url).text)

with open(input_file_path, 'r') as f:
    data = f.read()
print(f"length of dataset in characters: {len(data):,}")

# get all the unique characters that occur in this text
chars = sorted(list(set(data)))
vocab_size = len(chars)
print("all the unique characters:", ''.join(chars))
print(f"vocab size: {vocab_size:,}")

# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
def encode(s):
    return [stoi[c] for c in s] # encoder: take a string, output a list of integers
def decode(l):
    return ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

# create the train and test splits
n = len(data)
train_data = data[:int(n*0.9)]
val_data = data[int(n*0.9):]

# encode both to integers
train_ids = encode(train_data)
val_ids = encode(val_data)
print(f"train has {len(train_ids):,} tokens")
print(f"val has {len(val_ids):,} tokens")

# export to bin files
train_ids = np.array(train_ids, dtype=np.uint16)
val_ids = np.array(val_ids, dtype=np.uint16)
train_ids.tofile(os.path.join(os.path.dirname(base_dir), 'train.bin'))
val_ids.tofile(os.path.join(os.path.dirname(base_dir), 'val.bin'))

# save the meta information as well, to help us encode/decode later
meta = {
    'vocab_size': vocab_size,
    'itos': itos,
    'stoi': stoi,
}
with open(os.path.join(os.path.dirname(base_dir), 'meta.pkl'), 'wb') as f:
    pickle.dump(meta, f)

length of dataset in characters: 1,115,394
all the unique characters: 
 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
vocab size: 65
train has 1,003,854 tokens
val has 111,540 tokens


In [320]:
#copyright joshuah rainstar Joshuah.rainstar@gmail.com 2025
#covered under Gratis Public License

import torch
import torch.nn as nn
import math
from torch_mice import (
    BatchedICNN,
    PositiveLinearHK,
    BatchAffineNorm,
)
import scipy.fftpack
from matplotlib import pyplot as plt
def dct_basis(L, k):
    return torch.tensor(scipy.fftpack.dct(np.eye(L), norm='ortho')[:k], dtype=torch.float32)
    
class BatchedSingleICNN(nn.Module):
    """
    A minimal wrapper for a BatchedICNN with petals=1, matching VectorHull semantics.
    This flattens (B, S, D) → (N=D*B, D), adds a petal dimension, runs the ICNN,
    and restores the (B, S, D_out) output.
    """
    def __init__(self, in_dim: int, out_dim: int):
        super().__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.icnn = BatchedICNN(in_dim, petals=1, out_dim=out_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (B, S, D_in)
        Returns:
            output: (B, S, D_out)
        """
        B, S, D = x.shape
        assert D == self.in_dim, f"Expected last dim {self.in_dim}, got {D}"
        N = B * S

        x_flat = x.reshape(N, D)                 # (N, D)
        x_proj = x_flat.unsqueeze(0)             # (1, N, D) — single petal

        # BatchedICNN forward takes (P, N, D), (N, D) → (N, P, D_out)
        out = self.icnn(x_proj, x_flat)          # (N, 1, D_out)
        out = out.squeeze(1)                     # (N, D_out)

        return out.reshape(B, S, self.out_dim)   # (B, S, D_out)

        
# === VAE-based Query Generator ===
class DisentangledQueryGenerator(nn.Module):
    def __init__(self, hash_dim, static_dim, dynamic_dim, out_shape):
        super().__init__()
        self.out_shape = out_shape  # (L, D)
        L, D = out_shape
        self.L = L
        self.D = D
        self.static_dim = static_dim
        self.dynamic_dim = dynamic_dim

        # z_f ~ q(z_f | h)
        self.encoder_f = nn.Sequential(
            nn.Linear(hash_dim, 64),
            nn.ReLU(),
            nn.Linear(64, static_dim * 2)
        )

        # z_t ~ q(z_t | h), now outputs per-position latents
        self.encoder_t = nn.Sequential(
            nn.Linear(hash_dim, L * dynamic_dim * 2)
        )

        # Decoder maps (z_f ⊕ z_t) → Q_t
        self.decoder = BatchedSingleICNN(in_dim=static_dim + dynamic_dim, out_dim=D)

    def forward(self, h):  # h: (B, hash_dim)
        B = h.size(0)
        L, D = self.out_shape
        D_f = self.static_dim
        D_t = self.dynamic_dim

        # -- Static latent z_f --
        mu_f, logvar_f = self.encoder_f(h).chunk(2, dim=-1)  # (B, D_f)
        std_f = torch.exp(0.5 * logvar_f)
        eps_f = torch.randn_like(std_f)
        z_f = mu_f + eps_f * std_f                            # (B, D_f)

        # -- Dynamic latent z_t (per position) --
        mu_logvar_t = self.encoder_t(h).view(B, L, 2 * D_t)   # (B, L, 2*D_t)
        mu_t, logvar_t = mu_logvar_t.chunk(2, dim=-1)         # (B, L, D_t)
        std_t = torch.exp(0.5 * logvar_t)
        eps_t = torch.randn_like(std_t)
        z_t = mu_t + eps_t * std_t                            # (B, L, D_t)

        # -- Expand static z_f across positions --
        z_f_exp = z_f.unsqueeze(1).expand(-1, L, -1)          # (B, L, D_f)

        # -- Concatenate (B, L, D_f + D_t) --
        z_cat = torch.cat([z_f_exp, z_t], dim=-1)

        # -- Decode Q: (B, L, D)
        Q = self.decoder(z_cat)                               # (B, L, D)

        # -- KL divergence with free bits
        kl_f = -0.5 * (1 + logvar_f - mu_f.pow(2) - logvar_f.exp())
        kl_t = -0.5 * (1 + logvar_t - mu_t.pow(2) - logvar_t.exp())

        kl_f = torch.clamp(kl_f, min=0.2).sum(dim=-1).mean()           # scalar
        kl_t = torch.clamp(kl_t, min=0.2).sum(dim=-1).mean()           # scalar

        return Q, kl_f + kl_t

class ConvexHyperQueryGenerator(nn.Module):
    def __init__(self, hash_dim, out_shape):
        super().__init__()
        self.out_shape = out_shape
        out_dim = int(torch.prod(torch.tensor(out_shape)))

        self.net = nn.Sequential(
            PositiveLinearHK(hash_dim, 128),
            nn.ReLU(),
            PositiveLinearHK(128, 128),
            nn.ReLU(),
            nn.Linear(128, out_dim)  # allows directional freedom
        )

    def forward(self, h):
        Q = self.net(h).view(h.size(0), *self.out_shape)
        return Q
        
# === Hypernetwork-based Key Generator ===
class ConvexHyperKeyGenerator(nn.Module):
    def __init__(self, hash_dim, out_shape):
        super().__init__()
        self.out_shape = out_shape
        out_dim = int(torch.prod(torch.tensor(out_shape)))

        self.net = nn.Sequential(
            PositiveLinearHK(hash_dim, 128),
            nn.ReLU(),
            PositiveLinearHK(128, 128),
            nn.ReLU(),
            nn.Linear(128, out_dim)  # <--- Unconstrained
        )
    def forward(self, h):
        K = self.net(h).view(h.size(0), *self.out_shape)
        return K


class ConvexICNNHash(nn.Module):
    def __init__(self, seq_len, hash_dim):
        super().__init__()
        self.icnn = BatchedSingleICNN(in_dim=seq_len, out_dim=hash_dim)
        self.final_activation = nn.Softplus()

    def forward(self, x):  # x: (B, L, D)
        x_mean = x.mean(dim=-1)               # (B, L)
        x_for_icnn = x_mean.unsqueeze(1)      # (B, 1, L)
        h = self.icnn(x_for_icnn).squeeze(1)  # (B, hash_dim)
        return  h / (h.norm(dim=-1, keepdim=True) + 1e-6)

def normalized_alpha_softplus_attention(Q, K, V, alpha=1.5, tau=0.5, eps=1e-6, log=True):
    # 1. Scaled dot product
    logits = torch.einsum("bqd,bkd->bqk", Q, K)  # both are (B, L, D)
    logits = logits / math.sqrt(Q.size(-1))       # scale by √D
  


    # 3. α-softplus-like activation (optional smooth ReLU)
    # Using F.softplus here instead of raw relu
    scores = F.softplus((alpha - 1) * logits - tau) ** (1 / (alpha - 1))
    # 4. Normalize manually (row-wise softmax)
    weights = scores / (scores.sum(dim=-1, keepdim=True) + eps)
    attn_score = weights.sum(dim=2)  # (B, K)

    # Normalize each sample separately (min-max per row)
    min_vals = attn_score.min(dim=-1, keepdim=True).values
    max_vals = attn_score.max(dim=-1, keepdim=True).values
    attn_score = (attn_score - min_vals) / (max_vals - min_vals + 1e-6)  # (B, K)
    

    # 5. Apply attention
    return torch.matmul(weights, V),attn_score
    
class PerceptualAttentionBlock(nn.Module):
    def __init__(self, model_dim, hash_dim, latent_dim, seq_len):
        super().__init__()
        self.hash = ConvexICNNHash(seq_len, hash_dim)
        self.query_gen = ConvexHyperQueryGenerator(hash_dim, (seq_len, model_dim))
        self.key_gen = ConvexHyperKeyGenerator(hash_dim, (seq_len, model_dim))
        self.q_proj = nn.Linear(model_dim, model_dim)
        self.k_proj = nn.Linear(model_dim, model_dim)

        
    def forward(self, x):
        """
        x: (B, L, D) — used as Q input, K input, and V source
        Returns:
            attention output: (B, L, D)
            kl divergence from VAE
        """
        h = self.hash(x)                # (B, hash_dim)
        Q = self.query_gen(h)    # (B, L, D)
        K = self.key_gen(h)            # (B, L, D)
        Q = self.q_proj(Q)
        K = self.k_proj(K)
        K = F.normalize(K, dim=-1)
        Q = F.normalize(Q, dim=-1)  # Unit direction: OK


       
        attn_out,attn_score = normalized_alpha_softplus_attention(Q, K, x)
        return attn_out, 0 ,attn_score


        
class ConvexBlock(nn.Module):
    def __init__(self, model_dim: int, seq_len: int, hash_dim: int, latent_dim: int):
        super().__init__()
        self.attn  = PerceptualAttentionBlock(model_dim, hash_dim, latent_dim, seq_len)
        self.norm1 = BatchAffineNorm(model_dim)
        self.icnn1 = BatchedSingleICNN(in_dim=model_dim, out_dim=model_dim)
        self.norm2 = BatchAffineNorm(model_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B, L, dim)
        att,kl, attn_score= self.attn(self.norm1(x))
        x = x + att       # → (B, L, dim)

        # ICNN #1 + norm
        x = x + self.icnn1(self.norm2(x))          # → (B, L, dim)

        # residual add
        return x,kl , attn_score                      # → (B, L, dim)


class ConvexLanguageModel(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        model_dim: int = 128,
        seq_len: int = 128,
        num_blocks: int = 6,
        hash_dim: int = 32,
        latent_dim: int = 16,
    ):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, model_dim)
        self.wpe = nn.Embedding(seq_len, model_dim)
        self.num_blocks = num_blocks
        self.blocks = nn.ModuleList([
            ConvexBlock(model_dim, seq_len, hash_dim, latent_dim)
            for _ in range(num_blocks)
        ])

        self.decoder = BatchedSingleICNN(in_dim=model_dim, out_dim=vocab_size)

        

    def forward(self, input_ids: torch.LongTensor) -> torch.Tensor:
        """
        input_ids: (B, L)
        returns   : logits (B, L, vocab_size)
        """
        # -- Embed
    
        x = self.embedding(input_ids)            # (B, L, E)
        B,t,_ = x.shape
        pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)
        pos_emb = self.wpe(pos) # position embeddings of shape (t, n_embd)
        x = x + pos_emb
        # -- Convex blocks
        kl = 0.0
        attn_scores = []
        for block in self.blocks:
            x, klt ,attn_score = block(x)                         # (B, L, E)
            kl += klt
            attn_scores.append(attn_score)

        # -- Decode
        attn_vis = torch.stack(attn_scores).mean(dim=0)
        logits = self.decoder(x)                 # (B, L, V)
        return logits,kl/self.num_blocks,attn_vis

In [321]:
import torch

def convexity_test(block, B=2, L=4, D=12, tol=1e-6):
    block.eval()
    # two random inputs
    X1 = torch.rand(B, L, D)
    X2 = torch.rand(B, L, D)
    λ  = 0.3

    # mixed input
    Xm = λ*X1 + (1-λ)*X2

    with torch.no_grad():
        f1,_,_ = block(X1)
        f2,_,_ = block(X2)
        fm,_,_ = block(Xm)

    # Convexity: f(λx1+(1−λ)x2) ≤ λ f(x1) + (1−λ) f(x2)
    rhs = λ*f1 + (1-λ)*f2
    if torch.all(fm <= rhs + tol):
        print("✅ Block is numerically convex (within tol).")
    else:
        diff = (fm - rhs).clamp(min=0)
        print(f"❌ Violation max Δ = {diff.max().item():.3e}")

# Example usage
block = ConvexBlock(model_dim=12, seq_len=4, hash_dim=6, latent_dim=24)
convexity_test(block)


✅ Block is numerically convex (within tol).


In [322]:
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import torch
import ipywidgets as widgets
from IPython.display import display, clear_output
import time

# --- Configuration Constants ---
CHAR_WIDTH = 8  # Font size 8 for token rendering
CHAR_HEIGHT = 11
SEQ_LEN = 128
BATCH_SIZE = 16
LOSS_BAR_HEIGHT = 32
EWMA_HEIGHT = 32  # Increased to accommodate large text (previously 32)

# Full-resolution framebuffer dimensions
container_width = CHAR_WIDTH * SEQ_LEN  # 1024 pixels
container_height = CHAR_HEIGHT * BATCH_SIZE  # 176 pixels
total_height = container_height + LOSS_BAR_HEIGHT + EWMA_HEIGHT  # Adjusted for larger EWMA

# Final scaled-down dimensions
scaled_width = container_width   # 512 pixels
scaled_height = total_height  # 170 pixels

# Initialize framebuffer
framebuffer = np.zeros((total_height, container_width, 3), dtype=np.uint8)

# EWMA storage
ticker_history = np.zeros(SEQ_LEN, dtype=np.float32)  # Stock ticker moving buffer
loss_memory = 0.0
# Load font
try:
    font = ImageFont.truetype("DejaVuSansMono.ttf", 8)  # Monospaced font
    font_large = ImageFont.truetype("DejaVuSansMono.ttf", 64)  # Large EWMA display
except:
    font = ImageFont.load_default()
    font_large = font

# --- Color Mapping Functions ---
def get_flame_color(val):
    """Map a normalized value to a flame-like color."""
    return np.array([int(val * 255), int(val * 0.5 * 255), 0], dtype=np.uint8)

# --- IPython Display Setup ---
out = widgets.Output()
display(out)

def get_dynamic_color(attn_val, loss_val):
    """
    Compute a dynamic color transition between flame orange (uncertain) and phosphor green (confident).

    attn_val: Normalized attention value (0 to 1)
    loss_val: Normalized loss value (0 to 1, inverted as 1 - loss)

    Returns an RGB color as a NumPy array.
    colors late in training will often be red. this is suggested to swap out for get_flame_color
    but only on fine tuning on new data.
    """
    certainty = 1 - loss_val  # High certainty = low loss

    # Define RGB endpoints
    orange = np.array([attn_val * 255, attn_val * 0.5 * 255, 0], dtype=np.uint8)   # Uncertain (High Loss)
    green = np.array([attn_val * 0.5 * 255, attn_val * 255, attn_val * 0.25 * 255], dtype=np.uint8)  # Confident (Low Loss)

    # Interpolate based on certainty (0 = uncertain/orange, 1 = confident/green)
    color = (certainty * green) + ((1 - certainty) * orange)

    return color.astype(np.uint8)
def normalize_rows(x: np.ndarray) -> np.ndarray:
    min_val = np.min(x, axis=1, keepdims=True)
    max_val = np.max(x, axis=1, keepdims=True)
    scale = max_val - min_val
    return (x - min_val) / (scale + 1e-16)
    
# --- Framebuffer Update Function ---
def update_framebuffer(attn_weights, token_losses, current_loss, tokens):
    token_losses = normalize_rows(token_losses)
    attn_weights = normalize_rows(attn_weights)
    """Render the text grid with coloration based on attn * inverse loss."""
    global framebuffer, loss_history, ticker_history, loss_memory

    # Normalize to [0,1]

    # Create image buffer
    img = Image.new("RGB", (container_width, total_height), (0, 0, 0))
    draw = ImageDraw.Draw(img)

    # Render text with colored intensity
    char_positions = [
        (col * CHAR_WIDTH, row * CHAR_HEIGHT + EWMA_HEIGHT + LOSS_BAR_HEIGHT, tokens[row][col])
        for row in range(BATCH_SIZE) for col in range(SEQ_LEN)
    ]
    colors = [
        tuple(get_dynamic_color(attn_weights[row, col], token_losses[row, col]))
        for row in range(BATCH_SIZE) for col in range(SEQ_LEN)
    ]
    for (x, y, char), color in zip(char_positions, colors):
        draw.text((x, y), char, font=font, fill=color)


    etcerta = 0.367879441  # Constant used in update rule
    et = 1 - etcerta
    update = loss_memory * et + np.minimum(12, np.maximum(current_loss , 0)) * etcerta
    loss_memory = loss_memory * et + update * etcerta
    # --- EWMA Display (LARGE FONT) ---
    ewma = loss_memory
    ewma_text = f"{ewma:.4f}"
    draw.text((container_width-128, 0), ewma_text, font_size=32, fill=(65,255, 125))

    # --- Moving Loss Ticker Graph ---
    ticker_history = np.roll(ticker_history, -1)  # Shift left
    ticker_history[-1] = current_loss  # Insert new loss on the right

    # Rescale ticker dynamically like a stock ticker (normalize to min-max range)
    min_loss = np.min(ticker_history)
    max_loss = np.max(ticker_history)
    range_loss = max_loss - min_loss if max_loss != min_loss else 1
    normalized_ticker = (ticker_history - min_loss) / range_loss

    # Draw ticker graph line
    # Optimized drawing loop (fewer function calls)
    y_vals = EWMA_HEIGHT + (1 - normalized_ticker) * LOSS_BAR_HEIGHT
    x_vals = np.arange(SEQ_LEN) * CHAR_WIDTH
    for i in range(SEQ_LEN - 1):
        draw.line([(x_vals[i], y_vals[i]), (x_vals[i + 1], y_vals[i + 1])], fill=(0, 255, 255), width=2)

    framebuffer = np.array(img)

# --- IPython Display Update Function ---
def update_display():
    """Show the framebuffer, scaled down by half using ipywidgets."""
    img = Image.fromarray(framebuffer)
    img_resized = img.resize((scaled_width, scaled_height), Image.LANCZOS)

    with out:
        clear_output(wait=True)
        display(img_resized)

loss_history = []

Output()

In [None]:
import os
import pickle
import numpy as np
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
from torch.optim.optimizer import Optimizer
from torch import nn
import torch
import torch.nn.functional as F


device = "cpu"
def wolf_update(p: torch.Tensor,
                g: torch.Tensor,
                state_p: torch.Tensor,
                lr: float):
    # define your constants here instead of capturing them
    etcerta: float = 0.367879441
    et:      float = 1.0 - etcerta

    # same logic as before
    update    = state_p * et + g * etcerta
    new_state = state_p * et + update * etcerta
    sign_agree = torch.sign(update) * torch.sign(g)
    update    = update + (torch.rand_like(update)*2 - 1) * etcerta * update
    p_new     = torch.where(sign_agree > 0, p - lr * update, p)
    return p_new, new_state

class Wolf(Optimizer):
    def __init__(self, params, lr=1e-3):
        defaults = dict(lr=lr)
        super().__init__(params, defaults)
        for group in self.param_groups:
            for p in group['params']:
                self.state[p]['p'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure=None):
        loss = closure() if closure is not None else None
        for group in self.param_groups:
            lr = group['lr']
            for p in group['params']:
                if p.grad is None:
                    continue
                state_p = self.state[p]['p']
                p_new, new_state = wolf_update(p.data, p.grad, state_p, lr)
                p.data.copy_(p_new)
                state_p.copy_(new_state)
        return loss

# 1) Load data and meta as before
data_dir  = os.path.dirname(base_dir)
train_ids = np.fromfile(os.path.join(data_dir, 'train.bin'), dtype=np.uint16)
val_ids   = np.fromfile(os.path.join(data_dir, 'val.bin'),   dtype=np.uint16)
with open(os.path.join(data_dir, 'meta.pkl'), 'rb') as f:
    meta = pickle.load(f)
vocab_size = meta['vocab_size']

# 2) Compute data‐marginal q[v]
counts = np.bincount(train_ids, minlength=vocab_size).astype(float)
q = torch.tensor(counts / counts.sum(), dtype=torch.float32, device=device)  # [V]

# 3) Dataset + DataLoader
class CharDataset(Dataset):
    def __init__(self, data, block_size):
        self.data = torch.from_numpy(data).long()
        self.block_size = block_size
    def __len__(self):
        return len(self.data) - self.block_size
    def __getitem__(self, idx):
        x = self.data[idx : idx + self.block_size]
        y = self.data[idx + 1 : idx + self.block_size + 1]
        return x, y

block_size = 128
train_loader = DataLoader(CharDataset(train_ids, block_size),
                          batch_size=16, shuffle=True, drop_last=True)
val_loader   = DataLoader(CharDataset(val_ids,   block_size),
                          batch_size=16, shuffle=False, drop_last=True)
virgin  = ConvexLanguageModel(vocab_size=65, model_dim = 128,seq_len = 128,num_blocks = 6, hash_dim = 64,latent_dim = 16)
print(sum(p.numel() for p in virgin.parameters()))

print("Number of parameters: ", sum(p.numel() for p in virgin.parameters()))
model = virgin.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-2)#or adam, but i prefer the WOLF.
criterion = nn.CrossEntropyLoss()
losses = []
beta = 0.01
# 6) Train / eval functions
def train_epoch():
    model.train()
    total_loss = 0
    for xb, yb in train_loader:
        xb, yb = xb.to(device), yb.to(device)
        optimizer.zero_grad()
        logits , kl, attn_weights = model(xb)
        B, T, V = logits.shape
        per_token_loss = F.cross_entropy(
            logits.view(-1, logits.size(-1)),
            yb.view(-1),
            reduction='none'  # This gives the raw loss per token
        ).reshape(B,T) # Shape: (B*S,)
        loss = per_token_loss.mean() + kl *0.1
        loss_cpu = per_token_loss.cpu().detach().numpy()
        tokens = [[itos[idx] for idx in seq.tolist()] for seq in yb]
        attn_cpu = attn_weights.cpu().detach().numpy()
        update_framebuffer(attn_cpu, loss_cpu, loss.item(), tokens)
        update_display()

        # Backprop
        loss.backward()
    
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    
        optimizer.step()
        print(loss.item())
        total_loss += loss.item()
        losses.append(loss.item())
    return total_loss / len(train_loader)

@torch.no_grad()
def eval_epoch():
    model.eval()
    total_loss = 0
    for xb, yb in val_loader:
        xb, yb = xb.to(device), yb.to(device)
        logits = model(xb)
        B, T, V = logits.shape
        total_loss += criterion(logits.view(B*T,V),
                                yb.view(B*T)).item()
    return total_loss / len(val_loader)

# 7) Run training
num_epochs = 10
for epoch in range(1, num_epochs+1):
    train_loss = train_epoch()
    val_loss   = eval_epoch()
    print(f"Epoch {epoch:2d} | train: {train_loss:.4f} | val: {val_loss:.4f}")



# --- helpers ---------------------------------------------------------
def fenchel_decode(logits, tau=1.0, iters=3):
    """Fenchel‑dual KL‑regularised projection of -logits (energy)."""
    energy = -logits                        # (B,V)
    p = torch.full_like(energy, 1.0 / energy.size(-1))  # uniform start
    for _ in range(iters):
        p = torch.softmax((-energy / tau) + p.log(), dim=-1)
    return p



28115013
Number of parameters:  28115013
6811745.0
1745661.625
693612.375
322185.3125
128809.203125
77945.046875
42894.390625
19881.140625
2844.34619140625
172.14068603515625
12.848677635192871
4.09986686706543
4.08946418762207
4.077268600463867
4.04638671875
4.022488594055176
4.004036903381348
3.958981513977051
3.9515013694763184
3.910064935684204
3.901932716369629
3.868828058242798
3.8374361991882324
3.826965808868408
3.803696632385254
3.759023904800415
3.7373483180999756
3.7235372066497803
3.702165365219116
3.7042391300201416
3.666679859161377
3.610358715057373
3.5868782997131348
3.5799131393432617
3.5938656330108643
3.583686113357544
3.5562286376953125
3.5474319458007812
3.525012969970703
3.528042793273926
3.5587961673736572
3.5471444129943848
3.465327501296997
3.527066946029663
3.4418249130249023
3.5142877101898193
3.4188084602355957
3.425403118133545
3.4439144134521484
3.4973256587982178
3.43867564201355
3.477679491043091
3.448380947113037
3.426710844039917
3.4328293800354004
3.4

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-3)#or adam, but i prefer the WOLF.
criterion = nn.CrossEntropyLoss()
beta = 0.1
# 6) Train / eval functions
def train_epoch():
    model.train()
    total_loss = 0
    for xb, yb in train_loader:
        xb, yb = xb.to(device), yb.to(device)
        optimizer.zero_grad()
        logits,kl = model(xb)
        B, T, V = logits.shape
        p = F.softmax(logits, dim=-1)      # (B, T, V)
        # 1) Standard CE
        loss = criterion(logits.view(B*T, V),
                                yb.view(B*T))        # Forward
        loss = loss + beta * kl
        print(kl)
        # Backprop
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        print(loss.item())
        total_loss += loss.item()
        losses.append(loss.item())
    return total_loss / len(train_loader)

@torch.no_grad()
def eval_epoch():
    model.eval()
    total_loss = 0
    for xb, yb in val_loader:
        xb, yb = xb.to(device), yb.to(device)
        logits = model(xb)
        B, T, V = logits.shape
        total_loss += criterion(logits.view(B*T,V),
                                yb.view(B*T)).item()
    return total_loss / len(val_loader)

# 7) Run training
num_epochs = 10
for epoch in range(1, num_epochs+1):
    train_loss = train_epoch()
    val_loss   = eval_epoch()
    print(f"Epoch {epoch:2d} | train: {train_loss:.4f} | val: {val_loss:.4f}")



# --- helpers ---------------------------------------------------------
def fenchel_decode(logits, tau=1.0, iters=3):
    """Fenchel‑dual KL‑regularised projection of -logits (energy)."""
    energy = -logits                        # (B,V)
    p = torch.full_like(energy, 1.0 / energy.size(-1))  # uniform start
    for _ in range(iters):
        p = torch.softmax((-energy / tau) + p.log(), dim=-1)
    return p
