# Convolutional neural networks (CNN) for CIFAR-10/100 using PyTorch

Markus Enzweiler, markus.enzweiler@hs-esslingen.de

This is a demo used in a Computer Vision & Machine Learning lecture. Feel free to use and contribute.

We test a pretrained CNN for CIFAR-10 / CIFAR-100 image classification, see https://www.cs.toronto.edu/~kriz/cifar.html. We use the Python code from https://github.com/menzHSE/torch-cifar-10-cnn.git.

## Setup

Adapt `packagePath` to point to the directory containing this notebeook.

In [None]:
# Imports
import sys
import os
import subprocess

In [None]:
# Package Path
package_path = "./" # local
print(f"Package path: {package_path}")


def check_for_colab():
  try:
      import google.colab
      return True
  except ImportError:
      return False

# Running on Colab?
on_colab = check_for_colab()

In [None]:
# Clone git repository

# Absolute path of the repository directory
repo_dir = os.path.join(package_path, "torch-cifar-10-cnn")
repo_url = "https://github.com/menzHSE/torch-cifar-10-cnn.git"

# Store the original working directory
original_cwd = os.getcwd()

# Check if the directory already exists using the absolute path
if os.path.exists(os.path.join(original_cwd, repo_dir)):
    print("Repository exists. Resetting to HEAD...")
    # Navigate into the repository directory
    os.chdir(repo_dir)
    # Fetch the latest changes from the remote
    subprocess.run(["git", "fetch", "origin"])
    # Reset the local branch to the latest commit from the remote
    subprocess.run(["git", "reset", "--hard", "origin/HEAD"])
    # Change back to the original working directory
    os.chdir(original_cwd)
else:
    print("Cloning repository...")
    # Clone the repository if it doesn't exist
    subprocess.run(["git", "clone", repo_url, repo_dir])


In [None]:
# Install requirements in the current Jupyter kernel
req_file = os.path.join(repo_dir, "requirements.txt")
if os.path.exists(req_file):
    !{sys.executable} -m pip install -r {req_file}
else:
    print(f"Requirements file not found: {req_file}")

In [None]:
# Add the directory containing models.py to the system path
sys.path.append(os.path.join(package_path, 'torch-cifar-10-cnn'))

# Now we can import torch
import torch

# Now we can import the model and dataset
import model
import dataset
import device

# Inference on CIFAR-10



## Load the pretrained model and the CIFAR-10 test data

In [None]:
# device
dev = device.autoselectDevice()
print(f"Device: {dev}")

# load the dataset
cifar_version = "CIFAR-10"
batch_size = 64
img_size = 32 # we assume 32x32 pixel images
_, test_loader, classes = dataset.cifar(batch_size=batch_size, custom_transforms=None, cifar_version=cifar_version)
num_classes = len(classes)

# load the model 
model_fname = os.path.join(package_path, 'torch-cifar-10-cnn', 
    'pretrained_models', 'model_CIFAR-10.pth')

cnn = model.CNN(num_classes)
cnn.load(model_fname, dev)

print(f"Loaded model for {cifar_version} from {model_fname}")

## Run inference on a batch of test data

In [None]:
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torchvision

# put the model in evaluation mode and on the device we are using
cnn.eval()
cnn.to(dev)

# get a batch of testing images
data = next(iter(test_loader))

with torch.no_grad():
    # Get the testing data and push the data to the device we are using       
    images, labels = data[0].to(dev), data[1].to(dev)

    # Get the model predictions
    predictions = cnn(images)
    _, predicted_labels = torch.max(predictions, 1) 
    # CNN output posterior probability estimate via softmax
    probabilities = F.softmax(predictions, dim=1)  # softmax along the rows
    
    # plot the images in the batch, along with the corresponding labels and predictions

    # Make a grid from the batch
    grid = torchvision.utils.make_grid(images.cpu(), nrow=8, normalize=True, scale_each=True)    

    # Convert grid to a numpy image
    grid = grid.numpy().transpose((1, 2, 0))

    plt.figure(figsize=(15, 15))
    plt.imshow(grid)
    plt.axis('off')

    # Add labels
    for i in range(batch_size):
        row = i // 8
        col = i %  8

        # Ground truth and predicted class
        gt_label          = classes[labels[i]]
        pr_label          = classes[predicted_labels[i]]                                
        output_prob       = probabilities[i][predicted_labels[i]]
               
        if gt_label == pr_label:
            # Green label for correct predictions
            label_color = 'lightgreen'
        else:
            # Red label for incorrect predictions
            label_color = 'red'

        plt.text(col*(img_size+2)+2, row*(img_size+2)+4, f'T: {gt_label}', 
            color=label_color, backgroundcolor='black')

        plt.text(col*(img_size+2)+2, row*(img_size+2)+10, f'P: {pr_label} ({output_prob:.2f})', 
            color=label_color, backgroundcolor='black')

    plt.title("True (T) and Predicted Classes (P) with confidence value for CIFAR-10")
    plt.show()