In [12]:
## IMPORTS

import os
import torch
import numpy as np
import nibabel as nib
from tqdm import tqdm
from transformers import AutoImageProcessor, AutoModel
import torch.nn as nn
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report, roc_auc_score



import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from sklearn.manifold import TSNE
import os


import umap


from sklearn.ensemble import RandomForestClassifier
from imblearn.over_sampling import SMOTE


In [2]:

# --- Paths and Model ---
DINO_MODEL_NAME = "facebook/dinov2-base"
CHECKPOINT_PATH = "/oak/stanford/groups/ogevaert/maxvpuyv/projects/brain/runs/dino_full_public/checkpoints/checkpoint_step160000_epoch9.pt"
PDGM_DIR = "/oak/stanford/groups/ogevaert/data/brain_mri_tumor_project/UCSF-PDGM-v3"
SAVE_PATH = "/oak/stanford/groups/ogevaert/maxvpuyv/projects/brain/data/features/dino_pdgm_features.npz"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


# --- Projection Head definition ---
class ProjectionHead(nn.Module):
    def __init__(self, in_dim, out_dim=256):
        super().__init__()
        self.proj = nn.Sequential(
            nn.Linear(in_dim, out_dim),
            nn.GELU(),
            nn.Linear(out_dim, out_dim)
        )
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight)
            nn.init.constant_(m.bias, 0)

    def forward(self, x):
        return self.proj(x)


# --- Load DINO backbone and projection head ---
processor = AutoImageProcessor.from_pretrained(DINO_MODEL_NAME)
vit = AutoModel.from_pretrained(DINO_MODEL_NAME).to(DEVICE)
proj_head = ProjectionHead(vit.config.hidden_size).to(DEVICE)

ckpt = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
vit.load_state_dict(ckpt["backbone"])
proj_head.load_state_dict(ckpt["head"])
vit.eval()
proj_head.eval()

# --- Preprocessing helper ---
def preprocess_slice(slice_2d):
    # Normalize and convert to uint8
    slice_2d = np.uint8(255 * np.clip(slice_2d, 0, 1))
    # DINO expects 3 channels; stack if needed
    if slice_2d.ndim == 2:
        slice_2d = np.stack([slice_2d]*3, axis=-1)
    return processor(images=slice_2d, return_tensors="pt")["pixel_values"].squeeze(0)

# --- Feature extraction ---
features, case_ids = [], []
case_dirs = sorted([d for d in os.listdir(PDGM_DIR) if d.startswith("UCSF-PDGM-")])

for case_folder in tqdm(case_dirs):
    case_path = os.path.join(PDGM_DIR, case_folder)
    num4 = case_folder.split("-")[-1].split("_")[0]
    file_case_id = f"UCSF-PDGM-{num4}"
    t1c_path = os.path.join(case_path, f"{file_case_id}_T1c_bias.nii.gz")
    if not os.path.exists(t1c_path):
        continue

    img = nib.load(t1c_path).get_fdata().astype(np.float32)
    img = img / (np.max(img) + 1e-8)
    slice_embeds = []
    for i in range(img.shape[2]):
        slice_2d = img[:, :, i]
        if np.std(slice_2d) < 1e-5:
            continue
        input_tensor = preprocess_slice(slice_2d).to(DEVICE)
        with torch.no_grad():
            embed = vit(input_tensor.unsqueeze(0)).last_hidden_state[:, 0, :]
            proj = proj_head(embed)
            slice_embeds.append(proj.cpu().numpy().squeeze())
    if len(slice_embeds) == 0:
        continue
    case_embed = np.mean(np.stack(slice_embeds), axis=0)
    features.append(case_embed)
    case_ids.append(file_case_id)

features = np.stack(features)
np.savez(SAVE_PATH, features=features, case_ids=np.array(case_ids))
print(f"Saved DINO features to {SAVE_PATH}")

100%|██████████| 501/501 [15:08<00:00,  1.81s/it]

Saved DINO features to /oak/stanford/groups/ogevaert/maxvpuyv/projects/brain/data/features/dino_pdgm_features.npz





In [6]:
# L2 log reg classifier

# --- Load features and metadata ---
data = np.load("/oak/stanford/groups/ogevaert/maxvpuyv/projects/brain/data/features/dino_pdgm_features.npz", allow_pickle=True)
X = data["features"]
case_ids = data["case_ids"]
meta = pd.read_csv("/oak/stanford/groups/ogevaert/maxvpuyv/projects/brain/data/metadata/PGDM/UCSF-PDGM-metadata_v2.csv").set_index("ID")

# --- Standardize case_ids to match metadata index (three digits) ---
case_ids_fixed = np.array([f"UCSF-PDGM-{int(cid.split('-')[-1]):03d}" for cid in case_ids])

# --- Filter to only cases present in metadata ---
valid_mask = np.isin(case_ids_fixed, meta.index)
X = X[valid_mask]
case_ids_fixed = case_ids_fixed[valid_mask]

