In [1]:
import os
import math
from dataclasses import dataclass
from typing import List, Tuple, Optional, Dict, Union, Any
from types import SimpleNamespace

from sklearn import metrics
import torch
import torch.distributed as dist
import torch.nn as nn
import einops
import numpy as np
import torch.nn.functional as F

from src.nn.modules.utils import compute_lr
from src.nn.utils.constants import IGNORE_LABEL_ID
from torch.nn.functional import scaled_dot_product_attention
from src.nn.modules.utils import trunc_normal_init_
from src.nn.data.sudoku_datamodule import SudokuDataModule

from src.nn.modules.utils import stablemax_cross_entropy, trunc_normal_init_
from src.nn.utils import RankedLogger

from examples.utils import visualize_sudoku_text

CosSin = Tuple[torch.Tensor, torch.Tensor]
log = RankedLogger(__name__, rank_zero_only=True)
torch.manual_seed(0)
np.random.seed(0)

device = "cuda" if torch.cuda.is_available() else "cpu"
print("device:", device)
# set seed
torch.manual_seed(42)
np.random.seed(42)

device: cuda


In [2]:
class CastedLinear(nn.Module):
    def __init__(self, in_features: int, out_features: int, bias: bool):
        super().__init__()
        # Truncated LeCun normal init
        self.weight = nn.Parameter(
            trunc_normal_init_(
                torch.empty((out_features, in_features)), std=1.0 / (in_features**0.5)
            )
        )
        self.bias = None
        if bias:
            # Zero init bias
            self.bias = nn.Parameter(torch.zeros((out_features,)))

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return F.linear(
            input,
            self.weight.to(input.dtype),
            bias=self.bias.to(input.dtype) if self.bias is not None else None,
        )
    
class CastedEmbedding(nn.Module):
    def __init__(
        self, num_embeddings: int, embedding_dim: int, init_std: float, cast_to: torch.dtype
    ):
        super().__init__()
        self.cast_to = cast_to

        # Truncated LeCun normal init
        self.embedding_weight = nn.Parameter(
            trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std)
        )

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return F.embedding(input, self.embedding_weight.to(self.cast_to))

class SwiGLU(nn.Module):
    def __init__(self, hidden_size: int, expansion: float):
        super().__init__()
        inter = _find_multiple(round(expansion * hidden_size * 2 / 3), 256)

        self.gate_up_proj = CastedLinear(hidden_size, inter * 2, bias=False)
        self.down_proj = CastedLinear(inter, hidden_size, bias=False)

    def forward(self, x):
        gate, up = self.gate_up_proj(x).chunk(2, dim=-1)
        return self.down_proj(F.silu(gate) * up)

def rms_norm(hidden_states: torch.Tensor, variance_epsilon: float) -> torch.Tensor:
    input_dtype = hidden_states.dtype
    hidden_states = hidden_states.to(torch.float32)

    variance = hidden_states.square().mean(-1, keepdim=True)
    hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
    return hidden_states.to(input_dtype)

