In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader, random_split
from sklearn.utils.class_weight import compute_class_weight

#Set hyperparameters
ImageDir = r"C:\Users\hgood\OneDrive\Processed_REMBRANDT"
CSVAcess = r"C:\Users\hgood\Downloads\clinical_cleaned_v2.csv"
BatchSize = 32
EpochCount = 10
LearningRate = 0.001
Classes = 3
BestModelPathCNN = 'best_model.pth'

# set dataset and class weights 
datasetREM = REMDataset(ImageDir, CSVAcess, transform=TransformImagesREM)
REMlabels = [int(label) for _, label in datasetREM]
Weights = compute_class_weight(class_weight='balanced', classes=np.unique(REMlabels), y=REMlabels)
TensorWeights = torch.tensor(Weights, dtype=torch.float)

# Split dataset into train/validation/testing
TrainSplit = int(0.8 * len(datasetREM))
ValSplit = int(0.1 * len(datasetREM))
TestSplit = len(datasetREM) - TrainSplit - ValSplit
TrainData, ValData, TestData = random_split(datasetREM, [TrainSplit, ValSplit, TestSplit])
#Create data loaders three different ways.
TrainingLoader = DataLoader(TrainData, batch_size=BatchSize, shuffle=True)
ValidationLoader = DataLoader(ValData, batch_size=BatchSize, shuffle=False)
TestingLoader = DataLoader(TestData, batch_size=BatchSize, shuffle=False)

# Create model, loss, and optimizer, utilizing an Adam optimizer and learning rate scheduler as well.
TorchDevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
CNNModel = build_resnet(num_classes=Classes).to(TorchDevice)
EntropyREM = nn.CrossEntropyLoss(weight=TensorWeights.to(TorchDevice))
OptimizerREM = optim.Adam(CNNModel.parameters(), lr=LearningRate)
SchedulerREM = optim.lr_scheduler.StepLR(OptimizerREM, step_size=5, gamma=0.1)

# Resume training from checkpoint unless we start training from scratch
start_epoch = 0
best_val_accuracy = 0.0
if os.path.exists(BestModelPathCNN):
    checkpoint = torch.load(BestModelPathCNN)
    CNNModel.load_state_dict(checkpoint['model_state_dict'])
    OptimizerREM.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    best_val_accuracy = checkpoint['val_accuracy']
    print(f"Resuming training at epoch {start_epoch} with Validation Accuracy: {best_val_accuracy:.2f}%")
else:
    print("Starting training from scratch")

TotalEpochs = start_epoch + EpochCount

# Function to calculate accuracy and loss metrics.
def evaluate_accuracy_metrics(loader):
    CNNModel.eval()
    CorrectPredictions = 0
    TotalPredictions = 0
    TotalLoss = 0.0
    with torch.no_grad():
        for inputs, REMlabels in loader:
            inputs, REMlabels = inputs.to(TorchDevice), REMlabels.to(TorchDevice)
            outputs = CNNModel(inputs)
            loss = EntropyREM(outputs, REMlabels)
            TotalLoss += loss.item()
            _, predicted = torch.max(outputs, 1)
            CorrectPredictions += (predicted == REMlabels).sum().item()
            TotalPredictions += REMlabels.size(0)
    return TotalLoss / len(loader), 100 * CorrectPredictions/ TotalPredictions

# Create the training loop.
for epoch in range(start_epoch, TotalEpochs):
    CNNModel.train()
    TotalLoss = 0.0
    CorrectPredictions = 0
    TotalPredictions = 0
# Reset gradients, pass outputs forward, calculate total loss, and utilize optimizer.
    for inputs, REMlabels in TrainingLoader:
        inputs, REMlabels = inputs.to(TorchDevice), REMlabels.to(TorchDevice)
        OptimizerREM.zero_grad()
        outputs = CNNModel(inputs)
        loss = EntropyREM(outputs, REMlabels)
        loss.backward()
        OptimizerREM.step()
        TotalLoss += loss.item()
        _, predicted = torch.max(outputs, 1)
        CorrectPredictions += (predicted == REMlabels).sum().item()
        TotalPredictions += REMlabels.size(0)
#Evaluate the model on both validation and test sets.
    AvgTrainingLoss = TotalLoss/len(TrainingLoader)
    TrainingAccuracy = 100 * CorrectPredictions/TotalPredictions
    ValidationLoss, ValidationAccuracy = evaluate_accuracy_metrics(ValidationLoader)
    TestingLoss, TestingAccuracy = evaluate_accuracy_metrics(TestingLoader)

    print(f"Epoch [{epoch + 1}/{TotalEpochs}], Train Loss: {AvgTrainingLoss:.4f}, Train Accuracy: {TrainingAccuracy:.2f}%, "
          f"Val Loss: {ValidationLoss:.4f}, Val Accuracy: {ValidationAccuracy:.2f}%, "
          f"Test Loss: {TestingLoss:.4f}, Test Accuracy: {TestingAccuracy:.2f}%")

    SchedulerREM.step()
# Save the latest checkpoint.
    torch.save({
        'model_state_dict': CNNModel.state_dict(),
        'optimizer_state_dict': OptimizerREM.state_dict(),
        'epoch': epoch
    }, 'model_checkpoint.pth')

# Save best model if the validation accuracy improves.
    if ValidationAccuracy > best_val_accuracy:
        best_val_accuracy = ValidationAccuracy
        torch.save({
            'model_state_dict': CNNModel.state_dict(),
            'optimizer_state_dict': OptimizerREM.state_dict(),
            'epoch': epoch,
            'val_accuracy': ValidationAccuracy,
            'test_accuracy': TestingAccuracy
        }, BestModelPathCNN)
        print(f"Best model saved at epoch {epoch + 1} with Val Accuracy: {ValidationAccuracy:.2f}% and Test Accuracy: {TestingAccuracy:.2f}%")