## Data Loader

In [1]:
"""
modified from https://github.com/mlmed/torchxrayvision/blob/master/torchxrayvision/datasets.py
"""
from pathlib import Path
from tqdm import tqdm

# from toolz import *
# from toolz.curried import *
# from toolz.curried.operator import *

import pandas as pd

import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader

class EmoryDataset(Dataset):

    split_ratio = [0.6, 0.2, 0.2]

    embed_prefix = "embds"

    pathologies = ["Enlarged Cardiomediastinum",
                   "Cardiomegaly",
                   "Lung Opacity",
                   "Lung Lesion",
                   "Edema",
                   "Consolidation",
                   "Pneumonia",
                   "Atelectasis",
                   "Pneumothorax",
                   "Pleural Effusion",
                   "Pleural Other",
                   "Fracture",
                   "Support Devices"]


    embedding_d = {
        "BiomedCLIP": Path("~/fsx/embeddings/EmoryCXR/embds_BiomedCLIP"), 
        "CheXagent": Path("~/fsx/embeddings/EmoryCXR/embds_CheXagent"), 
        "MedGemma": Path("~/fsx/embeddings/EmoryCXR/embds_MedGemma"),
        "RAD-DINO": Path("~/fsx/embeddings/EmoryCXR/embds_RAD-DINO"),
    }

    csvpath = Path("~/fsx/embeddings/EmoryCXR/Tables/EmoryCXR_v2_FindingLabel_10162024.csv")
    metacsvpath = Path("~/fsx/embeddings/EmoryCXR/Tables/EmoryCXR_v2_metadata_08152025.csv")
    base_dicom_path = Path("/home/jupyter-oluwatunmise/fsx/embeddings/EmoryCXR/")

    def __init__(
        self,
        views: str = ["PA", "AP"][0],
        mode: str = ["train", "validate", "test"][0],
        embedding_type: str = ["BiomedCLIP", "CheXagent", "MedGemma", "RAD-DINO", "All"][0],       
        unique_patients=True,
        seed : int = 0):
        
        np.random.seed(seed)  # Reset the seed so all runs are the same.        
        self.views = views
        self.mode = mode
        self.embedding_type = embedding_type
        self.unique_patients = unique_patients        
        self.seed = seed
        
        self.embpath: str | list[str] = self.load_emb_path(embedding_type)
            
        self.csv = pd.read_csv(self.csvpath)
        self.metacsv = pd.read_csv(self.metacsvpath)
        self.csv = self.csv.set_index(["AccessionNumber_anon"])
        self.metacsv = self.metacsv.set_index(["AccessionNumber_anon"])
        self.csv = self.csv.join(self.metacsv).reset_index()   

        # Keep only the desired view
        self.csv["view"] = self.csv["ViewPosition"]
        self.limit_to_selected_views(views)
    
        if unique_patients:
            self.csv = self.csv.groupby("empi_anon").first().reset_index()
            
        self.csv = self.csv.sample(frac=1, random_state=self.seed).reset_index(drop=True)
        self.csv = self.csv.fillna(0)
        self.csv = self.csv[:10000]
        
        n_row = self.csv.shape[0]
        
        # spit data to one of train valid test
        if self.mode == "train":
            self.csv = self.csv[: int(n_row * self.split_ratio[0])]
        elif self.mode == "valid":
            self.csv = self.csv[
                int(n_row * self.split_ratio[0]) : int(
                    n_row * (self.split_ratio[0] + self.split_ratio[1])
                )
            ]            
        elif self.mode == "test":
            self.csv = self.csv[-int(n_row * self.split_ratio[-1]) :]
        else:
            raise ValueError(
                f"attr:mode has to be one of [train, valid, test] but your input is {self.mode}"
            )

        # Get our classes.
        healthy = self.csv["No Finding"] == 1
        labels = []
        for pathology in self.pathologies:
            if pathology in self.csv.columns:
                self.csv.loc[healthy, pathology] = 0
                mask = self.csv[pathology]

            labels.append(mask.values)
        self.labels = np.asarray(labels).T
        self.labels = self.labels.astype(np.float32)

        # Make all the -1 values into nans to keep things simple
        self.labels[self.labels == -1] = 0

        # Rename pathologies
        #self.pathologies = list(np.char.replace(self.pathologies, "Pleural Effusion", "Effusion"))
        # add consistent csv values

        # patientid
        self.csv["empi_anon"] = self.csv["empi_anon"].astype(str)

    def __getitem__(self, i):        
        sample = {}
        sample["patient_id"] = int(float(self.csv.iloc[i]["empi_anon"]))
        sample["study_id"] = int(float(self.csv.iloc[i]["AccessionNumber_anon"]))
        sample["lab"] = self.labels[i]
        sample["emb"] = self.load_embedding(self.csv.iloc[i]["SOP"])

        return sample

    def __len__(self):
        return len(self.labels)        
        

    def load_emb_path(self, embedding_type):
        if self.embedding_type != "All":
            return self.embedding_d[embedding_type]
        else:            
            return list(self.embedding_d.values())

    def limit_to_selected_views(self, views):
        """This function is called by subclasses to filter the
        images by view based on the values in .csv['view']
        """
        if type(views) is not list:
            views = [views]
        if '*' in views:
            # if you have the wildcard, the rest are irrelevant
            views = ["*"]
        self.views = views

        # missing data is unknown
        self.csv = self.csv.copy()
        self.csv["view"] = self.csv["view"].fillna("UNKNOWN")

        if "*" not in views:
            self.csv = self.csv[self.csv["view"].isin(self.views)]  # Select the view
    
    def load_embedding(self, embedding_id):
        if self.embedding_type == "All":
            merged_emb = []
            for embedding_type in list(self.embedding_d.keys()):
                emb = np.load(f"{self.base_dicom_path/('embds_'+ embedding_type)/embedding_id}.npy")
                merged_emb.append(emb)
            return np.concat(merged_emb)                
        else:
            return np.load(f"{self.base_dicom_path/('embds_'+ self.embedding_type)/embedding_id}.npy")

    def load_all(self):
        print(f"loading all {self.mode} data")
        samples = []
        for i in tqdm(range(self.__len__())):
            sample = self.__getitem__(i)
            samples.append(sample)
        return samples


