In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
from omegaconf import OmegaConf
from hydra import compose, initialize
from hydra.utils import instantiate

from src.model import TinyRecursiveModelJobShop
from src.datamodule import CustomJobShopDataset, JobShopDataModule

In [2]:
with initialize(version_base = "1.3", config_path = "configs"):
    cfg = compose(config_name="default.yaml")
datamodule = JobShopDataModule(config = cfg, batch_size = cfg.batch_size, max_seq_len = cfg.model.trm_model.seq_len)
datamodule.setup()
val_dl = datamodule.val_dataloader()

checkpoint_path = "checkpoints/default-1118_1059/run-1118_1059-epoch=331-val_loss=2121.2168.ckpt"
model = TinyRecursiveModelJobShop.load_from_checkpoint(checkpoint_path)
model.eval()
model.freeze()
model

Train samples: 1000, Val samples: 301


TinyRecursiveModelJobShop(
  (model): TinyRecursiveReasoningModel_ACTV1(
    (inner): TinyRecursiveReasoningModel_ACTV1_Inner(
      (input_proj): CastedLinear()
      (lm_head): CastedLinear()
      (q_head): CastedLinear()
      (puzzle_emb): CastedSparseEmbedding()
      (rotary_emb): RotaryEmbedding()
      (L_level): TinyRecursiveReasoningModel_ACTV1ReasoningModule(
        (layers): ModuleList(
          (0-1): 2 x TinyRecursiveReasoningModel_ACTV1Block(
            (self_attn): Attention(
              (qkv_proj): CastedLinear()
              (o_proj): CastedLinear()
            )
            (mlp): SwiGLU(
              (gate_up_proj): CastedLinear()
              (down_proj): CastedLinear()
            )
          )
        )
      )
    )
  )
  (loss_fn): MSELoss()
  (act_loss_fn): BCEWithLogitsLoss()
)

In [5]:
batch = next(iter(val_dl))

In [6]:
batch.keys()

dict_keys(['inputs', 'labels', 'mask', 'puzzle_identifiers'])

In [None]:
device = model.device
batch = {k: v.to(device) for k, v in batch.items()}

In [27]:
with torch.no_grad():
    carry = model.model.initial_carry(batch)
    carry = carry.to(model.device)
    for step in range(cfg.model.trm_model.halt_max_steps):
        carry, output = model(carry = carry, batch = batch)
        y_hat = output['logits']
        y_true = batch['labels']
        mask = batch['mask']

