In [1]:
import matplotlib.pyplot as plt
import random
import torch

from data_utils import ucmerced_prepare_data
from train_utils import convnode_task

In [2]:
def visualize_predictions(model, dataloader, num_samples=25):
    """
    Visualize ground truth vs. predicted labels for random samples from the dataloader.
    
    Args:
        model (torch.nn.Module): Trained PyTorch model.
        dataloader (DataLoader): Dataloader for the dataset.
        num_samples (int): Number of random examples to visualize.
    """
    model.eval()  # Set model to evaluation mode
    all_images, all_preds, all_labels = [], [], []

    # ImageNet normalization stats
    mean = torch.tensor([0.4914, 0.4822, 0.4465])
    std = torch.tensor([0.2023, 0.1994, 0.2010])

    with torch.no_grad():
        for batch in dataloader:
            images, labels = batch
            logits = model(images)
            preds = torch.argmax(logits, dim=1)

            all_images.extend(images.cpu())
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # Randomly select samples
    indices = random.sample(range(len(all_images)), num_samples)
    selected_images = [all_images[i] for i in indices]
    selected_preds = [all_preds[i] for i in indices]
    selected_labels = [all_labels[i] for i in indices]

    # Create a grid of images
    fig, axes = plt.subplots(5, 5, figsize=(15, 15))
    axes = axes.flatten()
    for img, pred, label, ax in zip(selected_images, selected_preds, selected_labels, axes):
        # De-normalize the image
        img = img.permute(1, 2, 0)  # Convert from CHW to HWC
        img = img * std + mean  # De-normalize (undo normalization)
        img = img * 255.0  # Scale back to 0-255
        img = img.numpy().astype('uint8')  # Convert to uint8 for visualization

        # Plot the image
        ax.imshow(img)
        ax.axis('off')
        ax.set_title(f"GT: {label}\nPred: {pred}", color="green" if label == pred else "red")

    plt.tight_layout()
    plt.show()

In [3]:
train_loader, val_loader, labels = ucmerced_prepare_data.prepare_data()

ALL LABELS
['agricultural', 'airplane', 'baseballdiamond', 'beach', 'buildings', 'chaparral', 'denseresidential', 'forest', 'freeway', 'golfcourse', 'harbor', 'intersection', 'mediumresidential', 'mobilehomepark', 'overpass', 'parkinglot', 'river', 'runway', 'sparseresidential', 'storagetanks', 'tenniscourt']


In [4]:
model = convnode_task.UCMERCDNeuralODE(num_classes=21)

In [5]:
# visualize_predictions(trained_model, dataloader)