# ASPIRE-IO: Patient-Level Immunotherapy Response Prediction

This notebook demonstrates how to use `train_IO_response_prediction_mlp.py` to train a patient-level outcome prediction model using MUSK embeddings and spatial gene predictions.

## Overview

The IO response prediction pipeline:
1. **Loads patch-level MUSK embeddings** from `.pt` files
2. **Loads patch-level gene predictions** with spatial coordinates and attention weights
3. **Selects top-N patches by attention** for embedding mean pooling
4. **Aggregates features** using 4 spatial gene expression feature families
5. **Trains a 2-layer MLP** for binary response classification

## Data Format

### Embeddings
`embeddings/{patient_id}/{patient_id}_{x}_{y}_embedding.pt`

Each `.pt` file contains:
- `patch_id`: string identifier
- `image_embedding`: 1024-dim MUSK image encoder output

### Gene Predictions
`gene_predictions/{patient_id}_predictions.csv`

| Column | Description |
|--------|-------------|
| `x` | Spatial x-coordinate |
| `y` | Spatial y-coordinate |
| `immune_score` | Predicted immune infiltration score |
| `attention_weight` | Attention weight from model |

### Outcomes
`train_outcomes.csv`, `val_outcomes.csv`

| Column | Description |
|--------|-------------|
| `patient_id` | Patient identifier |
| `outcome` | Binary response label (0/1) |

In [1]:
# Explore the sample data structure
import pandas as pd
import numpy as np
import torch
import os
import warnings
warnings.filterwarnings("ignore")

DATA_DIR = "./pseudo_data_IO"

# 1. Check outcomes
train_outcomes = pd.read_csv(f"{DATA_DIR}/train_outcomes.csv")
val_outcomes = pd.read_csv(f"{DATA_DIR}/val_outcomes.csv")

print("=== Training Outcomes ===")
print(train_outcomes)
print(f"\nClass distribution: {dict(train_outcomes['outcome'].value_counts())}")

print("\n=== Validation Outcomes ===")
print(val_outcomes)
print(f"\nClass distribution: {dict(val_outcomes['outcome'].value_counts())}")

=== Training Outcomes ===
           patient_id  outcome
0   train_patient_000        0
1   train_patient_001        1
2   train_patient_002        0
3   train_patient_003        0
4   train_patient_004        1
5   train_patient_005        1
6   train_patient_006        1
7   train_patient_007        0
8   train_patient_008        0
9   train_patient_009        0
10  train_patient_010        1
11  train_patient_011        1
12  train_patient_012        0
13  train_patient_013        1
14  train_patient_014        1
15  train_patient_015        0
16  train_patient_016        0
17  train_patient_017        0
18  train_patient_018        1
19  train_patient_019        1

Class distribution: {0: 10, 1: 10}

=== Validation Outcomes ===
        patient_id  outcome
0  val_patient_000        0
1  val_patient_001        0
2  val_patient_002        0
3  val_patient_003        1
4  val_patient_004        0
5  val_patient_005        1
6  val_patient_006        1
7  val_patient_007        1
8  val

In [2]:
# 2. Check patch-level embeddings
patient_id = train_outcomes['patient_id'].iloc[0]
embed_dir = f"{DATA_DIR}/embeddings/{patient_id}"

print(f"=== Patch-level Embeddings for {patient_id} ===")
pt_files = [f for f in os.listdir(embed_dir) if f.endswith('.pt')]
print(f"Number of patches: {len(pt_files)}")
print(f"Sample files: {pt_files[:3]}")

# Load one embedding
sample_emb = torch.load(os.path.join(embed_dir, pt_files[0]), map_location='cpu', weights_only=False)
print(f"\nEmbedding keys: {list(sample_emb.keys())}")
print(f"Embedding shape: {sample_emb['image_embedding'].shape}")

=== Patch-level Embeddings for train_patient_000 ===
Number of patches: 50
Sample files: ['train_patient_000_38645_24879_embedding.pt', 'train_patient_000_48008_44320_embedding.pt', 'train_patient_000_27914_17895_embedding.pt']

Embedding keys: ['patch_id', 'image_embedding']
Embedding shape: (1024,)


In [3]:
# 3. Check gene predictions
gene_pred = pd.read_csv(f"{DATA_DIR}/gene_predictions/{patient_id}_predictions.csv")

