In [1]:
# Import necessary modules
import sys
import os
import torch 
# Set root folder to project root
os.chdir(os.path.dirname(os.getcwd()))

# Add root folder to path
sys.path.append(os.getcwd())

from src.utils.config import Hyperparameters
from src.utils.model_loading import load_model, summarize_models
from src.utils.maze_loading import load_mazes
from src.utils.testing import is_correct
from src.utils.analysis import plot_mazes

In [2]:
summarize_models()

2025-06-12 14:12:36,897 - src.utils.model_loading - INFO - Loaded model: ff_net from models/ff_net/2025-04-21_00:26:24/best.pth to device: cuda:0
2025-06-12 14:12:36,935 - src.utils.model_loading - INFO - Loaded pi_net from models/pi_net/original.pth to device: cuda:0
2025-06-12 14:12:36,947 - src.utils.model_loading - INFO - Loaded model: dt_net from models/dt_net/original.pth to device: cuda:0
2025-06-12 14:12:36,959 - src.utils.model_loading - INFO - Loaded model: it_net from models/it_net/2025-03-05_13:43:46/best.pth to device: cuda:0


Model   Params (M)  Size (MB)
-----------------------------
ff_net  8.89        33.91    
pi_net  0.78        2.99     
dt_net  0.78        2.99     
it_net  1.37        5.24     


In [None]:
# Load model
dt_net = load_model(pretrained='models/dt_net/original.pth')
pi_net = load_model(pretrained='models/pi_net/original.pth')
it_net = load_model(pretrained='models/it_net/2025-03-27_16:16:36/best.pth') # Train percolation 0.0

# Load mazes
hyperparams = Hyperparameters()
hyperparams.iters = 300
hyperparams.num_mazes = 20
hyperparams.percolation = 0.2 # 0.5 has good examples
hyperparams.maze_size = 11
hyperparams.dataset_name = 'easy-to-hard-data'
inputs, solutions = load_mazes(hyperparams)
print(f'{inputs.shape = }')
print(f'{solutions.shape = }')

# Predict
dt_predictions = dt_net.predict(inputs, iters=hyperparams.iters)
pi_predictions = pi_net.predict(inputs, iters=hyperparams.iters)
it_predictions = it_net.predict(inputs, iters=hyperparams.iters)

# Evaluate predictions
dt_net_corrects = is_correct(inputs, dt_predictions, solutions)
print(f'DT Net solved {dt_net_corrects.sum()}/{hyperparams.num_mazes} mazes correctly')
pi_net_corrects = is_correct(inputs, pi_predictions, solutions)
print(f'PI Net solved {pi_net_corrects.sum()}/{hyperparams.num_mazes} mazes correctly')
it_net_corrects = is_correct(inputs, it_predictions, solutions)
print(f'IT Net solved {it_net_corrects.sum()}/{hyperparams.num_mazes} mazes correctly')

# Plot results
plot_mazes([
    ('Input', inputs),
    ('Solution', solutions),
    ('DT-Net', dt_predictions),
    ('PI-Net', pi_predictions),
    ('IT-Net', it_predictions)
]);

2025-06-12 14:12:36,979 - src.utils.model_loading - INFO - Loaded model: dt_net from models/dt_net/original.pth to device: cuda:0


2025-06-12 14:12:37,006 - src.utils.model_loading - INFO - Loaded pi_net from models/pi_net/original.pth to device: cuda:0
2025-06-12 14:12:37,019 - src.utils.model_loading - INFO - Loaded model: it_net from models/it_net/2025-03-27_16:16:36/best.pth to device: cuda:0


Downloading https://cs.umd.edu/~tomg/download/Easy_to_Hard_Datav2/maze_data_test_11.tar.gz


Downloaded 0.00 GB: 100%|██████████| 4/4 [00:00<00:00, 13.96it/s]


Loading mazes of size 11 x 11.


RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [0]