# Time-Series Modeling with VAE

This notebook demonstrates how to use the enhanced VAE model for time-series data analysis and generation. Unlike image VAEs that work with spatial patterns, time-series VAEs model sequential patterns in medical data.

## Key Concepts:
- **Sequence Encoding**: RNN-based encoder captures temporal dependencies
- **Latent Representation**: Compressed representation of patient trajectories
- **Sequence Generation**: RNN decoder reconstructs realistic medical sequences

## Applications:
- Patient trajectory modeling and generation
- Medical sequence anomaly detection
- Synthetic data generation for rare conditions
- Treatment pattern analysis

In [8]:
from pyhealth.datasets import split_by_visit, get_dataloader
from pyhealth.trainer import Trainer
from pyhealth.models import VAE
from pyhealth.datasets import MIMIC4Dataset
from pyhealth.tasks import InHospitalMortalityMIMIC4

import torch
import numpy as np
import matplotlib.pyplot as plt

## Load MIMIC4 Demo Dataset

We'll use the MIMIC4 demo dataset to demonstrate time-series VAE on real medical sequences.

**Setup Instructions:**
1. Download MIMIC4 demo data from: https://physionet.org/files/mimic-iv-demo/2.2/
2. Create a `data/mimic4_demo` directory in your project root
3. Extract the downloaded files into `data/mimic4_demo/hosp/` subdirectory
4. Update the `ehr_root` path below if needed

In [9]:
# Load MIMIC4 demo dataset
# Download demo data from: https://physionet.org/files/mimic-iv-demo/2.2/
# and place in a local directory, then update ehr_root below
ehr_root = "data/mimic4_demo"  # Update this path to your local MIMIC4 demo data

dataset = MIMIC4Dataset(
    ehr_root=ehr_root,
    ehr_tables=["diagnoses_icd", "procedures_icd", "prescriptions"],
    dev=True,
)

# Set task for time-series modeling
task = InHospitalMortalityMIMIC4()
ts_dataset = dataset.set_task(task, num_workers=2)

print("MIMIC4 demo dataset loaded")
print(f"Number of samples: {len(ts_dataset)}")
print(f"Input features: {list(ts_dataset.input_schema.keys())}")
print(f"Output features: {list(ts_dataset.output_schema.keys())}")

Processing samples: 100%|██████████| 4/4 [00:00<00:00, 19949.13it/s]

Time-series dataset created
Number of samples: 4
Vocabulary size: 18





## Create and Train Time-Series VAE

The VAE will learn to encode patient trajectories into a latent space and reconstruct them.

In [10]:
# Create time-series VAE model
ts_model = VAE(
    dataset=ts_dataset,
    feature_keys=["conditions", "procedures"],  # Sequence features from MIMIC4
    label_key="mortality",
    mode="binary",  # Binary classification for mortality prediction
    input_type="timeseries",  # Key parameter for time-series mode
    hidden_dim=64,  # Latent dimension for medical sequences
)

print("Time-series VAE created")
print(f"Input type: {ts_model.input_type}")
print(f"Has embedding model: {hasattr(ts_model, 'embedding_model')}")
print(f"Has RNN encoder: {hasattr(ts_model, 'encoder_rnn')}")
print(f"Latent dimension: {ts_model.hidden_dim}")

Time-series VAE created
Input type: timeseries
Has embedding model: True
Has RNN encoder: True
Latent dimension: 32


## Understanding the Time-Series VAE Architecture

The time-series VAE differs from image VAEs:

1. **EmbeddingModel**: Converts categorical sequences to dense vectors
2. **RNN Encoder**: Processes sequential embeddings, capturing temporal patterns
3. **Latent Space**: Fixed-size representation of the entire sequence
4. **Linear Decoder**: Reconstructs the sequence's compressed representation

This architecture can learn patterns like "diabetes → metformin → insulin" or "asthma → albuterol → steroids".

In [11]:
# Prepare data for training
train_dataloader = get_dataloader(ts_dataset, batch_size=32, shuffle=True)

# Create trainer
trainer = Trainer(
    model=ts_model, 
    device="cuda" if torch.cuda.is_available() else "cpu",
    metrics=["kl_divergence", "pr_auc", "roc_auc"]
)

