In [2]:
import torch
import numpy as np
from collections import Counter

from utils.dataloader import get_dataloaders
from utils.train import train_model, test_model
from utils.metrics import (
    compute_full_metrics,
    plot_confusion_matrix,
    plot_roc_auc,
    save_epoch_history_excel,
    save_training_summary_excel,
)
from utils.gradcam import save_gradcam_samples
from utils.load_model import load_trained_model

from models.single_models import get_single_model
from models.ensemble import SoftVotingEnsemble, WeightedEnsemble

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
DATA_DIR = "lung_ct_split"  # change if needed
BATCH_SIZE = 8
EPOCHS = 20
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("Device:", DEVICE)

Device: cuda


In [4]:
train_loader, val_loader, test_loader, class_names = get_dataloaders(
    data_dir=DATA_DIR,
    batch_size=BATCH_SIZE
)

num_classes = len(class_names)
print("Classes: ",class_names)

Classes:  ['Bengin cases', 'Malignant cases', 'Normal cases']


In [5]:
single_models = {
    "resnet50": lambda: get_single_model("resnet50", num_classes),
    "efficientnet_b0": lambda: get_single_model("efficientnet_b0", num_classes),
    "vgg16": lambda: get_single_model("vgg16", num_classes),
}

results = {}
histories = {}
trained_models = {}

In [None]:
for name, builder in single_models.items():
    print(f"\n===== TRAINING {name.upper()} =====")

    model = builder().to(DEVICE)

    # -------- TRAIN --------
    model, history, summary = train_model(
        model,
        train_loader,
        val_loader,
        DEVICE,
        epochs=EPOCHS,
        model_name=name,
    )

    # -------- LOAD BEST MODEL --------
    model = load_trained_model(
        builder().to(DEVICE),
        f"results/checkpoints/{name}_best.pth",
        DEVICE,
    )

    # -------- TEST --------
    test_acc, report, cm, labels, preds, probs, images = test_model(
        model,
        test_loader,
        DEVICE,
        class_names,
        return_details=True,
    )

    accuracy, precision, recall, f1 = compute_full_metrics(labels, preds)

    print(f"\n{name} Test Accuracy: {accuracy:.4f}")
    print(report)

    # -------- SAVE METRICS --------
    results[name] = accuracy
    histories[name] = history
    trained_models[name] = model

    # -------- PLOTS --------
    plot_confusion_matrix(cm, class_names)

    plot_roc_auc(
        labels,
        probs,
        class_names,
        save_path=f"results/roc_auc/{name}.png",
    )

    # -------- SAVE LOGS --------
    save_epoch_history_excel(history, model_name=name)
    save_training_summary_excel(summary, model_name=name)

    # -------- GRAD-CAM --------
    save_gradcam_samples(
        model,
        images,
        class_names,
        save_dir=f"results/gradcam/{name}",
        device=DEVICE,
    )

In [None]:

print("\n===== SOFT VOTING ENSEMBLE =====")

soft_ensemble = SoftVotingEnsemble(list(trained_models.values())).to(DEVICE)

acc, report, cm, labels, preds, probs, images = test_model(
    soft_ensemble,
    test_loader,
    DEVICE,
    class_names,
    return_details=True,
)

accuracy, precision, recall, f1 = compute_full_metrics(labels, preds)

print(f"Soft Ensemble Accuracy: {accuracy:.4f}")
print(report)

plot_confusion_matrix(cm, class_names)
plot_roc_auc(labels, probs, class_names, save_path="results/roc_auc/soft_ensemble.png")

save_gradcam_samples(
    list(trained_models.values())[0],  # use best single model for Grad-CAM
    images,
    class_names,
    save_dir="results/gradcam/soft_ensemble",
    device=DEVICE,
)

results["soft_ensemble"] = accuracy


In [None]:
print("\n===== WEIGHTED ENSEMBLE =====")

# weights based on validation accuracy stored in results
weights = [results[m] for m in trained_models.keys()]

weighted_ensemble = WeightedEnsemble(
    list(trained_models.values()),
    weights,
).to(DEVICE)

acc, report, cm, labels, preds, probs, images = test_model(
    weighted_ensemble,
    test_loader,
    DEVICE,
    class_names,
    return_details=True,
)

accuracy, precision, recall, f1 = compute_full_metrics(labels, preds)

print(f"Weighted Ensemble Accuracy: {accuracy:.4f}")
print(report)

plot_confusion_matrix(cm, class_names)
plot_roc_auc(labels, probs, class_names, save_path="results/roc_auc/weighted_ensemble.png")

save_gradcam_samples(
    list(trained_models.values())[0],
    images,
    class_names,
    save_dir="results/gradcam/weighted_ensemble",
    device=DEVICE,
)

results["weighted_ensemble"] = accuracy

In [None]:
print("\n===== FINAL RESULTS =====")

for k, v in results.items():
    print(f"{k:20s} : {v:.4f}")

print("\nPipeline Complete âœ”")
