# Brain Tumor CDSS Demo

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/melmbrain/brain-tumor-cdss/blob/main/notebooks/demo.ipynb)

This notebook demonstrates the Brain Tumor Clinical Decision Support System.

## 1. Installation

In [None]:
# Install dependencies (uncomment if running on Colab)
# !pip install torch monai nibabel SimpleITK gseapy lifelines

In [None]:
# Clone repository (if running on Colab)
# !git clone https://github.com/melmbrain/brain-tumor-cdss.git
# %cd brain-tumor-cdss

## 2. Import Libraries

In [None]:
import sys
sys.path.insert(0, '..')

import numpy as np
import torch
import matplotlib.pyplot as plt

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## 3. Load Models

In [None]:
from models.m1 import MRIMultiTaskModel, M1Inference
from models.mg import GeneExpressionCDSS, MGInference
from models.mm import MultimodalFusionModel, MMInference

# Initialize models (will use random weights in demo mode)
print("Loading M1 model...")
m1 = M1Inference(device='cpu')

print("Loading MG model...")
mg = MGInference(device='cpu')

print("Loading MM model...")
mm = MMInference(device='cpu')

print("\nAll models loaded (demo mode with random weights)")

## 4. Model Architectures

In [None]:
# M1 Model Architecture
m1_model = MRIMultiTaskModel(in_channels=4, embed_dim=48, include_segmentation=True)
print("M1 Model (MRI Encoder):")
print(f"  Parameters: {sum(p.numel() for p in m1_model.parameters()):,}")
print(f"  Input: (B, 4, 128, 128, 128) - T1, T1ce, T2, FLAIR")
print(f"  Output: Segmentation mask + Classification logits")

In [None]:
# MG Model Architecture
gene_embeddings = torch.randn(500, 64)  # Placeholder
mg_model = GeneExpressionCDSS(gene_embeddings=gene_embeddings, n_pathways=48)
print("\nMG Model (Gene VAE Encoder):")
print(f"  Parameters: {sum(p.numel() for p in mg_model.parameters()):,}")
print(f"  Input: 500 genes + 48 pathways")
print(f"  Output: 64-dim latent + Task predictions")

In [None]:
# MM Model Architecture
mm_model = MultimodalFusionModel(mri_dim=768, gene_dim=64, protein_dim=167, clinical_dim=10)
print("\nMM Model (Multimodal Fusion):")
print(f"  Parameters: {sum(p.numel() for p in mm_model.parameters()):,}")
print(f"  Input: MRI (768) + Gene (64) + Protein (167) + Clinical (10)")
print(f"  Output: 7 classification tasks + Survival prediction")

## 5. Demo: Gene Expression Analysis (MG)

In [None]:
# Generate synthetic gene expression data
np.random.seed(42)

# Simulate gene expression for 100 genes
gene_names = [f'GENE_{i}' for i in range(100)]
gene_expression = {name: np.random.randn() for name in gene_names}

# Simulate pathway scores (Hallmark pathways)
pathway_names = [
    'HALLMARK_APOPTOSIS', 'HALLMARK_CELL_CYCLE', 'HALLMARK_DNA_REPAIR',
    'HALLMARK_GLYCOLYSIS', 'HALLMARK_HYPOXIA', 'HALLMARK_P53_PATHWAY'
]
pathway_scores = {name: np.random.randn() for name in pathway_names}

print("Sample gene expression data:")
for gene, value in list(gene_expression.items())[:5]:
    print(f"  {gene}: {value:.3f}")

In [None]:
# Run MG analysis
result = mg.analyze(
    patient_id="demo_patient",
    gene_expression=gene_expression,
    pathway_scores=pathway_scores,
    include_explainability=True
)

print("\n" + "="*50)
print("MG Analysis Results")
print("="*50)
print(f"\nSurvival Risk: {result['survival_risk']['category']}")
print(f"  Score: {result['survival_risk']['score']:.4f}")
print(f"\nGrade Prediction: {result['grade_prediction']['predicted']}")
print(f"  Confidence: {result['grade_prediction']['confidence']:.3f}")
print(f"\nSurvival Time: {result['survival_time']['predicted_months']:.1f} months")
print(f"\nRecurrence: {result['recurrence']['prediction']}")
print(f"  Probability: {result['recurrence']['probability']:.3f}")

## 6. Demo: Multimodal Fusion (MM)

