In [None]:
#Libraries 
import os
import random
import math
import numpy as np
from PIL import Image, ImageOps
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torchvision.transforms.functional as TF
import torchvision.models as models
from gudhi import CubicalComplex
import pandas as pd
from sklearn.model_selection import train_test_split
from torchvision.transforms.functional import InterpolationMode
from sklearn.metrics import confusion_matrix

In [None]:
# Hyperparameters
DATA_DIR_TRAIN = "train_images/train_images"
DATA_DIR_VAL   = "val_images/val_images"
DATA_DIR_TEST  = "test_images/test_images"

CSV_TRAIN = "train_1.csv"
CSV_VAL   = "valid.csv"
CSV_TEST  = "test.csv"

IMG_SIZE = (256, 256)
TDA_OUT_SIZE = (64, 64)
BATCH_SIZE = 16
NUM_EPOCHS = 30
PATIENCE = 7
LR = 1e-4
RANDOM_SEED = 50
DEVICE = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
LAMBDA_MASK = 1.0
PRECOMPUTE_TDA = False
TDA_DIR = "./tda_cache_fixed"
NUM_WORKERS = 0

In [None]:
# reproducibility
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
if DEVICE == "cuda":
    torch.cuda.manual_seed_all(RANDOM_SEED)

In [None]:
# Preprocess & TDA Functions
def preprocess_green_clahe(img_rgb: Image.Image, resize_shape=IMG_SIZE):
    """Return normalized green-channel image (float32) for TDA computation."""
    img = img_rgb.resize(resize_shape)
    arr = np.asarray(img).astype(np.float32) / 255.0
    green = arr[..., 1]
    g8 = (np.clip(green, 0, 1) * 255).astype(np.uint8)
    green_eq = np.array(ImageOps.equalize(Image.fromarray(g8)), dtype=np.float32) / 255.0

    #Mask 
    h, w = green_eq.shape
    cy, cx = h // 2, w // 2
    r = int(0.9 * min(cy, cx))
    Y, X = np.ogrid[:h, :w]
    mask = (X - cx) ** 2 + (Y - cy) ** 2 <= r ** 2

    #Mean and Std 
    vals = green_eq[mask]
    mu, sigma = (float(vals.mean()), float(vals.std())) if vals.size > 0 else (0.0, 1.0)
    sigma = sigma if sigma > 1e-6 else 1.0

    norm = np.zeros_like(green_eq, dtype=np.float32)
    norm[mask] = (green_eq[mask] - mu) / sigma
    
    # Scale norm to [0,1] for better stability in persistence diagram
    if norm.max() != norm.min():
        nmin, nmax = norm.min(), norm.max()
        norm = (norm - nmin) / (nmax - nmin + 1e-8)
    return norm

def compute_persistence_diagram_from_array(a2d: np.ndarray):
    """Compute 0/1-dim persistence intervals using CubicalComplex on -a2d."""
    a2d = np.asarray(a2d, dtype=np.float32)
    top_cells = (-a2d).ravel()
    try:
        cc = CubicalComplex(dimensions=a2d.shape, top_dimensional_cells=top_cells)
        cc.persistence()
        pd0 = np.array(cc.persistence_intervals_in_dimension(0))
        pd1 = np.array(cc.persistence_intervals_in_dimension(1))
    except Exception:
        pd0 = np.zeros((0, 2), dtype=np.float32)
        pd1 = np.zeros((0, 2), dtype=np.float32)
    return pd0, pd1

def persistence_image_from_diagram(diag, out_size=(64, 64), sigma=0.03):
    """Convert PD (N,2) to a persistence image in [0,1] range."""
    H, W = out_size
    img = np.zeros((H, W), dtype=np.float32)
    if diag is None or len(diag) == 0:
        return img
    births, deaths = diag[:, 0], diag[:, 1]
    finite_deaths = np.where(np.isinf(deaths), births + 1.0, deaths)
    pers = finite_deaths - births
    if len(births) == 0 or len(pers) == 0:
        return img
    b_min, b_max = births.min(), births.max() + 1e-8
    p_min, p_max = pers.min(), pers.max() + 1e-8
    births_n = (births - b_min) / (b_max - b_min + 1e-8)
    pers_n = (pers - p_min) / (p_max - p_min + 1e-8)
    ys = np.linspace(0, 1, H)
    xs = np.linspace(0, 1, W)
    Xg, Yg = np.meshgrid(xs, ys)
    for b, p in zip(births_n, pers_n):
        d2 = (Xg - b) ** 2 + (Yg - p) ** 2
        img += np.exp(-0.5 * d2 / (sigma ** 2))
    if img.max() > 0:
        img /= img.max()
    return img.astype(np.float32)