class Dataloader(DataLoader):
    def __init__(
        self,
        dataset,
        batch_size=1,
        shuffle=False,
        sampler=None,
        batch_sampler=None,
        num_workers=8,
        collate_fn=None,
        pin_memory=True,
        drop_last=False,
        timeout=0,
        worker_init_fn=None,
        multiprocessing_context=None,
        generator=None,
        prefetch_factor=None,
        persistent_workers=False,
        **kwargs
    ):

        # Initialize the parent class with all arguments
        super().__init__(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=shuffle,
            sampler=sampler,
            batch_sampler=batch_sampler,
            num_workers=num_workers,
            collate_fn=collate_fn,
            pin_memory=pin_memory,
            drop_last=drop_last,
            timeout=timeout,
            worker_init_fn=worker_init_fn,
            multiprocessing_context=multiprocessing_context,
            generator=generator,
            prefetch_factor=prefetch_factor,
            persistent_workers=persistent_workers,
            **kwargs
        )

In [3]:
# ================== 1) Load FM="All" per split ==================
all_train = EmoryDataset(
    views="PA", mode="train", embedding_type="All",
    unique_patients=True, seed=0
).load_all()

all_valid = EmoryDataset(
    views="PA", mode="valid", embedding_type="All",
    unique_patients=True, seed=0
).load_all()

all_test = EmoryDataset(
    views="PA", mode="test", embedding_type="All",
    unique_patients=True, seed=0
).load_all()

# Wire into the structure your pipeline expects
concat_splits = {"train": all_train, "valid": all_valid, "test": all_test}

loading all train data


100%|██████████| 6000/6000 [00:38<00:00, 157.35it/s]


loading all valid data


100%|██████████| 2000/2000 [00:12<00:00, 158.79it/s]


loading all test data


100%|██████████| 2000/2000 [00:12<00:00, 155.79it/s]


## Run PCA on Concat

In [None]:
# ===== PCA-enabled trainer + safe PCA sweep + visuals =====
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from typing import List, Dict, Tuple
from sklearn.preprocessing import StandardScaler
from sklearn.neural_network import MLPClassifier
from sklearn.utils.class_weight import compute_sample_weight
from sklearn.metrics import roc_auc_score, average_precision_score
from sklearn.decomposition import PCA
# ===== Labels (after renaming Pleural Effusion -> Effusion) =====

