In [43]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader
from sklearn.metrics import classification_report, roc_auc_score, roc_curve, auc, precision_recall_fscore_support
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from tqdm import tqdm
from sklearn.preprocessing import label_binarize
import seaborn as sns

In [52]:
# Configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset_name = "lora"
dataset_path = f"./data/dataset_{dataset_name}"
batch_size = 16
num_epochs = 30
num_classes = 3
class_names = ['benign', 'malignant', 'normal']
output_dir = f"results/dataset_{dataset_name}/ResNet18/"
os.makedirs(output_dir, exist_ok=True)

In [53]:
# Image transforms
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}


# Load data
image_datasets = {
    x: datasets.ImageFolder(os.path.join(dataset_path, x), data_transforms[x])
    for x in ['train', 'val']
}
dataloaders = {
    x: DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True)
    for x in ['train', 'val']
}

# Model setup
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
model.fc = nn.Linear(model.fc.in_features, num_classes)
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [54]:
history = {
    'epoch': [],
    'train_loss': [],
    'train_acc': [],
    'val_loss': [],
    'val_acc': [],
    'val_f1': [],
    'val_auc_roc': [],
    'val_precision': [],
    'val_recall': []
}

best_val_loss = float('inf')
best_model_wts = None
best_metrics = {}

# Training loop
for epoch in range(num_epochs):
    model.train()
    train_loss, train_corrects = 0.0, 0

    for inputs, labels in tqdm(dataloaders['train'], desc=f"Epoch {epoch+1}/{num_epochs}"):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        preds = outputs.argmax(dim=1)
        train_loss += loss.item() * inputs.size(0)
        train_corrects += (preds == labels).sum().item()

    train_loss /= len(image_datasets['train'])
    train_acc = train_corrects / len(image_datasets['train'])
    print(f"train_loss: {train_loss:.4f}, train_acc: {train_acc:.4f}")

    # Validation
    model.eval()
    val_loss, val_corrects = 0.0, 0
    all_preds, all_labels, all_probs = [], [], []

    with torch.no_grad():
        for inputs, labels in dataloaders['val']:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            probs = torch.softmax(outputs, dim=1)
            preds = probs.argmax(dim=1)

            all_probs.append(probs.cpu().numpy())
            all_preds.append(preds.cpu().numpy())
            all_labels.append(labels.cpu().numpy())

            val_loss += loss.item() * inputs.size(0)
            val_corrects += (preds == labels).sum().item()

    val_loss /= len(image_datasets['val'])
    val_acc = val_corrects / len(image_datasets['val'])

    all_preds = np.concatenate(all_preds)
    all_labels = np.concatenate(all_labels)
    all_probs = np.concatenate(all_probs)
    all_labels_bin = label_binarize(all_labels, classes=list(range(num_classes)))

    precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='macro', zero_division=0)
    roc_auc = roc_auc_score(all_labels_bin, all_probs, multi_class='ovr')

    print(
        f"val_loss: {val_loss:.4f}, "
        f"val_acc: {val_acc:.4f}, "
        f"f1: {f1:.4f}, \n"
        f"roc_auc: {roc_auc:.4f}, "
        f"precision: {precision:.4f}, "
        f"recall: {recall:.4f}"
    )
    print()

    history['epoch'].append(epoch)
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    history['val_f1'].append(f1)
    history['val_auc_roc'].append(roc_auc)
    history['val_precision'].append(precision)
    history['val_recall'].append(recall)

    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model_wts = model.state_dict()
        best_metrics = {
            "epoch": epoch,
            "val_loss": val_loss,
            "val_acc": val_acc,
            "f1": f1,
            "auc_roc": roc_auc,
            "precision": precision,
            "recall": recall
        }

