# Comparative Inference Analysis

In [None]:

from pathlib import Path
import numpy as np
from torch.utils.data import DataLoader

from dnn_guidance.data_loader import PathfindingDataset
from dnn_guidance.inference import InferenceHandler
from dnn_guidance.model import UNetFiLM, HRFiLMNet, ResNetFPNFiLM
from utils.visualization import plot_inference_comparison


In [None]:

# Map model names to (class, checkpoint path)
MODEL_REGISTRY = {
    'UNet-FiLM': (UNetFiLM, Path('models/dnn_guidance/unet_film.pth')),
    'HR-FiLM-Net': (HRFiLMNet, Path('models/dnn_guidance/hr_film_net.pth')),
    'ResNet-FPN-FiLM': (ResNetFPNFiLM, Path('models/dnn_guidance/resnet_fpn_film.pth')),
}


In [None]:

# Create test DataLoader
samples_dir = Path('data/processed/test/samples')
gt_dir = Path('data/processed/test/heatmaps')
dataset = PathfindingDataset(samples_dir, gt_dir, augment=False)
loader = DataLoader(dataset, batch_size=1, shuffle=False)


In [None]:

for (grid_tensor, robot_tensor), heatmap_tensor in loader:
    g = grid_tensor.squeeze(0).numpy()
    start = g[0] > 0.5
    goal = g[1] > 0.5
    obstacles = g[3] > 0.5
    grid = np.zeros_like(obstacles, dtype=np.uint8)
    grid[obstacles] = 1
    grid[start] = 8
    grid[goal] = 9
    robot = robot_tensor.squeeze(0).numpy()
    gt_heatmap = heatmap_tensor.squeeze().numpy()
    predictions_to_plot = {}
    for name, (cls, ckpt) in MODEL_REGISTRY.items():
        handler = InferenceHandler(cls, ckpt, device='cpu')
        heatmap = handler.predict(grid, robot)
        predictions_to_plot[name] = heatmap
    plot_inference_comparison({
        'input_grid': grid,
        'ground_truth_heatmap': gt_heatmap,
        'predictions': predictions_to_plot,
    })
    break


Analysis notes go here.