# 1. Setup

## 1.1. Import statements

In [None]:
%matplotlib inline

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import matplotlib.pyplot as plt
from torchinfo import summary
from tqdm import tqdm
from PIL import Image
from torch.utils.data import DataLoader, random_split
from torch.optim.swa_utils import AveragedModel, SWALR, update_bn
from torchvision.transforms import v2
from torchvision.datasets import ImageFolder

## 1.2. Device configuration

In [None]:
def get_device() -> torch.device:
    """Get the best available device for PyTorch."""
    if torch.cuda.is_available():
        device = "cuda"

        # The flag below controls whether to allow TF32 on matmul.
        torch.backends.cuda.matmul.allow_tf32 = True
        # The flag below controls whether to allow TF32 on cuDNN.
        torch.backends.cudnn.allow_tf16 = True

        # Print GPU info
        print(f"Using CUDA device: {torch.cuda.get_device_name(0)}")
        
        # Set up GPU memory management
        memory_limit_mb = 4095.5  # Adjust as needed

        total_memory = torch.cuda.get_device_properties(0).total_memory

        memory_limit = memory_limit_mb * 1024 ** 2
        memory_fraction = memory_limit / total_memory

        torch.cuda.set_per_process_memory_fraction(memory_fraction, device=0)

        print(f"Set GPU memory fraction to {memory_fraction:.2%}")
    elif torch.backends.mps.is_available():
        device = "mps"
        print("Using Apple Silicon MPS device")
    else:
        device = "cpu"
        print("Using CPU device")
    
    return torch.device(device)

In [None]:
torch.set_float32_matmul_precision('high')
device = get_device()

In [None]:
pin_memory = True if device.type == 'cuda' else False
pin_memory_device = 'cuda' if device.type == 'cuda' else ''

# 2. Load & transform data

## 2.1. Normalize the dataset

In [None]:
# First, create transforms without normalization to calculate dataset statistics
initial_transforms = v2.Compose([
    v2.Resize((256, 256)),
    v2.ToImage(),
    v2.ToDtype(torch.float16, scale=True)
])

label_transforms = v2.Lambda(
    lambda y: torch.zeros(54, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y, dtype=torch.long), value=1)
)

# Create temporary dataset to calculate mean and std
temp_dataset = ImageFolder(root='../data/new_pool', transform=initial_transforms, target_transform=label_transforms)

temp_loader = DataLoader(temp_dataset, batch_size=64, shuffle=False, num_workers=4, pin_memory=pin_memory, pin_memory_device=pin_memory_device)

In [None]:
calc_stats = True

if calc_stats:
    print(f"Computing dataset statistics using device: {device}")
    print(f"Number of images to process: {len(temp_dataset)}")

    channels_sum = torch.zeros(3, device=device)
    channels_sqrd_sum = torch.zeros(3, device=device)
    num_batches = 0
    start_time = time.time()

    for batch_idx, (data, _) in enumerate(tqdm(temp_loader, desc="Computing mean/std")):
        data = data.to(device)  # Add non_blocking=True
        channels_sum += torch.mean(data, dim=[0, 2, 3])
        channels_sqrd_sum += torch.mean(data ** 2, dim=[0, 2, 3])
        num_batches += 1
        
        if batch_idx % 10 == 0:
            batch_time = time.time() - start_time
            print(f"\nProcessed {batch_idx * temp_loader.batch_size} images in {batch_time:.2f}s")

    mean = channels_sum / num_batches
    std = torch.sqrt(channels_sqrd_sum / num_batches - mean ** 2)

    mean = mean.cpu().tolist()
    std = std.cpu().tolist()

    total_time = time.time() - start_time
    print(f"\nTotal processing time: {total_time:.2f} seconds")
    print(f"Dataset mean: {mean}")
    print(f"Dataset std: {std}")
