<a href="https://colab.research.google.com/github/falseywinchnet/PyITD/blob/main/FusedQKVT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
"""
Prepare the Shakespeare dataset for character-level language modeling.
So instead of encoding with GPT-2 BPE tokens, we just map characters to ints.
Will save train.bin, val.bin containing the ids, and meta.pkl containing the
encoder and decoder and some other related info.
"""
import os
import pickle
import requests
import numpy as np
import os
from pathlib import Path

try:
    base_dir = Path(__file__).parent
except NameError:
    base_dir = Path(os.getcwd())  # fallback if __file__ is not defined (e.g. in REPL)
# download the tiny shakespeare dataset
input_file_path = os.path.join(os.path.dirname(base_dir), 'input.txt')
if not os.path.exists(input_file_path):
    data_url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
    with open(input_file_path, 'w') as f:
        f.write(requests.get(data_url).text)

with open(input_file_path, 'r') as f:
    data = f.read()
print(f"length of dataset in characters: {len(data):,}")

# get all the unique characters that occur in this text
chars = sorted(list(set(data)))
vocab_size = len(chars)
print("all the unique characters:", ''.join(chars))
print(f"vocab size: {vocab_size:,}")

# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
def encode(s):
    return [stoi[c] for c in s] # encoder: take a string, output a list of integers
def decode(l):
    return ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

# create the train and test splits
n = len(data)
train_data = data[:int(n*0.9)]
val_data = data[int(n*0.9):]

# encode both to integers
train_ids = encode(train_data)
val_ids = encode(val_data)
print(f"train has {len(train_ids):,} tokens")
print(f"val has {len(val_ids):,} tokens")

# export to bin files
train_ids = np.array(train_ids, dtype=np.uint16)
val_ids = np.array(val_ids, dtype=np.uint16)
train_ids.tofile(os.path.join(os.path.dirname(base_dir), 'train.bin'))
val_ids.tofile(os.path.join(os.path.dirname(base_dir), 'val.bin'))

# save the meta information as well, to help us encode/decode later
meta = {
    'vocab_size': vocab_size,
    'itos': itos,
    'stoi': stoi,
}
with open(os.path.join(os.path.dirname(base_dir), 'meta.pkl'), 'wb') as f:
    pickle.dump(meta, f)

length of dataset in characters: 1,115,394
all the unique characters: 
 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
vocab size: 65
train has 1,003,854 tokens
val has 111,540 tokens


#if you use my ideas, please credit me, dont just steal
joshuah.rainstar@gmail.com


In [71]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class BandHannFilter(nn.Module):
    """
    Extracts the designated sub-band of x via RFFT + Hann window.
    """
    def __init__(self,):
        super().__init__()

    def forward(self, x, depth, path_index):
        B, T, D = x.shape
        Fbins = T // 2 + 1
        band_size = Fbins // (2 ** (depth + 1))
        start = path_index * band_size
        end = start + band_size

        Xf = torch.fft.rfft(x, dim=1)
        mask = torch.zeros(Fbins, device=x.device)
        hann = torch.hann_window(band_size, device=x.device)
        mask[start:end] = hann
        mask = mask.view(1, Fbins, 1)
        out = torch.fft.irfft(Xf * mask, n=T, dim=1)
        return out


def apply_rope(x):
    B, H, T, D = x.shape
    half = D // 2
    freq = torch.exp(-torch.arange(0, half, 2, device=x.device) * (math.log(10000.0) / half))
    pos = torch.arange(T, device=x.device).unsqueeze(1)
    angles = pos * freq.unsqueeze(0)
    sin = torch.sin(angles).repeat(1, 2)
    cos = torch.cos(angles).repeat(1, 2)
    x1, x2 = x[..., :half], x[..., half:]
    return torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)