print(f"=== Gene Predictions for {patient_id} ===")
print(f"Shape: {gene_pred.shape}")
print(f"Columns: {list(gene_pred.columns)}")
print(f"\nSample rows:")
print(gene_pred.head())

=== Gene Predictions for train_patient_000 ===
Shape: (50, 4)
Columns: ['x', 'y', 'immune_score', 'attention_weight']

Sample rows:
       x      y  immune_score  attention_weight
0  38645  24879      0.802586          0.025310
1  48008  44320      0.572049          0.036340
2  27914  17895      0.512668          0.024375
3  11640  20987      0.293489          0.034765
4  30990  20631      0.931754          0.028807


In [4]:
# Train the IO response prediction model
import subprocess
import sys
import warnings
warnings.filterwarnings("ignore")

DATA_DIR = "./pseudo_data_IO"
OUTPUT_DIR = "./demo_outputs"
os.makedirs(OUTPUT_DIR, exist_ok=True)

args = [
    sys.executable,
    "./train_IO_response_prediction_mlp.py",
    "--train_csv", f"{DATA_DIR}/train_outcomes.csv",
    "--val_csv", f"{DATA_DIR}/val_outcomes.csv",
    "--embedding_dir", f"{DATA_DIR}/embeddings",
    "--gene_pred_dir", f"{DATA_DIR}/gene_predictions",
    "--outcome_col", "outcome",
    "--patient_col", "patient_id",
    "--top_n", "50",
    "--save_path", f"{OUTPUT_DIR}/io_response_model.pt",
    "--gpu", "0",
]

print("=" * 60)
print("ASPIRE-IO: Training Patient-Level Response Predictor")
print("=" * 60)
print(f"\nTrain: {DATA_DIR}/train_outcomes.csv")
print(f"Val: {DATA_DIR}/val_outcomes.csv")
print(f"Output: {OUTPUT_DIR}/io_response_model.pt\n")

result = subprocess.run(args, capture_output=False)
print(f"\nTraining completed with exit code: {result.returncode}")

ASPIRE-IO: Training Patient-Level Response Predictor

Train: ./pseudo_data_IO/train_outcomes.csv
Val: ./pseudo_data_IO/val_outcomes.csv
Output: ./demo_outputs/io_response_model.pt

Outcome Dataset: 20 patients
  - Flattened features: top-50 patches × 4 values = 200 features
  - Attention weighting: True
  - Coordinates re-centered to highest attention patch (Centroid)
  - 4 feature families: Spatial Architecture, Attention-Weighted, Distribution, Dispersion
Outcome Dataset: 10 patients
  - Flattened features: top-50 patches × 4 values = 200 features
  - Attention weighting: True
  - Coordinates re-centered to highest attention patch (Centroid)
  - 4 feature families: Spatial Architecture, Attention-Weighted, Distribution, Dispersion
Input feature dimension: 1486


                                                       

Epoch   1: train_loss=0.6909, val_loss=0.6860, val_auc=0.7600, val_acc=0.5000
  -> Saved best model (AUC: 0.7600)


                                                       

Epoch   2: train_loss=0.6152, val_loss=0.6892, val_auc=0.6400, val_acc=0.5000


                                                       

Epoch   3: train_loss=0.5695, val_loss=0.6924, val_auc=0.5200, val_acc=0.5000


                                                       

Epoch   4: train_loss=0.5060, val_loss=0.6932, val_auc=0.5600, val_acc=0.5000


                                                       

Epoch   5: train_loss=0.4745, val_loss=0.6936, val_auc=0.6000, val_acc=0.5000


                                                       

Epoch   6: train_loss=0.4774, val_loss=0.6950, val_auc=0.5200, val_acc=0.5000


                                                       

Epoch   7: train_loss=0.4595, val_loss=0.6957, val_auc=0.5600, val_acc=0.5000


                                                       

Epoch   8: train_loss=0.3551, val_loss=0.6961, val_auc=0.5600, val_acc=0.5000


                                                       

Epoch   9: train_loss=0.3634, val_loss=0.6989, val_auc=0.5200, val_acc=0.5000


                                                       

Epoch  10: train_loss=0.3574, val_loss=0.6985, val_auc=0.4000, val_acc=0.5000


                                                       

Epoch  11: train_loss=0.3242, val_loss=0.6996, val_auc=0.4400, val_acc=0.5000


                                                       

Epoch  12: train_loss=0.2756, val_loss=0.7020, val_auc=0.3600, val_acc=0.5000


                                                       

