# ASPIRE-IO Demo Pipeline

**Before running**, update the following in each cell:
- `TOKENIZER_PATH`: Path to your MUSK tokenizer.spm file
- `HF_TOKEN`: Your HuggingFace token for MUSK model access


In [None]:
#!/usr/bin/env python3
"""
Step 1: Pre-training Demo
"""

import subprocess
import sys
import os
import warnings
warnings.filterwarnings("ignore")

# ============================================================================
# Configuration - UPDATE THESE PATHS
# ============================================================================

SAMPLE_DATA_DIR = "./sample_data"
OUTPUT_DIR = "./demo_outputs"
TOKENIZER_PATH = "<path-to-MUSK>/musk/models/tokenizer.spm"  # UPDATE THIS
HF_TOKEN = "<your-huggingface-token>"  # UPDATE THIS

# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Arguments for pretrain.py
args = [
    sys.executable,
    "./pretrain.py",
    "--train_csv", f"{SAMPLE_DATA_DIR}/sample_train.csv",
    "--val_csv", f"{SAMPLE_DATA_DIR}/sample_val.csv",
    "--patch_dir", f"{SAMPLE_DATA_DIR}/patches",
    "--gene_dir", f"{SAMPLE_DATA_DIR}/gene_data",
    "--text_dir", f"{SAMPLE_DATA_DIR}/text_descriptions",
    "--tokenizer_path", TOKENIZER_PATH,
    "--hf_token", HF_TOKEN,
    "--save_path", f"{OUTPUT_DIR}/pretrain_checkpoint.pt",
    "--gpu", "0",
    "--num_epochs", "3",
    "--batch_size", "14",
]

print("=" * 70)
print("ASPIRE-IO Demo: Step 1 - Pre-training")
print("=" * 70)
print(f"\nTrain CSV: {SAMPLE_DATA_DIR}/sample_train.csv")
print(f"Output: {OUTPUT_DIR}/pretrain_checkpoint.pt\\n")

env = os.environ.copy()
env["PYTHONWARNINGS"] = "ignore"
result = subprocess.run(args, env=env)

if result.returncode != 0:
    print(f"\\n[ERROR] pretrain.py failed with exit code {result.returncode}")
else:
    print(f"\\n[SUCCESS] Pre-training complete!")


ASPIRE-IO Demo: Step 1 - Pre-training

Train CSV: ./sample_data/sample_train.csv
Output: ./demo_outputs/pretrain_checkpoint.pt\n


Loading MUSK backbone...
Load ckpt from hf_hub:xiangjx/musk
Frozen 20/24 transformer layers
Dataset: 56 spots, organs: ['Lung' 'Prostate' 'Bladder' 'Breast' 'Lymph node' 'Bowel' 'Skin']
Dataset: 14 spots, organs: ['Lung' 'Prostate' 'Breast' 'Skin' 'Bowel']



Training:   0%|          | 0/4 [00:00<?, ?it/s]


Training:  25%|██▌       | 1/4 [00:02<00:07,  2.65s/it]


Training:  50%|█████     | 2/4 [00:03<00:03,  1.59s/it]


Training:  75%|███████▌  | 3/4 [00:04<00:01,  1.35s/it]


Training: 100%|██████████| 4/4 [00:05<00:00,  1.23s/it]
Training: 100%|██████████| 4/4 [00:05<00:00,  1.43s/it]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]


Validation: 100%|██████████| 1/1 [00:00<00:00,  1.03it/s]
Validation: 100%|██████████| 1/1 [00:01<00:00,  1.06s/it]


Epoch   1: train_loss=376725222921.1875, val_loss=8.6558
  -> Saved best model (val_loss=8.6558)



Training:   0%|          | 0/4 [00:00<?, ?it/s]


Training:  25%|██▌       | 1/4 [00:01<00:05,  1.83s/it]


Training:  50%|█████     | 2/4 [00:02<00:02,  1.29s/it]


Training:  75%|███████▌  | 3/4 [00:03<00:01,  1.11s/it]


Training: 100%|██████████| 4/4 [00:04<00:00,  1.07s/it]
Training: 100%|██████████| 4/4 [00:04<00:00,  1.19s/it]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]


