In [1]:
# Import necessary modules
import sys
import os
import torch 
# Set root folder to project root
os.chdir(os.path.dirname(os.getcwd()))

# Add root folder to path
sys.path.append(os.getcwd())

from src.utils.config import Hyperparameters
from src.utils.model_loading import load_model
from src.utils.maze_loading import load_mazes
from src.utils.testing import is_correct
from src.utils.analysis import plot_mazes

In [None]:
# Load model
dt_net = load_model(pretrained='models/dt_net/original.pth')
pi_net = load_model(pretrained='models/pi_net/original.pth')

# Load mazes
hyperparams = Hyperparameters()
hyperparams.iters = 300
hyperparams.num_mazes = 10
hyperparams.percolation = 0.0
hyperparams.maze_size = 9
inputs, solutions = load_mazes(hyperparams)

# Predict
dt_predictions = dt_net.predict(inputs, iters=hyperparams.iters)
pi_predictions = pi_net.predict(inputs, iters=hyperparams.iters)

# Evaluate predictions
dt_net_corrects = is_correct(inputs, dt_predictions, solutions)
pi_net_corrects = is_correct(inputs, pi_predictions, solutions)
print(f'DT Net solved {dt_net_corrects.sum()}/{hyperparams.num_mazes} mazes correctly')
print(f'PI Net solved {pi_net_corrects.sum()}/{hyperparams.num_mazes} mazes correctly')

# Plot results
plot_mazes([('Input', inputs), ('Solution', solutions), ('DT-Net', dt_predictions), ('PI-Net', pi_predictions)])


2025-05-14 21:10:12,551 - src.utils.model_loading - INFO - Loaded model: dt_net from models/dt_net/original.pth to device: cuda:1
2025-05-14 21:10:12,580 - src.utils.model_loading - INFO - Loaded pi_net from models/pi_net/original.pth to device: cuda:1
2025-05-14 21:10:12,581 - src.utils.maze_loading - INFO - Attempting 10 mazes to generate 10 mazes with size: 9, percolation: 0.3, and deadend_start: True
2025-05-14 21:10:12,588 - src.utils.maze_loading - INFO - Attempting 20 mazes to generate 10 mazes with size: 9, percolation: 0.3, and deadend_start: True
2025-05-14 21:10:12,607 - src.utils.maze_loading - INFO - Loaded 10 mazes with size: 9, percolation: 0.3, and deadend_start: True


DT Net solved 0/10 mazes correctly
PI Net solved 0/10 mazes correctly