print()
print()
print("Best validation metrics:")
print(f"Epoch: {best_metrics['epoch']}")
print(f"Validation Loss: {best_metrics['val_loss']:.4f}")
print(f"Validation Accuracy: {best_metrics['val_acc']:.4f}")
print(f"F1 Score: {best_metrics['f1']:.4f}")
print(f"AUC ROC: {best_metrics['auc_roc']:.4f}")
print(f"Precision: {best_metrics['precision']:.4f}")
print(f"Recall: {best_metrics['recall']:.4f}")



# Save best model
model.load_state_dict(best_model_wts)
torch.save(model.state_dict(), os.path.join(output_dir, f"{dataset_name}_best_resnet18.pth"))

# Save history
history_df = pd.DataFrame(history)
history_df.to_csv(os.path.join(output_dir, f"{dataset_name}_training_history.csv"), index=False)

Epoch 1/30: 100%|██████████| 66/66 [00:08<00:00,  7.41it/s]


train_loss: 0.5377, train_acc: 0.7600
val_loss: 0.5198, val_acc: 0.8089, f1: 0.8073, 
roc_auc: 0.9458, precision: 0.7961, recall: 0.8355



Epoch 2/30: 100%|██████████| 66/66 [00:08<00:00,  7.45it/s]


train_loss: 0.2503, train_acc: 0.9029
val_loss: 0.3867, val_acc: 0.8471, f1: 0.8152, 
roc_auc: 0.9623, precision: 0.8947, recall: 0.7735



Epoch 3/30: 100%|██████████| 66/66 [00:08<00:00,  7.55it/s]


train_loss: 0.1460, train_acc: 0.9552
val_loss: 0.3233, val_acc: 0.9172, f1: 0.9041, 
roc_auc: 0.9671, precision: 0.9041, recall: 0.9041



Epoch 4/30: 100%|██████████| 66/66 [00:08<00:00,  7.54it/s]


train_loss: 0.1088, train_acc: 0.9695
val_loss: 0.2631, val_acc: 0.8981, f1: 0.8844, 
roc_auc: 0.9783, precision: 0.8844, recall: 0.8844



Epoch 5/30: 100%|██████████| 66/66 [00:08<00:00,  7.52it/s]


train_loss: 0.0870, train_acc: 0.9752
val_loss: 0.2832, val_acc: 0.8917, f1: 0.8764, 
roc_auc: 0.9782, precision: 0.8681, recall: 0.8892



Epoch 6/30: 100%|██████████| 66/66 [00:08<00:00,  7.56it/s]


train_loss: 0.0825, train_acc: 0.9714
val_loss: 0.3639, val_acc: 0.8662, f1: 0.8525, 
roc_auc: 0.9654, precision: 0.8866, recall: 0.8364



Epoch 7/30: 100%|██████████| 66/66 [00:08<00:00,  7.55it/s]


train_loss: 0.0769, train_acc: 0.9771
val_loss: 0.3626, val_acc: 0.8726, f1: 0.8610, 
roc_auc: 0.9694, precision: 0.8423, recall: 0.8903



Epoch 8/30: 100%|██████████| 66/66 [00:08<00:00,  7.54it/s]


train_loss: 0.0919, train_acc: 0.9705
val_loss: 0.3899, val_acc: 0.8408, f1: 0.8161, 
roc_auc: 0.9751, precision: 0.8548, recall: 0.8156



Epoch 9/30: 100%|██████████| 66/66 [00:08<00:00,  7.54it/s]


train_loss: 0.0842, train_acc: 0.9686
val_loss: 0.3465, val_acc: 0.9045, f1: 0.8857, 
roc_auc: 0.9690, precision: 0.8824, recall: 0.8924



Epoch 10/30: 100%|██████████| 66/66 [00:08<00:00,  7.58it/s]


train_loss: 0.0493, train_acc: 0.9819
val_loss: 0.3368, val_acc: 0.8981, f1: 0.8865, 
roc_auc: 0.9713, precision: 0.8915, recall: 0.8847



Epoch 11/30: 100%|██████████| 66/66 [00:08<00:00,  7.55it/s]


