In [None]:
"""
Tiny Recursive Model (TRM) for Text Generation

Based on: "Less is More: Recursive Reasoning with Tiny Networks" by Jolicoeur-Martineau (2025)

This implementation adapts TRM for autoregressive text generation:
- Uses recursive reasoning with a tiny 2-layer transformer
- Deep supervision with latent state carried across improvement steps
- Single network architecture (no hierarchical split)
- EMA for training stability

Key insight from the paper: smaller networks with deep recursion can outperform
larger networks by avoiding overfitting while achieving high effective depth.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from transformers import GPT2Tokenizer
import math
import os
from tqdm import tqdm
import copy


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from transformers import GPT2Tokenizer
import math
import os
from tqdm import tqdm
import copy


# ============================================================================
# Model Architecture
# ============================================================================

class RMSNorm(nn.Module):
    """Root Mean Square Layer Normalization"""
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
        return x / rms * self.weight


class RotaryEmbedding(nn.Module):
    """Rotary Position Embedding (RoPE)"""
    def __init__(self, dim, max_seq_len=512):
        super().__init__()
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)
        self.max_seq_len = max_seq_len
        self._build_cache(max_seq_len)

    def _build_cache(self, seq_len):
        t = torch.arange(seq_len, device=self.inv_freq.device)
        freqs = torch.einsum('i,j->ij', t, self.inv_freq)
        emb = torch.cat([freqs, freqs], dim=-1)
        self.register_buffer('cos_cached', emb.cos())
        self.register_buffer('sin_cached', emb.sin())

    def forward(self, x):
        seq_len = x.shape[1]
        return self.cos_cached[:seq_len], self.sin_cached[:seq_len]


def rotate_half(x):
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat([-x2, x1], dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin):
    # Original cos/sin shape: [seq_len, head_dim]
    # q/k shape: [batch_size, n_heads, seq_len, head_dim]
    # We need cos/sin to be [1, 1, seq_len, head_dim] for proper broadcasting
    cos = cos.unsqueeze(0).unsqueeze(1)  # Corrected from unsqueeze(2)
    sin = sin.unsqueeze(0).unsqueeze(1)  # Corrected from unsqueeze(2)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


class SwiGLU(nn.Module):
    """SwiGLU activation function"""
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.w1 = nn.Linear(dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False)
        self.w3 = nn.Linear(dim, hidden_dim, bias=False)

    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))


class CausalSelfAttention(nn.Module):
    """Multi-head causal self-attention with RoPE"""
    def __init__(self, dim, n_heads, max_seq_len=512):
        super().__init__()
        assert dim % n_heads == 0
        self.n_heads = n_heads
        self.head_dim = dim // n_heads

        self.qkv = nn.Linear(dim, 3 * dim, bias=False)
        self.proj = nn.Linear(dim, dim, bias=False)
        self.rope = RotaryEmbedding(self.head_dim, max_seq_len)

        # Causal mask
        mask = torch.triu(torch.ones(max_seq_len, max_seq_len), diagonal=1).bool()
        self.register_buffer('mask', mask)

    def forward(self, x):
        B, T, C = x.shape

        qkv = self.qkv(x)
        q, k, v = qkv.split(C, dim=-1)

        q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)

        cos, sin = self.rope(x)
        q, k = apply_rotary_pos_emb(q, k, cos, sin)

        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim))
        att = att.masked_fill(self.mask[:T, :T], float('-inf'))
        att = F.softmax(att, dim=-1)

        y = att @ v
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        return self.proj(y)


class TransformerBlock(nn.Module):
    """Single transformer block with pre-norm"""
    def __init__(self, dim, n_heads, mlp_ratio=4, max_seq_len=512):
        super().__init__()
        self.norm1 = RMSNorm(dim)
        self.attn = CausalSelfAttention(dim, n_heads, max_seq_len)
        self.norm2 = RMSNorm(dim)
        self.mlp = SwiGLU(dim, dim * mlp_ratio)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x


In [None]:

class TinyRecursiveNetwork(nn.Module):
    """
    The core tiny network used in TRM.
    Only 2 layers as per the paper's finding that smaller is better.
    """
    def __init__(self, dim, n_heads=8, n_layers=2, mlp_ratio=4, max_seq_len=512):
        super().__init__()
        self.layers = nn.ModuleList([
            TransformerBlock(dim, n_heads, mlp_ratio, max_seq_len)
            for _ in range(n_layers)
        ])
        self.norm = RMSNorm(dim)

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return self.norm(x)


class TinyRecursiveModel(nn.Module):
    """
    Tiny Recursive Model for Text Generation

    Architecture based on TRM paper:
    - Single tiny 2-layer network
    - Recursive reasoning with latent z and prediction y
    - Deep supervision across multiple improvement steps

    For text generation:
    - x: embedded input sequence (context)
    - y: current token predictions (embedded)
    - z: latent reasoning state

    The model recursively improves its latent z, then updates y.
    """
    def __init__(
        self,
        vocab_size,
        dim=256,
        n_heads=8,
        n_layers=2,
        mlp_ratio=4,
        max_seq_len=256,
        n_latent_recursions=6,  # n in the paper
        n_improvement_cycles=3,  # T in the paper
    ):
        super().__init__()
        self.dim = dim
        self.vocab_size = vocab_size
        self.max_seq_len = max_seq_len
        self.n_latent_recursions = n_latent_recursions
        self.n_improvement_cycles = n_improvement_cycles

        # Embeddings
        self.token_emb = nn.Embedding(vocab_size, dim)
        self.pos_emb = nn.Embedding(max_seq_len, dim)

        # Single tiny network (key insight: one network is better than two)
        self.net = TinyRecursiveNetwork(dim, n_heads, n_layers, mlp_ratio, max_seq_len)

        # Projection layers for combining x, y, z
        self.combine_xyz = nn.Linear(dim * 3, dim, bias=False)
        self.combine_yz = nn.Linear(dim * 2, dim, bias=False)

        # Output head
        self.output_head = nn.Linear(dim, vocab_size, bias=False)

        # Halting head for ACT (simplified - no Q-learning)
        self.halt_head = nn.Linear(dim, 1, bias=False)

        # Learnable initial states for y and z
        self.y_init = nn.Parameter(torch.randn(1, 1, dim) * 0.02)
        self.z_init = nn.Parameter(torch.randn(1, 1, dim) * 0.02)

        self._init_weights()

    def _init_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            elif isinstance(module, nn.Embedding):
                torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def get_embeddings(self, input_ids):
        """Get token + position embeddings"""
        B, T = input_ids.shape
        # Clamp input_ids to valid range
        input_ids = input_ids.clamp(0, self.vocab_size - 1)
        # Clamp position to max_seq_len
        T = min(T, self.max_seq_len)
        pos = torch.arange(T, device=input_ids.device).unsqueeze(0)
        return self.token_emb(input_ids[:, :T]) + self.pos_emb(pos)

    def latent_recursion(self, x, y, z):
        """
        Single recursion cycle:
        1. Update z n times given (x, y, z)
        2. Update y once given (y, z)
        """
        # Latent reasoning: update z n times
        for _ in range(self.n_latent_recursions):
            combined = self.combine_xyz(torch.cat([x, y, z], dim=-1))
            z = self.net(combined)

        # Refine prediction: update y given (y, z)
        combined_yz = self.combine_yz(torch.cat([y, z], dim=-1))
        y = self.net(combined_yz)

        return y, z

    def deep_recursion(self, x, y, z, use_grad=True):
        """
        Deep recursion with T improvement cycles.
        First T-1 cycles without gradients, last cycle with gradients.
        """
        if not use_grad:
            # All cycles without gradients (inference)
            with torch.no_grad():
                for _ in range(self.n_improvement_cycles):
                    y, z = self.latent_recursion(x, y, z)
            return y.detach(), z.detach()

        # T-1 cycles without gradients
        with torch.no_grad():
            for _ in range(self.n_improvement_cycles - 1):
                y, z = self.latent_recursion(x, y, z)

        # Last cycle with gradients
        y, z = self.latent_recursion(x, y, z)

        return y.detach(), z.detach(), self.output_head(y), self.halt_head(y.mean(dim=1))

    def forward(self, input_ids, targets=None, n_supervision_steps=4):
        """
        Forward pass with deep supervision.

        Args:
            input_ids: [B, T] input token IDs
            targets: [B, T] target token IDs (for training)
            n_supervision_steps: number of deep supervision steps

        Returns:
            If training: loss
            If inference: logits
        """
        B, T = input_ids.shape
        T = min(T, self.max_seq_len)
        input_ids = input_ids[:, :T]

        x = self.get_embeddings(input_ids)

        # Initialize y and z
        y = self.y_init.expand(B, T, -1).clone()
        z = self.z_init.expand(B, T, -1).clone()

        if targets is None:
            # Inference: just run deep recursion
            y, z = self.deep_recursion(x, y, z, use_grad=False)
            return self.output_head(y)

        # Ensure targets match input length
        targets = targets[:, :T]

        # Training with deep supervision
        total_loss = 0.0

        for step in range(n_supervision_steps):
            y, z, logits, halt_logit = self.deep_recursion(x, y, z, use_grad=True)

            # Cross-entropy loss for token prediction
            ce_loss = F.cross_entropy(
                logits.view(-1, self.vocab_size),
                targets.reshape(-1),
                ignore_index=-100
            )

            # Halting loss (simplified ACT)
            with torch.no_grad():
                preds = logits.argmax(dim=-1)
                mask = (targets != -100)
                correct = ((preds == targets) & mask).float().sum() / mask.float().sum().clamp(min=1)
            halt_loss = F.binary_cross_entropy_with_logits(
                halt_logit.squeeze(-1),
                correct.expand(B)
            )

            total_loss = total_loss + ce_loss + 0.1 * halt_loss

        return total_loss / n_supervision_steps

    @torch.no_grad()
    def generate(self, input_ids, max_new_tokens=50, temperature=0.8, top_k=40):
        """Generate text autoregressively"""
        self.eval()

        for _ in range(max_new_tokens):
            # Crop to max_seq_len - 1 to leave room for prediction
            idx_cond = input_ids[:, -(self.max_seq_len - 1):]

            # Clamp input ids to valid vocab range
            idx_cond = idx_cond.clamp(0, self.vocab_size - 1)

            # Get predictions
            logits = self(idx_cond)
            logits = logits[:, -1, :] / temperature

            # Top-k sampling
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = float('-inf')

            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            input_ids = torch.cat([input_ids, next_token], dim=1)

        return input_ids



In [None]:

# ============================================================================
# Dataset
# ============================================================================

class TinyStoriesDataset(Dataset):
    """Dataset for TinyStories"""
    def __init__(self, tokenizer, split='train', max_length=256, max_samples=None):
        print(f"Loading TinyStories {split} split...")
        dataset = load_dataset('roneneldan/TinyStories', split=split)

        if max_samples:
            dataset = dataset.select(range(min(max_samples, len(dataset))))

        self.tokenizer = tokenizer
        self.max_length = max_length
        self.texts = dataset['text']
        self.vocab_size = tokenizer.vocab_size
        self.pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
        print(f"Loaded {len(self.texts)} samples")

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        # Add BOS/EOS handling
        tokens = self.tokenizer.encode(text, truncation=True, max_length=self.max_length)

        # Ensure all tokens are within valid range
        tokens = [min(max(t, 0), self.vocab_size - 1) for t in tokens]

        # Pad if necessary
        if len(tokens) < self.max_length:
            tokens = tokens + [self.pad_token_id] * (self.max_length - len(tokens))
        else:
            tokens = tokens[:self.max_length]

        tokens = torch.tensor(tokens, dtype=torch.long)

        # Input is tokens[:-1], target is tokens[1:]
        input_ids = tokens[:-1].clone()
        targets = tokens[1:].clone()

        # Mask padding in targets (set to -100 to ignore in loss)
        targets[targets == self.pad_token_id] = -100

        return input_ids, targets


In [None]:


# ============================================================================
# Training
# ============================================================================

class EMA:
    """Exponential Moving Average for model weights"""
    def __init__(self, model, decay=0.999):
        self.model = model
        self.decay = decay
        self.shadow = {}
        self.backup = {}

        for name, param in model.named_parameters():
            if param.requires_grad:
                self.shadow[name] = param.data.clone()

    def update(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                self.shadow[name] = (
                    self.decay * self.shadow[name] +
                    (1 - self.decay) * param.data
                )

    def apply_shadow(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                self.backup[name] = param.data.clone()
                param.data = self.shadow[name]

    def restore(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                param.data = self.backup[name]


def train(
    model,
    train_loader,
    val_loader,
    tokenizer,
    device,
    epochs=5,
    lr=1e-4,
    warmup_steps=1000,
    n_supervision_steps=4,
    ema_decay=0.999,
    save_path='trm_tinystories.pt'
):
    """Training loop with deep supervision and EMA"""

    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=(0.9, 0.95), weight_decay=0.1)
    ema = EMA(model, decay=ema_decay)

    # Learning rate scheduler with warmup
    def lr_schedule(step):
        if step < warmup_steps:
            return step / warmup_steps
        return 1.0

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_schedule)

    global_step = 0
    best_val_loss = float('inf')

    for epoch in range(epochs):
        model.train()
        pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs}')

        for input_ids, targets in pbar:
            input_ids = input_ids.to(device)
            targets = targets.to(device)

            optimizer.zero_grad()
            loss = model(input_ids, targets, n_supervision_steps=n_supervision_steps)
            loss.backward()

            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

            optimizer.step()
            scheduler.step()
            ema.update()

            global_step += 1
            pbar.set_postfix({'loss': f'{loss.item():.4f}', 'lr': f'{scheduler.get_last_lr()[0]:.6f}'})

        # Validation
        ema.apply_shadow()
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for input_ids, targets in tqdm(val_loader, desc='Validation'):
                input_ids = input_ids.to(device)
                targets = targets.to(device)
                loss = model(input_ids, targets, n_supervision_steps=n_supervision_steps)
                val_loss += loss.item()

        val_loss /= len(val_loader)
        print(f'Epoch {epoch+1} - Val Loss: {val_loss:.4f}')

        # Generate sample
        prompt = "Once upon a time"
        prompt_ids = torch.tensor([tokenizer.encode(prompt)], device=device)
        generated = model.generate(prompt_ids, max_new_tokens=100)
        generated_text = tokenizer.decode(generated[0].tolist())
        print(f'Generated: {generated_text[:300]}...\n')

        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save({
                'model_state_dict': model.state_dict(),
                'ema_shadow': ema.shadow,
                'epoch': epoch,
                'val_loss': val_loss
            }, save_path)
            print(f'Saved best model with val_loss={val_loss:.4f}')

        ema.restore()

    return model



In [None]:

# ============================================================================
# Main
# ============================================================================

#def main():
# Configuration
config = {
    'vocab_size': 50257,  # GPT-2 vocab
    'dim': 256,           # Hidden dimension
    'n_heads': 8,         # Attention heads
    'n_layers': 2,        # Only 2 layers (key insight from paper)
    'mlp_ratio': 4,
    'max_seq_len': 128,   # Reduced for stability
    'n_latent_recursions': 4,  # n in paper (reduced for memory)
    'n_improvement_cycles': 2,  # T in paper (reduced for memory)

    # Training
    'batch_size': 256,     # 16 Reduced batch size
    'epochs': 3,
    'lr': 1e-4,
    'warmup_steps': 500,
    'n_supervision_steps': 3,  # Deep supervision steps during training
    'max_train_samples': 2000000,  # Limit for faster training demo
    'max_val_samples': 20000,
}


In [None]:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token

# Model
model = TinyRecursiveModel(
    vocab_size=config['vocab_size'],
    dim=config['dim'],
    n_heads=config['n_heads'],
    n_layers=config['n_layers'],
    mlp_ratio=config['mlp_ratio'],
    max_seq_len=config['max_seq_len'],
    n_latent_recursions=config['n_latent_recursions'],
    n_improvement_cycles=config['n_improvement_cycles'],
)

# Count parameters
n_params = sum(p.numel() for p in model.parameters())
print(f'Model parameters: {n_params:,} ({n_params/1e6:.2f}M)')
print(f'Effective depth per supervision step: {config["n_improvement_cycles"] * (config["n_latent_recursions"] + 1) * config["n_layers"]}')


Using device: cuda
Model parameters: 28,191,232 (28.19M)
Effective depth per supervision step: 20


In [None]:

# Datasets
train_dataset = TinyStoriesDataset(
    tokenizer,
    split='train',
    max_length=config['max_seq_len'] + 1,  # +1 for next token prediction
    max_samples=config['max_train_samples']
)
val_dataset = TinyStoriesDataset(
    tokenizer,
    split='validation',
    max_length=config['max_seq_len'] + 1,
    max_samples=config['max_val_samples']
)

train_loader = DataLoader(
    train_dataset,
    batch_size=config['batch_size'],
    shuffle=True,
    num_workers=2,
    pin_memory=True
)
val_loader = DataLoader(
    val_dataset,
    batch_size=config['batch_size'],
    num_workers=2,
    pin_memory=True
)


Loading TinyStories train split...
Loaded 2000000 samples
Loading TinyStories validation split...
Loaded 20000 samples


In [None]:

# Train
model = train(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    tokenizer=tokenizer,
    device=device,
    epochs=config['epochs'],
    lr=config['lr'],
    warmup_steps=config['warmup_steps'],
    n_supervision_steps=config['n_supervision_steps'],
)

print('\nTraining complete!')


Epoch 1/3: 100%|██████████| 7813/7813 [1:18:09<00:00,  1.67it/s, loss=2.0730, lr=0.000100]
Validation: 100%|██████████| 79/79 [00:22<00:00,  3.48it/s]


Epoch 1 - Val Loss: 2.0817
Generated: Once upon a time, there was a boy named Timmy. Timmy loved to play with his toys all day long. One day, Timmy's mom gave him a big, colorful toy car. Timmy was so happy to have it, but his mom warned him not to eat it because it might not fit in the toy car.

Timmy didn't like any idea, so he asked ...

Saved best model with val_loss=2.0817


Epoch 2/3: 100%|██████████| 7813/7813 [1:18:09<00:00,  1.67it/s, loss=1.8105, lr=0.000100]
Validation: 100%|██████████| 79/79 [00:22<00:00,  3.48it/s]


Epoch 2 - Val Loss: 1.7954
Generated: Once upon a time, there was a little girl named Lily. One day, she went to visit her grandma's house. Lily loved the color pink and her grandma wore them. 

As they walked around grandma's house, Lily noticed that her grandma's grandma was wearing a sparkly dress. Mom asked Lily where her grandma wa...

Saved best model with val_loss=1.7954


Epoch 3/3: 100%|██████████| 7813/7813 [1:18:09<00:00,  1.67it/s, loss=1.7581, lr=0.000100]
Validation: 100%|██████████| 79/79 [00:22<00:00,  3.49it/s]


Epoch 3 - Val Loss: 1.6934
Generated: Once upon a time, there was a little girl named Lily. She loved to play outside and pick flowers in the garden. One day, she found a tiny seed and watered it with a big smile. She watered it every day and waited for it to grow.

One day, Lily went outside to play. She saw a butterfly and said hello ...

Saved best model with val_loss=1.6934

Training complete!


In [None]:

# Final generation examples
model.eval()
ema = EMA(model)

prompts = [
    "Once upon a time",
    "The little girl",
    "One day, a rabbit",
    "Tom and his friend"
]

print('\n=== Generated Stories ===\n')
for prompt in prompts:
    prompt_ids = torch.tensor([tokenizer.encode(prompt)], device=device)
    generated = model.generate(prompt_ids, max_new_tokens=150, temperature=0.8)
    text = tokenizer.decode(generated[0].tolist())
    print(f'Prompt: "{prompt}"')
    print(f'Story: {text}\n')
    print('-' * 50 + '\n')


#if __name__ == '__main__':
#    main()


=== Generated Stories ===

Prompt: "Once upon a time"
Story: Once upon a time, there was a big lion. He was very lazy, but he always went for a walk. One day, he saw a little girl walking by. She was running around in a bush and looking at the lion. The lion didn't like that, so he kept trying to catch her. 

As he walked, he heard a voice. It was a friendly voice calling for help! He looked around and noticed a little girl walking on the path. The girl was smiling and she waved her hand. The lion was happy to help her and hopped over to her. 

The lion led her through the forest together, until it was out of sight. The little girl was so excited to see the lion and wanted to get closer

--------------------------------------------------

Prompt: "The little girl"
Story: The little girl was very excited. She had a tray for her birthday party. It was big and blue and had lots of things inside. She couldn't wait to play with it! 

The little girl spent the day with the tray. But then sh