In [1]:
from pathlib import Path
import json
import random

import torch
from torch.utils.data import Dataset, DataLoader
from ollama import Client
import pandas as pd
import numpy as np

from TinyRecursiveModels.models.losses import ACTLossHead
from TinyRecursiveModels.dataset.build_sudoku_dataset import DataProcessConfig, preprocess_data
from TinyRecursiveModels.models.recursive_reasoning.trm import (
    TinyRecursiveReasoningModel_ACTV1, 
    TinyRecursiveReasoningModel_ACTV1Carry, 
    TinyRecursiveReasoningModel_ACTV1InnerCarry
 )

In [2]:
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

<torch._C.Generator at 0x212ffac4050>

In [3]:
if torch.backends.mps.is_available():
    device = torch.device('mps')
elif torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
device

device(type='cuda')

In [4]:
ollama_client = Client()

In [5]:
LAB_DIR = Path(".")
DATA_ROOT = LAB_DIR / "data" / "sudoku-extreme-mini"

config = DataProcessConfig(
    output_dir=str(DATA_ROOT),
    subsample_size=512,
    min_difficulty=None,
    num_aug=1,
)

train_inputs_path = DATA_ROOT / "train" / "all__inputs.npy"

if not train_inputs_path.exists():
    DATA_ROOT.mkdir(parents=True, exist_ok=True)
    preprocess_data(config)
else:
    print(f"Using cached dataset at {DATA_ROOT}")



train.csv:   0%|          | 0.00/719M [00:00<?, ?B/s]

100%|██████████| 512/512 [00:00<00:00, 8419.92it/s]


test.csv:   0%|          | 0.00/79.4M [00:00<?, ?B/s]

100%|██████████| 422786/422786 [00:00<00:00, 1747734.15it/s]


In [None]:
def grid_to_text(tokens, side=None):
    arr = np.array(tokens, dtype=np.int64)
    if side is not None:
        arr = arr.reshape(side, side)
    decoded = arr - 1
    return "\n".join(" ".join(str(int(x)) for x in row) for row in decoded)


def load_split(base_dir, split="train"):
    split_dir = Path(base_dir) / split
    meta = json.loads((split_dir / "dataset.json").read_text())
    inputs = np.load(split_dir / "all__inputs.npy")
    labels = np.load(split_dir / "all__labels.npy")
    puzzle_ids = np.load(split_dir / "all__puzzle_identifiers.npy")
    return meta, inputs, labels, puzzle_ids


meta_train, train_inputs, train_labels, train_puzzle_ids = load_split(DATA_ROOT, "train")
meta_test, test_inputs, test_labels, test_puzzle_ids = load_split(DATA_ROOT, "test")
dataset_meta = meta_train
dataset_meta

In [None]:
class NumpyPuzzleDataset(Dataset):
    def __init__(self, inputs, labels, puzzle_ids):
        self.inputs = inputs
        self.labels = labels
        self.puzzle_ids = puzzle_ids

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        return (
            torch.tensor(self.inputs[idx], dtype=torch.long),
            torch.tensor(self.labels[idx], dtype=torch.long),
            torch.tensor(self.puzzle_ids[idx], dtype=torch.long),
        )


def collate_batch(batch):
    inputs, labels, puzzle_ids = zip(*batch)
    inputs = torch.stack(inputs, dim=0)
    labels = torch.stack(labels, dim=0)
    puzzle_ids = torch.stack(puzzle_ids, dim=0)

    return {
        'inputs': inputs,
        'labels': labels,
        'puzzle_identifiers': puzzle_ids
    }


def make_loaders(batch_size=64):
    train_ds = NumpyPuzzleDataset(train_inputs, train_labels, train_puzzle_ids)
    test_ds = NumpyPuzzleDataset(test_inputs, test_labels, test_puzzle_ids)
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate_batch)
    test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, collate_fn=collate_batch)
    return train_loader, test_loader


loaders = make_loaders(batch_size=64)
{'train_batches': len(loaders[0]), 'test_batches': len(loaders[1])}

In [None]:
def softmax_cross_entropy_masked(logits, labels, ignore_index: int = -100, valid_mask=None):
    # Compute CE in float32 and apply optional mask
    loss = torch.nn.functional.cross_entropy(
        logits.to(torch.float32).view(-1, logits.shape[-1]),
        labels.to(torch.long).view(-1),
        ignore_index=ignore_index,
        reduction='none',
    ).view(labels.shape)
    if valid_mask is None:
        return loss
    return loss * valid_mask