def _find_multiple(a, b):
    return (-(a // -b)) * b

class ReasoningBlockConfig:
    """
    Configuration A: The "Standard Transformer"
        mlp_t=False (Use Attention)
        use_convswiglu/use_boardswiglu=False (Use Standard MLP)
        Result: Classic powerful reasoning. Attention handles global context; MLP handles logic.

    Configuration B: The "Spatial-Inductive Transformer"
        mlp_t=False (Use Attention)
        use_convswiglu/use_boardswiglu=True (Use Conv MLP)
        Result: Strongest. Attention sees the whole board ("I can win in column 7"), while ConvSwiGLU recognizes patterns immediately ("I have 3-in-a-row here"). This gives the best of both worlds.

    Configuration C: The "MLP-Mixer" (Pure MLP)
        mlp_t=True (Use Token MLP)
        use_convswiglu/use_boardswiglu=False (Use Standard MLP)
        Result: Very fast, very stable, but no Attention. The model mixes information globally using a fixed matrix. It might struggle with "dynamic" reasoning.

    Configuration D: The "ConvMixer"
        mlp_t=True
        use_convswiglu/use_boardswiglu=True
        Result: A fully convolutional/MLP network. It has zero attention mechanisms: mlp_t mixes the board globally (fixed weights), convswiglu mixes neighbors locally.
    """
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        expansion: int,
        rms_norm_eps: float,
        mlp_t: bool = False,
        seq_len: int = 0,
        cols: int = None,
        rows: int = None,
        puzzle_emb_ndim: int = 0,
        puzzle_emb_len: int = 0,
        use_conv_swiglu: bool = False,
        use_board_swiglu: bool = False,
        dropout: float = 0.1
    ) -> None:
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.expansion = expansion
        self.rms_norm_eps = rms_norm_eps
        self.mlp_t = mlp_t
        self.puzzle_emb_ndim = puzzle_emb_ndim
        self.puzzle_emb_len = puzzle_emb_len
        self.seq_len = seq_len
        self.cols = cols
        self.rows = rows
        self.use_conv_swiglu = use_conv_swiglu
        self.use_board_swiglu = use_board_swiglu
        self.dropout = dropout

class ReasoningBlock(nn.Module):
    def __init__(self, config: ReasoningBlockConfig) -> None:
        super().__init__()
        self.config = config
        self.norm_eps = config.rms_norm_eps
        self.dropout = nn.Dropout(config.dropout)

        # 1. Calculate Effective Length
        # If config is 0 (auto), infer from dimensions. Otherwise use config.
        # This handles the case where puzzle_emb_ndim > 0 but puzzle_emb_len was not manually set.
        self.puzzle_emb_len = (
            -(config.puzzle_emb_ndim // -config.hidden_size)
            if config.puzzle_emb_len == 0
            else config.puzzle_emb_len
        )

        self.mlp_t = SwiGLU(
            hidden_size=config.seq_len + self.puzzle_emb_len, 
            expansion=config.expansion,
        )
        self.mlp = SwiGLU(
            hidden_size=config.hidden_size,
            expansion=config.expansion,
        )

    def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:

        if self.config.mlp_t:
            hidden_states = hidden_states.transpose(1,2)
            out = self.mlp_t(hidden_states)
            out = self.dropout(out)
            hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
            hidden_states = hidden_states.transpose(1,2)
        else:
            attn_out = self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states)
            hidden_states = rms_norm(hidden_states + attn_out, variance_epsilon=self.norm_eps)

        mlp_out = self.mlp(hidden_states)
        hidden_states = rms_norm(hidden_states + mlp_out, variance_epsilon=self.norm_eps)
            
        return hidden_states

class ReasoningModule(nn.Module):
    def __init__(self, layers: List[ReasoningBlock]):
        super().__init__()
        self.layers = torch.nn.ModuleList(layers)

    def forward(
        self, hidden_states: torch.Tensor, input_injection: torch.Tensor, **kwargs
    ) -> torch.Tensor:
        hidden_states = hidden_states + input_injection
        for layer in self.layers:
            hidden_states = layer(hidden_states=hidden_states, **kwargs)
        return hidden_states

In [3]:
@dataclass
class TRMInnerCarry:
    z_H: torch.Tensor  # High-level state (y = the solution representation)
    z_L: torch.Tensor  # Low-level state (z = the problem representation)


@dataclass
class TRMCarry:
    """Carry structure for maintaining state across steps."""

    inner_carry: TRMInnerCarry
    steps: torch.Tensor
    halted: torch.Tensor
    current_data: Dict[str, torch.Tensor]  # Stores current batch data
    
class TRMModel(nn.Module):
    """
    HRM implementation following Figure 2 pseudocode exactly.
    """

    def __init__(
        self,
        hidden_size: int = 512,
        num_layers: int = 2,
        num_heads: int = 8,  # min(2, hidden_size // 64)
        max_grid_size: int = 30,
        H_cycles: int = 3,
        L_cycles: int = 6,
        N_supervision: int = 16,
        N_supervision_val: int = 16,
        ffn_expansion: int = 2,
        learning_rate: float = 1e-4,
        learning_rate_emb: float = 1e-2,
        weight_decay: float = 0.01,
        warmup_steps: int = 2000,
        halt_exploration_prob: float = 0.1,
        puzzle_emb_dim: int = 512,  # Puzzle embedding dimension
        puzzle_emb_len: int = 16,  # How many tokens for puzzle embedding
        rope_theta: int = 10000,
        pos_emb_type: str = "1d",
        use_mlp_t: bool = False,
        use_conv_swiglu: bool = False,
        use_board_swiglu: bool = False,
        lr_min_ratio: float = 1.0,
        use_muon: bool = False,
        vocab_size: int = 0,  # Should be set from datamodule
        num_puzzles: int = 0,  # Should be set from datamodule
        batch_size: int = 0,  # Should be set from datamodule
        pad_value: int = -1,  # Should be set from datamodule
        seq_len: int = 0,  # Should be set from datamodule
        output_dir: str = None,
    ):
        super().__init__()
        self.hidden_size=hidden_size
        self.num_layers=num_layers
        self.num_heads=num_heads
        self.max_grid_size=max_grid_size
        self.H_cycles=H_cycles
        self.L_cycles=L_cycles
        self.N_supervision=N_supervision
        self.N_supervision_val=N_supervision_val
        self.ffn_expansion=ffn_expansion
        self.learning_rate=learning_rate
        self.learning_rate_emb=learning_rate_emb
        self.weight_decay=weight_decay
        self.warmup_steps=warmup_steps
        self.halt_exploration_prob=halt_exploration_prob
        self.puzzle_emb_dim=puzzle_emb_dim
        self.puzzle_emb_len=puzzle_emb_len
        self.rope_theta=rope_theta
        self.pos_emb_type=pos_emb_type
        self.use_mlp_t=use_mlp_t
        self.use_conv_swiglu=use_conv_swiglu
        self.use_board_swiglu=use_board_swiglu
        self.lr_min_ratio=lr_min_ratio
        self.use_muon=use_muon
        self.vocab_size=vocab_size
        self.num_puzzles=num_puzzles
        self.batch_size=batch_size
        self.pad_value=pad_value
        self.seq_len=seq_len
        self.output_dir=output_dir
        self.forward_dtype = torch.bfloat16

        # Token embeddings
        self.embed_scale = math.sqrt(hidden_size)
        embed_init_std = 1.0 / self.embed_scale

        log.info(f"Creating TRM with vocab size={vocab_size}, seq_len={seq_len}, puzzle_emb_len={puzzle_emb_len} {pos_emb_type=} {puzzle_emb_dim=}")
        log.info(f"{use_mlp_t=}, {use_conv_swiglu=}, {use_board_swiglu=}")

         # Input embedding
        self.input_embedding = CastedEmbedding(
            vocab_size, hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype
        )

        # Positional encoding
        log.info("Not using Rotary Embeddings")

        if not use_mlp_t:
            assert pos_emb_type is not None, "Rotary embeddings required if using attention"

        # a single network (not two separate networks)
        reasoning_config = ReasoningBlockConfig(
            hidden_size=hidden_size,
            num_heads=num_heads,
            expansion=ffn_expansion,
            rms_norm_eps=1e-5,
            seq_len=seq_len,
            mlp_t=use_mlp_t,
            puzzle_emb_ndim=puzzle_emb_dim,
            puzzle_emb_len=puzzle_emb_len,
            use_conv_swiglu=use_conv_swiglu,
            use_board_swiglu=use_board_swiglu,
            rows = max_grid_size,
            cols = max_grid_size,
            dropout=0.25
        )

        self.lenet = ReasoningModule(
            layers=[ReasoningBlock(reasoning_config) for _ in range(num_layers)]
        )

        self.lm_head = CastedLinear(hidden_size, vocab_size, bias=False)
        self.q_head = CastedLinear(hidden_size, 1, bias=True) # learn to stop, not to continue

        with torch.no_grad():
            self.q_head.weight.zero_()
            if self.q_head.bias is not None:
                self.q_head.bias.fill_(-5.0)  # Strong negative bias

        # State for carry (persisted across training steps)
        self.carry: Optional[TRMCarry] = None

        # Init states (registered buffers)
        self.register_buffer(
            "z_H_init",
            trunc_normal_init_(torch.empty(hidden_size, dtype=self.forward_dtype), std=1),
            persistent=True,
        )
        self.register_buffer(
            "z_L_init",
            trunc_normal_init_(torch.empty(hidden_size, dtype=self.forward_dtype), std=1),
            persistent=True,
        )

        log.info("puzzle_emb_dim <= 0, not creating puzzle embeddings")
        self.puzzle_emb = None
        self.puzzle_emb_len = 0

        self.manual_step = 0
        self.total_steps: float = float("inf")

    def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
        # Token embedding
        embedding = self.input_embedding(input.to(torch.int32))

        # Puzzle embeddings
        if self.puzzle_emb_dim > 0:
            puzzle_embedding = self.puzzle_emb(puzzle_identifiers)

            pad_count = self.puzzle_emb_len * self.hidden_size - puzzle_embedding.shape[-1]

            if pad_count > 0:
                puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))

            embedding = torch.cat(
                (
                    puzzle_embedding.view(-1, self.puzzle_emb_len, self.hidden_size),
                    embedding,
                ),
                dim=-2,
            )

        # Scale
        return self.embed_scale * embedding

    def initial_carry(self, batch: Dict[str, torch.Tensor]):
        batch_size = batch["input"].shape[0]
        device = batch["input"].device

        return TRMCarry(
            inner_carry=self.empty_carry(
                batch_size, device
            ),  # Empty is expected, it will be reseted in first pass as all sequences are halted.
            steps=torch.zeros((batch_size,), dtype=torch.int32, device=device),
            halted=torch.ones((batch_size,), dtype=torch.bool, device=device),  # Default to halted
            current_data={k: torch.empty_like(v, device=device) for k, v in batch.items()},
        )

    def empty_carry(self, batch_size: int, device: torch.device) -> TRMInnerCarry:
        return TRMInnerCarry(
            z_H=torch.empty(
                batch_size,
                self.seq_len + self.puzzle_emb_len,
                self.hidden_size,
                dtype=self.forward_dtype,
                device=device,
            ),
            z_L=torch.empty(
                batch_size,
                self.seq_len + self.puzzle_emb_len,
                self.hidden_size,
                dtype=self.forward_dtype,
                device=device,
            ),
        )

    def reset_carry(self, reset_flag: torch.Tensor, carry: TRMInnerCarry) -> TRMInnerCarry:
        return TRMInnerCarry(
            z_H=torch.where(reset_flag.view(-1, 1, 1), self.z_H_init, carry.z_H),
            z_L=torch.where(reset_flag.view(-1, 1, 1), self.z_L_init, carry.z_L),
        )

    def inner_forward(
        self, carry: TRMInnerCarry, batch: Dict[str, torch.Tensor]
    ) -> Tuple[TRMInnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        seq_info = dict(
            cos_sin=self.pos_embedding() if hasattr(self, "pos_embedding") else None,
        )

        # Input encoding
        input_embeddings = self._input_embeddings(batch["input"], batch["puzzle_identifiers"])

        # Forward iterations
        z_H, z_L = carry.z_H, carry.z_L
        # H_cycles-1 without grad
        with torch.no_grad():
            for _ in range(self.H_cycles - 1):
                for _ in range(self.L_cycles):
                    z_L = self.lenet(z_L, z_H + input_embeddings, **seq_info)
                z_H = self.lenet(z_H, z_L, **seq_info)
        # 1 with grad
        for _ in range(self.L_cycles):
            z_L = self.lenet(z_L, z_H + input_embeddings, **seq_info)
        z_H = self.lenet(z_H, z_L, **seq_info)
    
        # LM Outputs
        new_carry = TRMInnerCarry(z_H=z_H.detach(), z_L=z_L.detach())  # New carry no grad
        output = self.lm_head(z_H)[:, self.puzzle_emb_len :] # discard puzzle embeddings
        q_logits = self.q_head(z_H[:, 0]).to(
            torch.float32
        )  # Q-head; uses the first puzzle_emb position

        return new_carry, output, q_logits[..., 0]

    def forward(
        self, carry: TRMCarry, batch: Dict[str, torch.Tensor]
    ) -> Tuple[TRMCarry, Dict[str, torch.Tensor]]:
        # Update data, carry (removing halted sequences)
        new_inner_carry = self.reset_carry(carry.halted, carry.inner_carry)

        new_steps = torch.where(carry.halted, 0, carry.steps)

        new_current_data = {
            k: torch.where(carry.halted.view((-1,) + (1,) * (batch[k].ndim - 1)), batch[k], v)
            for k, v in carry.current_data.items()
        }

        # Forward inner model
        new_inner_carry, logits, q_halt_logits = self.inner_forward(
            new_inner_carry, new_current_data
        )

        outputs = {
            "logits": logits,
            "q_halt_logits": q_halt_logits,
        }

        with torch.no_grad():
            # Step
            new_steps = new_steps + 1
            n_supervision_steps = (
                self.N_supervision if self.training else self.N_supervision_val
            )

            is_last_step = new_steps >= n_supervision_steps

            halted = is_last_step

            # if training, and ACT is enabled
            if self.training and (self.N_supervision > 1):
                # Halt signal
                # NOTE: During evaluation, always use max steps, this is to guarantee the same halting steps inside a batch for batching purposes

                halted = halted | (q_halt_logits > 0)

                # Exploration
                min_halt_steps = (
                    torch.rand_like(q_halt_logits) < self.halt_exploration_prob
                ) * torch.randint_like(new_steps, low=2, high=self.N_supervision + 1)
                halted = halted & (new_steps >= min_halt_steps)

        return TRMCarry(new_inner_carry, new_steps, halted, new_current_data), outputs

    def compute_loss_and_metrics(self, carry, batch):
        """Compute loss and metrics without circular reference."""
        # Get model outputs
        new_carry, outputs = self.forward(carry, batch)
        labels = new_carry.current_data["output"]

        with torch.no_grad():
            outputs["preds"] = torch.argmax(outputs["logits"], dim=-1)

            # Correctness
            mask = labels != IGNORE_LABEL_ID
            loss_counts = mask.sum(-1)

            loss_divisor = loss_counts.clamp_min(1).unsqueeze(-1)  # Avoid NaNs in division

            is_correct = mask & (torch.argmax(outputs["logits"], dim=-1) == labels)
            seq_is_correct = is_correct.sum(-1) == loss_counts

            # Metrics (halted)
            valid_metrics = new_carry.halted & (loss_counts > 0)

            metrics = {
                "count": valid_metrics.sum(),
                "accuracy": torch.where(
                    valid_metrics, (is_correct.float() / loss_divisor).sum(-1), 0
                ).sum(),
                "exact_accuracy": (valid_metrics & seq_is_correct).sum(),
                "q_halt_accuracy": (
                    valid_metrics & ((outputs["q_halt_logits"].squeeze() >= 0) == seq_is_correct)
                ).sum(),
                "steps": torch.where(valid_metrics, new_carry.steps, 0).sum(),
            }

        # Compute losses: These are per-sequence losses that will be summed
        lm_loss = (
            stablemax_cross_entropy(
                outputs["logits"], labels, ignore_index=IGNORE_LABEL_ID, valid_mask=mask
            )
            / loss_divisor
        ).sum()

        q_halt_loss = F.binary_cross_entropy_with_logits(
            outputs["q_halt_logits"],
            seq_is_correct.to(outputs["q_halt_logits"].dtype),
            reduction="sum",
        )
        metrics.update(
            {
                "lm_loss": lm_loss.detach(),
                "q_halt_loss": q_halt_loss.detach(),
            }
        )

        total_loss = lm_loss + 0.5 * q_halt_loss

        return new_carry, total_loss, metrics, new_carry.halted.all()

    def log_metrics(self, metrics: dict, lr_this_step: float = None, batch_size: int = None):

        # Log learning rate (will log the last optimizer's LR)
        self.log("train/lr", lr_this_step, on_step=True)

        # Log metrics
        if metrics.get("count", 0) > 0:
            with torch.no_grad():
                count = metrics["count"]
                self.log("train/accuracy", metrics.get("accuracy", 0) / count, on_step=True)
                self.log(
                    "train/exact_accuracy",
                    metrics.get("exact_accuracy", 0) / count,
                    prog_bar=True,
                    on_step=True,
                )
                self.log(
                    "train/q_halt_accuracy",
                    metrics.get("q_halt_accuracy", 0) / count,
                    on_step=True,
                )
                self.log(
                    "train/steps",
                    metrics.get("steps", 0) / count,
                    prog_bar=True,
                    on_step=True,
                )

                self.log("train/lm_loss", metrics.get("lm_loss", 0) / batch_size, on_step=True)
                self.log(
                    "train/q_halt_loss", metrics.get("q_halt_loss", 0) / batch_size, on_step=True
                )

                avg_halt_steps = metrics.get("steps", 0) / metrics["count"]
                early_halt_rate = avg_halt_steps < self.N_supervision
                self.log("train/early_halt_rate", early_halt_rate, on_step=True)

### Load sudoku data

In [4]:
grid_size = 6
max_grid_size = 8
dm = SudokuDataModule(
    data_dir=None,       
    batch_size=32,       
    num_train_puzzles=1000,
    num_val_puzzles=100,
    num_test_puzzles=100,
    grid_size=grid_size,
    max_grid_size=max_grid_size,
    num_workers=0,
    seed=42
)
dm.setup()

train_loader = dm.train_dataloader()
val_loader = dm.val_dataloader()

vocab_size = dm.vocab_size 
seq_len = dm.seq_len
puzzle_emb_len = 0

print(f"Data Ready: {grid_size}x{grid_size} Sudoku")
print(f"Vocab Size: {vocab_size}, Sequence Length: {seq_len}")
batch = next(iter(train_loader))
# Visualize the first element in the batch
inp = batch['input'][0].reshape(max_grid_size, max_grid_size)
tgt = batch['output'][0].reshape(max_grid_size, max_grid_size)
inp_grid = inp[:grid_size, :grid_size]
tgt_grid = tgt[:grid_size, :grid_size]
print(inp_grid)
visualize_sudoku_text(batch, idx=0, grid_size=6)

Data Ready: 6x6 Sudoku
Vocab Size: 9, Sequence Length: 64
tensor([[ 2,  2,  7,  2,  9,  8],
        [ 2,  9,  2,  2,  4,  3],
        [ 2, 10,  2,  8,  2,  2],
        [ 9,  2,  8,  2, 10,  2],
        [10,  8,  4,  9,  2,  2],
        [ 2,  3,  2,  4,  8, 10]])
Sample 0 (grid_size=6x6)
Givens: 19, Empty: 17

Puzzle:
+-------+-------+
| . . 5 | . 7 6 |
| . 7 . | . 2 1 |
+-------+-------+
| . 8 . | 6 . . |
| 7 . 6 | . 8 . |
+-------+-------+
| 8 6 2 | 7 . . |
| . 1 . | 2 6 8 |
+-------+-------+

Solution:
+-------+-------+
| 1 2 5 | 8 7 6 |
| 6 7 8 | 5 2 1 |
+-------+-------+
| 2 8 1 | 6 5 7 |
| 7 5 6 | 1 8 2 |
+-------+-------+
| 8 6 2 | 7 1 5 |
| 5 1 7 | 2 6 8 |
+-------+-------+




### Define TRM model

In [5]:
model = TRMModel(
    hidden_size=512,
    num_layers=1,
    num_heads=8,
    max_grid_size=6,
    H_cycles=3,
    L_cycles=6,
    N_supervision=16,
    N_supervision_val=16,
    ffn_expansion=4,
    learning_rate=1e-4,
    learning_rate_emb=1e-4,
    weight_decay=1.0,
    warmup_steps=2000,
    halt_exploration_prob=0.1,
    puzzle_emb_dim=0,
    puzzle_emb_len=0,
    rope_theta=10000,
    pos_emb_type=None, # IMPORTANT: since use_mlp_t=True
    use_mlp_t=True,
    use_conv_swiglu=False,
    use_board_swiglu=False,
    lr_min_ratio=0.01,
    use_muon=False,
    vocab_size=3+dm.max_grid_size,
    num_puzzles=0,
    batch_size=train_loader.batch_size,
    pad_value=getattr(dm, "pad_value", 0),
    seq_len=seq_len,
    output_dir=None,
).to(device)

### Define TRM optimizer

In [6]:
base_lr = model.learning_rate
wd = model.weight_decay
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=base_lr,
    weight_decay=wd,
    betas=(0.9, 0.95),
)

