<a href="https://colab.research.google.com/github/falseywinchnet/PyITD/blob/main/FusedQKVT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
"""
Prepare the Shakespeare dataset for character-level language modeling.
So instead of encoding with GPT-2 BPE tokens, we just map characters to ints.
Will save train.bin, val.bin containing the ids, and meta.pkl containing the
encoder and decoder and some other related info.
"""
import os
import pickle
import requests
import numpy as np
import os
from pathlib import Path

try:
    base_dir = Path(__file__).parent
except NameError:
    base_dir = Path(os.getcwd())  # fallback if __file__ is not defined (e.g. in REPL)
# download the tiny shakespeare dataset
input_file_path = os.path.join(os.path.dirname(base_dir), 'input.txt')
if not os.path.exists(input_file_path):
    data_url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
    with open(input_file_path, 'w') as f:
        f.write(requests.get(data_url).text)

with open(input_file_path, 'r') as f:
    data = f.read()
print(f"length of dataset in characters: {len(data):,}")

# get all the unique characters that occur in this text
chars = sorted(list(set(data)))
vocab_size = len(chars)
print("all the unique characters:", ''.join(chars))
print(f"vocab size: {vocab_size:,}")

# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
def encode(s):
    return [stoi[c] for c in s] # encoder: take a string, output a list of integers
def decode(l):
    return ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

# create the train and test splits
n = len(data)
train_data = data[:int(n*0.9)]
val_data = data[int(n*0.9):]

# encode both to integers
train_ids = encode(train_data)
val_ids = encode(val_data)
print(f"train has {len(train_ids):,} tokens")
print(f"val has {len(val_ids):,} tokens")

# export to bin files
train_ids = np.array(train_ids, dtype=np.uint16)
val_ids = np.array(val_ids, dtype=np.uint16)
train_ids.tofile(os.path.join(os.path.dirname(base_dir), 'train.bin'))
val_ids.tofile(os.path.join(os.path.dirname(base_dir), 'val.bin'))

# save the meta information as well, to help us encode/decode later
meta = {
    'vocab_size': vocab_size,
    'itos': itos,
    'stoi': stoi,
}
with open(os.path.join(os.path.dirname(base_dir), 'meta.pkl'), 'wb') as f:
    pickle.dump(meta, f)

length of dataset in characters: 1,115,394
all the unique characters: 
 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
vocab size: 65
train has 1,003,854 tokens
val has 111,540 tokens


#if you use my ideas, please credit me, dont just steal
joshuah.rainstar@gmail.com


In [55]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class RainstarActivation(nn.Module):
    def __init__(self, p=0.5, kappa=1.0, gamma=2.0, s_min=0.48):
        super().__init__()
        self.p = p
        self.kappa = kappa
        self.gamma = gamma
        self.s_min = s_min

    def forward(self, x):
        # x: any shape
        a = x.abs()
        # avoid log(0)
        t = torch.log(a + 1e-12) + self.p * torch.log(torch.tanh(a) + 1e-12)
        base = x.sign() * torch.exp(t)
        clamp = self.s_min + (1 - self.s_min) / (1 + (a / self.kappa) ** self.gamma)
        return base * clamp