Epoch  13: train_loss=0.2851, val_loss=0.7075, val_auc=0.3200, val_acc=0.4000


                                                       

Epoch  14: train_loss=0.2393, val_loss=0.7071, val_auc=0.4000, val_acc=0.3000


                                                       

Epoch  15: train_loss=0.2114, val_loss=0.7105, val_auc=0.4400, val_acc=0.3000


                                                       

Epoch  16: train_loss=0.2567, val_loss=0.7129, val_auc=0.4400, val_acc=0.3000


                                                       

Epoch  17: train_loss=0.2385, val_loss=0.7109, val_auc=0.4800, val_acc=0.3000


                                                       

Epoch  18: train_loss=0.3243, val_loss=0.7217, val_auc=0.3600, val_acc=0.3000


                                                       

Epoch  19: train_loss=0.2140, val_loss=0.7201, val_auc=0.4000, val_acc=0.3000


                                                       

Epoch  20: train_loss=0.2199, val_loss=0.7137, val_auc=0.3600, val_acc=0.3000


                                                       

Epoch  21: train_loss=0.1958, val_loss=0.7179, val_auc=0.3600, val_acc=0.4000


                                                       

Epoch  22: train_loss=0.2047, val_loss=0.7193, val_auc=0.3600, val_acc=0.4000


                                                       

Epoch  23: train_loss=0.1700, val_loss=0.7220, val_auc=0.3600, val_acc=0.4000


                                                       

Epoch  24: train_loss=0.1601, val_loss=0.7255, val_auc=0.3600, val_acc=0.4000


                                                       

Epoch  25: train_loss=0.1896, val_loss=0.7218, val_auc=0.3600, val_acc=0.4000


                                                       

Epoch  26: train_loss=0.2031, val_loss=0.7296, val_auc=0.3200, val_acc=0.2000


                                                       

Epoch  27: train_loss=0.2057, val_loss=0.7330, val_auc=0.3200, val_acc=0.3000


                                                       

Epoch  28: train_loss=0.1687, val_loss=0.7352, val_auc=0.3200, val_acc=0.3000


                                                       

Epoch  29: train_loss=0.2018, val_loss=0.7338, val_auc=0.3200, val_acc=0.3000


                                                       

Epoch  30: train_loss=0.1854, val_loss=0.7337, val_auc=0.3200, val_acc=0.3000
Training complete. Best validation AUC: 0.7600

Training completed with exit code: 0


In [5]:
# Inspect the trained model
import torch
import warnings
warnings.filterwarnings("ignore")

checkpoint_path = f"{OUTPUT_DIR}/io_response_model.pt"

if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
    
    print("=== Model Checkpoint ===")
    print(f"Keys: {list(checkpoint.keys())}")
    
    if 'val_metrics' in checkpoint:
        val_metrics = checkpoint['val_metrics']
        if isinstance(val_metrics, dict) and 'auc' in val_metrics:
            print(f"\nBest validation AUC: {val_metrics['auc']:.4f}")
        
    if 'in_dim' in checkpoint:
        print(f"Input dimension: {checkpoint['in_dim']}")
        
    if 'config' in checkpoint:
        print(f"Config: {checkpoint['config']}")
    
    if 'model_state_dict' in checkpoint:
        print("\n=== Model Architecture ===")
        for name, param in checkpoint['model_state_dict'].items():
            print(f"  {name}: {param.shape}")
else:
    print(f"Checkpoint not found at {checkpoint_path}")

=== Model Checkpoint ===
Keys: ['model_state_dict', 'in_dim', 'config', 'val_metrics']

Best validation AUC: 0.7600
Input dimension: 1486
Config: {'batch_size': 16, 'num_epochs': 30, 'hidden_dim_1': 32, 'hidden_dim_2': 16, 'learning_rate': 0.001, 'weight_decay': 0.0001, 'patience': 30}

=== Model Architecture ===
  input_norm.weight: torch.Size([1486])
  input_norm.bias: torch.Size([1486])
  input_norm.running_mean: torch.Size([1486])
  input_norm.running_var: torch.Size([1486])
  input_norm.num_batches_tracked: torch.Size([])
  net.0.weight: torch.Size([32, 1486])
  net.0.bias: torch.Size([32])
  net.2.weight: torch.Size([16, 32])
  net.2.bias: torch.Size([16])
  net.4.weight: torch.Size([1, 16])
  net.4.bias: torch.Size([1])