In [None]:
# Generate synthetic multimodal features
batch_size = 1

# Pre-extracted features (normally from M1 and MG encoders)
mri_features = torch.randn(batch_size, 768)   # From M1
gene_features = torch.randn(batch_size, 64)   # From MG
protein_data = torch.randn(batch_size, 167)   # RPPA
clinical_data = torch.randn(batch_size, 10)   # Clinical features

print("Input features:")
print(f"  MRI: {mri_features.shape}")
print(f"  Gene: {gene_features.shape}")
print(f"  Protein: {protein_data.shape}")
print(f"  Clinical: {clinical_data.shape}")

In [None]:
# Run MM model
with torch.no_grad():
    outputs = mm_model(mri_features, gene_features, protein_data, clinical_data, return_attention=True)

# Process outputs
grade_probs = torch.softmax(outputs['grade_logits'], dim=-1).numpy()[0]
idh_probs = torch.softmax(outputs['idh_logits'], dim=-1).numpy()[0]
mgmt_probs = torch.softmax(outputs['mgmt_logits'], dim=-1).numpy()[0]

print("\n" + "="*50)
print("MM Multimodal Fusion Results")
print("="*50)
print(f"\nGrade Prediction:")
for i, grade in enumerate(['G1', 'G2', 'G3', 'G4']):
    print(f"  {grade}: {grade_probs[i]:.3f}")
print(f"\nIDH Mutation:")
print(f"  Wildtype: {idh_probs[0]:.3f}")
print(f"  Mutant: {idh_probs[1]:.3f}")
print(f"\nMGMT Methylation:")
print(f"  Unmethylated: {mgmt_probs[0]:.3f}")
print(f"  Methylated: {mgmt_probs[1]:.3f}")
print(f"\nSurvival:")
print(f"  Risk Score: {torch.sigmoid(outputs['risk_score']).item():.3f}")

## 7. Visualize Attention Weights

In [None]:
# Get attention weights from MM model
if 'modal_attention' in outputs:
    attention = outputs['modal_attention'].numpy()[0]  # [num_heads, 4, 4]
    
    # Average over heads
    avg_attention = attention.mean(axis=0)
    
    plt.figure(figsize=(8, 6))
    modalities = ['MRI', 'Gene', 'Protein', 'Clinical']
    
    plt.imshow(avg_attention, cmap='Blues', vmin=0, vmax=1)
    plt.colorbar(label='Attention Weight')
    plt.xticks(range(4), modalities)
    plt.yticks(range(4), modalities)
    plt.xlabel('Key')
    plt.ylabel('Query')
    plt.title('Cross-Modal Attention Weights')
    
    # Add values
    for i in range(4):
        for j in range(4):
            plt.text(j, i, f'{avg_attention[i,j]:.2f}', ha='center', va='center')
    
    plt.tight_layout()
    plt.show()
else:
    print("Attention weights not available")

## 8. Performance Summary

In [None]:
# Display reported performance metrics
performance = {
    'M1-Seg': {'Task': 'Segmentation', 'Metric': 'Dice', 'Score': 0.766},
    'M1-Cls (IDH)': {'Task': 'IDH Mutation', 'Metric': 'AUC', 'Score': 0.878},
    'M1-Cls (Grade)': {'Task': 'Grade Classification', 'Metric': 'Accuracy', 'Score': 0.838},
    'M1-Cls (Survival)': {'Task': 'Survival', 'Metric': 'C-Index', 'Score': 0.660},
    'MG': {'Task': 'Gene Survival', 'Metric': 'C-Index', 'Score': 0.780},
    'MM': {'Task': 'Multimodal Survival', 'Metric': 'C-Index', 'Score': 0.610},
}

print("\n" + "="*60)
print("Performance Summary")
print("="*60)
print(f"{'Model':<20} {'Task':<25} {'Metric':<10} {'Score':<10}")
print("-"*60)
for model, metrics in performance.items():
    print(f"{model:<20} {metrics['Task']:<25} {metrics['Metric']:<10} {metrics['Score']:<10.3f}")

## 9. Next Steps

1. **Download pretrained weights** from [Releases](https://github.com/melmbrain/brain-tumor-cdss/releases)
2. **Prepare your data** using the preprocessing scripts
3. **Run inference** with real patient data
4. **Fine-tune models** on your own dataset