In [None]:
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parent / "TinyRecursiveModels"))
from TinyRecursiveModels.models.recursive_reasoning.trm import *

import torch 
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader, Dataset
from typing import Tuple, Optional
from torch.nn.utils.rnn import pad_sequence
from pathlib import Path

import torch.optim as optim
from tqdm import tqdm
import wandb
from datetime import datetime

In [None]:
# -----------------------------
# CONFIG
# -----------------------------
# --- Hyperparameters ---
TRAIN_PROBLEMS_DIR = 'data/train/problems'
TRAIN_SOLUTIONS_DIR = 'data/train/solutions'
VAL_PROBLEMS_DIR = 'data/val/problems'
VAL_SOLUTIONS_DIR = 'data/val/solutions'
CHKPT_DIR = 'checkpoints'
CHKPT_DIR = Path(CHKPT_DIR)

BATCH_SIZE = 4
NUM_EPOCHS = 10
LR = 1e-4
HALT_MAX_STEPS = 16 # This is N_sup, the number of "Deep Supervision" steps
ACT_LOSS_WEIGHT = 0.01  # Weight for the ACT (halting) loss

# --- Model & Data Shape Config ---
# We must pad all problems to a fixed size.
MAX_J = 20
MAX_T = 20
MAX_SEQ_LEN = MAX_J * MAX_T # This is the 'seq_len' for the model
HIDDEN_SIZE = 256 # Model's internal dimension
PUZZLE_EMB_DIM = 64 # Dimension for the "task ID" embedding

model_cfg = dict(
    batch_size=BATCH_SIZE,
    seq_len=MAX_SEQ_LEN,
    puzzle_emb_ndim=PUZZLE_EMB_DIM,
    num_puzzle_identifiers=1, # We only have one task: "JSP"
    vocab_size=0,             # We are not using a token vocab (CRITICAL)
    puzzle_emb_len=16,
    
    # Recursive reasoning config
    H_cycles=3,               # T=3 in the paper, a good default
    L_cycles=6,               # n=6 in the paper, a good default
    L_layers=2,               # "tiny" 2-layer model from paper
    H_layers=0,               # This is ignored since the TRM simplified the hierarchy

    # Transformer config
    hidden_size=HIDDEN_SIZE,
    expansion=2.0,            # Standard for SwiGLU
    num_heads=4,              # Must be a divisor of HIDDEN_SIZE
    pos_encodings="rope",

    # Halting Q-learning config
    halt_max_steps=HALT_MAX_STEPS,
    halt_exploration_prob=0.0, # Not needed for simple training
    no_ACT_continue=True,      # Use the simplified ACT loss from paper

    forward_dtype="float32"
)

In [None]:
# -----------------------------
# DATASET AND DATALOADER
# -----------------------------

class CustomJobShopDataset(Dataset):
    def __init__(self, problems_dir, solutions_dir, transform=None):
        self.problems_dir = Path(problems_dir)
        self.solutions_dir = Path(solutions_dir)
        self.transform = transform

        self.problems_files = sorted(self.problems_dir.glob('*.pt'))
        self.solutions_files = sorted(self.solutions_dir.glob('*.pt'))

        self.num_problems = len(self.problems_files)
        assert self.num_problems == len(self.solutions_files), "Problem and solution file count mismatch"

    def __len__(self):
        return self.num_problems

    def __getitem__(self, idx):
        problem = torch.load(self.problems_files[idx])
        solution = torch.load(self.solutions_files[idx])
        
        # flatten from (J,T,C) to (L,C)
        problem = problem.view(-1, 2)
        solution = solution.view(-1, 1)
        # we need to provide puzzle_identifier, 0 is the ID for our JSP task
        puzzle_identifier = torch.tensor(0, dtype=torch.int32)
        
        return {
            "inputs": problem,
            "labels": solution,
            "puzzle_identifiers": puzzle_identifier
        }


