How to train a neural network to classify images using the CIFAR-10 dataset

TABLE OF CONTENTS:
1. Introduction
2. Dataset and Libraries
3. Data Preprocessing and Visualisation
4. Model Definition and Training
5. Hyperparameter Tuning
6. Evaluation and Results
7. Conclusion
8. References

ABSTRACT:
This tutorial demonstrates how to use Jupyter Notebook to develop a machine learning pipeline for image classification using PyTorch. 
This tutorial will utilise the CIFAR-10 dataset, which consists of 10 different object categories. 
The topics covers within include: data preprocessing, model training, evaluation, and hyperparameter tuning.
By the end of this guide, users will have a fully functional AI system capable of classifying images 
and understanding how different hyperparameters affect model performance.

To begin, it is important to ensure that our machine learning pipeline is able to access all appropriate libraries. To achieve this, we can import that which we require. For this tutorial, we will use the libraries Pytorch, Matplotlib and Numpy. As such, we will import them as shown below:

In [3]:
# Importing the required libraries
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
# PyTorch Library: PyTorch Team. (2025). torch module documentation. Retrieved February 25, 2025, from https://pytorch.org/docs/stable/index.html
# Torchvision: PyTorch Team. (2025). torchvision module documentation. Retrieved February 25, 2025, from https://pytorch.org/vision/stable/index.html
# Matplotlib: Hunter, J. D. (2007). Matplotlib: A 2D Graphics Environment. Computing in Science & Engineering, 9(3), 90-95. Retrieved February 25, 2025, from https://matplotlib.org/stable/index.html
# NumPy: Harris, C. R., et al. (2020). Array programming with NumPy. Nature, 585(7825), 357–362. Retrieved February 25, 2025, from https://numpy.org/doc/stable/

Following the importing of the required libraries, it is important to give our model the ability to train on an accelerator. An accelerator is a
device that can be used alongside the CPU to speed up the computation of our machine learning model.

In [4]:
# Step 1: Define Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')
# [Source: https://pytorch.org/docs/stable/notes/cuda.html]
# https://pytorch.org/tutorials/beginner/basics/buildmodel_tutorial.html#model-layers 

Using device: cpu


The next stage of the machine learning pipeline is the loading of normalisation of the data that is to be used to train the model.
This data will be ingested by the pipeline, and will be used to teach the pipeline how to separate the data into appropriate groups.
In this example, we will be using image data from the CIFAR10. The pipeline will then use these images and their classifications to learn which
attributes are present in each group, and will gain the ability to identify which images belong in each group based on these attributes.
We will be modifying the data such that it is appropriate for utility within our network.

In [5]:
# Step 2: Load and Normalise Data
# Download and dataloader data from the CIFAR10 as shown in https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
# We use transforms to change attributes of the data to make it appropriate for the pipeline.
data_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),  # Randomly flips images horizontally to introduce variation
    transforms.RandomRotation(10),  # Rotates images by a small angle to enhance model robustness
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),  # Adjusts brightness/contrast
    transforms.ToTensor(),  # Converts images to tensors for PyTorch compatibility
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalizes pixel values to improve learning stability
])
# Loading CIFAR-10 dataset [Source: https://pytorch.org/vision/stable/datasets.html#torchvision.datasets.CIFAR10]
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, transform=data_transform, download=True)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, transform=data_transform, download=True)

# DataLoader allows efficient batch loading [Source: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader]
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

# https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
# PyTorch Team. (2025). torch.cuda documentation. Retrieved February 25, 2025, from https://pytorch.org/docs/stable/cuda.html

Files already downloaded and verified
Files already downloaded and verified


Next, we will define the Convolutional Neural Network to be used in our machine learning pipeline.
As we are training our pipeline to be able to identify image data, a CNN is most approrpiate. This is due to the fact that CNNs are able to 
automatically capture spatial hierarchies in images, including that of edges and textures. CNNs will also be able to identify any more patterns that 
occur within the data set.
Within this neural network, we will use multiple convolutional layers, batch normalisation, ReLU and pooling.