### Train

In [7]:
def move_batch_to_device(batch, device):
    return {k: (v.to(device, non_blocking=True) if torch.is_tensor(v) else v) for k, v in batch.items()}

In [8]:
def training_step(model, carry, batch, optimizer, *, total_steps, grad_clip=1.0):
    """
    Mirrors your Lightning training_step, but as a pure function.
    - one forward pass per batch (carry persists across batches)
    - manual backward
    - warmup LR schedule
    - gradient clipping
    """
    model.train()
    batch_size = batch["input"].shape[0]

    if carry is None:
        carry = model.initial_carry(batch)

    carry, loss, metrics, _ = model.compute_loss_and_metrics(carry, batch)

    (loss / batch_size).backward()

    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip)

    # Warmup LR schedule (same logic)
    current_step = model.manual_step

    last_lr = None
    if current_step < model.warmup_steps:
        lr_this_step = compute_lr(
            base_lr=model.learning_rate,
            lr_warmup_steps=model.warmup_steps,
            lr_min_ratio=model.lr_min_ratio,
            current_step=current_step,
            total_steps=total_steps,
        )
    else:
        lr_this_step = model.learning_rate

    if hasattr(optimizer, "_optimizer"):
        for param_group in optimizer._optimizer.param_groups:
            param_group["lr"] = lr_this_step
            optimizer._optimizer.step()
            optimizer._optimizer.zero_grad()
        else:
            for param_group in optimizer.param_groups:
                param_group["lr"] = lr_this_step
            optimizer.step()
            optimizer.zero_grad()

    # safety
    if torch.isnan(metrics["lm_loss"]):
        raise RuntimeError(f"LM loss is NaN at step {model.manual_step}")

    model.manual_step += 1

    # return carry + some printable scalars
    count = float(metrics["count"].item())
    logs = {
        "train/loss": float(loss.item()),
        "train/lm_loss": float((metrics["lm_loss"] / batch_size).item()),
        "train/q_halt_loss": float((metrics["q_halt_loss"] / batch_size).item()),
        "train/lr": float(last_lr) if last_lr is not None else float("nan"),
    }
    if count > 0:
        logs.update(
            {
                "train/accuracy": float((metrics["accuracy"] / metrics["count"]).item()),
                "train/exact_accuracy": float((metrics["exact_accuracy"] / metrics["count"]).item()),
                "train/q_halt_accuracy": float((metrics["q_halt_accuracy"] / metrics["count"]).item()),
                "train/steps": float((metrics["steps"] / metrics["count"]).item()),
            }
        )
    return carry, loss, logs

