# Time-Series Modeling with VAE

This notebook demonstrates how to use the enhanced VAE model for time-series data analysis and generation. Unlike image VAEs that work with spatial patterns, time-series VAEs model sequential patterns in medical data.

## Key Concepts:
- **Sequence Encoding**: RNN-based encoder captures temporal dependencies
- **Latent Representation**: Compressed representation of patient trajectories
- **Sequence Generation**: RNN decoder reconstructs realistic medical sequences

## Applications:
- Patient trajectory modeling and generation
- Medical sequence anomaly detection
- Synthetic data generation for rare conditions
- Treatment pattern analysis

In [1]:
from pyhealth.datasets import split_by_visit, get_dataloader
from pyhealth.trainer import Trainer
from pyhealth.models import VAE
from pyhealth.datasets import MIMIC4Dataset
from pyhealth.tasks import MortalityPredictionMIMIC4

import torch
import numpy as np
import matplotlib.pyplot as plt

  from tqdm.autonotebook import trange
  import pkg_resources


## Load MIMIC4 Demo Dataset

We'll use the MIMIC4 demo dataset to demonstrate time-series VAE on real medical sequences.

**Setup Instructions:**
1. Download MIMIC4 demo data from: https://physionet.org/files/mimic-iv-demo/2.2/
2. Create a `data/mimic4_demo` directory in your project root
3. Extract the downloaded files into `data/mimic4_demo/hosp/` subdirectory
4. Update the `ehr_root` path below if needed

In [2]:
# Load MIMIC4 demo dataset
# Download demo data from: https://physionet.org/files/mimic-iv-demo/2.2/
# and place in a local directory, then update ehr_root below
ehr_root = "/home/ubuntu/PyHealth/data/mimic-iv-clinical-database-demo-2.2/"  # Update this path to your local MIMIC4 demo data

dataset = MIMIC4Dataset(
    ehr_root=ehr_root,
    ehr_tables=["diagnoses_icd", "procedures_icd", "prescriptions"],
    dev=True,
)

# Set task for time-series modeling
task = MortalityPredictionMIMIC4()
ts_dataset = dataset.set_task(task, num_workers=2)

print("MIMIC4 demo dataset loaded")
print(f"Number of samples: {len(ts_dataset)}")
print(f"Input features: {list(ts_dataset.input_schema.keys())}")
print(f"Output features: {list(ts_dataset.output_schema.keys())}")

Memory usage Starting MIMIC4Dataset init: 812.4 MB
Initializing MIMIC4EHRDataset with tables: ['diagnoses_icd', 'procedures_icd', 'prescriptions'] (dev mode: True)
Using default EHR config: /home/ubuntu/PyHealth/pyhealth/datasets/configs/mimic4_ehr.yaml
Memory usage Before initializing mimic4_ehr: 812.4 MB
Initializing mimic4_ehr dataset from /home/ubuntu/PyHealth/data/mimic-iv-clinical-database-demo-2.2/ (dev mode: False)
Scanning table: diagnoses_icd from /home/ubuntu/PyHealth/data/mimic-iv-clinical-database-demo-2.2/hosp/diagnoses_icd.csv.gz
Joining with table: /home/ubuntu/PyHealth/data/mimic-iv-clinical-database-demo-2.2/hosp/admissions.csv.gz
Scanning table: procedures_icd from /home/ubuntu/PyHealth/data/mimic-iv-clinical-database-demo-2.2/hosp/procedures_icd.csv.gz
Joining with table: /home/ubuntu/PyHealth/data/mimic-iv-clinical-database-demo-2.2/hosp/admissions.csv.gz
Scanning table: prescriptions from /home/ubuntu/PyHealth/data/mimic-iv-clinical-database-demo-2.2/hosp/prescrip

Collecting samples for MortalityPredictionMIMIC4 from 2 workers: 100%|██████████| 100/100 [00:00<00:00, 184.14it/s]

Label mortality vocab: {0: 0, 1: 1}



Processing samples: 100%|██████████| 108/108 [00:00<00:00, 20409.32it/s]

Generated 108 samples for task MortalityPredictionMIMIC4
MIMIC4 demo dataset loaded
Number of samples: 108
Input features: ['conditions', 'procedures', 'drugs']
Output features: ['mortality']





## Create and Train Time-Series VAE

The VAE will learn to encode patient trajectories into a latent space and reconstruct them.

In [3]:
# Create time-series VAE model
ts_model = VAE(
    dataset=ts_dataset,
    feature_keys=["conditions"],  # Single sequence feature for VAE
    label_key="mortality",
    mode="binary",  # Binary classification for mortality prediction
    input_type="timeseries",  # Key parameter for time-series mode
    hidden_dim=64,  # Latent dimension for medical sequences
)

