In [1]:
## MODEL CODE + EXTRACTION

# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from collections.abc import Sequence

import numpy as np
import torch
import torch.nn as nn

from monai.networks.blocks.patchembedding import PatchEmbeddingBlock
from monai.networks.blocks.pos_embed_utils import build_sincos_position_embedding
from monai.networks.blocks.transformerblock import TransformerBlock
from monai.networks.layers import trunc_normal_
from monai.utils import ensure_tuple_rep
from monai.utils.module import look_up_option

SUPPORTED_POS_EMBEDDING_TYPES = {"none", "learnable", "sincos"}

__all__ = ["MaskedAutoEncoderViT"]


class MaskedAutoEncoderViT(nn.Module):
    """
    Masked Autoencoder (ViT), based on: "Kaiming et al.,
    Masked Autoencoders Are Scalable Vision Learners <https://arxiv.org/abs/2111.06377>"
    Only a subset of the patches passes through the encoder. The decoder tries to reconstruct
    the masked patches, resulting in improved training speed.
    """

    def __init__(
        self,
        in_channels: int,
        img_size: Sequence[int] | int,
        patch_size: Sequence[int] | int,
        hidden_size: int = 768,
        mlp_dim: int = 512,
        num_layers: int = 12,
        num_heads: int = 12,
        masking_ratio: float = 0.75,
        decoder_hidden_size: int = 384,
        decoder_mlp_dim: int = 512,
        decoder_num_layers: int = 4,
        decoder_num_heads: int = 12,
        proj_type: str = "conv",
        pos_embed_type: str = "sincos",
        decoder_pos_embed_type: str = "sincos",
        dropout_rate: float = 0.0,
        spatial_dims: int = 3,
        qkv_bias: bool = False,
        save_attn: bool = False,
    ) -> None:
        """
        Args:
            in_channels: dimension of input channels or the number of channels for input.
            img_size: dimension of input image.
            patch_size: dimension of patch size
            hidden_size: dimension of hidden layer. Defaults to 768.
            mlp_dim: dimension of feedforward layer. Defaults to 512.
            num_layers:  number of transformer blocks. Defaults to 12.
            num_heads: number of attention heads. Defaults to 12.
            masking_ratio: ratio of patches to be masked. Defaults to 0.75.
            decoder_hidden_size: dimension of hidden layer for decoder. Defaults to 384.
            decoder_mlp_dim: dimension of feedforward layer for decoder. Defaults to 512.
            decoder_num_layers: number of transformer blocks for decoder. Defaults to 4.
            decoder_num_heads: number of attention heads for decoder. Defaults to 12.
            proj_type: position embedding layer type. Defaults to "conv".
            pos_embed_type: position embedding layer type. Defaults to "sincos".
            decoder_pos_embed_type: position embedding layer type for decoder. Defaults to "sincos".
            dropout_rate: fraction of the input units to drop. Defaults to 0.0.
            spatial_dims: number of spatial dimensions. Defaults to 3.
            qkv_bias: apply bias to the qkv linear layer in self attention block. Defaults to False.
            save_attn: to make accessible the attention in self attention block. Defaults to False.
        Examples::
            # for single channel input with image size of (96,96,96), and sin-cos positional encoding
            >>> net = MaskedAutoEncoderViT(in_channels=1, img_size=(96,96,96), patch_size=(16,16,16),
            pos_embed_type='sincos')
            # for 3-channel with image size of (128,128,128) and a learnable positional encoding
            >>> net = MaskedAutoEncoderViT(in_channels=3, img_size=128, patch_size=16, pos_embed_type='learnable')
            # for 3-channel with image size of (224,224) and a masking ratio of 0.25
            >>> net = MaskedAutoEncoderViT(in_channels=3, img_size=(224,224), patch_size=(16,16), masking_ratio=0.25,
            spatial_dims=2)
        """

        super().__init__()

        if not (0 <= dropout_rate <= 1):
            raise ValueError(f"dropout_rate should be between 0 and 1, got {dropout_rate}.")

        if hidden_size % num_heads != 0:
            raise ValueError("hidden_size should be divisible by num_heads.")

        if decoder_hidden_size % decoder_num_heads != 0:
            raise ValueError("decoder_hidden_size should be divisible by decoder_num_heads.")

        self.patch_size = ensure_tuple_rep(patch_size, spatial_dims)
        self.img_size = ensure_tuple_rep(img_size, spatial_dims)
        self.spatial_dims = spatial_dims
        for m, p in zip(self.img_size, self.patch_size):
            if m % p != 0:
                raise ValueError(f"patch_size={patch_size} should be divisible by img_size={img_size}.")

        self.decoder_hidden_size = decoder_hidden_size

        if masking_ratio <= 0 or masking_ratio >= 1:
            raise ValueError(f"masking_ratio should be in the range (0, 1), got {masking_ratio}.")

        self.masking_ratio = masking_ratio
        self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size))

        self.patch_embedding = PatchEmbeddingBlock(
            in_channels=in_channels,
            img_size=img_size,
            patch_size=patch_size,
            hidden_size=hidden_size,
            num_heads=num_heads,
            proj_type=proj_type,
            pos_embed_type=pos_embed_type,
            dropout_rate=dropout_rate,
            spatial_dims=self.spatial_dims,
        )
        blocks = [
            TransformerBlock(hidden_size, mlp_dim, num_heads, dropout_rate, qkv_bias, save_attn)
            for _ in range(num_layers)
        ]
        self.blocks = nn.Sequential(*blocks, nn.LayerNorm(hidden_size))

        # decoder
        self.decoder_embed = nn.Linear(hidden_size, decoder_hidden_size)

        self.mask_tokens = nn.Parameter(torch.zeros(1, 1, decoder_hidden_size))

        self.decoder_pos_embed_type = look_up_option(decoder_pos_embed_type, SUPPORTED_POS_EMBEDDING_TYPES)
        self.decoder_pos_embedding = nn.Parameter(torch.zeros(1, self.patch_embedding.n_patches, decoder_hidden_size))

        decoder_blocks = [
            TransformerBlock(decoder_hidden_size, decoder_mlp_dim, decoder_num_heads, dropout_rate, qkv_bias, save_attn)
            for _ in range(decoder_num_layers)
        ]
        self.decoder_blocks = nn.Sequential(*decoder_blocks, nn.LayerNorm(decoder_hidden_size))
        self.decoder_pred = nn.Linear(decoder_hidden_size, int(np.prod(self.patch_size)) * in_channels)

        self._init_weights()

    def _init_weights(self):
        """
        similar to monai/networks/blocks/patchembedding.py for the decoder positional encoding and for mask and
        classification tokens
        """
        if self.decoder_pos_embed_type == "none":
            pass
        elif self.decoder_pos_embed_type == "learnable":
            trunc_normal_(self.decoder_pos_embedding, mean=0.0, std=0.02, a=-2.0, b=2.0)
        elif self.decoder_pos_embed_type == "sincos":
            grid_size = []
            for in_size, pa_size in zip(self.img_size, self.patch_size):
                grid_size.append(in_size // pa_size)

            self.decoder_pos_embedding = build_sincos_position_embedding(
                grid_size, self.decoder_hidden_size, self.spatial_dims
            )

        else:
            raise ValueError(f"decoder_pos_embed_type {self.decoder_pos_embed_type} not supported.")

        # initialize patch_embedding like nn.Linear (instead of nn.Conv2d)
        trunc_normal_(self.mask_tokens, mean=0.0, std=0.02, a=-2.0, b=2.0)
        trunc_normal_(self.cls_token, mean=0.0, std=0.02, a=-2.0, b=2.0)

    def _masking(self, x, masking_ratio: float | None = None):
        batch_size, num_tokens, _ = x.shape
        percentage_to_keep = 1 - masking_ratio if masking_ratio is not None else 1 - self.masking_ratio
        selected_indices = torch.multinomial(
            torch.ones(batch_size, num_tokens), int(percentage_to_keep * num_tokens), replacement=False
        )
        x_masked = x[torch.arange(batch_size).unsqueeze(1), selected_indices]  # gather the selected tokens
        mask = torch.ones(batch_size, num_tokens, dtype=torch.int).to(x.device)
        mask[torch.arange(batch_size).unsqueeze(-1), selected_indices] = 0

        return x_masked, selected_indices, mask

    def forward(self, x, masking_ratio: float | None = None):
        x = self.patch_embedding(x)
        x, selected_indices, mask = self._masking(x, masking_ratio=masking_ratio)

        cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)

        x = self.blocks(x)

        # decoder
        x = self.decoder_embed(x)

        x_ = self.mask_tokens.repeat(x.shape[0], mask.shape[1], 1)
        x_[torch.arange(x.shape[0]).unsqueeze(-1), selected_indices] = x[:, 1:, :]  # no cls token
        x_ = x_ + self.decoder_pos_embedding
        x = torch.cat([x[:, :1, :], x_], dim=1)
        x = self.decoder_blocks(x)
        x = self.decoder_pred(x)

        x = x[:, 1:, :]
        return x, mask


