# Chest X-Ray Image Generation using GAN

In [None]:
# pip install pyhealth

### Load Libraries

In [None]:
from pyhealth.datasets import split_by_visit, get_dataloader
from pyhealth.trainer import Trainer
from pyhealth.datasets import COVID19CXRDataset
from pyhealth.models import VAE
from pyhealth.processors import ImageProcessor
from torchvision import transforms
from pyhealth.processors import SequenceProcessor

import torch
import numpy as np
import matplotlib.pyplot as plt

  from tqdm.autonotebook import trange
  import pkg_resources


## STEP 1: load the chest Xray data

We also prepare the data:
- resize images to 128x128
- split train/test/validation

In [None]:
# Download command (uncomment to run)
# !curl -L -o ~/Downloads/covid19-radiography-database.zip https://www.kaggle.com/api/v1/datasets/download/tawsifurrahman/covid19-radiography-database
# !unzip ~/Downloads/covid19-radiography-database.zip -d ~/Downloads/COVID-19_Radiography_Dataset

In [2]:
image_size = 128
covid19cxr_path = "~/Downloads/COVID-19_Radiography_Dataset"

base_dataset = COVID19CXRDataset(covid19cxr_path)

base_dataset.stats()


# Step 2: Set task with custom image processing for GAN
image_processor = ImageProcessor(image_size=image_size, mode="RGB")  # Resize to 128x128 for GAN

sample_dataset = base_dataset.set_task(input_processors={"image": image_processor})

No config path provided, using default config
Initializing covid19_cxr dataset from /home/ubuntu/Downloads/COVID-19_Radiography_Dataset (dev mode: False)
Scanning table: covid19_cxr from /home/ubuntu/Downloads/COVID-19_Radiography_Dataset/covid19_cxr-metadata-pyhealth.csv
Collecting global event dataframe...
Collected dataframe with shape: (21165, 6)
Dataset: covid19_cxr
Dev mode: False
Number of patients: 21165
Number of events: 21165
Setting task COVID19CXRClassification for covid19_cxr base dataset...
Generating samples with 1 worker(s)...


Generating samples for COVID19CXRClassification with 1 worker: 100%|██████████| 21165/21165 [00:08<00:00, 2637.68it/s]

Label disease vocab: {'COVID': 0, 'Lung Opacity': 1, 'Normal': 2, 'Viral Pneumonia': 3}



Processing samples: 100%|██████████| 21165/21165 [01:18<00:00, 270.15it/s]

Generated 21165 samples for task COVID19CXRClassification





In [None]:

# split dataset
train_dataset, val_dataset, test_dataset = split_by_visit(
    sample_dataset, [0.8, 0.1, 0.1]
)

train_dataloader = get_dataloader(train_dataset, batch_size=256, shuffle=True)
val_dataloader = get_dataloader(val_dataset, batch_size=256, shuffle=False)
test_dataloader = get_dataloader(test_dataset, batch_size=256, shuffle=False)

data = next(iter(train_dataloader))

print(data["image"][0].shape)

print(
    "loader size: train/val/test",
    len(train_dataset),
    len(val_dataset),
    len(test_dataset),
)

torch.Size([3, 128, 128])
loader size: train/val/test 16932 2116 2117


### STEP3: define the GAN model

In [None]:
from pyhealth.models import GAN

model = GAN(
    input_channel=3,
    input_size=128,
    hidden_dim=256,
)

### STEP4: training the GAN model in an adversarial way

In [None]:
import torch
from tqdm import tqdm

# Loss function
loss = torch.nn.BCELoss()

opt_G = torch.optim.AdamW(model.generator.parameters(), lr=1e-3)
opt_D = torch.optim.AdamW(model.discriminator.parameters(), lr=1e-4)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

curve_D, curve_G = [], []

for epoch in range(50):
    curve_G.append(0)
    curve_D.append(0)
    for batch in tqdm(train_dataloader):
        
        """ train discriminator """
        
        opt_D.zero_grad()
        
        real_imgs = batch["image"].to(device)
        
        batch_size = real_imgs.shape[0]
        
        fake_imgs = model.generate_fake(batch_size, device)
        
        real_loss = loss(model.discriminator(real_imgs), torch.ones(batch_size, 1).to(device))
        fake_loss = loss(model.discriminator(fake_imgs.detach()), torch.zeros(batch_size, 1).to(device))
        loss_D = (real_loss + fake_loss) / 2
        
        loss_D.backward()
        opt_D.step()
        
        """ train generator """
        
        opt_G.zero_grad()
        loss_G = loss(model.discriminator(fake_imgs), torch.ones(batch_size, 1).to(device))
        
        loss_G.backward()
        opt_G.step()
        
        curve_G[-1] += loss_G.item()
        curve_D[-1] += loss_D.item()
        
    print(f"epoch: {epoch} --- loss of G: {curve_G[-1]}, loss of D: {curve_D[-1]}")

### EXP 2: synthesize random images

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt

model.eval()

with torch.no_grad():
    fake_imgs = model.generate_fake(1, device).detach().cpu()
    plt.imshow(fake_imgs[0].permute(1, 2, 0).clamp(0, 1))  # RGB image
    plt.title("Generated Chest X-Ray")
    plt.axis('off')
    plt.show()