# Simple Language Modeling Notebook

---

**Author: Windsor Nguyen '25**

This is a lightweight Python notebook for language modeling experiments.

The goal of this notebook is to:
- Be easy to prototype research ideas with
- Robust enough to get an accurate gauge for a model architecture's abilities


To get you started, we have code for the [Transformer](https://arxiv.org/abs/1706.03762) and the [Flash STU](https://arxiv.org/abs/2409.10489) model architectures.

> NOTE: It is *highly* recommended that you run everything on a GPU/TPU. This notebook was written with PyTorch/GPUs in mind, so 100% compatibility with other frameworks or TPUs is not guaranteed.

**May divine benevolence be with you and your research ideas!**


# Install required packages

Pip install required Python packages here.

In [None]:
%%capture

!pip install tiktoken

# Import required packages

In [None]:
import glob
import logging
import math
import os
import time

from contextlib import nullcontext
from functools import partial
from typing import Optional

import matplotlib.pyplot as plt
import numpy as np
import safetensors
import tiktoken
import torch
import torch.nn as nn
import torch.nn.functional as F

from pathlib import Path
from transformers import PreTrainedModel, PretrainedConfig


# Logging settings

Adjust logging settings here.


In [None]:
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)


# Plotting settings

Adjust plotting settings here.

In [None]:
%matplotlib inline

# PyTorch settings

In [None]:
torch.set_float32_matmul_precision("high")

SEED = 1746 # @param {type:"integer"}
torch.manual_seed(SEED)
np.random.seed(SEED)

CUDA_AVAILABLE = torch.cuda.is_available()
device = torch.device("cuda") if CUDA_AVAILABLE else torch.device("cpu")
if CUDA_AVAILABLE:
    print(f"Connected to {torch.cuda.get_device_name(device)}")
else:
    print("No CUDA devices found, running on CPU.")


# Google Colab settings

This will point to your Google Drive so that you can access training data, save outputs, etc.


In [None]:
from google.colab import drive
drive.mount('/content/drive')
save_dir = "/content/drive/MyDrive/stu_exps"  # @param {type:"string"}


# Utility functions


In [None]:
def nearest_power_of_two(x: int, round_up: bool = False) -> int:
    return (
        1 << math.floor(math.log2(x)) if not round_up else 1 << math.ceil(math.log2(x))
    )

def get_hankel(seq_len: int, use_hankel_L: bool = False) -> np.ndarray:
    entries = np.arange(1, seq_len + 1, dtype=np.float32)
    i_plus_j = entries[:, None] + entries[None, :]
    if use_hankel_L:
        sgn = (-1.0) ** (i_plus_j - 2.0) + 1.0
        denom = (i_plus_j + 3.0) * (i_plus_j - 1.0) * (i_plus_j + 1.0)
        Z = sgn * (8.0 / denom)
    else:
        Z = 2.0 / (i_plus_j**3 - i_plus_j)
    return Z

