# Baseline

## Imports

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, models, transforms
from sklearn.model_selection import KFold
from torch.utils.data import random_split
import matplotlib.pyplot as plt
from utils import *

## Helper Function

In [2]:
def custom_collate(batch):
    images, labels = zip(*batch)
    # Stack images into a tensor and convert labels to Float
    return torch.stack(images), torch.tensor(labels, dtype=torch.float32).unsqueeze(1)  # Reshape labels to (batch_size, 1)


### Set your dataset directory

In [3]:
data_dir = '../../data/BiteCount/salient_poses/' 

# Define transformations for training set and validation set
train_transforms = transforms.Compose([
    transforms.Resize((224, 244)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

val_transforms = transforms.Compose([
    transforms.Resize((224, 244)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Load the dataset with training transformations
dataset = datasets.ImageFolder(data_dir, transform=train_transforms)
dataset_size = len(dataset)
class_names = dataset.classes

# Cross-validation setup
kf = KFold(n_splits=5, shuffle=True, random_state=42)

# Create data loaders for train and validation datasets
batch_size = 192
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Initialize model, criterion, optimizer
num_epochs = 25
patience = 5

## Train the model

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, roc_auc_score
from torch.utils.data import DataLoader
import numpy as np

fold_idx = 1
best_val_acc = 0.0
val_accuracies = []
val_f1_scores = []
precisions = []
recalls = []
roc_aucs = []

best_model_acc = 0

# Early stopping parameters
early_stopping_patience = 5

for train_idx, val_idx in kf.split(dataset):
    print(f"Fold {fold_idx}")
    train_dataset = torch.utils.data.Subset(dataset, train_idx)
    val_dataset = torch.utils.data.Subset(dataset, val_idx)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=custom_collate)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=custom_collate)

    # Initialize model, criterion, and optimizer
    model = models.resnet50(pretrained=True)
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, 1)
    model = model.to(device)
    
    criterion = nn.BCELoss()
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.8)
    
    best_val_acc_for_fold = 0.0
    patience_counter = 0

    # Training loop
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device).float()

            optimizer.zero_grad()

            outputs = model(inputs)
            outputs = torch.sigmoid(outputs)

            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader)}")

        # Validation phase
        model.eval()
        val_labels = []
        val_preds = []

        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device).float()

                outputs = model(inputs)
                outputs = torch.sigmoid(outputs).cpu().numpy()
                preds = np.round(outputs)

                val_labels.extend(labels.cpu().numpy())
                val_preds.extend(preds)

        # Calculate validation metrics
        val_acc = accuracy_score(val_labels, val_preds)
        val_f1 = f1_score(val_labels, val_preds)
        val_precision = precision_score(val_labels, val_preds)
        val_recall = recall_score(val_labels, val_preds)
        val_roc_auc = roc_auc_score(val_labels, val_preds)

        print(f"Validation Accuracy: {val_acc:.4f}, F1 Score: {val_f1:.4f}, Precision: {val_precision:.4f}, Recall: {val_recall:.4f}, ROC AUC: {val_roc_auc:.4f}")

        # Check for early stopping
        if val_acc > best_val_acc_for_fold:
            best_val_acc_for_fold = val_acc
            patience_counter = 0  # Reset counter if improvement
        else:
            patience_counter += 1

        if patience_counter >= early_stopping_patience:
            print(f"Early stopping triggered after {epoch+1} epochs.")
            break

    # Store metrics for this fold
    val_accuracies.append(best_val_acc_for_fold)
    val_f1_scores.append(val_f1)
    precisions.append(val_precision)
    recalls.append(val_recall)
    roc_aucs.append(val_roc_auc)

    # Check if current validation accuracy is the best so far
    if best_val_acc_for_fold > best_model_acc:
        best_model_acc = best_val_acc_for_fold

    fold_idx += 1

# After all folds, display average results
print(f"Average Validation Accuracy: {np.mean(val_accuracies):.4f}")
print(f"Average F1 Score: {np.mean(val_f1_scores):.4f}")
print(f"Average Precision: {np.mean(precisions):.4f}")
print(f"Average Recall: {np.mean(recalls):.4f}")
print(f"Average ROC AUC: {np.mean(roc_aucs):.4f}")


Fold 1




