# Baseline

## Imports

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, models, transforms
from sklearn.model_selection import KFold
from torch.utils.data import random_split
import matplotlib.pyplot as plt
from utils import *

## Helper Function

In [2]:
def custom_collate(batch):
    images, labels = zip(*batch)
    # Stack images into a tensor and convert labels to Float
    return torch.stack(images), torch.tensor(labels, dtype=torch.float32).unsqueeze(1)  # Reshape labels to (batch_size, 1)


### Set your dataset directory

In [3]:
data_dir = '../../data/BiteCount/salient_poses/' 

# Define transformations for training set and validation set
train_transforms = transforms.Compose([
    transforms.Resize((224, 244)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

val_transforms = transforms.Compose([
    transforms.Resize((224, 244)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Load the dataset with training transformations
dataset = datasets.ImageFolder(data_dir, transform=train_transforms)
dataset_size = len(dataset)
class_names = dataset.classes

# Cross-validation setup
kf = KFold(n_splits=5, shuffle=True, random_state=42)

# Create data loaders for train and validation datasets
batch_size = 192
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Initialize model, criterion, optimizer
num_epochs = 25
patience = 5

## Train the model

In [4]:
# Perform 5-Fold Cross Validation
fold_idx = 1
best_val_acc = 0.0
val_accuracies = []
val_f1_scores = []
precisions = []
recalls = []
roc_aucs = []

best_model_acc = 0

for train_idx, val_idx in kf.split(dataset):
    print(f"Fold {fold_idx}")
    train_dataset = torch.utils.data.Subset(dataset, train_idx)
    val_dataset = torch.utils.data.Subset(dataset, val_idx)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=custom_collate)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=custom_collate)

    # Initialize model and optimizer
    model = models.resnet50(pretrained=True)
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, 1)
    model = model.to(device)
    
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.8)
    
    # Training loop with early stopping mechanism
    best_val_acc, best_f1, best_precision, best_recall, best_roc_auc = train_and_evaluate(model, train_loader, val_loader, criterion, optimizer, num_epochs, patience, device, 'bce_loss')
    
    # Store val_acc and val_f1 for current fold
    val_accuracies.append(best_val_acc)
    val_f1_scores.append(best_f1)
    precisions.append(best_precision)
    recalls.append(best_recall)
    roc_aucs.append(best_roc_auc)

    # Check if current validation accuracy is the best so far
    if best_val_acc > best_model_acc:
        best_model_acc = best_val_acc
    
    fold_idx += 1

Fold 1




Epoch [1/25] - Train Loss: 0.7041, Train Acc: 0.4953, Val Loss: 0.7189, Val Acc: 0.5298, Val F1: 0.4279, Val Precision: 0.5429, Val Recall: 0.0929, Val ROC AUC: 0.5092
Validation accuracy improved to 0.5298. Saving model.
Epoch [2/25] - Train Loss: 0.6678, Train Acc: 0.6019, Val Loss: 0.6543, Val Acc: 0.6138, Val F1: 0.6110, Val Precision: 0.6102, Val Recall: 0.5281, Val ROC AUC: 0.6693
Validation accuracy improved to 0.6138. Saving model.
Epoch [3/25] - Train Loss: 0.5168, Train Acc: 0.7434, Val Loss: 0.4885, Val Acc: 0.7515, Val F1: 0.7481, Val Precision: 0.7988, Val Recall: 0.6406, Val ROC AUC: 0.8608
Validation accuracy improved to 0.7515. Saving model.
Epoch [4/25] - Train Loss: 0.2821, Train Acc: 0.8786, Val Loss: 0.3598, Val Acc: 0.8250, Val F1: 0.8241, Val Precision: 0.7627, Val Recall: 0.9193, Val ROC AUC: 0.9319
Validation accuracy improved to 0.8250. Saving model.
Epoch [5/25] - Train Loss: 0.1802, Train Acc: 0.9270, Val Loss: 0.2571, Val Acc: 0.8786, Val F1: 0.8787, Val Pre



Epoch [1/25] - Train Loss: 0.6989, Train Acc: 0.5032, Val Loss: 0.6865, Val Acc: 0.5018, Val F1: 0.4781, Val Precision: 0.5008, Val Recall: 0.7150, Val ROC AUC: 0.5379
Validation accuracy improved to 0.5018. Saving model.
Epoch [2/25] - Train Loss: 0.6729, Train Acc: 0.5695, Val Loss: 0.6875, Val Acc: 0.5671, Val F1: 0.4909, Val Precision: 0.7938, Val Recall: 0.1799, Val ROC AUC: 0.6875
Validation accuracy improved to 0.5671. Saving model.
Epoch [3/25] - Train Loss: 0.5487, Train Acc: 0.7212, Val Loss: 0.4410, Val Acc: 0.8086, Val F1: 0.8078, Val Precision: 0.8548, Val Recall: 0.7430, Val ROC AUC: 0.8902
Validation accuracy improved to 0.8086. Saving model.
Epoch [4/25] - Train Loss: 0.2968, Train Acc: 0.8675, Val Loss: 0.3074, Val Acc: 0.8646, Val F1: 0.8645, Val Precision: 0.8861, Val Recall: 0.8364, Val ROC AUC: 0.9475
Validation accuracy improved to 0.8646. Saving model.
Epoch [5/25] - Train Loss: 0.1677, Train Acc: 0.9297, Val Loss: 0.3637, Val Acc: 0.8425, Val F1: 0.8409, Val Pre



Epoch [1/25] - Train Loss: 0.6987, Train Acc: 0.5123, Val Loss: 0.6945, Val Acc: 0.5379, Val F1: 0.5065, Val Precision: 0.5849, Val Recall: 0.2870, Val ROC AUC: 0.5739
Validation accuracy improved to 0.5379. Saving model.
Epoch [2/25] - Train Loss: 0.6592, Train Acc: 0.6074, Val Loss: 0.6184, Val Acc: 0.6359, Val F1: 0.6123, Val Precision: 0.5938, Val Recall: 0.8796, Val ROC AUC: 0.7257
Validation accuracy improved to 0.6359. Saving model.
Epoch [3/25] - Train Loss: 0.5066, Train Acc: 0.7449, Val Loss: 0.3347, Val Acc: 0.8355, Val F1: 0.8354, Val Precision: 0.8255, Val Recall: 0.8542, Val ROC AUC: 0.9323
Validation accuracy improved to 0.8355. Saving model.
Epoch [4/25] - Train Loss: 0.2606, Train Acc: 0.8932, Val Loss: 0.3305, Val Acc: 0.8436, Val F1: 0.8432, Val Precision: 0.8880, Val Recall: 0.7894, Val ROC AUC: 0.9438
Validation accuracy improved to 0.8436. Saving model.
Epoch [5/25] - Train Loss: 0.1595, Train Acc: 0.9370, Val Loss: 0.2100, Val Acc: 0.9172, Val F1: 0.9172, Val Pre



Epoch [1/25] - Train Loss: 0.7036, Train Acc: 0.4879, Val Loss: 0.6804, Val Acc: 0.5175, Val F1: 0.4633, Val Precision: 0.5050, Val Recall: 0.8429, Val ROC AUC: 0.5593
Validation accuracy improved to 0.5175. Saving model.
Epoch [2/25] - Train Loss: 0.6636, Train Acc: 0.5854, Val Loss: 0.6477, Val Acc: 0.6215, Val F1: 0.5908, Val Precision: 0.7474, Val Recall: 0.3452, Val ROC AUC: 0.7281
Validation accuracy improved to 0.6215. Saving model.
Epoch [3/25] - Train Loss: 0.4847, Train Acc: 0.7631, Val Loss: 0.3986, Val Acc: 0.7921, Val F1: 0.7915, Val Precision: 0.7564, Val Recall: 0.8500, Val ROC AUC: 0.8957
Validation accuracy improved to 0.7921. Saving model.
Epoch [4/25] - Train Loss: 0.2736, Train Acc: 0.8821, Val Loss: 0.2603, Val Acc: 0.8575, Val F1: 0.8572, Val Precision: 0.8197, Val Recall: 0.9095, Val ROC AUC: 0.9575
Validation accuracy improved to 0.8575. Saving model.
Epoch [5/25] - Train Loss: 0.1658, Train Acc: 0.9326, Val Loss: 0.1894, Val Acc: 0.9077, Val F1: 0.9077, Val Pre



Epoch [1/25] - Train Loss: 0.7113, Train Acc: 0.4955, Val Loss: 0.7033, Val Acc: 0.5035, Val F1: 0.4798, Val Precision: 0.5184, Val Recall: 0.2926, Val ROC AUC: 0.5253
Validation accuracy improved to 0.5035. Saving model.
Epoch [2/25] - Train Loss: 0.6880, Train Acc: 0.5372, Val Loss: 0.6520, Val Acc: 0.5771, Val F1: 0.5203, Val Precision: 0.5499, Val Recall: 0.9147, Val ROC AUC: 0.6741
Validation accuracy improved to 0.5771. Saving model.
Epoch [3/25] - Train Loss: 0.5485, Train Acc: 0.7234, Val Loss: 0.4492, Val Acc: 0.7500, Val F1: 0.7470, Val Precision: 0.7107, Val Recall: 0.8548, Val ROC AUC: 0.8597
Validation accuracy improved to 0.7500. Saving model.
Epoch [4/25] - Train Loss: 0.3130, Train Acc: 0.8588, Val Loss: 0.3401, Val Acc: 0.8376, Val F1: 0.8366, Val Precision: 0.7980, Val Recall: 0.9101, Val ROC AUC: 0.9305
Validation accuracy improved to 0.8376. Saving model.
Epoch [5/25] - Train Loss: 0.1946, Train Acc: 0.9110, Val Loss: 0.3082, Val Acc: 0.8645, Val F1: 0.8628, Val Pre

In [5]:
print(f"Best Validation Accuracy: {best_val_acc}")
print(f"Validation Accuracies for all folds: {val_accuracies}")
print(f"Validation F1 Scores for all folds: {val_f1_scores}")
print(f"Precision for all folds: {precisions}")
print(f"Recall for all folds: {recalls}")
print(f"ROC AUC for all folds: {roc_aucs}")

# Calculate mean and standard deviation of validation accuracies
mean_val_acc = np.mean(val_accuracies)
std_val_acc = np.std(val_accuracies)

print(f"Mean Validation Accuracy: {mean_val_acc:.4f} +/- {std_val_acc:.4f}")

Best Validation Accuracy: 0.9602803738317757
Validation Accuracies for all folds: [0.9521586931155193, 0.9591598599766628, 0.9498249708284714, 0.9509345794392523, 0.9602803738317757]
Validation F1 Scores for all folds: [0.9521304644933555, 0.959152518478943, 0.949739592723817, 0.9508518637877322, 0.9602814579833118]
Precision for all folds: [0.96, 0.9712230215827338, 0.9949109414758269, 0.9821428571428571, 0.9716981132075472]
Recall for all folds: [0.9388753056234719, 0.9462616822429907, 0.9050925925925926, 0.9166666666666666, 0.9493087557603687]
ROC AUC for all folds: [0.9841403248340901, 0.9901803803672963, 0.9864270152505447, 0.9902086063783311, 0.9908161705287527]
Mean Validation Accuracy: 0.9545 +/- 0.0044


`Best Validation Accuracy: 0.9602803738317757
Validation Accuracies for all folds: [0.9521586931155193, 0.9591598599766628, 0.9498249708284714, 0.9509345794392523, 0.9602803738317757]
Validation F1 Scores for all folds: [0.9521304644933555, 0.959152518478943, 0.949739592723817, 0.9508518637877322, 0.9602814579833118]
Precision for all folds: [0.96, 0.9712230215827338, 0.9949109414758269, 0.9821428571428571, 0.9716981132075472]
Recall for all folds: [0.9388753056234719, 0.9462616822429907, 0.9050925925925926, 0.9166666666666666, 0.9493087557603687]
ROC AUC for all folds: [0.9841403248340901, 0.9901803803672963, 0.9864270152505447, 0.9902086063783311, 0.9908161705287527]
Mean Validation Accuracy: 0.9545 +/- 0.0044`