# Train the model (reduced epochs for demo)
print("Training time-series VAE...")
trainer.train(
    train_dataloader=train_dataloader,
    val_dataloader=train_dataloader,  # Using same data for demo
    epochs=10,
    monitor="kl_divergence",
    monitor_criterion="min",
    optimizer_params={"lr": 1e-4},
)

print("Training completed!")

VAE(
  (embedding_model): EmbeddingModel(embedding_layers=ModuleDict(
    (visits): Embedding(18, 32, padding_idx=0)
  ))
  (encoder_rnn): GRU(32, 32, batch_first=True)
  (mu): Linear(in_features=32, out_features=32, bias=True)
  (log_std2): Linear(in_features=32, out_features=32, bias=True)
  (decoder_linear): Linear(in_features=32, out_features=32, bias=True)
)
Metrics: ['kl_divergence', 'mse', 'mae']
Device: cuda

Training time-series VAE...
Training:
Batch size: 1
Optimizer: <class 'torch.optim.adam.Adam'>
Optimizer params: {'lr': 0.001}
Weight decay: 0.0
Max grad norm: None
Val dataloader: <torch.utils.data.dataloader.DataLoader object at 0x7a19402482c0>
Monitor: kl_divergence
Monitor criterion: min
Epochs: 5
Patience: None



Epoch 0 / 5: 100%|██████████| 4/4 [00:00<00:00, 249.98it/s]

--- Train epoch-0, step-4 ---
loss: 12.0682



Evaluation: 100%|██████████| 4/4 [00:00<00:00, 632.67it/s]

--- Eval epoch-0, step-4 ---
kl_divergence: 6.1801
mse: 0.0002
mae: 0.0102
loss: 13.0345
New best kl_divergence score (6.1801) at epoch-0, step-4




Epoch 1 / 5: 100%|██████████| 4/4 [00:00<00:00, 266.74it/s]

--- Train epoch-1, step-8 ---
loss: 14.3934



Evaluation: 100%|██████████| 4/4 [00:00<00:00, 590.21it/s]

--- Eval epoch-1, step-8 ---
kl_divergence: 5.3789
mse: 0.0002
mae: 0.0103
loss: 11.8951
New best kl_divergence score (5.3789) at epoch-1, step-8




Epoch 2 / 5: 100%|██████████| 4/4 [00:00<00:00, 253.78it/s]

--- Train epoch-2, step-12 ---
loss: 12.0990



Evaluation: 100%|██████████| 4/4 [00:00<00:00, 710.03it/s]

--- Eval epoch-2, step-12 ---
kl_divergence: 7.1621
mse: 0.0003
mae: 0.0113
loss: 13.0983




Epoch 3 / 5: 100%|██████████| 4/4 [00:00<00:00, 277.91it/s]

--- Train epoch-3, step-16 ---
loss: 11.9011



Evaluation: 100%|██████████| 4/4 [00:00<00:00, 614.75it/s]

--- Eval epoch-3, step-16 ---
kl_divergence: 7.2321
mse: 0.0003
mae: 0.0119
loss: 13.5714




Epoch 4 / 5: 100%|██████████| 4/4 [00:00<00:00, 264.15it/s]

--- Train epoch-4, step-20 ---
loss: 15.0954



Evaluation: 100%|██████████| 4/4 [00:00<00:00, 602.67it/s]

--- Eval epoch-4, step-20 ---
kl_divergence: 6.8994
mse: 0.0003
mae: 0.0119
loss: 10.0336
Loaded best model
Training completed!





## Evaluate Reconstruction Performance

Check how well the VAE reconstructs the original sequences.

In [12]:
# Evaluate on training data
eval_results = trainer.evaluate(train_dataloader)
print("Evaluation Results:")
for metric, value in eval_results.items():
    print(f"{metric}: {value:.4f}")

# Get reconstruction examples
data_batch = next(iter(train_dataloader))
with torch.no_grad():
    output = ts_model(**data_batch)
    