Epoch 1/25, Loss: 0.6999474002255334
Validation Accuracy: 0.4889, F1 Score: 0.6224, Precision: 0.4807, Recall: 0.8826, ROC AUC: 0.5061
Epoch 2/25, Loss: 0.6702480382389493
Validation Accuracy: 0.5461, F1 Score: 0.6655, Precision: 0.5133, Recall: 0.9462, ROC AUC: 0.5635
Epoch 3/25, Loss: 0.5227804597881105
Validation Accuracy: 0.7655, F1 Score: 0.7235, Precision: 0.8270, Recall: 0.6430, ROC AUC: 0.7601
Epoch 4/25, Loss: 0.29614125357733834
Validation Accuracy: 0.8343, F1 Score: 0.8179, Precision: 0.8598, Recall: 0.7800, ROC AUC: 0.8319
Epoch 5/25, Loss: 0.18481935643487507
Validation Accuracy: 0.8588, F1 Score: 0.8558, Precision: 0.8349, Recall: 0.8778, ROC AUC: 0.8596
Epoch 6/25, Loss: 0.10918362562855084
Validation Accuracy: 0.9160, F1 Score: 0.9093, Precision: 0.9377, Recall: 0.8826, ROC AUC: 0.9145
Epoch 7/25, Loss: 0.09975275293820435
Validation Accuracy: 0.8868, F1 Score: 0.8816, Precision: 0.8805, Recall: 0.8826, ROC AUC: 0.8866
Epoch 8/25, Loss: 0.10648325830698013
Validation Ac



Epoch 1/25, Loss: 0.7031464676062266
Validation Accuracy: 0.4877, F1 Score: 0.5285, Precision: 0.4891, Recall: 0.5748, ROC AUC: 0.4878
Epoch 2/25, Loss: 0.6773571405145857
Validation Accuracy: 0.6114, F1 Score: 0.6312, Precision: 0.6000, Recall: 0.6659, ROC AUC: 0.6115
Epoch 3/25, Loss: 0.5616550064749188
Validation Accuracy: 0.7760, F1 Score: 0.7975, Precision: 0.7269, Recall: 0.8832, ROC AUC: 0.7761
Epoch 4/25, Loss: 0.33453115324179333
Validation Accuracy: 0.8425, F1 Score: 0.8319, Precision: 0.8907, Recall: 0.7804, ROC AUC: 0.8424
Epoch 5/25, Loss: 0.18860393597020042
Validation Accuracy: 0.8670, F1 Score: 0.8649, Precision: 0.8774, Recall: 0.8528, ROC AUC: 0.8670
Epoch 6/25, Loss: 0.11931233315004243
Validation Accuracy: 0.9113, F1 Score: 0.9114, Precision: 0.9093, Recall: 0.9136, ROC AUC: 0.9113
Epoch 7/25, Loss: 0.08490665132800738
Validation Accuracy: 0.9218, F1 Score: 0.9216, Precision: 0.9227, Recall: 0.9206, ROC AUC: 0.9218
Epoch 8/25, Loss: 0.06557335476908419
Validation Ac



Epoch 1/25, Loss: 0.7064103384812673
Validation Accuracy: 0.5193, F1 Score: 0.3814, Precision: 0.5427, Recall: 0.2940, ROC AUC: 0.5211
Epoch 2/25, Loss: 0.6780653728379143
Validation Accuracy: 0.6149, F1 Score: 0.5875, Precision: 0.6386, Recall: 0.5440, ROC AUC: 0.6155
Epoch 3/25, Loss: 0.5676218817631403
Validation Accuracy: 0.7643, F1 Score: 0.7960, Precision: 0.7061, Recall: 0.9120, ROC AUC: 0.7631
Epoch 4/25, Loss: 0.3283341783616278
Validation Accuracy: 0.8506, F1 Score: 0.8443, Precision: 0.8897, Recall: 0.8032, ROC AUC: 0.8510
Epoch 5/25, Loss: 0.1811185884806845
Validation Accuracy: 0.8600, F1 Score: 0.8429, Precision: 0.9699, Recall: 0.7454, ROC AUC: 0.8609
Epoch 6/25, Loss: 0.14402834326028824
Validation Accuracy: 0.8996, F1 Score: 0.8988, Precision: 0.9139, Recall: 0.8843, ROC AUC: 0.8998
Epoch 7/25, Loss: 0.08150501466459698
Validation Accuracy: 0.9207, F1 Score: 0.9220, Precision: 0.9136, Recall: 0.9306, ROC AUC: 0.9206
Epoch 8/25, Loss: 0.07004134087926811
Validation Accu



