In [2]:
import torch

device = "cpu"

if torch.cuda.is_available():
    device = "cuda"

elif torch.backends.mps.is_available():
    device = "mps"

print(f'Using device: {device}')

Using device: mps


In [3]:
# we're going to crop all images to this resolution
image_resolution = 64

# how many colour channels?
# 3 for RGB, 1 for greyscale
colour_channels = 3

# make sure the path to your image folder is correct
folder_path = './data/my_dataset'

In [4]:
from torchvision.transforms.v2 import Compose, ToImage, ToDtype, Grayscale, Resize, RandomCrop, Normalize

# this transformation function will help us pre-process images during the training (on-the-fly)
transformation = Compose([   
    # convert an image to tensor
    ToImage(),
    ToDtype(torch.float32, scale=True),
    
    # resize and crop
    Resize(image_resolution),
    RandomCrop(image_resolution),
    
    # normalise pixel values to be between -1 and 1 
    Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    
    # if you're training on greyscale images, convert them to greyscale
    # otherwise just do nothing
    Grayscale() if colour_channels == 1 else torch.nn.Identity()
    
])

In [5]:
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

import torch

from IPython.display import display

from torchvision.transforms.functional import to_pil_image
from torchvision.utils import make_grid

In [None]:
# how many images are going to be put into the testing set, 
# e.g. 0.2 means 20% percent of images
test_size = 0.2 

# how many images will be used in one epoch, 
# this usually depend on your model / types of data / CPU or GPU's capability
batch_size = 16

In [13]:
# Instatiate train and test datasets
train_dataset = ImageFolder(folder_path, transform=transformation)
test_dataset = ImageFolder(folder_path, transform=transformation)

# Get length of dataset and indicies
num_train = len(train_dataset)
indices = list(range(num_train))

# Get train / test split for data points
train_indices, test_indices = train_test_split(indices, test_size=test_size, random_state=42)

# Override dataset classes to only be samples for each split
train_sub_dataset = torch.utils.data.Subset(train_dataset, train_indices)
test_sub_dataset = torch.utils.data.Subset(test_dataset, test_indices)

# Create training and tresing data loaders
train_loader = DataLoader(train_sub_dataset, batch_size=batch_size, num_workers=2, multiprocessing_context="forkserver" if device=='mps' else None, shuffle=True)
test_loader = DataLoader(test_sub_dataset, batch_size=batch_size, num_workers=2, multiprocessing_context="forkserver" if device=='mps' else None, shuffle=True)


In [None]:
# sort out the names for each classes (according to the folder names)
class_names = train_dataset.classes

print(f'{len(train_indices)} training images loaded')
print(f'{len(test_indices)} testing images loaded')
print(f'classes: {class_names}')


In [None]:
data, labels = next(iter(train_loader))

print(f'data shape: {data.shape}')
print(f'labels shape: {labels.shape}')

In [None]:
# make a grid of images and display them

grid = make_grid(data, nrow = 8)
display(to_pil_image(grid.add(1).div(2)))

print([f'{i}: {class_names[class_name]}' for i, class_name in enumerate(labels)])

In [17]:
from src.model import ConvNeuralNetwork

In [None]:
model = ConvNeuralNetwork(img_channel = colour_channels, 
                          img_resolution = image_resolution,
                          num_classes = len(class_names))
model.to(device)

In [19]:
# Cross entropy loss
loss_function = torch.nn.CrossEntropyLoss()

# Adam optimizer
optimizer = torch.optim.Adam(model.parameters())

In [None]:
# we can save the model regularly
save_every_n_epoch = 5

# total number of epochs we aim for
num_epochs = 8

# keep track of the losses, we can plot them in the end
train_losses = []
test_losses = []

print('Epoch 0')