# --- Prepare labels for IDH status ---
idh_status = meta.loc[case_ids_fixed, "IDH"].values
labels = np.array([0 if str(v).strip().lower() == "wildtype" else 1 for v in idh_status])

# --- Train/test split ---


X_train, X_test, y_train, y_test = train_test_split(
    X, labels, test_size=0.2, random_state=42, stratify=labels
)

# --- Train regularized logistic regression classifier ---
clf = LogisticRegression(
    penalty='l2',
    max_iter=1000,
    class_weight='balanced'
)
clf.fit(X_train, y_train)

# --- Evaluate ---
y_pred = clf.predict(X_test)
y_prob = clf.predict_proba(X_test)[:, 1]

print("Classification report:")
print(classification_report(y_test, y_pred))
print("ROC AUC:", roc_auc_score(y_test, y_prob))

Classification report:
              precision    recall  f1-score   support

           0       0.91      0.79      0.85        78
           1       0.48      0.71      0.58        21

    accuracy                           0.78        99
   macro avg       0.70      0.75      0.71        99
weighted avg       0.82      0.78      0.79        99

ROC AUC: 0.8266178266178267


In [13]:
# Forest (with SMOTE) for DINO features


# --- Load features and metadata ---
data = np.load("/oak/stanford/groups/ogevaert/maxvpuyv/projects/brain/data/features/dino_pdgm_features_1e4lr160000.npz", allow_pickle=True)
X = data["features"]
case_ids = data["case_ids"]
meta = pd.read_csv("/oak/stanford/groups/ogevaert/maxvpuyv/projects/brain/data/metadata/PGDM/UCSF-PDGM-metadata_v2.csv").set_index("ID")

# --- Standardize case_ids to three digits and filter ---
case_ids_fixed = np.array([f"UCSF-PDGM-{int(cid.split('-')[-1]):03d}" for cid in case_ids])
valid_mask = np.isin(case_ids_fixed, meta.index)
X = X[valid_mask]
case_ids_fixed = case_ids_fixed[valid_mask]

# --- Prepare labels for IDH status ---
idh_status = meta.loc[case_ids_fixed, "IDH"].values
labels = np.array([0 if str(v).strip().lower() == "wildtype" else 1 for v in idh_status])

# --- Train/test split ---
X_train, X_test, y_train, y_test = train_test_split(
    X, labels, test_size=0.2, random_state=42, stratify=labels
)

# --- Oversample minority class in training set ---
sm = SMOTE(random_state=42)
X_train_res, y_train_res = sm.fit_resample(X_train, y_train)

# --- Train Random Forest classifier ---
clf = RandomForestClassifier(n_estimators=100, random_state=42)
clf.fit(X_train_res, y_train_res)

# --- Evaluate ---
y_pred = clf.predict(X_test)
y_prob = clf.predict_proba(X_test)[:, 1]

print("Classification report:")
print(classification_report(y_test, y_pred))
print("ROC AUC:", roc_auc_score(y_test, y_prob))



Classification report:
              precision    recall  f1-score   support

           0       0.87      0.94      0.90        78
           1       0.67      0.48      0.56        21

    accuracy                           0.84        99
   macro avg       0.77      0.71      0.73        99
weighted avg       0.83      0.84      0.83        99

ROC AUC: 0.7780830280830281


In [9]:
data = np.load("/oak/stanford/groups/ogevaert/maxvpuyv/projects/brain/data/features/dino_pdgm_features_1e4lr160000.npz", allow_pickle=True)
X = data["features"]
case_ids = data["case_ids"]
meta = pd.read_csv("/oak/stanford/groups/ogevaert/maxvpuyv/projects/brain/data/metadata/PGDM/UCSF-PDGM-metadata_v2.csv").set_index("ID")

# --- Standardize case_ids to three digits and filter ---
case_ids_fixed = np.array([f"UCSF-PDGM-{int(cid.split('-')[-1]):03d}" for cid in case_ids])
valid_mask = np.isin(case_ids_fixed, meta.index)
X = X[valid_mask]
case_ids_fixed = case_ids_fixed[valid_mask]

# --- Standardize features and run t-SNE ---
X_scaled = StandardScaler().fit_transform(X)
tsne = TSNE(n_components=2, random_state=42, perplexity=30)
X_tsne = tsne.fit_transform(X_scaled)

# --- Metadata columns to visualize ---
columns = [
    "ID", "Sex", "Age at MRI", "WHO CNS Grade", "Final pathologic diagnosis (WHO 2021)",
    "MGMT status", "MGMT index", "1p/19q", "IDH", "1-dead 0-alive", "OS", "EOR",
    "Biopsy prior to imaging", "BraTS21 ID", "BraTS21 Segmentation Cohort", "BraTS21 MGMT Cohort"
]

outdir = "/oak/stanford/groups/ogevaert/maxvpuyv/projects/brain/data/tsne_plots/PDGM_dino"
os.makedirs(outdir, exist_ok=True)

