# Chest X-Ray Image Generation using VAE

This notebook illustrates how to use the VAE module to generate X-Ray images, including the new features for conditional generation and time-series support.

We will take the COVID-19 CXR dataset as starting point. This dataset is freely available on Kaggle and contains images of Chest X-Rays from COVID-19 patients.

## Download Data

Data is available from Kaggle. If it is not already available locally, download it with the following command:

In [1]:
# Download command (uncomment to run)
# !curl -L -o ~/Downloads/covid19-radiography-database.zip https://www.kaggle.com/api/v1/datasets/download/tawsifurrahman/covid19-radiography-database
# !unzip ~/Downloads/covid19-radiography-database.zip -d ~/Downloads/COVID-19_Radiography_Dataset

## Load Data with PyHealth Datasets

Use the COVID19CXRDataset to load this data. For custom datasets, see the `BaseImageDataset` class.

In [2]:
from pyhealth.datasets import split_by_visit, get_dataloader
from pyhealth.trainer import Trainer
from pyhealth.datasets import COVID19CXRDataset
from pyhealth.models import VAE
from pyhealth.processors import ImageProcessor
from torchvision import transforms

import torch
import numpy as np
import matplotlib.pyplot as plt

  from tqdm.autonotebook import trange
  import pkg_resources


In [None]:
# Step 1: Load data
root = "/home/ubuntu/Downloads/COVID-19_Radiography_Dataset"
base_dataset = COVID19CXRDataset(root)

# Step 2: Set task with custom image processing for VAE
image_processor = ImageProcessor(image_size=128, mode="L")  # Resize to 128x128 for VAE
sample_dataset = base_dataset.set_task(input_processors={"image": image_processor})

No config path provided, using default config
Initializing covid19_cxr dataset from /home/ubuntu/Downloads/COVID-19_Radiography_Dataset (dev mode: False)
Scanning table: covid19_cxr from /home/ubuntu/Downloads/COVID-19_Radiography_Dataset/covid19_cxr-metadata-pyhealth.csv
Setting task COVID19CXRClassification for covid19_cxr base dataset...
Generating samples with 1 worker(s)...
Collecting global event dataframe...
Collected dataframe with shape: (21165, 6)


Generating samples for COVID19CXRClassification with 1 worker: 100%|██████████| 21165/21165 [00:08<00:00, 2354.79it/s]

Label disease vocab: {'COVID': 0, 'Lung Opacity': 1, 'Normal': 2, 'Viral Pneumonia': 3}



Processing samples: 100%|██████████| 21165/21165 [01:30<00:00, 233.88it/s]

Generated 21165 samples for task COVID19CXRClassification





AttributeError: 'SampleDataset' object has no attribute 'set_transform'

In [5]:
# Split dataset
train_dataset, val_dataset, test_dataset = split_by_visit(
    sample_dataset, [0.6, 0.2, 0.2]
)
train_dataloader = get_dataloader(train_dataset, batch_size=256, shuffle=True)
val_dataloader = get_dataloader(val_dataset, batch_size=256, shuffle=False)
test_dataloader = get_dataloader(test_dataset, batch_size=256, shuffle=False)

# Check data
data = next(iter(train_dataloader))
print("Data keys:", data.keys())
print("Image shape:", data["image"][0].shape)
print(f"Dataset sizes - Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")

Data keys: dict_keys(['image', 'disease'])
Dataset sizes - Train: 12699, Val: 4233, Test: 4233


## Basic VAE Training (Image Generation)

Train a standard VAE for unconditional image generation.

In [6]:
# Define model
model = VAE(
    dataset=sample_dataset,
    feature_keys=["image"],
    label_key="image",
    mode="regression",
    input_type="image",
    input_channel=1,  # Grayscale images from COVID dataset
    input_size=128,  # Resized for VAE
    hidden_dim=128,
)

# Define trainer
trainer = Trainer(model=model, device="cuda" if torch.cuda.is_available() else "cpu", 
                 metrics=["kl_divergence", "mse", "mae"])

# Train (reduce epochs for demo)
trainer.train(
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    epochs=5,  # Reduced for demo
    monitor="kl_divergence",
    monitor_criterion="min",
    optimizer_params={"lr": 1e-3},
)