class QKVHybridAttention(nn.Module):
    """
    Hybrid attention block: uses depth/path_index to filter K, returns x_out and per-node significance.
    """
    def __init__(self, d_model, n_heads):
        super().__init__()
        assert d_model % n_heads == 0, "embed_dim must be divisible by num_heads"
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        self.attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
        self.mlp = nn.Sequential(
            nn.Linear(d_model, d_model * 4), nn.GELU(), nn.Linear(d_model * 4, d_model)
        )
        self.filter = BandHannFilter()
        self.out = nn.Linear(d_model, d_model)

    def forward(self, x, depth, path_index):
        # Query via self-attention
        q_out, _ = self.attn(x, x, x, need_weights=False)
        # Key via band-pass filter at this depth/path_index
        k_band = self.filter(x, depth, path_index)
        # Value via MLP injection
        v_out = self.mlp(x)

        B, T, D = x.shape
        def split(z): return z.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        Qh, Kh, Vh = split(q_out), split(k_band), split(v_out)
        Qh = apply_rope(Qh)
        Kh = apply_rope(Kh)

        scores = (Qh @ Kh.transpose(-2, -1)) / math.sqrt(self.head_dim)
        w = F.softmax(scores, dim=-1)
        significance = Qh.mean(dim=[1, 2, 3])  # [B]

        o = (w @ Vh).transpose(1, 2).reshape(B, T, D)
        return self.out(o), significance

class HybridBlock(nn.Module):
    """
    A single-level processor: attention + MLP with residuals.
    """
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.norm1 = nn.LayerNorm(d_model)
        self.attn = QKVHybridAttention(d_model, n_heads)
        self.norm2 = nn.LayerNorm(d_model)
        self.mlp = nn.Sequential(
            nn.Linear(d_model, d_model * 4), nn.GELU(), nn.Linear(d_model * 4, d_model)
        )

    def forward(self, x, depth, path_index):
        h = self.norm1(x)
        attn_out, sig = self.attn(h, depth, path_index)
        x = x + attn_out
        h2 = self.norm2(x)
        mlp_out = self.mlp(h2)
        x = x + mlp_out
        return x, sig

class TinyHybridGPT(nn.Module):
    """
    Constructs a binary tree of HybridBlocks according to n_layers.
    Final depth nodes' x and significance are used for selection, gating, fusion.
    Blocks are pre-created in __init__ for each depth and node.
    """
    def __init__(self, vocab_size, d_model=64, n_heads=4, n_layers=3):
        super().__init__()
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.n_layers = n_layers
        self.n_heads = n_heads
        # Pre-create HybridBlocks for each depth 0..n_layers-1
        self.blocks_per_depth = nn.ModuleList()
        for depth in range(n_layers):
            num_blocks = 2 ** (depth + 1)
            layer_blocks = nn.ModuleList([
                HybridBlock(d_model=d_model, n_heads=n_heads)
                for _ in range(num_blocks)
            ])
            self.blocks_per_depth.append(layer_blocks)
        self.norm = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size, bias=False)

    def forward(self, idx):
        B, T = idx.shape
        x0 = self.token_emb(idx)
        # BFS tree traversal
        nodes = [(x0, 0)]  # (x, path_index)
        sigs = []
        for depth in range(self.n_layers):
            next_nodes = []
            # each depth has 2**(depth+1) blocks: one per child slot
            blocks = self.blocks_per_depth[depth]
            for idx_node, (x, pidx) in enumerate(nodes):
                # assign distinct blocks for low and high children
                low_block = blocks[2 * idx_node]
                high_block = blocks[2 * idx_node + 1]
                # compute child path indices
                low_idx = pidx * 2
                high_idx = pidx * 2 + 1
                # process children with their own blocks
                x_low, sig_low = low_block(x, depth, low_idx)
                x_high, sig_high = high_block(x, depth, high_idx)
                next_nodes.append((x_low, low_idx))
                next_nodes.append((x_high, high_idx))
            nodes = next_nodes
        # At final depth, nodes has 2**n_layers items
        # Collect final leaves and significance by re-running last depth's blocks
        sigs = []
        final_blocks = self.blocks_per_depth[self.n_layers-1]
        for idx_node, (x, pidx) in enumerate(nodes):
            blk = final_blocks[idx_node]
            x_out, sig = blk(x, self.n_layers-1, pidx)
            sigs.append((x_out, sig))
        xs, S = zip(*sigs)
        X = torch.stack(xs, dim=1)   # [B, M, T, D]
        S = torch.stack(S, dim=1)    # [B, M]
        gates = F.softmax(S, dim=1)               # [B, M], all >0, sum to 1
        fused = (gates.unsqueeze(-1).unsqueeze(-1) * X).sum(dim=1)
        x = self.norm(fused)
        logits = self.head(x)
        return logits, S