label_cols = [
    "Enlarged Cardiomediastinum","Cardiomegaly","Lung Opacity","Lung Lesion","Edema",
    "Consolidation","Pneumonia","Atelectasis","Pneumothorax","PleuralEffusion",
    "Pleural Other","Fracture","Support Devices"
]

# ---------- utilities ----------
def stack_from_samples(samples: List[Dict]) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    if len(samples) == 0:
        raise ValueError("Empty sample list.")
    X_list, Y_list, ids = [], [], []
    for s in samples:
        emb = np.asarray(s["emb"])
        if emb.ndim > 1: emb = emb.reshape(-1)
        X_list.append(emb.astype(np.float32))
        Y_list.append(np.asarray(s["lab"], dtype=np.float32))
        ids.append(int(s["patient_id"]))
    X = np.vstack(X_list)
    Y = np.vstack(Y_list)
    ids = np.asarray(ids)
    X = np.nan_to_num(X, posinf=0.0, neginf=0.0)
    Y = np.nan_to_num(Y, posinf=0.0, neginf=0.0)
    Y = (Y > 0.5).astype(int)
    return X, Y, ids

# ---------- trainer with optional PCA ----------
def fit_eval_mlp_on_splits(
    train_samples: List[Dict],
    valid_samples: List[Dict],
    test_samples:  List[Dict],
    label_names: List[str],
    seed: int = 0,
    hidden=(512, 256),
    max_iter=120,
    pca_dim: int | None = None,   # NEW
    pca_whiten: bool = False,     # optional
) -> Dict[str, Dict[str, float]]:
    X_tr, Y_tr, _ = stack_from_samples(train_samples)
    X_va, Y_va, _ = stack_from_samples(valid_samples)
    X_te, Y_te, _ = stack_from_samples(test_samples)

    # scale using TRAIN only
    scaler = StandardScaler()
    X_tr_s = scaler.fit_transform(X_tr)
    X_va_s = scaler.transform(X_va)
    X_te_s = scaler.transform(X_te)

    # PCA after concat (optional)
    if pca_dim is not None and 0 < pca_dim < X_tr_s.shape[1]:
        pca = PCA(n_components=pca_dim, whiten=pca_whiten, random_state=seed)
        X_tr_s = pca.fit_transform(X_tr_s)
        X_va_s = pca.transform(X_va_s)
        X_te_s = pca.transform(X_te_s)

    L = Y_tr.shape[1]
    P_va = np.zeros((X_va_s.shape[0], L), dtype=float)
    P_te = np.zeros((X_te_s.shape[0], L), dtype=float)

    for j in range(L):
        y_tr = Y_tr[:, j]
        if len(np.unique(y_tr)) < 2:
            const_p = float(y_tr.mean())
            P_va[:, j] = const_p
            P_te[:, j] = const_p
            continue

        sw = compute_sample_weight("balanced", y_tr)
        clf = MLPClassifier(
            hidden_layer_sizes=hidden,
            activation="relu",
            solver="adam",
            alpha=1e-4,
            learning_rate_init=1e-3,
            batch_size=256,
            max_iter=max_iter,
            early_stopping=True,
            n_iter_no_change=10,
            validation_fraction=0.15,
            shuffle=True,
            random_state=seed,
        )
        clf.fit(X_tr_s, y_tr, sample_weight=sw)
        P_va[:, j] = clf.predict_proba(X_va_s)[:, 1]
        P_te[:, j] = clf.predict_proba(X_te_s)[:, 1]

    def split_metrics(Y_true: np.ndarray, P: np.ndarray) -> Dict[str, float]:
        per_auc, per_ap, valid_cols = {}, {}, []
        for j, lab in enumerate(label_names):
            yt = Y_true[:, j]
            if len(np.unique(yt)) < 2:
                per_auc[lab] = np.nan
                per_ap[lab]  = np.nan
            else:
                per_auc[lab] = roc_auc_score(yt, P[:, j])
                per_ap[lab]  = average_precision_score(yt, P[:, j])
                valid_cols.append(j)
        auroc_macro = float(np.nanmean(list(per_auc.values()))) if per_auc else np.nan
        ap_macro    = float(np.nanmean(list(per_ap.values())))  if per_ap else np.nan
        if valid_cols:
            auroc_micro = roc_auc_score(Y_true[:, valid_cols], P[:, valid_cols], average="micro")
            ap_micro    = average_precision_score(Y_true[:, valid_cols], P[:, valid_cols], average="micro")
        else:
            auroc_micro = np.nan; ap_micro = np.nan
        return dict(
            AUROC_macro=auroc_macro, AP_macro=ap_macro,
            AUROC_micro=auroc_micro, AP_micro=ap_micro,
            per_label_AUROC=per_auc, per_label_AP=per_ap,
        )

    return {"valid": split_metrics(Y_va, P_va), "test": split_metrics(Y_te, P_te)}