print("Time-series VAE created")
print(f"Input type: {ts_model.input_type}")
print(f"Has embedding model: {hasattr(ts_model, 'embedding_model')}")
print(f"Has RNN encoder: {hasattr(ts_model, 'encoder_rnn')}")
print(f"Latent dimension: {ts_model.hidden_dim}")

Time-series VAE created
Input type: timeseries
Has embedding model: True
Has RNN encoder: True
Latent dimension: 64


## Understanding the Time-Series VAE Architecture

The time-series VAE differs from image VAEs:

1. **EmbeddingModel**: Converts categorical sequences to dense vectors
2. **RNN Encoder**: Processes sequential embeddings, capturing temporal patterns
3. **Latent Space**: Fixed-size representation of the entire sequence
4. **Linear Decoder**: Reconstructs the sequence's compressed representation

This architecture can learn patterns like "diabetes → metformin → insulin" or "asthma → albuterol → steroids".

In [4]:
# Prepare data for training
train_dataloader = get_dataloader(ts_dataset, batch_size=32, shuffle=True)

# Create trainer
trainer = Trainer(
    model=ts_model, 
    device="cuda" if torch.cuda.is_available() else "cpu",
    metrics=[]  # VAE is unsupervised, no classification metrics needed
)

# Train the model (reduced epochs for demo)
print("Training time-series VAE...")
trainer.train(
    train_dataloader=train_dataloader,
    val_dataloader=train_dataloader,  # Using same data for demo
    epochs=10,
    monitor="loss",
    monitor_criterion="min",
    optimizer_params={"lr": 1e-4},
)

print("Training completed!")

VAE(
  (embedding_model): EmbeddingModel(embedding_layers=ModuleDict(
    (conditions): Embedding(865, 64, padding_idx=0)
    (procedures): Embedding(218, 64, padding_idx=0)
    (drugs): Embedding(486, 64, padding_idx=0)
  ))
  (encoder_rnn): GRU(64, 64, batch_first=True)
  (mu): Linear(in_features=64, out_features=64, bias=True)
  (log_std2): Linear(in_features=64, out_features=64, bias=True)
  (decoder_linear): Linear(in_features=64, out_features=64, bias=True)
)
Metrics: []
Device: cuda

Training time-series VAE...
Training:
Batch size: 32
Optimizer: <class 'torch.optim.adam.Adam'>
Optimizer params: {'lr': 0.0001}
Weight decay: 0.0
Max grad norm: None
Val dataloader: <torch.utils.data.dataloader.DataLoader object at 0x7cd7dc4a4350>
Monitor: loss
Monitor criterion: min
Epochs: 10
Patience: None



Epoch 0 / 10: 100%|██████████| 4/4 [00:00<00:00, 10.86it/s]

--- Train epoch-0, step-4 ---
loss: 600.1951



Evaluation: 100%|██████████| 4/4 [00:00<00:00, 712.44it/s]

--- Eval epoch-0, step-4 ---
loss: 596.4567
New best loss score (596.4567) at epoch-0, step-4




Epoch 1 / 10: 100%|██████████| 4/4 [00:00<00:00, 290.50it/s]

--- Train epoch-1, step-8 ---
loss: 570.5898



Evaluation: 100%|██████████| 4/4 [00:00<00:00, 731.70it/s]

--- Eval epoch-1, step-8 ---
loss: 581.0278
New best loss score (581.0278) at epoch-1, step-8




Epoch 2 / 10: 100%|██████████| 4/4 [00:00<00:00, 288.28it/s]

--- Train epoch-2, step-12 ---
loss: 585.5507



Evaluation: 100%|██████████| 4/4 [00:00<00:00, 743.87it/s]

--- Eval epoch-2, step-12 ---
loss: 591.7324




Epoch 3 / 10: 100%|██████████| 4/4 [00:00<00:00, 291.74it/s]

--- Train epoch-3, step-16 ---
loss: 571.2698



Evaluation: 100%|██████████| 4/4 [00:00<00:00, 740.78it/s]

--- Eval epoch-3, step-16 ---
loss: 569.4090
New best loss score (569.4090) at epoch-3, step-16




Epoch 4 / 10: 100%|██████████| 4/4 [00:00<00:00, 293.35it/s]

--- Train epoch-4, step-20 ---
loss: 575.8144



Evaluation: 100%|██████████| 4/4 [00:00<00:00, 740.13it/s]

--- Eval epoch-4, step-20 ---
loss: 593.6073




