In [1]:
# conda env: scalinglaws
import os
import math
from dataclasses import dataclass
from typing import Dict, Tuple
import matplotlib.pyplot as plt
import numpy as np

from sklearn import metrics
import torch
import torch.nn as nn
import torch.nn.functional as F

from src.nn.modules.utils import compute_lr
from src.nn.utils.constants import IGNORE_LABEL_ID
from src.nn.optimizers.muon import Muon

# pip install adam-atan2-pytorch
try:
    from adam_atan2_pytorch import AdamAtan2
    print(f"*"*60)
    print("Imported AdamATan2 successfully")
except ImportError:
    print("Failed to import adam2")

from lightning import LightningModule

from src.nn.modules.sparse_embeddings import (
    CastedSparseEmbedding,
    CastedSparseEmbeddingSignSGD_Distributed,
)
from src.nn.modules.trm_block import (
    CastedEmbedding,
    CastedLinear,
    ReasoningBlock,
    ReasoningBlockConfig,
    ReasoningModule,
    RotaryEmbedding,
    RotaryEmbedding2D,
)
from src.nn.modules.utils import stablemax_cross_entropy, trunc_normal_init_
from src.nn.utils import RankedLogger
from src.nn.data.sudoku_datamodule import SudokuDataModule

log = RankedLogger(__name__, rank_zero_only=True)

************************************************************
Imported AdamATan2 successfully