def compute_persistence_image_from_green(green_norm, out_size=TDA_OUT_SIZE):
    pd0, pd1 = compute_persistence_diagram_from_array(green_norm)
    diag = pd1 if len(pd1) > 0 else pd0
    return persistence_image_from_diagram(diag, out_size=out_size)

In [None]:
# ------------------ DATASET ------------------
class AptosTDA_Dataset(Dataset):
    """
    For each sample:
    - Load RGB image
    - Compute or load persistence image (single channel float32)
    - Apply random flip & rotation with SAME params to both RGB and persistence image
    - Return normalized image tensor and persistence tensor (1xHxW)
    """

    def __init__(self, df, data_dir, img_size=IMG_SIZE,
                 compute_tda_on_the_fly=True, tda_out_size=TDA_OUT_SIZE,
                 precompute_dir=None, transforms=None):
        
        self.df = df.reset_index(drop=True)
        self.data_dir = data_dir
        self.img_size = img_size
        self.compute_tda_on_the_fly = compute_tda_on_the_fly
        self.tda_out_size = tda_out_size
        self.precompute_dir = precompute_dir
        self.transforms = transforms  # only normalization / ToTensor expected here

        if (not self.compute_tda_on_the_fly) and (self.precompute_dir is None):
            raise ValueError("If not computing on the fly, you must provide precompute_dir")

        if self.precompute_dir and not os.path.exists(self.precompute_dir):
            os.makedirs(self.precompute_dir, exist_ok=True)

    def __len__(self):
        return len(self.df)

    def resolve_image_path(self, id_code):
        for ext in [".png", ".jpg", ".jpeg"]:
            p = os.path.join(self.data_dir, id_code + ext)
            if os.path.exists(p):
                return p
        return None

    def maybe_load_precomputed(self, idc):
        fname = os.path.join(self.precompute_dir, f"{idc}.npz")
        if os.path.exists(fname):
            try:
                arr = np.load(fname)["tda_img"].astype(np.float32)
                return arr
            except Exception:
                return None
        return None

    def precompute_and_save(self, idc, green_norm):
        arr = compute_persistence_image_from_green(green_norm, out_size=self.tda_out_size)
        fname = os.path.join(self.precompute_dir, f"{idc}.npz")
        np.savez_compressed(fname, tda_img=arr)
        return arr

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        idc, label = row['id_code'], int(row['diagnosis'])
        p = self.resolve_image_path(idc)
        if p:
            pil = Image.open(p).convert("RGB").resize(self.img_size)
        else:
            pil = Image.new("RGB", self.img_size)

        # compute / load persistence image based on green-channel preprocessing
        if self.compute_tda_on_the_fly:
            green_norm = preprocess_green_clahe(pil, resize_shape=self.img_size)
            tda_img = compute_persistence_image_from_green(green_norm, out_size=self.tda_out_size)
        else:
            loaded = self.maybe_load_precomputed(idc)
            if loaded is None:
                green_norm = preprocess_green_clahe(pil, resize_shape=self.img_size)
                tda_img = self.precompute_and_save(idc, green_norm)
            else:
                tda_img = loaded

        # convert to PIL for geometric transforms to keep same ops on both
        tda_pil = Image.fromarray((tda_img * 255).astype(np.uint8)).convert("L").resize(self.img_size)

        # apply identical random geometric transforms to both
        # Random horizontal flip
        if random.random() < 0.5:
            pil = TF.hflip(pil)
            tda_pil = TF.hflip(tda_pil)

        # Random rotation between -10 and 10 degrees
        angle = random.uniform(-10, 10)
        pil = TF.rotate(pil, angle, interpolation=InterpolationMode.BILINEAR)
        tda_pil = TF.rotate(tda_pil, angle, interpolation=InterpolationMode.BILINEAR)

        #tda -> tensor float32 [0,1]
        if self.transforms is not None:
            img_tensor = self.transforms(pil)
        else:
            img_tensor = TF.to_tensor(pil)

        tda_tensor = TF.to_tensor(tda_pil).float() / 1.0  # already in 0-1 after ToTensor

        tda_tensor = tda_tensor.float()

        return img_tensor, tda_tensor, label, idc


