In [1]:
import torch
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import torch
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm
import numpy as np
import os
import random
from model import UNetSmall
from load_dataset import PATH_TEST, PATH_TEST_T32, load_all_pairs
from ipywidgets import interact

In [10]:
unet_model = UNetSmall()
unet_model.load_state_dict(torch.load("gol_unet_best.pt"))
unet_model.eval()

if os.path.exists("dataset_test.pt"):
    X, Y = torch.load("dataset_test.pt")
else:
    X, Y = load_all_pairs(PATH_TEST)
    X = torch.tensor(X[:, None, :, :].astype("float32"))
    Y = torch.tensor(Y[:, None, :, :].astype("float32"))
    torch.save((X, Y), "dataset_test.pt")
test_ds = TensorDataset(X, Y)
print("Model & dataset loaded")

Model & dataset loaded


In [11]:
N = 10
random_indices = random.sample(range(len(test_ds)), N)


def plot_selected_sample(n):
    i = random_indices[n]
    x_frame, y_true = test_ds[i]

    img_size = int(x_frame.numel() ** 0.5)
    x_img = x_frame.view(1, 1, img_size, img_size)

    device = next(unet_model.parameters()).device
    x_img = x_img.to(device)

    with torch.no_grad():
        y_pred_logits = unet_model(x_img)
        y_pred_prob = torch.sigmoid(y_pred_logits)
        y_pred_binary = (y_pred_prob > 0.5).float()

    x_plot = x_img.squeeze().cpu()
    y_pred_plot = y_pred_binary.squeeze().cpu()
    y_true_plot = y_true.view(img_size, img_size).cpu()

    tp = (y_pred_plot == 1) & (y_true_plot == 1)
    fp = (y_pred_plot == 1) & (y_true_plot == 0)
    fn = (y_pred_plot == 0) & (y_true_plot == 1)

    rgb = torch.zeros(3, img_size, img_size)
    rgb[1][tp] = 1.0
    rgb[0][fp] = 1.0
    rgb[2][fn] = 1.0
    error_map = rgb.permute(1, 2, 0).numpy()

    pixel_acc = (y_pred_plot == y_true_plot).float().mean().item()
    iou = tp.sum() / (tp.sum() + fp.sum() + fn.sum() + 1e-9)

    fig, axes = plt.subplots(3, 2, figsize=(16, 16))
    axes[0, 0].imshow(x_plot, cmap="gray")
    axes[0, 0].set_title(f"Input Frame t (sample {i})")
    axes[0, 0].axis("off")

    axes[0, 1].imshow(y_pred_plot, cmap="gray")
    axes[0, 1].set_title("Predicted Frame t+1")
    axes[0, 1].axis("off")

    axes[1, 0].imshow(y_true_plot, cmap="gray")
    axes[1, 0].set_title("True Frame t+1")
    axes[1, 0].axis("off")

    axes[1, 1].imshow(error_map)
    axes[1, 1].set_title("Error Map\nG=TP  R=FP  B=FN")
    axes[1, 1].axis("off")

    red_cmap = ListedColormap(["black", "red"])
    blue_cmap = ListedColormap(["black", "blue"])

    axes[2, 0].imshow(fp, cmap=red_cmap)
    axes[2, 0].set_title("False Positives")
    axes[2, 0].axis("off")

    axes[2, 1].imshow(fn, cmap=blue_cmap)
    axes[2, 1].set_title("False Negatives")
    axes[2, 1].axis("off")

    plt.tight_layout()
    plt.show()

    print(f"Pixel Accuracy: {pixel_acc:.4f}")

    print(f"IoU (Intersection over Union): {iou:.4f}")


interact(plot_selected_sample, n=(0, N - 1))

interactive(children=(IntSlider(value=4, description='n', max=9), Output()), _dom_classes=('widget-interact',)…

<function __main__.plot_selected_sample(n)>

#### Now lets test the model, when 32 steps are required
Simply put, given nth frame, predict (n+32)th frame