def build_trm(meta, hidden_size=64, heads=4, halt_steps=2):
    cfg = dict(
        batch_size=32,
        seq_len=meta['seq_len'],
        puzzle_emb_ndim=0,
        num_puzzle_identifiers=1,
        vocab_size=meta['vocab_size'],
        H_cycles=2,
        L_cycles=2,
        H_layers=0,
        L_layers=1,
        hidden_size=hidden_size,
        expansion=2,
        num_heads=heads,
        pos_encodings='rope',
        halt_max_steps=halt_steps,
        halt_exploration_prob=0.05,
        forward_dtype='float32',
        mlp_t=True,
        puzzle_emb_len=0,  # keep seq length aligned when puzzle_emb_ndim=0
        no_ACT_continue=True,
    )
    model = TinyRecursiveReasoningModel_ACTV1(cfg)
    return ACTLossHead(model, loss_type='softmax_cross_entropy_masked').to(device)


def move_batch(batch):
    return {k: v.to(device) for k, v in batch.items()}


def move_carry_to_device(carry: TinyRecursiveReasoningModel_ACTV1Carry, device):
    inner = TinyRecursiveReasoningModel_ACTV1InnerCarry(
        z_H=carry.inner_carry.z_H.to(device),
        z_L=carry.inner_carry.z_L.to(device),
)
    current_data = {k: v.to(device) for k, v in carry.current_data.items()}
    return TinyRecursiveReasoningModel_ACTV1Carry(
        inner_carry=inner,
        steps=carry.steps.to(device),
        halted=carry.halted.to(device),
        current_data=current_data
)


def train_epoch(model, loader, optimizer):
    model.train()
    total_loss = 0.0
    total_correct = 0
    total_tokens = 0
    for batch in loader:
        batch = move_batch(batch)
        carry = model.initial_carry(batch)
        carry = move_carry_to_device(carry, batch['inputs'].device)
        carry, loss, metrics, outputs, _ = model(return_keys=['preds'], carry=carry, batch=batch)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        total_loss += loss.item()
        preds = outputs['preds']
        total_correct += (preds == batch['labels']).sum().item()
        total_tokens += batch['labels'].numel()
    return {'loss': total_loss / max(len(loader), 1), 'token_acc': total_correct / max(total_tokens, 1)}


def evaluate(model, loader):
    model.eval()
    total_loss = 0.0
    total_correct = 0
    total_tokens = 0
    with torch.no_grad():
        for batch in loader:
            batch = move_batch(batch)
            carry = model.initial_carry(batch)
            carry = move_carry_to_device(carry, batch['inputs'].device)
            carry, loss, metrics, outputs, _ = model(return_keys=['preds'], carry=carry, batch=batch)
            total_loss += loss.item()
            preds = outputs['preds']
            total_correct += (preds == batch['labels']).sum().item()
            total_tokens += batch['labels'].numel()
    return {'loss': total_loss / max(len(loader), 1), 'token_acc': total_correct / max(total_tokens, 1)}


def predict_batch(model, batch):
    model.eval()
    with torch.no_grad():
        batch = move_batch(batch)
        carry = model.initial_carry(batch)
        carry = move_carry_to_device(carry, batch['inputs'].device)
        carry, _, _, outputs, _ = model(return_keys=['preds'], carry=carry, batch=batch)
        return outputs['preds'].cpu().numpy()


In [None]:
trained_model = None
training_logs = []

print(f"\n=== Training sudoku (seq_len={dataset_meta['seq_len']}) ===")
model = build_trm(dataset_meta, hidden_size=96, heads=4, halt_steps=2)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-3, weight_decay=0.01)
train_loader, test_loader = loaders
epochs = 2

for epoch in range(1, epochs + 1):
    train_metrics = train_epoch(model, train_loader, optimizer)
    val_metrics = evaluate(model, test_loader)
    training_logs.append({
        'task': 'sudoku9',
        'epoch': epoch,
        **{f"train_{k}": v for k, v in train_metrics.items()},
        **{f"val_{k}": v for k, v in val_metrics.items()}
    })
    print(f"epoch {epoch:02d} | train_loss {train_metrics['loss']:.3f} acc {train_metrics['token_acc']:.3f} | val_loss {val_metrics['loss']:.3f} acc {val_metrics['token_acc']:.3f}")