train_loss: 0.0309, train_acc: 0.9895
val_loss: 0.2757, val_acc: 0.8981, f1: 0.8853, 
roc_auc: 0.9814, precision: 0.8855, recall: 0.8976



Epoch 12/30: 100%|██████████| 66/66 [00:08<00:00,  7.55it/s]


train_loss: 0.0207, train_acc: 0.9943
val_loss: 0.2969, val_acc: 0.9172, f1: 0.9039, 
roc_auc: 0.9804, precision: 0.8986, recall: 0.9215



Epoch 13/30: 100%|██████████| 66/66 [00:08<00:00,  7.50it/s]


train_loss: 0.0225, train_acc: 0.9933
val_loss: 0.3141, val_acc: 0.9236, f1: 0.9075, 
roc_auc: 0.9744, precision: 0.9117, recall: 0.9037



Epoch 14/30: 100%|██████████| 66/66 [00:08<00:00,  7.55it/s]


train_loss: 0.0203, train_acc: 0.9924
val_loss: 0.2952, val_acc: 0.8917, f1: 0.8831, 
roc_auc: 0.9833, precision: 0.8636, recall: 0.9146



Epoch 15/30: 100%|██████████| 66/66 [00:08<00:00,  7.55it/s]


train_loss: 0.0118, train_acc: 0.9971
val_loss: 0.2695, val_acc: 0.9045, f1: 0.8920, 
roc_auc: 0.9829, precision: 0.8842, recall: 0.9009



Epoch 16/30: 100%|██████████| 66/66 [00:08<00:00,  7.56it/s]


train_loss: 0.0273, train_acc: 0.9895
val_loss: 0.2617, val_acc: 0.8917, f1: 0.8757, 
roc_auc: 0.9822, precision: 0.8652, recall: 0.8892



Epoch 17/30: 100%|██████████| 66/66 [00:08<00:00,  7.57it/s]


train_loss: 0.0218, train_acc: 0.9952
val_loss: 0.3083, val_acc: 0.9045, f1: 0.8917, 
roc_auc: 0.9801, precision: 0.9138, recall: 0.8758



Epoch 18/30: 100%|██████████| 66/66 [00:08<00:00,  7.56it/s]


train_loss: 0.0224, train_acc: 0.9914
val_loss: 0.3804, val_acc: 0.9045, f1: 0.8925, 
roc_auc: 0.9755, precision: 0.8976, recall: 0.9058



Epoch 19/30: 100%|██████████| 66/66 [00:08<00:00,  7.53it/s]


train_loss: 0.0308, train_acc: 0.9895
val_loss: 0.3061, val_acc: 0.9236, f1: 0.9135, 
roc_auc: 0.9733, precision: 0.9101, recall: 0.9252



Epoch 20/30: 100%|██████████| 66/66 [00:08<00:00,  7.54it/s]


train_loss: 0.0531, train_acc: 0.9829
val_loss: 0.3376, val_acc: 0.8917, f1: 0.8861, 
roc_auc: 0.9762, precision: 0.8655, recall: 0.9190



Epoch 21/30: 100%|██████████| 66/66 [00:08<00:00,  7.55it/s]


train_loss: 0.0212, train_acc: 0.9943
val_loss: 0.3161, val_acc: 0.8790, f1: 0.8641, 
roc_auc: 0.9768, precision: 0.8525, recall: 0.8816



Epoch 22/30: 100%|██████████| 66/66 [00:08<00:00,  7.55it/s]


train_loss: 0.0559, train_acc: 0.9752
val_loss: 0.2747, val_acc: 0.9108, f1: 0.8937, 
roc_auc: 0.9795, precision: 0.9025, recall: 0.8878



Epoch 23/30: 100%|██████████| 66/66 [00:08<00:00,  7.53it/s]


train_loss: 0.0391, train_acc: 0.9876
val_loss: 0.2433, val_acc: 0.9108, f1: 0.8960, 
roc_auc: 0.9842, precision: 0.9106, recall: 0.8834