for epoch in range(num_epochs): 

    #---- Training loop -----------------------------
    train_loss = 0.0
    model.train()
    
    for i, data in enumerate(train_loader, 0):
        # Load: The training data loader loads a batch of training data and their true class labels.
        inputs, true_labels = data
        inputs = inputs.to(device)
        true_labels = true_labels.to(device)
        
        # Pass: Forward pass the training data to our model, and get the predicted classes.
        pred_labels = model(inputs)
        
        # Loss: The loss function compares the predicted classes to the true classes, and calculates the error.
        loss = loss_function(pred_labels, true_labels)
        train_loss += loss.item()
        
        # Optimise: The optimizer slightly optimises our model based on the error.
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        if i % 50 == 0:
            print(f'  -> Step {i + 1:04}, train loss: {loss.item():.4f}')
    
    
    #---- Testing loop -----------------------------
    test_loss = 0.0
    model.eval()
    
    with torch.inference_mode():
        test_loss = 0.0
        for i, data in enumerate(test_loader, 0):
            # Load: The testing data loader loads a batch of testing data and their true class labels.
            inputs, true_labels = data
            inputs = inputs.to(device)
            true_labels = true_labels.to(device)
            
            # Pass: Forward pass the testing data to our model, and get the predicted classes.
            pred_labels = model(inputs)
            
            # Loss: The loss function compares the predicted classes to the true classes, and calculates the error.
            loss = loss_function(pred_labels, true_labels)
            test_loss += loss.item()
    
    
    #---- Report some numbers -----------------------------
    
    # Calculate the cumulative losses in this epoch
    train_loss = train_loss / len(train_loader)
    test_loss = test_loss / len(test_loader)
    
    # Added cumulative losses to lists for later display
    train_losses.append(train_loss)
    test_losses.append(test_loss)
    
    print(f'Epoch {epoch + 1}, train loss: {train_loss:.3f}, test loss: {test_loss:.3f}')
    
    # save our model every n epoch
    if (epoch+1) % save_every_n_epoch==0:
        torch.save(model.state_dict(), f'model_epoch{epoch:04}.pt')
        
# save the model at the end of the training process
torch.save(model.state_dict(), f'model_final.pt')

print("training finished, model saved to 'model_final.pt'")

In [21]:
import matplotlib.pyplot as plt
import numpy as np

In [None]:
plt.figure(figsize=(5,3))
plt.plot(train_losses, label='training loss')
plt.plot(test_losses, label = 'validation loss')
plt.xticks(np.arange(len(train_losses)))
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend(loc='upper right')
plt.show()

In [23]:
from src.model import ConvNeuralNetwork

In [None]:
# make sure the parameters are the same as when the model is created
eval_model = ConvNeuralNetwork(img_channel = colour_channels, 
                          img_resolution = image_resolution,
                          num_classes = len(class_names))

# load the saved model, make sure the path is correct
eval_model.load_state_dict(torch.load('model_final.pt'))

eval_model.to(device)
eval_model.eval() 

In [25]:
test_data, test_labels = next(iter(test_loader))

true_label = class_names[test_labels[0]]

In [None]:
plt.figure(figsize=(3,3))
plt.title(f"True Label: {true_label}")
plt.imshow(test_data[0].add(1).div(2).permute((1,2,0)).cpu().detach().numpy())

plt.show()

In [27]:
with torch.inference_mode():
    pred_labels = eval_model(test_data.to(device))

In [None]:
plt.figure(figsize=(6,2))
plt.plot(pred_labels[0].cpu().detach().numpy())
plt.xticks(np.arange(len(class_names)),class_names, rotation='vertical')
plt.ylabel("Probabilities")
plt.grid(axis='x', color='0.95')
plt.show()

In [None]:
_, predicted_class_index = torch.max(pred_labels[0], 0)

plt.figure(figsize=(3,3))
plt.title(f"True label: {true_label}, predicted label: {class_names[predicted_class_index]}")
plt.imshow(test_data[0].add(1).div(2).permute((1,2,0)).cpu().detach().numpy())

plt.show()

In [None]:
num_samples = 0
num_correct = 0

with torch.inference_mode():
    for i, data in enumerate(test_loader, 0):
        # Load: The testing data loader loads a batch of testing data and their true class labels.
        inputs, true_labels = data
        inputs = inputs.to(device)
        true_labels = true_labels.to(device)

        # Pass: Forward pass the testing data to our model, and get the predicted classes.
        pred_labels = eval_model(inputs)
        pred_labels = torch.argmax(pred_labels, dim=1)
        
        num_correct += pred_labels.size(0) - torch.count_nonzero(pred_labels - true_labels)
        num_samples += pred_labels.size(0) 
        
accuracy = num_correct / num_samples
print(f'correct samples: {num_correct}  \ntotal samples: {num_samples}  \nmodel accuracy: {accuracy:.3f}')