In [None]:
check_val_every_n_epoch = 10
print_every = 50
max_epochs = 100
steps_per_epoch = len(train_loader)
total_steps = steps_per_epoch * max_epochs
train_carry = None

for epoch in range(max_epochs):
    for batch_idx, batch in enumerate(train_loader):
        batch = move_batch_to_device(batch, device)
        # If you want to reset carry every batch (instead of persisting), uncomment:
        # train_carry = None

        train_carry, loss, logs = training_step(
            model, train_carry, batch, optimizer, total_steps=total_steps, grad_clip=1.0
        )

        if batch_idx % print_every == 0:
            msg = (
                f"epoch {epoch:03d} | batch {batch_idx:04d} | "
                f"loss {logs['train/loss']:.4f} | "
                f"exact {logs.get('train/exact_accuracy', float('nan')):.4f} | "
                f"steps {logs.get('train/steps', float('nan')):.2f} | "
                f"lr {logs['train/lr']:.2e}"
            )
            print(msg)
    """
    if (epoch + 1) % check_val_every_n_epoch == 0:
        val_logs = run_validation_epoch(model, val_loader, device)
        print(
            f"[VAL @ epoch {epoch:03d}] "
            f"exact={val_logs['val/exact_accuracy']:.4f} "
            f"acc={val_logs['val/accuracy']:.4f} "
            f"steps={val_logs['val/steps']:.2f} "
            f"loss={val_logs['val/loss']:.4f}"
        )
    """