else:
    mean=[0.36145609617233276, 0.3521592915058136, 0.3483520746231079]
    std=[0.20100851356983185, 0.21270669996738434, 0.254261314868927]

## 2.2. Define transforms

In [None]:
train_transforms = v2.Compose([
    v2.Resize((256, 256)),
    v2.RandomHorizontalFlip(p=0.2),
    v2.RandomVerticalFlip(p=0.2),
    # New subtle augmentations
    v2.ColorJitter(
        brightness=0.1,  # Subtle brightness changes for different lighting
        contrast=0.1,    # Subtle contrast changes for different cameras/lighting
        saturation=0.1,  # Subtle color saturation changes
    ),
    v2.RandomAdjustSharpness(sharpness_factor=1.1, p=1.0),  # Subtle sharpness changes
    # Small perspective changes to simulate different card angles
    v2.RandomPerspective(
        distortion_scale=0.15,  # Keep perspective changes minimal
        p=0.2                   # Apply only 30% of the time
    ),
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=mean, std=std)
])

In [None]:
val_transforms = v2.Compose([
    v2.Resize((256, 256)),
    v2.RandomAdjustSharpness(sharpness_factor=1.1, p=1.0),  # Subtle sharpness changes
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=mean, std=std),
])

# 3. Partition dataset

In [None]:
# Load the dataset with appropriate transforms
full_dataset = ImageFolder(root='../data/new_pool/', transform=None)

In [None]:
# Define split ratios
train_ratio = 0.70
val_ratio = 0.20
test_ratio = 0.10

# calculate lengths
total_size = len(full_dataset)
train_size = int(train_ratio * total_size)
val_size = int(val_ratio * total_size)
test_size = total_size - train_size - val_size

In [None]:
# Create train/val/test datasets with appropriate transforms
train_data, val_data, test_data = random_split(
    full_dataset, 
    [train_size, val_size, test_size],
    generator=torch.Generator().manual_seed(42)  # For reproducibility
)

# Override transforms for validation and test sets
train_data.dataset.transform = train_transforms
val_data.dataset.transform = val_transforms
test_data.dataset.transform = val_transforms

# Create DataLoaders
train_loader = DataLoader(train_data, batch_size=64, shuffle=True, num_workers=4, pin_memory=pin_memory, pin_memory_device=pin_memory_device)

val_loader = DataLoader(val_data, batch_size=64, shuffle=False, num_workers=4, pin_memory=pin_memory, pin_memory_device=pin_memory_device)  

test_loader = DataLoader(test_data, batch_size=64, shuffle=False, num_workers=4, pin_memory=pin_memory, pin_memory_device=pin_memory_device)

# 4. Building the neural network

In [None]:
# Input shape constants
IMG_WIDTH = 360
IMG_HEIGHT = 640
IMG_CHANNELS = 3

## 4.1. Defining the model class

In [None]:
class ConvolutionalNeuralNetwork(nn.Module):
    def __init__(self) -> None:
        super().__init__()

        self.conv1_block = nn.Sequential(
            nn.Conv2d(IMG_CHANNELS, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

        self.conv2_block = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

        self.conv3_block = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

        self.conv4_block = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

        self.conv5_block = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

        self.fc1 = nn.Linear(512 * 11 * 20, 1024)
        self.dropout = nn.Dropout(0.3)
        self.fc2 = nn.Linear(1024, 54)  # Assuming 54 classes for Uno cards

    def forward(self, x):
        x = self.conv1_block(x)
        x = self.conv2_block(x)
        x = self.conv3_block(x)
        x = self.conv4_block(x)
        x = self.conv5_block(x)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

## 4.2. Create the model

In [None]:
model = ConvolutionalNeuralNetwork()
model.to(device)
model = torch.compile(model, backend="inductor")
swa_model = AveragedModel(model)

In [None]:
summary(model, input_size=(64, IMG_CHANNELS, IMG_HEIGHT, IMG_WIDTH))

# 5. Optimising model parameters

## 5.1. Learning parameters

In [None]:
learning_rate = 1e-3
REG_FACTOR = 5e-3
epochs = 50

## 5.2. Optimizer & cost function

In [None]:
loss_fn = nn.CrossEntropyLoss()

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=REG_FACTOR, fused=True)

In [None]:
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)

