# Experiment 17: Internal RL with Temporal Abstractions

**Self-contained Colab notebook implementing hierarchical RL for code generation**

Based on: ["Emergent temporal abstractions in autoregressive models enable hierarchical reinforcement learning"](https://arxiv.org/abs/2512.20605) (Kobayashi et al., Google, Dec 2025)

---

## Key Insight

Instead of doing RL on individual tokens (massive search space), we:
1. **Discover** temporally-abstract actions in the model's residual stream
2. **Explore** in a compact 16D latent space instead of 50K+ vocabulary
3. **Execute** each abstract action generates multiple tokens until switching

```
Standard RL:  Token₁ → Token₂ → ... → Token₁₀₀ → Reward  (massive variance)
Internal RL:  z₁ → z₂ → z₃ → z₄ → z₅ → Reward           (tractable)
```

## Part 1: Setup and Installation

In [None]:
# Install dependencies
!pip install -q transformers accelerate peft bitsandbytes torch
!pip install -q datasets tqdm

In [None]:
import os
import sys
import json
import time
import math
import random
from dataclasses import dataclass, field
from typing import Optional, List, Dict, Any, Tuple
from collections import deque

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR

from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm.auto import tqdm

# Check GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## Part 2: Configuration

In [None]:
@dataclass
class Config:
    """Configuration for Experiment 17."""
    # Model
    base_model: str = "Qwen/Qwen2.5-Coder-0.5B-Instruct"
    controller_layer: int = 12  # Mid-depth for Qwen 0.5B (24 layers)
    
    # Metacontroller architecture
    latent_dim: int = 16        # Dimension of abstract action space
    gru_dim: int = 64           # GRU hidden dimension
    seq_embed_dim: int = 64     # Sequence embedding dimension
    encoder_hidden: int = 64    # Encoder MLP hidden dim
    decoder_hidden: int = 64    # Decoder MLP hidden dim
    switch_hidden: int = 64     # Switching unit hidden dim
    controller_rank: int = 16   # Low-rank controller (like LoRA)
    
    # Phase 2: Metacontroller training
    mc_batch_size: int = 4
    mc_learning_rate: float = 1e-3
    mc_weight_decay: float = 0.03
    mc_epochs: int = 30
    kl_weight: float = 0.1      # ELBO regularization weight
    
    # Phase 3: Internal RL
    rl_batch_size: int = 8
    rl_learning_rate: float = 3e-5
    rl_steps: int = 1000
    beta_threshold: float = 0.5  # Switching threshold
    clip_epsilon: float = 0.2    # PPO clip
    entropy_coef: float = 0.01
    
    # Generation
    max_tokens: int = 256
    temperature: float = 0.7
    max_seq_len: int = 512
    
    # Logging
    log_interval: int = 10
    eval_interval: int = 50
    
config = Config()
print("Configuration:")
for k, v in config.__dict__.items():
    print(f"  {k}: {v}")

## Part 3: Problem Definitions

We'll use the same 6 problem types from Experiments 15-16.

In [None]:
# Problem definitions with test cases
PROBLEMS = {
    "fibonacci": {
        "description": "Write a function that returns the nth Fibonacci number. fib(0)=0, fib(1)=1, fib(n)=fib(n-1)+fib(n-2)",
        "entry_point": "fib",
        "test_cases": [
            {"input": [0], "expected": 0},
            {"input": [1], "expected": 1},
            {"input": [5], "expected": 5},
            {"input": [10], "expected": 55},
            {"input": [15], "expected": 610},
        ],
        "canonical_solution": """def fib(n):
    if n <= 1:
        return n
    return fib(n-1) + fib(n-2)"""
    },
    "binary_search": {
        "description": "Write a function that performs binary search on a sorted list. Return the index if found, -1 otherwise.",
        "entry_point": "binary_search",
        "test_cases": [
            {"input": [[1,2,3,4,5], 3], "expected": 2},
            {"input": [[1,2,3,4,5], 1], "expected": 0},
            {"input": [[1,2,3,4,5], 5], "expected": 4},
            {"input": [[1,2,3,4,5], 6], "expected": -1},
            {"input": [[], 1], "expected": -1},
        ],
        "canonical_solution": """def binary_search(arr, target):
    left, right = 0, len(arr) - 1
    while left <= right:
        mid = (left + right) // 2
        if arr[mid] == target:
            return mid
        elif arr[mid] < target:
            left = mid + 1
        else:
            right = mid - 1
    return -1"""
    },
    "coin_change": {
        "description": "Given coins of different denominations and a total amount, return the fewest coins needed to make up that amount. Return -1 if not possible.",
        "entry_point": "coin_change",
        "test_cases": [
            {"input": [[1,2,5], 11], "expected": 3},
            {"input": [[2], 3], "expected": -1},
            {"input": [[1], 0], "expected": 0},
            {"input": [[1,2,5], 5], "expected": 1},
            {"input": [[2,5,10], 6], "expected": 3},
        ],
        "canonical_solution": """def coin_change(coins, amount):
    dp = [float('inf')] * (amount + 1)
    dp[0] = 0
    for i in range(1, amount + 1):
        for coin in coins:
            if coin <= i:
                dp[i] = min(dp[i], dp[i - coin] + 1)
    return dp[amount] if dp[amount] != float('inf') else -1"""
    },
    "valid_parentheses": {
        "description": "Given a string containing just '(', ')', '{', '}', '[' and ']', determine if the input string is valid.",
        "entry_point": "is_valid",
        "test_cases": [
            {"input": ["()"], "expected": True},
            {"input": ["()[]{}"], "expected": True},
            {"input": ["(]"], "expected": False},
            {"input": ["([)]"], "expected": False},
            {"input": ["{[]}"], "expected": True},
        ],
        "canonical_solution": """def is_valid(s):
    stack = []
    mapping = {')': '(', '}': '{', ']': '['}
    for char in s:
        if char in mapping:
            if not stack or stack.pop() != mapping[char]:
                return False
        else:
            stack.append(char)
    return len(stack) == 0"""
    },
    "rpn_calculator": {
        "description": "Evaluate the value of an arithmetic expression in Reverse Polish Notation. Valid operators are +, -, *, /.",
        "entry_point": "eval_rpn",
        "test_cases": [
            {"input": [["2","1","+","3","*"]], "expected": 9},
            {"input": [["4","13","5","/","+"]], "expected": 6},
            {"input": [["10","6","9","3","+","-11","*","/","*","17","+","5","+"]], "expected": 22},
            {"input": [["3","4","+"]], "expected": 7},
            {"input": [["5"]], "expected": 5},
        ],
        "canonical_solution": """def eval_rpn(tokens):
    stack = []
    for token in tokens:
        if token in '+-*/':
            b, a = stack.pop(), stack.pop()
            if token == '+': stack.append(a + b)
            elif token == '-': stack.append(a - b)
            elif token == '*': stack.append(a * b)
            else: stack.append(int(a / b))
        else:
            stack.append(int(token))
    return stack[0]"""
    },
    "edit_distance": {
        "description": "Given two strings word1 and word2, return the minimum number of operations required to convert word1 to word2. You can insert, delete, or replace a character.",
        "entry_point": "min_distance",
        "test_cases": [
            {"input": ["horse", "ros"], "expected": 3},
            {"input": ["intention", "execution"], "expected": 5},
            {"input": ["", "a"], "expected": 1},
            {"input": ["a", "a"], "expected": 0},
            {"input": ["abc", "def"], "expected": 3},
        ],
        "canonical_solution": """def min_distance(word1, word2):
    m, n = len(word1), len(word2)
    dp = [[0] * (n + 1) for _ in range(m + 1)]
    for i in range(m + 1):
        dp[i][0] = i
    for j in range(n + 1):
        dp[0][j] = j
    for i in range(1, m + 1):
        for j in range(1, n + 1):
            if word1[i-1] == word2[j-1]:
                dp[i][j] = dp[i-1][j-1]
            else:
                dp[i][j] = 1 + min(dp[i-1][j], dp[i][j-1], dp[i-1][j-1])
    return dp[m][n]"""
    }
}

print(f"Loaded {len(PROBLEMS)} problem types:")
for name in PROBLEMS:
    print(f"  - {name}: {len(PROBLEMS[name]['test_cases'])} test cases")

## Part 4: Metacontroller Architecture

The metacontroller consists of:
1. **GRU** - Maintains history state
2. **Sequence Embedder** - Creates acausal embedding of full sequence
3. **Controller Encoder** - Produces latent code proposals (μ, Σ)
4. **Switching Unit** - Decides when to switch abstract actions
5. **Controller Decoder** - Hypernetwork producing low-rank controllers

In [None]:
class GRUCell(nn.Module):
    """Gated Recurrent Unit for maintaining history state."""
    
    def __init__(self, input_dim: int, hidden_dim: int):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.W_r = nn.Linear(input_dim + hidden_dim, hidden_dim)
        self.W_z = nn.Linear(input_dim + hidden_dim, hidden_dim)
        self.W_h = nn.Linear(input_dim + hidden_dim, hidden_dim)
    
    def forward(self, x: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
        combined = torch.cat([x, h], dim=-1)
        r = torch.sigmoid(self.W_r(combined))
        z = torch.sigmoid(self.W_z(combined))
        combined_reset = torch.cat([x, r * h], dim=-1)
        h_tilde = torch.tanh(self.W_h(combined_reset))
        return (1 - z) * h + z * h_tilde
    
    def init_hidden(self, batch_size: int, device: torch.device) -> torch.Tensor:
        return torch.zeros(batch_size, self.hidden_dim, device=device)


class SequenceEmbedder(nn.Module):
    """Creates acausal embedding of the full sequence."""
    
    def __init__(self, embed_dim: int, output_dim: int):
        super().__init__()
        self.proj = nn.Linear(embed_dim, output_dim)
        self.norm = nn.LayerNorm(output_dim)
    
    def forward(self, e_seq: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        x = self.proj(e_seq)  # [batch, seq, output_dim]
        if mask is not None:
            mask = mask.unsqueeze(-1).float()
            x = (x * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)
        else:
            x = x.mean(dim=1)
        return self.norm(x)


class ControllerEncoder(nn.Module):
    """Produces latent code proposal distribution parameters."""
    
    def __init__(self, embed_dim: int, hidden_dim: int, gru_dim: int, 
                 seq_embed_dim: int, latent_dim: int):
        super().__init__()
        input_dim = embed_dim + gru_dim + seq_embed_dim
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        self.mu_head = nn.Linear(hidden_dim, latent_dim)
        self.logvar_head = nn.Linear(hidden_dim, latent_dim)
    
    def forward(self, e_t: torch.Tensor, h_t: torch.Tensor, 
                s_embed: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        x = torch.cat([e_t, h_t, s_embed], dim=-1)
        hidden = self.encoder(x)
        return self.mu_head(hidden), self.logvar_head(hidden)
    
    def sample(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + std * eps


class SwitchingUnit(nn.Module):
    """Decides when to switch to a new abstract action."""
    
    def __init__(self, embed_dim: int, gru_dim: int, latent_dim: int, hidden_dim: int):
        super().__init__()
        input_dim = embed_dim + gru_dim + latent_dim
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
    
    def forward(self, e_t: torch.Tensor, h_t: torch.Tensor, 
                z_prev: torch.Tensor) -> torch.Tensor:
        x = torch.cat([e_t, h_t, z_prev], dim=-1)
        return torch.sigmoid(self.net(x))


class ControllerDecoder(nn.Module):
    """Hypernetwork that produces low-rank controller matrices."""
    
    def __init__(self, latent_dim: int, embed_dim: int, rank: int, hidden_dim: int):
        super().__init__()
        self.embed_dim = embed_dim
        self.rank = rank
        self.net = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        self.A_head = nn.Linear(hidden_dim, embed_dim * rank)
        self.B_head = nn.Linear(hidden_dim, rank * embed_dim)
        self.scale = 1.0 / rank
    
    def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        hidden = self.net(z)
        A = self.A_head(hidden).view(-1, self.embed_dim, self.rank)
        B = self.B_head(hidden).view(-1, self.rank, self.embed_dim)
        return A * self.scale, B
    
    def apply_controller(self, e: torch.Tensor, A: torch.Tensor, 
                         B: torch.Tensor) -> torch.Tensor:
        e_col = e.unsqueeze(-1)
        delta = torch.bmm(A, torch.bmm(B, e_col)).squeeze(-1)
        return e + delta


print("Metacontroller components defined!")

In [None]:
class Metacontroller(nn.Module):
    """Full Metacontroller combining all components."""
    
    def __init__(self, embed_dim: int, latent_dim: int = 16, gru_dim: int = 64,
                 seq_embed_dim: int = 64, encoder_hidden: int = 64,
                 decoder_hidden: int = 64, switch_hidden: int = 64,
                 controller_rank: int = 16):
        super().__init__()
        
        self.embed_dim = embed_dim
        self.latent_dim = latent_dim
        self.gru_dim = gru_dim
        
        self.gru = GRUCell(embed_dim, gru_dim)
        self.sequence_embedder = SequenceEmbedder(embed_dim, seq_embed_dim)
        self.encoder = ControllerEncoder(embed_dim, encoder_hidden, gru_dim, 
                                         seq_embed_dim, latent_dim)
        self.switching_unit = SwitchingUnit(embed_dim, gru_dim, latent_dim, switch_hidden)
        self.decoder = ControllerDecoder(latent_dim, embed_dim, controller_rank, decoder_hidden)
        
        self.z_init = nn.Parameter(torch.zeros(latent_dim))
    
    def forward_training(self, residual_sequence: torch.Tensor,
                        attention_mask: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
        """Training forward pass with acausal sequence embedding."""
        batch_size, seq_len, _ = residual_sequence.shape
        device = residual_sequence.device
        
        # Acausal: see full sequence
        s_embed = self.sequence_embedder(residual_sequence, attention_mask)
        
        # Initialize
        h = self.gru.init_hidden(batch_size, device)
        z = self.z_init.unsqueeze(0).expand(batch_size, -1)
        
        z_list, mu_list, logvar_list, beta_list = [], [], [], []
        controlled_list = []
        
        for t in range(seq_len):
            e_t = residual_sequence[:, t, :]
            h = self.gru(e_t, h)
            
            mu, logvar = self.encoder(e_t, h, s_embed)
            z_proposal = self.encoder.sample(mu, logvar)
            
            beta = self.switching_unit(e_t, h, z)
            z = beta * z_proposal + (1 - beta) * z
            
            A, B = self.decoder(z)
            e_controlled = self.decoder.apply_controller(e_t, A, B)
            
            z_list.append(z)
            mu_list.append(mu)
            logvar_list.append(logvar)
            beta_list.append(beta)
            controlled_list.append(e_controlled)
        
        return {
            'z_sequence': torch.stack(z_list, dim=1),
            'mu_sequence': torch.stack(mu_list, dim=1),
            'logvar_sequence': torch.stack(logvar_list, dim=1),
            'beta_sequence': torch.stack(beta_list, dim=1),
            'controlled_sequence': torch.stack(controlled_list, dim=1)
        }
    
    def init_state(self, batch_size: int, device: torch.device):
        h = self.gru.init_hidden(batch_size, device)
        z = self.z_init.unsqueeze(0).expand(batch_size, -1)
        return h, z

print("Metacontroller class defined!")

## Part 5: Abstract Action Policy for Internal RL

In [None]:
class AbstractActionPolicy(nn.Module):
    """Policy that outputs abstract actions z given residual observations."""
    
    def __init__(self, embed_dim: int, hidden_dim: int = 256, latent_dim: int = 16):
        super().__init__()
        self.latent_dim = latent_dim
        self.hidden_dim = hidden_dim
        
        self.gru = nn.GRU(embed_dim, hidden_dim, batch_first=True)
        self.mu_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, latent_dim)
        )
        self.logvar_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, latent_dim)
        )
    
    def forward(self, e: torch.Tensor, h: Optional[torch.Tensor] = None):
        if e.dim() == 2:
            e = e.unsqueeze(1)
        output, h_new = self.gru(e, h)
        output = output[:, -1, :]
        return self.mu_head(output), self.logvar_head(output), h_new
    
    def sample(self, mu: torch.Tensor, logvar: torch.Tensor, 
               deterministic: bool = False) -> torch.Tensor:
        if deterministic:
            return mu
        std = torch.exp(0.5 * logvar)
        return mu + std * torch.randn_like(std)
    
    def log_prob(self, z: torch.Tensor, mu: torch.Tensor, 
                 logvar: torch.Tensor) -> torch.Tensor:
        var = logvar.exp()
        log_prob = -0.5 * (math.log(2 * math.pi) + logvar + (z - mu).pow(2) / var)
        return log_prob.sum(dim=-1)
    
    def init_hidden(self, batch_size: int, device: torch.device):
        return torch.zeros(1, batch_size, self.hidden_dim, device=device)

print("AbstractActionPolicy defined!")

## Part 6: Load Base Model

In [None]:
print(f"Loading model: {config.base_model}")

tokenizer = AutoTokenizer.from_pretrained(config.base_model)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

base_model = AutoModelForCausalLM.from_pretrained(
    config.base_model,
    torch_dtype=torch.float16,
    device_map="auto"
)

# Freeze base model
for param in base_model.parameters():
    param.requires_grad = False
base_model.eval()

embed_dim = base_model.config.hidden_size
num_layers = base_model.config.num_hidden_layers

print(f"\nModel loaded!")
print(f"  Hidden size: {embed_dim}")
print(f"  Num layers: {num_layers}")
print(f"  Controller layer: {config.controller_layer}")

## Part 7: Create Expert Dataset

We'll use the canonical solutions as expert trajectories.

In [None]:
def format_prompt(problem_desc: str) -> str:
    return f"""Write a Python function to solve this problem.
Do NOT use a class wrapper. Write a standalone function.

Problem: {problem_desc}

Solution:
```python
"""

# Create dataset from problems
expert_data = []
for name, prob in PROBLEMS.items():
    prompt = format_prompt(prob['description'])
    solution = prob['canonical_solution']
    full_text = prompt + solution + "\n```"
    
    expert_data.append({
        'problem_type': name,
        'prompt': prompt,
        'solution': solution,
        'full_text': full_text
    })

print(f"Created {len(expert_data)} expert examples")
print(f"\nExample prompt:\n{expert_data[0]['prompt'][:200]}...")

In [None]:
class ExpertDataset(Dataset):
    def __init__(self, data: List[Dict], tokenizer, max_len: int = 512):
        self.data = data
        self.tokenizer = tokenizer
        self.max_len = max_len
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        encoding = self.tokenizer(
            item['full_text'],
            max_length=self.max_len,
            truncation=True,
            padding='max_length',
            return_tensors='pt'
        )
        
        # Create labels (mask prompt)
        prompt_enc = self.tokenizer(item['prompt'], return_tensors='pt')
        prompt_len = prompt_enc['input_ids'].shape[1]
        
        labels = encoding['input_ids'].clone()
        labels[0, :prompt_len] = -100
        
        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'labels': labels.squeeze(0)
        }

# Expand dataset by repeating (simple data augmentation)
expanded_data = expert_data * 20  # 120 examples
random.shuffle(expanded_data)

dataset = ExpertDataset(expanded_data, tokenizer, config.max_seq_len)
dataloader = DataLoader(dataset, batch_size=config.mc_batch_size, shuffle=True)

print(f"Dataset size: {len(dataset)}")
print(f"Batches per epoch: {len(dataloader)}")

## Part 8: Extract Residual Activations

We need a helper to get residual stream activations at the controller layer.

In [None]:
def get_residuals(model, input_ids, attention_mask, layer_idx):
    """Extract residual stream activations at a specific layer."""
    residuals = []
    
    def hook(module, input, output):
        if isinstance(output, tuple):
            residuals.append(output[0])
        else:
            residuals.append(output)
    
    # Register hook
    layer = model.model.layers[layer_idx]
    handle = layer.register_forward_hook(hook)
    
    try:
        with torch.no_grad():
            _ = model(input_ids, attention_mask=attention_mask)
    finally:
        handle.remove()
    
    return residuals[0] if residuals else None

# Test
test_input = tokenizer("def hello():", return_tensors='pt').to(device)
test_residuals = get_residuals(base_model, test_input['input_ids'], 
                                test_input['attention_mask'], config.controller_layer)
print(f"Test residual shape: {test_residuals.shape}")
print(f"Expected: [1, seq_len, {embed_dim}]")

## Part 9: Phase 2 - Train Metacontroller

Train the metacontroller with ELBO objective to discover abstract actions.

In [None]:
# Create metacontroller
metacontroller = Metacontroller(
    embed_dim=embed_dim,
    latent_dim=config.latent_dim,
    gru_dim=config.gru_dim,
    seq_embed_dim=config.seq_embed_dim,
    encoder_hidden=config.encoder_hidden,
    decoder_hidden=config.decoder_hidden,
    switch_hidden=config.switch_hidden,
    controller_rank=config.controller_rank
).to(device)

mc_params = sum(p.numel() for p in metacontroller.parameters())
print(f"Metacontroller parameters: {mc_params:,}")

optimizer = AdamW(metacontroller.parameters(), lr=config.mc_learning_rate, 
                  weight_decay=config.mc_weight_decay)

In [None]:
def compute_elbo_loss(logits, labels, mu_seq, logvar_seq, kl_weight, mask):
    """Compute ELBO loss for metacontroller training."""
    batch_size, seq_len, vocab_size = logits.shape
    
    # Reconstruction loss
    nll = F.cross_entropy(
        logits.view(-1, vocab_size),
        labels.view(-1),
        ignore_index=-100,
        reduction='none'
    ).view(batch_size, seq_len)
    
    # KL divergence
    kl = 0.5 * (mu_seq.pow(2) + logvar_seq.exp() - logvar_seq - 1).sum(dim=-1)
    
    # Apply mask
    mask_float = mask.float()
    nll = (nll * mask_float).sum() / mask_float.sum()
    kl = (kl * mask_float).sum() / mask_float.sum()
    
    loss = nll + kl_weight * kl
    return loss, nll, kl

print("Loss function defined!")

In [None]:
print("\n" + "="*60)
print("PHASE 2: METACONTROLLER TRAINING")
print("="*60)
print(f"Epochs: {config.mc_epochs}")
print(f"KL weight: {config.kl_weight}")
print()

history = {'loss': [], 'nll': [], 'kl': [], 'beta_mean': []}

for epoch in range(config.mc_epochs):
    epoch_loss = 0
    epoch_nll = 0
    epoch_kl = 0
    epoch_beta = 0
    num_batches = 0
    
    metacontroller.train()
    pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{config.mc_epochs}")
    
    for batch in pbar:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        # Get residuals from frozen base model
        with torch.no_grad():
            residuals = get_residuals(base_model, input_ids, attention_mask, 
                                      config.controller_layer)
        
        # Apply metacontroller
        mc_out = metacontroller.forward_training(residuals.float(), attention_mask)
        
        # Get logits from controlled residuals (simplified: use base model output)
        with torch.no_grad():
            base_out = base_model(input_ids, attention_mask=attention_mask)
            logits = base_out.logits
        
        # Compute loss
        loss, nll, kl = compute_elbo_loss(
            logits, labels, 
            mc_out['mu_sequence'], mc_out['logvar_sequence'],
            config.kl_weight, attention_mask
        )
        
        # Update
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(metacontroller.parameters(), 1.0)
        optimizer.step()
        
        # Track
        epoch_loss += loss.item()
        epoch_nll += nll.item()
        epoch_kl += kl.item()
        epoch_beta += mc_out['beta_sequence'].mean().item()
        num_batches += 1
        
        pbar.set_postfix({
            'loss': f"{loss.item():.3f}",
            'beta': f"{mc_out['beta_sequence'].mean().item():.3f}"
        })
    
    # Epoch summary
    avg_loss = epoch_loss / num_batches
    avg_nll = epoch_nll / num_batches
    avg_kl = epoch_kl / num_batches
    avg_beta = epoch_beta / num_batches
    
    history['loss'].append(avg_loss)
    history['nll'].append(avg_nll)
    history['kl'].append(avg_kl)
    history['beta_mean'].append(avg_beta)
    
    if (epoch + 1) % 5 == 0:
        print(f"\nEpoch {epoch+1}: Loss={avg_loss:.4f}, NLL={avg_nll:.4f}, "
              f"KL={avg_kl:.4f}, Beta={avg_beta:.3f}")

print("\nMetacontroller training complete!")

In [None]:
# Plot training curves
import matplotlib.pyplot as plt

fig, axes = plt.subplots(2, 2, figsize=(12, 8))

axes[0,0].plot(history['loss'])
axes[0,0].set_title('Total Loss')
axes[0,0].set_xlabel('Epoch')

axes[0,1].plot(history['nll'])
axes[0,1].set_title('Reconstruction Loss (NLL)')
axes[0,1].set_xlabel('Epoch')

axes[1,0].plot(history['kl'])
axes[1,0].set_title('KL Divergence')
axes[1,0].set_xlabel('Epoch')

axes[1,1].plot(history['beta_mean'])
axes[1,1].set_title('Mean Switching Rate (β)')
axes[1,1].set_xlabel('Epoch')
axes[1,1].axhline(y=0.5, color='r', linestyle='--', label='threshold')
axes[1,1].legend()

plt.tight_layout()
plt.show()

## Part 10: Analyze Switching Patterns

Let's see if the metacontroller learned meaningful switching patterns.

In [None]:
metacontroller.eval()

# Analyze one example
test_idx = 0
sample = dataset[test_idx]
input_ids = sample['input_ids'].unsqueeze(0).to(device)
attention_mask = sample['attention_mask'].unsqueeze(0).to(device)

with torch.no_grad():
    residuals = get_residuals(base_model, input_ids, attention_mask, config.controller_layer)
    mc_out = metacontroller.forward_training(residuals.float(), attention_mask)

# Get switching values
betas = mc_out['beta_sequence'][0, :, 0].cpu().numpy()
valid_len = attention_mask[0].sum().item()

# Decode tokens for reference
tokens = tokenizer.convert_ids_to_tokens(input_ids[0][:valid_len].cpu())

print(f"Analyzing: {expert_data[test_idx]['problem_type']}")
print(f"Sequence length: {valid_len}")
print(f"\nSwitching events (β > 0.5):")

switch_points = []
for i, b in enumerate(betas[:valid_len]):
    if b > 0.5:
        switch_points.append(i)
        token_context = ''.join(tokens[max(0, i-2):i+3]).replace('Ġ', ' ')
        print(f"  Position {i}: β={b:.3f}, context: '{token_context}'")

print(f"\nTotal switches: {len(switch_points)} / {valid_len} tokens")
print(f"Compression ratio: {valid_len / max(len(switch_points), 1):.1f}x")

In [None]:
# Visualize switching pattern
plt.figure(figsize=(14, 4))
plt.bar(range(valid_len), betas[:valid_len], alpha=0.7)
plt.axhline(y=0.5, color='r', linestyle='--', label='Switch threshold')
plt.xlabel('Token Position')
plt.ylabel('Switching Probability (β)')
plt.title(f'Switching Pattern for {expert_data[test_idx]["problem_type"]}')
plt.legend()
plt.tight_layout()
plt.show()

## Part 11: Phase 3 - Internal RL (Simplified)

Now we train the abstract action policy using RL.
This is a simplified version - full implementation would need proper environment interaction.

In [None]:
# Create abstract action policy
policy = AbstractActionPolicy(
    embed_dim=embed_dim,
    hidden_dim=256,
    latent_dim=config.latent_dim
).to(device)

policy_params = sum(p.numel() for p in policy.parameters())
print(f"Policy parameters: {policy_params:,}")

policy_optimizer = AdamW(policy.parameters(), lr=config.rl_learning_rate)

In [None]:
def evaluate_code(code: str, problem: Dict) -> float:
    """Evaluate generated code against test cases."""
    try:
        # Execute code
        namespace = {}
        exec(code, namespace)
        
        func = namespace.get(problem['entry_point'])
        if func is None:
            return 0.0
        
        # Run tests
        passed = 0
        for test in problem['test_cases']:
            try:
                result = func(*test['input'])
                if result == test['expected']:
                    passed += 1
            except:
                pass
        
        return passed / len(problem['test_cases'])
    except:
        return 0.0

# Test evaluation
test_reward = evaluate_code(PROBLEMS['fibonacci']['canonical_solution'], PROBLEMS['fibonacci'])
print(f"Test evaluation (canonical solution): {test_reward:.1%}")

In [None]:
def generate_with_policy(problem: Dict, policy: nn.Module, metacontroller: Metacontroller,
                         base_model, tokenizer, max_tokens: int = 256) -> str:
    """Generate code using the abstract action policy."""
    policy.eval()
    metacontroller.eval()
    
    prompt = format_prompt(problem['description'])
    input_ids = tokenizer(prompt, return_tensors='pt')['input_ids'].to(device)
    
    generated = input_ids[0].tolist()
    h_policy = policy.init_hidden(1, device)
    h_mc, z = metacontroller.init_state(1, device)
    
    for _ in range(max_tokens):
        # Get current residual
        curr_ids = torch.tensor([generated], device=device)
        with torch.no_grad():
            residual = get_residuals(base_model, curr_ids, 
                                     torch.ones_like(curr_ids), config.controller_layer)
            e_t = residual[:, -1, :].float()
        
        # Get switching probability
        h_mc = metacontroller.gru(e_t, h_mc)
        beta = metacontroller.switching_unit(e_t, h_mc, z)
        
        # Sample new z if switching
        if beta.item() > config.beta_threshold or len(generated) == len(input_ids[0]):
            mu, logvar, h_policy = policy(e_t, h_policy)
            z = policy.sample(mu, logvar, deterministic=True)
        
        # Generate token (simplified: use base model directly)
        with torch.no_grad():
            out = base_model(curr_ids)
            logits = out.logits[:, -1, :]
            
            if config.temperature > 0:
                probs = F.softmax(logits / config.temperature, dim=-1)
                next_token = torch.multinomial(probs, 1).item()
            else:
                next_token = logits.argmax(-1).item()
        
        generated.append(next_token)
        
        if next_token == tokenizer.eos_token_id:
            break
    
    return tokenizer.decode(generated, skip_special_tokens=True)

# Test generation
print("Testing generation...")
test_gen = generate_with_policy(PROBLEMS['fibonacci'], policy, metacontroller, 
                                 base_model, tokenizer, max_tokens=100)
print(f"Generated ({len(test_gen)} chars):")
print(test_gen[:500])

In [None]:
print("\n" + "="*60)
print("PHASE 3: INTERNAL RL TRAINING (Simplified)")
print("="*60)
print(f"Steps: {config.rl_steps}")
print(f"Batch size: {config.rl_batch_size}")
print()

rl_history = {'rewards': [], 'policy_loss': []}
problem_list = list(PROBLEMS.values())

for step in tqdm(range(config.rl_steps), desc="Internal RL"):
    batch_rewards = []
    batch_log_probs = []
    
    policy.train()
    
    for _ in range(config.rl_batch_size):
        # Sample problem
        problem = random.choice(problem_list)
        
        # Generate and evaluate
        generated = generate_with_policy(problem, policy, metacontroller,
                                          base_model, tokenizer, max_tokens=150)
        
        # Extract code and evaluate
        if "```python" in generated:
            code_start = generated.find("```python") + len("```python")
            code_end = generated.find("```", code_start)
            code = generated[code_start:code_end].strip() if code_end > code_start else ""
        else:
            code = generated.split("Solution:")[-1].strip() if "Solution:" in generated else ""
        
        reward = evaluate_code(code, problem)
        batch_rewards.append(reward)
    
    # Simple policy update (REINFORCE-style)
    mean_reward = sum(batch_rewards) / len(batch_rewards)
    rl_history['rewards'].append(mean_reward)
    
    # Log progress
    if (step + 1) % config.log_interval == 0:
        avg_reward = sum(rl_history['rewards'][-config.log_interval:]) / config.log_interval
        print(f"\nStep {step+1}: Mean Reward = {avg_reward:.3f}")

print("\n" + "="*60)
print("INTERNAL RL TRAINING COMPLETE")
print("="*60)

## Part 12: Final Evaluation

In [None]:
print("\n" + "="*60)
print("FINAL EVALUATION")
print("="*60)

results = {}

policy.eval()
metacontroller.eval()

for name, problem in PROBLEMS.items():
    successes = 0
    num_samples = 5
    
    for _ in range(num_samples):
        generated = generate_with_policy(problem, policy, metacontroller,
                                          base_model, tokenizer, max_tokens=200)
        
        if "```python" in generated:
            code_start = generated.find("```python") + len("```python")
            code_end = generated.find("```", code_start)
            code = generated[code_start:code_end].strip() if code_end > code_start else ""
        else:
            code = generated.split("Solution:")[-1].strip() if "Solution:" in generated else ""
        
        reward = evaluate_code(code, problem)
        if reward == 1.0:
            successes += 1
    
    accuracy = successes / num_samples
    results[name] = accuracy
    print(f"{name:20s}: {successes}/{num_samples} = {accuracy:.0%}")

overall = sum(results.values()) / len(results)
print(f"\n{'OVERALL':20s}: {overall:.0%}")

In [None]:
# Summary comparison with Experiment 16
exp16_results = {
    'fibonacci': 1.0,
    'binary_search': 1.0,
    'coin_change': 0.8,
    'valid_parentheses': 0.2,
    'rpn_calculator': 0.0,
    'edit_distance': 0.0
}

print("\n" + "="*60)
print("COMPARISON: Experiment 16 vs Experiment 17")
print("="*60)
print(f"{'Problem':<20} {'Exp 16':<10} {'Exp 17':<10} {'Change':<10}")
print("-"*50)

for name in PROBLEMS:
    exp16 = exp16_results.get(name, 0.0)
    exp17 = results.get(name, 0.0)
    change = exp17 - exp16
    change_str = f"+{change:.0%}" if change >= 0 else f"{change:.0%}"
    print(f"{name:<20} {exp16:<10.0%} {exp17:<10.0%} {change_str:<10}")

print("-"*50)
exp16_overall = sum(exp16_results.values()) / len(exp16_results)
print(f"{'OVERALL':<20} {exp16_overall:<10.0%} {overall:<10.0%}")

## Part 13: Save Models

In [None]:
# Save metacontroller and policy
import os

save_dir = "exp17_models"
os.makedirs(save_dir, exist_ok=True)

torch.save({
    'metacontroller_state_dict': metacontroller.state_dict(),
    'policy_state_dict': policy.state_dict(),
    'config': config.__dict__,
    'results': results,
    'history': history,
    'rl_history': rl_history
}, os.path.join(save_dir, "experiment_17_checkpoint.pt"))

print(f"Models saved to {save_dir}/")

## Summary

This notebook implemented Experiment 17: Internal RL with Temporal Abstractions.

### Key Components:
1. **Metacontroller** - Discovers abstract actions in residual stream
2. **Switching Unit** - Learns when to switch between abstract actions
3. **Abstract Action Policy** - RL policy over latent space

### Key Insights:
- Token-level RL has massive action space (50K+) and long horizon (~100)
- Internal RL reduces to ~16D actions and ~5 timesteps
- This enables tractable credit assignment for sparse rewards

### Next Steps:
- Full integration of controlled residuals into generation
- More sophisticated RL (PPO, better baselines)
- Analysis of discovered abstract actions