# Chest X-Ray image generation using VAE

This notebook illustrates how to use the VAE module to generate X-Ray images. 

We will take the COVID-19 CXR dataset as starting point. This dataset is freely available on Kaggle and contains images of Chest X-Rays from COVID-19 patients.

## Download data

Data is available from Kaggle. If it is not already available locally, download it with the following command:

In [None]:
#!/bin/bash
#!curl -L -o ~/Downloads/covid19-radiography-database.zip https://www.kaggle.com/api/v1/datasets/download/tawsifurrahman/covid19-radiography-database

### Custom datasets

## Load data with PyHealth Datasets

Use the COVID19CXRDataset to load this data. For custom datasets, see the `BaseImageDataset` class.

In [None]:
from pyhealth.datasets import split_by_visit, get_dataloader
from pyhealth.trainer import Trainer
from pyhealth.datasets import COVID19CXRDataset
from pyhealth.models import VAE
from torchvision import transforms

import torch
import numpy as np

  from tqdm.autonotebook import trange
  import pkg_resources


In [None]:
# step 1: load signal data
root = "/home/ubuntu/Downloads/COVID-19_Radiography_Dataset"
base_dataset = COVID19CXRDataset(root)

# step 2: set task
sample_dataset = base_dataset.set_task()

# the transformation automatically normalize the pixel intensity into [0, 1]
transform = transforms.Compose([
    transforms.Lambda(lambda x: x if x.shape[0] == 3 else x.repeat(3, 1, 1)), # only use the first channel
    transforms.Resize((128, 128)),
])

def encode(sample):
    sample["path"] = transform(sample["path"])
    return sample



No config path provided, using default config
Initializing covid19_cxr dataset from /home/ubuntu/Downloads/COVID-19_Radiography_Dataset (dev mode: False)
Scanning table: covid19_cxr from /home/ubuntu/Downloads/COVID-19_Radiography_Dataset/covid19_cxr-metadata-pyhealth.csv
Setting task COVID19CXRClassification for covid19_cxr base dataset...
Generating samples with 1 worker(s)...
Collecting global event dataframe...
Collected dataframe with shape: (21165, 6)


Generating samples for COVID19CXRClassification with 1 worker: 100%|██████████| 21165/21165 [00:08<00:00, 2498.29it/s]

Label disease vocab: {'COVID': 0, 'Lung Opacity': 1, 'Normal': 2, 'Viral Pneumonia': 3}



Processing samples: 100%|██████████| 21165/21165 [01:26<00:00, 246.09it/s]

Generated 21165 samples for task COVID19CXRClassification





AttributeError: 'SampleDataset' object has no attribute 'set_transform'

In [None]:
sample_dataset.set_transform(encode)

In [None]:
# split dataset
train_dataset, val_dataset, test_dataset = split_by_visit(
    sample_dataset, [0.6, 0.2, 0.2]
)
train_dataloader = get_dataloader(train_dataset, batch_size=256, shuffle=True)
val_dataloader = get_dataloader(val_dataset, batch_size=256, shuffle=False)
test_dataloader = get_dataloader(test_dataset, batch_size=256, shuffle=False)

data = next(iter(train_dataloader))
print (data)

print (data["path"][0].shape)

print(
    "loader size: train/val/test",
    len(train_dataset),
    len(val_dataset),
    len(test_dataset),
)

In [None]:
# STEP 3: define model
model = VAE(
    dataset=sample_dataset,
    input_channel=3,
    input_size=128,
    feature_keys=["path"],
    label_key="path",
    mode="regression",
    input_type="image",
    hidden_dim = 128,
)

# STEP 4: define trainer
trainer = Trainer(model=model, device="cuda:4", metrics=["kl_divergence", "mse", "mae"])
trainer.train(
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    epochs=10,
    monitor="kl_divergence",
    monitor_criterion="min",
    optimizer_params={"lr": 1e-3},
)

In [None]:
# # STEP 5: evaluate
# print(trainer.evaluate(test_dataloader))

In [None]:
import matplotlib.pyplot as plt

# EXP 1: check the real chestxray image and the reconstructed image
X, X_rec, _ = trainer.inference(test_dataloader)

plt.figure()
plt.subplot(1, 2, 1)
plt.imshow(X[0].reshape(128, 128), cmap="gray")
plt.subplot(1, 2, 2)
plt.imshow(X_rec[0].reshape(128, 128), cmap="gray")
plt.savefig("chestxray_vae_comparison.png")

In [None]:
# EXP 2: random images
model = trainer.model
   
model.eval()
with torch.no_grad():
    x = np.random.normal(0, 1, 128)
    x = x.astype(np.float32)
    x = torch.from_numpy(x).to(trainer.device)
    rec = model.decoder(x).detach().cpu().numpy()
    rec = rec.reshape((128, 128))
    plt.figure()
    plt.imshow(rec, cmap="gray")
    plt.savefig("chestxray_vae_synthetic.png")