In [None]:
# Core
import os, re, random, numpy as np, pandas as pd
from PIL import Image
import matplotlib.pyplot as plt

# Torch / Vision
import torch, torch.nn as nn, torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models, datasets

# Metrics
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

# Progress bar
from tqdm.auto import tqdm

# ---- PATHS: update base_path if needed ----
base_path = "/../FIVES A Fundus Image Dataset for AI-based Vessel Segmentation"
train_gt = os.path.join(base_path, "train/Ground truth")
test_gt  = os.path.join(base_path, "test/Ground truth")
excel_path = os.path.join(base_path, "Quality Assessment.xlsx")

print("Base exists:", os.path.exists(base_path))
print("Train Ground Truth exists:", os.path.exists(train_gt))
print("Test Ground Truth exists:", os.path.exists(test_gt))
print("Excel exists:", os.path.exists(excel_path))

# ---- Device ----
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)



In [None]:
# Sheets named exactly "Train" and "Test"
df_train = pd.read_excel(excel_path, sheet_name="Train")
df_test  = pd.read_excel(excel_path, sheet_name="Test")

# Map disease codes to integers
label_map = {"A":0, "D":1, "G":2, "N":3}
inv_label_map = {v:k for k,v in label_map.items()}

df_train["Disease"] = df_train["Disease"].astype(str).str.upper()
df_test["Disease"]  = df_test["Disease"].astype(str).str.upper()
df_train["label"] = df_train["Disease"].map(label_map)
df_test["label"]  = df_test["Disease"].map(label_map)

print("Train sheet sample:\n", df_train.head(3))
print("\nTest sheet sample:\n", df_test.head(3))


In [None]:
class FundusGTDataset(Dataset):
    """
    Loads GROUND TRUTH (segmented) images for classification.
    Expects filenames like:  1_A.png, 23_D.png, etc.
    Matches (Number, Disease) to Excel row to get label.
    """
    def __init__(self, img_dir, df, transform=None):
        self.img_dir = img_dir
        self.df = df.copy()
        self.transform = transform
        self.samples = []   # list of (fname, label)

        pattern = re.compile(r"^\s*(\d+)_([ADGNadgn])\.(png|PNG)$")
        files = [f for f in os.listdir(img_dir) if f.lower().endswith(".png")]

        for fname in files:
            m = pattern.match(fname)
            if not m:
                continue
            number = int(m.group(1))
            disease = m.group(2).upper()

            row = self.df[(self.df["Number"] == number) & (self.df["Disease"] == disease)]
            if not row.empty:
                label = int(row["label"].values[0])
                self.samples.append((fname, label))

        # Basic sanity
        if len(self.samples) == 0:
            print(f"[WARN] No samples matched in {img_dir}. Check naming and Excel mapping.")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        fname, label = self.samples[idx]
        path = os.path.join(self.img_dir, fname)
        img = Image.open(path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img, label, fname


In [None]:
# ImageNet normalization for pretrained backbones
imagenet_mean = [0.485, 0.456, 0.406]
imagenet_std  = [0.229, 0.224, 0.225]

train_tfms = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomResizedCrop(224, scale=(0.9, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize(mean=imagenet_mean, std=imagenet_std),
])

test_tfms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=imagenet_mean, std=imagenet_std),
])

# ---- Ground Truth datasets ----
train_ds = FundusGTDataset(train_gt, df_train, transform=train_tfms)
test_ds  = FundusGTDataset(test_gt,  df_test,  transform=test_tfms)

batch_size = 16
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,  num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

print("Train samples:", len(train_ds))
print("Test samples:", len(test_ds))

# Peek a few mappings
for i in range(min(5, len(train_ds))):
    _, lbl, fname = train_ds[i]
    print(f"Example -> {fname}  => label {lbl} ({inv_label_map[lbl]})")


In [None]:
# Compute class weights from the actual loaded Ground Truth training samples
labels_in_train = [lbl for _, lbl in train_ds.samples]
class_counts = np.bincount(labels_in_train, minlength=4)
class_weights = (len(labels_in_train) / (4.0 * np.maximum(class_counts, 1))).astype(np.float32)
print("Class counts:", class_counts)
print("Class weights:", class_weights)

weights_tensor = torch.tensor(class_weights, dtype=torch.float32, device=device)

# Model
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
model.fc = nn.Linear(model.fc.in_features, 4)  # 4 classes: A,D,G,N
model = model.to(device)