tensor([[ 2,  2, 10,  2,  7,  3,  0,  0],
        [ 2,  7,  2,  2,  4,  6,  0,  0],
        [ 2,  5,  2,  3,  2,  2,  0,  0],
        [ 7,  2,  3,  2,  5,  2,  0,  0],
        [ 5,  3,  4,  7,  2,  2,  0,  0],
        [ 2,  6,  2,  4,  3,  5,  0,  0],
        [ 1,  0,  0,  0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0,  0,  0,  0]], device='cuda:0')


epoch 000 | batch 0000 | loss 82.0539 | exact nan | steps nan | lr nan
tensor([[ 2,  2,  7,  5, 10,  2,  0,  0],
        [ 2, 10,  2,  3,  8,  2,  0,  0],
        [ 2,  2,  2,  9,  2,  8,  0,  0],
        [ 9,  2,  8,  7,  3, 10,  0,  0],
        [ 8,  9,  2,  2,  7,  5,  0,  0],
        [10,  7,  2,  2,  9,  2,  0,  0],
        [ 1,  0,  0,  0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0,  0,  0,  0]], device='cuda:0')
tensor([[ 2,  3, 10,  2,  5,  2,  0,  0],
        [ 2,  2,  5,  6, 10,  2,  0,  0],
        [ 2,  2,  8,  2,  3,  2,  0,  0],
        [ 5,  2,  2,  2,  8,  6,  0,  0],
        [ 2, 10,  9,  3,  6,  5,  0,  0],
        [ 3,  5,  6,  2,  2, 10,  0,  0],
        [ 1,  0,  0,  0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0,  0,  0,  0]], device='cuda:0')
