## Imports

In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import models, transforms
import pandas as pd
from sklearn.metrics import roc_auc_score
import numpy as np
from dataset import ActionImageDataset  


# Define transformations for training set and validation set
train_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

val_transforms = transforms.Compose([
    transforms.Resize((224, 224)),  # Adjusted size for ResNet input
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Parameters
batch_size = 32
num_epochs = 10
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_folds = 10  # Number of folds

# Root directory for dataset
root_dir = '../../data/RepCount/extracted/train'
actions_csv = './all_action.csv'

# DataFrame to store benchmarking results
results_df = pd.DataFrame(columns=['Model', 'Fold', 'Epoch', 'Val_Loss', 'Val_Accuracy', 'ROC_AUC'])
    
# Load the corresponding training and validation CSV files for the current fold
train_annotation_file = os.path.join(root_dir, 'annotation', 'pose_train.csv')

# Create dataset instances for the current fold
dataset = ActionImageDataset(root_directory=root_dir, 
                            action_file_path=actions_csv,
                            transform=train_transforms
                            )

# Define the split proportions
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

# Split the dataset into training and validation sets
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Create DataLoaders for the current fold
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

model = models.resnet50(pretrained=True)

    
# Adjust final fully connected layer for binary classification
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 16)
model = model.to(device)

# Calculate class weights
class_counts = np.zeros(16)  # Assuming you have 16 classes

# Count each class in the dataset
for batch in DataLoader(dataset, batch_size=1):
    label = batch['label'].item()  # Extract label
    class_counts[label] += 1

# Compute weights as the inverse of the class frequency
class_weights = 1.0 / class_counts
class_weights = class_weights / class_weights.sum()  # Normalize to sum to 1
class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)

# Define loss function with class weights
criterion = nn.CrossEntropyLoss(weight=class_weights)

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.8)


best_val_acc = 0.0

# Training loop for the current fold
for epoch in range(num_epochs):
    model.train()  # Set the model to training mode
    running_loss = 0.0

    # Training loop
    for batch in train_loader:
        images = batch['rgb'].to(device)  # Get the RGB images
        labels = batch['label'].to(device).long()  # Convert labels to float for BCEWithLogitsLoss

        optimizer.zero_grad()  # Zero the gradients
        outputs = model(images)  # Forward pass
        loss = criterion(outputs, labels)  # Calculate loss
        loss.backward()  # Backward pass
        optimizer.step()  # Optimize weights

        running_loss += loss.item()

    avg_loss = running_loss / len(train_loader)
    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {avg_loss:.4f}")

    # Evaluation on validation set
    model.eval()  # Set the model to evaluation mode
    val_running_loss = 0.0
    correct = 0
    total = 0
    all_labels = []
    all_probs = []

    with torch.no_grad():
        for batch in val_loader:
            images = batch['rgb'].to(device)
            labels = batch['label'].to(device)

            outputs = model(images)
            val_loss = criterion(outputs, labels)
            val_running_loss += val_loss.item()

            # For accuracy
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            # Store true labels and predicted probabilities for ROC AUC calculation
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(torch.softmax(outputs, dim=1).cpu().numpy())  # Use softmax for probabilities

    avg_val_loss = val_running_loss / len(val_loader)
    val_accuracy = correct / total

    # For multi-class ROC AUC (one-vs-rest strategy)
    roc_auc = roc_auc_score(all_labels, all_probs, multi_class='ovr')

    print(f"Validation ROC AUC: {roc_auc:.4f}")

    print(f"Validation Loss: {avg_val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}")

    # Create a dictionary for current results
    epoch_result = {
        'Epoch': epoch + 1,
        'Val_Loss': avg_val_loss,
        'Val_Accuracy': val_accuracy,
        'ROC_AUC': roc_auc
    }

    # Use pd.concat() to add the row to results_df
    results_df = pd.concat([results_df, pd.DataFrame([epoch_result])], ignore_index=True)

    # Save the best model based on validation accuracy
    if val_accuracy > best_val_acc:
        best_val_acc = val_accuracy
        torch.save(model.state_dict(), f'./best_resnet50_v2.pth')
    
    
print(f"Best Validation Accuracy: {best_val_acc:.4f}")
    

# Save the results dataframe to a CSV file
results_df.to_csv('resnet_benchmark_results.csv', index=False)




Epoch [1/10], Loss: 0.8503
Validation ROC AUC: 0.9925
Validation Loss: 0.3483, Validation Accuracy: 0.8403
Epoch [2/10], Loss: 0.3195
Validation ROC AUC: 0.9953
Validation Loss: 0.3597, Validation Accuracy: 0.8556
Epoch [3/10], Loss: 0.1677
Validation ROC AUC: 0.9984
Validation Loss: 0.1691, Validation Accuracy: 0.9405
Epoch [4/10], Loss: 0.0864
Validation ROC AUC: 0.9988
Validation Loss: 0.1795, Validation Accuracy: 0.9405
Epoch [5/10], Loss: 0.0828
Validation ROC AUC: 0.9976
Validation Loss: 0.2691, Validation Accuracy: 0.9193
Epoch [6/10], Loss: 0.1162
Validation ROC AUC: 0.9993
Validation Loss: 0.1549, Validation Accuracy: 0.9626
Epoch [7/10], Loss: 0.0631
Validation ROC AUC: 0.9994
Validation Loss: 0.1150, Validation Accuracy: 0.9686
Epoch [8/10], Loss: 0.0403
Validation ROC AUC: 0.9951
Validation Loss: 0.2936, Validation Accuracy: 0.9473
Epoch [9/10], Loss: 0.0447
Validation ROC AUC: 0.9978
Validation Loss: 0.2080, Validation Accuracy: 0.9584
Epoch [10/10], Loss: 0.0447
Validatio