In [None]:
# ------------------ MODEL ------------------
class CNNBackboneSpatial(nn.Module):
    def __init__(self, pretrained=True, freeze_backbone=False):
        super().__init__()
        from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights
        weights = EfficientNet_B0_Weights.IMAGENET1K_V1 if pretrained else None
        net = efficientnet_b0(weights=weights)

        # Keep only feature extractor
        self.stem = net.features

        if freeze_backbone:
            for p in self.stem.parameters():
                p.requires_grad = False

    def forward(self, x):
        return self.stem(x)

class TopologyGuidedAttention(nn.Module):
    """
    Implements exactly: maxpool(Fi) + avgpool(Fi) + topo_resized -> Conv -> Sigmoid
    Returns gated features Fi_att and attention map (1 x H x W)
    """
    def __init__(self, kernel_size=7):
        super().__init__()
        pad = kernel_size // 2
        self.conv = nn.Conv2d(1, 1, kernel_size=kernel_size, padding=pad)

    def forward(self, Fi, pimg):
        """
        Fi: [B, C, Hf, Wf]
        pimg: [B, 1, Horig, Worig] (will be resized)
        """
        # resize topo mask to feature map spatial dims
        p_resized = F.interpolate(pimg, size=Fi.shape[2:], mode="bilinear", align_corners=False)

        # channel-wise max and avg pool -> [B,1,Hf,Wf]
        max_pool = torch.max(Fi, dim=1, keepdim=True)[0]
        avg_pool = torch.mean(Fi, dim=1, keepdim=True)

        combined = max_pool + avg_pool + p_resized  # element-wise sum as in diagram
        att_logits = self.conv(combined)
        att = torch.sigmoid(att_logits)
        Fi_att = Fi * att
        return Fi_att, att, p_resized

