In [1]:
# Import necessary modules

import sys
import os

# Set root folder to project root
os.chdir(os.path.dirname(os.getcwd()))

# Add root folder to path
sys.path.append(os.getcwd())

import numpy as np
import torch
from matplotlib import pyplot as plt

from src.utils.testing import compare_mazes
from src.utils.loading import load_model, get_mazes
from src.utils.plotting import plot_mazes

In [2]:
# Load model and mazes

model = load_model('dt_net')

inputs, solutions = get_mazes(
    dataset='maze-dataset', 
    maze_size=9, 
    num_mazes=30,
    percolation=0.0,
    deadend_start=True)

Using device: cuda
Loaded pi_net to cuda


In [3]:
# Compute predictions and compare to solutions

predictions = torch.zeros_like(solutions)
for i in range(inputs.shape[0]):
    predictions[i:i+1] = model.predict(inputs[i:i+1], iters=300)

corrects = torch.tensor(compare_mazes(predictions, solutions), dtype=torch.bool)
incorrects = ~corrects

In [4]:
# Plot incorrect predictions

if incorrects.any():
    plot_mazes(inputs[incorrects], 
            predictions=predictions[incorrects], 
            solutions=solutions[incorrects], 
            file_name=f'outputs/mazes/{model.name()}_incorrects.pdf')
else:
    print('No incorrect predictions found.')

No incorrect predictions found.


In [5]:
incorrect_inputs = inputs[incorrects].cpu().numpy()
incorrect_inputs = np.moveaxis(incorrect_inputs, 1, -1) # Move RGB axis to last

plt.imshow(incorrect_inputs[0], cmap='gray')

IndexError: index 0 is out of bounds for axis 0 with size 0