In [72]:
import os
import pickle
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
import math
from torch.optim.optimizer import Optimizer
class Wolf(Optimizer):
  """Implements Wolf algorithm."""
  #Wolf, also called Rainstar Optimizer, is fast. it is resistant to valleys and other things where adam hangs.
  #on some problems, it is faster than adam. Try a high LR and lower it until it doesnt explode.
  #wolf is initially smoother than adam over difficult plateaus and at high LR.


  def __init__(self, params, lr=0.25, betas=(0.9, 0.999), eps=1e-8):
        # Define default parameters
        defaults = dict(lr=lr, betas=betas, eps=eps)
        self.lr = lr
        # Initialize the parent Optimizer class first
        super().__init__(params, defaults)
        # Constants specific to Wolf
        # Initialize state for each parameter
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                state['p'] = torch.zeros_like(p)  # Second moment estimate

  @torch.no_grad()
  def step(self, closure=None):
    """Performs a single optimization step.

    Args:
      closure (callable, optional): A closure that reevaluates the model
        and returns the loss.

    Returns:
      the loss.
    """
    etcerta = 0.367879441
    et = 1 - etcerta

    loss = None
    if closure is not None:
      with torch.enable_grad():
        loss = closure()

    for group in self.param_groups:
      for p in group['params']:
        if p.grad is None:
          continue
        state = self.state[p]

            # Update step count


        # Perform stepweight decay
        grad = p.grad
        state = self.state[p]
        # State initialization

        exp_avg = state['p']
        # Weight update
        update = exp_avg * et + grad * etcerta
        state['p']  = exp_avg * et + update * etcerta
        sign_agreement = torch.sign(update) * torch.sign(grad)

        update = update + (torch.rand_like(update)*2 - 1) * etcerta * update
        # Where signs agree (positive), apply normal update
        mask = (sign_agreement > 0)
        p.data = torch.where(mask,
                            p.data - self.lr * update,
                            p.data)
    return loss
# 2. Load the binary data
data_dir = os.path.dirname(base_dir)
train_ids = np.fromfile(os.path.join(data_dir, 'train.bin'), dtype=np.uint16)
val_ids   = np.fromfile(os.path.join(data_dir, 'val.bin'),   dtype=np.uint16)
with open(os.path.join(data_dir, 'meta.pkl'), 'rb') as f:
    meta = pickle.load(f)
vocab_size = meta['vocab_size']

# 3. Define a simple PyTorch Dataset for next‐char prediction
class CharDataset(Dataset):
    def __init__(self, data, block_size):
        self.data = torch.from_numpy(data).long()
        self.block_size = block_size
    def __len__(self):
        return len(self.data) - self.block_size
    def __getitem__(self, idx):
        x = self.data[idx : idx + self.block_size]
        y = self.data[idx + 1 : idx + self.block_size + 1]
        return x, y

block_size = 128
train_ds = CharDataset(train_ids, block_size)
val_ds   = CharDataset(val_ids,   block_size)

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, drop_last=True)
val_loader   = DataLoader(val_ds,   batch_size=32, shuffle=False, drop_last=True)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = TinyHybridGPT(vocab_size=vocab_size, d_model=64, n_heads=4, n_layers=3).to(device)
print(sum(p.numel() for p in model.parameters()), ' parameters')
# 5. Set up optimizer & loss
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-3)
criterion = torch.nn.CrossEntropyLoss()

# 6. Training loop
def train_epoch():
    model.train()
    total_loss = 0
    for xb, yb in train_loader:
        xb, yb = xb.to(device), yb.to(device)
        logits,S = model(xb)               # (B, T, vocab_size)
        print(S)
        B, T, V = logits.shape
        loss = criterion(logits.view(B*T, V), yb.view(B*T))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        print(loss.item())
    return total_loss / len(train_loader)

@torch.no_grad()
def eval_epoch():
    model.eval()
    total_loss = 0
    for xb, yb in val_loader:
        xb, yb = xb.to(device), yb.to(device)
        logits = model(xb)
        B, T, V = logits.shape
        loss = criterion(logits.view(B*T, V), yb.view(B*T))
        total_loss += loss.item()
    return total_loss / len(val_loader)

