In [1]:
# Cell 1: Imports
import torch
from pathlib import Path
import numpy as np
from src.nn.sudoku_evaluator import SudokuEvaluator

# Set up paths
CHECKPOINT_PATH = "../train/runs/2026-01-06_16-21-25/checkpoints/last.ckpt"
DATA_DIR = "../data/sudoku_6x6_large"

Failed to import adam2


In [2]:
# Cell 2: Initialize evaluator
evaluator = SudokuEvaluator(
    checkpoint_path=CHECKPOINT_PATH,
    data_dir=DATA_DIR,
    batch_size=256,
    device="auto",
    num_workers=0,
    eval_split="val"
)

print(f"Model loaded successfully!")
print(f"Grid size: {evaluator.grid_size}x{evaluator.grid_size}")
print(f"Vocab size: {evaluator.vocab_size}")

Using device: cuda
Loading model from ../train/runs/2026-01-06_16-21-25/checkpoints/last.ckpt
Model loaded: TRMModule
Model configuration:
  hidden_size: 512
  num_layers: 2
  H_cycles: 3
  L_cycles: 6
  N_supervision: 16
  vocab_size: 9
  seq_len: 64
Grid size: 6x6
Max grid size: 8x8
Vocab size: 9
Model loaded successfully!
Grid size: 6x6
Vocab size: 9


## Analyze thinking

In [3]:
evaluator.model.lenet.layers[0]

ReasoningBlock(
  (dropout): Dropout(p=0.25, inplace=False)
  (mlp_t): SwiGLU(
    (gate_up_proj): CastedLinear()
    (down_proj): CastedLinear()
  )
  (mlp): SwiGLU(
    (gate_up_proj): CastedLinear()
    (down_proj): CastedLinear()
  )
)

In [4]:
evaluator.datamodule.val_dataset[45]

  input_tensor = torch.from_numpy(input_flat).long()