trained_model = model
training_logs[:3]

In [None]:
def render_sample(meta, tensor_1d):
    side = int(np.sqrt(meta['seq_len']))
    return grid_to_text(tensor_1d, side=side)

batch = next(iter(test_loader))
preds = predict_batch(trained_model, batch)
rows = []

for i in range(min(2, len(preds))):
    rows.append({
        'input': render_sample(dataset_meta, batch['inputs'][i].numpy()),
        'target': render_sample(dataset_meta, batch['labels'][i].numpy()),
        'prediction': render_sample(dataset_meta, preds[i])
    })

json.dumps(rows, indent=2)

In [None]:
def parse_grid(text, side):
    """Parse an LLM text response into a side x side numpy grid of ints.
    Accepts lines separated by newlines and cells separated by spaces or commas.
    Returns None if parsing fails or the shape/value constraints are violated.
    """
    if not text:
        return None
    lines = [ln.strip() for ln in text.strip().splitlines() if ln.strip()]
    if len(lines) < 1:
        return None
    # Use first line that looks numeric-rich; otherwise concatenate
    candidate_lines = lines
    cells = []
    for ln in candidate_lines:
        parts = [p for p in ln.replace(',', ' ').split(' ') if p]
        if len(parts) == side * side:
            cells = parts
            break
        cells.extend(parts)
    if len(cells) != side * side:
        return None
    try:
        arr = np.array([int(p) for p in cells], dtype=np.int64)
    except ValueError:
        return None
    # Ensure values are 0-9 for Sudoku; allow 0 as blank
    if not ((arr >= 0) & (arr <= 9)).all():
        return None
    try:
        return arr.reshape(side, side)
    except ValueError:
        return None


In [None]:
def build_prompt(meta, grid_in, strategy='zero'):
    side = int(np.sqrt(meta['seq_len']))
    grid_txt = grid_to_text(grid_in, side=side)
    header = f"Fill the {side}x{side} Sudoku. Zeros/blanks mean empty cells. Return the completed grid with spaces separated."
    prompt = f"{header}\nGrid:\n{grid_txt}\n"
    if strategy == 'cot':
        prompt += "Explain briefly then give the final grid."
    return prompt


def evaluate_llm(meta, max_examples=5, strategies=("zero", "cot")):
    if ollama_client is None:
        print("Ollama client not available; skipping LLM eval.")
        return []

    side = int(np.sqrt(meta['seq_len']))
    results = []
    for idx, (inp, out) in enumerate(zip(test_inputs, test_labels)):
        if idx >= max_examples:
            break

        for strat in strategies:
            prompt = build_prompt(meta, inp, strategy=strat)
            try:
                resp = ollama_client.chat(model='ministral-3:3b', messages=[{'role': 'user', 'content': prompt}])
                content = resp['message']['content'] if isinstance(resp, dict) else ''
            except Exception as e:
                content = f"ERROR: {e}"

            parsed = parse_grid(content, side)
            token_acc = None
            parsed_tokens = None

            if parsed is not None:
                parsed_tokens = (parsed + 1).reshape(-1)
                token_acc = float((parsed_tokens == out.reshape(-1)).mean())

            results.append({
                'task': 'sudoku9',
                'example_id': idx,
                'strategy': strat,
                'prompt': prompt,
                'raw_response': content,
                'parsed': parsed.tolist() if parsed is not None else None,
                'token_acc': token_acc
            })

    return results

In [None]:
llm_results = evaluate_llm(dataset_meta, max_examples=1, strategies=("zero", "cot"))
llm_results[:2]

In [None]:
trm_eval = []
metrics = evaluate(trained_model, test_loader)
trm_eval.append({'task': 'sudoku9', 'model': 'trm', **metrics})

trm_df = pd.DataFrame(trm_eval)
llm_df = pd.DataFrame([r for r in llm_results if r.get('token_acc') is not None])

if (not llm_df.empty) and {'task','strategy','token_acc'}.issubset(llm_df.columns):
    llm_summary = llm_df.groupby(['task', 'strategy'])['token_acc'].mean().reset_index()
else:
    llm_summary = pd.DataFrame(columns=['task','strategy','token_acc'])

trm_df, llm_summary.head()