def get_spectral_filters(
    seq_len: int,
    k: int,
    use_hankel_L: bool = False,
    device: torch.device = None,
    dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
    print(f"Generating spectral filters...")
    Z_np = get_hankel(seq_len, use_hankel_L)
    sigma, phi = np.linalg.eigh(Z_np)
    sigma, phi = sigma[-k:], phi[:, -k:]
    phi *= sigma**0.25
    print("Spectral filters built!")
    return torch.tensor(phi, device=device, dtype=dtype)

def stu_conv(u: torch.Tensor, v: torch.Tensor, n: int, use_tensordot: bool = True) -> tuple[torch.Tensor, torch.Tensor]:
    bsz, seq_len, d_in = u.shape

    sgn = torch.full((1, seq_len, 1), 1, device=u.device)
    sgn[:, 1::2] *= -1

    if use_tensordot:
        _, d_out = v.shape
        v = v.view(1, -1, d_out, 1).to(torch.float32)
    else:
        _, K = v.shape
        sgn = sgn.unsqueeze(-1)
        v = v.view(1, -1, K, 1, 1).to(torch.float32) # (bsz, seq_len, K, d_in, stack)
        u = u.view(bsz, -1, 1, d_in).expand(bsz, -1, K, d_in)

    v = torch.fft.rfft(v, n=n, dim=1)
    U = torch.stack([u, u * sgn], dim=-1).to(torch.float32)
    U = torch.fft.rfft(U, n=n, dim=1)
    U_conv = torch.fft.irfft(v * U, n=n, dim=1)[:, :seq_len]
    U_plus, U_minus = torch.unbind(U_conv, dim=-1)
    U_minus = U_minus * sgn

    return U_plus, U_minus

def linear_decay_with_warmup( # https://arxiv.org/pdf/2310.07831
    current_step: int,
    warmup_steps: int,
    num_steps: int,
    max_lr: float = 3e-4,
    min_lr: float = 3e-5,
) -> float:
    if current_step < warmup_steps:
        return min_lr + (max_lr - min_lr) * float(current_step) / float(max(warmup_steps, 1))
    else:
        return max_lr - (max_lr - min_lr) * float(current_step - warmup_steps) / float(max(num_steps - warmup_steps, 1))


# MLP

SwiGLU variant. See more [here](https://arxiv.org/abs/2002.05202).

In [None]:
class MLP(nn.Module):
    def __init__(self, dim: int, inter_dim: int, dtype: torch.dtype = torch.float32):
        super().__init__()
        self.w1 = nn.Linear(dim, inter_dim, dtype=dtype, bias=False)
        self.w2 = nn.Linear(inter_dim, dim, dtype=dtype, bias=False)
        self.w3 = nn.Linear(dim, inter_dim, dtype=dtype, bias=False)
        self.w2.SCALE_INIT = 1

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.w2(F.silu(self.w1(x)) * self.w3(x))


# Rotary Positional Embeddings

See more [here](https://arxiv.org/abs/2104.09864).


In [None]:
class RoPE(nn.Module):
    def __init__(self, head_dim: int, max_seq_len: int = 4096, base: int = 10000):
        super().__init__()
        self.head_dim = head_dim
        self.base = base
        self.max_seq_len = max_seq_len

        # Theta computation for rotation frequencies
        theta = 1.0 / (self.base ** (torch.arange(0, self.head_dim, 2)[: (self.head_dim // 2)].float() / self.head_dim))
        self.register_buffer("theta", theta, persistent=False)

        seq_idx = torch.arange(self.max_seq_len, dtype=theta.dtype, device=theta.device)
        idx_theta = seq_idx[:, None] * theta[None, :]
        cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)
        self.register_buffer("cache", cache, persistent=False)

    def forward(self, x: torch.Tensor, *, input_pos: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Applies rotary positional embeddings to the input tensor.

        Args:
            x (torch.Tensor): Input tensor of shape [batch_size, seq_len, num_heads, head_dim].
            input_pos (Optional[torch.Tensor]): Optional position indices.

        Returns:
            torch.Tensor: Output tensor with rotary embeddings applied.
        """
        seq_len = x.size(1)
        rope_cache = self.cache[:seq_len] if input_pos is None else self.cache[input_pos]
        xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
        rope_cache = rope_cache.view(-1, xshaped.size(1), 1, xshaped.size(3), 2)
        x_out = torch.stack(
            [
                xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
                xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
            ],
            -1,
        )
        x_out = x_out.flatten(3)
        return x_out.type_as(x)


# Attention

Standard Multi-head Attention with RoPE.

In [None]:
class Attention(nn.Module):
    """Standard Multi-head Attention with Rotary Positional Embeddings"""
    def __init__(self, config):
        super().__init__()
        assert config.dim % config.num_heads == 0, (
            f"dim must be divisible by num_heads, got dim={dim} and num_heads={num_heads}"
        )
        self.dim = config.dim
        self.num_heads = config.num_heads
        self.head_dim = config.dim // config.num_heads
        self.seq_len = config.seq_len
        self.rope_theta = rope_theta

        # Rotary positional embeddings
        self.rope = RoPE(self.head_dim, self.seq_len, self.rope_theta)

        # Learned projections
        self.wq = nn.Linear(dim, dim)
        self.wk = nn.Linear(dim, dim)
        self.wv = nn.Linear(dim, dim)
        self.wo = nn.Linear(dim, dim)
        self.wo.SCALE_INIT = 1

        # Register causal mask as a buffer
        causal_mask = torch.tril(torch.ones(self.seq_len, self.seq_len))
        self.register_buffer("causal_mask", causal_mask)

    def forward(self, x: torch.Tensor, training: bool = False) -> torch.Tensor:
        batch_size, seq_len, dim = x.shape

        # Linear projections
        q, k, v = self.wq(x), self.wk(x), self.wv(x)

        # Split into multiple attention heads
        q = q.view(batch_size, seq_len, self.num_heads, self.head_dim)
        k = k.view(batch_size, seq_len, self.num_heads, self.head_dim)
        v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        # Apply positional embeddings to queries and keys
        q, k = self.rope(q).transpose(1, 2), self.rope(k).transpose(1, 2)

        # Compute scaled similarity scores
        scale = self.head_dim ** 0.5
        scores = torch.matmul(q, k.transpose(-1, -2)) / scale

        # Apply causal mask
        causal_mask = self.causal_mask[:seq_len, :seq_len]  # For inference
        scores = scores.masked_fill(causal_mask == 0, float("-inf"))

        # Apply softmax
        attn_weights = torch.softmax(scores, dim=-1)

        # Apply attention weights
        ctxt = torch.matmul(attn_weights, v)

        # Concatenate attention heads back together
        ctxt = ctxt.transpose(1, 2)
        ctxt = ctxt.contiguous().view(batch_size, seq_len, -1)

        # Output projection
        out = self.wo(ctxt)
        return out

class AttentionLayer(nn.Module):
    def __init__(self, config) -> None:
        super(AttentionLayer, self).__init__()
        self.attn_norm = nn.LayerNorm(config.dim, dtype=config.torch_dtype)
        self.attn = Attention(config)
        self.mlp_norm = nn.LayerNorm(config.dim, dtype=config.torch_dtype)
        self.mlp = MLP(config.dim, config.inter_dim, dtype=config.torch_dtype)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.attn(self.attn_norm(x))
        x = x + self.mlp(self.mlp_norm(x))
        return x


# Transformer

Classic dense Transformer


In [None]:
# Transformer configurations

class TransformerConfig(PretrainedConfig):
    model_type = "transformer"

    def __init__(
        self,
        bsz: int = 1,
        dim: int = 896,
        num_heads: int = 16,
        num_layers: int = 12,
        seq_len: int = 8192,
        vocab_size: int = 200064,
        mlp_scale: int = 12,
        weight_tying: bool = True,
        bias: bool = False,
        rope_theta: float = 10000.0,
        torch_dtype: torch.dtype = torch.bfloat16,
        device: torch.device = None,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.bsz = bsz
        self.dim = dim
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.seq_len = seq_len
        self.vocab_size = vocab_size
        self.mlp_scale = mlp_scale
        self.inter_dim = self.dim * self.mlp_scale
        self.weight_tying = weight_tying
        self.bias = bias
        self.rope_theta = rope_theta
        self.torch_dtype = torch_dtype
        self.device = device


In [None]:
# Transformer model architecture

class Transformer(PreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.tok_emb = nn.Embedding(config.vocab_size, config.dim)
        self.num_layers = config.num_layers

        self.layers = nn.ModuleList([
            AttentionLayer(config) for _ in range(self.num_layers)
        ])

        self.norm_f = nn.LayerNorm(config.dim)
        self.lm_head = nn.Linear(config.dim, config.vocab_size, bias=config.bias)

        if config.weight_tying:
            self.tok_emb.weight = self.lm_head.weight

        self.std = config.dim ** -0.5
        self.apply(self._init_weights)
        print("Model Parameter Count: %.2fM\n" % (self._get_num_params() / 1e6,))

    def forward(self, x):
        x = self.tok_emb(x)

        for layer in self.layers:
            x = layer(x)

        return self.lm_head(self.norm_f(x))

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            if hasattr(module, "SCALE_INIT"):
                self.std *= (2 * self.num_layers) ** -0.5
            torch.nn.init.normal_(module.weight, mean=0.0, std=self.std)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def _get_num_params(self):
        n_params = sum(p.numel() for p in self.parameters())
        if hasattr(self, "pos_emb") and self.pos_emb is not None:
            n_params -= self.pos_emb.weight.numel()
        return n_params

bsz = 1  # @param {type:"integer"}
dim = 768 # @param {type:"integer"}
num_heads = 12 # @param {type:"integer"}
num_layers = 8 # @param {type:"integer"}
seq_len = 1024 # @param {type:"integer"}
vocab_size = 200064 # @param {type:"integer"}
mlp_scale = 12 # @param {type:"integer"}
weight_tying = True # @param {type:"boolean"}
bias = False # @param {type:"boolean"}
rope_theta = 10000.0 # @param {type:"number"}

# Must match an attribute in torch (e.g., "bfloat16", "float32")
torch_dtype = "float32"  # @param {type:"string"}
torch_dtype = getattr(torch, torch_dtype)
print("Torch dtype:", torch_dtype)

device_str = "cuda"  # @param {type:"string"}
if not CUDA_AVAILABLE and device_str == "cuda":
    print("No CUDA devices detected but CUDA requested, setting device to CPU...")
    device_str = "cpu"
device = torch.device(device_str)

config = TransformerConfig(
    bsz=bsz,
    dim=dim,
    num_heads=num_heads,
    num_layers=num_layers,
    seq_len=seq_len,
    vocab_size=vocab_size,
    mlp_scale=mlp_scale,
    weight_tying=weight_tying,
    bias=bias,
    rope_theta=rope_theta,
    torch_dtype=torch_dtype,
    device=torch.device(device_str),
)

print("\nConfigs:")
for key, value in vars(config).items():
    print(f"  {key}: {value}")

model = Transformer(config).to(device=device, dtype=torch_dtype)
x = torch.randint(0, config.vocab_size, (config.bsz, config.seq_len), device=device)
outputs = model(x)

print("Output shape:", outputs.shape)
print("Sample output:", outputs[0, 0, :10])


# Spectral Transform Unit (STU)

See more [here](https://arxiv.org/abs/2312.06837).

In [None]:
class STU(nn.Module):
    def __init__(self, config, filters) -> None:
        super(STU, self).__init__()
        self.config = config
        self.dim = config.dim
        self.stu_filters = filters
        self.n = nearest_power_of_two(config.seq_len * 2 - 1, round_up=True)
        self.K = config.num_eigh
        self.r = config.r
        self.use_hankel_L = config.use_hankel_L
        self.use_tensordot = config.use_tensordot

        if self.use_tensordot:
            # Projection matrices
            self.M_inputs = nn.Parameter(torch.empty(self.dim, self.r, dtype=config.torch_dtype))
            self.M_filters = nn.Parameter(torch.empty(self.K, self.r, dtype=config.torch_dtype))
            self.out_proj = nn.Linear(self.r, self.dim, bias=config.bias)
        else:
            # Full M matrix
            self.M_phi_plus = nn.Parameter(torch.empty(self.K, self.dim, self.dim, dtype=config.torch_dtype))

            # If not using Hankel_L, we compute the negative featurization separately
            if not self.use_hankel_L:
                self.M_phi_minus = nn.Parameter(torch.empty(self.K, self.dim, self.dim, dtype=config.torch_dtype))

    def forward(self, u: torch.Tensor) -> torch.Tensor:
        if self.use_tensordot:
            # Project first
            u_proj = u @ self.M_inputs                     # (B, L, D) x (D, r) -> (B, L, r)
            phi_proj = self.stu_filters @ self.M_filters   # (L, K) x (K, r) -> (L, r)

            # Then, convolve: (B, L, r) ⊗ (L, r) -> (B, L, r)
            spectral_plus, spectral_minus = stu_conv(u_proj, phi_proj, self.n, self.use_tensordot)
        else:
            # Convolve first to get featurized inputs: (B, L, D) x (L, K) -> (B, L, K, D)
            U_plus, U_minus = stu_conv(u, self.stu_filters, self.n, self.use_tensordot)

            # Compute sum-product of featurized inputs and M matrices over the K filters
            B, L, K, D = U_plus.shape

            # Spectral output: (B, L, K * D) x (K * D, D) -> (B, L, D)
            spectral_plus = U_plus.view(B, L, K * self.dim) @ self.M_phi_plus.view(K * self.dim, self.dim)

            if not self.use_hankel_L:
                spectral_minus = U_minus.view(B, L, K * self.dim) @ self.M_phi_minus.view(K * self.dim, self.dim)

        out = spectral_plus if self.use_hankel_L else spectral_plus + spectral_minus
        out = self.out_proj(out) if self.use_tensordot else out
        return out

class STULayer(nn.Module):
    def __init__(self, config, stu_filters):
        super(STULayer, self).__init__()
        self.stu_norm = nn.LayerNorm(config.dim)
        self.stu = STU(config, stu_filters)
        self.mlp_norm = nn.LayerNorm(config.dim)
        self.mlp = MLP(config.dim, config.inter_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.stu(self.stu_norm(x))
        x = x + self.mlp(self.mlp_norm(x))
        return x


# Flash STU

Model definition for the Flash STU architecture. See more [here](https://arxiv.org/abs/2409.10489).

In [None]:
class FlashSTUConfig(PretrainedConfig):
    model_type = "flash_stu"

    def __init__(
        self,
        bsz: int = 1,
        dim: int = 768,
        num_heads: int = 12,
        num_layers: int = 8,
        seq_len: int = 8192,
        window_size: int = 1024,
        vocab_size: int = 200064,
        mlp_scale: int = 12,
        weight_tying: bool = True,
        bias: bool = False,
        num_eigh: int = 24,
        r: int = 8,
        use_hankel_L: bool = False,
        use_tensordot: bool = True,
        use_attn: bool = True,
        rope_theta: float = 10000.0,
        torch_dtype: torch.dtype = torch.bfloat16,
        device: torch.device = None,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.bsz = bsz
        self.dim = dim
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.seq_len = seq_len
        self.window_size = window_size
        self.vocab_size = vocab_size
        self.mlp_scale = mlp_scale
        self.inter_dim = self.dim * self.mlp_scale
        self.weight_tying = weight_tying
        self.bias = bias
        self.num_eigh = num_eigh
        self.r = r
        self.use_hankel_L = use_hankel_L
        self.use_tensordot = use_tensordot
        self.use_attn = use_attn
        self.rope_theta = rope_theta
        self.torch_dtype = torch_dtype
        self.device = device


In [None]:
class FlashSTU(PreTrainedModel):
    config_class = FlashSTUConfig

    def __init__(self, config, filters) -> None:
        super(FlashSTU, self).__init__(config)
        assert config.dim % config.num_heads == 0, f"dim ({self.dim}) must be divisible num_heads ({self.num_heads})"
        self.head_dim = config.dim // config.num_heads

        self.use_tensordot = config.use_tensordot
        self.use_hankel_L = config.use_hankel_L

        self.tok_emb = nn.Embedding(config.vocab_size, config.dim)

        self.num_layers = config.num_layers
        self.layers = nn.ModuleList()
        for layer_idx in range(config.num_layers):
            # For more complex %-split arrangements, see https://arxiv.org/pdf/2406.07887
            if layer_idx % 2 == 0:
                self.layers.append(STULayer(config, filters))
            else:
                self.layers.append(AttentionLayer(config) if config.use_attn else STULayer(config, filters))

        self.norm_f = nn.LayerNorm(config.dim)
        self.lm_head = nn.Linear(config.dim, config.vocab_size, bias=config.bias)

        if config.weight_tying:
            self.tok_emb.weight = self.lm_head.weight

        self.std = config.dim ** -0.5
        self.apply(self._init_weights)
        print("Model Parameter Count: %.2fM\n" % (self._get_num_params() / 1e6,))

    def forward(self, x: torch.Tensor) -> torch.tensor:
        x = self.tok_emb(x)

        for layer in self.layers:
            x = layer(x)

        out = self.lm_head(self.norm_f(x))
        return out

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            if hasattr(module, "SCALE_INIT"):
                self.std *= (2 * self.num_layers) ** -0.5
            torch.nn.init.normal_(module.weight, mean=0.0, std=self.std)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, STU):
            if self.use_tensordot:
                torch.nn.init.xavier_normal_(module.M_inputs)
                torch.nn.init.xavier_normal_(module.M_filters)
            else:
                torch.nn.init.xavier_normal_(module.M_phi_plus)
                if not self.use_hankel_L:
                    torch.nn.init.xavier_normal_(module.M_phi_minus)

    def _get_num_params(self):
        n_params = sum(p.numel() for p in self.parameters())
        return n_params

bsz = 1  # @param {type:"integer"}
dim = 768 # @param {type:"integer"}
num_heads = 12 # @param {type:"integer"}
num_layers = 8 # @param {type:"integer"}
seq_len = 1024 # @param {type:"integer"}
window_size = 128 # @param {type:"integer"}
vocab_size = 200064 # @param {type:"integer"}
mlp_scale = 4 # @param {type:"integer"}
weight_tying = True # @param {type:"boolean"}
bias = False # @param {type:"boolean"}
num_eigh = 24 # @param {type:"integer"}
r = 64 # @param {type:"integer"}
use_hankel_L = False # @param {type:"boolean"}
use_tensordot = True # @param {type:"boolean"}
use_attn = True #@param {type:"boolean"}
rope_theta = 10000.0 # @param {type:"number"}

# Must match an attribute in torch (e.g., "bfloat16", "float32")
torch_dtype = "float32"  # @param {type:"string"}
torch_dtype = getattr(torch, torch_dtype)
print("Torch dtype:", torch_dtype)

device_str = "cuda"  # @param {type:"string"}
if not CUDA_AVAILABLE and device_str == "cuda":
    print("No CUDA devices detected but CUDA requested, setting device to CPU...")
    device_str = "cpu"
device = torch.device(device_str)

config = FlashSTUConfig(
    bsz=bsz,
    dim=dim,
    num_heads=num_heads,
    num_layers=num_layers,
    seq_len=seq_len,
    window_size=window_size,
    vocab_size=vocab_size,
    mlp_scale=mlp_scale,
    weight_tying=weight_tying,
    bias=bias,
    num_eigh=num_eigh,
    r=r,
    use_hankel_L=use_hankel_L,
    use_tensordot=use_tensordot,
    use_attn=use_attn,
    rope_theta=rope_theta,
    torch_dtype=torch_dtype,
    device=torch.device(device_str),
)

filters = get_spectral_filters(
    seq_len=seq_len,
    k=num_eigh,
    use_hankel_L=use_hankel_L,
    device=torch.device(device_str),
)

print("\nConfigs:")
for key, value in vars(config).items():
    print(f"  {key}: {value}")

model = FlashSTU(config, filters).to(device=device, dtype=torch_dtype)
x = torch.randint(0, config.vocab_size, (config.bsz, config.seq_len), device=device)
outputs = model(x)

print("Output shape:", outputs.shape)
print("Sample output:", outputs[0, 0, :10])


# Dataloader
Simple dataloader for next-token prediction training.


In [None]:
def load_tokens(filename):
    try:
        ext = os.path.splitext(filename)[1]

        if ext == ".npy":
            npt = np.load(filename)
            npt = npt.astype(np.int32)
            ptt = torch.tensor(npt, dtype=torch.long)
            return ptt
        elif ext == ".pt":
            return torch.load(filename, weights_only=True)
        else:
            raise ValueError(f"Unsupported file extension: {ext}")

    except Exception as e:
        logger.error(f"Error loading file {filename}: {str(e)}")
        raise

class Dataloader:
    def __init__(
        self,
        bsz: int,
        seq_len: int,
        rank: int,
        world_size: int,
        dataset: str,
        split: str,
        main_process: bool = False,
    ):
        self.bsz = bsz
        self.seq_len = seq_len
        self.rank = rank
        self.world_size = world_size
        assert split in {'train', 'val', 'test'}, f"Invalid split: {split}"

        data_root = dataset
        shards = [s for s in os.listdir(data_root) if split in s and (s.endswith('.pt') or s.endswith('.npy'))]
        self.shards = [os.path.join(data_root, s) for s in sorted(shards)]
        assert len(self.shards) > 0, f'No shards found for split {split}'
        if main_process:
            logger.info(f'Found {len(self.shards)} shards for split {split}')

        # Default shard order is just sequential
        self.shard_order = list(range(len(self.shards)))
        self.shard_order_idx = 0
        self.tokens = load_tokens(self.shards[self.shard_order[self.shard_order_idx]])
        self.current_position = (self.bsz * self.seq_len * self.rank) % len(self.tokens)

    def reset(self):
        self.shard_order_idx = 0
        self.tokens = load_tokens(self.shards[self.shard_order[self.shard_order_idx]])
        self.current_position = (self.bsz * self.seq_len * self.rank) % len(self.tokens)

    def set_epoch(self, epoch):
        self.generator = torch.Generator()
        self.generator.manual_seed(epoch)
        self.shard_order = torch.randperm(len(self.shards), generator=self.generator).tolist()
        self.shard_order_idx = self.rank % len(self.shard_order)
        self.tokens = load_tokens(self.shards[self.shard_order[self.shard_order_idx]])
        self.current_position = (self.bsz * self.seq_len * self.rank) % len(self.tokens)

    def __iter__(self):
        return self

    def __next__(self):
        if self.current_position + (self.bsz * self.seq_len + 1) > len(self.tokens):
            # Move to the next shard in the randomized order
            self.shard_order_idx = (self.shard_order_idx + 1) % len(self.shard_order)
            self.tokens = load_tokens(self.shards[self.shard_order[self.shard_order_idx]])
            self.current_position = (self.bsz * self.seq_len * self.rank) % len(self.tokens)

        buf = self.tokens[self.current_position : self.current_position + self.bsz * self.seq_len + 1]
        x = buf[:-1].view(self.bsz, self.seq_len)
        y = buf[1:].view(self.bsz, self.seq_len)

        self.current_position += self.bsz * self.seq_len * self.world_size
        return x, y.to(torch.long)


# Training Loop

Basic standalone training loop setup.

In [None]:
# -----------------------
# Distributed setup

rank = 0        # This should equal 0 if NOT doing distributed training
world_size = 1  # This should equal 1 if NOT doing distributed training

# -----------------------
# Model configurations

dim = 128                # @param {type:"integer"}
num_heads = 8            # @param {type:"integer"}
num_layers = 4           # @param {type:"integer"}
seq_len = 1024           # @param {type:"integer"}
window_size = 128        # @param {type:"integer"}
vocab_size = 200064      # This depends on what tokenizer was used on dataset
mlp_scale = 4            # @param {type:"integer"}
weight_tying = True      # @param {type:"boolean"}
rope_theta = 10000.0     # @param {type:"number"}
torch_dtype = "float32"  # @param {type:"string"}
bias = False             # @param {type:"boolean"}
num_eigh = 24            # @param {type:"integer"}
r = 64                   # @param {type:"integer"}
use_hankel_L = False     # @param {type:"boolean"}
use_tensordot = False    # @param {type:"boolean"}
use_attn = True          # @param {type:"boolean"}

torch_dtype = getattr(torch, torch_dtype)
print("Training with datatype:", torch_dtype)

# -----------------------
# Training configurations

num_epochs = 1           # @param {type:"integer"}
global_bsz = 524288      # @param {type:"integer"}
micro_bsz = 1            # @param {type:"integer"}

os.makedirs(save_dir, exist_ok=True)
log_dir = os.path.join(save_dir, "log")
checkpoint_dir = os.path.join(save_dir, "checkpoints")
os.makedirs(log_dir, exist_ok=True)
os.makedirs(checkpoint_dir, exist_ok=True)

device_str = "cuda" if CUDA_AVAILABLE else "cpu"
device = torch.device(device_str)

# -----------------------
# Optimizations
torch_compile = False    # @param {type:"boolean"}

# -----------------------
# Dataset configurations

dataset = "/content/drive/MyDrive/fineweb-edu-10B"  # @param {type:"string"}
os.makedirs(dataset, exist_ok=True)
total_tokens = 10_000_000_000                       # @param {type:"integer"}

assert (
    global_bsz % (micro_bsz * seq_len * world_size) == 0
), f"global_bsz ({global_bsz}) must be divisible by micro_bsz * seq_len * world_size ({micro_bsz * seq_len * world_size}),"
grad_accum_steps = global_bsz // (micro_bsz * seq_len * world_size)

# -----------------------
# Compute derived parameters

num_steps = total_tokens // global_bsz
max_steps = num_steps * num_epochs

# -----------------------
# Training configurations

optimizer_name = "Adagrad"  # @param {type:"string"}
eval_period = 50            # @param {type:"integer"}
save_period = 1000          # @param {type:"integer"}
max_lr = 3.0e-4             # @param {type:"number"}
min_lr = 3.0e-5             # @param {type:"number"}
max_norm = 1.0              # @param {type:"number"}
warmup_steps = 1907         # @param {type:"integer"}

optimizer_cls = getattr(torch.optim, optimizer_name, None)
if optimizer_cls is None:
    raise ValueError(f"Optimizer {optimizer_cls} is not available in torch.optim")

print(f"Total (desired) batch size: {global_bsz}")
print(f"=> Number of gradient accumulation steps: {grad_accum_steps}")
print(f"\nTraining for {max_steps} steps")

# -----------------------
# Model configuration

config = FlashSTUConfig(
    bsz=micro_bsz,
    dim=dim,
    num_heads=num_heads,
    num_layers=num_layers,
    seq_len=seq_len,
    window_size=window_size,
    vocab_size=vocab_size,
    mlp_scale=mlp_scale,
    weight_tying=weight_tying,
    bias=bias,
    num_eigh=num_eigh,
    r=r,
    use_hankel_L=use_hankel_L,
    use_tensordot=use_tensordot,
    use_attn=use_attn,
    rope_theta=rope_theta,
    torch_dtype=torch_dtype,
    device=torch.device(device_str),
)

phi = get_spectral_filters(seq_len, num_eigh, use_hankel_L, device)
model = FlashSTU(config, phi)
if torch_compile:
    if torch.cuda.is_available():
        torch.compile(model)
        print(f"Successfully torch.compiled the {model.__class__.__name__} model.")
    else:
        print("Warning: Torch compiler enabled but no CUDA devices detected.")
model = model.to(device)
print(model.eval())

optimizer = optimizer_cls(model.parameters(), lr=max_lr)
loss_fn = torch.nn.CrossEntropyLoss()

# -----------------------
# Checkpointing

# Look for the latest checkpoint file in the checkpoint directory
checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "checkpoint_*.pt"))
if checkpoint_files:
    latest_checkpoint = max(checkpoint_files, key=os.path.getmtime)
    checkpoint = torch.load(latest_checkpoint)
    model.load_state_dict(checkpoint["model_state"])
    optimizer.load_state_dict(checkpoint["optimizer_state"])
    start_step = checkpoint["step"]
    best_val_loss = checkpoint["val_loss"]
    # Use the timestamp from the checkpoint for consistency
    timestamp = checkpoint.get("timestamp", time.strftime("%Y%m%d-%H%M%S"))
    log_mode = "a"
    print(f"Resumed training from checkpoint '{latest_checkpoint}' at step {start_step}, best validation loss: {best_val_loss:.6f}, timestamp: {timestamp}")
else:
    start_step = 0
    best_val_loss = float("inf")
    log_mode = "w"
    timestamp = time.strftime("%Y%m%d-%H%M%S")

timestamp = time.strftime("%Y%m%d-%H%M%S")
log_file = os.path.join(log_dir, f"log_{timestamp}.txt")

with open(log_file, log_mode) as f:
    pass

# -----------------------
# Dataloader setup

train_loader = Dataloader(bsz=micro_bsz, seq_len=seq_len, rank=rank, world_size=world_size, dataset=dataset, split="train")
val_loader = Dataloader(bsz=micro_bsz, seq_len=seq_len, rank=rank, world_size=world_size, dataset=dataset, split="val")

# -----------------------
# Training loop

for step in range(start_step + 1, max_steps + 1):
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    # Start epoch
    epoch = step // num_steps
    last_step = step % num_steps == 0
    if step == 1 or step % num_steps == 1:
        print(f"Starting epoch {epoch + 1}...")
        train_loader.set_epoch(epoch)

    t0 = time.perf_counter()

    # Evaluate on validation set every once in a while
    if step == 1 or step % eval_period == 0 or last_step:
        print(f"Evaluating the model at step {step}...")
        val_steps = 20  # Arbitrarily set to reduce long evaluations
        model.eval()
        val_loader.reset()

        total_val_loss = 0.0
        with torch.no_grad():
            for i, (val_x, val_y) in zip(range(val_steps), val_loader, strict=False):
                val_x, val_y = val_x.to(device), val_y.to(device)
                val_preds = model(val_x)
                loss = loss_fn(val_preds.flatten(0, 1), val_y.flatten(0, 1))
                total_val_loss += loss.detach().float()
        avg_val_loss = total_val_loss / val_steps

        with open(log_file, "a") as f:
            f.write(f"{step} val {avg_val_loss.item()}\n")

        if step > 0 and (step % save_period == 0 or last_step):
            if avg_val_loss < best_val_loss:
                print(f"Validation loss improved from {best_val_loss:.6f} to {avg_val_loss:.6f}!")
                best_val_loss = avg_val_loss
                timestamp_short = time.strftime("%Y%m%d-%H%M")
                new_checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_{timestamp_short}.pt")
                checkpoint_data = {
                    "model_class": model.__class__.__name__,
                    "model_config": config.to_dict(),
                    "model_state": model.state_dict(),
                    "model_repr": repr(model),
                    "optimizer_state": optimizer.state_dict(),
                    "step": step,
                    "val_loss": best_val_loss,
                    "timestamp": timestamp_short,
                }
                torch.save(checkpoint_data, new_checkpoint_path)
                print(f"Saved checkpoint at step {step} with validation loss: {avg_val_loss:.6f}, timestamp: {timestamp_short}, file: {new_checkpoint_path}")

    # Training step
    model.train()
    train_loss = 0.0
    optimizer.zero_grad()

    for micro_step, (x, y) in zip(range(grad_accum_steps), train_loader, strict=False):
        x, y = x.to(device), y.to(device)
        preds = model(x)
        loss = loss_fn(preds.flatten(0, 1), y.flatten(0, 1))
        del preds
        loss = loss / grad_accum_steps
        train_loss += loss.detach().float()
        loss.backward()

    t1 = time.perf_counter()
    dt = t1 - t0

    # Clip gradients
    norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)

    # Get next learning rate from scheduler
    lr = linear_decay_with_warmup(step, warmup_steps, max_steps, max_lr, min_lr)

    # Update the learning rate(s)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    # Take grad step
    try:
        optimizer.step()
    except KeyError as optim_key_err:
        raise RuntimeError(
            "optimizer.step() failed; are you using the same optimizer from the checkpoint?"
        ) from optim_key_err

    # Zero out grads for the next forward pass
    optimizer.zero_grad()

    if torch.cuda.is_available():
        torch.cuda.synchronize()

    # Log step metrics
    toks_processed = (
        train_loader.bsz * train_loader.seq_len * grad_accum_steps * world_size
    )
    toks_per_sec = toks_processed / dt
    log_message = (
        f"step {step:5d} | "
        f"loss: {train_loss:.6f} | "
        f"lr {lr:.4e} | "
        f"norm: {norm:.4f} | "
        f"dt: {dt*1000:.2f}ms | "
        f"tok/s: {toks_per_sec:.2f}"
    )
    print(log_message)
    with open(log_file, "a") as f:
        f.write(
            f"{step} train {train_loss:.6f} lr {lr:.4e} norm {norm:.4f} dt {dt*1000:.2f} tok/s {toks_per_sec:.2f}\n"
        )


# Running inference

As a sanity check, the following cell runs inference on a trained checkpoint.


In [None]:
CHECKPOINT_PATH = "/content/drive/MyDrive/stu_exps/checkpoints/checkpoint_20250324-0126.pt"  # @param {type:"string"}

def load_checkpoint(checkpoint_path: str, device: torch.device):
    """
    Load a checkpoint and rebuild the model from the stored metadata.

    Returns:
        model, optimizer_state, step, val_loss, timestamp, model_repr
    """
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model_class_name = checkpoint.get("model_class")
    model_config_dict = checkpoint.get("model_config")
    model_config_dict["torch_dtype"] = getattr(torch, model_config_dict["torch_dtype"])
    if model_class_name is None or model_config_dict is None:
        raise ValueError("Missing required model metadata in checkpoint.")

    # Map stored class name to actual class.
    MODEL_MAPPING = {
        "FlashSTU": FlashSTU,
        "Transformer": Transformer,
    }
    if model_class_name not in MODEL_MAPPING:
        raise ValueError(f"Unknown model class: {model_class_name}")
    model_cls = MODEL_MAPPING[model_class_name]

    # This depends on which model you trained!
    config = FlashSTUConfig(**model_config_dict)

    filters = get_spectral_filters(
        seq_len=config.seq_len,
        k=config.num_eigh,
        use_hankel_L=config.use_hankel_L,
        device=device,
    )

    # Instantiate and load the model.
    model = model_cls(config, filters)
    model = model.to(device=device, dtype=config.torch_dtype)
    model.load_state_dict(checkpoint["model_state"], strict=True)
    model.eval()

    optimizer_state = checkpoint.get("optimizer_state")
    step = checkpoint.get("step")
    val_loss = checkpoint.get("val_loss")
    timestamp = checkpoint.get("timestamp")
    model_repr = checkpoint.get("model_repr")

    return model, optimizer_state, step, val_loss, timestamp, model_repr

def generate_text(
    model,
    tokenizer,
    prompt,
    num_return_sequences=1,
    max_length=512,
    device="cuda",
    temperature=1.0,
    top_k=50,
):
    """
    Generate text from the given prompt using top-k sampling.

    Args:
        model: The FlashSTU model instance.
        tokenizer: The tokenizer used for encoding/decoding.
        prompt (str): Input prompt text.
        num_return_sequences (int): How many sequences to return.
        max_length (int): Maximum length of generated tokens.
        device: torch device.
        temperature (float): Sampling temperature. Higher = more random.
        top_k (int): Top-K sampling parameter.

    Returns:
        list[str]: A list of generated text sequences.
    """
    model.eval()

    # Encode prompt tokens.
    tokens = torch.tensor(
        [tokenizer.encode(prompt, allowed_special={"<|endoftext|>"})],
        device=device,
    )
    tokens = tokens.repeat(num_return_sequences, 1)

    sample_rng = torch.Generator(device=device)
    sample_rng.manual_seed(1746)

    eos_token_id = tokenizer.encode(
        "<|endoftext|>", allowed_special={"<|endoftext|>"}
    )[0]

    with torch.no_grad():
        for _ in range(max_length - tokens.size(1)):
            with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
                # Fwd pass. Inspect logits here.
                logits = model(tokens)     # shape: [batch, seq, vocab]
                logits = logits[:, -1, :]  # last token logits

                # Apply temperature scaling.
                if temperature > 0:
                    logits = logits / temperature

            # Compute probabilities.
            probs = F.softmax(logits, dim=-1)

            # Top-K sampling.
            top_k_probs, top_k_indices = torch.topk(probs, top_k, dim=-1)
            ix = torch.multinomial(top_k_probs, 1, generator=sample_rng)
            next_token = torch.gather(top_k_indices, -1, ix)

            # Append next token.
            tokens = torch.cat((tokens, next_token), dim=1)

            # Stop if EOS token is generated.
            if (next_token == eos_token_id).any():
                break

    # Decode all sequences.
    generated_sequences = []
    for i in range(num_return_sequences):
        decoded = tokenizer.decode(tokens[i].tolist())
        generated_sequences.append(decoded)

    return generated_sequences


# Load model
model, _, _, _, _, _= load_checkpoint(CHECKPOINT_PATH, device)
tokenizer = tiktoken.get_encoding("o200k_base")

# Collect prompt(s) from user.
prompts = []
PROMPT_ONE = "Hi, I'm a language model, and"  # @param {type:"string"}
PROMPT_TWO = "The biggest scientific discovery in the 21st century was"  # @param {type:"string"}
PROMPT_THREE = "The capital of France is"  # @param {type:"string"}
prompts.extend([PROMPT_ONE, PROMPT_TWO, PROMPT_THREE])

# -------------------------------------------------------------------
# BASE SETTINGS:
BASE_TEMPERATURE = 0.7  # @param {type:"number"}  Increase for more randomness.
BASE_TOP_K = 50         # @param {type:"integer"} Limit sampling to the top k tokens.
MAX_LENGTH = 512        # @param {type:"integer"} Maximum number of tokens to generate.
# -------------------------------------------------------------------

total_tokens = 0
start_time = time.perf_counter()

for i, prompt in enumerate(prompts, 1):
    logger.info(f"Generating text for prompt {i}: {prompt}")
    generated_texts = generate_text(
        model,
        tokenizer,
        prompt,
        num_return_sequences=1,
        max_length=MAX_LENGTH,
        device=device,
        temperature=BASE_TEMPERATURE,
        top_k=BASE_TOP_K,
    )
    for gen_text in generated_texts:
        print(f"\nPrompt: {prompt}")
        print(f"Generated Text: {gen_text}\n")
        total_tokens += len(
            tokenizer.encode(gen_text, allowed_special={"<|endoftext|>"})
        )
end_time = time.perf_counter()
tokens_per_second = total_tokens / (end_time - start_time)
logger.info(f"Tokens per second: {tokens_per_second:.2f}")