{'input': tensor([2, 6, 3, 2, 5, 2, 0, 0, 2, 2, 2, 6, 3, 7, 0, 0, 2, 2, 2, 2, 8, 2, 0, 0,
         2, 2, 2, 2, 6, 2, 0, 0, 6, 5, 7, 2, 2, 8, 0, 0, 2, 2, 8, 2, 7, 2, 0, 0,
         1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
 'output': tensor([   7,    6,    3,    8,    5,    4, -100, -100,    4,    8,    5,    6,
            3,    7, -100, -100,    5,    7,    6,    4,    8,    3, -100, -100,
            8,    3,    4,    7,    6,    5, -100, -100,    6,    5,    7,    3,
            4,    8, -100, -100,    3,    4,    8,    5,    7,    6, -100, -100,
            1, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100]),
 'puzzle_identifiers': np.int32(0)}

In [5]:
# Cell 4: (Optional) Visualize model thinking on a sample
results_viz = evaluator.visualize_sample(
    split="val",
    sample_idx=45,
    show_confidence=True,
    save_gif=True  # Set to True if you want to save GIF
)

[inner forward] input_embeddings shape: torch.Size([256, 64, 512])


[inner forward] z_H shape: torch.Size([256, 64, 512])
[inner forward] q_logits shape: torch.Size([256, 1])
[inner forward] output shape: torch.Size([256, 64, 9])
[visualize_thinking] q_halt[45] = -1.5546875
[inner forward] input_embeddings shape: torch.Size([256, 64, 512])
[inner forward] z_H shape: torch.Size([256, 64, 512])
[inner forward] q_logits shape: torch.Size([256, 1])
[inner forward] output shape: torch.Size([256, 64, 9])
[visualize_thinking] q_halt[45] = -0.36328125
[inner forward] input_embeddings shape: torch.Size([256, 64, 512])
[inner forward] z_H shape: torch.Size([256, 64, 512])
[inner forward] q_logits shape: torch.Size([256, 1])
[inner forward] output shape: torch.Size([256, 64, 9])
[visualize_thinking] q_halt[45] = -0.2236328125
[inner forward] input_embeddings shape: torch.Size([256, 64, 512])
[inner forward] z_H shape: torch.Size([256, 64, 512])
[inner forward] q_logits shape: torch.Size([256, 1])
[inner forward] output shape: torch.Size([256, 64, 9])
[visualize_t

## Evaluation (full)

In [12]:
# Cell 3: Run evaluation on validation split
results = evaluator.evaluate(split="val", print_examples=True)

print("\n" + "=" * 60)
print("EVALUATION RESULTS")
print("=" * 60)
print(f"Cell Accuracy:   {results['cell_accuracy']:.4f} ({results['cell_accuracy'] * 100:.2f}%)")
print(f"Puzzle Accuracy: {results['puzzle_accuracy']:.4f} ({results['puzzle_accuracy'] * 100:.2f}%)")
print(f"Validity Rate:   {results['validity_rate']:.4f} ({results['validity_rate'] * 100:.2f}%)")
print(f"Puzzles Solved:  {results['puzzles_correct']}/{results['total_puzzles']}")
print(f"Valid Solutions: {results['valid_puzzles']}/{results['total_puzzles']}")
print(f"Avg Steps:       {results['avg_steps']:.1f}")
print("=" * 60)


Evaluating on val split...


  input_tensor = torch.from_numpy(input_flat).long()



EXAMPLE PREDICTIONS

--- Example 1 ---
Status: ✓ CORRECT | Valid Sudoku: ✓

INPUT           TARGET          PREDICTION
--------------  --------------  --------------
1 2 4 |_ _ _    1 2 4 |5 3 6    1 2 4 |5 3 6
5 6 _ |1 _ _    5 6 3 |1 4 2    5 6 3 |1 4 2
-----+-----     -----+-----     -----+-----
_ 5 2 |6 1 _    4 5 2 |6 1 3    4 5 2 |6 1 3
_ 3 _ |2 5 4    6 3 1 |2 5 4    6 3 1 |2 5 4
-----+-----     -----+-----     -----+-----
_ _ 5 |3 6 1    2 4 5 |3 6 1    2 4 5 |3 6 1
_ 1 _ |_ 2 _    3 1 6 |4 2 5    3 1 6 |4 2 5

--- Example 2 ---
Status: ✓ CORRECT | Valid Sudoku: ✓

INPUT           TARGET          PREDICTION
--------------  --------------  --------------
_ 4 _ |_ 6 5    1 4 2 |3 6 5    1 4 2 |3 6 5
3 5 _ |_ 1 _    3 5 6 |2 1 4    3 5 6 |2 1 4
-----+-----     -----+-----     -----+-----
6 _ 3 |_ 4 _    6 1 3 |5 4 2    6 1 3 |5 4 2
_ 2 _ |_ _ 1    5 2 4 |6 3 1    5 2 4 |6 3 1
-----+-----     -----+-----     -----+-----
_ _ _ |4 _ 6    2 3 1 |4 5 6    2 3 1 |4 5 6
_ _ 5 |1 _ _    

Evaluating val: 100%|██████████| 7/7 [00:13<00:00,  1.91s/it]


EVALUATION RESULTS
Cell Accuracy:   0.9094 (90.94%)
Puzzle Accuracy: 0.7070 (70.70%)
Validity Rate:   0.8711 (87.11%)
Puzzles Solved:  1267/1792
Valid Solutions: 1561/1792
Avg Steps:       16.0





In [7]:
results_viz = evaluator.visualize_sample(
    split="val",
    sample_idx=45,
    show_confidence=True,
    save_gif=False,
    num_stochastic_runs=3,      
    dropout_enabled=True,       
)


Per-step MC sampling: 3 sample(s)/step (dropout=ON)
Initial carry z_H: torch.Size([256, 64, 512])
Initial carry z_L: torch.Size([256, 64, 512])
Initial carry steps: torch.Size([256])
Initial carry halted: torch.Size([256])
Initial carry current_data['input']: torch.Size([256, 64])
Initial carry current_data['output']: torch.Size([256, 64])
Initial carry current_data['puzzle_identifiers']: torch.Size([256])

Total steps executed: 16

TRM THINKING VISUALIZATION (MC runs side-by-side)
H_cycles=3, L_cycles=6
Runs=3, Steps=16

INPUT            TARGET
---------------  ---------------
 ? 4 1| ? 3 ?     5 4 1| 6 3 2
 ? ? ?| 4 1 5     2 6 3| 4 1 5
------+------    ------+------
 ? ? ?| ? 6 ?     3 5 4| 2 6 1
 ? ? ?| ? 4 ?     6 1 2| 5 4 3
------+------    ------+------
 4 3 5| ? ? 6     4 3 5| 1 2 6
 ? ? 6| ? 5 ?     1 2 6| 3 5 4

Empty cells to fill: 22

------------------------------------------------------------------------------------------------------------------------
STEP-BY-STEP REASON

In [5]:
solution = np.array([
    [1,2,3,4,5,6],
    [4,5,6,1,2,3],
    [2,3,1,5,6,4],
    [3,4,5,6,1,2],
    [5,6,4,2,3,1],
    [6,1,2,3,4,5],
], dtype=np.int32)

puzzle = np.array([
    [0,2,0,4,0,6],
    [4,0,6,0,2,0],
    [2,0,0,5,6,0],
    [0,4,5,0,0,2],
    [5,0,4,0,3,0],
    [0,0,2,3,0,5],
], dtype=np.int32)

# Encode & pad (0=PAD,1=EOS,2=empty,3..=values+2)
inp_flat, labels_flat = evaluator.datamodule.val_dataset.pad_and_encode(puzzle, solution)

# Convert to tensors and add batch-dimension
inp_t = torch.from_numpy(inp_flat).long().unsqueeze(0).to(evaluator.device)
labels_t = torch.from_numpy(labels_flat).long().unsqueeze(0).to(evaluator.device)
puzzle_ids = torch.zeros(256, dtype=torch.long).to(evaluator.device)
batch = {"input": inp_t, "output": labels_t, "puzzle_identifiers": puzzle_ids}
print(batch)
# Visualize
viz = evaluator.visualize_thinking(
    batch,
    sample_idx=0,
    show_confidence=True,
    save_gif=False,
)

{'input': tensor([[2, 4, 2, 6, 2, 8, 0, 0, 6, 2, 8, 2, 4, 2, 0, 0, 4, 2, 2, 7, 8, 2, 0, 0,
         2, 6, 7, 2, 2, 4, 0, 0, 7, 2, 6, 2, 5, 2, 0, 0, 2, 2, 4, 5, 2, 7, 0, 0,
         1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], device='cuda:0'), 'output': tensor([[   3,    4,    5,    6,    7,    8, -100, -100,    6,    7,    8,    3,
            4,    5, -100, -100,    4,    5,    3,    7,    8,    6, -100, -100,
            5,    6,    7,    8,    3,    4, -100, -100,    7,    8,    6,    4,
            5,    3, -100, -100,    8,    3,    4,    5,    6,    7, -100, -100,
            1, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100]], device='cuda:0'), 'puzzle_identifiers': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0

In [None]:
results_viz = evaluator.visualize_sample(
    split="val",
    sample_idx=5,
    show_confidence=True,
    save_gif=False,
    num_stochastic_runs=3,      
    dropout_enabled=True,       
)


Running 3 stochastic forward pass(es) (dropout=ON)

TRM THINKING VISUALIZATION (STOCHASTIC - 3 forward passes with dropout)
H_cycles=3, L_cycles=6

----------------------------------------------------------------------------------------------------
STEP-BY-STEP REASONING
----------------------------------------------------------------------------------------------------

┌─ Step 1/1 ─────────────────────────────────────────────────────────────────┐
│ Consensus Accuracy: 100.0% | Disagreement: 0.0% | q=+3.03 (HALT) ← STOPPED
│ Model Confidence: avg=1.00
└─ Run predictions (rows show each forward pass):
   Run  1: 1 2 4 5 3 6  | 5 6 3 1 4 2  | 4 5 2 6 1 3  | 6 3 1 2 5 4  | 2 4 5 3 6 1  | 3 1 6 4 2 5 
   Run  2: 1 2 4 5 3 6  | 5 6 3 1 4 2  | 4 5 2 6 1 3  | 6 3 1 2 5 4  | 2 4 5 3 6 1  | 3 1 6 4 2 5 
   Run  3: 1 2 4 5 3 6  | 5 6 3 1 4 2  | 4 5 2 6 1 3  | 6 3 1 2 5 4  | 2 4 5 3 6 1  | 3 1 6 4 2 5 
   CONS:  1 2 4 5 3 6  | 5 6 3 1 4 2  | 4 5 2 6 1 3  | 6 3 1 2 5 4  | 2 4 5 3 6 1  | 3 1 6 4 