In [None]:
swa_scheduler = SWALR(optimizer, swa_lr=1e-3)
# Start SWA after 10 epochs
swa_start = 10  

# 6. Training the model

## 6.1. Define train function

In [None]:
def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    # Set the model to training mode - important for batch normalization and dropout layers
    # Unnecessary in this situation but added for best practices
    model.train()

    training_loss = 0.0
    correct = 0
    total = 0

    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True)  # Move data to device

        optimizer.zero_grad()

        # Compute prediction and loss
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        loss.backward()
        optimizer.step()

        training_loss += loss.item() * X.size(0)
        correct += (pred.argmax(1) == y).type(torch.float).sum().item()
        total += y.size(0)

        if batch % 100 == 0:
            loss_item = loss.item()
            current = batch * len(X)
            print(f"loss: {loss_item:>7f}  [{current:>5d}/{size:>5d}]")

    avg_loss = training_loss / total
    accuracy = correct / total
    return avg_loss, accuracy

## 6.2. Define test function

In [None]:
def test_loop(dataloader, model, loss_fn):
    # Set the model to evaluation mode - important for batch normalization and dropout layers
    # Unnecessary in this situation but added for best practices
    model.eval()
    test_loss = 0.0
    correct = 0
    total = 0

    # Evaluating the model with torch.no_grad() ensures that no gradients are computed during test mode
    # also serves to reduce unnecessary gradient computations and memory usage for tensors with requires_grad=True
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True)  # Move data to device
            pred = model(X)
            loss = loss_fn(pred, y)
            test_loss += loss.item() * X.size(0)
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
            total += y.size(0)

    avg_loss = test_loss / total
    accuracy = correct / total
    print(f"Avg loss: {avg_loss:>8f}, Accuracy: {(100*accuracy):>0.1f}%\n")

    return avg_loss, accuracy

## 6.3. Define overfitting function

In [None]:
def check_overfitting(train_loss, val_loss, train_acc, val_acc, threshold=0.1):
    loss_gap = abs(train_loss - val_loss)
    acc_gap = abs(train_acc - val_acc)
    
    is_overfitting = (loss_gap > threshold) and (train_acc > val_acc + threshold)
    
    if is_overfitting:
        print(f"Warning: Possible overfitting detected")
        print(f"Loss gap: {loss_gap:.4f}, Accuracy gap: {acc_gap:.4f}")
    
    return is_overfitting

## 6.4. Training loop

In [None]:
train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []
epoch_times = []

best_val_loss = float('inf')
best_model_metrics = None
stopped_early = False
is_overfitting = 0
patience = 10  # Number of epochs with no improvement after which training will be stopped
total_start_time = time.time()

