In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet50
from torch.utils.data import DataLoader

In [2]:
# Set device
device = torch.device(
    "mps" if torch.backends.mps.is_available()
    else "cuda" if torch.cuda.is_available()
    else "cpu"
)
print(f"Using device: {device}")

Using device: mps


In [3]:
# Define transforms
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

# Load CIFAR-10 test dataset
test_data   = torchvision.datasets.CIFAR10(root="../data", train=False, download=True, transform=transform)
test_loader = DataLoader(test_data, batch_size=10, shuffle=False)

Files already downloaded and verified


In [4]:
# Load pre-trained ResNet-50 model
model = resnet50()

# Modify the ResNet-50 architecture for CIFAR-10
model.conv1   = torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
model.maxpool = torch.nn.Identity()
model.fc      = torch.nn.Linear(model.fc.in_features, 10)

# Load pre-trained weights
model.load_state_dict(torch.load("../data/models/resnet50_cifar10.pt"))  
model.to(device)
model.eval()

# Define CIFAR-10 classes
classes = ("plane", "car", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck")

In [5]:
# Get a batch of test images
images, labels = next(iter(test_loader))
images, labels = images.to(device), labels.to(device)

# Perform inference
with torch.no_grad():
    outputs = model(images)
    _, predicted = torch.max(outputs, 1)

# Print results
print("Prediction Results:")
for i in range(10):
    print(f"Image {i+1}:")
    print(f"  Predicted: {classes[predicted[i]]}")
    print(f"  Actual:    {classes[labels[i]]}")
    print(f"  Correct:   {predicted[i] == labels[i]}\n")

# Calculate and print accuracy for this batch
correct = (predicted == labels).sum().item()
print(f"Batch Accuracy: {correct / 10 * 100:.2f}%")

Prediction Results:
Image 1:
  Predicted: plane
  Actual:    cat
  Correct:   False

Image 2:
  Predicted: plane
  Actual:    ship
  Correct:   False

Image 3:
  Predicted: plane
  Actual:    ship
  Correct:   False

Image 4:
  Predicted: plane
  Actual:    plane
  Correct:   True

Image 5:
  Predicted: cat
  Actual:    frog
  Correct:   False

Image 6:
  Predicted: plane
  Actual:    frog
  Correct:   False

Image 7:
  Predicted: cat
  Actual:    car
  Correct:   False

Image 8:
  Predicted: cat
  Actual:    frog
  Correct:   False

Image 9:
  Predicted: plane
  Actual:    cat
  Correct:   False

Image 10:
  Predicted: plane
  Actual:    car
  Correct:   False

Batch Accuracy: 10.00%
