### Set up model and define necessary stuff

In [None]:
import os
from typing import Type
import torchvision
from torchvision.transforms import v2 as transforms
import torch
import torch.nn as nn
from models.centernet import ModelBuilder
from data.dataset_visualizer import ImageDatasetWithLabels
from torch.utils.data import Subset
from postprocess_visual.object_detection_visualizer import (
    ObjectDetectionVisualizer,
)


Device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def load_model(model_type: Type[nn.Module], checkpoint_path: str = None):
    checkpoint_path = (
        "../models/checkpoints/pretrained_weights.pt"
        if checkpoint_path is None
        else checkpoint_path
    )
    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")

    model = model_type(alpha=0.25).to(Device)
    model.load_state_dict(
        torch.load(
            checkpoint_path,
            map_location=Device,
            weights_only=True,
        )
    )
    model.eval()
    return model


def transform_dataset(dataset):
    """Transform the dataset for visualization"""
    transform = transforms.Compose(
        [
            transforms.Resize(size=(256, 256)),
        ]
    )

    return ImageDatasetWithLabels(dataset=dataset, transformation=transform)


def get_predictions(model, dataset):
    """Get model predictions for the given dataset"""
    transform = transforms.Compose(
        [
            transforms.ToImage(),
            transforms.ToDtype(torch.float32, scale=True),
        ]
    )

    predictions = []
    for _, orig_img in enumerate(dataset):
        # Apply transformations
        img = transform(orig_img)
        img = img.unsqueeze(0).to(Device)

        # Inference
        with torch.no_grad():
            pred = model(img)

        predictions.append(pred)

    return predictions


model = load_model(ModelBuilder)

### Get predictions for defined dataset

In [None]:
def prepare_dataset():
    # Load VOC dataset
    dataset_val = torchvision.datasets.VOCDetection(
        root="../VOC", year="2007", image_set="val", download=False
    )
    dataset_val = torchvision.datasets.wrap_dataset_for_transforms_v2(dataset_val)

    # Define a dataset that is a subset of the initial dataset
    indices = range(10)
    dataset_val = Subset(dataset_val, indices)

    return dataset_val

dataset = prepare_dataset()

# Transform the dataset to the correct form for further processing
dataset_transformed = transform_dataset(dataset)

# Get predictions
predictions = get_predictions(model, dataset_transformed)

### Predictions visualization

In [None]:
# Create visualizer with default settings
visualizer = ObjectDetectionVisualizer(
    dataset=dataset_transformed, input_height=256, input_width=256, confidence_threshold=0.3
)

# Visualize predictions
visualizer.visualize_predictions(predictions)