VAE(
  (encoder1): Sequential(
    (0): ResBlock2D(
      (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ELU(alpha=1.0)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (downsampler): Sequential(
        (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (dropout): Dropout(p=0.5, inplace=False)
    )
    (1): ResBlock2D(
      (conv1): Conv2d(16, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ELU(al

Epoch 0 / 5:   0%|          | 0/50 [00:00<?, ?it/s]


KeyError: 'path'

In [None]:
# Evaluate
print("Evaluation results:")
eval_results = trainer.evaluate(test_dataloader)
print(eval_results)

## Experiment 1: Real vs Reconstructed Images

In [None]:
# Get real and reconstructed images
X, X_rec, _ = trainer.inference(test_dataloader)

# Plot comparison
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
ax1.imshow(X[0].reshape(128, 128), cmap="gray")
ax1.set_title("Real Chest X-Ray")
ax1.axis('off')

ax2.imshow(X_rec[0].reshape(128, 128), cmap="gray")
ax2.set_title("Reconstructed by VAE")
ax2.axis('off')

plt.tight_layout()
plt.savefig("chestxray_vae_comparison.png", dpi=150)
plt.show()

## Experiment 2: Random Image Generation

Generate new images by sampling from the latent space.

In [None]:
# Generate synthetic images
model = trainer.model
model.eval()

with torch.no_grad():
    # Sample from latent space
    z = torch.randn(1, model.hidden_dim).to(model.device)
    
    # Reshape for decoder (add spatial dims)
    z = z.unsqueeze(2).unsqueeze(3)
    
    # Generate image
    generated = model.decoder(z).detach().cpu().numpy()
    
    # Plot
    plt.figure(figsize=(5, 5))
    plt.imshow(generated[0].reshape(128, 128), cmap="gray")
    plt.title("Generated Chest X-Ray")
    plt.axis('off')
    plt.savefig("chestxray_vae_synthetic.png", dpi=150)
    plt.show()

## New Feature: Conditional VAE

Generate images conditioned on additional features (e.g., diagnosis codes).

In [None]:
# Create dataset with conditional features
samples_with_conditions = [
    {
        "patient_id": "patient-0",
        "visit_id": "visit-0",
        "path": torch.rand(3, 128, 128),  # Dummy image
        "conditions": ["COVID-19", "pneumonia"],
        "label": 0,
    },
    {
        "patient_id": "patient-1",
        "visit_id": "visit-1",
        "path": torch.rand(3, 128, 128),
        "conditions": ["normal"],
        "label": 1,
    },
]

from pyhealth.datasets import SampleDataset

cond_dataset = SampleDataset(
    samples=samples_with_conditions,
    input_schema={"path": "tensor", "conditions": "sequence"},
    output_schema={"label": "binary"},
    dataset_name="conditional_demo",
)

# Conditional VAE model
cond_model = VAE(
    dataset=cond_dataset,
    feature_keys=["path"],
    label_key="label",
    mode="binary",
    input_type="image",
    input_channel=3,
    input_size=128,
    hidden_dim=64,
    conditional_feature_keys=["conditions"],  # New parameter
)

print("Conditional VAE created with embedding model for conditions")
print(f"Has embedding model: {hasattr(cond_model, 'embedding_model')}")

## New Feature: Time-Series VAE

Use VAE for time-series data reconstruction and generation.

In [None]:
# Create time-series dataset
ts_samples = [
    {
        "patient_id": "patient-0",
        "visit_id": "visit-0",
        "visits": ["diagnosis1", "diagnosis2", "procedure1"],
        "label": 1.0,
    },
    {
        "patient_id": "patient-1",
        "visit_id": "visit-1",
        "visits": ["diagnosis3"],
        "label": 0.5,
    },
]

ts_dataset = SampleDataset(
    samples=ts_samples,
    input_schema={"visits": "sequence"},
    output_schema={"label": "regression"},
    dataset_name="timeseries_demo",
)

# Time-series VAE model
ts_model = VAE(
    dataset=ts_dataset,
    feature_keys=["visits"],
    label_key="label",
    mode="regression",
    input_type="timeseries",  # New parameter
    hidden_dim=64,
)

print("Time-series VAE created")
print(f"Input type: {ts_model.input_type}")
print(f"Has embedding model: {hasattr(ts_model, 'embedding_model')}")
print(f"Has RNN encoder: {hasattr(ts_model, 'encoder_rnn')}")

## Summary

The enhanced VAE now supports:
- **Image generation**: Unconditional and conditional
- **Time-series modeling**: For sequential medical data
- **Flexible embeddings**: Integrated with PyHealth's EmbeddingModel

Key new parameters:
- `input_type`: 'image' or 'timeseries'
- `conditional_feature_keys`: List of keys for conditional generation

This enables more sophisticated generative models for medical data analysis and synthesis.