In [65]:
"""
Problem
          task_0   task_1   task_2   task_3   task_4   task_5   task_6   task_7   task_8   task_9  task_10
job_0    (8, 4)  (13, 8)  (10, 4)   (9, 3)  (11, 4)   (9, 2)  (11, 0)   (2, 7)   (7, 0)   (5, 1)   (5, 4)
job_1   (13, 5)   (7, 7)  (12, 0)   (9, 5)  (10, 1)   (1, 0)   (6, 5)  (12, 6)   (8, 9)  (13, 0)   (0, 1)
job_2    (5, 7)  (11, 5)   (4, 5)  (11, 4)  (11, 0)   (5, 5)  (12, 3)   (1, 4)   (6, 0)   (6, 9)   (6, 3)
job_3    (9, 8)  (12, 2)   (1, 6)   (5, 4)  (12, 6)  (11, 3)   (1, 4)   (2, 4)   (9, 5)   (0, 1)  (12, 9)
job_4   (10, 1)   (0, 9)   (5, 8)   (8, 1)  (11, 4)  (11, 6)   (7, 4)  (10, 6)   (5, 8)  (10, 4)   (3, 3)
job_5   (11, 3)   (1, 5)  (12, 3)   (4, 5)   (0, 9)   (1, 6)   (2, 7)   (9, 1)   (8, 7)   (5, 0)  (11, 5)
job_6    (1, 3)   (7, 7)   (4, 0)   (5, 0)   (3, 3)  (13, 2)   (0, 9)   (5, 3)  (10, 2)   (1, 7)  (11, 1)
job_7    (7, 2)   (0, 5)   (3, 0)   (8, 9)   (2, 1)   (9, 4)   (2, 6)   (7, 8)   (8, 8)   (0, 3)   (7, 8)
job_8    (6, 4)  (11, 8)  (10, 4)   (3, 4)   (5, 7)   (0, 9)   (6, 0)   (3, 0)   (9, 4)   (8, 5)  (10, 4)
job_9    (3, 2)   (0, 3)   (8, 4)  (11, 2)   (1, 0)   (2, 4)  (11, 9)   (9, 5)   (9, 3)  (11, 8)   (1, 1)
job_10   (1, 2)  (12, 0)   (2, 1)   (5, 8)   (3, 0)  (10, 1)   (0, 8)  (12, 6)   (6, 5)   (3, 1)  (12, 7)
job_11   (9, 4)   (5, 3)   (6, 5)  (10, 7)   (4, 8)  (11, 3)  (13, 7)  (12, 9)   (8, 5)   (7, 3)  (11, 6)
job_12   (7, 7)   (0, 4)  (12, 3)  (11, 0)   (3, 8)   (9, 3)  (12, 1)  (12, 9)   (2, 6)   (8, 8)   (8, 9)
job_13   (3, 6)  (13, 9)   (2, 3)   (4, 4)   (5, 2)   (9, 0)  (13, 6)   (4, 6)   (9, 4)   (2, 2)   (6, 5)
job_14  (11, 6)   (0, 8)  (10, 5)  (10, 1)   (9, 1)   (6, 8)  (11, 0)  (10, 6)   (8, 7)   (1, 0)  (13, 5)
job_15   (3, 7)  (13, 9)  (13, 3)   (2, 1)   (2, 9)   (5, 4)  (11, 7)   (9, 0)   (5, 1)   (4, 3)   (4, 3)
job_16   (7, 3)   (3, 2)   (1, 6)   (1, 8)   (9, 7)   (1, 3)   (1, 3)   (4, 9)   (6, 8)   (8, 0)  (13, 8)
job_17   (3, 6)  (13, 2)   (4, 1)   (4, 9)   (7, 2)   (3, 0)   (7, 5)   (0, 2)   (7, 8)   (8, 1)   (0, 5)
"""

'\nProblem\n          task_0   task_1   task_2   task_3   task_4   task_5   task_6   task_7   task_8   task_9  task_10\njob_0    (8, 4)  (13, 8)  (10, 4)   (9, 3)  (11, 4)   (9, 2)  (11, 0)   (2, 7)   (7, 0)   (5, 1)   (5, 4)\njob_1   (13, 5)   (7, 7)  (12, 0)   (9, 5)  (10, 1)   (1, 0)   (6, 5)  (12, 6)   (8, 9)  (13, 0)   (0, 1)\njob_2    (5, 7)  (11, 5)   (4, 5)  (11, 4)  (11, 0)   (5, 5)  (12, 3)   (1, 4)   (6, 0)   (6, 9)   (6, 3)\njob_3    (9, 8)  (12, 2)   (1, 6)   (5, 4)  (12, 6)  (11, 3)   (1, 4)   (2, 4)   (9, 5)   (0, 1)  (12, 9)\njob_4   (10, 1)   (0, 9)   (5, 8)   (8, 1)  (11, 4)  (11, 6)   (7, 4)  (10, 6)   (5, 8)  (10, 4)   (3, 3)\njob_5   (11, 3)   (1, 5)  (12, 3)   (4, 5)   (0, 9)   (1, 6)   (2, 7)   (9, 1)   (8, 7)   (5, 0)  (11, 5)\njob_6    (1, 3)   (7, 7)   (4, 0)   (5, 0)   (3, 3)  (13, 2)   (0, 9)   (5, 3)  (10, 2)   (1, 7)  (11, 1)\njob_7    (7, 2)   (0, 5)   (3, 0)   (8, 9)   (2, 1)   (9, 4)   (2, 6)   (7, 8)   (8, 8)   (0, 3)   (7, 8)\njob_8    (6, 4)  (11, 8)