Epoch 5 / 10: 100%|██████████| 4/4 [00:00<00:00, 293.48it/s]

--- Train epoch-5, step-24 ---
loss: 554.4583



Evaluation: 100%|██████████| 4/4 [00:00<00:00, 730.30it/s]

--- Eval epoch-5, step-24 ---
loss: 575.9142




Epoch 6 / 10: 100%|██████████| 4/4 [00:00<00:00, 289.78it/s]

--- Train epoch-6, step-28 ---
loss: 567.8894



Evaluation: 100%|██████████| 4/4 [00:00<00:00, 743.44it/s]

--- Eval epoch-6, step-28 ---
loss: 577.6712




Epoch 7 / 10: 100%|██████████| 4/4 [00:00<00:00, 294.15it/s]

--- Train epoch-7, step-32 ---
loss: 590.4861



Evaluation: 100%|██████████| 4/4 [00:00<00:00, 737.75it/s]

--- Eval epoch-7, step-32 ---
loss: 557.5274
New best loss score (557.5274) at epoch-7, step-32




Epoch 8 / 10: 100%|██████████| 4/4 [00:00<00:00, 291.25it/s]

--- Train epoch-8, step-36 ---
loss: 577.2799



Evaluation: 100%|██████████| 4/4 [00:00<00:00, 751.47it/s]

--- Eval epoch-8, step-36 ---
loss: 552.2935
New best loss score (552.2935) at epoch-8, step-36




Epoch 9 / 10: 100%|██████████| 4/4 [00:00<00:00, 293.05it/s]

--- Train epoch-9, step-40 ---
loss: 536.7056



Evaluation: 100%|██████████| 4/4 [00:00<00:00, 743.14it/s]

--- Eval epoch-9, step-40 ---
loss: 547.2967
New best loss score (547.2967) at epoch-9, step-40
Loaded best model





Training completed!


## Evaluate Reconstruction Performance

Check how well the VAE reconstructs the original sequences.