Epoch 1/25, Loss: 0.7117649449242486
Validation Accuracy: 0.5269, F1 Score: 0.5361, Precision: 0.5166, Recall: 0.5571, ROC AUC: 0.5274
Epoch 2/25, Loss: 0.6714539494779375
Validation Accuracy: 0.6192, F1 Score: 0.6810, Precision: 0.5781, Recall: 0.8286, ROC AUC: 0.6230
Epoch 3/25, Loss: 0.5381505456235673
Validation Accuracy: 0.7325, F1 Score: 0.7675, Precision: 0.6690, Recall: 0.9000, ROC AUC: 0.7356
Epoch 4/25, Loss: 0.30381130261553657
Validation Accuracy: 0.8493, F1 Score: 0.8455, Precision: 0.8506, Recall: 0.8405, ROC AUC: 0.8491
Epoch 5/25, Loss: 0.19267304283049372
Validation Accuracy: 0.8937, F1 Score: 0.8878, Precision: 0.9207, Recall: 0.8571, ROC AUC: 0.8930
Epoch 6/25, Loss: 0.12152295021547212
Validation Accuracy: 0.8937, F1 Score: 0.9001, Precision: 0.8350, Recall: 0.9762, ROC AUC: 0.8952
Epoch 7/25, Loss: 0.08848851153420077
Validation Accuracy: 0.8867, F1 Score: 0.8921, Precision: 0.8372, Recall: 0.9548, ROC AUC: 0.8879
Epoch 8/25, Loss: 0.058576513909631304
Validation A



Epoch 1/25, Loss: 0.7047247555520799
Validation Accuracy: 0.5222, F1 Score: 0.5283, Precision: 0.5289, Recall: 0.5276, ROC AUC: 0.5221
Epoch 2/25, Loss: 0.6722855534818437
Validation Accuracy: 0.6227, F1 Score: 0.6066, Precision: 0.6434, Recall: 0.5737, ROC AUC: 0.6234
Epoch 3/25, Loss: 0.5187958909405602
Validation Accuracy: 0.7360, F1 Score: 0.6762, Precision: 0.8939, Recall: 0.5438, ROC AUC: 0.7387
Epoch 4/25, Loss: 0.2927172796593772
Validation Accuracy: 0.8668, F1 Score: 0.8787, Precision: 0.8162, Recall: 0.9516, ROC AUC: 0.8656
Epoch 5/25, Loss: 0.15880020086963972
Validation Accuracy: 0.9171, F1 Score: 0.9196, Precision: 0.9042, Recall: 0.9355, ROC AUC: 0.9168
Epoch 6/25, Loss: 0.11065834305352634
Validation Accuracy: 0.8925, F1 Score: 0.8982, Precision: 0.8638, Recall: 0.9355, ROC AUC: 0.8919
Epoch 7/25, Loss: 0.0666686145381795
Validation Accuracy: 0.9276, F1 Score: 0.9272, Precision: 0.9450, Recall: 0.9101, ROC AUC: 0.9278
Epoch 8/25, Loss: 0.05946202100151115
Validation Accu

In [6]:
val_accuracies

[0.9544924154025671,
 0.9439906651108518,
 0.9521586931155193,
 0.955607476635514,
 0.9450934579439252]

`Best Validation Accuracy: 0.9602803738317757
Validation Accuracies for all folds: [0.9521586931155193, 0.9591598599766628, 0.9498249708284714, 0.9509345794392523, 0.9602803738317757]
Validation F1 Scores for all folds: [0.9521304644933555, 0.959152518478943, 0.949739592723817, 0.9508518637877322, 0.9602814579833118]
Precision for all folds: [0.96, 0.9712230215827338, 0.9949109414758269, 0.9821428571428571, 0.9716981132075472]
Recall for all folds: [0.9388753056234719, 0.9462616822429907, 0.9050925925925926, 0.9166666666666666, 0.9493087557603687]
ROC AUC for all folds: [0.9841403248340901, 0.9901803803672963, 0.9864270152505447, 0.9902086063783311, 0.9908161705287527]
Mean Validation Accuracy: 0.9545 +/- 0.0044`