# ---------- SAFE PCA sweep over concatenated splits ----------
# expects: concat_splits["train"|"valid"|"test"] and label_cols already defined
X_tr_concat, _, _ = stack_from_samples(concat_splits["train"])
full_dim = X_tr_concat.shape[1]

# candidate dims; keep valid (< full_dim)
pca_dims = [4, 8, 16, 32, 64, 128, 256, 512, 1024]
pca_dims = [d for d in pca_dims if d < full_dim]
if len(pca_dims) == 0:
    print(f"No PCA dims < concatenated dim ({full_dim}). Skipping PCA sweep.")

concat_grid_results: Dict[str, Dict[str, Dict[str, float]]] = {}
concat_grid_perlabel: Dict[str, Dict[str, float]] = {}

for d in pca_dims:
    tag = f"PCA-{d}"
    try:
        res_d = fit_eval_mlp_on_splits(
            train_samples=concat_splits["train"],
            valid_samples=concat_splits["valid"],
            test_samples=concat_splits["test"],
            label_names=label_cols,
            seed=0, hidden=(512,256), max_iter=120,
            pca_dim=d,  # PCA after concat
            # pca_whiten=True,  # optional
        )
        concat_grid_results[tag] = {
            "valid": {k:v for k,v in res_d["valid"].items() if not isinstance(v, dict)},
            "test":  {k:v for k,v in res_d["test"].items()  if not isinstance(v, dict)},
        }
        concat_grid_perlabel[tag] = res_d["test"]["per_label_AUROC"]
        print(f"[{tag}] TEST AUROC_macro={concat_grid_results[tag]['test']['AUROC_macro']:.4f} "
              f"| AP_macro={concat_grid_results[tag]['test']['AP_macro']:.4f} "
              f"| AUROC_micro={concat_grid_results[tag]['test']['AUROC_micro']:.4f} "
              f"| AP_micro={concat_grid_results[tag]['test']['AP_micro']:.4f}")
    except Exception as e:
        print(f"⚠️ Skipping {tag} due to error: {e}")

# ---------- Plot: macro metrics vs PCA dim (robust to missing dims) ----------
rows = []
for d in pca_dims:
    tag = f"Concat-PCA-{d}"
    if tag not in concat_grid_results:
        print(f"⚠️ Skipping {tag} (no results)")
        continue
    rows.append({
        "pca_dim": d,
        "AUROC_macro": concat_grid_results[tag]["test"]["AUROC_macro"],
        "AP_macro":    concat_grid_results[tag]["test"]["AP_macro"],
        "AUROC_micro": concat_grid_results[tag]["test"]["AUROC_micro"],
        "AP_micro":    concat_grid_results[tag]["test"]["AP_micro"],
    })
df_pca = pd.DataFrame(rows).sort_values("pca_dim")

if not df_pca.empty:
    plt.figure(figsize=(7,4))
    plt.plot(df_pca["pca_dim"], df_pca["AUROC_macro"], marker="o")
    plt.xlabel("PCA dimension (after concat)")
    plt.ylabel("Test AUROC (macro)")
    plt.title("Concat → PCA: Test AUROC_macro vs. PCA dim")
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

    plt.figure(figsize=(7,4))
    plt.plot(df_pca["pca_dim"], df_pca["AP_macro"], marker="o")
    plt.xlabel("PCA dimension (after concat)")
    plt.ylabel("Test AP (macro)")
    plt.title("Concat → PCA: Test AP_macro vs. PCA dim")
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()
else:
    print("No PCA points to plot.")

