In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.utils as vutils
import torchvision.transforms as transforms
import tensorboardX
from torch.utils.data import DataLoader
import medmnist
from torch_fidelity import calculate_metrics
import os
import shutil
import torch.autograd as autograd

In [6]:
def compute_gradient_penalty(critic, real_images, fake_images):
    alpha = torch.rand(real_images.size(0), 1, 1, 1).cuda()
    interpolates = (alpha * real_images + (1 - alpha) * fake_images).requires_grad_(True)
    critic_interpolates = critic(interpolates)
    gradients = autograd.grad(outputs=critic_interpolates, inputs=interpolates,
                              grad_outputs=torch.ones(critic_interpolates.size()).cuda(),
                              create_graph=True, retain_graph=True, only_inputs=True)[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

class WGAN_GP_Loss(nn.Module):
    def __init__(self):
        super(WGAN_GP_Loss, self).__init__()
    def forward(self, pred, target):
        return -torch.mean(pred) if target else torch.mean(pred)



In [7]:
dataset_name = "pneumoniamnist"
info = medmnist.INFO[dataset_name]
DataClass = getattr(medmnist, info["python_class"])

image_size = 28  
nChannels = 1  

transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

dataset = DataClass(split="train", transform=transform, download=True)
batch_size = 128
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

Downloading https://zenodo.org/records/10519652/files/pneumoniamnist.npz?download=1 to /root/.medmnist/pneumoniamnist.npz


100%|██████████| 4.17M/4.17M [00:00<00:00, 4.74MB/s]


In [8]:
nz = 100
ngf = 64
ndf = 64
lr = 0.0001  
num_epochs = 50
n_critic = 5
gp_lambda = 10  

In [9]:
class Generator(nn.Module):
    def __init__(self, nz, ngf, nChannels):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(nz, ngf * 4, 4, 1, 0, bias=False),  # Output: 4x4
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf * 4, ngf * 2, 3, 2, 1, bias=False),  # Output: 8x8
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf * 2, ngf, 3, 2, 1, bias=False),  # Output: 16x16
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf, nChannels, 4, 2, 1, bias=False),  # Output: 32x32
            nn.Tanh(),

            nn.Upsample(size=(28, 28), mode='bilinear', align_corners=True)  # Resize to 28x28
        )

    def forward(self, input):
        return self.main(input)


In [10]:
class Critic(nn.Module):
    def __init__(self, ndf=64, nChannels=1):
        super(Critic, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(nChannels, ndf, 3, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf, ndf * 2, 3, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 2, ndf * 4, 3, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 4, 1, 4, 1, 0, bias=False)
        )
    def forward(self, input):
        return self.main(input).view(-1)

In [11]:
generator = Generator(nz, ngf, nChannels).cuda()
critic = Critic(ndf, nChannels).cuda()
optimizerG = optim.Adam(generator.parameters(), lr=lr, betas=(0, 0.9))
optimizerD = optim.Adam(critic.parameters(), lr=lr, betas=(0, 0.9))


In [12]:
writer = tensorboardX.SummaryWriter("runs/WGAN_GP_Pneumonia")


In [13]:
os.makedirs("checkpoints", exist_ok=True)
os.makedirs("generated_images", exist_ok=True)

In [14]:
for epoch in range(num_epochs):
    for i, (data, _) in enumerate(dataloader):
        real_images = data.cuda()
        batch_size = real_images.size(0)

        # Train Critic
        for _ in range(n_critic):
            optimizerD.zero_grad()
            noise = torch.randn(batch_size, nz, 1, 1).cuda()
            fake_images = generator(noise).detach()

            loss_D_real = WGAN_GP_Loss()(critic(real_images), True)
            loss_D_fake = WGAN_GP_Loss()(critic(fake_images), False)
            gradient_penalty = compute_gradient_penalty(critic, real_images, fake_images)
            loss_D = loss_D_real + loss_D_fake + gp_lambda * gradient_penalty
            loss_D.backward()
            optimizerD.step()

        # Train Generator
        optimizerG.zero_grad()
        fake_images = generator(noise)
        loss_G = WGAN_GP_Loss()(critic(fake_images), True)
        loss_G.backward()
        optimizerG.step()

        # Logging
        writer.add_scalar("Loss/Discriminator", loss_D.item(), epoch * len(dataloader) + i)
        writer.add_scalar("Loss/Generator", loss_G.item(), epoch * len(dataloader) + i)

    # Save Checkpoints
    os.makedirs("checkpoints", exist_ok=True)
    torch.save(generator.state_dict(), f'checkpoints/generator_epoch_{epoch}.pth')
    torch.save(critic.state_dict(), f'checkpoints/critic_epoch_{epoch}.pth')

    # Save Generated Images
    os.makedirs("generated_images", exist_ok=True)
    vutils.save_image(fake_images[:16], f"generated_images/epoch_{epoch}.png", normalize=True)
    writer.add_image('Generated Images', vutils.make_grid(fake_images[:16], normalize=True, scale_each=True), global_step=epoch)


    # Evaluate FID & IS every 10 epochs
    if (epoch + 1) % 10 == 0:
        fake_images_dir = "generated_images/fake"
        real_images_dir = "generated_images/real"
        os.makedirs(fake_images_dir, exist_ok=True)
        os.makedirs(real_images_dir, exist_ok=True)

        # Clear old images
        shutil.rmtree(fake_images_dir)
        shutil.rmtree(real_images_dir)
        os.makedirs(fake_images_dir)
        os.makedirs(real_images_dir)

        for j, img in enumerate(fake_images[:batch_size]):
            vutils.save_image(img, os.path.join(fake_images_dir, f"fake_{epoch}_{j}.png"), normalize=True)

        for j, img in enumerate(real_images[:batch_size]):
            vutils.save_image(img, os.path.join(real_images_dir, f"real_{epoch}_{j}.png"), normalize=True)

        # Compute Metrics
        metrics = calculate_metrics(
            input1=fake_images_dir, 
            input2=real_images_dir, 
            cuda=True, 
            isc=True, 
            fid=True
        )

        print(f"\n=== Epoch {epoch+1} ===")
        print(f"Inception Score (IS)   : {metrics['inception_score_mean']:.4f} ± {metrics.get('inception_score_std', 0):.4f}")
        print(f"Fréchet Inception Distance (FID): {metrics['frechet_inception_distance']:.4f}")
        print("=========================")

        writer.add_scalar("Metrics/IS", metrics['inception_score_mean'], epoch)
        writer.add_scalar("Metrics/FID", metrics['frechet_inception_distance'], epoch)

