# Chest X-Ray Image Generation using VAE

This notebook illustrates how to use the VAE module to generate X-Ray images, including the new features for conditional generation and time-series support.

We will take the COVID-19 CXR dataset as starting point. This dataset is freely available on Kaggle and contains images of Chest X-Rays from COVID-19 patients.

## Download Data

Data is available from Kaggle. If it is not already available locally, download it with the following command:

In [None]:
# Download command (uncomment to run)
# !curl -L -o ~/Downloads/covid19-radiography-database.zip https://www.kaggle.com/api/v1/datasets/download/tawsifurrahman/covid19-radiography-database
# !unzip ~/Downloads/covid19-radiography-database.zip -d ~/Downloads/COVID-19_Radiography_Dataset

## Load Data with PyHealth Datasets

Use the COVID19CXRDataset to load this data. For custom datasets, see the `BaseImageDataset` class.

In [None]:
from pyhealth.datasets import split_by_visit, get_dataloader
from pyhealth.trainer import Trainer
from pyhealth.datasets import COVID19CXRDataset
from pyhealth.models import VAE
from torchvision import transforms

import torch
import numpy as np
import matplotlib.pyplot as plt

In [None]:
# Step 1: Load data
root = "/home/ubuntu/Downloads/COVID-19_Radiography_Dataset"
base_dataset = COVID19CXRDataset(root)

# Step 2: Set task
sample_dataset = base_dataset.set_task()

# Transformations to normalize pixel intensity into [0, 1]
transform = transforms.Compose([
    transforms.Lambda(lambda x: x if x.shape[0] == 3 else x.repeat(3, 1, 1)),  # Use first channel if needed
    transforms.Resize((128, 128)),
])

def encode(sample):
    sample["path"] = transform(sample["path"])
    return sample

sample_dataset.set_transform(encode)

In [None]:
# Split dataset
train_dataset, val_dataset, test_dataset = split_by_visit(
    sample_dataset, [0.6, 0.2, 0.2]
)
train_dataloader = get_dataloader(train_dataset, batch_size=256, shuffle=True)
val_dataloader = get_dataloader(val_dataset, batch_size=256, shuffle=False)
test_dataloader = get_dataloader(test_dataset, batch_size=256, shuffle=False)

# Check data
data = next(iter(train_dataloader))
print("Data keys:", data.keys())
print("Image shape:", data["path"][0].shape)
print(f"Dataset sizes - Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")

## Basic VAE Training (Image Generation)

Train a standard VAE for unconditional image generation.

In [None]:
# Define model
model = VAE(
    dataset=sample_dataset,
    feature_keys=["path"],
    label_key="path",
    mode="regression",
    input_type="image",
    input_channel=3,
    input_size=128,
    hidden_dim=128,
)

# Define trainer
trainer = Trainer(model=model, device="cuda" if torch.cuda.is_available() else "cpu", 
                 metrics=["kl_divergence", "mse", "mae"])

# Train (reduce epochs for demo)
trainer.train(
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    epochs=5,  # Reduced for demo
    monitor="kl_divergence",
    monitor_criterion="min",
    optimizer_params={"lr": 1e-3},
)

In [None]:
# Evaluate
print("Evaluation results:")
eval_results = trainer.evaluate(test_dataloader)
print(eval_results)

## Experiment 1: Real vs Reconstructed Images

In [None]:
# Get real and reconstructed images
X, X_rec, _ = trainer.inference(test_dataloader)

# Plot comparison
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
ax1.imshow(X[0].reshape(128, 128, 3)[:, :, 0], cmap="gray")
ax1.set_title("Real Chest X-Ray")
ax1.axis('off')

ax2.imshow(X_rec[0].reshape(128, 128, 3)[:, :, 0], cmap="gray")
ax2.set_title("Reconstructed by VAE")
ax2.axis('off')

plt.tight_layout()
plt.savefig("chestxray_vae_comparison.png", dpi=150)
plt.show()

## Experiment 2: Random Image Generation

Generate new images by sampling from the latent space.

In [None]:
# Generate synthetic images
model = trainer.model
model.eval()

with torch.no_grad():
    # Sample from latent space
    z = torch.randn(1, 128).to(model.device)
    
    # Reshape for decoder (add spatial dims)
    z = z.unsqueeze(2).unsqueeze(3)
    
    # Generate image
    generated = model.decoder(z).detach().cpu().numpy()
    
    # Plot
    plt.figure(figsize=(5, 5))
    plt.imshow(generated[0].reshape(128, 128, 3)[:, :, 0], cmap="gray")
    plt.title("Generated Chest X-Ray")
    plt.axis('off')
    plt.savefig("chestxray_vae_synthetic.png", dpi=150)
    plt.show()

## New Feature: Conditional VAE

Generate images conditioned on additional features (e.g., diagnosis codes).

In [None]:
# Create dataset with conditional features
samples_with_conditions = [
    {
        "patient_id": "patient-0",
        "visit_id": "visit-0",
        "path": torch.rand(3, 128, 128),  # Dummy image
        "conditions": ["COVID-19", "pneumonia"],
        "label": 0,
    },
    {
        "patient_id": "patient-1",
        "visit_id": "visit-1",
        "path": torch.rand(3, 128, 128),
        "conditions": ["normal"],
        "label": 1,
    },
]

from pyhealth.datasets import SampleDataset

cond_dataset = SampleDataset(
    samples=samples_with_conditions,
    input_schema={"path": "tensor", "conditions": "sequence"},
    output_schema={"label": "binary"},
    dataset_name="conditional_demo",
)

# Conditional VAE model
cond_model = VAE(
    dataset=cond_dataset,
    feature_keys=["path"],
    label_key="label",
    mode="binary",
    input_type="image",
    input_channel=3,
    input_size=128,
    hidden_dim=64,
    conditional_feature_keys=["conditions"],  # New parameter
)

print("Conditional VAE created with embedding model for conditions")
print(f"Has embedding model: {hasattr(cond_model, 'embedding_model')}")

## New Feature: Time-Series VAE

Use VAE for time-series data reconstruction and generation.

In [None]:
# Create time-series dataset
ts_samples = [
    {
        "patient_id": "patient-0",
        "visit_id": "visit-0",
        "visits": ["diagnosis1", "diagnosis2", "procedure1"],
        "label": 1.0,
    },
    {
        "patient_id": "patient-1",
        "visit_id": "visit-1",
        "visits": ["diagnosis3"],
        "label": 0.5,
    },
]

ts_dataset = SampleDataset(
    samples=ts_samples,
    input_schema={"visits": "sequence"},
    output_schema={"label": "regression"},
    dataset_name="timeseries_demo",
)

# Time-series VAE model
ts_model = VAE(
    dataset=ts_dataset,
    feature_keys=["visits"],
    label_key="label",
    mode="regression",
    input_type="timeseries",  # New parameter
    hidden_dim=64,
)

print("Time-series VAE created")
print(f"Input type: {ts_model.input_type}")
print(f"Has embedding model: {hasattr(ts_model, 'embedding_model')}")
print(f"Has RNN encoder: {hasattr(ts_model, 'encoder_rnn')}")

## Summary

The enhanced VAE now supports:
- **Image generation**: Unconditional and conditional
- **Time-series modeling**: For sequential medical data
- **Flexible embeddings**: Integrated with PyHealth's EmbeddingModel

Key new parameters:
- `input_type`: 'image' or 'timeseries'
- `conditional_feature_keys`: List of keys for conditional generation

This enables more sophisticated generative models for medical data analysis and synthesis.