# 1. Setup


In [None]:
%matplotlib inline

In [None]:
import torch
import torch.nn as nn
import time
import matplotlib.pyplot as plt
from PIL import Image
from torch.utils.data import DataLoader, random_split
from torchvision.transforms import v2
from torchvision.datasets import ImageFolder

In [None]:
torch.set_float32_matmul_precision('high')

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device: ", torch.cuda.get_device_name(0))

In [None]:
if device == "cuda":
    # Desired memory limit in MB
    memory_limit_mb = 4095.5
    # Get the total memory of GPU 0 in bytes
    total_memory = torch.cuda.get_device_properties(0).total_memory
    # Convert memory limit to bytes
    memory_limit = memory_limit_mb * 1024 ** 2
    # Calculate the fraction of total memory
    memory_fraction = memory_limit / total_memory
    # Set the memory fraction for GPU 0
    torch.cuda.set_per_process_memory_fraction(memory_fraction, device=0)
    print(f"Set GPU 0 memory fraction to {memory_fraction:.2%}")

# 2. Load & transform data


In [None]:
# First, create transforms without normalization to calculate dataset statistics
initial_transforms = v2.Compose([
    v2.Resize((256, 256)),
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True)
])

In [None]:
label_transforms = v2.Lambda(
    lambda y: torch.zeros(54, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y, dtype=torch.long), value=1)
)

In [None]:
# Create temporary dataset to calculate mean and std
temp_dataset = ImageFolder(root='../data/data_pool', transform=initial_transforms, target_transform=label_transforms)
temp_loader = DataLoader(temp_dataset, batch_size=64, shuffle=False, num_workers=4, pin_memory=True, pin_memory_device=device)

In [None]:
channels_sum = torch.zeros(3, device=device)
channels_sqrd_sum = torch.zeros(3, device=device)
num_batches = 0

for data, _ in temp_loader:
    data = data.to(device)  # Move data to the specified device
    channels_sum += torch.mean(data, dim=[0, 2, 3])
    channels_sqrd_sum += torch.mean(data ** 2, dim=[0, 2, 3])
    num_batches += 1

# Compute final mean and standard deviation
mean = channels_sum / num_batches
std = torch.sqrt(channels_sqrd_sum / num_batches - mean ** 2)

# Move mean and std to CPU and convert to list (if needed)
mean = mean.cpu().tolist()
std = std.cpu().tolist()

# Calculate mean and std
print(f"Dataset mean: {mean}")
print(f"Dataset std: {std}")

In [None]:
train_transforms = v2.Compose([
    v2.Resize((256, 256)),
    v2.RandomHorizontalFlip(p=0.5),
    v2.RandomRotation(degrees=10),
    v2.RandomAffine(degrees=0, translate=(0.05, 0.05)),
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=mean, std=std)
])

In [None]:
val_transforms = v2.Compose([
    v2.Resize((256, 256)),
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=mean, std=std),
])

In [None]:
# Load the dataset with appropriate transforms
train_dataset = ImageFolder(root='../data/data_pool', transform=train_transforms)

# 3. Split data


In [None]:
# Define split ratios
train_ratio = 0.75
val_ratio = 0.15
test_ratio = 0.10

In [None]:
# Calculate lengths
total_size = len(train_dataset)
train_size = int(train_ratio * total_size)
val_size = int(val_ratio * total_size)
test_size = total_size - train_size - val_size

In [None]:
# Create train/val/test datasets with appropriate transforms
train_data, val_data, test_data = random_split(
    train_dataset, 
    [train_size, val_size, test_size],
    generator=torch.Generator().manual_seed(42)  # For reproducibility
)

# Override transforms for validation and test sets
val_data.dataset.transform = val_transforms
test_data.dataset.transform = val_transforms

# Create DataLoaders
train_loader = DataLoader(train_data, batch_size=64, shuffle=True, num_workers=4, pin_memory=True, pin_memory_device=device)
val_loader = DataLoader(val_data, batch_size=64, shuffle=False, num_workers=4, pin_memory=True, pin_memory_device=device)
test_loader = DataLoader(test_data, batch_size=64, shuffle=False, num_workers=4, pin_memory=True, pin_memory_device=device)

# 4. Building the neural network


In [None]:
# Input shape constants
IMG_WIDTH = 256
IMG_HEIGHT = 256
IMG_CHANNELS = 3

REG_FACTOR = 1e-4

