In [1]:
# Cell 1: Imports
import torch
from pathlib import Path
from src.nn.sudoku_evaluator_uncertainty import SudokuEvaluator

# Set up paths
CHECKPOINT_PATH = "../train/runs/2026-01-06_16-21-25/checkpoints/last.ckpt"
DATA_DIR = "../data/sudoku_6x6_large"

Failed to import adam2


In [2]:
# Cell 2: Initialize evaluator
evaluator = SudokuEvaluator(
    checkpoint_path=CHECKPOINT_PATH,
    data_dir=DATA_DIR,
    batch_size=256,
    device="auto",
    num_workers=0,
    eval_split="val"
)

print(f"Model loaded successfully!")
print(f"Grid size: {evaluator.grid_size}x{evaluator.grid_size}")
print(f"Vocab size: {evaluator.vocab_size}")

Using device: cuda
Loading model from ../train/runs/2026-01-06_16-21-25/checkpoints/last.ckpt


Model loaded: TRMModule
Model configuration:
  hidden_size: 512
  num_layers: 2
  H_cycles: 3
  L_cycles: 6
  N_supervision: 16
  vocab_size: 9
  seq_len: 64
Grid size: 6x6
Max grid size: 8x8
Vocab size: 9
Model loaded successfully!
Grid size: 6x6
Vocab size: 9


In [12]:
# Cell 3: Run evaluation on validation split
results = evaluator.evaluate(split="val", print_examples=True)

print("\n" + "=" * 60)
print("EVALUATION RESULTS")
print("=" * 60)
print(f"Cell Accuracy:   {results['cell_accuracy']:.4f} ({results['cell_accuracy'] * 100:.2f}%)")
print(f"Puzzle Accuracy: {results['puzzle_accuracy']:.4f} ({results['puzzle_accuracy'] * 100:.2f}%)")
print(f"Validity Rate:   {results['validity_rate']:.4f} ({results['validity_rate'] * 100:.2f}%)")
print(f"Puzzles Solved:  {results['puzzles_correct']}/{results['total_puzzles']}")
print(f"Valid Solutions: {results['valid_puzzles']}/{results['total_puzzles']}")
print(f"Avg Steps:       {results['avg_steps']:.1f}")
print("=" * 60)


Evaluating on val split...


  input_tensor = torch.from_numpy(input_flat).long()



EXAMPLE PREDICTIONS

--- Example 1 ---
Status: ✓ CORRECT | Valid Sudoku: ✓

INPUT           TARGET          PREDICTION
--------------  --------------  --------------
1 2 4 |_ _ _    1 2 4 |5 3 6    1 2 4 |5 3 6
5 6 _ |1 _ _    5 6 3 |1 4 2    5 6 3 |1 4 2
-----+-----     -----+-----     -----+-----
_ 5 2 |6 1 _    4 5 2 |6 1 3    4 5 2 |6 1 3
_ 3 _ |2 5 4    6 3 1 |2 5 4    6 3 1 |2 5 4
-----+-----     -----+-----     -----+-----
_ _ 5 |3 6 1    2 4 5 |3 6 1    2 4 5 |3 6 1
_ 1 _ |_ 2 _    3 1 6 |4 2 5    3 1 6 |4 2 5

--- Example 2 ---
Status: ✓ CORRECT | Valid Sudoku: ✓

INPUT           TARGET          PREDICTION
--------------  --------------  --------------
_ 4 _ |_ 6 5    1 4 2 |3 6 5    1 4 2 |3 6 5
3 5 _ |_ 1 _    3 5 6 |2 1 4    3 5 6 |2 1 4
-----+-----     -----+-----     -----+-----
6 _ 3 |_ 4 _    6 1 3 |5 4 2    6 1 3 |5 4 2
_ 2 _ |_ _ 1    5 2 4 |6 3 1    5 2 4 |6 3 1
-----+-----     -----+-----     -----+-----
_ _ _ |4 _ 6    2 3 1 |4 5 6    2 3 1 |4 5 6
_ _ 5 |1 _ _    