In [15]:
def visualize_sudoku_text(batch, idx=0, grid_size=6):
    input_tensor = batch['input'][idx]
    label_tensor = batch['output'][idx]
    
    # 1. We need to determine the max_grid_size used by the dataset to reshape correctly.
    # We can infer it from the tensor length: sqrt(sequence_length)
    seq_len = input_tensor.numel()
    max_grid_size = int(seq_len**0.5)

    # Box dimensions
    if grid_size == 6: box_rows, box_cols = 2, 3
    elif grid_size == 9: box_rows, box_cols = 3, 3
    elif grid_size == 4: box_rows, box_cols = 2, 2
    else: box_rows, box_cols = 2, 3

    def decode_cell(val):
        val = val.item()
        if val == 2: return "."
        if val > 2: return str(val - 2)
        return "?" # 0=PAD or 1=EOS

    def render_grid(tensor):
        # CORRECTED LOGIC:
        # 1. Reshape using the FULL max_grid_size to restore 2D structure
        full_grid = tensor.reshape(max_grid_size, max_grid_size)
        
        # 2. Crop to the actual grid_size (top-left corner)
        grid = full_grid[:grid_size, :grid_size]
        
        lines = []
        dash_segment = "-" * (box_cols * 2 + 1)
        h_sep = "+" + "+".join([dash_segment] * (grid_size // box_cols)) + "+"
        
        for r in range(grid_size):
            if r % box_rows == 0:
                lines.append(h_sep)
            
            row_str = "|"
            for c in range(grid_size):
                cell = decode_cell(grid[r, c])
                row_str += f" {cell}"
                if (c + 1) % box_cols == 0:
                    row_str += " |"
            lines.append(row_str)
        lines.append(h_sep)
        return "\n".join(lines)

    # Stats calculation (only on valid crop)
    full_input_2d = input_tensor.reshape(max_grid_size, max_grid_size)
    valid_input_crop = full_input_2d[:grid_size, :grid_size]
    
    givens = (valid_input_crop > 2).sum().item()
    empty = (valid_input_crop == 2).sum().item()

    print("=" * 60)
    print(f"Sample {idx} (grid_size={grid_size}x{grid_size})")
    print("=" * 60)
    print(f"Givens: {givens}, Empty: {empty}\n")
    
    print("Puzzle:")
    print(render_grid(input_tensor))
    print("\nSolution:")
    print(render_grid(label_tensor))
    print("\n")

In [16]:
@dataclass
class TRMInnerCarry:
    z_H: torch.Tensor
    z_L: torch.Tensor

@dataclass
class TRMCarry:
    inner_carry: TRMInnerCarry
    steps: torch.Tensor
    halted: torch.Tensor
    current_data: Dict[str, torch.Tensor]

class TRMModel(nn.Module):
    def __init__(
        self,
        hidden_size: int = 512,
        num_layers: int = 2,
        num_heads: int = 8,
        max_grid_size: int = 30,
        ffn_expansion: int = 2,
        puzzle_emb_dim: int = 512,
        puzzle_emb_len: int = 16,
        pos_emb_type: str = "1d",
        use_mlp_t: bool = False,
        use_conv_swiglu: bool = False,
        use_board_swiglu: bool = False,
        vocab_size: int = 0,
        num_puzzles: int = 0,
        batch_size: int = 0,
        pad_value: int = -1,
        seq_len: int = 0,
    ):
        super().__init__()

        self.vocab_size = vocab_size
        self.pad_value = pad_value
        self.forward_dtype = torch.bfloat16
        self.seq_len = seq_len
        self.hidden_size = hidden_size

        # Token embeddings
        self.embed_scale = math.sqrt(self.hidden_size)
        embed_init_std = 1.0 / self.embed_scale

        # Input embedding
        self.input_embedding = CastedEmbedding(
            vocab_size, hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype
        )

        # Puzzle embedding
        if puzzle_emb_dim > 0:
            self.puzzle_emb = CastedSparseEmbedding(
                num_embeddings=num_puzzles,
                embedding_dim=puzzle_emb_dim,
                batch_size=batch_size,
                init_std=0.0,
                cast_to=self.forward_dtype,
            )
            self.puzzle_emb_len = puzzle_emb_len
        else:
            self.puzzle_emb = None
            self.puzzle_emb_len = 0
        
        # Positional embeddings
        if pos_emb_type == "2d":
            self.pos_embedding = RotaryEmbedding2D(
                dim=self.hidden_size // num_heads,
                prefix_len=self.puzzle_emb_len, # Use self.puzzle_emb_len
                max_grid_size=int(math.sqrt(self.seq_len)), 
                base=10000,
            )
        elif pos_emb_type == "1d":
            self.pos_embedding = RotaryEmbedding(
                dim=self.hidden_size // num_heads,
                max_position_embeddings=self.seq_len + self.puzzle_emb_len, # Use self.puzzle_emb_len
                base=10000,
            )
        
        if not use_mlp_t:
            assert pos_emb_type is not None, "Rotary embeddings required if using attention"

        # Reasoning Block
        reasoning_config = ReasoningBlockConfig(
            hidden_size=self.hidden_size,
            num_heads=num_heads,
            expansion=ffn_expansion,
            rms_norm_eps=1e-5,
            seq_len=self.seq_len,
            mlp_t=use_mlp_t,
            puzzle_emb_ndim=puzzle_emb_dim,
            puzzle_emb_len=self.puzzle_emb_len, # Use self.puzzle_emb_len
            use_conv_swiglu=use_conv_swiglu,
            use_board_swiglu=use_board_swiglu,
            rows = max_grid_size,
            cols = max_grid_size
        )

        self.lenet = ReasoningModule(
            layers=[ReasoningBlock(reasoning_config) for _ in range(num_layers)]
        )

        self.lm_head = CastedLinear(self.hidden_size, vocab_size, bias=False)
        self.q_head = CastedLinear(self.hidden_size, 1, bias=True)

        with torch.no_grad():
            self.q_head.weight.zero_()
            if self.q_head.bias is not None:
                self.q_head.bias.fill_(-5.0)

        self.carry = None

        self.z_H_init = nn.Buffer(
            trunc_normal_init_(torch.empty(self.hidden_size, dtype=self.forward_dtype), std=1),
            persistent=True,
        )
        self.z_L_init = nn.Buffer(
            trunc_normal_init_(torch.empty(self.hidden_size, dtype=self.forward_dtype), std=1),
            persistent=True,
        )

        # Add puzzle embeddings
        if puzzle_emb_dim > 0:
            self.puzzle_emb = CastedSparseEmbedding(
                num_embeddings=num_puzzles,
                embedding_dim=puzzle_emb_dim,
                batch_size=batch_size,
                init_std=0.0,  # Reference uses 0 init
                cast_to=self.forward_dtype,
            )
            self.puzzle_emb_len = puzzle_emb_len
            log.info(f"Created puzzle_emb with num_puzzles={num_puzzles}, batch_size={batch_size}")
            log.info(f"puzzle_emb.local_weights.shape: {self.puzzle_emb.local_weights.shape}")
            log.info(f"puzzle_emb.weights.shape: {self.puzzle_emb.weights.shape}")
        else:
            log.info("puzzle_emb_dim <= 0, not creating puzzle embeddings")
            self.puzzle_emb = None
            self.puzzle_emb_len = 0

    def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
        # Token embedding
        embedding = self.input_embedding(input.to(torch.int32))

        # Puzzle embeddings (Optional, based on your init)
        if self.puzzle_emb is not None:
            puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
            pad_count = self.puzzle_emb_len * self.hidden_size - puzzle_embedding.shape[-1]
            if pad_count > 0:
                puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
            
            embedding = torch.cat(
                (puzzle_embedding.view(-1, self.puzzle_emb_len, self.hidden_size), embedding),
                dim=-2,
            )

        return self.embed_scale * embedding

    def inner_forward(self, carry: TRMInnerCarry, batch: Dict[str, torch.Tensor]):
        """The core recurrent block: processes H_cycles and L_cycles."""
        
        # Calculate Rotary Embeddings if available
        seq_info = dict(
            cos_sin=self.pos_embedding() if hasattr(self, "pos_embedding") else None,
        )

        input_embeddings = self._input_embeddings(batch["input"], batch["puzzle_identifiers"])

        z_H, z_L = carry.z_H, carry.z_L
        
        # H_cycles: High-level reasoning
        # We run H-1 cycles without gradients to save memory (standard TRM trick)
        H_cycles = 3 # Default from your config
        L_cycles = 6 # Default from your config
        
        with torch.no_grad():
            for _ in range(H_cycles - 1):
                for _ in range(L_cycles):
                    z_L = self.lenet(z_L, z_H + input_embeddings, **seq_info)
                z_H = self.lenet(z_H, z_L, **seq_info)
                
        # The final cycle tracks gradients
        for _ in range(L_cycles):
            z_L = self.lenet(z_L, z_H + input_embeddings, **seq_info)
        z_H = self.lenet(z_H, z_L, **seq_info)

        # Output Heads
        new_carry = TRMInnerCarry(z_H=z_H.detach(), z_L=z_L.detach())
        
        # Slicing off the puzzle embedding tokens to get just the grid predictions
        output = self.lm_head(z_H)[:, self.puzzle_emb_len :] 
        q_logits = self.q_head(z_H[:, 0]).to(torch.float32)

        return new_carry, output, q_logits[..., 0]

    def initial_carry(self, batch_size, device):
        """Creates the initial zero-state carry."""
        return TRMInnerCarry(
            z_H=torch.zeros(batch_size, self.seq_len + self.puzzle_emb_len, self.hidden_size, device=device, dtype=self.forward_dtype),
            z_L=torch.zeros(batch_size, self.seq_len + self.puzzle_emb_len, self.hidden_size, device=device, dtype=self.forward_dtype),
        )

    def forward(self, carry: TRMCarry, batch: Dict[str, torch.Tensor], n_supervision: int):
        """
        Runs one step of reasoning.
        Logic: 
        1. If a sample in the batch is 'halted' (finished), reset its state (start over).
        2. Run inner_forward.
        3. Determine if we should halt now (based on steps or Q-head).
        """
        batch_size = batch["input"].shape[0]
        device = batch["input"].device

        # If carry is None, initialize it (Start of a new batch)
        if carry is None:
            inner = self.initial_carry(batch_size, device)
            carry = TRMCarry(
                inner_carry=inner,
                steps=torch.zeros((batch_size,), dtype=torch.int32, device=device),
                halted=torch.ones((batch_size,), dtype=torch.bool, device=device), # Start as halted so we reset immediately
                current_data=batch
            )

        # Reset logic: If a sequence halted in the *previous* step, reset it now to start fresh
        # Note: We use z_H_init buffer from your init
        reset_mask = carry.halted.view(-1, 1, 1)
        new_z_H = torch.where(reset_mask, self.z_H_init, carry.inner_carry.z_H)
        new_z_L = torch.where(reset_mask, self.z_L_init, carry.inner_carry.z_L)
        
        new_inner_carry = TRMInnerCarry(z_H=new_z_H, z_L=new_z_L)
        new_steps = torch.where(carry.halted, 0, carry.steps)
        
        # Actual Forward Pass
        new_inner_carry, logits, q_halt_logits = self.inner_forward(new_inner_carry, batch)

        # Increment steps
        new_steps = new_steps + 1
        
        # Halt Logic (Did we reach max supervision steps?)
        # In a simple notebook version, we mostly rely on fixed N steps
        halted = new_steps >= n_supervision
        
        # Update Carry
        new_carry = TRMCarry(
            inner_carry=new_inner_carry,
            steps=new_steps,
            halted=halted,
            current_data=batch
        )

        return new_carry, logits, q_halt_logits

## Prepare Sudoku dataset

In [17]:
# 1. Instantiate the DataModule in Generation Mode (data_dir=None)
grid_size = 6
dm = SudokuDataModule(
    data_dir=None,       
    batch_size=32,       
    num_train_puzzles=1000,
    num_val_puzzles=100,
    num_test_puzzles=100,
    grid_size=grid_size,
    num_workers=0  
)

# 2. Setup the data (generates the puzzle pool)
dm.setup()

# 3. specific loaders 
train_loader = dm.train_dataloader()
val_loader = dm.val_dataloader()

# 4. Extract metadata needed for the Model dimensions
# In generation mode, these are calculated based on grid_size
vocab_size = dm.vocab_size 
seq_len = dm.seq_len
puzzle_emb_len = 16 # how many tokes for puzzle embedding

print(f"Data Ready: {grid_size}x{grid_size} Sudoku")
print(f"Vocab Size: {vocab_size}, Sequence Length: {seq_len}")

Data Ready: 6x6 Sudoku
Vocab Size: 9, Sequence Length: 64


In [18]:
batch = next(iter(train_loader))
# Visualize the first element in the batch
visualize_sudoku_text(batch, idx=0, grid_size=6)

Sample 0 (grid_size=6x6)
Givens: 18, Empty: 18

Puzzle:
+-------+-------+
| 1 . . | 5 . 7 |
| . . . | 4 3 1 |
+-------+-------+
| . . . | 7 5 3 |
| . 7 5 | 1 . 8 |
+-------+-------+
| 5 . 1 | . 7 4 |
| . . . | 8 . . |
+-------+-------+

Solution:
+-------+-------+
| 1 4 3 | 5 8 7 |
| 7 5 8 | 4 3 1 |
+-------+-------+
| 8 1 4 | 7 5 3 |
| 3 7 5 | 1 4 8 |
+-------+-------+
| 5 8 1 | 3 7 4 |
| 4 3 7 | 8 1 5 |
+-------+-------+




## Training

In [19]:
def visualize_batch(batch, logits, grid_size=4):
    """Visualizes the first sample in the batch: Input vs Prediction vs Target"""
    inputs = batch['input'][0].cpu().numpy()
    targets = batch['output'][0].cpu().numpy()
    preds = logits[0].argmax(dim=-1).cpu().numpy()
    
    # Remap tokens back to numbers (0=pad, 1=eos, 2=empty, 3+=values)
    # See pad_and_encode in SudokuDataset
    def decode(flat_arr):
        arr = flat_arr[:grid_size*grid_size].reshape(grid_size, grid_size)
        # Shift back: 2->0 (empty), 3->1 (val 1)
        res = np.zeros_like(arr)
        mask = arr >= 3
        res[mask] = arr[mask] - 2
        return res

    fig, axes = plt.subplots(1, 3, figsize=(10, 4))
    titles = ["Input (0=Empty)", "Prediction", "Target"]
    for ax, data, title in zip(axes, [inputs, preds, targets], titles):
        grid = decode(data)
        ax.matshow(grid, cmap='Blues')
        for (i, j), z in np.ndenumerate(grid):
            ax.text(j, i, f'{z}', ha='center', va='center', 
                    color='black' if z != 0 else 'lightgray')
        ax.set_title(title)
        ax.set_xticks([]); ax.set_yticks([])
    plt.show()

In [22]:
CONFIG = {
    "grid_size": 6,             
    "batch_size": 32,
    "max_epochs": 1000,
    "hidden_size": 512,
    "num_layers": 2,
    "num_heads": 8,
    "ffn_expansion": 4,
    "puzzle_emb_dim": 0,        
    "puzzle_emb_len": 0,        
    "pos_emb_type": None,       
    "use_mlp_t": True,          
    "use_conv_swiglu": False,
    "use_board_swiglu": False,
    "learning_rate": 1e-4,
    "weight_decay": 1.0,
    "N_supervision": 16
}

device = "cpu" #torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Initializing TRM Model...")

model = TRMModel(
    vocab_size=3 + dm.max_grid_size,  # 0=PAD, 1=EOS, 2=Empty, 3+=Values
    seq_len=dm.seq_len,
    max_grid_size=dm.max_grid_size,
    pad_value=0,
    hidden_size=CONFIG["hidden_size"],
    num_layers=CONFIG["num_layers"],
    num_heads=CONFIG["num_heads"],
    ffn_expansion=CONFIG["ffn_expansion"],
    puzzle_emb_dim=CONFIG["puzzle_emb_dim"],
    puzzle_emb_len=CONFIG["puzzle_emb_len"],
    pos_emb_type=CONFIG["pos_emb_type"],  
    use_mlp_t=CONFIG["use_mlp_t"],    
    use_conv_swiglu=CONFIG["use_conv_swiglu"],
    use_board_swiglu=CONFIG["use_board_swiglu"],
    # Dataset specific (required for safeguards even if puzzle_emb_dim=0)
    num_puzzles=0,    
    batch_size=CONFIG["batch_size"]
).to(device)

optimizer = torch.optim.AdamW(
    model.parameters(), 
    lr=CONFIG["learning_rate"], 
    weight_decay=CONFIG["weight_decay"]
)

print("-" * 40)
print(f"Model successfully initialized on {device}")
print(f"Parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f}M")
print(f"Configuration: MLP_T={CONFIG['use_mlp_t']}, PosEmb={CONFIG['pos_emb_type']}")
print("-" * 40)


Initializing TRM Model...
----------------------------------------
Model successfully initialized on cpu
Parameters: 4.83M
Configuration: MLP_T=True, PosEmb=None
----------------------------------------


In [None]:
print(f"Starting training on {device}...")
EPOCHS = 3
N_SUPERVISION = 4
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    
    for batch_idx, batch in enumerate(train_loader):
        # Move batch to deviceÃŸ
        batch = {k: v.to(device) for k, v in batch.items()}
        visualize_sudoku_text(batch, idx=0, grid_size=6)
        print(batch['input'][0])
        print(batch['output'][0])
        # 1. Initialize State for this new batch
        carry = None 
        
        # 2. Reasoning Loop (The "Recurrent" part of TRM)
        # We process the SAME batch multiple times (N_SUPERVISION)
        # allowing the model to refine its hidden state
        for step in range(N_SUPERVISION):
            optimizer.zero_grad()
            # 1. Forward pass
            carry, logits, q_logits = model(carry, batch, n_supervision=N_SUPERVISION)
            
            labels = batch['output']
            
            # ---------------------------------------------------------
            # Loss 1: Language Modeling (Sudoku Solution)
            # ---------------------------------------------------------
            # Flatten for CrossEntropy: [batch * seq_len, vocab_size]
            flat_logits = logits.reshape(-1, logits.shape[-1])
            flat_labels = labels.reshape(-1)

            lm_loss = F.cross_entropy(
                flat_logits, 
                flat_labels, 
                ignore_index=-100 
            )

            # ---------------------------------------------------------
            # Loss 2: Halting (Q-Head)
            # ---------------------------------------------------------
            # We need to determine if the model *actually* got the solution right 
            # for this specific step. The target for the Q-head is 1.0 if correct, 0.0 otherwise.
            
            with torch.no_grad():
                # Get predictions
                preds = logits.argmax(dim=-1)
                
                # Mask out padding/ignore tokens (-100)
                mask = labels != -100
                
                # Check correctness per cell
                # (Where mask is True, pred must match label)
                correct_cells = (preds == labels) & mask
                
                # Count required correct cells per sequence
                required_correct = mask.sum(dim=-1)
                
                # Count actual correct cells per sequence
                actual_correct = correct_cells.sum(dim=-1)
                
                # A sequence is correct ONLY if all non-ignored tokens match
                seq_is_correct = (actual_correct == required_correct).float()

            # Binary Cross Entropy for the halting head
            # q_logits shape: [batch_size] -> We need to align with seq_is_correct
            q_halt_loss = F.binary_cross_entropy_with_logits(
                q_logits, 
                seq_is_correct.to(q_logits.device)
            )

            # ---------------------------------------------------------
            # Total Loss
            # ---------------------------------------------------------
            # Standard TRM weighting: 1.0 * LM + 0.5 * Q_Halt
            loss = lm_loss + 0.5 * q_halt_loss

            # Backward pass
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            # Detach carry to prevent infinite graph growth
            if carry is not None:
                carry.inner_carry.z_H = carry.inner_carry.z_H.detach()
                carry.inner_carry.z_L = carry.inner_carry.z_L.detach()

Starting training on cpu...
Sample 0 (grid_size=6x6)
Givens: 16, Empty: 20

Puzzle:
+-------+-------+
| 7 4 . | 8 . 1 |
| . . . | . 7 . |
+-------+-------+
| 1 . . | 2 . . |
| 4 8 2 | 7 . 6 |
+-------+-------+
| . . 4 | . 8 7 |
| . . . | 4 . . |
+-------+-------+

Solution:
+-------+-------+
| 7 4 6 | 8 2 1 |
| 2 1 8 | 6 7 4 |
+-------+-------+
| 1 6 7 | 2 4 8 |
| 4 8 2 | 7 1 6 |
+-------+-------+
| 6 2 4 | 1 8 7 |
| 8 7 1 | 4 6 2 |
+-------+-------+


tensor([ 9,  6,  2, 10,  2,  3,  0,  0,  2,  2,  2,  2,  9,  2,  0,  0,  3,  2,
         2,  4,  2,  2,  0,  0,  6, 10,  4,  9,  2,  8,  0,  0,  2,  2,  6,  2,
        10,  9,  0,  0,  2,  2,  2,  6,  2,  2,  0,  0,  1,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0])
tensor([   9,    6,    8,   10,    4,    3, -100, -100,    4,    3,   10,    8,
           9,    6, -100, -100,    3,    8,    9,    4,    6,   10, -100, -100,
           6,   10,    4,    9,    3,    8, -100, -100,    8,    4,    6,    3,
          10, 

KeyboardInterrupt: 