In [6]:
# Step 3: Define CNN Model with Batch Normalisation and Dropout
class CNN(nn.Module):  # [Source: PyTorch Official Examples]
    def __init__(self):
        super(CNN, self).__init__()

        # First convolutional layer extracts low-level features [Source: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html]
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)  # Increased Filters
        self.bn1 = nn.BatchNorm2d(64)

        # Second convolutional layer extracts mid-level features.
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(128)

        # Third convolutional layer extracts deeper patterns.
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)  # Added Extra Layer
        self.bn3 = nn.BatchNorm2d(256)

        # Max pooling reduces spatial dimensions while retaining important features. [Source: https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html]
        self.pool = nn.MaxPool2d(2, 2)

        # Dropout prevents overfitting [Source: https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html]
        self.dropout = nn.Dropout(0.4)  # Adjusted Dropout

        # Fully connected layers for classification.
        self.fc1 = nn.Linear(256 * 4 * 4, 256)  # Adjusted Fully Connected Layer
        self.fc2 = nn.Linear(256, 10)

    # Define the forward pass of the neural network
    def forward(self, x):
        x = self.pool(torch.relu(self.bn1(self.conv1(x)))) # Apply first convolution and ReLU activation
        x = self.pool(torch.relu(self.bn2(self.conv2(x)))) # Apply second convolution and ReLU activation
        x = self.pool(torch.relu(self.bn3(self.conv3(x)))) # Apply third convolution and ReLU activation
        x = torch.flatten(x, 1) # Flattens feature maps into a vector for FC layers [Source: https://pytorch.org/docs/stable/generated/torch.flatten.html]
        x = torch.relu(self.fc1(x))
        x = self.dropout(x) # Applies dropout for regularisation [Source: https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html]
        x = self.fc2(x) # Final classification layer [Source: https://pytorch.org/docs/stable/generated/torch.nn.Linear.html]
        return x

# Create an instance of the model and move it to the selected device.
model = CNN().to(device)

# https://medium.com/@myringoleMLGOD/simple-convolutional-neural-network-cnn-for-dummies-in-pytorch-a-step-by-step-guide-6f4109f6df80
# PyTorch Team. (2025). torchvision.transforms documentation. Retrieved February 25, 2025, from https://pytorch.org/vision/stable/transforms.html
# Normalisation values for CIFAR-10: Krizhevsky, A. (2009). Learning Multiple Layers of Features from Tiny Images. University of Toronto.

The next stage of the machine learning pipeline is to use the torch.optim package to implement different optimisation algorithms.
Within this tutorial, I will be using the cross entropy loss, stochastic gradient descent and cosine annealing learning rate enhancements.
Each of these optimisation algorithms can be used to increase the accuracy of the model produced.
The loss function used is CrossEntropyLoss, commonly used for multi-class classification problems. We will use it within our pipeline to quantify how far the predicted values are from the true labels.

In [5]:
# Step 4: Define Loss, Optimiser, and Scheduler

criterion = nn.CrossEntropyLoss() # Cross-entropy loss for multi-class classification
optimiser = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)  # Stochastic Gradient Descent with momentum 
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=15)  # Cosine Annealing Learning Rate 

# [Source: https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html]
# [Source: https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.CosineAnnealingLR.html]
# [Source: https://pytorch.org/docs/stable/generated/torch.optim.SGD.html]

The next stage of the machine learning pipeline is to feed the data present in the dataset into our model. One way that we can alter the accuracy of the model is by changing the number of epochs that our model is trained using. Each epoch is an iteration of the entire training dataset. As such, increasing the number of epochs used to train our model will increase the ability of our model to identify the data in the data set. However, it is important to not make the number of epochs used too large, as it can lead to overfitting, which is the occurance of an AI only being able to accurately predict the training data, and innacurately deals with the validation and testing data.
Within the training loop we will include the functionality of calculating and demonstrating the accuracy of the model both as text and graphically through the use of the Matplotlib library.
This training loop uses backpropagation and an optimiser to teach the AI, allowing for increased performance and predictive ability.
We calculate the loss of the model in two different modes, that of the training mode and the evaluation mode. This is done to test the model under separate conditions, testing how well the model deals with both seen and unseen data. This is done to measure if the model is overfitting.