print(f"\nReconstruction shape: {output['y_prob'].shape}")
print(f"Original shape: {output['y_true'].shape}")
print(f"Loss: {output['loss'].item():.4f}")

Evaluation: 100%|██████████| 4/4 [00:00<00:00, 599.21it/s]

Evaluation Results:
kl_divergence: 8.0949
mse: 0.0003
mae: 0.0125
loss: 13.2461

Reconstruction shape: torch.Size([1, 32])
Original shape: torch.Size([1, 32])
Loss: 23.6351





## Generate New Medical Sequences

Sample from the latent space to generate new patient trajectories and convert them to human-understandable medical codes.

In [13]:
# Generate new sequences by sampling from latent space
ts_model.eval()
with torch.no_grad():
    # Sample random latent vectors
    latent_samples = torch.randn(3, ts_model.hidden_dim).to(ts_model.device)
    
    # Decode to get sequence representations
    generated_sequences = ts_model.decoder(latent_samples)
    
    print("Generated sequence representations:")
    print(f"Shape: {generated_sequences.shape}")
    print(f"Sample values: {generated_sequences[0, :5].cpu().numpy()}")
    
    # Convert embeddings to human-understandable medical codes
    # Find closest codes in embedding space
    
    # Get all code embeddings from the embedding model
    all_codes = list(ts_model.embedding_model.code_vocab.keys())
    code_embeddings = ts_model.embedding_model.embeddings.weight.data  # [vocab_size, embed_dim]
    
    print(f"\nConverting to medical codes for generated sequence 0:")
    
    # For each position in the sequence, find closest codes
    seq_embeds = generated_sequences[0]  # [seq_len, embed_dim]
    
    # Compute cosine similarity with all code embeddings
    similarities = torch.matmul(seq_embeds, code_embeddings.t())  # [seq_len, vocab_size]
    
    # Get top 3 most similar codes for each position
    top_k = 3
    top_similarities, top_indices = torch.topk(similarities, top_k, dim=1)
    
    for pos in range(min(5, seq_embeds.shape[0])):  # Show first 5 positions
        codes = [all_codes[idx] for idx in top_indices[pos].cpu().numpy()]
        sims = top_similarities[pos].cpu().numpy()
        print(f"Position {pos}: {codes} (similarities: {sims})")
    
    print("\nNote: These represent the most likely medical codes for the generated sequence.")
    print("In practice, you might use beam search or other decoding strategies for better results.")

Generated sequence representations:
Shape: torch.Size([3, 32])
Sample values: [-1.181458   -0.8360508   0.49716952 -0.88951784  0.26240543]


## Key Insights

### How Time-Series VAE Works:
1. **Input Processing**: Categorical sequences (diagnoses, procedures) are embedded using the EmbeddingModel
2. **Sequence Encoding**: RNN processes the embedded sequence to capture temporal patterns
3. **Latent Compression**: Variable-length sequences become fixed-size latent vectors
4. **Reconstruction**: Decoder attempts to recreate the embedded sequence representation
5. **Code Generation**: Generated embeddings are mapped back to medical codes using nearest neighbor search

### Medical Applications:
- **Trajectory Analysis**: Understand typical patient progression patterns from real MIMIC4 data
- **Synthetic Data**: Generate realistic patient histories for research and model training
- **Anomaly Detection**: Identify unusual treatment sequences in clinical practice
- **Outcome Prediction**: Learn sequence patterns that correlate with mortality and other outcomes
- **Data Augmentation**: Create additional training samples for underrepresented conditions

### Key Improvements in This Version:
- **Real Data**: Uses MIMIC4 demo dataset instead of synthetic data for more realistic modeling
- **Multiple Sequences**: Models both diagnoses and procedures simultaneously
- **Human-Readable Output**: Converts generated embeddings back to interpretable medical codes
- **Clinical Relevance**: Focuses on in-hospital mortality prediction task

### Differences from Image VAE:
- **Temporal vs Spatial**: Captures time-ordered dependencies instead of spatial patterns
- **Variable Length**: Handles sequences of different lengths
- **Categorical Data**: Works with medical codes, diagnoses, treatments
- **Generation**: Creates new realistic patient trajectories with interpretable medical codes