Validation: 100%|██████████| 1/1 [00:01<00:00,  1.13s/it]
Validation: 100%|██████████| 1/1 [00:01<00:00,  1.23s/it]


Epoch   2: train_loss=743847647693.5000, val_loss=7.6476
  -> Saved best model (val_loss=7.6476)



Training:   0%|          | 0/4 [00:00<?, ?it/s]


Training:  25%|██▌       | 1/4 [00:01<00:05,  1.89s/it]


Training:  50%|█████     | 2/4 [00:02<00:02,  1.38s/it]


Training:  75%|███████▌  | 3/4 [00:03<00:01,  1.05s/it]


Training: 100%|██████████| 4/4 [00:04<00:00,  1.08it/s]
Training: 100%|██████████| 4/4 [00:04<00:00,  1.10s/it]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]


Validation: 100%|██████████| 1/1 [00:00<00:00,  1.06it/s]
Validation: 100%|██████████| 1/1 [00:01<00:00,  1.06s/it]


Epoch   3: train_loss=133765701464.0000, val_loss=7.9804
Training complete.


\n[SUCCESS] Pre-training complete!


In [None]:
#!/usr/bin/env python3
"""
Step 2: Fine-tuning Demo
"""

import subprocess
import sys
import os
import warnings
warnings.filterwarnings("ignore")

# ============================================================================
# Configuration - UPDATE THESE PATHS
# ============================================================================

SAMPLE_DATA_DIR = "./sample_data"
OUTPUT_DIR = "./demo_outputs"

TOKENIZER_PATH = "<path-to-MUSK>/musk/models/tokenizer.spm"  # UPDATE THIS
HF_TOKEN = "<your-huggingface-token>"  # UPDATE THIS

PRETRAIN_CHECKPOINT = f"{OUTPUT_DIR}/pretrain_checkpoint.pt"
TARGET_ORGAN = "Lung"  # Can change to any organ in sample data

if not os.path.exists(PRETRAIN_CHECKPOINT):
    raise FileNotFoundError("Pre-trained checkpoint not found. Run Step 1 first.")

args = [
    sys.executable,
    "./finetune_mlp.py",
    "--pretrained_checkpoint", PRETRAIN_CHECKPOINT,
    "--train_csv", f"{SAMPLE_DATA_DIR}/sample_train.csv",
    "--val_csv", f"{SAMPLE_DATA_DIR}/sample_val.csv",
    "--patch_dir", f"{SAMPLE_DATA_DIR}/patches",
    "--gene_dir", f"{SAMPLE_DATA_DIR}/gene_data",
    "--text_dir", f"{SAMPLE_DATA_DIR}/text_descriptions",
    "--tokenizer_path", TOKENIZER_PATH,
    "--hf_token", HF_TOKEN,
    "--organ", TARGET_ORGAN,
    "--save_path", f"{OUTPUT_DIR}/finetune_checkpoint.pt",
    "--gpu", "0",
    "--num_epochs", "5",
    "--batch_size", "7",
]

print("=" * 70)
print("ASPIRE-IO Demo: Step 2 - Fine-tuning")
print("=" * 70)
print(f"\\nTarget organ: {TARGET_ORGAN}")
print(f"Output: {OUTPUT_DIR}/finetune_checkpoint.pt\\n")

env = os.environ.copy()
env["PYTHONWARNINGS"] = "ignore"
result = subprocess.run(args, env=env)

if result.returncode != 0:
    print(f"\\n[ERROR] finetune_mlp.py failed with exit code {result.returncode}")
else:
    print(f"\\n[SUCCESS] Fine-tuning complete!")


ASPIRE-IO Demo: Step 2 - Fine-tuning
\nTarget organ: Lung
Output: ./demo_outputs/finetune_checkpoint.pt\n


Fine-tuning organ-specific MLP for: Lung
Load ckpt from hf_hub:xiangjx/musk
Loaded pre-trained model from: ./demo_outputs/pretrain_checkpoint.pt
[TRAINABLE] Organ immune head: Lung.immune_head
Trainable parameters: 2,099,201 / 705,287,876 (0.30%)
Organ Dataset [Lung]: 7 spots
Organ Dataset [Lung]: 3 spots



