In [None]:
!pip install medigan
!pip install torchxrayvision

In [None]:
from google.colab import drive
drive.mount('/content/drive')

#!pip install medigan
#!pip install torchxrayvision
import torch
import numpy as np
import matplotlib.pyplot as plt
from medigan import Generators
from torch.utils.data import DataLoader
import torchvision
from PIL import Image
import torchxrayvision as xrv
import torchvision.transforms as T
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision.transforms.functional as TF
import os
from drive.MyDrive.ECE661_Project.PGGAN_CHEST_XRAY.model20.network import ProGenerator, ProDiscriminator, ProGAN
from drive.MyDrive.ECE661_Project.PGGAN_CHEST_XRAY.model20.configuration import hparams
import torchvision.transforms as transforms
import torch.autograd as autograd

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
#install dependencies and load pretrained PGGAN
gens = Generators()
dataloader = gens.get_as_torch_dataloader(
    model_id=20,
    num_samples=1,
    install_dependencies=True,
    prefetch_factor=None
)

from drive.MyDrive.ECE661_Project.PGGAN_CHEST_XRAY.model20.network import ProGenerator, ProDiscriminator, ProGAN
from drive.MyDrive.ECE661_Project.PGGAN_CHEST_XRAY.model20.configuration import hparams

In [None]:
class XrayDataset(Dataset):
    def __init__(self, images_raw, labels, transform=None):
        images = torch.tensor(images_raw, dtype=torch.float32).unsqueeze(1)  #[N,1,H,W]
        if images.max() > 1.0:
            images = images.div(255.0)
        self.images = images
        self.labels = torch.tensor(labels, dtype=torch.float32)  #[N,3]
        self.transform = transform

    def __len__(self):
        return self.images.size(0)

    def __getitem__(self, idx):
        img = self.images[idx]    # [1, H, W]
        lbl = self.labels[idx]
        if self.transform:
            img = self.transform(img)

        return img, lbl

class GANDiscriminator(nn.Module):
    def __init__(self):
        super().__init__()
        base = xrv.models.DenseNet(weights="densenet121-res224-chex")  # pretrained DenseNet121
        # freeze feature extractor layers
        for p in base.parameters():
            p.requires_grad = False

        # extract feature and pool
        self.features = base.features
        self.pool = base.pool if hasattr(base, 'pool') else nn.AdaptiveAvgPool2d((1,1))

        # binary head
        self.head = nn.Sequential(
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 1))

    def forward(self, x):
        f = self.features(x)
        f = self.pool(f)
        f = f.flatten(1)
        return self.head(f)

In [None]:
pgan = (ProGAN(hparams).load_model( "models/00020_PGGAN_CHEST_XRAY/Final_Full_Model.pth", map_location='cpu', image_size=1024))
generator = pgan.generator.to(device)

for name, param in generator.named_parameters(): # Freeze all except last deconv block
    param.requires_grad = False

for name, param in generator.named_parameters():
    if name.startswith("ScaleBlocks.6.") or name.startswith("ScaleBlocks.7.") or name.startswith("toRGB.7.") or name.startswith("toRGB.8."):
        param.requires_grad = True
for name, param in generator.named_parameters():
    if param.requires_grad:
        print(name, param.requires_grad)

discriminator = GANDiscriminator().to(device)

gen_optimizer = optim.Adam(filter(lambda p: p.requires_grad, generator.parameters()),
                           lr=1e-4, betas=(0.5,0.9))

disc_optimizer = optim.Adam(discriminator.parameters(), lr=2e-5, betas=(0.5,0.9))

In [None]:
images_gray_full = np.load("/content/drive/MyDrive/ECE661_Project/CheXpert-v1.0-small/train_image.npy")
print(images_gray_full.shape)

images_gray = images_gray_full#[:5000]
labels = labels_full#[:5000]


XRAY_MEAN_1CH, XRAY_STD_1CH = [0.502], [0.290]

transform = transforms.Compose([
    transforms.Resize((224, 224), antialias=True),
    # Normalize input in [0, 1] to [-1, 1]
    transforms.Normalize(mean=XRAY_MEAN_1CH, std=XRAY_STD_1CH)
])

b_size = 32

dataset = XrayDataset(images_gray, labels, transform=transform)
dataloader = DataLoader(dataset, batch_size=b_size, shuffle=True)

In [None]:
lambda_gp = 20
d_steps = 2

generator.train()
discriminator.train()
weights = torch.tensor([0.2989, 0.5870, 0.1140], device=device).view(1, 3, 1, 1) # Not needed
iter = 0
epoch = 30
target_size = (224, 224)

print(f"Starting Training: d_steps={d_steps}, lambda_gp={lambda_gp}")

for e in range(epoch):
    for real_imgs_normalized, real_lbls in dataloader:
        real_imgs_normalized = real_imgs_normalized.to(device)
        bs = real_imgs_normalized.size(0)

        for _ in range(d_steps): #critic
            disc_optimizer.zero_grad()

            # Generate fake image batch
            z = torch.randn(bs, 512, device=device)
            fake_imgs = generator(z)
            fake_imgs = torch.tanh(fake_imgs)
            fake_imgs_resized = F.interpolate(fake_imgs, size=target_size, mode='bilinear', align_corners=False) #Resize
            fake_gray = (fake_imgs_resized * weights).sum(dim=1, keepdim=True) #grayscale
            fake_gray_normalized = torch.tanh(fake_gray) #normalize

            # Real images
            real_logits = discriminator(real_imgs_normalized)
            fake_logits = discriminator(fake_gray_normalized.detach())
            gradient_penalty = compute_gradient_penalty(discriminator, real_imgs_normalized.data, fake_gray_normalized.data, device)

            d_loss = torch.mean(fake_logits) - torch.mean(real_logits) + lambda_gp * gradient_penalty #d_loss

            d_loss.backward()
            disc_optimizer.step()


        gen_optimizer.zero_grad()

        z = torch.randn(bs, 512, device=device)
        fake_imgs = generator(z)
        fake_imgs = torch.tanh(fake_imgs)
        fake_imgs_resized = F.interpolate(fake_imgs, size=target_size, mode='bilinear', align_corners=False) #resize
        fake_gray =(fake_imgs_resized * weights).sum(dim=1, keepdim=True) #grayscale
        fake_gray_normalized = torch.tanh(fake_gray) #normalize

        fake_logits = discriminator(fake_gray_normalized)
        g_loss = -torch.mean(fake_logits) #g_loss

        g_loss.backward()
        gen_optimizer.step()

        iter += 1
        if iter % 300 == 0:

             print(f"Iter {iter}: d_loss={d_loss.item():.4f}, g_loss={g_loss.item():.4f}, gp={gradient_penalty.item():.4f}")

    print(f"Epoch {e}: d_loss={d_loss.item():.4f}, g_loss={g_loss.item():.4f}")


print(f"Final d_loss={d_loss.item():.4f}, g_loss={g_loss.item():.4f}")

# Sampling
generator.eval()
with torch.no_grad():
    z = torch.randn(8, 512, device=device)
    samples = generator(z)

samples_np = samples.cpu().numpy()
print("Generated sample batch shape:", samples_np.shape)