Evaluating val: 100%|██████████| 7/7 [00:13<00:00,  1.91s/it]


EVALUATION RESULTS
Cell Accuracy:   0.9094 (90.94%)
Puzzle Accuracy: 0.7070 (70.70%)
Validity Rate:   0.8711 (87.11%)
Puzzles Solved:  1267/1792
Valid Solutions: 1561/1792
Avg Steps:       16.0





In [3]:
# Cell 4: (Optional) Visualize model thinking on a sample
results_viz = evaluator.visualize_sample(
    split="val",
    sample_idx=1,
    show_confidence=True,
    save_gif=True  # Set to True if you want to save GIF
)

  input_tensor = torch.from_numpy(input_flat).long()



Running 1 stochastic forward pass(es) (dropout=OFF)

TRM THINKING VISUALIZATION
H_cycles=3, L_cycles=6

INPUT            TARGET
---------------  ---------------
 ? 4 ?| ? 6 5     1 4 2| 3 6 5
 3 5 ?| ? 1 ?     3 5 6| 2 1 4
------+------    ------+------
 6 ? 3| ? 4 ?     6 1 3| 5 4 2
 ? 2 ?| ? ? 1     5 2 4| 6 3 1
------+------    ------+------
 ? ? ?| 4 ? 6     2 3 1| 4 5 6
 ? ? 5| 1 ? ?     4 6 5| 1 2 3

Empty cells to fill: 21

--------------------------------------------------------------------------------
STEP-BY-STEP REASONING (each step = H×L iterations of reasoning blocks)
--------------------------------------------------------------------------------

┌─ Step 1 ─────────────────────────────────────────────────────────
│ Accuracy: 81.0% (4 errors) | Changes: - | q=+1.05 (HALT)
│ Confidence: avg=1.00, min=1.00
└──────────────────────────────────────────────────────────────────────
  !2 4!1| 3 6 5
   3 5 6| 2 1 4
  ------+------
   6 1 3| 5 4 2
   5 2 4| 6 3 1
  ------+------
 

In [4]:
results_viz = evaluator.visualize_sample(
    split="val",
    sample_idx=0,
    show_confidence=True,
    save_gif=False,
    num_stochastic_runs=3,      
    dropout_enabled=True,       
)


Running 3 stochastic forward pass(es) (dropout=ON)

TRM THINKING VISUALIZATION (STOCHASTIC - 3 forward passes with dropout)
H_cycles=3, L_cycles=6

----------------------------------------------------------------------------------------------------
STEP-BY-STEP REASONING
----------------------------------------------------------------------------------------------------

┌─ Step 1/1 ─────────────────────────────────────────────────────────────────┐
│ Consensus Accuracy: 100.0% | Disagreement: 0.0% | q=+3.03 (HALT) ← STOPPED
│ Model Confidence: avg=1.00
└─ Run predictions (rows show each forward pass):
   Run  1: 1 2 4 5 3 6  | 5 6 3 1 4 2  | 4 5 2 6 1 3  | 6 3 1 2 5 4  | 2 4 5 3 6 1  | 3 1 6 4 2 5 
   Run  2: 1 2 4 5 3 6  | 5 6 3 1 4 2  | 4 5 2 6 1 3  | 6 3 1 2 5 4  | 2 4 5 3 6 1  | 3 1 6 4 2 5 
   Run  3: 1 2 4 5 3 6  | 5 6 3 1 4 2  | 4 5 2 6 1 3  | 6 3 1 2 5 4  | 2 4 5 3 6 1  | 3 1 6 4 2 5 
   CONS:  1 2 4 5 3 6  | 5 6 3 1 4 2  | 4 5 2 6 1 3  | 6 3 1 2 5 4  | 2 4 5 3 6 1  | 3 1 6 4 