criterion = nn.CrossEntropyLoss(weight=weights_tensor)
optimizer = optim.Adam(model.parameters(), lr=1e-4)


In [None]:
def train_one_epoch(model, loader, optimizer, criterion, epoch):
    model.train()
    running_loss, correct, total = 0.0, 0, 0
    pbar = tqdm(loader, desc=f"Epoch {epoch} [train]", leave=False)
    for imgs, labels, _ in pbar:
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        out = model(imgs)
        loss = criterion(out, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * imgs.size(0)
        preds = out.argmax(1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
        pbar.set_postfix(loss=running_loss/total, acc=correct/total)
    return running_loss/total, correct/total

@torch.no_grad()
def evaluate(model, loader, criterion):
    model.eval()
    running_loss, correct, total = 0.0, 0, 0
    for imgs, labels, _ in loader:
        imgs, labels = imgs.to(device), labels.to(device)
        out = model(imgs)
        loss = criterion(out, labels)
        running_loss += loss.item() * imgs.size(0)
        preds = out.argmax(1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    return running_loss/total, correct/total


In [None]:
epochs = 5
for ep in range(1, epochs+1):
    tr_loss, tr_acc = train_one_epoch(model, train_loader, optimizer, criterion, ep)
    te_loss, te_acc = evaluate(model, test_loader, criterion)
    print(f"Epoch {ep:02d} | train: loss {tr_loss:.4f}, acc {tr_acc:.3f} | test: loss {te_loss:.4f}, acc {te_acc:.3f}")


In [None]:
@torch.no_grad()
def preds_and_labels(model, loader):
    model.eval()
    all_preds, all_labels = [], []
    for imgs, labels, _ in loader:
        imgs = imgs.to(device)
        out = model(imgs)
        all_preds.extend(out.argmax(1).cpu().numpy())
        all_labels.extend(labels.numpy())
    return np.array(all_preds), np.array(all_labels)

preds, gts = preds_and_labels(model, test_loader)
cm = confusion_matrix(gts, preds, labels=[0,1,2,3])
print("Confusion Matrix (rows=true A,D,G,N; cols=pred A,D,G,N):\n", cm)

disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["A","D","G","N"])
disp.plot(values_format="d", cmap="Blues")
plt.title("Original-only Baseline — Confusion Matrix")
plt.show()


In [None]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report

# 1️⃣ Accuracy
acc = accuracy_score(gts, preds)
print("Accuracy:", acc)

# 2️⃣ Precision, Recall, F1-score (macro/micro/weighted)
precision = precision_score(gts, preds, average='macro')
recall = recall_score(gts, preds, average='macro')
f1 = f1_score(gts, preds, average='macro')
print(f"Precision: {precision:.4f}, Recall: {recall:.4f}, F1-score: {f1:.4f}")

# 3️⃣ Per-class metrics
print("\nClassification Report:\n")
print(classification_report(gts, preds, target_names=["A","D","G","N"]))


In [None]:
save_path = "/../segmented_trained_weights.pth"

# Save the model's state_dict
torch.save(model.state_dict(), save_path)
print("Segmented-trained model weights saved to:", save_path)

In [None]:
# Denormalize for display
def denorm(img_t):
    mean = torch.tensor(imagenet_mean)[:,None,None].to(img_t.device)
    std  = torch.tensor(imagenet_std)[:,None,None].to(img_t.device)
    return (img_t*std + mean).clamp(0,1)

class GradCAM:
    def __init__(self, backbone, target_layer):
        self.backbone = backbone.eval()   # renamed from "model"
        self.target_layer = target_layer
        self.activations = None
        self.gradients = None

        def fwd_hook(module, inp, out):
            self.activations = out.detach()

        def bwd_hook(module, grad_in, grad_out):
            self.gradients = grad_out[0].detach()

        self.fwd_handle = target_layer.register_forward_hook(fwd_hook)
        self.bwd_handle = target_layer.register_full_backward_hook(bwd_hook)

    def __del__(self):
        self.fwd_handle.remove()
        self.bwd_handle.remove()

    def __call__(self, x, class_idx=None):
        x = x.requires_grad_(True)
        scores = self.backbone(x)  # renamed from "self.model"
        if class_idx is None:
            class_idx = scores.argmax(1)
        sel = scores.gather(1, class_idx.view(-1,1)).squeeze()

        self.backbone.zero_grad()
        sel.backward(torch.ones_like(sel))

        # activations: [B, K, H, W], gradients: [B, K, H, W]
        weights = self.gradients.mean(dim=(2,3), keepdim=True)  # [B, K, 1, 1]
        cam = (weights * self.activations).sum(dim=1, keepdim=True)  # [B,1,H,W]
        cam = torch.relu(cam)

        # normalize 0..1 per-sample
        B = cam.size(0)
        cam_flat = cam.view(B, -1)
        cam = (cam - cam_flat.min(dim=1, keepdim=True)[0].view(B,1,1,1)) / \
              (cam_flat.max(dim=1, keepdim=True)[0].view(B,1,1,1) - cam_flat.min(dim=1, keepdim=True)[0].view(B,1,1,1) + 1e-8)
        return cam, scores

# Attach Grad-CAM to reloaded resnet (kept as "model")
target_layer = model.layer4
gradcam_wrapper = GradCAM(model, target_layer)


In [None]:
def show_gradcam_samples(n=6):
    model.eval()
    imgs, labels, names = next(iter(test_loader))
    imgs, labels = imgs.to("cuda"), labels.to("cuda")   # move to GPU

    with torch.no_grad():
        preds = model(imgs).argmax(1)

    # run grad-cam with the wrapper
    cams, _ = gradcam_wrapper(imgs, None)
    cams = torch.nn.functional.interpolate(
        cams, size=imgs.shape[-2:], mode="bilinear", align_corners=False
    )

    n = min(n, imgs.size(0))
    plt.figure(figsize=(12, n*3))
    for i in range(n):
        img = denorm(imgs[i]).permute(1,2,0).detach().cpu().numpy()   # back to CPU
        cam = cams[i,0].detach().cpu().numpy()
        overlay = (0.6*img + 0.4*plt.cm.jet(cam)[...,:3])

        true_label = names[i]  # assuming names[i] is true label
        pred_label = str(preds[i].item())

        plt.subplot(n,3,i*3+1)
        plt.imshow(img)
        plt.axis('off')
        plt.title(f"True: {true_label}\nPred: {pred_label}")

        plt.subplot(n,3,i*3+2)
        plt.imshow(cam, cmap='jet')
        plt.axis('off')
        plt.title("CAM")

        plt.subplot(n,3,i*3+3)
        plt.imshow(overlay)
        plt.axis('off')
        plt.title("Overlay")

    plt.tight_layout()
    plt.show()


In [None]:
show_gradcam_samples(6)

In [None]:
# ROC
from sklearn.preprocessing import label_binarize
from sklearn.metrics import roc_curve, auc, RocCurveDisplay
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as F
import copy
import os

# ---------- Helpers to get probabilities and labels ----------
@torch.no_grad()
def get_probs_and_labels(model, loader, device):
    """Return (probs, labels, filenames) for entire loader.
       probs: np.array shape (N, n_classes)
       labels: np.array shape (N,)
    """
    model.eval()
    all_probs = []
    all_labels = []
    all_names = []
    for imgs, labels, names in loader:
        imgs = imgs.to(device)
        out = model(imgs)                          # logits
        probs = F.softmax(out, dim=1).cpu().numpy()  # probabilities
        all_probs.append(probs)
        all_labels.append(labels.numpy())
        all_names.extend(names)
    all_probs = np.vstack(all_probs)
    all_labels = np.concatenate(all_labels)
    return all_probs, all_labels, all_names

# ---------- ROC plotting for multiclass ----------
def plot_multiclass_roc(y_true, y_score, n_classes=4, class_names=None, figsize=(8,6)):
    """
    y_true: array shape (N,) with integer labels in 0..n_classes-1
    y_score: array shape (N, n_classes) with probabilities for each class
    """
    # Binarize the true labels
    y_test_bin = label_binarize(y_true, classes=np.arange(n_classes))
    if class_names is None:
        class_names = [str(i) for i in range(n_classes)]

    # Compute per-class ROC and AUC
    fpr = dict()
    tpr = dict()
    roc_auc = dict()
    for i in range(n_classes):
        fpr[i], tpr[i], _ = roc_curve(y_test_bin[:, i], y_score[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])

    # Compute micro-average ROC and AUC
    fpr["micro"], tpr["micro"], _ = roc_curve(y_test_bin.ravel(), y_score.ravel())
    roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])

    # Compute macro-average ROC and AUC
    # Aggregate all fpr points for interpolation
    all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))
    mean_tpr = np.zeros_like(all_fpr)
    for i in range(n_classes):
        mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])
    mean_tpr /= n_classes
    fpr["macro"] = all_fpr
    tpr["macro"] = mean_tpr
    roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])

    # Plot
    plt.figure(figsize=figsize)
    plt.plot(fpr["micro"], tpr["micro"],
             label=f'micro-average (AUC = {roc_auc["micro"]:.3f})', linestyle=':', linewidth=2)
    plt.plot(fpr["macro"], tpr["macro"],
             label=f'macro-average (AUC = {roc_auc["macro"]:.3f})', linestyle='-.', linewidth=2)

    colors = plt.cm.get_cmap('tab10', n_classes)
    for i, cname in enumerate(class_names):
        plt.plot(fpr[i], tpr[i], label=f'{cname} (AUC = {roc_auc[i]:.3f})', linewidth=1.8)

    plt.plot([0,1], [0,1], 'k--', linewidth=1)
    plt.xlim([-0.02, 1.02])
    plt.ylim([-0.02, 1.02])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Multiclass ROC')
    plt.legend(loc='lower right', fontsize='small')
    plt.grid(alpha=0.3)
    plt.show()