Fine-tuning [Lung]:   0%|          | 0/1 [00:00<?, ?it/s]


Fine-tuning [Lung]: 100%|██████████| 1/1 [00:01<00:00,  1.20s/it]
                                                                 

Epoch   1: train_loss=981.4384, val_loss=28.0932, immune_pcc=0.9573
  -> Saved fine-tuned model



Fine-tuning [Lung]:   0%|          | 0/1 [00:00<?, ?it/s]


Fine-tuning [Lung]: 100%|██████████| 1/1 [00:00<00:00,  1.61it/s]
                                                                 

Epoch   2: train_loss=2962.0574, val_loss=2.1984, immune_pcc=0.8060
  -> Saved fine-tuned model



Fine-tuning [Lung]:   0%|          | 0/1 [00:00<?, ?it/s]


Fine-tuning [Lung]: 100%|██████████| 1/1 [00:00<00:00,  2.02it/s]
                                                                 

Epoch   3: train_loss=503.0334, val_loss=1.3084, immune_pcc=-0.2978
  -> Saved fine-tuned model



Fine-tuning [Lung]:   0%|          | 0/1 [00:00<?, ?it/s]


Fine-tuning [Lung]: 100%|██████████| 1/1 [00:00<00:00,  1.93it/s]
                                                                 

Epoch   4: train_loss=576.7669, val_loss=1.8067, immune_pcc=-0.7951



Fine-tuning [Lung]:   0%|          | 0/1 [00:00<?, ?it/s]


Fine-tuning [Lung]: 100%|██████████| 1/1 [00:00<00:00,  1.70it/s]
                                                                 

Epoch   5: train_loss=308.8769, val_loss=12.1394, immune_pcc=-0.7954
Fine-tuning complete for [Lung]. Best val_loss: 1.3084


\n[SUCCESS] Fine-tuning complete!


In [None]:
#!/usr/bin/env python3
"""
Step 3: MoE Wiring Network Training Demo
"""

import subprocess
import sys
import os
import warnings
warnings.filterwarnings("ignore")

# ============================================================================
# Configuration - UPDATE THESE PATHS
# ============================================================================

SAMPLE_DATA_DIR = "./sample_data"
OUTPUT_DIR = "./demo_outputs"

TOKENIZER_PATH = "<path-to-MUSK>/musk/models/tokenizer.spm"  # UPDATE THIS
HF_TOKEN = "<your-huggingface-token>"  # UPDATE THIS

args = [
    sys.executable,
    "./train_moe.py",
    "--train_csv", f"{SAMPLE_DATA_DIR}/sample_train.csv",
    "--val_csv", f"{SAMPLE_DATA_DIR}/sample_val.csv",
    "--patch_dir", f"{SAMPLE_DATA_DIR}/patches",
    "--gene_dir", f"{SAMPLE_DATA_DIR}/gene_data",
    "--text_dir", f"{SAMPLE_DATA_DIR}/text_descriptions",
    "--tokenizer_path", TOKENIZER_PATH,
    "--hf_token", HF_TOKEN,
    "--mode", "multimodal",
    "--save_path", f"{OUTPUT_DIR}/moe_checkpoint.pt",
    "--gpu", "0",
    "--num_epochs", "5",
    "--batch_size", "14",
]

print("=" * 70)
print("ASPIRE-IO Demo: Step 3 - MoE Training")
print("=" * 70)
print(f"\\nMode: multimodal")
print(f"Output: {OUTPUT_DIR}/moe_checkpoint.pt\\n")

env = os.environ.copy()
env["PYTHONWARNINGS"] = "ignore"
result = subprocess.run(args, env=env)

if result.returncode != 0:
    print(f"\\n[ERROR] train_moe.py failed with exit code {result.returncode}")
else:
    print(f"\\n[SUCCESS] MoE training complete!")


ASPIRE-IO Demo: Step 3 - MoE Training
\nMode: multimodal
Output: ./demo_outputs/moe_checkpoint.pt\n


