In [None]:
import os

import numpy as np
import pandas as pd
import torch
import torchvision.transforms.v2 as transforms
from sklearn.model_selection import StratifiedGroupKFold
from torch.utils.data import Subset, DataLoader

from src.classes.dataset import MRIDataset, MRISubset
from src.classes.models import ResNet50variant
from src.classes.training import Trainer
from src.config import PATH_TO_DATASET, PATH_TO_DATASET_CSV, ID_TO_NAME
from src.config import PATH_TO_MODELS

In [None]:
# Load dataset from CSV
df = pd.read_csv(PATH_TO_DATASET_CSV, sep=';', header=0)

# Create a dictionary mapping image indices to (image path, label)
data = {
    idx: (os.path.join(PATH_TO_DATASET, ID_TO_NAME[row['label']], str(row['img_name'])), row['label'])
    for idx, row in df.iterrows()
}

# Convert labels and groups to numpy arrays
y = df['label'].to_numpy()
groups = df['group'].to_numpy()

# Set up Stratified Group K-Fold cross-validation
sgkf = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=7)

# Convert dictionary keys to numpy array (image indices)
X = np.array(list(data.keys()))

# Set device to GPU if available, otherwise fallback to CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Define test data transformations
test_transforms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToDtype(torch.float32),
    transforms.ToTensor()
])

# Load the pre-trained model
path_to_model = os.path.join(PATH_TO_MODELS, "resnet50v.pth")
model = ResNet50variant()
model.load_state_dict(torch.load(path_to_model, map_location=device))

# Initialize variables to track best performance across folds
best_model_state = None
best_test_acc = 0

# Start Stratified Group K-Fold cross-validation
for fold, (train_index, test_index) in enumerate(sgkf.split(X, y, groups)):
    print(f"\nTraining Fold {fold + 1}...")

    # Create the train and test dataset subsets
    train_dataset = MRIDataset(data)  # Assuming 'data' is globally available or passed here
    train_dataset = MRISubset(Subset(train_dataset, train_index), train_bool=True)

    test_dataset = MRISubset(Subset(train_dataset, test_index), train_bool=False, transform=test_transforms)

    # Create dataloaders for training and testing
    dataloaders = {
        "train": DataLoader(train_dataset, batch_size=32, shuffle=True),
        "test": DataLoader(test_dataset, batch_size=32)
    }

    # Instantiate the Trainer class
    criterion = torch.nn.CrossEntropyLoss()  # Define loss function
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)  # Define optimizer
    trainer = Trainer(model=model, criterion=criterion, optimizer=optimizer, device=device)

    # Train the model
    print(f"Training on Fold {fold + 1}...")
    trainer.train(dataloaders['train'], num_epochs=10)

    # Evaluate the model on the test set using the Trainer class
    test_loss, test_acc = trainer.evaluate(dataloaders['test'])

    # Track the best model
    if test_acc > best_test_acc:
        best_test_acc = test_acc
        best_model_state = model.state_dict()

# Save the best model after training across all folds
if best_model_state is not None:
    final_model_path = os.path.join(PATH_TO_MODELS, "best_model_resnet50.pth")
    torch.save(best_model_state, final_model_path)
    print(f"Best model saved at: {final_model_path}")