# ---------- Training loop with best-model saving and ROC plotted at end ----------
def train_and_evaluate(model, train_loader, val_loader, optimizer, criterion,
                       device, epochs=10, save_path="/../segmented_trained_weights_best.pth",
                       n_classes=4, class_names=None, scheduler=None, print_freq=1):
    best_acc = 0.0
    best_model_w = copy.deepcopy(model.state_dict())
    history = {"train_loss":[], "train_acc":[], "val_loss":[], "val_acc":[]}

    for epoch in range(1, epochs+1):
        # ---- train ----
        model.train()
        running_loss = 0.0
        running_correct = 0
        running_total = 0
        for imgs, labels, _ in train_loader:
            imgs = imgs.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            out = model(imgs)
            loss = criterion(out, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * imgs.size(0)
            preds = out.argmax(1)
            running_correct += (preds == labels).sum().item()
            running_total += labels.size(0)

        train_loss = running_loss / running_total
        train_acc = running_correct / running_total
        history["train_loss"].append(train_loss)
        history["train_acc"].append(train_acc)

        # ---- val ----
        model.eval()
        val_running_loss = 0.0
        val_running_correct = 0
        val_running_total = 0
        with torch.no_grad():
            for imgs, labels, _ in val_loader:
                imgs = imgs.to(device)
                labels = labels.to(device)
                out = model(imgs)
                loss = criterion(out, labels)
                val_running_loss += loss.item() * imgs.size(0)
                preds = out.argmax(1)
                val_running_correct += (preds == labels).sum().item()
                val_running_total += labels.size(0)

        val_loss = val_running_loss / val_running_total
        val_acc = val_running_correct / val_running_total
        history["val_loss"].append(val_loss)
        history["val_acc"].append(val_acc)

        # scheduler step if provided
        if scheduler is not None:
            scheduler.step(val_loss)

        # save best
        if val_acc > best_acc:
            best_acc = val_acc
            best_model_w = copy.deepcopy(model.state_dict())
            torch.save(best_model_w, save_path)

        if epoch % print_freq == 0:
            print(f"Epoch {epoch}/{epochs} | train_loss {train_loss:.4f} acc {train_acc:.3f} | val_loss {val_loss:.4f} acc {val_acc:.3f} | best_val_acc {best_acc:.3f}")

    # load best before producing final metrics/plots
    model.load_state_dict(best_model_w)
    print("Loaded best model with val_acc = {:.4f}".format(best_acc))

    # get probs + labels from validation (or test) loader
    probs, labels, names = get_probs_and_labels(model, val_loader, device)

    # Plot ROC
    if class_names is None:
        class_names = ["A","D","G","N"][:n_classes]
    plot_multiclass_roc(labels, probs, n_classes=n_classes, class_names=class_names)

    return model, history, probs, labels, names

# ---------- Example usage ----------
# (Adjust epochs, save_path as desired)
save_path = "/../segmented_trained_weights_best.pth"
model, history, probs, labels, names = train_and_evaluate(
    model=model,
    train_loader=train_loader,
    val_loader=test_loader,      # you can pass a validation loader if you have one
    optimizer=optimizer,
    criterion=criterion,
    device=device,
    epochs=10,                   # raise if you want more training
    save_path=save_path,
    n_classes=4,
    class_names=["A","D","G","N"],
    scheduler=None               # optionally: torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
)
print("ROC plotting finished. Best model saved to:", save_path)
