In [None]:
import os
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import json

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [None]:
class SimpsonsCNN(nn.Module):
    def __init__(self, num_classes):
        super(SimpsonsCNN, self).__init__()

        self.block1 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )

        self.block2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )

        self.block3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )

        self.block4 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )

        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Dropout(0.5),
            nn.Linear(256 * 8 * 8, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.classifier(x)
        return x


In [None]:

def infer(data_dir, model_path):
    # Load checkpoint
    checkpoint = torch.load(model_path, map_location=device)
    class_names = checkpoint['class_names']

    # Initialize model
    model = SimpsonsCNN(len(class_names)).to(device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()

    # Transform
    transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    # Collect all images
    results = {}

    for img_name in os.listdir(data_dir):
        if not img_name.endswith(('.jpg', '.jpeg', '.png')):
            continue

        img_path = os.path.join(data_dir, img_name)
        full_path = os.path.abspath(img_path)

        img = Image.open(img_path).convert('RGB')
        img_tensor = transform(img).unsqueeze(0).to(device)

        with torch.no_grad():
            output = model(img_tensor)
            _, predicted = torch.max(output, 1)
            predicted_class = class_names[predicted.item()]
        # results[full_path] = predicted_class
        results[full_path] = {
            "true": os.path.basename(os.path.dirname(full_path)),
            "pred": predicted_class
        }
        # results[img_name] = predicted_class

    # Save results
    with open('results.json', 'w') as f:
        json.dump(results, f, indent=2)

    print(f'Inference complete. Processed {len(results)} images.')
    print('Results saved to results.json')

    return results


In [None]:
# Execute inference
TEST_DIR = 'test'  # Change this to your test directory path
MODEL_PATH = 'model.pth'

results = infer(TEST_DIR, MODEL_PATH)

# Display sample results
print('\nSample predictions:')
for i, (img_name, pred_class) in enumerate(list(results.items())[:5]):
    print(f'{img_name}: {pred_class}')