## Imports

In [2]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import models, transforms
import pandas as pd
from sklearn.metrics import roc_auc_score
import numpy as np
from dataset import ActionImageDatasetV2

# Define transformations for training set and validation set
train_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

val_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Parameters
batch_size = 32
num_epochs = 10
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_folds = 10

# Root directory for dataset
root_dir = '../../data/RepCount/extracted/train'
actions_csv = './all_action.csv'

# DataFrame to store benchmarking results
results_df = pd.DataFrame(columns=['Model', 'Fold', 'Epoch', 'Val_Loss', 'Val_Accuracy', 'ROC_AUC'])

# Load dataset
dataset = ActionImageDatasetV2(root_directory=root_dir, 
                               action_file_path=actions_csv,
                               transform=train_transforms)

# Define the split proportions
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

# Split dataset into training and validation sets
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Create DataLoaders for the current fold
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# Calculate class weights based on training dataset
class_counts = np.zeros(dataset.num_actions)

# Iterate through the dataset to get counts for each class
for batch in DataLoader(train_dataset, batch_size=1):
    label = batch['label'].squeeze().numpy()  # Get label as a numpy array
    class_counts += label  # Assuming label is a binary array (multi-hot encoding)

# Compute weights as the inverse of class frequency
class_weights = 1.0 / (class_counts + 1e-6)  # Small epsilon to prevent division by zero
class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)

# Define model
model = models.resnet50(pretrained=True)

# Adjust final fully connected layer for multi-label classification
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, dataset.num_actions)
model = model.to(device)

# Define loss function with class weights
criterion = nn.BCEWithLogitsLoss(pos_weight=class_weights)

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.8)
best_val_acc = 0.0

# Training loop
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for batch in train_loader:
        images = batch['rgb'].to(device)
        labels = batch['label'].float().to(device)  # Convert labels to float for BCEWithLogitsLoss

        optimizer.zero_grad()
        outputs = model(images)  # Forward pass
        loss = criterion(outputs, labels)  # Calculate loss
        loss.backward()  # Backward pass
        optimizer.step()  # Optimize weights

        running_loss += loss.item()

    avg_loss = running_loss / len(train_loader)
    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {avg_loss:.4f}")

    # Evaluation on validation set
    model.eval()
    val_running_loss = 0.0
    correct = 0
    total = 0
    all_labels = []
    all_probs = []

    with torch.no_grad():
        for batch in val_loader:
            images = batch['rgb'].to(device)
            labels = batch['label'].float().to(device)

            outputs = model(images)
            val_loss = criterion(outputs, labels)
            val_running_loss += val_loss.item()

            # Apply sigmoid for probabilities in BCEWithLogitsLoss
            predicted_probs = torch.sigmoid(outputs)
            predicted_labels = (predicted_probs >= 0.5).float()  # Threshold to get binary output

            # Calculate accuracy for multi-label: compare each class individually
            correct += (predicted_labels == labels).float().sum().item()
            total += labels.numel()  # Total number of elements across all classes

            # Store true labels and predicted probabilities for ROC AUC calculation
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(predicted_probs.cpu().numpy())

    avg_val_loss = val_running_loss / len(val_loader)
    val_accuracy = correct / total

    # For multi-label ROC AUC (one-vs-rest strategy)
    roc_auc = roc_auc_score(np.array(all_labels), np.array(all_probs), average='macro')

    print(f"Validation ROC AUC: {roc_auc:.4f}")
    print(f"Validation Loss: {avg_val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}")

    # Log results
    epoch_result = {
        'Epoch': epoch + 1,
        'Val_Loss': avg_val_loss,
        'Val_Accuracy': val_accuracy,
        'ROC_AUC': roc_auc
    }

    results_df = pd.concat([results_df, pd.DataFrame([epoch_result])], ignore_index=True)

    # Save the best model based on validation accuracy
    if val_accuracy > best_val_acc:
        best_val_acc = val_accuracy
        torch.save(model.state_dict(), f'./best_resnet50_v4.pth')

print(f"Best Validation Accuracy: {best_val_acc:.4f}")

# Save results to CSV
results_df.to_csv('resnet_benchmark_results.csv', index=False)




Epoch [1/10], Loss: 0.0206
Validation ROC AUC: 0.5471
Validation Loss: 0.0032, Validation Accuracy: 0.9401
Epoch [2/10], Loss: 0.0028
Validation ROC AUC: 0.5442
Validation Loss: 0.0025, Validation Accuracy: 0.9401
Epoch [3/10], Loss: 0.0024
Validation ROC AUC: 0.5612
Validation Loss: 0.0023, Validation Accuracy: 0.9401
Epoch [4/10], Loss: 0.0023
Validation ROC AUC: 0.5703
Validation Loss: 0.0022, Validation Accuracy: 0.9401
Epoch [5/10], Loss: 0.0022
Validation ROC AUC: 0.5829
Validation Loss: 0.0021, Validation Accuracy: 0.9401
Epoch [6/10], Loss: 0.0021
Validation ROC AUC: 0.5889
Validation Loss: 0.0021, Validation Accuracy: 0.9401
Epoch [7/10], Loss: 0.0021
Validation ROC AUC: 0.5885
Validation Loss: 0.0021, Validation Accuracy: 0.9401
Epoch [8/10], Loss: 0.0021
Validation ROC AUC: 0.6020
Validation Loss: 0.0020, Validation Accuracy: 0.9401
Epoch [9/10], Loss: 0.0020
Validation ROC AUC: 0.6002
Validation Loss: 0.0020, Validation Accuracy: 0.9401
Epoch [10/10], Loss: 0.0020
Validatio