######### NEW #########
    def get_encoder_features(self, x):
        """
        Returns the encoder output (after transformer blocks, before decoder).
        """
        x = self.patch_embedding(x)
        x, selected_indices, mask = self._masking(x)
        cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = self.blocks(x)
        return x  # shape: (batch, num_patches+1, hidden_size)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
## Imports

import os
import glob
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from monai.transforms import (
    LoadImage, EnsureChannelFirst, NormalizeIntensity, Compose
)
import nibabel as nib
import matplotlib.pyplot as plt
import umap
from sklearn.preprocessing import StandardScaler
from sklearn.manifold import TSNE
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, roc_auc_score
from imblearn.over_sampling import SMOTE
from sklearn.linear_model import LogisticRegression

In [6]:
# ---- SETTINGS ----
PDGM_DIR = "/oak/stanford/groups/ogevaert/data/brain_mri_tumor_project/UCSF-PDGM-v3"
METADATA_CSV = "/oak/stanford/groups/ogevaert/maxvpuyv/projects/brain/data/metadata/PGDM/UCSF-PDGM-metadata_v2.csv"
CHECKPOINT_PATH = "/oak/stanford/groups/ogevaert/maxvpuyv/projects/brain/runs/mae_pdgm_bbox_lr1e4/checkpoint_epoch4000.pt"
CROP_SIZE = (80, 96, 80)
MARGIN = [10, 10, 10]
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SAVE_PATH = "/oak/stanford/groups/ogevaert/maxvpuyv/projects/brain/data/features/pdgm_mae_bbox_features_lr1e4_4000.npz"