In [56]:
batch['labels'][0][:18*11].reshape(18,11)

tensor([[ 0.,  5., 13., 22., 25., 56., 68., 71., 78., 82., 83.],
        [ 0.,  5., 12., 12., 17., 18., 18., 23., 29., 84., 84.],
        [ 2., 18., 34., 43., 47., 48., 53., 62., 66., 66., 80.],
        [ 4., 16., 18., 28., 37., 47., 50., 60., 64., 70., 71.],
        [ 0., 10., 20., 28., 29., 37., 55., 67., 73., 81., 85.],
        [ 0.,  3.,  8., 11., 23., 32., 38., 45., 46., 68., 68.],
        [15., 38., 48., 48., 50., 55., 61., 70., 73., 75., 87.],
        [ 0.,  2.,  7.,  7., 16., 17., 21., 30., 38., 77., 80.],
        [ 0.,  4., 25., 29., 33., 40., 49., 50., 52., 58., 63.],
        [ 0.,  7., 19., 23., 32., 51., 59., 69., 74., 79., 87.],
        [ 0., 11., 11., 12., 50., 50., 51., 65., 75., 80., 81.],
        [ 0.,  9., 13., 18., 25., 33., 36., 43., 53., 70., 73.],
        [12., 19., 30., 33., 33., 49., 52., 56., 65., 71., 79.],
        [14., 46., 55., 58., 62., 64., 64., 70., 77., 81., 83.],
        [12., 32., 40., 45., 48., 49., 57., 57., 63., 70., 79.],
        [ 7., 16., 25., 2

In [61]:
y_hat[0].round()[:18*11].reshape(18, 11)

tensor([[  3.,   5.,  14.,  21.,  25.,  34.,  46.,  53.,  62.,  64.,  59.],
        [ 56.,  52.,  53.,  49.,  50.,  49.,  44.,  43.,  42.,  41.,  42.],
        [ 44.,  52.,  61.,  67.,  67.,  60.,  56.,  52.,  49.,  44.,  46.],
        [ 47.,  49.,  49.,  49.,  50.,  50.,  50.,  51.,  53.,  53.,  51.],
        [ 52.,  48.,  49.,  55.,  61.,  64.,  69.,  70.,  67.,  63.,  54.],
        [ 46.,  41.,  40.,  41.,  43.,  46.,  50.,  57.,  63.,  71.,  78.],
        [ 78.,  76.,  66.,  47.,  36.,  32.,  28.,  32.,  41.,  48.,  52.],
        [ 52.,  48.,  50.,  48.,  52.,  55.,  56.,  46.,  39.,  37.,  35.],
        [ 38.,  44.,  31.,  31.,  36.,  41.,  46.,  51.,  56.,  59.,  60.],
        [ 63.,  66.,  59.,  52.,  41.,  36.,  37.,  40.,  47.,  47.,  51.],
        [ 49.,  53.,  52.,  48.,  47.,  48.,  45.,  44.,  44.,  41.,  37.],
        [ 38.,  40.,  47.,  50.,  49.,  50.,  46.,  44.,  44.,  45.,  48.],
        [ 53.,  55.,  58.,  54.,  45.,  41.,  42.,  45.,  50.,  55.,  54.],
        [ 49

In [80]:
prob01 = torch.load("data/v1/val/problems/case_0_m14_j18_t11.pt")
sol01 = torch.load("data/v1/val/solutions/case_0_m14_j18_t11.pt")

In [None]:
prob01

torch.Size([18, 11, 2])

In [83]:
sol01.squeeze(-1)

tensor([[ 0,  5, 13, 22, 25, 56, 68, 71, 78, 82, 83],
        [ 0,  5, 12, 12, 17, 18, 18, 23, 29, 84, 84],
        [ 2, 18, 34, 43, 47, 48, 53, 62, 66, 66, 80],
        [ 4, 16, 18, 28, 37, 47, 50, 60, 64, 70, 71],
        [ 0, 10, 20, 28, 29, 37, 55, 67, 73, 81, 85],
        [ 0,  3,  8, 11, 23, 32, 38, 45, 46, 68, 68],
        [15, 38, 48, 48, 50, 55, 61, 70, 73, 75, 87],
        [ 0,  2,  7,  7, 16, 17, 21, 30, 38, 77, 80],
        [ 0,  4, 25, 29, 33, 40, 49, 50, 52, 58, 63],
        [ 0,  7, 19, 23, 32, 51, 59, 69, 74, 79, 87],
        [ 0, 11, 11, 12, 50, 50, 51, 65, 75, 80, 81],
        [ 0,  9, 13, 18, 25, 33, 36, 43, 53, 70, 73],
        [12, 19, 30, 33, 33, 49, 52, 56, 65, 71, 79],
        [14, 46, 55, 58, 62, 64, 64, 70, 77, 81, 83],
        [12, 32, 40, 45, 48, 49, 57, 57, 63, 70, 79],
        [ 7, 16, 25, 28, 29, 44, 50, 81, 81, 82, 85],
        [ 2,  5,  9, 24, 32, 39, 42, 48, 57, 71, 71],
        [23, 29, 33, 39, 48, 50, 50, 59, 61, 70, 72]])

## New Version of Data

In [86]:
prob02 = torch.load('data/v2/train/1000/problems/case_0_m6_j7_t17.pt')
sol02 = torch.load('data/v2/train/1000/solutions/case_0_m6_j7_t17.pt')

In [None]:
prob02.shape

tensor([[[ 0.,  0.,  4.,  1.],
         [ 0.,  1.,  5.,  2.],
         [ 0.,  2.,  5.,  5.],
         [ 0.,  3.,  0.,  5.],
         [ 0.,  4.,  1.,  6.],
         [ 0.,  5.,  3.,  3.],
         [ 0.,  6.,  1.,  3.],
         [ 0.,  7.,  4.,  8.],
         [ 0.,  8.,  5.,  6.],
         [ 0.,  9.,  3.,  9.],
         [ 0., 10.,  3.,  6.],
         [ 0., 11.,  1.,  5.],
         [ 0., 12.,  4.,  6.],
         [ 0., 13.,  2.,  9.],
         [ 0., 14.,  5.,  3.],
         [ 0., 15.,  5.,  9.],
         [ 0., 16.,  4.,  6.]],

        [[ 1.,  0.,  2.,  4.],
         [ 1.,  1.,  0.,  7.],
         [ 1.,  2.,  4.,  5.],
         [ 1.,  3.,  5., 10.],
         [ 1.,  4.,  0.,  9.],
         [ 1.,  5.,  2.,  5.],
         [ 1.,  6.,  0.,  1.],
         [ 1.,  7.,  5.,  2.],
         [ 1.,  8.,  1.,  5.],
         [ 1.,  9.,  1.,  1.],
         [ 1., 10.,  0.,  3.],
         [ 1., 11.,  1.,  4.],
         [ 1., 12.,  0.,  7.],
         [ 1., 13.,  2.,  6.],
         [ 1., 14.,  2.,  6.],
      

In [96]:
sol02.squeeze(-1)

tensor([[ 10,  26,  61,  66,  71,  77,  84,  87,  95, 101, 110, 116, 121, 127,
         136, 139, 149],
        [ 12,  16,  23,  46,  57,  69,  74,  77,  79,  97,  98, 101, 105, 115,
         121, 127, 148],
        [  0,   9,  23,  28,  36,  42,  46,  53,  59,  69,  79,  87,  97, 101,
         105, 111, 117],
        [  0,   1,   9,  15,  16,  41,  42,  52,  57,  71,  86,  93,  94, 107,
         117, 135, 145],
        [  0,  18,  25,  33,  42,  66,  69,  71,  77,  82,  86,  90, 100, 107,
         120, 130, 140],
        [  0,   2,  34,  42,  46,  50,  57,  75,  85,  88, 112, 128, 130, 134,
         136, 143, 150],
        [ 16,  25,  26,  32,  36,  47,  56,  61,  74,  82,  87,  95,  99, 121,
         134, 137, 149]])