In [None]:
class ConvolutionalNeuralNetwork(nn.Module):
    def __init__(self):
        super(ConvolutionalNeuralNetwork, self).__init__()
        
        self.conv_block1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding='same', bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16, 16, 3, padding='same', bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        
        self.conv_block2 = nn.Sequential(
            nn.Conv2d(16, 32, 3, padding='same', bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 32, 3, padding='same', bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        
        self.conv_block3 = nn.Sequential(
            nn.Conv2d(32, 64, 3, padding='same', bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding='same', bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        
        self.flatten = nn.Flatten()
        
        self.fc_layers = nn.Sequential(
            nn.Linear(64 * (IMG_HEIGHT // 8) * (IMG_WIDTH // 8), 256, bias=False),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.3),
            
            nn.Linear(256, 128, bias=False),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.3),
            
            nn.Linear(128, 128, bias=False),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.3),
            
            nn.Linear(128, 64, bias=False),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Dropout(0.3),
            
            nn.Linear(64, 64, bias=False),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Dropout(0.3),
            
            nn.Linear(64, 54)
        )
        
    def forward(self, x):
        x = self.conv_block1(x)
        x = self.conv_block2(x)
        x = self.conv_block3(x)
        x = self.flatten(x)
        x = self.fc_layers(x)
        return x

In [None]:
model = ConvolutionalNeuralNetwork()
model.to(device)
model = torch.compile(model)

# 5. Optimising the model parameters

In [None]:
learning_rate = 5e-4
epochs = 50

In [None]:
# Initialize the loss function
loss_fn = nn.CrossEntropyLoss()

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=REG_FACTOR)

# 6. Train the model

In [None]:
def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    # Set the model to training mode - important for batch normalization and dropout layers
    # Unnecessary in this situation but added for best practices
    model.train()

    training_loss = 0.0
    correct = 0
    total = 0

    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)  # Move data to device

        optimizer.zero_grad()

        # Compute prediction and loss
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        loss.backward()
        optimizer.step()

        training_loss += loss.item() * X.size(0)
        correct += (pred.argmax(1) == y).type(torch.float).sum().item()
        total += y.size(0)

        if batch % 100 == 0:
            loss_item = loss.item()
            current = batch * len(X)
            print(f"loss: {loss_item:>7f}  [{current:>5d}/{size:>5d}]")

    avg_loss = training_loss / total
    accuracy = correct / total
    return avg_loss, accuracy

def test_loop(dataloader, model, loss_fn):
    # Set the model to evaluation mode - important for batch normalization and dropout layers
    # Unnecessary in this situation but added for best practices
    model.eval()
    test_loss = 0.0
    correct = 0
    total = 0

    # Evaluating the model with torch.no_grad() ensures that no gradients are computed during test mode
    # also serves to reduce unnecessary gradient computations and memory usage for tensors with requires_grad=True
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)  # Move data to device
            pred = model(X)
            loss = loss_fn(pred, y)
            test_loss += loss.item() * X.size(0)
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
            total += y.size(0)

    avg_loss = test_loss / total
    accuracy = correct / total
    print(f"Avg loss: {avg_loss:>8f}, Accuracy: {(100*accuracy):>0.1f}%\n")

    return avg_loss, accuracy

In [None]:
train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []
epoch_times = []

best_val_loss = float('inf')
patience = 5  # Number of epochs with no improvement after which training will be stopped
epochs_no_improve = 0
total_start_time = time.time()
stopped_early = False

In [None]:
for epoch in range(epochs):
    print(f"Epoch {epoch+1}\n-------------------------------")
    epoch_start_time = time.time()

    train_loss, train_accuracy = train_loop(train_loader, model, loss_fn, optimizer)
    val_loss, val_accuracy = test_loop(val_loader, model, loss_fn)

    epoch_time = time.time() - epoch_start_time
    epoch_times.append(epoch_time)

    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_accuracies.append(train_accuracy)
    val_accuracies.append(val_accuracy)

    print(f"Epoch {epoch+1} completed in {epoch_time:.2f} seconds")
    print(f"Training Loss: {train_loss:.4f}, Training Accuracy: {train_accuracy:.4f}")
    print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}\n")

    
    # Check for improvement
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        epochs_no_improve = 0
        # Optionally save the best model
        torch.save(model.state_dict(), '../data/models/best_model.pth')
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= patience:
            print(f"Early stopping triggered after {patience} epochs with no improvement.")
            stopped_early = True
            break

total_training_time = time.time() - total_start_time
torch.save(model, '../data/models/full_model.pth')

print(f"\nTraining complete in {total_training_time:.2f} seconds")
print("\n-------------------------------\nDone!")

# 7. Plot model metrics

In [None]:
epochs_range = range(1, len(train_losses) + 1)

# Plot Losses
plt.figure(figsize=(10, 5))
plt.plot(epochs_range, train_losses, label='Training Loss')
plt.plot(epochs_range, val_losses, label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()

# Plot Accuracies
plt.figure(figsize=(10, 5))
plt.plot(epochs_range, train_accuracies, label='Training Accuracy')
plt.plot(epochs_range, val_accuracies, label='Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()
plt.show()

# Plot Epoch Times
plt.figure(figsize=(10, 5))
plt.plot(epochs_range, epoch_times, label='Time per Epoch')
plt.xlabel('Epochs')
plt.ylabel('Time (seconds)')
plt.title('Time Taken per Epoch')
plt.legend()
plt.show()

# 8. Test the model

In [None]:
model = ConvolutionalNeuralNetwork()
model = torch.compile(model)
model.to(device)

# Load the best model (if saved during early stopping)
if stopped_early:
    model.load_state_dict(torch.load('../data/models/best_model.pth'))
else:
    model = torch.load('../data/models/full_model.pth')

model.eval()

print("Test Results on the Test Set:")
test_loop(test_loader, model, loss_fn)

In [None]:
image_transform = v2.Compose([
    v2.Resize((256, 256)),
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=mean, std=std),
])

In [None]:
# Load and preprocess the image
image_path = 'test_image.jpg' 
image = Image.open(image_path).convert('RGB')
input_tensor = image_transform(image)
input_tensor = input_tensor.unsqueeze(0)  # Add batch dimension
input_tensor = input_tensor.to(device)

In [None]:
# Make prediction
with torch.no_grad():
    output = model(input_tensor)
    predicted_class = output.argmax(dim=1).item()

In [None]:
# Get class names
class_names = train_dataset.classes
predicted_label = class_names[predicted_class]
print(f"Predicted class: {predicted_label}")