# ---- Load metadata ----
meta = pd.read_csv(METADATA_CSV).set_index("ID")

# ---- Model ----
model = MaskedAutoEncoderViT(
    in_channels=1, img_size=CROP_SIZE, patch_size=(16,16,16),
    hidden_size=1152, mlp_dim=4608, num_layers=12, num_heads=16,
    masking_ratio=0.30, decoder_hidden_size=1152,
    decoder_mlp_dim=4608, decoder_num_layers=6, decoder_num_heads=16,
    spatial_dims=3
)
model.load_state_dict(torch.load(CHECKPOINT_PATH, map_location=DEVICE)["model"])
model.to(DEVICE)
model.eval()

# ---- Helper: Crop to tumor bbox + margin ----
def crop_to_bbox_with_margin(img, mask, margin):
    coords = np.array(np.where(mask > 0))
    if coords.shape[1] == 0:
        # No tumor, return center crop
        center = [s // 2 for s in img.shape]
        half = [min(s // 2, c // 2) for s, c in zip(img.shape, CROP_SIZE)]
        slices = tuple(slice(c - h, c + h) for c, h in zip(center, half))
        return img[slices]
    zmin, ymin, xmin = coords.min(axis=1)
    zmax, ymax, xmax = coords.max(axis=1)
    zmin = max(0, zmin - margin[0])
    zmax = min(img.shape[0] - 1, zmax + margin[0])
    ymin = max(0, ymin - margin[1])
    ymax = min(img.shape[1] - 1, ymax + margin[1])
    xmin = max(0, xmin - margin[2])
    xmax = min(img.shape[2] - 1, xmax + margin[2])
    return img[zmin:zmax+1, ymin:ymax+1, xmin:xmax+1]

# ---- Transforms ----
transforms = Compose([
    NormalizeIntensity(nonzero=True, channel_wise=True),
])

# ---- Feature Extraction ----
features, case_ids = [], []
case_dirs = sorted(glob.glob(os.path.join(PDGM_DIR, "UCSF-PDGM-*_*")))

for case_dir in tqdm(case_dirs):
    case_folder = os.path.basename(case_dir)
    num4 = case_folder.split("-")[-1].split("_")[0]
    num3 = f"{int(num4):03d}"
    csv_case_id = f"UCSF-PDGM-{num3}"
    file_case_id = f"UCSF-PDGM-{num4}"

    t1c_path = os.path.join(case_dir, f"{file_case_id}_T1c_bias.nii.gz")
    mask_path = os.path.join(case_dir, f"{file_case_id}_tumor_segmentation.nii.gz")
    if not (os.path.exists(t1c_path) and os.path.exists(mask_path)):
        continue
    if csv_case_id not in meta.index:
        continue

    img = nib.load(t1c_path).get_fdata()
    mask = nib.load(mask_path).get_fdata()
    # Crop to tumor bbox + margin
    cropped = crop_to_bbox_with_margin(img, mask, MARGIN)
    # Pad/crop to CROP_SIZE
    pad_width = [(0, max(0, CROP_SIZE[i] - cropped.shape[i])) for i in range(3)]
    cropped = np.pad(cropped, pad_width, mode='constant')
    cropped = cropped[:CROP_SIZE[0], :CROP_SIZE[1], :CROP_SIZE[2]]
    # Add channel dimension
    cropped = np.expand_dims(cropped, axis=0)
    # Apply transforms
    cropped = transforms(cropped)
    img_tensor = torch.tensor(cropped, dtype=torch.float32).unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        enc = model.get_encoder_features(img_tensor)
        cls_token = enc[:, 0, :]
        features.append(cls_token.cpu().numpy().squeeze())
        case_ids.append(csv_case_id)

features = np.stack(features)
np.savez(SAVE_PATH, features=features, case_ids=np.array(case_ids))
print("Saved features:", SAVE_PATH)

  img_tensor = torch.tensor(cropped, dtype=torch.float32).unsqueeze(0).to(DEVICE)
  img_tensor = torch.tensor(cropped, dtype=torch.float32).unsqueeze(0).to(DEVICE)
  img_tensor = torch.tensor(cropped, dtype=torch.float32).unsqueeze(0).to(DEVICE)
  img_tensor = torch.tensor(cropped, dtype=torch.float32).unsqueeze(0).to(DEVICE)
  img_tensor = torch.tensor(cropped, dtype=torch.float32).unsqueeze(0).to(DEVICE)
  img_tensor = torch.tensor(cropped, dtype=torch.float32).unsqueeze(0).to(DEVICE)
  img_tensor = torch.tensor(cropped, dtype=torch.float32).unsqueeze(0).to(DEVICE)
  img_tensor = torch.tensor(cropped, dtype=torch.float32).unsqueeze(0).to(DEVICE)
  img_tensor = torch.tensor(cropped, dtype=torch.float32).unsqueeze(0).to(DEVICE)
  img_tensor = torch.tensor(cropped, dtype=torch.float32).unsqueeze(0).to(DEVICE)
  img_tensor = torch.tensor(cropped, dtype=torch.float32).unsqueeze(0).to(DEVICE)
  img_tensor = torch.tensor(cropped, dtype=torch.float32).unsqueeze(0).to(DEVICE)
  img_tensor = t

Saved features: /oak/stanford/groups/ogevaert/maxvpuyv/projects/brain/data/features/pdgm_mae_bbox_features_lr1e4_4000.npz





In [7]:
# ---- UMAP Visualization ----
data = np.load(SAVE_PATH, allow_pickle=True)
X = data["features"]
case_ids = data["case_ids"]
meta = pd.read_csv(METADATA_CSV).set_index("ID")

X_scaled = StandardScaler().fit_transform(X)
reducer = umap.UMAP(random_state=42)
X_umap = reducer.fit_transform(X_scaled)

columns = [
    "ID", "Sex", "Age at MRI", "WHO CNS Grade", "Final pathologic diagnosis (WHO 2021)",
    "MGMT status", "MGMT index", "1p/19q", "IDH", "1-dead 0-alive", "OS", "EOR",
    "Biopsy prior to imaging", "BraTS21 ID", "BraTS21 Segmentation Cohort", "BraTS21 MGMT Cohort"
]

os.makedirs("/oak/stanford/groups/ogevaert/maxvpuyv/projects/brain/data/umap_plots/PDGM_mae_bbox", exist_ok=True)

for column in columns:
    try:
        values = meta.loc[case_ids, column].values
    except KeyError:
        print(f"Column {column} not found in metadata.")
        continue

    values_str = np.array([str(v) if pd.notna(v) else "NA" for v in values])

    plt.figure(figsize=(7, 6))
    for val in np.unique(values_str):
        idx = values_str == val
        plt.scatter(X_umap[idx, 0], X_umap[idx, 1], label=str(val), alpha=0.7, s=20)
    plt.legend(markerscale=2, bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.title(f"UMAP colored by {column}")
    plt.xlabel("UMAP-1")
    plt.ylabel("UMAP-2")
    plt.tight_layout()
    plt.savefig(f"/oak/stanford/groups/ogevaert/maxvpuyv/projects/brain/data/umap_plots/PDGM_mae_bbox/umap_{column.replace(' ', '_').replace('/', '_')}.png")
    plt.close()
    print(f"Saved UMAP for {column}")

print("All UMAP plots saved")

  warn(


Column ID not found in metadata.
Saved UMAP for Sex


  plt.tight_layout()


Saved UMAP for Age at MRI
Saved UMAP for WHO CNS Grade
Saved UMAP for Final pathologic diagnosis (WHO 2021)
Saved UMAP for MGMT status
Saved UMAP for MGMT index
Saved UMAP for 1p/19q
Saved UMAP for IDH
Saved UMAP for 1-dead 0-alive


  plt.tight_layout()


Saved UMAP for OS
Saved UMAP for EOR
Saved UMAP for Biopsy prior to imaging


  plt.tight_layout()


Saved UMAP for BraTS21 ID
Saved UMAP for BraTS21 Segmentation Cohort
Saved UMAP for BraTS21 MGMT Cohort
All UMAP plots saved


In [8]:
# ---- t-SNE Visualization ----
tsne = TSNE(n_components=2, random_state=42, perplexity=30)
X_tsne = tsne.fit_transform(X_scaled)

os.makedirs("/oak/stanford/groups/ogevaert/maxvpuyv/projects/brain/data/tsne_plots/PDGM_mae_bbox", exist_ok=True)

for column in columns:
    try:
        values = meta.loc[case_ids, column].values
    except KeyError:
        print(f"Column {column} not found in metadata.")
        continue

    values_str = np.array([str(v) if pd.notna(v) else "NA" for v in values])

    plt.figure(figsize=(7, 6))
    for val in np.unique(values_str):
        idx = values_str == val
        plt.scatter(X_tsne[idx, 0], X_tsne[idx, 1], label=str(val), alpha=0.7, s=20)
    plt.legend(markerscale=2, bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.title(f"t-SNE colored by {column}")
    plt.xlabel("t-SNE-1")
    plt.ylabel("t-SNE-2")
    plt.tight_layout()
    plt.savefig(f"/oak/stanford/groups/ogevaert/maxvpuyv/projects/brain/data/tsne_plots/PDGM_mae_bbox/tsne_{column.replace(' ', '_').replace('/', '_')}.png")
    plt.close()
    print(f"Saved t-SNE for {column}")

print("All t-SNE plots saved")

Column ID not found in metadata.
Saved t-SNE for Sex


  plt.tight_layout()


Saved t-SNE for Age at MRI
Saved t-SNE for WHO CNS Grade
Saved t-SNE for Final pathologic diagnosis (WHO 2021)
Saved t-SNE for MGMT status
Saved t-SNE for MGMT index
Saved t-SNE for 1p/19q
Saved t-SNE for IDH
Saved t-SNE for 1-dead 0-alive


  plt.tight_layout()


Saved t-SNE for OS
Saved t-SNE for EOR
Saved t-SNE for Biopsy prior to imaging


  plt.tight_layout()


Saved t-SNE for BraTS21 ID
Saved t-SNE for BraTS21 Segmentation Cohort
Saved t-SNE for BraTS21 MGMT Cohort
All t-SNE plots saved


In [9]:
# ---- Random Forest Classifier (with SMOTE) ----
idh_status = meta.loc[case_ids, "IDH"].values
labels = np.array([0 if str(v).strip().lower() == "wildtype" else 1 for v in idh_status])

X_train, X_test, y_train, y_test = train_test_split(
    X, labels, test_size=0.2, random_state=42, stratify=labels
)

sm = SMOTE(random_state=42)
X_train_res, y_train_res = sm.fit_resample(X_train, y_train)

clf = RandomForestClassifier(n_estimators=100, random_state=42)
clf.fit(X_train_res, y_train_res)

y_pred = clf.predict(X_test)
y_prob = clf.predict_proba(X_test)[:, 1]

print("Random Forest Classification report:")
print(classification_report(y_test, y_pred))
print("ROC AUC:", roc_auc_score(y_test, y_prob))



Random Forest Classification report:
              precision    recall  f1-score   support

           0       0.80      0.86      0.83        78
           1       0.27      0.19      0.22        21

    accuracy                           0.72        99
   macro avg       0.53      0.52      0.52        99
weighted avg       0.68      0.72      0.70        99

ROC AUC: 0.5961538461538463


In [10]:
# ---- L2 Logistic Regression ----
clf = LogisticRegression(
    penalty='l2',
    max_iter=1000,
    class_weight='balanced'
)
clf.fit(X_train, y_train)

y_pred = clf.predict(X_test)
y_prob = clf.predict_proba(X_test)[:, 1]

print("Logistic Regression Classification report:")
print(classification_report(y_test, y_pred))
print("ROC AUC:", roc_auc_score(y_test, y_prob))

Logistic Regression Classification report:
              precision    recall  f1-score   support

           0       0.84      0.73      0.78        78
           1       0.32      0.48      0.38        21

    accuracy                           0.68        99
   macro avg       0.58      0.60      0.58        99
weighted avg       0.73      0.68      0.70        99

ROC AUC: 0.681929181929182
