# Time-Series Modeling with VAE

This notebook demonstrates how to use the enhanced VAE model for time-series data analysis and generation. Unlike image VAEs that work with spatial patterns, time-series VAEs model sequential patterns in medical data.

## Key Concepts:
- **Sequence Encoding**: RNN-based encoder captures temporal dependencies
- **Latent Representation**: Compressed representation of patient trajectories
- **Sequence Generation**: RNN decoder reconstructs realistic medical sequences

## Applications:
- Patient trajectory modeling and generation
- Medical sequence anomaly detection
- Synthetic data generation for rare conditions
- Treatment pattern analysis

In [None]:
from pyhealth.datasets import split_by_visit, get_dataloader
from pyhealth.trainer import Trainer
from pyhealth.models import VAE
from pyhealth.datasets import SampleDataset

import torch
import numpy as np
import matplotlib.pyplot as plt

## Create Time-Series Medical Data

We'll create sample patient trajectories showing disease progression and treatment sequences.

In [None]:
# Create sample time-series medical data
ts_samples = [
    {
        "patient_id": "patient-0",
        "visit_id": "visit-0",
        "visits": ["diabetes", "metformin", "hba1c_test", "insulin"],
        "label": 1.0,
    },
    {
        "patient_id": "patient-1",
        "visit_id": "visit-1",
        "visits": ["hypertension", "lisinopril", "bp_check", "followup"],
        "label": 0.5,
    },
    {
        "patient_id": "patient-2",
        "visit_id": "visit-2",
        "visits": ["asthma", "albuterol", "peak_flow", "steroids"],
        "label": 0.8,
    },
    {
        "patient_id": "patient-3",
        "visit_id": "visit-3",
        "visits": ["depression", "sertraline", "therapy", "counseling"],
        "label": 0.3,
    },
]

# Create dataset
ts_dataset = SampleDataset(
    samples=ts_samples,
    input_schema={"visits": "sequence"},
    output_schema={"label": "regression"},
    dataset_name="timeseries_demo",
)

print("Time-series dataset created")
print(f"Number of samples: {len(ts_dataset)}")
print(f"Vocabulary size: {len(ts_dataset.input_processors['visits'].code_vocab)}")

## Create and Train Time-Series VAE

The VAE will learn to encode patient trajectories into a latent space and reconstruct them.

In [None]:
# Create time-series VAE model
ts_model = VAE(
    dataset=ts_dataset,
    feature_keys=["visits"],
    label_key="label",
    mode="regression",
    input_type="timeseries",  # Key parameter for time-series mode
    hidden_dim=32,  # Smaller latent dimension for sequences
)

print("Time-series VAE created")
print(f"Input type: {ts_model.input_type}")
print(f"Has embedding model: {hasattr(ts_model, 'embedding_model')}")
print(f"Has RNN encoder: {hasattr(ts_model, 'encoder_rnn')}")
print(f"Latent dimension: {ts_model.hidden_dim}")

## Understanding the Time-Series VAE Architecture

The time-series VAE differs from image VAEs:

1. **EmbeddingModel**: Converts categorical sequences to dense vectors
2. **RNN Encoder**: Processes sequential embeddings, capturing temporal patterns
3. **Latent Space**: Fixed-size representation of the entire sequence
4. **Linear Decoder**: Reconstructs the sequence's compressed representation

This architecture can learn patterns like "diabetes → metformin → insulin" or "asthma → albuterol → steroids".

In [None]:
# Prepare data for training
train_dataloader = get_dataloader(ts_dataset, batch_size=1, shuffle=True)

# Create trainer
trainer = Trainer(
    model=ts_model, 
    device="cuda" if torch.cuda.is_available() else "cpu",
    metrics=["kl_divergence", "mse", "mae"]
)

# Train the model (reduced epochs for demo)
print("Training time-series VAE...")
trainer.train(
    train_dataloader=train_dataloader,
    val_dataloader=train_dataloader,  # Using same data for demo
    epochs=5,
    monitor="kl_divergence",
    monitor_criterion="min",
    optimizer_params={"lr": 1e-3},
)

print("Training completed!")

## Evaluate Reconstruction Performance

Check how well the VAE reconstructs the original sequences.

In [None]:
# Evaluate on training data
eval_results = trainer.evaluate(train_dataloader)
print("Evaluation Results:")
for metric, value in eval_results.items():
    print(f"{metric}: {value:.4f}")

# Get reconstruction examples
data_batch = next(iter(train_dataloader))
with torch.no_grad():
    output = ts_model(**data_batch)
    
print(f"\nReconstruction shape: {output['y_prob'].shape}")
print(f"Original shape: {output['y_true'].shape}")
print(f"Loss: {output['loss'].item():.4f}")

## Generate New Medical Sequences

Sample from the latent space to generate new patient trajectories.

In [None]:
# Generate new sequences by sampling from latent space
ts_model.eval()
with torch.no_grad():
    # Sample random latent vectors
    latent_samples = torch.randn(3, ts_model.hidden_dim).to(ts_model.device)
    
    # Decode to get sequence representations
    generated_sequences = ts_model.decoder(latent_samples)
    
    print("Generated sequence representations:")
    print(f"Shape: {generated_sequences.shape}")
    print(f"Sample values: {generated_sequences[0, :5].cpu().numpy()}")
    
    # The generated sequences represent points in the embedded space
    # In a full implementation, you might use a decoder RNN to generate
    # actual token sequences, but here we show the latent generation concept

## Key Insights

### How Time-Series VAE Works:
1. **Input Processing**: Categorical sequences are embedded using the EmbeddingModel
2. **Sequence Encoding**: RNN processes the embedded sequence to capture temporal patterns
3. **Latent Compression**: Variable-length sequences become fixed-size latent vectors
4. **Reconstruction**: Decoder attempts to recreate the embedded sequence representation

### Medical Applications:
- **Trajectory Analysis**: Understand typical patient progression patterns
- **Synthetic Data**: Generate realistic patient histories for research
- **Anomaly Detection**: Identify unusual treatment sequences
- **Outcome Prediction**: Learn sequence patterns that correlate with outcomes

### Differences from Image VAE:
- **Temporal vs Spatial**: Captures time-ordered dependencies instead of spatial patterns
- **Variable Length**: Handles sequences of different lengths
- **Categorical Data**: Works with medical codes, diagnoses, treatments
- **Generation**: Creates new realistic patient trajectories