ASPIRE-IO ORGAN CLASSIFIER TRAINING

Training organ classifier with mode: multimodal
Device: cuda:0

Loading MUSK backbone...
Load ckpt from hf_hub:xiangjx/musk
Froze MUSK backbone
MoE Organ Classifier initialized: mode='both', in_dim=2048, num_organs=7
Trainable parameters (classifier): 526,343

Loading datasets...
MoE Dataset: 56 samples, organs: {'Bladder': 10, 'Lymph node': 10, 'Bowel': 8, 'Skin': 8, 'Lung': 7, 'Prostate': 7, 'Breast': 6}
MoE Dataset: 14 samples, organs: {'Breast': 4, 'Lung': 3, 'Prostate': 3, 'Skin': 2, 'Bowel': 2}

Starting training for 5 epochs...



Training [both]:   0%|          | 0/4 [00:00<?, ?it/s]


Training [both]:  25%|██▌       | 1/4 [00:01<00:03,  1.29s/it]


Training [both]:  50%|█████     | 2/4 [00:01<00:01,  1.26it/s]


Training [both]:  75%|███████▌  | 3/4 [00:01<00:00,  1.83it/s]
                                                              

Validation:   0%|          | 0/1 [00:00<?, ?it/s]


Validation: 100%|██████████| 1/1 [00:00<00:00,  1.90it/s]
                                                         

Training [both]:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch   1 [multimodal]: train_loss=1.1503, train_acc=0.6071, val_loss=0.1048, val_acc=1.0000
  -> Saved best model to ./demo_outputs/moe_checkpoint.pt (accuracy=1.0000)



Training [both]:  25%|██▌       | 1/4 [00:00<00:02,  1.29it/s]


Training [both]:  50%|█████     | 2/4 [00:01<00:01,  1.71it/s]


Training [both]:  75%|███████▌  | 3/4 [00:01<00:00,  2.08it/s]
                                                              

Validation:   0%|          | 0/1 [00:00<?, ?it/s]


Validation: 100%|██████████| 1/1 [00:00<00:00,  1.97it/s]
                                                         

Training [both]:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch   2 [multimodal]: train_loss=0.0237, train_acc=1.0000, val_loss=0.1004, val_acc=0.9286



Training [both]:  25%|██▌       | 1/4 [00:00<00:02,  1.45it/s]


Training [both]:  50%|█████     | 2/4 [00:01<00:01,  1.91it/s]


Training [both]:  75%|███████▌  | 3/4 [00:01<00:00,  1.98it/s]
                                                              

Validation:   0%|          | 0/1 [00:00<?, ?it/s]


Validation: 100%|██████████| 1/1 [00:00<00:00,  1.57it/s]
                                                         

Training [both]:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch   3 [multimodal]: train_loss=0.0013, train_acc=1.0000, val_loss=0.0260, val_acc=1.0000



Training [both]:  25%|██▌       | 1/4 [00:00<00:02,  1.38it/s]


Training [both]:  50%|█████     | 2/4 [00:01<00:00,  2.14it/s]


Training [both]:  75%|███████▌  | 3/4 [00:01<00:00,  2.61it/s]
                                                              

Validation:   0%|          | 0/1 [00:00<?, ?it/s]


Validation: 100%|██████████| 1/1 [00:00<00:00,  1.60it/s]
                                                         

Training [both]:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch   4 [multimodal]: train_loss=0.0003, train_acc=1.0000, val_loss=0.0327, val_acc=1.0000



Training [both]:  25%|██▌       | 1/4 [00:00<00:01,  1.54it/s]


Training [both]:  50%|█████     | 2/4 [00:01<00:01,  1.95it/s]


Training [both]:  75%|███████▌  | 3/4 [00:01<00:00,  2.06it/s]
                                                              

Validation:   0%|          | 0/1 [00:00<?, ?it/s]


Validation: 100%|██████████| 1/1 [00:00<00:00,  1.53it/s]
                                                         

Epoch   5 [multimodal]: train_loss=0.0004, train_acc=1.0000, val_loss=0.0254, val_acc=1.0000
Training complete [multimodal]. Best val_accuracy: 1.0000
Model saved to: ./demo_outputs/moe_checkpoint.pt