def custom_collate_fn(batch):
    max_len = max(item['inputs'].shape[0] for item in batch)
    if max_len > MAX_SEQ_LEN:
        max_len = MAX_SEQ_LEN

    # Pad inputs (the (M, D) puzzle)
    padded_problems = torch.zeros(len(batch), MAX_SEQ_LEN, 2, dtype=torch.float32)
    # Pad labels (the (S) solution)
    padded_solutions = torch.full((len(batch), MAX_SEQ_LEN, 1), -1.0, dtype=torch.float32)
    # Mask to identify real data vs. padding
    mask = torch.zeros(len(batch), MAX_SEQ_LEN, 1, dtype=torch.bool)
    
    identifiers = []

    for i, item in enumerate(batch):
        seq_len = item['inputs'].shape[0]
        
        # Truncate if longer than max len
        if seq_len > MAX_SEQ_LEN:
            seq_len = MAX_SEQ_LEN
            
        padded_problems[i, :seq_len] = item['inputs'][:seq_len]
        padded_solutions[i, :seq_len] = item['labels'][:seq_len]
        mask[i, :seq_len] = True
        identifiers.append(item['puzzle_identifiers'])

    return {
        'inputs': padded_problems,
        'labels': padded_solutions,
        'mask': mask,
        'puzzle_identifiers': torch.stack(identifiers)
    }

In [None]:
# -----------------------------
# INITIALIZE DATA LOADERS
# -----------------------------
print("Starting training...")
train_dataset = CustomJobShopDataset(TRAIN_PROBLEMS_DIR, TRAIN_SOLUTIONS_DIR)
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=custom_collate_fn)
print(f"There are {len(train_dataset)} training job shop problem samples")

val_dataset = CustomJobShopDataset(VAL_PROBLEMS_DIR, VAL_SOLUTIONS_DIR)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=custom_collate_fn)
print(f"There are {len(val_dataset)} validation job shop problem samples")

Starting training...
There are 1000 training job shop problem samples
There are 200 validation job shop problem samples


In [None]:
# -----------------------------
# MODEL, LOSS, OPTIMIZER
# -----------------------------
print("Setting up model...")
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