# 7. Run training
num_epochs = 10
for epoch in range(1, num_epochs + 1):
    train_loss = train_epoch()
    val_loss   = eval_epoch()
    print(f"Epoch {epoch:2d} | train loss: {train_loss:.4f} | val loss: {val_loss:.4f}")

1229696  parameters
tensor([[ 2.4489e-03, -5.0535e-03,  5.3906e-03,  1.8213e-02,  6.2389e-03,
         -1.9977e-02,  1.2376e-04,  1.1495e-02],
        [-3.2691e-03, -1.0262e-02,  9.8340e-03,  1.5876e-02,  1.1024e-02,
         -1.8255e-02, -3.4833e-03,  1.5504e-02],
        [ 4.1652e-03, -4.7368e-03,  1.2671e-02,  1.8207e-02,  1.0422e-02,
         -1.7236e-02,  1.7473e-04,  9.9058e-03],
        [ 1.7250e-03, -5.3229e-03,  1.4014e-02,  1.6561e-02,  7.5398e-03,
         -2.3941e-02, -3.5752e-03,  1.5962e-02],
        [-1.7691e-03, -4.9972e-03,  6.6807e-03,  2.0760e-02,  9.5559e-03,
         -2.2263e-02, -1.8436e-05,  1.3636e-02],
        [ 3.4264e-03, -5.7315e-03,  1.2903e-02,  1.3000e-02,  8.8614e-03,
         -1.6049e-02, -3.9150e-03,  1.8445e-02],
        [ 6.3008e-03, -5.9741e-03,  3.9536e-03,  2.2181e-02,  1.1852e-02,
         -1.1942e-02, -4.1958e-04,  1.1580e-02],
        [ 2.0568e-03, -6.7012e-03,  8.2796e-03,  1.8119e-02,  9.1765e-03,
         -1.9789e-02, -3.7834e-03,  1.3285e-0

KeyboardInterrupt: 

In [68]:
def train_epoch():
    model.train()
    total_loss = 0
    for xb, yb in train_loader:
        xb, yb = xb.to(device), yb.to(device)
        logits,S = model(xb)               # (B, T, vocab_size)
        print(S)
        B, T, V = logits.shape
        loss = criterion(logits.view(B*T, V), yb.view(B*T))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        print(loss.item())
    return total_loss / len(train_loader)

@torch.no_grad()
def eval_epoch():
    model.eval()
    total_loss = 0
    for xb, yb in val_loader:
        xb, yb = xb.to(device), yb.to(device)
        logits = model(xb)
        B, T, V = logits.shape
        loss = criterion(logits.view(B*T, V), yb.view(B*T))
        total_loss += loss.item()
    return total_loss / len(val_loader)

# 7. Run training
num_epochs = 10
for epoch in range(1, num_epochs + 1):
    train_loss = train_epoch()
    val_loss   = eval_epoch()
    print(f"Epoch {epoch:2d} | train loss: {train_loss:.4f} | val loss: {val_loss:.4f}")

tensor([[0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078],
        [0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078],
        [0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078],
        [0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078],
        [0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078],
        [0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078],
        [0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078],
        [0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078],
        [0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078],
        [0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078],
        [0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078],
        [0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078],
        [0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078],
        [0.0078, 0.0078, 0.0078, 0.007

KeyboardInterrupt: 

In [42]:
def train_epoch():
    model.train()
    total_loss = 0
    for xb, yb in train_loader:
        xb, yb = xb.to(device), yb.to(device)
        logits,S = model(xb)               # (B, T, vocab_size)
        B, T, V = logits.shape
        loss = criterion(logits.view(B*T, V), yb.view(B*T))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        print(loss.item())
    return total_loss / len(train_loader)

@torch.no_grad()
def eval_epoch():
    model.eval()
    total_loss = 0
    for xb, yb in val_loader:
        xb, yb = xb.to(device), yb.to(device)
        logits = model(xb)
        B, T, V = logits.shape
        loss = criterion(logits.view(B*T, V), yb.view(B*T))
        total_loss += loss.item()
    return total_loss / len(val_loader)

# 7. Run training
num_epochs = 10
for epoch in range(1, num_epochs + 1):
    train_loss = train_epoch()
    val_loss   = eval_epoch()
    print(f"Epoch {epoch:2d} | train loss: {train_loss:.4f} | val loss: {val_loss:.4f}")

3.3272504806518555
3.304229497909546
3.2672877311706543
3.2476792335510254
3.275294303894043
3.3815929889678955
3.3547918796539307
3.301131010055542
3.3120720386505127
3.323371410369873
3.3096659183502197
3.2997119426727295
3.334392547607422
3.2494325637817383
3.360769510269165
3.2848989963531494
3.2806174755096436
3.27695369720459
3.331876039505005
3.3024120330810547
3.414670705795288
3.354112148284912
3.2829952239990234
3.3061180114746094
3.3639492988586426
3.300365924835205
3.3417983055114746
3.287287473678589
3.284991979598999
3.3076436519622803
3.2893238067626953
3.3513002395629883
3.3044443130493164
3.3326029777526855
3.396044969558716
3.305065393447876
3.280106544494629
3.325437545776367
3.3612101078033447
3.3514082431793213
3.287095546722412
3.2867677211761475
3.2750561237335205
3.274765729904175
3.314539909362793
3.291142702102661
3.330069065093994
3.375945806503296
3.3013930320739746
3.2833380699157715
3.292588710784912
3.308565378189087
3.2666268348693848
3.29486346244812
3.

KeyboardInterrupt: 

In [12]:
bcontext_str = "To be, or not to be,"
context_ids  = torch.tensor([[ stoi[c] for c in bcontext_str ]], dtype=torch.long)

max_new_tokens = 5000
temperature     = 1.0
top_k           = 50
block_size      = 128

generated = context_ids  # (1, T)

for _ in range(max_new_tokens):
    # only pass at most block_size tokens into the model
    input_ids = generated if generated.size(1) <= block_size \
                else generated[:, -block_size:]

    logits = model(input_ids)              # (1, T_cur, vocab_size)
    logits = logits[:, -1, :] / temperature

    if top_k is not None:
        v, _ = torch.topk(logits, top_k)
        logits[logits < v[:, [-1]]] = -1e10

    probs   = F.softmax(logits, dim=-1)
    next_id = torch.multinomial(probs, num_samples=1)  # (1,1)
    generated = torch.cat([generated, next_id], dim=1)

output_str = ''.join([itos[i] for i in generated[0].tolist()])
print(output_str)


To be, or not to be, t beet beet bebet bebet bebete be beToe bereribereribereribererieieieieieieieieieieieieieieieieieieieieieieieieieieieieieieieieieieieieieieieieieieieieieieieieieieieieieioieieioieieeeieneeoieeoeoieioieieioeieoioeieeieieioeieieioeioioieioioaie eieeeeioeeeeioeoioioioioioeeeeeeeeeeeeeeeneeee eee eioeioeoeeeeeeeeeeeeeeeeeeeeeeeeeeeeenooooooooooooooo eeeeeeee eeeeoeeeeoeoeeoeoeeoeeeeeeeeeeeeeeeeeeioooooooooooooooooooooooo,ooeoeoeoeoeoeoeoeoooeoooooooooooo'eeeeeeeeeeeeeeeeeeeeeoeeoeoeoeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeoeeeeeeeeeeeeeeeeeoee eoeoeoeoeoeeneonooooooooooooooooooooooooooooooooooooooooooooooooooeeeeeeeeeeeeeeeeeeoooeeeooeoeoeoeooeoononoooonooooeeeeeooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooeeoooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooaoooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo

In [48]:
import numpy as np
import scipy.linalg as la

# ── Primitive definitions ─────────────────────────────────────────────────────

def attention_map(Q, K, V):
    """
    Q,K,V: (T, d)
    returns: (T, d)
    """
    scores = (Q @ K.T) / np.sqrt(Q.shape[1])
    exp_scores = np.exp(scores - scores.max(axis=1, keepdims=True))
    weights = exp_scores / exp_scores.sum(axis=1, keepdims=True)
    return weights @ V

def s4d_map(x, B, C, alpha, beta):
    """
    x: (T, d)
    B,C,alpha,beta: (N,)
    returns: y of shape (T, d)
    """
    T, d = x.shape
    N = len(B)
    L = 2*T - 1

    # Build complex-exponential kernels for each mode
    t = np.arange(T)
    A = -np.exp(alpha) + 1j*beta               # (N,)
    modes = np.exp(A[:,None] * t[None,:])      # (N, T)
    K = (B[:,None] * modes) * C[:,None]        # (N, T)
    Kf = np.fft.fft(K, n=L, axis=1)            # (N, L)

    # FFT of input features: (T,d) → (L,d)
    Xf = np.fft.fft(x, n=L, axis=0)            # (L, d)

    # Multiply per-mode and sum, then IFFT back
    Yf = np.sum(Kf[:,:,None] * Xf[None,:,:], axis=0)  # (L, d)
    y  = np.fft.ifft(Yf, n=L, axis=0).real            # (L, d)

    return y[:T,:]                                    # (T, d)

def mlp_map(x, W1, b1, W2, b2):
    """
    x: (T, d)
    returns: (T, d)
    """
    # using ReLU for simplicity
    h = np.maximum(0, x @ W1.T + b1)
    return h @ W2.T + b2

# ── Six permutation definitions ────────────────────────────────────────────────

def Y_mode(m, x, params):
    """
    x: (T, d)
    params: dict of weight matrices/vectors
    m: 1..6
    returns: Y^{(m)}(x) of shape (T, d)
    """
    # compute primitives with correct matmuls
    Qx = attention_map(x @ params['WQ'], x @ params['WK'], x @ params['WV'])
    Sx = s4d_map(x, params['B'], params['C'], params['alpha'], params['beta'])
    Mx = mlp_map(x, params['W1'], params['b1'], params['W2'], params['b2'])

    assign = {
        1: (Qx, Sx, Mx),
        2: (Qx, Mx, Sx),
        3: (Sx, Qx, Mx),
        4: (Sx, Mx, Qx),
        5: (Mx, Qx, Sx),
        6: (Mx, Sx, Qx),
    }[m]

    return attention_map(*assign)

# ── Spectral‐norm Lipschitz bound ──────────────────────────────────────────────

def spectral_norm(mat):
    return la.svd(mat, compute_uv=False)[0]

def compute_L_bound(m, norms):
    pick = {
        1: ('A','S','M'),
        2: ('A','M','S'),
        3: ('S','A','M'),
        4: ('S','M','A'),
        5: ('M','A','S'),
        6: ('M','S','A'),
    }[m]
    def get_norm(role, prim):
        if prim == 'A':
            return norms[f"A_{role}"]
        else:
            return norms[prim]
    LQ = get_norm('Q', pick[0])
    LK = get_norm('K', pick[1])
    LV = get_norm('V', pick[2])
    return LQ * LK * LV

# ── Hessian top‐eigenvalue via finite‐difference Hv ────────────────────────────

def numeric_grad(f, x, h=1e-3):
    T, d = x.shape
    g = np.zeros_like(x)
    fx = f(x)
    for i in range(T):
        for j in range(d):
            xp = x.copy()
            xp[i,j] += h
            g[i,j] = (f(xp).sum() - fx.sum()) / h
    return g

def approx_top_eig(Yfunc, x0, eps=1e-3, iters=5):
    theta = x0.flatten()
    v = np.random.randn(theta.size)
    v /= np.linalg.norm(v)
    for _ in range(iters):
        g0 = numeric_grad(Yfunc, x0).flatten()
        xp = x0 + eps * v.reshape(x0.shape)
        g1 = numeric_grad(Yfunc, xp).flatten()
        Hv = (g1 - g0) / eps
        v = Hv / (np.linalg.norm(Hv) + 1e-12)
    return float(v @ Hv)

# ── Main exploration ──────────────────────────────────────────────────────────

if __name__ == "__main__":
    T, d, N = 20, 16, 4

    # Random “trained” weights
    params = {
        'WQ':    np.random.randn(d,d),
        'WK':    np.random.randn(d,d),
        'WV':    np.random.randn(d,d),
        'W1':    np.random.randn(4*d,d),
        'b1':    np.random.randn(4*d),
        'W2':    np.random.randn(d,4*d),
        'b2':    np.random.randn(d),
        'B':     np.random.randn(N),
        'C':     np.random.randn(N),
        'alpha': np.random.randn(N),
        'beta':  np.random.randn(N),
    }

    # Precompute primitive norms
    norms = {
        'A_Q': spectral_norm(params['WQ']),
        'A_K': spectral_norm(params['WK']),
        'A_V': spectral_norm(params['WV']),
        'M':   spectral_norm(params['W2']) * spectral_norm(params['W1']),
        'S':   max(np.abs(np.fft.fft(
                   s4d_map(np.random.randn(T,d),
                          params['B'],params['C'],
                          params['alpha'],params['beta']),
                   n=2*T-1, axis=0).flatten())),
    }

    x0 = np.random.randn(T, d)

    print("Mode | L_bound  | Hessian λ_max")
    for m in range(1, 7):
        Lb  = compute_L_bound(m, norms)
        eig = approx_top_eig(lambda x: Y_mode(m, x, params), x0)
        print(f"  {m}  | {Lb:8.3f} | {eig:8.3f}")


Mode | L_bound  | Hessian λ_max
  1  | 16971.954 | 76178.107
  2  | 16971.954 | 399011.787
  3  | 18144.262 | 24709.401
  4  | 15723.201 | 3325.412
  5  | 18144.262 | 13574.675
  6  | 15723.201 | 8001.522


In [50]:
import numpy as np
import scipy.linalg as la
from sklearn.decomposition import PCA

# ── [ primitives: attention_map, s4d_map, mlp_map ] ─────────────────────────
# (unchanged—copy from previous script)

# ── Jacobian singular-value spectrum via randomized SVD ──────────────────────

def topk_jac_svs(Yfunc, x0, k=10, n_probes=20, eps=1e-3):
    """
    Approximate the top-k singular values of the Jacobian J=∂Y/∂x at x0.
    We use randomized power method with directional derivatives.
    """
    T, d = x0.shape
    D = T*d
    # probes: random vectors in input space
    probes = np.random.randn(n_probes, D)
    probes /= np.linalg.norm(probes, axis=1, keepdims=True)
    # build Jv for each probe
    Jv = np.zeros((D, n_probes))
    for i, v in enumerate(probes):
        v_mat = v.reshape(T, d)
        # directional derivative: (Y(x+eps v) - Y(x-eps v)) / (2 eps)
        Yp = Yfunc(x0 + eps*v_mat)
        Ym = Yfunc(x0 - eps*v_mat)
        diff = ((Yp - Ym) / (2*eps)).reshape(D)
        Jv[:, i] = diff
    # now approximate SVs of J via SVD on Jv (D×n_probes)
    U, s, VT = la.svd(Jv, full_matrices=False)
    return s[:k]

# ── Directional second derivatives ───────────────────────────────────────────

def directional_second_derivative(Yfunc, x0, v, eps=1e-3):
    """
    Approximate v^T Hessian v at x0 via finite differences:
      v^T H v ≈ (Y(x+eps v) - 2Y(x) + Y(x-eps v))·v / eps^2
    """
    Y0 = Yfunc(x0).reshape(-1)
    v_mat = v.reshape(x0.shape)
    Yp = Yfunc(x0 + eps*v_mat).reshape(-1)
    Ym = Yfunc(x0 - eps*v_mat).reshape(-1)
    # project onto v
    return float(((Yp - 2*Y0 + Ym) @ v) / (eps**2))

def sample_second_derivatives(Yfunc, x0, num_dirs=50):
    """
    Sample directional curvatures v^T H v for random unit v's.
    """
    D = x0.size
    vals = []
    for _ in range(num_dirs):
        v = np.random.randn(D)
        v /= np.linalg.norm(v)
        vals.append(directional_second_derivative(Yfunc, x0, v))
    return np.array(vals)

# ── Tangent vs normal gains on a PCA manifold ────────────────────────────────

def tangent_normal_gains(Yfunc, X_batch, pca_dim=5, eps=1e-3):
    """
    Estimate average gain ||Y(x+δ)-Y(x)||/||δ|| for δ along
    - top-pca_dim directions (tangent)
    - random orthonormal directions in complement (normal)
    """
    B, T, d = X_batch.shape
    flat = X_batch.reshape(B, T*d)
    pca = PCA(n_components=pca_dim).fit(flat)
    tangent_dirs = pca.components_       # (pca_dim, D)
    # find orth complement basis via QR
    Q, _ = np.linalg.qr(flat.T)           # Q: (D, D)
    normal_dirs = Q[:, pca_dim:]          # (D, D-pca_dim)

    gains = {'tangent': [], 'normal': []}
    for x in flat:
        x_mat = x.reshape(T, d)
        Y0 = Yfunc(x_mat)
        for v in tangent_dirs:
            δ = v / np.linalg.norm(v)
            Yp = Yfunc(x_mat + eps*δ.reshape(T,d))
            gains['tangent'].append(np.linalg.norm(Yp-Y0)/eps)
        for j in range(5):  # sample 5 random normals
            v = normal_dirs[:, np.random.randint(normal_dirs.shape[1])]
            v /= np.linalg.norm(v)
            Yp = Yfunc(x_mat + eps*v.reshape(T,d))
            gains['normal'].append(np.linalg.norm(Yp-Y0)/eps)
    return {k: np.mean(v) for k,v in gains.items()}

# ── Main exploration with extended analysis ─────────────────────────────────

if __name__ == "__main__":
    # [ set T,d,N, params, norms, x0 as before ]
    # ...
    X_batch = np.random.randn(10, T, d)  # synthetic “validation batch”

    print("Mode | L_bound | λ_max |  top10 SVs... | 2nd-deriv stats | tangent vs normal")
    for m in range(1,7):
        # Lipschitz & Hessian max
        Lb  = compute_L_bound(m, norms)
        lmax= approx_top_eig(lambda x: Y_mode(m,x,params), x0)
        # Jacobian SVs
        svs = topk_jac_svs(lambda x: Y_mode(m,x,params), x0, k=6)
        # 2nd-derivative distribution
        curvs = sample_second_derivatives(lambda x: Y_mode(m,x,params), x0, num_dirs=20)
        sec_mean, sec_std = curvs.mean(), curvs.std()
        # manifold gains
        gains = tangent_normal_gains(lambda x: Y_mode(m,x,params), X_batch)
        # ... compute Lb, lmax, svs, sec_mean, sec_std, gains as before ...
        svs_str = ", ".join(f"{v:.3f}" for v in svs)
        print(f"{m:>4} | {Lb:7.1f} | {lmax:7.1f} | [{svs_str}] | "
              f"2ndμ={sec_mean:.1f},σ={sec_std:.1f} | "
              f"tan={gains['tangent']:.3f}, nor={gains['normal']:.3f}")


Mode | L_bound | λ_max |  top10 SVs... | 2nd-deriv stats | tangent vs normal
   1 | 16972.0 | 82645.0 | [194.092, 159.507, 123.940, 94.827, 74.558, 67.483] | 2ndμ=3.1,σ=28.7 | tan=80.706, nor=73.236
   2 | 16972.0 | 399011.8 | [164.661, 23.946, 1.821, 1.603, 1.506, 1.492] | 2ndμ=1.2,σ=120.7 | tan=12.434, nor=16.881
   3 | 18144.3 | 27116.0 | [339.537, 180.237, 134.718, 103.053, 84.968, 65.123] | 2ndμ=1.8,σ=6.9 | tan=88.987, nor=83.698
   4 | 15723.2 |  3444.1 | [94.983, 79.802, 45.735, 42.337, 33.154, 29.991] | 2ndμ=0.1,σ=5.6 | tan=21.483, nor=20.794
   5 | 18144.3 | 13575.5 | [64.068, 42.104, 28.238, 10.195, 8.300, 6.610] | 2ndμ=-0.2,σ=4.0 | tan=17.864, nor=15.867
   6 | 15723.2 |  8283.2 | [68.914, 47.907, 43.308, 38.520, 35.510, 31.590] | 2ndμ=-0.1,σ=1.8 | tan=22.553, nor=22.915


Together, these hyper-dimensional signatures show Mode 4 is most predisposed to (a) collapse irrelevant dimensions, (b) preserve task-relevant structure, and (c) yield clusters of semantically related inputs in its output space.