In [None]:
"""
Helper script which can be used to debug and evaluate TimTamNet model quality. 
"""
from os.path import join
import numpy as np
from skimage.io import imread
import matplotlib.pyplot as plt
from tqdm import tqdm
from utils import get_data
from timtamnet import TimTamNet

In [None]:
# Constants.
base_directory = "../../"
model_path = join(base_directory, "out/models/weights.06-0.501515.hdf5")
robot_directory = join(base_directory, "out/data/robot--4209387126734636757/")

In [None]:
# Get test data.
_, x_test, _, y_test = get_data(robot_directory)
input_shape = x_test.shape[1:]
image_size = input_shape[:-1]

In [None]:
# Load TimTamNet and weights
model = TimTamNet(input_shape=input_shape)
model.load_weights(model_path)

In [None]:
# Constants
num_samples = 6
num_rows = 3
num_cols = num_samples
inch_size = 2.7

# Get random sample from test data.
np.random.seed()
indices = np.random.permutation(len(x_test))[:num_samples]
original = x_test[indices]
truth = y_test[indices]

# Predict through model.
prediction = model.predict(original)

In [None]:
# Visually evaluate results.
# Note: if colours/contrast look incorrect, its just because of matplotlib auto-setting vmin and vmax.
# This only really occurs when the robot finds a path to the goal without invalid footsteps, which
# throws off the data range of colours.
plt.figure(figsize=(num_cols * inch_size, num_rows * inch_size))
plt.gray()
for i in range(num_samples):
    plt.subplot(num_rows, num_cols, i + 1 + 0 * num_cols)
    plt.imshow(original[i].reshape(*image_size), vmin=0, vmax=1)
    plt.subplot(num_rows, num_cols, i + 1 + 1 * num_cols)
    plt.imshow(truth[i].reshape(*image_size), vmin=0, vmax=1)
    plt.subplot(num_rows, num_cols, i + 1 + 2 * num_cols)
    plt.imshow(prediction[i].reshape(*image_size), vmin=0, vmax=1)
plt.show()