for column in columns:
    try:
        values = meta.loc[case_ids_fixed, column].values
    except KeyError:
        print(f"Column {column} not found in metadata.")
        continue

    # Convert all values to string, replace nan with "NA"
    values_str = np.array([str(v) if pd.notna(v) else "NA" for v in values])

    plt.figure(figsize=(7, 6))
    for val in np.unique(values_str):
        idx = values_str == val
        plt.scatter(X_tsne[idx, 0], X_tsne[idx, 1], label=str(val), alpha=0.7, s=20)
    plt.legend(markerscale=2, bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.title(f"t-SNE colored by {column}")
    plt.xlabel("t-SNE-1")
    plt.ylabel("t-SNE-2")
    plt.tight_layout()
    fname = f"tsne_{column.replace(' ', '_').replace('/', '_')}.png"
    plt.savefig(os.path.join(outdir, fname))
    plt.close()
    print(f"Saved t-SNE for {column}")

print("All t-SNE plots saved")

Column ID not found in metadata.
Saved t-SNE for Sex


  plt.tight_layout()


Saved t-SNE for Age at MRI
Saved t-SNE for WHO CNS Grade
Saved t-SNE for Final pathologic diagnosis (WHO 2021)
Saved t-SNE for MGMT status
Saved t-SNE for MGMT index
Saved t-SNE for 1p/19q
Saved t-SNE for IDH
Saved t-SNE for 1-dead 0-alive


  plt.tight_layout()


Saved t-SNE for OS
Saved t-SNE for EOR
Saved t-SNE for Biopsy prior to imaging


  plt.tight_layout()


Saved t-SNE for BraTS21 ID
Saved t-SNE for BraTS21 Segmentation Cohort
Saved t-SNE for BraTS21 MGMT Cohort
All t-SNE plots saved


In [11]:


data = np.load("/oak/stanford/groups/ogevaert/maxvpuyv/projects/brain/data/features/dino_pdgm_features_1e4lr160000.npz", allow_pickle=True)
X = data["features"]
case_ids = data["case_ids"]
meta = pd.read_csv("/oak/stanford/groups/ogevaert/maxvpuyv/projects/brain/data/metadata/PGDM/UCSF-PDGM-metadata_v2.csv").set_index("ID")

# --- Standardize case_ids to three digits and filter ---
case_ids_fixed = np.array([f"UCSF-PDGM-{int(cid.split('-')[-1]):03d}" for cid in case_ids])
valid_mask = np.isin(case_ids_fixed, meta.index)
X = X[valid_mask]
case_ids_fixed = case_ids_fixed[valid_mask]

# --- Standardize features and run UMAP ---
X_scaled = StandardScaler().fit_transform(X)
reducer = umap.UMAP(random_state=42)
X_umap = reducer.fit_transform(X_scaled)

# --- Metadata columns to visualize ---
columns = [
    "ID", "Sex", "Age at MRI", "WHO CNS Grade", "Final pathologic diagnosis (WHO 2021)",
    "MGMT status", "MGMT index", "1p/19q", "IDH", "1-dead 0-alive", "OS", "EOR",
    "Biopsy prior to imaging", "BraTS21 ID", "BraTS21 Segmentation Cohort", "BraTS21 MGMT Cohort"
]

outdir = "/oak/stanford/groups/ogevaert/maxvpuyv/projects/brain/data/umap_plots/PDGM_dino"
os.makedirs(outdir, exist_ok=True)

for column in columns:
    try:
        values = meta.loc[case_ids_fixed, column].values
    except KeyError:
        print(f"Column {column} not found in metadata.")
        continue

    # Convert all values to string, replace nan with "NA"
    values_str = np.array([str(v) if pd.notna(v) else "NA" for v in values])

    plt.figure(figsize=(7, 6))
    for val in np.unique(values_str):
        idx = values_str == val
        plt.scatter(X_umap[idx, 0], X_umap[idx, 1], label=str(val), alpha=0.7, s=20)
    plt.legend(markerscale=2, bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.title(f"UMAP colored by {column}")
    plt.xlabel("UMAP-1")
    plt.ylabel("UMAP-2")
    plt.tight_layout()
    fname = f"umap_{column.replace(' ', '_').replace('/', '_')}.png"
    plt.savefig(os.path.join(outdir, fname))
    plt.close()
    print(f"Saved UMAP for {column}")

print("All UMAP plots saved")

  warn(


Column ID not found in metadata.
Saved UMAP for Sex


  plt.tight_layout()


Saved UMAP for Age at MRI
Saved UMAP for WHO CNS Grade
Saved UMAP for Final pathologic diagnosis (WHO 2021)
Saved UMAP for MGMT status
Saved UMAP for MGMT index
Saved UMAP for 1p/19q
Saved UMAP for IDH
Saved UMAP for 1-dead 0-alive


  plt.tight_layout()


Saved UMAP for OS
Saved UMAP for EOR
Saved UMAP for Biopsy prior to imaging


  plt.tight_layout()


Saved UMAP for BraTS21 ID
Saved UMAP for BraTS21 Segmentation Cohort
Saved UMAP for BraTS21 MGMT Cohort
All UMAP plots saved
