In [None]:
# %% IMPORTS
import os
import csv
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import models, transforms, datasets
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, confusion_matrix
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.tensorboard import SummaryWriter

In [None]:
# %% CONFIGURATION (YOLO-aligned)
SELECTED_MODEL = "resnet50"  # "densenet169"
CV_TYPES = ["sgkf05"]
IMAGE_SIZE = (224, 224)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")
if DEVICE.type == "cuda":
    print(f"â†’ GPU name: {torch.cuda.get_device_name(0)}")

EPOCHS = 100
PATIENCE = 20
BATCH_SIZE = 32
LEARNING_RATE = 1e-4
NUM_CLASSES = 2
WEIGHT_DECAY = 0.0005
DROPOUT_RATE = 0.4
LABEL_SMOOTHING = 0.1
MIXUP_ALPHA = 0.4

# === TRANSFORMS ===
data_transforms = transforms.Compose([
    transforms.ColorJitter(hue=0.015, saturation=0.7, brightness=0.4),
    transforms.RandomRotation(degrees=2),
    transforms.RandomAffine(degrees=0, translate=(0.4, 0.4), shear=10),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
    transforms.Resize(IMAGE_SIZE),
    transforms.RandomErasing(p=0, scale=(0.02, 0.33)),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

In [None]:
# %% HELPERS
def one_hot_smooth(labels, num_classes, smoothing=0.1):
    with torch.no_grad():
        true_dist = torch.zeros((labels.size(0), num_classes), device=labels.device)
        true_dist.fill_(smoothing / (num_classes - 1))
        true_dist.scatter_(1, labels.data.unsqueeze(1), 1.0 - smoothing)
    return true_dist

def mixup_data(x, y, alpha=0.4):
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1
    index = torch.randperm(x.size(0)).to(x.device)
    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

In [None]:
# %% MODEL CREATION
def create_model(name):
    if name == "resnet18":
        model = models.resnet18(weights="DEFAULT")
        model.fc = nn.Sequential(nn.Dropout(DROPOUT_RATE), nn.Linear(model.fc.in_features, NUM_CLASSES))
    if name == "resnet50":
        model = models.resnet50(weights="DEFAULT")
        model.fc = nn.Sequential(nn.Dropout(DROPOUT_RATE), nn.Linear(model.fc.in_features, NUM_CLASSES))
    elif name == "densenet121":
        model = models.densenet121(weights="DEFAULT")
        model.classifier = nn.Sequential(nn.Dropout(DROPOUT_RATE), nn.Linear(model.classifier.in_features, NUM_CLASSES))
    elif name == "densenet169":
        model = models.densenet169(weights="DEFAULT")
        model.classifier = nn.Sequential(nn.Dropout(DROPOUT_RATE), nn.Linear(model.classifier.in_features, NUM_CLASSES))
    elif name.startswith("efficientnet"):
        model = getattr(models, name)(weights="DEFAULT")
        model.classifier[1] = nn.Linear(model.classifier[1].in_features, NUM_CLASSES)
    else:
        raise ValueError(f"Unsupported model: {name}")
    return model.to(DEVICE)

In [None]:
# %% TRAIN AND EVALUATE
def train_and_evaluate(data_root, project_name, run_name):
    project_dir = os.path.join("runs", "classify", project_name)
    os.makedirs(project_dir, exist_ok=True)
    output_dir = os.path.join(project_dir, run_name)
    os.makedirs(output_dir, exist_ok=True)

    results_csv = os.path.join(output_dir, "results.csv")
    test_csv = os.path.join(project_dir, f"test-{project_name}.csv")

    writer = SummaryWriter(log_dir=os.path.join(output_dir, "tensorboard"))

    with open(results_csv, "w", newline="") as f:
        csv.writer(f).writerow(["epoch", "train/loss", "val/loss", "metrics/accuracy_top1"])
    if not os.path.exists(test_csv):
        with open(test_csv, "w", newline="") as f:
            csv.writer(f).writerow(["sgkf", "model", "repeat", "fold", "TP", "TN", "FP", "FN", "accuracy", "precision", "sensitivity", "specificity", "f1_score", "data_path_match"])

    train_ds = datasets.ImageFolder(os.path.join(data_root, "train"), transform=data_transforms)
    val_ds = datasets.ImageFolder(os.path.join(data_root, "val"), transform=data_transforms)
    test_ds = datasets.ImageFolder(os.path.join(data_root, "test"), transform=data_transforms)

    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)
    test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)

    model = create_model(SELECTED_MODEL)
    #model=models.resnet50(pretrained=True)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

    best_val_loss = float("inf")
    wait = 0

    for epoch in range(EPOCHS):
        model.train()
        total_loss = 0
        y_true, y_pred = [], []

        for images, labels in train_loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            images, y_a, y_b, lam = mixup_data(images, labels, alpha=MIXUP_ALPHA)
            optimizer.zero_grad()
            outputs = model(images)
            smoothed_a = one_hot_smooth(y_a, NUM_CLASSES, smoothing=LABEL_SMOOTHING)
            smoothed_b = one_hot_smooth(y_b, NUM_CLASSES, smoothing=LABEL_SMOOTHING)
            loss = lam * F.cross_entropy(outputs, smoothed_a) + (1 - lam) * F.cross_entropy(outputs, smoothed_b)
            loss.backward()
            optimizer.step()
            scheduler.step()
            total_loss += loss.item()
            preds = torch.argmax(outputs, dim=1).cpu()
            y_pred.extend(preds)
            y_true.extend(labels.cpu())

        acc = accuracy_score(y_true, y_pred)
        # Evaluate validation loss
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()


        writer.add_scalar("Loss/train", total_loss, epoch + 1)
        writer.add_scalar("Loss/val", val_loss, epoch + 1)
        writer.add_scalar("Accuracy/train", acc, epoch + 1)

        with open(results_csv, "a", newline="") as f:
            csv.writer(f).writerow([epoch+1, round(total_loss, 6), round(val_loss, 6), round(acc, 6)])

        print(f"[{run_name}] Epoch {epoch+1:03d} | Train Loss: {total_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {acc:.4f}")

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            wait = 0
            torch.save(model.state_dict(), os.path.join(output_dir, "best_model.pt"))
        else:
            wait += 1
            if wait >= PATIENCE:
                print(f"[{run_name}] Early stopping triggered at epoch {epoch+1}")
                break

        cm = confusion_matrix(y_true, y_pred)
        sns.heatmap(cm, annot=True, fmt="d", cmap="Blues")
        plt.title("Validation Confusion Matrix")
        plt.xlabel("Predicted")
        plt.ylabel("Actual")
        plt.savefig(os.path.join(output_dir, "val_confusion_matrix.png"))
        plt.close()

    writer.close()

    model.load_state_dict(torch.load(os.path.join(output_dir, "best_model.pt")))
    model.eval()
    y_true, y_pred = [], []
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(DEVICE)
            outputs = model(images)
            preds = torch.argmax(outputs, dim=1).cpu()
            y_pred.extend(preds)
            y_true.extend(labels)

    acc = accuracy_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred)
    prec = precision_score(y_true, y_pred)
    rec = recall_score(y_true, y_pred)
    cm_test = confusion_matrix(y_true, y_pred)
    tn, fp, fn, tp = cm_test.ravel()
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0

    sns.heatmap(cm_test, annot=True, fmt="d", cmap="Blues")
    plt.title("Test Confusion Matrix")
    plt.xlabel("Predicted")
    plt.ylabel("Actual")
    plt.savefig(os.path.join(output_dir, "test_confusion_matrix.png"))
    plt.close()

    with open(test_csv, "a", newline="") as f:
        csv.writer(f).writerow([
            project_name.split("-")[0], SELECTED_MODEL,
            run_name.split("-")[1], run_name.split("-")[2],
            tp, tn, fp, fn,
            round(acc, 4), round(prec, 4), round(rec, 4), round(specificity, 4), round(f1, 4), data_root
        ])

    print(f"[{run_name}] Test Accuracy: {acc:.4f} | F1: {f1:.4f} | Precision: {prec:.4f} | Recall: {rec:.4f}")


In [None]:
for cv_type in CV_TYPES:
    split_base = f"data/2-splits/{cv_type}"
    for repeat in sorted(os.listdir(split_base)):
        for fold in sorted(os.listdir(os.path.join(split_base, repeat))):
            fold_path = os.path.join(split_base, repeat, fold)
            if not os.path.isdir(fold_path): continue
            project_name = f"{cv_type}-{SELECTED_MODEL}"
            run_name = f"train-{repeat}-{fold}"
            print(f"\n=== Training {SELECTED_MODEL} on {cv_type}/{repeat}/{fold} ===")
            train_and_evaluate(fold_path, project_name, run_name)