\n[SUCCESS] MoE training complete!


In [None]:
#!/usr/bin/env python3
"""
Step 4: Inference Demo
"""

import subprocess
import sys
import os
import warnings
warnings.filterwarnings("ignore")

# ============================================================================
# Configuration - UPDATE THESE PATHS
# ============================================================================

SAMPLE_DATA_DIR = "./sample_data"
OUTPUT_DIR = "./demo_outputs"

TOKENIZER_PATH = "<path-to-MUSK>/musk/models/tokenizer.spm"  # UPDATE THIS
HF_TOKEN = "<your-huggingface-token>"  # UPDATE THIS

FINETUNE_CHECKPOINT = f"{OUTPUT_DIR}/finetune_checkpoint.pt"
MOE_CHECKPOINT = f"{OUTPUT_DIR}/moe_checkpoint.pt"

if not os.path.exists(FINETUNE_CHECKPOINT):
    raise FileNotFoundError("Fine-tuned checkpoint not found. Run Step 2 first.")

# MoE checkpoint is optional
classifier_args = []
if os.path.exists(MOE_CHECKPOINT):
    classifier_args = ["--classifier_checkpoint", MOE_CHECKPOINT, "--use_classifier"]
    print("Using organ classifier checkpoint")
else:
    classifier_args = ["--no_classifier"]
    print("Classifier checkpoint not found - running without organ classifier")

args = [
    sys.executable,
    "./spatial_gene_signature_prediction.py",
    "--checkpoint", FINETUNE_CHECKPOINT,
    "--input_csv", f"{SAMPLE_DATA_DIR}/sample_val.csv",
    "--patch_dir", f"{SAMPLE_DATA_DIR}/patches",
    "--text_dir", f"{SAMPLE_DATA_DIR}/text_descriptions",
    "--tokenizer_path", TOKENIZER_PATH,
    "--hf_token", HF_TOKEN,
    "--output_path", f"{OUTPUT_DIR}/predictions.csv",
    "--batch_size", "16",
    "--gpu", "0",
] + classifier_args

print("=" * 70)
print("ASPIRE-IO Demo: Step 4 - Inference")
print("=" * 70)
print(f"\\nInput: {SAMPLE_DATA_DIR}/sample_val.csv")
print(f"Output: {OUTPUT_DIR}/predictions.csv\\n")

env = os.environ.copy()
env["PYTHONWARNINGS"] = "ignore"
result = subprocess.run(args, env=env)

if result.returncode != 0:
    print(f"\\n[ERROR] Inference failed with exit code {result.returncode}")
else:
    print(f"\\n[SUCCESS] Inference complete!")


Using organ classifier checkpoint
ASPIRE-IO Demo: Step 4 - Inference
\nInput: ./sample_data/sample_val.csv
Output: ./demo_outputs/predictions.csv\n


ASPIRE-IO Inference
Loading checkpoint from: ./demo_outputs/finetune_checkpoint.pt
Checkpoint format: integrated=False, classifier=False, fuse_proj=False
Organs: ['Bladder', 'Bowel', 'Breast', 'Lung', 'Lymph node', 'Prostate', 'Skin']
Loading MUSK backbone...
Load ckpt from hf_hub:xiangjx/musk
Loaded MultiHeadModel
Loading classifier from: ./demo_outputs/moe_checkpoint.pt
Loaded OrganClassifier: in=2048, hidden=256, organs=7
Using organ classifier for automatic routing
Inference Dataset: 14 samples

Running inference...



Running inference:   0%|          | 0/2 [00:00<?, ?it/s]


Running inference:  50%|█████     | 1/2 [00:01<00:01,  1.24s/it]


Running inference: 100%|██████████| 2/2 [00:01<00:00,  1.50it/s]
Running inference: 100%|██████████| 2/2 [00:01<00:00,  1.21it/s]


Saved 14 predictions to: ./demo_outputs/predictions.csv

Inference complete!
Total samples: 14
Output saved to: ./demo_outputs/predictions.csv


\n[SUCCESS] Inference complete!