model = TinyRecursiveReasoningModel_ACTV1(model_cfg)
loss_fn = nn.MSELoss(reduction='none') # 'none' to apply mask manually
act_loss_fn = nn.BCEWithLogitsLoss() # For the halting signal
optimizer = optim.AdamW(model.parameters(), lr=LR)
print(f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters.")
print(f"Model is as follows: \n{model}")

Setting up model...
Using device: cpu
Model created with 1,312,002 parameters.
Model is as follows: 
TinyRecursiveReasoningModel_ACTV1(
  (inner): TinyRecursiveReasoningModel_ACTV1_Inner(
    (input_proj): CastedLinear()
    (lm_head): CastedLinear()
    (q_head): CastedLinear()
    (puzzle_emb): CastedSparseEmbedding()
    (rotary_emb): RotaryEmbedding()
    (L_level): TinyRecursiveReasoningModel_ACTV1ReasoningModule(
      (layers): ModuleList(
        (0-1): 2 x TinyRecursiveReasoningModel_ACTV1Block(
          (self_attn): Attention(
            (qkv_proj): CastedLinear()
            (o_proj): CastedLinear()
          )
          (mlp): SwiGLU(
            (gate_up_proj): CastedLinear()
            (down_proj): CastedLinear()
          )
        )
      )
    )
  )
)


In [None]:
# -----------------------------
# WANDB
# -----------------------------
now = datetime.now().strftime("%m%d_%H%M")
exp_name = f"{now}"
wandb.login()
wandb.init(
    project="trm-scheduling",
    name=exp_name,
    config={
        "epochs": NUM_EPOCHS,
        "batch_size": BATCH_SIZE,
        "optimizer": "AdamW",
        "learning_rate": LR,
        **model_cfg
        }
    )

[34m[1mwandb[0m: Currently logged in as: [33mjana-mila[0m ([33mjana-mila-mila[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
# -----------------------------
# TRAINING
# -----------------------------
best_val_loss = 1e10
for epoch in range(NUM_EPOCHS):
    print(f"\n ---Epoch {epoch+1}/{NUM_EPOCHS} ---")
    model.train()
    train_epoch_loss = 0.0
    train_batch = 0

    for idx, batch in enumerate(tqdm(train_dataloader)):
        batch = {k: v.to(device) for k, v in batch.items()}
        carry = model.initial_carry(batch)
        total_loss_for_batch = 0

        for step in range(HALT_MAX_STEPS):
            carry, outputs = model(carry=carry, batch=batch)
            logits = outputs['logits']             
            labels = batch['labels']
            mask = batch['mask'] # (B, L, 1)

            # --- 1. Calculate Main Task Loss (Regression) ---
            step_loss_unmasked = loss_fn(logits, labels)
            step_loss_masked = step_loss_unmasked * mask
            
            # Add a small epsilon to prevent division by zero if a batch has no real tokens
            # (which shouldn't happen, but is good practice)
            num_real_tokens = mask.sum() + 1e-8 
            task_loss = step_loss_masked.sum() / num_real_tokens

            # --- 2. Calculate ACT Loss (Halting) ---
            # We want the model to halt when its answer is "good enough"
            # For regression, "good enough" is complex. A simpler target
            # is to just train it to halt on the *last step*.
            
            # Target is 0 (don't halt) for all steps except the last
            is_last_step = (step == HALT_MAX_STEPS - 1)
            halt_target = torch.full_like(outputs['q_halt_logits'], 
                                          float(is_last_step))
            act_loss = act_loss_fn(outputs['q_halt_logits'], halt_target)

            # --- 3. Combine Losses ---
            total_step_loss = task_loss + (act_loss * ACT_LOSS_WEIGHT)
            total_loss_for_batch += total_step_loss
        
         # Backpropagate the *sum* of losses from all 16 steps
        optimizer.zero_grad()
        total_loss_for_batch.backward()
        optimizer.step()

        train_batch += 1        
        train_epoch_loss += total_loss_for_batch.item()
    
    avg_train_epoch_loss = (train_epoch_loss / train_batch) / HALT_MAX_STEPS
    print(f"Epoch {epoch+1} finished. Average training loss: {avg_train_epoch_loss:.6f}")

    # -----------------------------
    # VALIDATE
    # -----------------------------
    model.eval()
    val_epoch_total_loss = 0.0
    val_batches = 0

    with torch.no_grad():
        for batch in tqdm(val_dataloader, leave=False, desc="Validating"):
            batch = {k: v.to(device) for k, v in batch.items()}
            carry = model.initial_carry(batch)
            total_val_loss_for_batch = 0

            for step in range(HALT_MAX_STEPS):
                carry, outputs = model(carry=carry, batch=batch)
                logits = outputs['logits']
                labels = batch['labels']
                mask = batch['mask']

                step_loss_unmasked = loss_fn(logits, labels)
                step_loss_masked = step_loss_unmasked * mask

                num_real_tokens = mask.sum() + 1e-8 
                task_loss = step_loss_masked.sum() / num_real_tokens

                is_last_step = (step == HALT_MAX_STEPS - 1)
                halt_target = torch.full_like(outputs['q_halt_logits'], float(is_last_step))
                act_loss = act_loss_fn(outputs['q_halt_logits'], halt_target)

                total_step_loss = task_loss + (act_loss * ACT_LOSS_WEIGHT)
                total_loss_for_batch += total_step_loss

            val_epoch_total_loss += total_loss_for_batch.item()
            val_batches += 1

    avg_val_loss = (val_epoch_total_loss / val_batches) / HALT_MAX_STEPS
    print(f"Epoch {epoch+1} Validation loss: {avg_val_loss: .6f}") 
    wandb.log({'val_loss': avg_val_loss})

    # -----------------------------
    # CHECKPOINT (IF BEST)
    # -----------------------------
    if avg_val_loss < best_val_loss:
        chckpt_name = f"{exp_name}-epoch-{epoch}-valloss-{avg_val_loss:.4f}.pth"
        chkpt_file = CHKPT_DIR / chckpt_name
        torch.save(model.state_dict(), chkpt_file)
        best_val_loss = avg_val_loss
        artifact = wandb.Artifact(name=chckpt_name, type='model')
        artifact.add_file(chkpt_file)
        wandb.log_artifact(artifact)
        print(f"New best model saved at {chkpt_file}")

wand.finish()
print('Training complete!')


 ---Epoch 1/10 ---


  0%|          | 0/250 [00:02<?, ?it/s]


KeyboardInterrupt: 

Error in callback <bound method _WandbInit._post_run_cell_hook of <wandb.sdk.wandb_init._WandbInit object at 0x13ebd70e0>> (for post_run_cell), with arguments args (<ExecutionResult object at 16402be30, execution_count=7 error_before_exec=None error_in_exec= info=<ExecutionInfo object at 16402b5c0, raw_cell="# %%
# -----------------------------
# TRAINING
# .." transformed_cell="# %%
# -----------------------------
# TRAINING
# .." store_history=True silent=False shell_futures=True cell_id=vscode-notebook-cell:/Interactive-2.interactive#X10sdW50aXRsZWQ%3D> result=None>,),kwargs {}:


ConnectionResetError: Connection lost