# Inference

This notebook loads the trained CNN model from `train.ipynb` and uses it to generate predictions for a folder of unseen images.

All predictions are stored in a results.json file with the format:

## Setup



In [4]:
import os
import torch

# Must match the size used during training
pic_size = 64

# These must be identical to the training notebook
# If you saved them in a file, import them:
# from training_notebook_variables import class_names, idx_to_class
# Otherwise define manually:

class_names = sorted([
    d for d in os.listdir("./data/simpsons/archive/characters_train")
    if os.path.isdir(os.path.join("./data/simpsons/archive/characters_train", d))
])

class_to_idx = {name: i for i, name in enumerate(class_names)}
idx_to_class = {i: name for name, i in class_to_idx.items()}

device = "cuda" if torch.cuda.is_available() else "cpu"

## Image Processing

In [5]:
import cv2

def load_and_preprocess_image(path):
    """
    Load an image using OpenCV, convert to RGB, resize, normalize,
    and convert into a PyTorch tensor.
    """
    img = cv2.imread(path)
    if img is None:
        return None

    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = cv2.resize(img, (pic_size, pic_size))

    tensor = torch.tensor(img, dtype=torch.float32).permute(2, 0, 1) / 255.0
    return tensor.unsqueeze(0)  # Add batch dimension

## Inference Function

In [6]:
import json
from model import CNN4Conv

def infer(data_dir, model_path):
    """
    Load the trained model and perform inference on all images inside data_dir.
    Creates a results.json file mapping filename â†’ predicted class.
    """

    # Load model
    model = CNN4Conv(num_classes=len(class_names))
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()

    results = {}

    for filename in sorted(os.listdir(data_dir)):
        full_path = os.path.join(data_dir, filename)

        if not os.path.isfile(full_path):
            continue

        img_tensor = load_and_preprocess_image(full_path)
        if img_tensor is None:
            print(f"Skipping invalid image: {filename}")
            continue

        img_tensor = img_tensor.to(device)

        # Predict
        with torch.no_grad():
            logits = model(img_tensor)
            pred_idx = logits.argmax(1).item()
            pred_name = idx_to_class[pred_idx]

        results[filename] = pred_name

    # Save results
    with open("results.json", "w") as f:
        json.dump(results, f, indent=4)

    print(f"Saved results.json with {len(results)} predictions.")

## Run Inference

In [7]:
data_dir = "./data/simpsons/archive/characters_train/abraham_grampa_simpson" # change this
model_path = "simpsons_cnn4conv.pth"

infer(data_dir, model_path)

Saved results.json with 731 predictions.
