# 1. Setup & Dependencies

In [None]:
# Install requirements.tx

!pip install -r requirements.txt

In [None]:
!pip install --upgrade torch torchvision

In [None]:
!python --version

In [None]:
# Check cuda version

!nvcc --version

## 1.1. Import statements

In [None]:
%matplotlib inline

In [None]:
import os

# Configure Pytorch to use 'cudaMallocAsync' as the allocator
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'backend:cudaMallocAsync'

In [None]:
import torch
import torch.nn as nn
import time
import gc
import matplotlib.pyplot as plt
from torchinfo import summary
from tqdm import tqdm
from PIL import Image
from torch.utils.data import DataLoader
from torch.amp import autocast, GradScaler
from torchvision.transforms import v2
from torchvision.datasets import ImageFolder

## 1.2. Device configuration

In [None]:
def get_device() -> torch.device:
    """Get the best available device for PyTorch."""
    if torch.cuda.is_available():
        device = "cuda"

        # The flag below controls whether to allow TF32 on matmul.
        torch.backends.cuda.matmul.allow_tf32 = True
        # The flag below controls whether to allow TF32 on cuDNN.
        torch.backends.cudnn.allow_tf32 = True

        # Print GPU info
        print(f"Using CUDA device: {torch.cuda.get_device_name(0)}")

        # Set up GPU memory management
        memory_limit_mb = 4095.5  # Adjust as needed

        # Fetching total memory of GPU
        total_memory = torch.cuda.get_device_properties(0).total_memory
        print(f"Total GPU memory: {total_memory / 1024**2:.2f} MB")

        # Setting memory limit
        memory_limit = memory_limit_mb * 1024 ** 2
        memory_fraction = memory_limit / total_memory
        torch.cuda.set_per_process_memory_fraction(memory_fraction, device=0)

        print(f"Set GPU memory fraction to {memory_fraction:.2%}")

        # Empty cache to measure memory usage
        torch.cuda.empty_cache()
    elif torch.backends.mps.is_available():
        device = "mps"
        print("Using Apple Silicon MPS device")
    else:
        device = "cpu"
        print("Using CPU device")

    return torch.device(device)

In [None]:
# GPU optimization
torch.set_float32_matmul_precision('high')

# Set device
device = get_device()

In [None]:
pin_memory = True if device.type == 'cuda' else False
pin_memory_device = 'cuda' if device.type == 'cuda' else ''

# 2. Load & transform data

In [None]:
# Input shape constants, width, height, and number of channels
IMG_WIDTH = 384
IMG_HEIGHT = 216
IMG_CHANNELS = 3

## 2.1. Normalize the dataset

In [None]:
# Transform without normalization
initial_transforms = v2.Compose([
    v2.Resize((IMG_WIDTH, IMG_HEIGHT)),
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True)
])

In [None]:
# Create temporary dataset to calculate mean and std
temp_dataset = ImageFolder(root='../data/merged_pool', transform=initial_transforms)

temp_loader = DataLoader(temp_dataset, batch_size=64, shuffle=False, num_workers=6, pin_memory=pin_memory, pin_memory_device=pin_memory_device)

In [None]:
calc_stats = True

if calc_stats:
    print(f"Computing dataset statistics using device: {device}")
    print(f"Number of images to process: {len(temp_dataset)}")

    # Initialize tensors to accumulate the sum and squared sum of channels
    channels_sum = torch.zeros(3, device=device)
    channels_sqrd_sum = torch.zeros(3, device=device)
    num_batches = 0
    start_time = time.time()

    # Iterating over the data loader with a progress bar
    for batch_idx, (data, _) in enumerate(tqdm(temp_loader, desc="Computing mean/std")):
        # Move data to the specified device
        data = data.to(device, non_blocking=True)

        # Used mixed precision for faster computation
        with autocast(device.type):
            # Calculating and adding the mean and squared mean of each channel
            channels_sum += torch.mean(data, dim=[0, 2, 3])
            channels_sqrd_sum += torch.mean(data ** 2, dim=[0, 2, 3])
        num_batches += 1

        if batch_idx % 10 == 0:
            batch_time = time.time() - start_time
            print(f"\nProcessed {batch_idx * temp_loader.batch_size} images in {batch_time:.2f}s")

    # Calculate overall mean and standard deviation
    mean = channels_sum / num_batches
    std = torch.sqrt((channels_sqrd_sum / num_batches) - (mean ** 2))

    mean = mean.cpu().tolist()
    std = std.cpu().tolist()

    total_time = time.time() - start_time
    print(f"\nTotal processing time: {total_time:.2f} seconds")
    print(f"Dataset mean: {mean}")
    print(f"Dataset std: {std}")

    # Release memory
    del temp_dataset, temp_loader, channels_sum, channels_sqrd_sum, data
    gc.collect()

    torch.cuda.empty_cache()