In [None]:
for epoch in range(epochs):
    print(f"Epoch {epoch+1}\n-------------------------------")
    epoch_start_time = time.time()

    train_loss, train_accuracy = train_loop(train_loader, model, loss_fn, optimizer)

    # Update SWA after swa_start epochs
    if epoch >= swa_start:
        swa_model.update_parameters(model)
        swa_scheduler.step()
    else:
        scheduler.step(train_loss)  # Use regular scheduler before SWA
        
    val_loss, val_accuracy = test_loop(val_loader, model, loss_fn)

    epoch_time = time.time() - epoch_start_time
    epoch_times.append(epoch_time)

    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_accuracies.append(train_accuracy)
    val_accuracies.append(val_accuracy)

    print(f"Epoch {epoch+1} completed in {epoch_time:.2f} seconds")
    print(f"Training Loss: {train_loss:.4f}, Training Accuracy: {train_accuracy:.4f}")
    print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}\n")

    # Increment overfitting counter if overfitting detected
    if check_overfitting(train_loss, val_loss, train_accuracy, val_accuracy):
        is_overfitting += 1
    else:
        is_overfitting = 0

    # Check all conditions
    accuracy_gap = abs(train_accuracy - val_accuracy)
    conditions_met = (
        train_loss > 0.13 and
        val_loss > 0.13 and
        train_accuracy < 0.99 and
        val_accuracy < 0.99 and
        accuracy_gap < 0.03
    )

    # Save model if conditions are met and validation loss improved
    if conditions_met and val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model_metrics = {
            'epoch': epoch + 1,
            'train_loss': train_loss,
            'val_loss': val_loss,
            'train_accuracy': train_accuracy,
            'val_accuracy': val_accuracy
        }
        torch.save(model.state_dict(), '../data/models/best_model.pth')

    # Stop if overfitting persists for multiple epochs
    if is_overfitting >= patience:
        print(f"Early stopping triggered due to persistent overfitting.")
        stopped_early = True
        break

update_bn(train_loader, swa_model, device=device)

total_training_time = time.time() - total_start_time

# Save both models
torch.save(model, '../data/models/final_model.pth')
torch.save(swa_model, '../data/models/final_swa_model.pth')

if best_model_metrics:
    print("\nBest model saved with metrics:")
    for key, value in best_model_metrics.items():
        print(f"{key}: {value}")

print(f"\nTraining complete in {total_training_time:.2f} seconds")
print("\n-------------------------------\nDone!")

# 7. Plot model metrics

In [None]:
epochs_range = range(1, len(train_losses) + 1)

## 7.1. Loss graph

In [None]:
# Plot Losses
plt.figure(figsize=(10, 5))
plt.plot(epochs_range, train_losses, label='Training Loss')
plt.plot(epochs_range, val_losses, label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()

## 7.2. Accuracy graph

In [None]:
plt.figure(figsize=(10, 5))
plt.plot(epochs_range, train_accuracies, label='Training Accuracy')
plt.plot(epochs_range, val_accuracies, label='Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()
plt.show()

## 7.3. Epoch duration graph

In [None]:
plt.figure(figsize=(10, 5))
plt.plot(epochs_range, epoch_times, label='Time per Epoch')
plt.xlabel('Epochs')
plt.ylabel('Time (seconds)')
plt.title('Time Taken per Epoch')
plt.legend()
plt.show()

# 8. Test the model

## 8.1. Create & load model

In [None]:
model = ConvolutionalNeuralNetwork()
model.to(device)
model = torch.compile(model, backend="inductor")

In [None]:
stopped_early = False

# Load the best model (if saved during early stopping)
if stopped_early:
    model.load_state_dict(torch.load('../data/models/best_model.pth'))
else:
    model = torch.load('../data/models/full_model_v1.pth')

## 8.2. Test model on test dataset

In [None]:
model.eval()

print("Test Results on the Test Set:")
test_loop(test_loader, model, loss_fn)

## 8.3. Test model on own image

In [None]:
image_transform = v2.Compose([
    v2.Resize((360, 640)),
    v2.RandomAdjustSharpness(sharpness_factor=1.1, p=1.0),  # Subtle sharpness changes
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=mean, std=std),
])

In [None]:
# Load and preprocess the image
image_path = 'Yellow_Draw_2.jpg' 
image = Image.open(image_path).convert('RGB')
input_tensor = image_transform(image)
input_tensor = input_tensor.unsqueeze(0)  # Add batch dimension
input_tensor = input_tensor.to(device)

In [None]:
# Make prediction
with torch.no_grad():
    output = model(input_tensor)
    predicted_class = output.argmax(dim=1).item()

In [None]:
# Get class names
class_names = full_dataset.classes
predicted_label = class_names[predicted_class]
print(f"Predicted class: {predicted_label}")