writer.close()

Creating feature extractor "inception-v3-compat" with features ['2048', 'logits_unbiased']
Downloading: "https://github.com/toshas/torch-fidelity/releases/download/v0.2.0/weights-inception-2015-12-05-6726825d.pth" to /root/.cache/torch/hub/checkpoints/weights-inception-2015-12-05-6726825d.pth
100%|██████████| 91.2M/91.2M [00:00<00:00, 250MB/s]
Extracting features from input1
Looking for samples non-recursivelty in "generated_images/fake" with extensions png,jpg,jpeg
Found 100 samples
  img = torch.ByteTensor(torch.ByteStorage.from_buffer(img.tobytes())).view(height, width, 3)
Processing samples                                                         
Extracting features from input2
Looking for samples non-recursivelty in "generated_images/real" with extensions png,jpg,jpeg
Found 100 samples
Processing samples                                                         
Inception Score: 1.5273970943593898 ± 0.12476322417810928
Frechet Inception Distance: 315.8864258190897



=== Epoch 10 ===
Inception Score (IS)   : 1.5274 ± 0.1248
Fréchet Inception Distance (FID): 315.8864


Creating feature extractor "inception-v3-compat" with features ['2048', 'logits_unbiased']
Extracting features from input1
Looking for samples non-recursivelty in "generated_images/fake" with extensions png,jpg,jpeg
Found 100 samples
Processing samples                                                         
Extracting features from input2
Looking for samples non-recursivelty in "generated_images/real" with extensions png,jpg,jpeg
Found 100 samples
Processing samples                                                         
Inception Score: 1.6529498620246321 ± 0.13759774574990646
Frechet Inception Distance: 312.0202794101396



=== Epoch 20 ===
Inception Score (IS)   : 1.6529 ± 0.1376
Fréchet Inception Distance (FID): 312.0203


Creating feature extractor "inception-v3-compat" with features ['2048', 'logits_unbiased']
Extracting features from input1
Looking for samples non-recursivelty in "generated_images/fake" with extensions png,jpg,jpeg
Found 100 samples
Processing samples                                                         
Extracting features from input2
Looking for samples non-recursivelty in "generated_images/real" with extensions png,jpg,jpeg
Found 100 samples
Processing samples                                                         
Inception Score: 1.7129441604877846 ± 0.150796951161115
Frechet Inception Distance: 319.5476884041691



=== Epoch 30 ===
Inception Score (IS)   : 1.7129 ± 0.1508
Fréchet Inception Distance (FID): 319.5477


Creating feature extractor "inception-v3-compat" with features ['2048', 'logits_unbiased']
Extracting features from input1
Looking for samples non-recursivelty in "generated_images/fake" with extensions png,jpg,jpeg
Found 100 samples
Processing samples                                                         
Extracting features from input2
Looking for samples non-recursivelty in "generated_images/real" with extensions png,jpg,jpeg
Found 100 samples
Processing samples                                                         
Inception Score: 1.7187430765533385 ± 0.0859940585032735
Frechet Inception Distance: 310.2978322549363



=== Epoch 40 ===
Inception Score (IS)   : 1.7187 ± 0.0860
Fréchet Inception Distance (FID): 310.2978


Creating feature extractor "inception-v3-compat" with features ['2048', 'logits_unbiased']
Extracting features from input1
Looking for samples non-recursivelty in "generated_images/fake" with extensions png,jpg,jpeg
Found 100 samples
Processing samples                                                         
Extracting features from input2
Looking for samples non-recursivelty in "generated_images/real" with extensions png,jpg,jpeg
Found 100 samples
Processing samples                                                         
Inception Score: 1.8842514485646145 ± 0.1369871670951578



=== Epoch 50 ===
Inception Score (IS)   : 1.8843 ± 0.1370
Fréchet Inception Distance (FID): 359.8660


Frechet Inception Distance: 359.8659668463393