else:
    mean=[0.5479778051376343, 0.526210367679596, 0.4944702088832855]
    std=[0.2237844169139862, 0.23763211071491241, 0.26044926047325134]

## 2.2. Define transforms

In [None]:
# Transform and data augmentation for training dataset
train_transforms = v2.Compose([
    v2.Resize((384, 216)),
    v2.RandomVerticalFlip(p=0.5),
    v2.RandomRotation(degrees=10),
    v2.ColorJitter(brightness=0.1, contrast=0.1),
    v2.RandomPerspective(distortion_scale=0.3, p=0.3),
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=mean, std=std)
])

In [None]:
# Transform for validation dataset
val_transforms = v2.Compose([
    v2.Resize((384, 216)),
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=mean, std=std),
])

# 3. Partition dataset

In [None]:
# Load the training and validation datasets with appropriate transforms
train_data = ImageFolder(root='../data/merged_pool', transform=train_transforms)

val_data = ImageFolder(root='../data/validation_pool', transform=val_transforms)

In [None]:
# Create DataLoaders for training and validation datasets
train_loader = DataLoader(train_data, batch_size=64, shuffle=True, num_workers=6, pin_memory=pin_memory, pin_memory_device=pin_memory_device)

val_loader = DataLoader(val_data, batch_size=64, shuffle=False, num_workers=6, pin_memory=pin_memory, pin_memory_device=pin_memory_device)  

# 4. Building the convolutional neural network