# ---------- Optional: per-label heatmap over PCA variants ----------
def plot_per_label_heatmap(per_label_results: Dict[str, Dict[str, float]],
                           label_cols: List[str],
                           title="Per-label AUROC: Concat + PCA sweep"):
    df_scores = pd.DataFrame(per_label_results).T
    for c in label_cols:
        if c not in df_scores.columns:
            df_scores[c] = np.nan
    df_scores = df_scores[label_cols]
    plt.figure(figsize=(14, 6))
    sns.heatmap(df_scores, annot=True, fmt=".2f", cmap="viridis",
                cbar_kws={'label':'AUROC'})
    plt.title(title)
    plt.xlabel("Disease Label")
    plt.ylabel("Model Variant")
    plt.tight_layout()
    plt.show()

if len(concat_grid_perlabel):
    plot_per_label_heatmap(concat_grid_perlabel, label_cols,
                           title="Per-label AUROC — Concat with PCA (various dims)")



In [None]:
import numpy as np, pandas as pd, seaborn as sns, matplotlib.pyplot as plt
from typing import Dict, List

def plot_per_label_heatmap_ranked(
    per_label_results: Dict[str, Dict[str, float]],
    label_cols: List[str],
    title: str = "Per-label AUROC (bright = best, includes Average)",
    add_average: bool = True,
    sort_rows_by_avg: bool = False,
    star_char: str = "★",
    star_tol: float = 1e-12,   # treat near-equal as ties
):
    # 1) Wide DF of true AUROCs (rows = models/variants, cols = labels)
    df_true = pd.DataFrame(per_label_results).T
    for c in label_cols:
        if c not in df_true.columns:
            df_true[c] = np.nan
    df_true = df_true[label_cols]

    # 2) Average column
    if add_average:
        df_true["Average"] = df_true[label_cols].mean(axis=1)

    # 3) Optional sort by Average (desc)
    if sort_rows_by_avg and "Average" in df_true.columns:
        df_true = df_true.sort_values("Average", ascending=False)

    # 4) Column-wise normalization → color brightness shows rank per label
    df_norm = df_true.copy()
    for c in label_cols:
        col = df_norm[c].to_numpy(dtype=float)
        cmin, cmax = np.nanmin(col), np.nanmax(col)
        if np.isfinite(cmin) and np.isfinite(cmax) and cmax > cmin:
            df_norm[c] = (col - cmin) / (cmax - cmin)
        else:
            df_norm[c] = 0.5  # flat/undefined → neutral brightness

    # Normalize Average column for color, too (optional but looks nicer)
    if "Average" in df_norm.columns:
        av = df_true["Average"].to_numpy(dtype=float)
        if np.isfinite(av).all() and np.nanmax(av) > np.nanmin(av):
            df_norm["Average"] = (av - np.nanmin(av)) / (np.nanmax(av) - np.nanmin(av))
        else:
            df_norm["Average"] = 0.5

    # 5) Build annotation text with stars on per-label winners
    annot_text = df_true.copy().astype(float).applymap(lambda v: "" if np.isnan(v) else f"{v:.2f}")

    for c in label_cols:
        col_vals = df_true[c].astype(float)
        if col_vals.notna().any():
            mx = col_vals.max()
            # treat near-equal within star_tol as ties
            winners = col_vals.index[(np.abs(col_vals - mx) <= star_tol)]
            for idx in winners:
                # append star to existing formatted value
                annot_text.loc[idx, c] = f"{annot_text.loc[idx, c]}{star_char}"

    # 6) Plot
    plt.figure(figsize=(14, 6))
    ax = sns.heatmap(
        df_norm, annot=annot_text, fmt="", cmap="viridis",
        cbar_kws={'label': 'Relative (per-label) performance'},
        vmin=0.0, vmax=1.0
    )
    ax.set_title(title)
    ax.set_xlabel("Disease Label" + (" + Average" if "Average" in df_true.columns else ""))
    ax.set_ylabel("Model / Variant")
    plt.tight_layout()
    plt.show()

##Run
    
plot_per_label_heatmap_ranked(
    concat_grid_perlabel,   # e.g., {"Concat-PCA-16": {"Atelectasis":0.84, ...}, ...}
    label_cols,
    title="Per-label AUROC — Concat + PCA (dims 4-> 1024; column-normalized, best=bright)",
    add_average=True,
    sort_rows_by_avg=False   # keep rows ordered by PCA dim name
)