class TDA_CNN_Attention_Model(nn.Module):
    """
    Backbone -> TopologyGuidedAttention -> post conv -> classifier
    """
    def __init__(self, backbone, num_classes=5):
        super().__init__()
        self.backbone = backbone

        # infer backbone feature channels by a dummy forward
        dummy = torch.zeros(1, 3, IMG_SIZE[1], IMG_SIZE[0])
        with torch.no_grad():
            feat = self.backbone(dummy)
        _, C, _, _ = feat.shape

        self.att_module = TopologyGuidedAttention(kernel_size=7)

        # post-attention conv then global pool -> fc
        self.post_conv = nn.Sequential(
            nn.Conv2d(C, max(32, C // 2), kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d((1, 1))
        )
        self.fc = nn.Linear(max(32, C // 2), num_classes)

    def forward(self, x_rgb, pimg):
        Feat = self.backbone(x_rgb)                       # [B, C, Hf, Wf]
        Feat_att, att_map, p_resized = self.att_module(Feat, pimg)
        z = self.post_conv(Feat_att).flatten(1)
        logits = self.fc(z)
        return logits, att_map, p_resized


In [None]:
def mixup_data(x, y, alpha=0.4):
    """Returns mixed inputs, pairs of targets, and lambda."""
    lam = np.random.beta(alpha, alpha) if alpha > 0 else 1
    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(x.device)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

# ------------------ TRAIN / VAL ------------------
def train_one_epoch(model, dataloader, optimizer, criterion_ce, lambda_mask, device):
    model.train()
    total, correct, running_loss, running_ce, running_mask = 0, 0, 0, 0, 0
    for imgs, tda_imgs, labels, _ in tqdm(dataloader, desc="Train"):
        imgs, tda_imgs, labels = imgs.to(device), tda_imgs.to(device), labels.to(device).long()
        optimizer.zero_grad()
        # ---- MIXUP ----
        imgs_mix, y_a, y_b, lam = mixup_data(imgs, labels, alpha=0.4)
        logits, att_pred, pimg_resized = model(imgs_mix, tda_imgs)

        ce_loss = lam * criterion_ce(logits, y_a) + (1 - lam) * criterion_ce(logits, y_b)
        mask_loss = F.mse_loss(att_pred, pimg_resized)

        loss = ce_loss + lambda_mask * mask_loss
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * imgs.size(0)
        running_ce += ce_loss.item() * imgs.size(0)
        running_mask += mask_loss.item() * imgs.size(0)
        _, preds = logits.max(1)
        correct += (preds == labels).sum().item()
        total += imgs.size(0)
    return running_loss/total, running_ce/total, running_mask/total, correct/total

def quadratic_weighted_kappa(y_true, y_pred):
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)

    O = confusion_matrix(y_true, y_pred, labels=[0,1,2,3,4])
    N = O.sum()

    w = np.zeros((5,5))
    for i in range(5):
        for j in range(5):
            w[i,j] = ((i - j)**2) / ((5 - 1)**2)

    act_hist = np.sum(O, axis=1)
    pred_hist = np.sum(O, axis=0)
    E = np.outer(act_hist, pred_hist) / N

    kappa = 1 - (np.sum(w * O) / np.sum(w * E))
    return kappa

def validate_one_epoch(model, dataloader, criterion_ce, lambda_mask, device):
    model.eval()
    total, correct = 0, 0
    running_loss = running_ce = running_mask = 0

    all_preds = []
    all_labels = []

    with torch.no_grad():
        for imgs, tda_imgs, labels, _ in dataloader:
            imgs, tda_imgs, labels = imgs.to(device), tda_imgs.to(device), labels.to(device).long()
            logits, att_pred, pimg_resized = model(imgs, tda_imgs)

            ce_loss = criterion_ce(logits, labels)
            mask_loss = F.mse_loss(att_pred, pimg_resized)
            loss = ce_loss + lambda_mask * mask_loss

            running_loss += loss.item() * imgs.size(0)
            running_ce += ce_loss.item() * imgs.size(0)
            running_mask += mask_loss.item() * imgs.size(0)

            _, preds = logits.max(1)

            all_preds.extend(preds.cpu().tolist())
            all_labels.extend(labels.cpu().tolist())

            correct += (preds == labels).sum().item()
            total += imgs.size(0)

    acc = correct / total
    qwk = quadratic_weighted_kappa(all_labels, all_preds)

    return running_loss/total, running_ce/total, running_mask/total, acc, qwk

In [None]:
# ------------------ TEST EVALUATION ------------------
def evaluate_test(model, dataloader, device):
    model.eval()
    total, correct = 0, 0
    preds_all, labels_all = [], []

    with torch.no_grad():
        for imgs, tda_imgs, labels, _ in dataloader:
            imgs, tda_imgs, labels = imgs.to(device), tda_imgs.to(device), labels.to(device).long()
            logits, _, _ = model(imgs, tda_imgs)
            _, preds = logits.max(1)

            preds_all.extend(preds.cpu().tolist())
            labels_all.extend(labels.cpu().tolist())

            correct += (preds == labels).sum().item()
            total += imgs.size(0)

    acc = correct / total
    qwk = quadratic_weighted_kappa(labels_all, preds_all)
    return acc, qwk

In [None]:
# ------------------ MAIN ------------------
def main():
    # Load three CSVs directly
    train_df = pd.read_csv(CSV_TRAIN)[['id_code', 'diagnosis']]
    val_df   = pd.read_csv(CSV_VAL)[['id_code', 'diagnosis']]
    test_df  = pd.read_csv(CSV_TEST)[['id_code', 'diagnosis']]

    # Basic transforms
    normalize = T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
    transforms = T.Compose([T.ToTensor(), normalize])

    # Datasets
    ds_train = AptosTDA_Dataset(train_df, DATA_DIR_TRAIN, img_size=IMG_SIZE,
                                compute_tda_on_the_fly=(not PRECOMPUTE_TDA),
                                tda_out_size=TDA_OUT_SIZE, precompute_dir=TDA_DIR,
                                transforms=transforms)

    ds_val = AptosTDA_Dataset(val_df, DATA_DIR_VAL, img_size=IMG_SIZE,
                              compute_tda_on_the_fly=(not PRECOMPUTE_TDA),
                              tda_out_size=TDA_OUT_SIZE, precompute_dir=TDA_DIR,
                              transforms=transforms)

    ds_test = AptosTDA_Dataset(test_df, DATA_DIR_TEST, img_size=IMG_SIZE,
                               compute_tda_on_the_fly=(not PRECOMPUTE_TDA),
                               tda_out_size=TDA_OUT_SIZE, precompute_dir=TDA_DIR,
                               transforms=transforms)

    # DataLoaders
    train_loader = DataLoader(ds_train, batch_size=BATCH_SIZE, shuffle=True,
                              num_workers=NUM_WORKERS, pin_memory=(DEVICE=="cuda"))

    val_loader = DataLoader(ds_val, batch_size=BATCH_SIZE, shuffle=False,
                            num_workers=NUM_WORKERS, pin_memory=(DEVICE=="cuda"))

    test_loader = DataLoader(ds_test, batch_size=BATCH_SIZE, shuffle=False,
                             num_workers=NUM_WORKERS, pin_memory=(DEVICE=="cuda"))

    # Model
    backbone = CNNBackboneSpatial(pretrained=True)
    model = TDA_CNN_Attention_Model(backbone=backbone, num_classes=5).to(DEVICE)

    optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)
    criterion_ce = nn.CrossEntropyLoss()

    best_val_qwk = -1
    patience_counter = 0
    best_val_acc = 0.0
    best_path = "best_tda_att_model_fixed.pth"

    # LR SCHEDULER (Warmup + Cosine)
    warmup_epochs = 3
    total_epochs = NUM_EPOCHS

    def lr_lambda(current_epoch):
        if current_epoch < warmup_epochs:
            # Linear warmup from 0 â†’ 1
            return float(current_epoch) / float(max(1, warmup_epochs))
        # Cosine decay after warmup
        progress = float(current_epoch - warmup_epochs) / float(max(1, total_epochs - warmup_epochs))
        return 0.5 * (1 + math.cos(math.pi * progress))

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)


    # TRAINING! 
    for epoch in range(1, NUM_EPOCHS + 1):
        print(f"\nEpoch {epoch}/{NUM_EPOCHS}")

        tr_loss, tr_ce, tr_mask, tr_acc = train_one_epoch(
            model, train_loader, optimizer, criterion_ce, LAMBDA_MASK, DEVICE
        )
        val_loss, val_ce, val_mask, val_acc, val_qwk = validate_one_epoch(
            model, val_loader, criterion_ce, LAMBDA_MASK, DEVICE
        )

        scheduler.step()

        print(f"Train: loss={tr_loss:.4f} acc={tr_acc:.4f}")
        print(f"Val  : loss={val_loss:.4f} acc={val_acc:.4f} qwk={val_qwk:.4f}")

        # ---------------- SAVE BEST BASED ON QWK ----------------
        if val_qwk > best_val_qwk:
            best_val_qwk = val_qwk
            torch.save(model.state_dict(), best_path)
            patience_counter = 0
            print("Saved BEST model (by QWK).")
        else:
            patience_counter += 1
            if patience_counter >= PATIENCE:
                print("Early stopping triggered.")
                break 

    print("\nTraining finished.")
    print("Best validation accuracy:", best_val_acc)

    # ------------------ FINAL TEST EVAL ------------------
    print("\nLoading best model for FINAL TEST evaluation...")
    model.load_state_dict(torch.load(best_path, map_location=DEVICE))

    test_acc, test_qwk = evaluate_test(model, test_loader, DEVICE)
    print("\n=======================================")
    print(f" FINAL TEST ACCURACY: {test_acc:.4f} ")
    print(f" FINAL TEST QWK:      {test_qwk:.4f} ")
    print("=======================================")

In [None]:
if __name__ == "__main__":
    main()