**What the outputs represent:**
- `y_prob`: Reconstructed patient trajectory embeddings (VAE's attempt to recreate the input)
- `y_true`: Original RNN hidden states summarizing each patient's diagnosis sequence
- `loss`: Reconstruction error measuring how well the VAE captures medical patterns

The `y_true` values are 64-dimensional vectors that represent compressed summaries of patient medical histories, capturing temporal patterns like disease progression (e.g., hypertension → diabetes → kidney disease).

In [5]:
# Evaluate on training data
eval_results = trainer.evaluate(train_dataloader)
print("Evaluation Results:")
for metric, value in eval_results.items():
    print(f"{metric}: {value:.4f}")

# Get reconstruction examples
data_batch = next(iter(train_dataloader))
with torch.no_grad():
    output = ts_model(**data_batch)
    
print(f"\nReconstruction shape: {output['y_prob'].shape}")
print(f"Original shape: {output['y_true'].shape}")
print(f"Loss: {output['loss'].item():.4f}")

Evaluation: 100%|██████████| 4/4 [00:00<00:00, 715.26it/s]

Evaluation Results:
loss: 541.9971

Reconstruction shape: torch.Size([32, 64])
Original shape: torch.Size([32, 64])
Loss: 654.7109





## Generate New Medical Sequences

Sample from the latent space to generate new patient trajectories and convert them to human-understandable medical codes.

In [None]:
# Generate new sequences by sampling from latent space
ts_model.eval()
with torch.no_grad():
    # Sample random latent vectors
    latent_samples = torch.randn(3, ts_model.hidden_dim).to(ts_model.device)
    
    # Decode to get sequence representations
    generated_sequences = ts_model.decoder(latent_samples)
    
    print("Generated sequence representations:")
    print(f"Shape: {generated_sequences.shape}")
    print(f"Sample values: {generated_sequences[0, :5].cpu().numpy()}")
    
    # Convert embeddings to human-understandable medical codes
    # Find closest codes in embedding space
    
    # Get all code embeddings from the embedding model
    conditions_vocab = list(ts_dataset.input_processors['conditions'].code_vocab.keys())
    all_codes = conditions_vocab  # Use conditions vocabulary
    code_embeddings = ts_model.embedding_model.embedding_layers['conditions'].weight.data  # [vocab_size, embed_dim]
    
    print(f"\nConverting to medical codes for generated sequence 0:")
    
    # The generated sequence is a single embedding vector, not a sequence
    seq_embed = generated_sequences[0]  # [embed_dim]
    
    # Compute cosine similarity with all code embeddings
    similarities = torch.matmul(seq_embed, code_embeddings.t())  # [vocab_size]
    
    # Get top 3 most similar codes
    top_k = 3
    top_similarities, top_indices = torch.topk(similarities, top_k, dim=0)
    
    codes = [all_codes[idx] for idx in top_indices.cpu().numpy()]
    sims = top_similarities.cpu().numpy()
    print(f"Top {top_k} similar medical codes: {codes}")
    print(f"Similarities: {sims}")
    
    print("\nNote: These represent the most likely medical codes for the generated sequence.")
    print("In practice, you might use beam search or other decoding strategies for better results.")

Generated sequence representations:
Shape: torch.Size([3, 64])
Sample values: [ 0.53002435  1.3221712  -0.48090675 -0.45167932 -0.49437132]

Converting to medical codes for generated sequence 0:
tensor([ 0.0000e+00, -7.3698e-03, -3.6019e+00,  4.2856e+00,  4.9865e+00,
        -3.4686e+00, -2.6077e+00,  3.6966e-01,  3.1678e-01,  1.8775e+00,
         2.7686e+00,  9.8846e+00,  1.0820e+01,  4.1149e+00,  1.1023e+01,
        -4.6142e+00,  6.1783e+00, -1.9797e+00,  3.3450e+00,  4.8513e-01,
         4.9730e+00,  4.7466e+00,  6.4790e+00,  4.3559e+00, -4.6351e+00,
         5.0865e-01,  2.9454e+00,  4.5632e+00, -6.7101e+00,  1.6942e+00,
         9.0058e+00,  1.1524e+00, -4.4482e+00,  3.0030e+00, -6.6758e+00,
         1.5273e+01, -9.0026e-01, -5.9171e-01, -8.6225e+00, -6.0907e+00,
         3.5220e+00,  9.3690e+00, -5.0509e-01,  1.1058e+01, -2.0280e+00,
         9.4648e+00, -1.8569e+00,  1.9413e+00,  6.4853e+00, -8.4831e+00,
        -9.4427e-01, -2.8588e+00,  7.2913e-01, -9.2253e+00, -3.6401e+00,
  

TypeError: topk() got multiple values for argument 'k'

## Key Insights

### How Time-Series VAE Works:
1. **Input Processing**: Categorical sequences (diagnoses, procedures) are embedded using the EmbeddingModel
2. **Sequence Encoding**: RNN processes the embedded sequence to capture temporal patterns
3. **Latent Compression**: Variable-length sequences become fixed-size latent vectors
4. **Reconstruction**: Decoder attempts to recreate the embedded sequence representation
5. **Code Generation**: Generated embeddings are mapped back to medical codes using nearest neighbor search

### Medical Applications:
- **Trajectory Analysis**: Understand typical patient progression patterns from real MIMIC4 data
- **Synthetic Data**: Generate realistic patient histories for research and model training
- **Anomaly Detection**: Identify unusual treatment sequences in clinical practice
- **Outcome Prediction**: Learn sequence patterns that correlate with mortality and other outcomes
- **Data Augmentation**: Create additional training samples for underrepresented conditions

### Key Improvements in This Version:
- **Real Data**: Uses MIMIC4 demo dataset instead of synthetic data for more realistic modeling
- **Multiple Sequences**: Models both diagnoses and procedures simultaneously
- **Human-Readable Output**: Converts generated embeddings back to interpretable medical codes
- **Clinical Relevance**: Focuses on in-hospital mortality prediction task

### Differences from Image VAE:
- **Temporal vs Spatial**: Captures time-ordered dependencies instead of spatial patterns
- **Variable Length**: Handles sequences of different lengths
- **Categorical Data**: Works with medical codes, diagnoses, treatments
- **Generation**: Creates new realistic patient trajectories with interpretable medical codes

In [10]:
 torch.topk?

[31mDocstring:[39m
topk(input, k, dim=None, largest=True, sorted=True, *, out=None) -> (Tensor, LongTensor)

Returns the :attr:`k` largest elements of the given :attr:`input` tensor along
a given dimension.

If :attr:`dim` is not given, the last dimension of the `input` is chosen.

If :attr:`largest` is ``False`` then the `k` smallest elements are returned.

A namedtuple of `(values, indices)` is returned with the `values` and
`indices` of the largest `k` elements of each row of the `input` tensor in the
given dimension `dim`.

The boolean option :attr:`sorted` if ``True``, will make sure that the returned
`k` elements are themselves sorted

.. note::
    When using `torch.topk`, the indices of tied elements are not guaranteed to be stable
    and may vary across different invocations.

Args:
    input (Tensor): the input tensor.
    k (int): the k in "top-k"
    dim (int, optional): the dimension to sort along
    largest (bool, optional): controls whether to return largest or
     