In [1]:
%cd /content
!rm -rf rsna-deep-learning-lab-2025
!ls

/content
sample_data


In [2]:
# ----------------------------
# Clone your repo
# ----------------------------
!git clone https://github.com/quantivly/rsna-deep-learning-lab-2025.git
%cd rsna-deep-learning-lab-2025

# ----------------------------
# Downgrade numpy/scipy to avoid import errors
# ----------------------------
!pip install --force-reinstall numpy==1.26.4 scipy==1.10.1

# ----------------------------
# Install PyTorch 2.3.0 + CUDA 12.1
# ----------------------------
!pip install torch==2.3.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

# ----------------------------
# Install PyG 2.6.1 + extensions (Python 3.12 compatible)
# ----------------------------
!pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv \
  -f https://data.pyg.org/whl/torch-2.3.0+cu121.html

!pip install torch_geometric==2.6.1


Cloning into 'rsna-deep-learning-lab-2025'...
remote: Enumerating objects: 70, done.[K
remote: Counting objects: 100% (70/70), done.[K
remote: Compressing objects: 100% (49/49), done.[K
remote: Total 70 (delta 36), reused 49 (delta 19), pack-reused 0 (from 0)[K
Receiving objects: 100% (70/70), 109.23 KiB | 4.96 MiB/s, done.
Resolving deltas: 100% (36/36), done.
/content/rsna-deep-learning-lab-2025
Collecting numpy==1.26.4
  Downloading numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.0/61.0 kB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: Ignored the following yanked versions: 1.11.0, 1.14.0rc1[0m[31m
[0m[31mERROR: Ignored the following versions that require a different python version: 1.10.0 Requires-Python <3.12,>=3.8; 1.10.0rc1 Requires-Python <3.12,>=3.8; 1.10.0rc2 Requires-Python <3.12,>=3.8; 1.10.1 Requires-Python <3.12,>=3.8; 1.21.2 Requires-Python >=

### Data Preprocessing and the DataCollection Class

**Why StandardScaler, OneHotEncoder, and ColumnTransformer?**

- **StandardScaler:** Normalizes numerical features to have mean 0 and variance 1
  - Helps the model learn efficiently, prevents domination by large-magnitude features
- **OneHotEncoder:** Converts categorical features into a binary vector representation
  - Allows the model to process non-numeric features without imposing arbitrary order
- **ColumnTransformer:** Combines multiple preprocessing steps for different feature types
  - Ensures numeric and categorical columns are processed appropriately in one pipeline

**Goal of `DataCollection` class:**

- Organizes **multi-modal patient data** from multiple sources:
  - Clinical metadata
  - Radiomic features
  - Gene assays
- Maps all patients, radiomic nodes, and gene nodes to **unique indices** for graph construction
- Builds **edges between patients and their corresponding radiomic/gene nodes**

**What will the raw features be used for?**

- Serve as **initial node features** for the heterogeneous graph
  - `patient` nodes: metadata features
  - `radiomic` nodes: extracted radiomic features
  - `gene` nodes: assay measurements
- Allow the GNN to **aggregate multi-modal information** to produce meaningful patient embeddings

**Why this logic is important:**

- Ensures consistent **node indexing** and **feature scaling**
- Sets up the graph so that downstream GNN operations (e.g., `HeteroConv`) can **correctly combine information across node types**
- Makes the graph **ready for contrastive learning and embedding generation**

## Exploring the Code

`self.metadata`, `self.radiomic`, and `self.gene_assay` are associated with loading the data. The `@property` functions displayed below will play the role of linking the nodes within the graph to the a specific patient id - so we know to reference it later

```python
class DataCollection:
    """Class for loading and processing clinical metadata, radiomic features, and gene assay data."""
    def __init__(self, metadata_path: str, radiomic_path: str, gene_assay_path: str):
        self.supervised = True
        self.metadata = pd.read_csv(metadata_path)
        self.radiomic = pd.read_csv(radiomic_path)
        self.gene_assay = pd.read_csv(gene_assay_path)
        self.set_radiomic_ids # set patient_id in radiomic dataframe so that it can be matched
        self.unique_ids = self.unique_patient_ids
    
    @property
    def set_radiomic_ids(self):
        ids = ['-'.join(n.split('.')[0].split('-')[:3]) for n in self.radiomic['Lesion Name'].to_list()]
        self.radiomic['patient_id'] = ids

    @property
    def unique_patient_ids(self):
        all_patient_ids = sorted(set(self.radiomic['patient_id'].to_list())
                                 | set(self.gene_assay['CLID'].to_list()))
        return all_patient_ids
    
    @property
    def patient_node_mapping(self):
        patient_ids = self.unique_patient_ids
        mapping = {pid: idx for idx, pid in enumerate(patient_ids)}
        return mapping
    
    @property
    def radiomic_node_mapping(self):
        return {i: idx for i, idx in enumerate(self.radiomic.index)}
    
    @property
    def gene_assay_node_mapping(self):
        return {i: idx for i, idx in enumerate(self.gene_assay.index)}
```

The rest of the `DataCollection` class **builds edges** and **converts raw data to GNN ready data**

```python
@property
    def build_radiomic_to_patient_edges(self):
        src_nodes = list(self.radiomic_node_mapping.values())
        dst_nodes = [self.patient_node_mapping[pid] for pid in self.radiomic['patient_id'].to_list()]
        return (src_nodes, dst_nodes)
    
    @property
    def build_gene_assay_to_patient_edges(self):
        src_nodes = list(self.gene_assay_node_mapping.values())
        dst_nodes = [self.patient_node_mapping[pid] for pid in self.gene_assay['CLID'].to_list()]
        return (src_nodes, dst_nodes)
    
    def get_radiomic_features(self, columns_to_drop: list=['Lesion Name', 'patient_id']):
        features = self.radiomic.drop(columns=columns_to_drop, errors='ignore')
        scaler = StandardScaler()
        return scaler.fit_transform(features.values)
    
    def get_gene_assay_features(self, columns_to_drop: list=['CLID', 'Unnamed: 16', 'Unnamed: 17', 'Unnamed: 18']):
        feature_cols = [c for c in self.gene_assay.columns if c not in columns_to_drop]
        if self.supervised:
            feature_cols.remove('Pam50.Call')  # Include target column if supervised
        numeric_cols = self.gene_assay[feature_cols].select_dtypes(include=[np.number]).columns.tolist()
        categorical_cols = list(set(feature_cols) - set(numeric_cols))
        preprocessor = ColumnTransformer(
            transformers=[
                ('num', StandardScaler(), numeric_cols),
                ('cat', OneHotEncoder(handle_unknown='ignore', sparse_output=False), categorical_cols)
            ]
        )
        return preprocessor.fit_transform(self.gene_assay[feature_cols])
    
    def get_patient_metadata_features(self, columns_to_drop: list=['bcr_patient_barcode', 'patient_id']):
        feature_cols = [c for c in self.metadata.columns if c not in columns_to_drop]
        if self.supervised:
            feature_cols.remove('ajcc_neoplasm_disease_stage')  # Include target column if supervised
        numeric_cols = self.metadata[feature_cols].select_dtypes(include=[np.number]).columns.tolist()
        categorical_cols = list(set(feature_cols) - set(numeric_cols))
        preprocessor = ColumnTransformer(
            transformers=[
                ('num', StandardScaler(), numeric_cols),
                ('cat', OneHotEncoder(handle_unknown='ignore', sparse_output=False), categorical_cols)
            ]
        )
        return preprocessor.fit_transform(self.metadata[feature_cols])
    
    def get_target(self, target_column: str='ajcc_neoplasm_disease_stage'):
        target_series = self.metadata[target_column]
        target_mapping = {stage: idx for idx, stage in enumerate(sorted(target_series.dropna().unique()))}
        targets = target_series.map(target_mapping).fillna(-1).astype(int).values
        return targets
    
    def get_gene_target(self, target_column: str='Pam50.Call'):
        target_series = self.gene_assay[target_column]
        target_mapping = {stage: idx for idx, stage in enumerate(sorted(target_series.dropna().unique()))}
        targets = target_series.map(target_mapping).fillna(-1).astype(int).values
        return targets
```

### Patient Representation GNN: Heterogeneous Graph Neural Network

**Purpose of this model:**

- Learn **low-dimensional patient embeddings** by aggregating information from:
  - `radiomic` nodes (image-derived features)
  - `gene` nodes (multi-gene assay features)
  - Other `patient` nodes via similarity edges
- It's important to note that due to the Graph Topology we've defined, the produced embedding are only focused on MERGING multi-modal information.

**Key Components:**

1. **HeteroConv**
   - Handles **multiple edge types** in a heterogeneous graph
   - Aggregates messages **from different node types separately**, then sums them
   - In our model:
     - `radiomic -> patient`
     - `gene -> patient`
     - `patient -> patient` (similarity edges)

2. **GATConv (Graph Attention Convolution)**
   - Assigns **learned attention weights** to neighbors
   - Helps model **focus on more relevant neighbors** when aggregating features

3. **Linear Layer**
   - Reduces aggregated hidden features to **final embedding size**
   - Produces **patient embedding vector** that captures multi-modal information

**Forward pass overview:**

- Each node type sends messages along its edges
- Messages are **weighted and aggregated** for each destination node
- Aggregated features pass through **ReLU activation**
- Patient node features are projected through a **linear layer** to produce embeddings

**Why this design matters:**

- Allows **integration of heterogeneous data** while preserving node-type structure
- Produces embeddings that reflect **multi-modal patient information**
- Can be used for **downstream tasks** like similarity comparison, clustering, or predictive modeling

```python
class PatientRepresentationGNN(nn.Module):
    def __init__(self, data, hidden_dim=64, out_dim=32):
        super().__init__()

        # --- Infer feature dims from graph ---
        patient_in = data['patient'].x.size(1)
        radiomic_in = data['radiomic'].x.size(1)
        gene_in = data['gene'].x.size(1)

        # --- Build HeteroConv with correct per-edge dims ---
        self.conv1 = HeteroConv({
            ('radiomic', 'to', 'patient'):
                GATConv((radiomic_in, patient_in), hidden_dim, add_self_loops=False),

            ('gene', 'to', 'patient'):
                GATConv((gene_in, patient_in), hidden_dim, add_self_loops=False),

            ('patient', 'similar', 'patient'):
                GATConv((patient_in, patient_in), hidden_dim, add_self_loops=False),
        }, aggr='sum')

        self.lin = nn.Linear(hidden_dim, out_dim)

    def forward(self, x_dict, edge_index_dict):
        x_dict = self.conv1(x_dict, edge_index_dict)
        x_dict = {k: torch.relu(v) for k, v in x_dict.items()}
        x_dict['patient'] = self.lin(x_dict['patient'])
        return x_dict
```

In [3]:
######################## CODING BLOCK 1 ########################

from rsna_deep_learning_lab_2025.train import load_data, Trainer, load_model_and_optimizer

#Use this cell if you need to navigate to content folder to initilize the variables
%cd rsna_deep_learning_lab_2025
dataset = load_data()
assert(dataset['data_collection'].supervised == True)
model, optimizer = load_model_and_optimizer(dataset['data'])

#Train model using pre-defined functions
trainer = Trainer(model, optimizer, dataset)

losses = trainer.train(500, verbose=False)



/content/rsna-deep-learning-lab-2025/rsna_deep_learning_lab_2025




Training complete.


In [4]:
######################## BONUS BLOCK (NOT NEEDED)  ########################
import pandas as pd

df = pd.DataFrame({'losses': losses, 'model_name': ['model']*len(losses)})
for hidden_, out_, name in [[128, 64, 'model_1'], [256, 128, 'model_2']]:
  model, optimizer = load_model_and_optimizer(dataset['data'], hidden_, out_)

  #Train model using pre-defined functions
  trainer = Trainer(model, optimizer, dataset)

  losses = trainer.train(500, verbose=False)
  df = pd.concat([df, pd.DataFrame({'losses': losses, 'model_name': [name]*len(losses)})])

#Here we can plot to see how the model's hidden channel depth will impact the learning curve

import plotly.express as px
px.line(df, y='losses', color='model_name')

Training complete.
Training complete.


##Building Similarity Matrices



The goal here is to fetch the newly learned patient embeddings, genetic raw vector, and patient metadata raw vector so that we can compare all 3 similarity matrices to investigate whether we have successfully incorporated some of the genetic signal *into* the learned embeddings. T

In [5]:
######################## CODING BLOCK 2 ########################

from rsna_deep_learning_lab_2025.utils import DataViewer

# --------------------------------------------------------
# Initialize the DataViewer
# --------------------------------------------------------
# DataViewer provides convenient utilities for extracting
# patient-level tables, genetic features, graph-connected
# metadata, and for computing similarity matrices.
#
# dataset['data']            → heterogeneous graph tensors
# dataset['data_collection'] → helpers to fetch patient labels, tables, etc.

viewer = DataViewer(dataset['data'], dataset['data_collection'])

# --------------------------------------------------------
# Fetch learned patient embeddings + gene connections
# --------------------------------------------------------
# trainer.get_patient_embeddings:
#     → returns the D-dimensional embedding vector for each patient
#       as learned by your GNN after training.
#
# viewer.fetch_connected_gene_features:
#     → returns a dictionary of genetic information aligned with
#       patient indices that actually HAVE genetic data in the graph.
#
# gene_info contains keys:
#   - 'patient_idx_with_genes': array of patient indices that have genetic data
#   - 'G': the raw genetic feature matrix for those patients

patient_embeddings = trainer.get_patient_embeddings
gene_info = viewer.fetch_connected_gene_features

# --------------------------------------------------------
# Restrict both embeddings and raw metadata to patients
# who have genetic information available
# --------------------------------------------------------
# viewer.fetch_all_patient_features:
#     → returns a matrix of raw patient-level metadata/features
#       (not graph-aggregated), e.g. age, staging, BMI, radiomic summaries, etc.
#
# We index it with patient_idx_with_genes to ensure the
# raw metadata matches exactly the subset used for genetics + embeddings.
raw_patient_features = viewer.fetch_all_patient_features[gene_info['patient_idx_with_genes']]
patient_embeddings_filtered = patient_embeddings[gene_info['patient_idx_with_genes']]

# --------------------------------------------------------
# Compute cosine similarity matrices
# --------------------------------------------------------
# viewer.fetch_cosine_similarity(X):
#     → returns an NxN cosine similarity matrix for the provided data X
#     where N = number of patients.
#
# These matrices allow you to compare:
#   - similarity based on raw patient metadata        (sim_P)
#   - similarity based on genetic profiles            (sim_G)
#   - similarity based on learned patient embeddings  (sim_E)
#
# This is useful for evaluating whether the GNN embeddings preserve,
# distort, or improve biologically meaningful relationships.
sim_P = viewer.fetch_cosine_similarity(raw_patient_features)

# gene_info['G']:
#     → genetic raw feature matrix aligned to patients
sim_G = viewer.fetch_cosine_similarity(gene_info['G'])

# patient_embeddings_filtered:
#     → GNN-learned embeddings aligned to same patients
sim_E = viewer.fetch_cosine_similarity(patient_embeddings_filtered)

In [6]:
######################## CODING BLOCK 3 ########################
import pandas as pd
import numpy as np
import torch
import plotly.express as px
from scipy.stats import pearsonr, spearmanr

# ------------------------------------------------------------
# Helper function:
# Flatten upper-triangle of pairwise similarity matrices
# ------------------------------------------------------------
# simA and simB are NxN matrices.
# Pairwise similarities include self-comparisons (diagonal),
# but we want ONLY patient-to-patient comparisons.
#
# This function:
#   - Removes the diagonal (self-similarity = 1.0)
#   - Flattens the remaining values into vectors
#
# Output:
#   vectors that can be directly compared or plotted
def flatten_pairwise(simA, simB):
    # mask = True everywhere except diagonal entries
    mask = ~np.eye(simA.shape[0], dtype=bool)

    # Return flattened off-diagonal elements
    return simA[mask], simB[mask]


# ------------------------------------------------------------
# Generate paired similarity vectors
# ------------------------------------------------------------
# sim_G : similarity derived from gene expression
# sim_P : similarity from raw patient metadata
# sim_E : similarity from learned GNN embeddings
#
# We compare gene-based similarities against:
#   (1) raw patient similarity
#   (2) embedding-space similarity
#
# This helps evaluate:
# • whether the embeddings better align with genetic relationships
# • how much the raw metadata reflects underlying biological structure
x_gene, y_raw = flatten_pairwise(sim_G, sim_P)

# Same gene similarities, now compared with embedding similarities
_, y_emb = flatten_pairwise(sim_G, sim_E)


# ------------------------------------------------------------
# Build a tidy DataFrame for Plotly
# ------------------------------------------------------------
# We stack the raw and embedding comparisons into the same DataFrame.
# Columns:
#   - gene_sim:     similarity based on genetics
#   - patient_sim:  similarity from raw or embedding
#   - type:         identifies Raw vs Embedding for plotting
df_combined = pd.DataFrame({
    "gene_sim": list(x_gene) * 2,                 # repeat gene sims twice
    "patient_sim": list(y_raw) + list(y_emb),     # concatenate raw + embedding sims
    "type": ["Raw"] * len(x_gene) + ["Embedding"] * len(x_gene)
})


# ------------------------------------------------------------
# Scatter plot with OLS trendlines
# ------------------------------------------------------------
# What this visual tells you:
#   • How strongly raw metadata correlates with genetics
#   • How strongly embeddings correlate with genetics
#   • Whether embeddings recovered genetic structure
#
# Interpretation in class:
#   - If the "Embedding" line is steeper or tighter → GNN learned biology
#   - If "Raw" is stronger → metadata already aligns well with genetics
#   - If one has more variance → noisy modality, or embedding collapse
fig = px.scatter(
    df_combined,
    x="gene_sim",
    y="patient_sim",
    color="type",
    trendline="ols",
    title="Patient Similarity vs Gene Similarity (Raw vs Embedding)"
)

fig.show()

In [7]:
######################## CODING BLOCK 4 ########################
# ------------------------------------------------------------
# Exploring trendlines but with Radiomics
# ------------------------------------------------------------
# Generate the same type of code but with radiomics information
sim_matrices = viewer.compute_all_similarities_from_radiomics(patient_embeddings)

x_rad, y_raw = flatten_pairwise(sim_matrices['sim_R'], sim_matrices['sim_P'])
_, y_emb = flatten_pairwise(sim_matrices['sim_R'], sim_matrices['sim_E'])

# Combine data into a single DataFrame with a 'type' column
df_combined_radiomic = pd.DataFrame({
    "radiomic_sim": list(x_rad) * 2,
    "patient_sim": list(y_raw) + list(y_emb),
    "type": ["Raw"] * len(x_rad) + ["Embedding"] * len(x_rad)
})

fig = px.scatter(
    df_combined_radiomic,
    x="radiomic_sim",
    y="patient_sim",
    color="type",
    trendline="ols",
    title="Patient Features vs Gene Similarity (Raw & Embedding)"
)
fig.show()

In [8]:
######################## CODING BLOCK 5 ########################
import pandas as pd
import numpy as np
from scipy.stats import pearsonr, spearmanr, norm
from rich.console import Console
from rich.table import Table
from rich import box

console = Console()


# ------------------------------------------------------------
# Separate Raw vs Embedding similarity comparisons
# ------------------------------------------------------------
# Recall that df_combined contains:
#   - gene_sim: similarity between patients in gene space
#   - patient_sim: similarity between patients in raw or embedding space
#   - type: ["Raw", "Embedding"]
#
# Here we split the dataset so we can compute correlations separately.
raw_df = df_combined[df_combined["type"] == "Raw"]
emb_df = df_combined[df_combined["type"] == "Embedding"]


# ------------------------------------------------------------
# Compute Pearson correlation between:
#   gene similarity  <->  patient similarity
# ------------------------------------------------------------
# Pearson correlation (r):
#   r = 1   → strong positive linear relationship
#   r = 0   → no linear relationship
#   r = -1  → strong negative linear relationship
r_raw, p_raw = pearsonr(raw_df["gene_sim"], raw_df["patient_sim"])
r_emb, p_emb = pearsonr(emb_df["gene_sim"], emb_df["patient_sim"])


# ------------------------------------------------------------
# Fisher r-to-z transform
# ------------------------------------------------------------
# Why?
#   Correlations are NOT normally distributed.
#   Fisher z-transform converts them to a distribution where
#   a z-test on their difference is valid.
#
# We use this to statistically test whether:
#   r_raw  vs.  r_emb
# are significantly different.
def fisher_r_to_z(r):
    return 0.5 * np.log((1 + r) / (1 - r))


z_raw = fisher_r_to_z(r_raw)
z_emb = fisher_r_to_z(r_emb)

n_raw = len(raw_df)
n_emb = len(emb_df)

# ------------------------------------------------------------
# z-test for difference between independent correlations
# ------------------------------------------------------------
# Standard error for two Fisher-transformed correlations:
z_diff = (z_emb - z_raw) / np.sqrt(1/(n_raw - 3) + 1/(n_emb - 3))

# Two-tailed p-value
p_value = 2 * (1 - norm.cdf(abs(z_diff)))


# ------------------------------------------------------------
# Pretty Rich Output
# ------------------------------------------------------------
table = Table(title="Correlation Comparison: Raw vs Embedding", box=box.HEAVY_HEAD)

table.add_column("Metric", justify="left", style="bold cyan")
table.add_column("Raw Features", justify="center")
table.add_column("Embeddings", justify="center")

table.add_row("Correlation (r)",
              f"{r_raw:.3f}",
              f"{r_emb:.3f}")

table.add_row("p-value",
              f"{p_raw:.3g}",
              f"{p_emb:.3g}")

console.print(table)

# Additional table for the statistical test comparing correlations
test_table = Table(title="Fisher r-to-z Comparison", box=box.SIMPLE_HEAVY)

test_table.add_column("Statistic", style="bold magenta")
test_table.add_column("Value", justify="center")

test_table.add_row("z-score (difference)", f"{z_diff:.3f}")
test_table.add_row("p-value", f"{p_value:.3g}")

console.print(test_table)

# Also print a quick interpretation
console.print("\n[b]Interpretation:[/b]")
if p_value < 0.05:
    console.print(
        "[green]✓ The difference between Raw vs Embedding correlations is statistically significant.[/green]"
    )
else:
    console.print(
        "[yellow]• No significant difference detected between the two correlation values.[/yellow]"
    )


## Investigating Point Movement in the Embedding Manifold

If genetic information is successfully incorporated into the patient embeddings, we should observe a meaningful restructuring of the manifold. Specifically, pairs of patients who are genetically similar but dissimilar in their raw metadata should move **closer together** in the learned embedding space. Conversely, pairs who are genetically dissimilar but appear similar in metadata should move **further apart**. This shift would indicate that the embedding is prioritizing biologically grounded relationships.

In [9]:
######################## CODING BLOCK 6 ########################
import numpy as np
from scipy.stats import ttest_rel
from rich.console import Console
from rich.table import Table
from rich import box

console = Console()


# ======================================================
# 1. Convert similarity matrices → distance matrices
# ======================================================
# dist = 1 - cosine_similarity
# For embeddings: similarity may be [-1, 1], so we rescale to [0, 1]
sim_E_normalized = (sim_E + 1) / 2

dist_P = 1 - sim_P            # raw patient-feature distance
dist_G = 1 - sim_G            # gene distance
dist_E = 1 - sim_E_normalized # embedding distance (normalized)


# ======================================================
# 2. Identify mismatch pairs:
#    A) Far in raw but close in gene
#    B) Close in raw but far in gene
# ======================================================
threshold_far = 0.75     # top 25% most distant
threshold_close = 0.25   # bottom 25% most similar

mask_far_raw_close_gene = (dist_P > np.quantile(dist_P, threshold_far)) & \
                          (dist_G < np.quantile(dist_G, threshold_close))

mask_close_raw_far_gene = (dist_P < np.quantile(dist_P, threshold_close)) & \
                          (dist_G > np.quantile(dist_G, threshold_far))


# ======================================================
# 3. Extract only the upper triangular distances
#    (pairwise matrices are symmetric; diagonal ignored)
# ======================================================
tri = np.triu_indices_from(dist_P, k=1)

dist_P_flat = dist_P[tri]
dist_G_flat = dist_G[tri]
dist_E_flat = dist_E[tri]

mask_far_raw_close_gene_flat = mask_far_raw_close_gene[tri]
mask_close_raw_far_gene_flat = mask_close_raw_far_gene[tri]


# ======================================================
# Utility function: Perform paired t-test + print rich table
# ======================================================
def print_distance_change(before, after, pair_label):
    mean_before = before.mean()
    mean_after = after.mean()
    delta = mean_after - mean_before
    direction = "decreased" if delta < 0 else "increased"

    test = ttest_rel(before, after)

    # Rich table for clean display
    table = Table(
        title=f"{pair_label}",
        box=box.SIMPLE_HEAVY,
        show_header=True,
        header_style="bold magenta"
    )

    table.add_column("Metric", style="cyan")
    table.add_column("Value", justify="center")

    table.add_row("Mean Before", f"{mean_before:.3f}")
    table.add_row("Mean After", f"{mean_after:.3f}")
    table.add_row("Change", f"{delta:.3f} ({direction})")
    table.add_row("t-statistic", f"{test.statistic:.3f}")
    table.add_row("p-value", f"{test.pvalue:.3g}")

    console.print(table)

    # Interpretation line
    if test.pvalue < 0.05:
        console.print("[green]✓ Statistically significant difference (p < 0.05)[/green]\n")
    else:
        console.print("[yellow]• No statistically significant difference (p ≥ 0.05)[/yellow]\n")


# ======================================================
# 4. Compute and display pairwise changes
# ======================================================

console.print("\n[bold underline]Evaluating Embedding Improvements in Pairwise Distance[/bold underline]\n")

# Case A: Far in raw but close in gene
print_distance_change(
    dist_P_flat[mask_far_raw_close_gene_flat],
    dist_E_flat[mask_far_raw_close_gene_flat],
    "Pairs Far in Raw but Close in Gene (Should Decrease)"
)

# Case B: Close in raw but far in gene
print_distance_change(
    dist_P_flat[mask_close_raw_far_gene_flat],
    dist_E_flat[mask_close_raw_far_gene_flat],
    "Pairs Close in Raw but Far in Gene (Should Increase)"
)


##Investigating neighbors in latent space

Another way to investigate whether we have successfully incorporated genetic infomation into the embedding space is to compare the amount of overlapping neighbors - *defined as a patient having the same individuals as a top-K neighbor in the genetic space and patient metadata / patient embedding space* - across the genetic and patient metadata space as well as the genetic and patient embedding space.

In [10]:
######################## CODING BLOCK 7 ########################
import numpy as np
import pandas as pd
import plotly.express as px

# ============================================================
# 1. Choose K (number of nearest neighbors)
# ============================================================
K = 10                   # Compare top-K neighbors across spaces
N = sim_E.shape[0]       # Number of patients (matrix is NxN)


# ============================================================
# 2. Compute Top-K Nearest Neighbors in Each Space
# ============================================================
# NOTE:
# - similarities are sorted descending (largest = most similar)
# - argsort(-sim) gives descending sort
# - We exclude index 0 because the closest neighbor is the point itself

topk_E = np.argsort(-sim_E, axis=1)[:, 1:K+1]   # embeddings → gene alignment
topk_G = np.argsort(-sim_G, axis=1)[:, 1:K+1]   # gene similarity baseline
topk_P = np.argsort(-sim_P, axis=1)[:, 1:K+1]   # raw patient features


# ============================================================
# 3. Compute Neighbor Overlap
#    For each patient:
#    overlap = (# shared neighbors between spaces) / K
# ============================================================

# ---- Raw metadata vs gene space ----
overlaps_meta = []
for p in range(N):
    # Count how many neighbors both spaces agree on
    overlap = len(set(topk_P[p]).intersection(topk_G[p])) / K
    overlaps_meta.append(overlap)

# ---- Learned embedding vs gene space ----
overlaps_embedding = []
for p in range(N):
    overlap = len(set(topk_E[p]).intersection(topk_G[p])) / K
    overlaps_embedding.append(overlap)


# ============================================================
# 4. Combine into DataFrame for Plotly Visualization
# ============================================================
overlaps = [
    [value, group]
    for values, group in [
        (overlaps_embedding, "embedding"),
        (overlaps_meta, "metadata_raw")
    ]
    for value in values
]

df = pd.DataFrame(overlaps, columns=["overlap", "group"])


# ============================================================
# 5. Plot histogram + boxplot using Plotly
# ============================================================
fig = px.histogram(
    df,
    x="overlap",
    color="group",
    nbins=20,
    barmode="overlay",
    marginal="box",
    title=f"Top-{K} Neighbor Overlap with Gene Space",
    labels={"overlap": "Neighbor Overlap (fraction of K)"},
)

fig.update_layout(
    bargap=0.1,
    title_font_size=20,
    legend_title_text="Feature Space"
)

fig.show()


In [11]:
######################## CODING BLOCK 7 (B) ########################
# ============================================================
# Explore same analysis with Radiomics info
# ============================================================
radiomic_similarities = viewer.compute_all_similarities_from_radiomics(patient_embeddings)

df_radiomic = viewer.neighborhood_overlap_analysis(radiomic_similarities['sim_P'],radiomic_similarities['sim_E'], radiomic_similarities['sim_R'])

fig = px.histogram(
    df_radiomic,
    x="data",
    color="group",
    nbins=20,
    barmode="overlay",
    marginal="box",
    title=f"Top-{K} Neighbor Overlap with Radiomic Space",
    labels={"overlap": "Neighbor Overlap (fraction of K)"},
)

fig.update_layout(
    bargap=0.1,
    title_font_size=20,
    legend_title_text="Feature Space"
)

fig.show()

## Patient Similarity Search

Using our embeddings we can see how our patient similarity search has correctly incorporated neighbors that are not **just** close in patient metadata space but in genetic space as well.

In [12]:
######################## CODING BLOCK 8 ########################
!pip install faiss-cpu
import faiss

# ========================================
# Construct similarity search index from embeddings
# ========================================
# Convert PyTorch tensor to NumPy array for FAISS compatibility
array = patient_embeddings_filtered.cpu().numpy()

# Create FAISS index for Inner Product (cosine similarity after L2 normalization)
# IndexFlatIP computes dot products between vectors
index = faiss.IndexFlatIP(array.shape[1])  # shape[1] = embedding dimension

# Normalize vectors to unit length (required for cosine similarity)
# After normalization, inner product = cosine similarity
faiss.normalize_L2(array)

# Add all patient embeddings to the index
index.add(array)

# Search for k=10 nearest neighbors for each patient embedding
# D: distances (similarity scores), shape [n_patients, 10]
# I: indices of nearest neighbors, shape [n_patients, 10]
# Note: I[:,0] is always the query itself (self-similarity)
D, I = index.search(array, 10)

# ========================================
# Find example of highest similarity neighbor pair
# ========================================
# D[:,1] contains similarities to the 1st nearest neighbor (excluding self at [:,0])
# argmax finds the patient with the highest similarity to their nearest neighbor
# src: index of the nearest neighbor of the most similar pair
# target: index of the query patient (the one with highest neighbor similarity)
src, target = I[D[:,1].argmax(), 1], D[:,1].argmax()

# Map from filtered embedding indices back to original patient node indices
src_patient_id = gene_info['patient_idx_with_genes'][src]
target_patient_id = gene_info['patient_idx_with_genes'][target]

# Map from patient node indices to actual patient IDs (e.g., "TCGA-XX-XXXX")
src_reference_id = dataset['data_collection'].unique_patient_ids[src_patient_id]
target_reference_id = dataset['data_collection'].unique_patient_ids[target_patient_id]

# ========================================
# Look at similarity evolution across different representations
# ========================================
# Compare similarity between this patient pair across three stages:
# sim_G: similarity in gene/genomic space (input features)
# sim_E_normalized: similarity in intermediate embedding space
# sim_P: similarity in final patient representation space (after GNN)
# This shows how the model transforms relationships through the network
print(sim_G[src, target], sim_E_normalized[src, target], sim_P[src, target])

Collecting faiss-cpu
  Downloading faiss_cpu-1.13.0-cp39-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (7.7 kB)
Downloading faiss_cpu-1.13.0-cp39-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (23.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.6/23.6 MB[0m [31m74.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: faiss-cpu
Successfully installed faiss-cpu-1.13.0
0.98133016 0.8833941 0.3939415


In [13]:
#Look at raw features
dataset['data_collection'].metadata.query("bcr_patient_barcode == @src_reference_id")


Unnamed: 0,bcr_patient_barcode,age_at_initial_pathologic_diagnosis,ajcc_cancer_metastasis_stage_code,ajcc_neoplasm_disease_lymph_node_stage,ajcc_neoplasm_disease_stage,ajcc_tumor_stage_code,anatomic_organ_subdivision,axillary_lymph_node_stage_method_type,breast_cancer_optical_measurement_histologic_type,breast_cancer_surgery_margin_status,...,lymph_node_examined_count,margin_status,menopause_status,number_of_lymphnodes_positive_by_he,patient_id,pretreatment_history,prior_diagnosis,race,tissue_source_site,vital_status
100,TCGA-AR-A1AQ,49,M0,N0,Stage IIA,T2,Right Upper Inner Quadrant|Right Upper Outer Q...,Sentinel lymph node biopsy plus axillary disse...,Infiltrating Ductal,[Not Available],...,6,Negative,Post (prior bilateral ovariectomy OR >12 mo si...,0,A1AQ,NO,NO,WHITE,AR,LIVING


In [14]:
#Look at raw features
dataset['data_collection'].metadata.query("bcr_patient_barcode == @target_reference_id")


Unnamed: 0,bcr_patient_barcode,age_at_initial_pathologic_diagnosis,ajcc_cancer_metastasis_stage_code,ajcc_neoplasm_disease_lymph_node_stage,ajcc_neoplasm_disease_stage,ajcc_tumor_stage_code,anatomic_organ_subdivision,axillary_lymph_node_stage_method_type,breast_cancer_optical_measurement_histologic_type,breast_cancer_surgery_margin_status,...,lymph_node_examined_count,margin_status,menopause_status,number_of_lymphnodes_positive_by_he,patient_id,pretreatment_history,prior_diagnosis,race,tissue_source_site,vital_status
110,TCGA-E2-A1B6,44,M0,N0,Stage IIA,T2,Left,Sentinel node biopsy alone,Infiltrating Ductal,[Not Available],...,5,Negative,Pre (<6 months since LMP AND no prior bilatera...,0,A1B6,NO,NO,WHITE,E2,LIVING


In [15]:
#Look at gene features
dataset['data_collection'].gene_assay.query("CLID == @src_reference_id")

Unnamed: 0,CLID,Pam50.Call,ROR-S Group (Subtype Only),ROR-P Group (Subtype + Proliferation),GHI_RS_3Group,GHI_RS Score,Mammaprint Predict.type,Mammaprint Pcorr_NKI70_Good_Correlation_Nature.2002_PMID.11823860,UNC_Scorr_Basal_Correlation_JCO.2009_PMID.19204204,UNC_Scorr_Her2_Correlation_JCO.2009_PMID.19204204,UNC_Scorr_LumA_Correlation_JCO.2009_PMID.19204204,UNC_Scorr_LumB_Correlation_JCO.2009_PMID.19204204,UNC_Scorr_Norm_Correlation_JCO.2009_PMID.19204204,UNC_ROR_S_Model_JCO.2009_PMID.19204204,UNC_Proliferation_11_Mean_JCO.2009_PMID.19204204,ROR-P (Subtype + Proliferation),Unnamed: 16,Unnamed: 17,Unnamed: 18
15,TCGA-AR-A1AQ,Basal,high,high,High,100.0,NKI70_Bad,-0.748,0.856274,-0.141772,-0.610508,-0.248046,0.05254,64.498432,0.598377,69.152486,,,


In [16]:
#Look at gene features
dataset['data_collection'].gene_assay.query("CLID == @target_reference_id")

Unnamed: 0,CLID,Pam50.Call,ROR-S Group (Subtype Only),ROR-P Group (Subtype + Proliferation),GHI_RS_3Group,GHI_RS Score,Mammaprint Predict.type,Mammaprint Pcorr_NKI70_Good_Correlation_Nature.2002_PMID.11823860,UNC_Scorr_Basal_Correlation_JCO.2009_PMID.19204204,UNC_Scorr_Her2_Correlation_JCO.2009_PMID.19204204,UNC_Scorr_LumA_Correlation_JCO.2009_PMID.19204204,UNC_Scorr_LumB_Correlation_JCO.2009_PMID.19204204,UNC_Scorr_Norm_Correlation_JCO.2009_PMID.19204204,UNC_ROR_S_Model_JCO.2009_PMID.19204204,UNC_Proliferation_11_Mean_JCO.2009_PMID.19204204,ROR-P (Subtype + Proliferation),Unnamed: 16,Unnamed: 17,Unnamed: 18
97,TCGA-E2-A1B6,Basal,high,high,High,100.0,NKI70_Bad,-0.541,0.762484,-0.23122,-0.49403,-0.346179,0.175098,56.252547,0.263237,53.21368,,,


## Clinical Investigation

Until now, our primary focus has been on learning patient embeddings that *integrate multiple sources* of information—specifically **radiomic features** and **gene expression profiles**. As a next step, we examine whether these learned embeddings also capture clinically meaningful signals, by testing how well they predict **relevant biological or diagnostic labels**.


In [17]:
######################## CODING BLOCK 9 ########################
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score
from rich.console import Console
from rich.table import Table
from rich import box

console = Console()

# ------------------------------------------------------
# 1. Fetch metadata based targets (ajcc_neoplasm_disease_stage)
# ------------------------------------------------------
# The model previously predicted metadata—now we evaluate whether
# the embeddings also capture genetic subtype information.
targets = dataset['data_collection'].get_target()
y = np.asarray(targets[gene_info['patient_idx_with_genes']])

# ------------------------------------------------------
# 2. Remove rare classes (avoid classes with <2 samples)
# ------------------------------------------------------
# Stratified splitting requires at least 2 samples per class.
unique, counts = np.unique(y, return_counts=True)
bad_classes = unique[counts < 2]

mask = ~np.isin(y, bad_classes)

# ------------------------------------------------------
# 3. Filter embeddings, raw features, and labels
# ------------------------------------------------------
X_emb_filtered = patient_embeddings[gene_info['patient_idx_with_genes']][mask]
X_raw_filtered = raw_patient_features[mask]
y_filtered     = y[mask]

# ------------------------------------------------------
# 4. Train/Test split (stratified by gene subtype)
# ------------------------------------------------------
X_emb_train, X_emb_test, X_raw_train, X_raw_test, y_train, y_test = train_test_split(
    X_emb_filtered,
    X_raw_filtered,
    y_filtered,
    test_size=0.2,
    random_state=42,
    stratify=y_filtered
)

# ------------------------------------------------------
# 5. Train logistic regression on raw features
# ------------------------------------------------------
clf_raw = LogisticRegression(max_iter=2000)
clf_raw.fit(X_raw_train, y_train)
pred_raw = clf_raw.predict_proba(X_raw_test)

# ------------------------------------------------------
# 6. Train logistic regression on learned embeddings
# ------------------------------------------------------
clf_emb = LogisticRegression(max_iter=2000)
clf_emb.fit(X_emb_train, y_train)
pred_emb = clf_emb.predict_proba(X_emb_test)

# ------------------------------------------------------
# 7. Compute AUC (multi-class One-vs-Rest)
# ------------------------------------------------------
auc_raw_clinical = roc_auc_score(y_test, pred_raw, multi_class='ovr')
auc_emb_clinical = roc_auc_score(y_test, pred_emb, multi_class='ovr')

# ------------------------------------------------------
# 8. Pretty Rich Output
# ------------------------------------------------------
table = Table(title="Prediction Performance on ajcc_neoplasm_disease_stage ", box=box.ROUNDED)

table.add_column("Model", style="bold cyan")
table.add_column("AUC (OvR)", style="bold magenta")

table.add_row("Raw Patient Features", f"{auc_raw_clinical:.4f}")
table.add_row("Learned Patient Embeddings", f"{auc_emb_clinical:.4f}")

console.print("\n[bold underline]Model Performance Comparison[/bold underline]")
console.print(table)



In [18]:
######################## CODING BLOCK 10 ########################
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score
from rich.console import Console
from rich.table import Table
from rich import box

console = Console()

# ------------------------------------------------------
# 1. Fetch gene-based targets (e.g., PAM50 subtypes)
# ------------------------------------------------------
# The model previously predicted metadata—now we evaluate whether
# the embeddings also capture genetic subtype information.
targets = dataset['data_collection'].get_gene_target()
y = targets

# ------------------------------------------------------
# 2. Remove rare classes (avoid classes with <2 samples)
# ------------------------------------------------------
# Stratified splitting requires at least 2 samples per class.
unique, counts = np.unique(y, return_counts=True)
bad_classes = unique[counts < 2]

mask = ~np.isin(y, bad_classes)

# ------------------------------------------------------
# 3. Filter embeddings, raw features, and labels
# ------------------------------------------------------
X_emb_filtered = patient_embeddings[gene_info['patient_idx_with_genes']][mask]
X_raw_filtered = raw_patient_features[mask]
y_filtered     = y[mask]

# ------------------------------------------------------
# 4. Train/Test split (stratified by gene subtype)
# ------------------------------------------------------
X_emb_train, X_emb_test, X_raw_train, X_raw_test, y_train, y_test = train_test_split(
    X_emb_filtered,
    X_raw_filtered,
    y_filtered,
    test_size=0.2,
    random_state=42,
    stratify=y_filtered
)

# ------------------------------------------------------
# 5. Train logistic regression on raw features
# ------------------------------------------------------
clf_raw = LogisticRegression(max_iter=2000)
clf_raw.fit(X_raw_train, y_train)
pred_raw = clf_raw.predict_proba(X_raw_test)

# ------------------------------------------------------
# 6. Train logistic regression on learned embeddings
# ------------------------------------------------------
clf_emb = LogisticRegression(max_iter=2000)
clf_emb.fit(X_emb_train, y_train)
pred_emb = clf_emb.predict_proba(X_emb_test)

# ------------------------------------------------------
# 7. Compute AUC (multi-class One-vs-Rest)
# ------------------------------------------------------
auc_raw_gene = roc_auc_score(y_test, pred_raw, multi_class='ovr')
auc_emb_gene = roc_auc_score(y_test, pred_emb, multi_class='ovr')

# ------------------------------------------------------
# 8. Pretty Rich Output
# ------------------------------------------------------
table = Table(title="Prediction Performance on Gene Subtypes (PAM50)", box=box.ROUNDED)

table.add_column("Model", style="bold cyan")
table.add_column("AUC (OvR)", style="bold magenta")

table.add_row("Raw Patient Features", f"{auc_raw_gene:.4f}")
table.add_row("Learned Patient Embeddings", f"{auc_emb_gene:.4f}")

console.print("\n[bold underline]Model Performance Comparison[/bold underline]")
console.print(table)



In [19]:
######################## CODING BLOCK 11 ########################
import plotly.graph_objects as go
import numpy as np

# --- Replace these with your computed values ---


# Data structure
targets = ["Metadata", "Gene"]
raw_auc = [auc_raw_clinical, auc_raw_gene]
emb_auc = [auc_emb_clinical, auc_emb_gene]

# Create figure
fig = go.Figure(data=[
    go.Bar(name='Raw Features', x=targets, y=raw_auc, marker_color='rgba(55, 83, 109, 0.7)'),
    go.Bar(name='Embeddings', x=targets, y=emb_auc, marker_color='rgba(26, 118, 255, 0.7)')
])

# Improve layout
fig.update_layout(
    title="AUC Comparison: Raw Features vs Learned Patient Embeddings",
    yaxis_title="AUC Score",
    xaxis_title="Prediction Target",
    barmode='group',
    template='plotly_white',
    font=dict(size=16),
    yaxis=dict(range=[0, 1])  # because AUC always between 0–1
)

fig.show()