In [1]:
import hydra
import os
from omegaconf import DictConfig, OmegaConf
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint
import torch

import os
import sys
from pathlib import Path
from hydra import initialize_config_dir, compose
from hydra.core.global_hydra import GlobalHydra

notebook_path = Path.cwd() 
project_root = notebook_path.parent
os.chdir(project_root)

from src.datamodule import JobShopDataModule
from src.model import TinyRecursiveModelJobShop

In [2]:
abs_config_dir = project_root / "configs" / "v3"
with initialize_config_dir(version_base="1.3", config_dir=str(abs_config_dir)):
    # You can add overrides here if needed, e.g., overrides=["batch_size=32"]
    cfg = compose(config_name="default_complex.yaml")

In [3]:
datamodule = JobShopDataModule(
    config=cfg, 
    batch_size=cfg.batch_size, 
    max_seq_len=cfg.model.trm_model.seq_len
)

datamodule.setup()
dataloader = datamodule.val_dataloader()

There are 1000 problems
There are 1000 solutions
There are 300 problems
There are 300 solutions
Train samples: 1000, Val samples: 300




In [4]:
device = "cuda"
checkpoint_path = "/home/mila/o/oseaj/projects/trm-scheduling/checkpoints/default_v3_complex_1222_2037/default_v3_complex_1222_2037-epoch=00-val_loss=1.5413.ckpt"
model = TinyRecursiveModelJobShop.load_from_checkpoint(checkpoint_path)
model.eval()
model.freeze()
model.to(device)

TinyRecursiveModelJobShop(
  (model): TinyRecursiveReasoningModel_ACTV1(
    (inner): TinyRecursiveReasoningModel_ACTV1_Inner(
      (input_proj): CastedLinear()
      (cls_head): CastedLinear()
      (q_head): CastedLinear()
      (puzzle_emb): CastedSparseEmbedding()
      (rotary_emb): RotaryEmbedding()
      (L_level): TinyRecursiveReasoningModel_ACTV1ReasoningModule(
        (layers): ModuleList(
          (0-1): 2 x TinyRecursiveReasoningModel_ACTV1Block(
            (self_attn): Attention(
              (qkv_proj): CastedLinear()
              (o_proj): CastedLinear()
            )
            (mlp): SwiGLU(
              (gate_up_proj): CastedLinear()
              (down_proj): CastedLinear()
            )
          )
        )
      )
    )
  )
  (ce_loss): CrossEntropyLoss()
  (act_loss_fn): BCEWithLogitsLoss()
)

In [8]:
batch = next(iter(dataloader))
batch = {k: v.to(device) for k, v in batch.items()}
B, L, K = batch['labels'].shape
num_classes = 10

carry = model.model.initial_carry(batch)
carry.to(device)

halt_max_steps = model.hparams_initial.config.halt_max_steps
final_logits = torch.zeros(B, L, K, num_classes, device=device)
is_finished = torch.zeros(B, dtype=torch.bool, device=device)
with torch.no_grad():
    for step in range(halt_max_steps):
        # Forward pass for one step
        carry, output = model(carry=carry, batch=batch)
        current_logits = output["logits"]



In [10]:
# 1. Prepare Batch
batch = next(iter(dataloader))
batch = {k: v.to(device) for k, v in batch.items() if isinstance(v, torch.Tensor)}

# 2. Initialize State
halt_max_steps = model.hparams.config.halt_max_steps
carry = model.model.initial_carry(batch)
carry.to(device)

print(f"Running inference for {halt_max_steps} steps...")

with torch.no_grad():
    for step in range(halt_max_steps):
        # Forward pass
        # The model internally updates the 'carry' (hidden state) based on the previous step
        carry, output = model(carry=carry, batch=batch)

    # 3. Get Final Result (from the very last step)
    final_logits = output["logits"]  # Shape: (B, L, K, 10)
    
    # Get predictions
    predictions = final_logits.argmax(dim=-1) # Shape: (B, L, K)

    # Optional: Zero out padding for cleaner inspection
    if 'mask' in batch:
        mask = batch['mask'].expand_as(predictions).bool()
        predictions[~mask] = -1

print("Done.")
print("Predictions shape:", predictions.shape)
# Example: Print first active task of first sample
# print(predictions[0, 0])

Running inference for 16 steps...
Done.
Predictions shape: torch.Size([4, 400, 3])


In [11]:
predictions[0][:200]

tensor([[ 0,  1,  0],
        [ 0,  1,  0],
        [ 0,  3,  0],
        [ 0,  4,  5],
        [ 0,  5,  9],
        [ 0,  5,  9],
        [ 0,  7,  9],
        [ 0,  5,  7],
        [ 0,  7,  9],
        [ 0,  8,  9],
        [ 0,  0,  0],
        [ 0,  1,  0],
        [ 0,  2,  0],
        [ 0,  5,  5],
        [ 0,  5,  5],
        [ 0,  7,  5],
        [ 0,  5,  5],
        [ 0,  7,  5],
        [ 0,  7,  5],
        [ 0,  8,  9],
        [ 0,  0,  0],
        [ 0,  1,  0],
        [ 0,  1,  0],
        [ 0,  3,  0],
        [ 0,  5,  5],
        [ 0,  7,  9],
        [ 0,  5,  5],
        [ 0,  7,  9],
        [ 0,  7,  9],
        [ 0,  7,  3],
        [ 0,  0,  0],
        [ 0,  0,  0],
        [ 0,  1,  0],
        [ 0,  4,  5],
        [ 0,  5,  5],
        [ 0,  5,  7],
        [ 0,  7,  5],
        [ 0,  7,  9],
        [ 0,  7,  5],
        [ 1,  7,  9],
        [ 0,  0,  0],
        [ 0,  1,  0],
        [ 0,  0,  0],
        [ 0,  5,  5],
        [ 0,  5,  5],
        [ 

In [12]:
targets = batch['labels']
matches = (predictions == targets)

# Correct if: Matches OR Target is -100 OR Padding
ignore_mask = (targets == -100)
is_padding  = (mask == 0)
effective_correctness = matches | ignore_mask | is_padding

In [13]:
is_sample_correct = effective_correctness.all(dim=-1).all(dim=-1)
is_sample_correct

tensor([False, False, False, False], device='cuda:0')