In [None]:
class UnoSymbolClassifier(nn.Module):
    def __init__(self) -> None:
        super().__init__()

        # First convolutional block
        self.conv1_block = nn.Sequential(
            # 1x1 convolution to reduce channel dimensions
            nn.Conv2d(3, 8, 1, 1), 
            # Batch normalization for stability           
            nn.BatchNorm2d(8),               
            nn.ReLU(),
            # Padding to maintain spatial dimensions                        
            nn.ReflectionPad2d(1), 
            # 3x3 convolution to capture more spatial features           
            nn.Conv2d(8, 16, 3, 1),           
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16, 16, 1, 1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            # Max pooling to reduce spatial dimensions
            nn.MaxPool2d(2, 2),               
        )

        # Second convolutional block
        self.conv2_block = nn.Sequential(
            nn.Conv2d(16, 16, 1, 1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.ReflectionPad2d(1),
            nn.Conv2d(16, 32, 3, 1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 32, 1, 1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
        )

        # Third convolutional block
        self.conv3_block = nn.Sequential(
            nn.Conv2d(32, 32, 1, 1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.ReflectionPad2d(1),
            nn.Conv2d(32, 64, 3, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, 1, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
        )

        # Fourth convolutional block
        self.conv4_block = nn.Sequential(
            nn.Conv2d(64, 64, 1, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ReflectionPad2d(1),
            nn.Conv2d(64, 128, 3, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 128, 1, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
        )

        # Flatten layer to convert 2D feature maps to 1D feature vectors
        self.flatten = nn.Flatten()

        # First fully connected layer
        self.fc1 = nn.Sequential(
            nn.Linear(128 * (IMG_HEIGHT // 16) * (IMG_WIDTH // 16), 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            # Dropout for regularization
            nn.Dropout(0.2)  
        )

        # Second fully connected layer
        self.fc2 = nn.Sequential(
            nn.Linear(64, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.2)
        )

        # Third fully connected layer
        self.fc3 = nn.Sequential(
            nn.Linear(128, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Dropout(0.2)
        )

        # Output layer
        self.fc4 = nn.Linear(64, 54)  # 54 classes for classification

    def forward(self, x) -> torch.utils.data.Dataset:
        x = self.conv1_block(x)  
        x = self.conv2_block(x)  
        x = self.conv3_block(x)  
        x = self.conv4_block(x)  
        x = self.flatten(x)      
        x = self.fc1(x)          
        x = self.fc2(x)          
        x = self.fc3(x)         
        x = self.fc4(x)          

        return x  

## 4.2. Create the model

In [None]:
# Create an instance of the model
model = UnoSymbolClassifier()
model.to("cpu")

In [None]:
summary(model, input_size=(64, IMG_CHANNELS, IMG_HEIGHT, IMG_WIDTH))

In [None]:
# Move the model to the specified device and compile it for faster performance
model.to(device, non_blocking=True)
model = torch.compile(model, backend="inductor")

torch.cuda.empty_cache()

# 5. Optimising model parameters

## 5.1. Learning parameters

In [None]:
# Adjust the learning rate and weight decay for the optimizer and number of epochs
LEARNING_RATE = 6e-4
WEIGHT_DECAY = 1e-6
EPOCHS = 30

## 5.2. Optimizer & cost function

In [None]:
loss_fn = nn.CrossEntropyLoss()

In [None]:
# Use the Adam optimizer for training the model, with weight decay for regularization
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY, fused=True)

In [None]:
# Optimizer scheduler to reduce learning rate on plateauing of validation loss
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.25, patience=5)

# 6. Training the model

## 6.1. Define train function

In [None]:
def train_loop(dataloader, model, loss_fn, optimizer):
    # Initialize gradient scaler for mixed precision
    scaler = GradScaler()

    # Total number of samples in the dataset
    size = len(dataloader.dataset)
    
    # Set the model to training mode
    model.train()

    training_loss = 0.0
    correct = 0
    total = 0

    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True)

        # Reset gradients
        optimizer.zero_grad()

        with autocast(device.type):
            # Forward pass to calculate predictions
            pred = model(X)
            # Calculate loss
            loss = loss_fn(pred, y)

        # Backpropagation
        scaler.scale(loss).backward()
        # Update model parameters
        scaler.step(optimizer)
        scaler.update()

        training_loss += loss.item() * X.size(0)
        
        # Count correct predictions
        correct += (pred.argmax(1) == y).type(torch.float32).sum().item()
        total += y.size(0)

        if batch % 100 == 0:
            loss_item = loss.item()
            current = batch * len(X)
            print(f"loss: {loss_item:>7f}  [{current:>5d}/{size:>5d}]")

    avg_loss = training_loss / total
    accuracy = correct / total
    return avg_loss, accuracy

## 6.2. Define validate & test function

In [None]:
def test_loop(dataloader, model, loss_fn):
    # Set the model to evaluation mode
    model.eval()

    test_loss = 0.0
    correct = 0
    total = 0

    # Disable gradient computation for evaluation to reduce memory usage and speed up computations
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True)  # Move data to device

            with autocast(device.type):
                pred = model(X)
                loss = loss_fn(pred, y)

            test_loss += loss.item() * X.size(0)
            correct += (pred.argmax(1) == y).type(torch.float32).sum().item()
            total += y.size(0)

    avg_loss = test_loss / total
    accuracy = correct / total
    print(f"Avg loss: {avg_loss:>8f}, Accuracy: {(100*accuracy):>0.1f}%\n")

    return avg_loss, accuracy

## 6.3. Define overfitting function

In [None]:
def check_overfitting(train_loss, val_loss, train_acc, val_acc, threshold=0.1):
    # Calculate the absolute difference between training and validation loss/accuracy
    loss_gap = abs(train_loss - val_loss)
    acc_gap = abs(train_acc - val_acc)

    # Determine if overfitting is occurring based on the defined threshold
    is_overfitting = (loss_gap > threshold) and (train_acc > val_acc + threshold)

    if is_overfitting:
        print(f"Warning: Possible overfitting detected")
        print(f"Loss gap: {loss_gap:.4f}, Accuracy gap: {acc_gap:.4f}")

    return is_overfitting

## 6.4. Training loop

In [None]:
train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []
epoch_times = []

best_model_metrics = None
stopped_early = False
is_overfitting = 0
patience = 300
total_start_time = time.time()

In [None]:
torch.cuda.empty_cache()

for epoch in range(EPOCHS):
    print(f"Epoch {epoch+1}\n-------------------------------")
    epoch_start_time = time.time()

    # Execute the training loop and retrieve training loss and accuracy
    train_loss, train_accuracy = train_loop(train_loader, model, loss_fn, optimizer)

    scheduler.step(train_loss)

    # Execute the validation loop and retrieve validation loss and accuracy
    val_loss, val_accuracy = test_loop(val_loader, model, loss_fn)

    epoch_time = time.time() - epoch_start_time
    epoch_times.append(epoch_time)

    # Append training and validation metrics to their respective lists
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_accuracies.append(train_accuracy)
    val_accuracies.append(val_accuracy)

    print(f"Epoch {epoch+1} completed in {epoch_time:.2f} seconds")
    print(f"Training Loss: {train_loss:.4f}, Training Accuracy: {train_accuracy:.4f}")
    print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}\n")

    # Increment overfitting counter if overfitting detected
    if check_overfitting(train_loss, val_loss, train_accuracy, val_accuracy):
        is_overfitting += 1
    else:
        is_overfitting = 0

    #  Check all conditions
    accuracy_gap = abs(train_accuracy - val_accuracy)
    conditions_met = (
        train_loss >= 0.1 and
        val_loss >= 0.1 and
        train_accuracy <= 0.975 and
        val_accuracy <= 0.975 and
        accuracy_gap <= 0.07
    )

    # Save model if conditions are met and validation loss improved
    if train_loss >= 0.1 and train_accuracy <= 0.975:
        best_model_metrics = {
            'epoch': epoch + 1,
            'train_loss': train_loss,
            'val_loss': val_loss,
            'train_accuracy': train_accuracy,
            'val_accuracy': val_accuracy
        }
        torch.save(model.state_dict(), '../data/models/best_symbol_classifier.pth')

    # Stop if overfitting persists for multiple epochs
    if is_overfitting >= patience:
        print(f"Early stopping triggered due to persistent overfitting.")
        stopped_early = True
        break

total_training_time = time.time() - total_start_time

# Save both models
torch.save(model.state_dict(), '../data/models/full_symbol_classifier.pth')

if best_model_metrics:
    print("\nBest model saved with metrics:")
    for key, value in best_model_metrics.items():
        print(f"{key}: {value}")

print(f"\nTraining complete in {total_training_time:.2f} seconds")
print("\n-------------------------------\nDone!")

# 7. Plot model metrics

In [None]:
epochs_range = range(1, len(train_losses) + 1)

## 7.1. Loss graph

In [None]:
# Plot losses for validation and training datasets
plt.figure(figsize=(10, 5))
plt.plot(epochs_range, train_losses, label='Training Loss')
plt.plot(epochs_range, val_losses, label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()

## 7.2. Accuracy graph

In [None]:
# Plot accuracies for validation and training datasets
plt.figure(figsize=(10, 5))
plt.plot(epochs_range, train_accuracies, label='Training Accuracy')
plt.plot(epochs_range, val_accuracies, label='Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()
plt.show()

## 7.3. Epoch duration graph

In [None]:
# Plot time taken per epoch
plt.figure(figsize=(10, 5))
plt.plot(epochs_range, epoch_times, label='Time per Epoch')
plt.xlabel('Epochs')
plt.ylabel('Time (seconds)')
plt.title('Time Taken per Epoch')
plt.legend()
plt.show()