# GRASP: Mortality Prediction on MIMIC-III (Baseline — No code_mapping)

This notebook runs the GRASP model for mortality prediction **without** `code_mapping`.
Raw ICD-9 and NDC codes are used as-is for the embedding vocabulary.

**Paper**: Liantao Ma et al. "GRASP: Generic Framework for Health Status Representation Learning Based on Incorporating Knowledge from Similar Patients." AAAI 2021.

GRASP encodes patient sequences with a backbone (ConCare, GRU, or LSTM), clusters patients via k-means, refines cluster representations with a 2-layer GCN, and blends cluster-level knowledge back into individual patient representations via a learned gating mechanism.

**Model:** GRASP (GRU backbone + GCN cluster refinement)  
**Task:** In-hospital mortality prediction  
**Dataset:** Synthetic MIMIC-III (`dev=False`)

## Step 1: Load the MIMIC-III Dataset

We load the MIMIC-III dataset using PyHealth's `MIMIC3Dataset` class. We use the synthetic dataset hosted on GCS, which requires no credentials.

- `root`: URL to the synthetic MIMIC-III data
- `tables`: Clinical tables to load (diagnoses, procedures, prescriptions)
- `dev`: Set to `False` for the full dataset

In [None]:
import tempfile

from pyhealth.datasets import MIMIC3Dataset

base_dataset = MIMIC3Dataset(
    root="https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III",
    tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"],
    cache_dir=tempfile.TemporaryDirectory().name,
    dev=False,
)

base_dataset.stats()

## Step 2: Define the Mortality Prediction Task

The `MortalityPredictionMIMIC3` task extracts samples from the raw EHR data:
- Extracts diagnosis codes (ICD-9), procedure codes, and drug information from each visit
- Creates binary labels based on in-hospital mortality
- Filters out visits without sufficient clinical codes

In [None]:
from pyhealth.tasks import MortalityPredictionMIMIC3

task = MortalityPredictionMIMIC3()

samples = base_dataset.set_task(task)

print(f"Generated {len(samples)} samples")
print(f"\nInput schema: {samples.input_schema}")
print(f"Output schema: {samples.output_schema}")

## Step 3: Dataset Statistics

Each sample represents one hospital visit with:
- **conditions**: List of ICD-9 diagnosis codes
- **procedures**: List of ICD-9 procedure codes
- **drugs**: List of NDC drug codes
- **mortality**: Binary label (0 = survived, 1 = deceased)

In [None]:
print("Sample structure:")
print(samples[0])

print("\n" + "=" * 50)
print("Processor Vocabulary Sizes:")
print("=" * 50)
for key, proc in samples.input_processors.items():
    if hasattr(proc, 'code_vocab'):
        print(f"{key}: {len(proc.code_vocab)} codes (including <pad>, <unk>)")

mortality_count = sum(float(s.get("mortality", 0)) for s in samples)
print(f"\nTotal samples: {len(samples)}")
print(f"Mortality rate: {mortality_count / len(samples) * 100:.2f}%")
print(f"Positive samples: {int(mortality_count)}")
print(f"Negative samples: {len(samples) - int(mortality_count)}")

## Step 4: Split the Dataset

We split the data by patient to avoid data leakage — all visits from a given patient go into the same split.

In [None]:
from pyhealth.datasets import split_by_patient

train_dataset, val_dataset, test_dataset = split_by_patient(
    samples, [0.8, 0.1, 0.1], seed=42
)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")

## Step 5: Create Data Loaders

Data loaders batch the samples and handle data feeding during training and evaluation.

In [None]:
from pyhealth.datasets import get_dataloader

train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False)
test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False)

print(f"Training batches: {len(train_dataloader)}")
print(f"Validation batches: {len(val_dataloader)}")
print(f"Test batches: {len(test_dataloader)}")

## Step 6: Initialize the GRASP Model

The GRASP model automatically handles different feature types via `EmbeddingModel`.
Sequence features (diagnosis/procedure/drug codes) are embedded using learned embeddings,
and each feature gets its own `GRASPLayer`.

### Key Parameters:
- `embedding_dim`: Dimension of code embeddings (default: 128)
- `hidden_dim`: Hidden dimension of the backbone (default: 128)
- `cluster_num`: Number of patient clusters for knowledge sharing (default: 2)
- `block`: Backbone encoder — `"ConCare"`, `"GRU"`, or `"LSTM"` (default: `"ConCare"`)
- `dropout`: Dropout rate for regularization (default: 0.5)

In [None]:
from pyhealth.models import GRASP

model = GRASP(
    dataset=samples,
    embedding_dim=128,
    hidden_dim=128,
    cluster_num=12,
    block="GRU",
)

print(f"Model initialized with {sum(p.numel() for p in model.parameters()):,} parameters")
print(f"\nModel architecture:")
print(model)

## Step 7: Train the Model

We use PyHealth's `Trainer` class which handles:
- Training loop with automatic batching
- Validation during training
- Model checkpointing based on validation metrics

We monitor the **ROC-AUC** score on the validation set.

In [None]:
from pyhealth.trainer import Trainer

trainer = Trainer(
    model=model,
    metrics=["roc_auc", "pr_auc", "accuracy", "f1"],
)

trainer.train(
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    epochs=50,
    monitor="roc_auc",
    optimizer_params={"lr": 1e-3},
)

## Step 8: Evaluate on Test Set

After training, we evaluate the model on the held-out test set to measure its generalization performance.

In [None]:
test_results = trainer.evaluate(test_dataloader)

print("\n" + "=" * 50)
print("Test Set Performance (NO code_mapping)")
print("=" * 50)
for metric, value in test_results.items():
    print(f"{metric}: {value:.4f}")

## Step 9: Extract Patient Embeddings

GRASP produces patient embeddings that encode health status enriched with knowledge from similar patients.
These embeddings can be used for downstream tasks like patient similarity search, cohort discovery, or transfer learning.

In [None]:
import torch

model.eval()
test_batch = next(iter(test_dataloader))
test_batch["embed"] = True

with torch.no_grad():
    output = model(**test_batch)

print(f"Embedding shape: {output['embed'].shape}")
print(f"  - Batch size: {output['embed'].shape[0]}")
print(f"  - Embedding dim: {output['embed'].shape[1]}")

print("\n" + "=" * 50)
print("Sample Predictions:")
print("=" * 50)
predictions = output["y_prob"].cpu().numpy()
true_labels = output["y_true"].cpu().numpy()

for i in range(min(5, len(predictions))):
    pred = predictions[i][0]
    true = int(true_labels[i][0])
    print(f"Patient {i + 1}: Predicted={pred:.3f}, True={true}, Prediction={'Mortality' if pred > 0.5 else 'Survival'}")