class ExpertMLP(nn.Module):
    """
    Convex mixture-of-linear-experts MLP with Rainstar gating.
    """
    def __init__(self, d_model: int, num_experts: int = 4,
                 p=0.5, kappa=1.0, gamma=2.0, s_min=0.48, eps=1e-6):
        super().__init__()
        self.num_experts = num_experts
        self.eps = eps

        # Rainstar gate applied to gate logits
        self.rainstar = RainstarActivation(p, kappa, gamma, s_min)
        # gating network: produces E pre-activations per token
        self.gate = nn.Linear(d_model, num_experts)

        # expert linear maps
        self.experts = nn.ModuleList([
            nn.Linear(d_model, d_model) for _ in range(num_experts)
        ])

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: (B, T, D)
        returns: (B, T, D)
        """
        # 1) compute raw gates and apply Rainstar
        #    → shape (B, T, E)
        gates_raw = self.rainstar(self.gate(x))

        # 2) ensure non-negativity & normalize to convex weights
        gates = gates_raw.abs() + self.eps
        gates = gates / gates.sum(dim=-1, keepdim=True)  # along expert dim

        # 3) compute each expert's linear output: (B, T, D, E)
        expert_outs = torch.stack([expert(x) for expert in self.experts],
                                  dim=-1)

        # 4) convex combination over experts → (B, T, D)
        y = (expert_outs * gates.unsqueeze(2)).sum(dim=-1)
        return y

@torch.jit.script
def trig_modulate(x: torch.Tensor, pos: torch.Tensor,
                  phase_index: int= 0) -> torch.Tensor:
    B, H, T, D = x.shape
    inv_freq = 1.0 / (10000 ** (torch.arange(0, D, dtype=torch.float32) / D)).to(x.device)
    theta = torch.outer(pos.float(), inv_freq).to(x.device)  # [T, D]
    phi = (math.pi / 3) * float(phase_index)
    theta = theta + phi

    mod = torch.cos(theta)
    mod = mod.unsqueeze(0).unsqueeze(0)  # [1, 1, T, D]
    return x * mod


class TrigQKVAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        assert d_model % n_heads == 0
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads

        self.q_mlp = ExpertMLP(d_model)
        self.k_mlp =  ExpertMLP(d_model)
        self.v_mlp =  ExpertMLP(d_model)

        self.out_proj = nn.Linear(d_model, d_model)

    def forward(self, x):
        B, T, D = x.shape
        H, d = self.n_heads, self.head_dim

        # Positional indices
        pos = torch.arange(T, device=x.device)

        # Generate Q, K, V via distinct MLPs
        Q = self.q_mlp(x).view(B, T, H, d).transpose(1, 2)  # [B, H, T, d]
        K = self.k_mlp(x).view(B, T, H, d).transpose(1, 2)
        V = self.v_mlp(x).view(B, T, H, d).transpose(1, 2)

        # Trig-based modulation
        Q = trig_modulate(Q, pos, 1)
        K = trig_modulate(K, pos, 2)
        V = trig_modulate(V, pos, 3)

        # Scaled dot-product attention
        scores = (Q @ K.transpose(-2, -1)) / math.sqrt(d)  # [B, H, T, T]
        attn = F.softmax(scores, dim=-1)
        out = (attn @ V)  # [B, H, T, d]

        out = out.transpose(1, 2).reshape(B, T, D)
        return self.out_proj(out)

class HybridBlock(nn.Module):
    """
    A single-level processor: attention + MLP with residuals.
    """
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.attn = TrigQKVAttention(d_model, n_heads)
        self.mlp =  ExpertMLP(d_model)
        self.rainstar = RainstarActivation()

    def forward(self, x):
        attn_out = self.rainstar(self.attn(x))
        mlp_out = self.mlp(attn_out)
        x = x + self.rainstar(mlp_out)
        return x

class TinyTransformer(nn.Module):
    """
    Standard Transformer-style model using HybridBlocks as layers.
    """
    def __init__(self, vocab_size, d_model=64, n_heads=4, n_layers=6):
        super().__init__()
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.blocks = nn.ModuleList([
            HybridBlock(d_model=d_model, n_heads=n_heads)
            for _ in range(n_layers)
        ])
        self.head = nn.Linear(d_model, vocab_size, bias=False)
        self.rainstar = RainstarActivation()

    def forward(self, x):
        x = self.token_emb(x)

        # apply all layers with Rainstar gating
        for block in self.blocks:
            x = self.rainstar(block(x))
        logits = self.head(x)
        return logits

In [56]:
import os
import pickle
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
from torch.optim.optimizer import Optimizer

@torch.jit.script
def wolf_update(p: torch.Tensor,
                g: torch.Tensor,
                state_p: torch.Tensor,
                lr: float):
    # define your constants here instead of capturing them
    etcerta: float = 0.367879441
    et:      float = 1.0 - etcerta

    # same logic as before
    update    = state_p * et + g * etcerta
    new_state = state_p * et + update * etcerta
    sign_agree = torch.sign(update) * torch.sign(g)
    update    = update + (torch.rand_like(update)*2 - 1) * etcerta * update
    p_new     = torch.where(sign_agree > 0, p - lr * update, p)
    return p_new, new_state

class Wolf(Optimizer):
    def __init__(self, params, lr=1e-3):
        defaults = dict(lr=lr)
        super().__init__(params, defaults)
        for group in self.param_groups:
            for p in group['params']:
                self.state[p]['p'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure=None):
        loss = closure() if closure is not None else None
        for group in self.param_groups:
            lr = group['lr']
            for p in group['params']:
                if p.grad is None:
                    continue
                state_p = self.state[p]['p']
                p_new, new_state = wolf_update(p.data, p.grad, state_p, lr)
                p.data.copy_(p_new)
                state_p.copy_(new_state)
        return loss

# 1) Load data and meta as before
data_dir  = os.path.dirname(base_dir)
train_ids = np.fromfile(os.path.join(data_dir, 'train.bin'), dtype=np.uint16)
val_ids   = np.fromfile(os.path.join(data_dir, 'val.bin'),   dtype=np.uint16)
with open(os.path.join(data_dir, 'meta.pkl'), 'rb') as f:
    meta = pickle.load(f)
vocab_size = meta['vocab_size']

# 2) Compute data‐marginal q[v]
counts = np.bincount(train_ids, minlength=vocab_size).astype(float)
q = torch.tensor(counts / counts.sum(), dtype=torch.float32, device=device)  # [V]

# 3) Dataset + DataLoader
class CharDataset(Dataset):
    def __init__(self, data, block_size):
        self.data = torch.from_numpy(data).long()
        self.block_size = block_size
    def __len__(self):
        return len(self.data) - self.block_size
    def __getitem__(self, idx):
        x = self.data[idx : idx + self.block_size]
        y = self.data[idx + 1 : idx + self.block_size + 1]
        return x, y

block_size = 128
train_loader = DataLoader(CharDataset(train_ids, block_size),
                          batch_size=32, shuffle=True, drop_last=True)
val_loader   = DataLoader(CharDataset(val_ids,   block_size),
                          batch_size=32, shuffle=False, drop_last=True)

# 4) Model, optimizer, loss
device    = 'cuda' if torch.cuda.is_available() else 'cpu'
virgin     = TinyTransformer(vocab_size=vocab_size,
                            d_model=256, n_heads=4, n_layers=4)
model = torch.jit.script(virgin)
model.to(device)
optimizer = Wolf(model.parameters(), lr=0.3)#adam would explode at this training rate
criterion = nn.CrossEntropyLoss()

# 5) Regularization weights
λ_ent = 0.1   # entropy penalty weight
λ_kl  = 0.5   # marginal‐KL penalty weight

# 6) Train / eval functions
def train_epoch():
    model.train()
    total_loss = 0
    for xb, yb in train_loader:
        xb, yb = xb.to(device), yb.to(device)

        # Forward
        logits = model(xb)                 # (B, T, V)
        B, T, V = logits.shape
        p = F.softmax(logits, dim=-1)      # (B, T, V)

        # 1) Standard CE
        ce_loss = criterion(logits.view(B*T, V),
                            yb.view(B*T))

        # 2) Entropy penalty
        ent = -(p * (p + 1e-12).log()).sum(dim=-1)  # (B, T)
        ent_loss = ent.mean()

        # 3) Marginal‐KL penalty
        p_m = p.mean(dim=(0,1))            # [V]
        kl_loss = (p_m * (p_m + 1e-12).log() - p_m * q.log()).sum()

        # 4) Combined loss
        loss = ce_loss + λ_ent * ent_loss + λ_kl * kl_loss

        # Backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print(loss.item())
        total_loss += loss.item()
    return total_loss / len(train_loader)

@torch.no_grad()
def eval_epoch():
    model.eval()
    total_loss = 0
    for xb, yb in val_loader:
        xb, yb = xb.to(device), yb.to(device)
        logits = model(xb)
        B, T, V = logits.shape
        total_loss += criterion(logits.view(B*T,V),
                                yb.view(B*T)).item()
    return total_loss / len(val_loader)

# 7) Run training
num_epochs = 10
for epoch in range(1, num_epochs+1):
    train_loss = train_epoch()
    val_loss   = eval_epoch()
    print(f"Epoch {epoch:2d} | train: {train_loss:.4f} | val: {val_loss:.4f}")


5.279724597930908
5.275017261505127
5.271737098693848
5.265516757965088
5.260200500488281


KeyboardInterrupt: 

In [43]:
bcontext_str = "To be, or not to be,"
context_ids  = torch.tensor([[ stoi[c] for c in bcontext_str ]], dtype=torch.long)

max_new_tokens = 5000
temperature     = 1.0
top_k           = 50
block_size      = 128

generated = context_ids  # (1, T)

for _ in range(max_new_tokens):
    # only pass at most block_size tokens into the model
    input_ids = generated if generated.size(1) <= block_size \
                else generated[:, -block_size:]

    logits = model(input_ids)              # (1, T_cur, vocab_size)
    logits = logits[:, -1, :] / temperature

    if top_k is not None:
        v, _ = torch.topk(logits, top_k)
        logits[logits < v[:, [-1]]] = -1e10

    probs   = F.softmax(logits, dim=-1)
    next_id = torch.multinomial(probs, num_samples=1)  # (1,1)
    generated = torch.cat([generated, next_id], dim=1)

output_str = ''.join([itos[i] for i in generated[0].tolist()])
print(output_str)


To be, or not to be, t
te, t, t, t t, t, t,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,','E,,E,,,,,,,,,R,,,,,,,,,,,,,,'ZE'''''''''''''''''''''''''''''''''''''''''''E''''''''''''''''''q''''''''''''''''''''''''''''''''''''E''''''Q''''''''''''E'''''''''''''''''''''''''''''''''''''''''E'''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''E''''''''''''''''''''''''''''''''''''''''''''''''''''E''''''''''''''''''E'''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''Q'E'''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''E''''''''''E''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''E''''''''''''''''''''''''''''''''E''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''