In [None]:
# Step 5: Training Loop
def train_model(model, train_loader, test_loader, criterion, optimiser, scheduler, num_epochs=15):
    model.train()
    train_losses, val_losses = [], [] # Arrays used to store losses 
    train_accuracies, val_accuracies = [], [] # Arrays used to store calculated accuracies
    epochs = []

    
    for epoch in range(num_epochs):
        running_loss, correct_train, total_train = 0.0, 0, 0 # Values that store the number of 
        for images, labels in train_loader: # For each image and classification within the set of training data
            images, labels = images.to(device), labels.to(device) #Send the input to the device
            optimiser.zero_grad() # Sets the grads to zero, increasing performance
            outputs = model(images)
            loss = criterion(outputs, labels) # Calculates how far the predicted values are from the true labels
            loss.backward() # Backpropagation to compute gradients 
            optimiser.step()
            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total_train += labels.size(0)
            correct_train += (predicted == labels).sum().item()
        
        train_loss = running_loss / len(train_loader)
        train_acc = 100 * correct_train / total_train
        train_losses.append(train_loss)
        train_accuracies.append(train_acc)
        epochs.append(epoch + 1)
        
        model.eval() # Sets the model to evaluation mode for validation
        val_loss, correct_val, total_val = 0.0, 0, 0
        with torch.no_grad(): # Disable gradient computation for validation
            for images, labels in test_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels) # Compute validation loss
                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                total_val += labels.size(0)
                correct_val += (predicted == labels).sum().item()
        
        val_loss /= len(test_loader)
        val_acc = 100 * correct_val / total_val
        val_losses.append(val_loss)
        val_accuracies.append(val_acc)
        scheduler.step() # Adjust learning rate
        
        print(f'Epoch {epoch+1}/{num_epochs} | Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}% | Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
        model.train() # Switch back to training mode
        
    # Plot Training Results
    plt.figure(figsize=(12, 6))
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_losses, label='Train Loss')
    plt.plot(epochs, val_losses, label='Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Loss Over Epochs')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(epochs, train_accuracies, label='Train Accuracy')
    plt.plot(epochs, val_accuracies, label='Validation Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy (%)')
    plt.title('Accuracy Over Epochs')
    plt.legend()
    plt.show()
    
    return train_losses, train_accuracies, val_losses, val_accuracies

train_model(model, train_loader, test_loader, criterion, optimiser, scheduler, num_epochs=15) # Run the model
# https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html
# https://pytorch.org/docs/stable/tensors.html#torch.Tensor.to
# https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.forward
# https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
# https://pytorch.org/docs/stable/autograd.html
# https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.CosineAnnealingLR.html
# https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.step.html
# https://pytorch.org/docs/stable/generated/torch.max.html

Epoch 1/15 | Train Loss: 1.3726, Train Acc: 50.15% | Val Loss: 1.0860, Val Acc: 60.70%


In [None]:
After training the AI model, we can use the model to classify images that are present within the training set.
The following code graphs an image, along with the group that the model identifies it to fit within.

In [None]:
# Step 6: Visualising Predictions
def visualise_predictions(model, test_loader):
    classes = ('airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
    model.eval()
    images, labels = next(iter(test_loader))
    images, labels = images.to(device), labels.to(device)
    outputs = model(images)
    _, preds = torch.max(outputs, 1)
    
    fig, axes = plt.subplots(3, 3, figsize=(8, 8))
    axes = axes.flatten()
    for i in range(9):
        img = images[i].cpu().numpy().transpose((1, 2, 0))
        img = (img * 0.5) + 0.5  # Unnormalise
        axes[i].imshow(img)
        axes[i].set_title(f'True: {classes[labels[i]]}\nPred: {classes[preds[i]]}')
        axes[i].axis('off')
    plt.show()

visualise_predictions(model, test_loader)  # Call visualisation function

To make the predictions more accurate, we can change the number of epochs, 