Epoch 24/30: 100%|██████████| 66/66 [00:08<00:00,  7.52it/s]


train_loss: 0.0385, train_acc: 0.9886
val_loss: 0.2637, val_acc: 0.9299, f1: 0.9149, 
roc_auc: 0.9833, precision: 0.9153, recall: 0.9161



Epoch 25/30: 100%|██████████| 66/66 [00:08<00:00,  7.54it/s]


train_loss: 0.0404, train_acc: 0.9848
val_loss: 0.2374, val_acc: 0.9172, f1: 0.9049, 
roc_auc: 0.9856, precision: 0.9015, recall: 0.9085



Epoch 26/30: 100%|██████████| 66/66 [00:08<00:00,  7.55it/s]


train_loss: 0.0269, train_acc: 0.9914
val_loss: 0.4175, val_acc: 0.8917, f1: 0.8769, 
roc_auc: 0.9751, precision: 0.8669, recall: 0.9019



Epoch 27/30: 100%|██████████| 66/66 [00:08<00:00,  7.53it/s]


train_loss: 0.0625, train_acc: 0.9743
val_loss: 0.4432, val_acc: 0.8471, f1: 0.8170, 
roc_auc: 0.9625, precision: 0.8612, recall: 0.7906



Epoch 28/30: 100%|██████████| 66/66 [00:08<00:00,  7.54it/s]


train_loss: 0.0455, train_acc: 0.9905
val_loss: 0.4392, val_acc: 0.8854, f1: 0.8661, 
roc_auc: 0.9570, precision: 0.8549, recall: 0.8810



Epoch 29/30: 100%|██████████| 66/66 [00:08<00:00,  7.51it/s]


train_loss: 0.0467, train_acc: 0.9848
val_loss: 0.3246, val_acc: 0.9045, f1: 0.8863, 
roc_auc: 0.9758, precision: 0.9111, recall: 0.8669



Epoch 30/30: 100%|██████████| 66/66 [00:08<00:00,  7.53it/s]


train_loss: 0.0279, train_acc: 0.9914
val_loss: 0.3657, val_acc: 0.8726, f1: 0.8503, 
roc_auc: 0.9693, precision: 0.8578, recall: 0.8436



Best validation metrics:
Epoch: 24
Validation Loss: 0.2374
Validation Accuracy: 0.9172
F1 Score: 0.9049
AUC ROC: 0.9856
Precision: 0.9015
Recall: 0.9085


In [56]:
# Combine Results

combined_best_metrics = {}

for exp in ['real', 'lora', 'lora_ti', 'lora_ti_controlnet', 'lora_ti_controlnet_refined']:
    best_metric = {}
    history_df = pd.read_csv(os.path.join('results', f"dataset_{exp}", "ResNet18", f"{exp}_training_history.csv"))
    best_metric.update(history_df[history_df['val_loss'] == history_df['val_loss'].min()].iloc[0].to_dict())
    combined_best_metrics[exp] = best_metric

combined_df = pd.DataFrame(combined_best_metrics).T
combined_df.to_csv(os.path.join('results', f"combined_best_metrics.csv"), index=True)

display(combined_df)

Unnamed: 0,epoch,train_loss,train_acc,val_loss,val_acc,val_f1,val_auc_roc,val_precision,val_recall
real,3.0,0.101756,0.967897,0.250601,0.904459,0.88654,0.979247,0.890395,0.883798
lora,24.0,0.040437,0.984762,0.23743,0.917197,0.904877,0.985554,0.901478,0.90849
lora_ti,11.0,0.030079,0.993333,0.267857,0.923567,0.912332,0.980282,0.905823,0.920575
lora_ti_controlnet,7.0,0.036173,0.991429,0.256576,0.904459,0.890187,0.98111,0.877091,0.913881
lora_ti_controlnet_refined,29.0,0.013763,0.995238,0.284104,0.917197,0.9042,0.985152,0.88855,0.925345