tensor([[ 7,  2,  5,  9,  2,  4,  0,  0],
        [ 2,  2,  4,  2,  7,  2,  0,  0],
        [ 2,  2,  7,  2, 10,  6,  0,  0],
        [10,  2,  2,  5,  2,  2,  0,  0],
        [ 5,  2,  2,  2,  4,  2,  0,  0],
        [ 2

KeyboardInterrupt: 

In [None]:
@torch.no_grad()
def validation_step(model, batch):
    """
    Mirrors your Lightning validation_step:
    - fresh carry
    - loop until all_halted
    - accumulate metrics then average
    """
    model.eval()
    batch_size = batch["input"].shape[0]
    carry = model.initial_carry(batch)

    accumulated = {}
    total_loss = 0.0
    n_steps = 0

    while True:
        carry, loss, metrics, all_halted = model.compute_loss_and_metrics(carry, batch)

        for k, v in metrics.items():
            accumulated[k] = accumulated.get(k, 0.0) + float(v.item())

        total_loss += float(loss.item())
        n_steps += 1

        if bool(all_halted.item()):
            break

    count = accumulated.get("count", float(batch_size))
    if count > 0:
        return {
            "val/loss": total_loss / (n_steps * batch_size),
            "val/accuracy": accumulated.get("accuracy", 0.0) / count,
            "val/exact_accuracy": accumulated.get("exact_accuracy", 0.0) / count,
            "val/q_halt_accuracy": accumulated.get("q_halt_accuracy", 0.0) / count,
            "val/steps": accumulated.get("steps", 0.0) / count,
            "val/lm_loss": accumulated.get("lm_loss", 0.0) / (n_steps * batch_size),
            "val/q_halt_loss": accumulated.get("q_halt_loss", 0.0) / (n_steps * batch_size),
        }
    else:
        return {k: 0.0 for k in ["val/loss","val/accuracy","val/exact_accuracy","val/q_halt_accuracy","val/steps","val/lm_loss","val/q_halt_loss"]}

@torch.no_grad()
def run_validation_epoch(model, val_loader, device):
    totals = {
        "val/loss": 0.0,
        "val/accuracy": 0.0,
        "val/exact_accuracy": 0.0,
        "val/q_halt_accuracy": 0.0,
        "val/steps": 0.0,
        "val/lm_loss": 0.0,
        "val/q_halt_loss": 0.0,
    }
    n = 0
    for batch in val_loader:
        batch = move_batch_to_device(batch, device)
        out = validation_step(model, batch)
        for k in totals:
            totals[k] += out[k]
        n += 1
    for k in totals:
        totals[k] /= max(n, 1)
    return totals