In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.utils as vutils
import torchvision.transforms as transforms
import tensorboardX
from torch.utils.data import DataLoader
import medmnist
from torch_fidelity import calculate_metrics
import os
import shutil
import datetime

In [6]:
writer = tensorboardX.SummaryWriter("runs/WGAN_Pneumonia")

In [7]:
class WGANLoss(nn.Module):
    def __init__(self):
        super(WGANLoss, self).__init__()
    
    def forward(self, pred, target):
        return -torch.mean(pred) if target else torch.mean(pred)

In [8]:
dataset_name = "pneumoniamnist"
info = medmnist.INFO[dataset_name]
DataClass = getattr(medmnist, info["python_class"])

image_size = 28  # PneumoniaMNIST images are 28x28
nChannels = 1  # Grayscale images

data_transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

dataset = DataClass(split="train", transform=data_transform, download=True)
batch_size = min(128, len(dataset)) if len(dataset) >= 128 else len(dataset)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)


Downloading https://zenodo.org/records/10519652/files/pneumoniamnist.npz?download=1 to /root/.medmnist/pneumoniamnist.npz


100%|██████████| 4.17M/4.17M [00:00<00:00, 4.91MB/s]


In [9]:
nz = 100
ngf = 64
ndf = 64
lr = 0.00005  
num_epochs = 50
n_critic = 5
clip_value = 0.01  

In [10]:
class Generator(nn.Module):
    def __init__(self, nz, ngf, nChannels):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(nz, ngf * 4, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf, nChannels, 4, 2, 1, bias=False),
            nn.Tanh()
        )
    
    def forward(self, input):
        return self.main(input)

In [11]:
class Critic(nn.Module):
    def __init__(self, ndf=64, nChannels=1):  # ndf: base channels, nChannels: input channels (grayscale=1)
        super(Critic, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(nChannels, ndf, 3, 2, 1, bias=False),  # Output: (28x28) → (14x14)
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf, ndf * 2, 3, 2, 1, bias=False),  # Output: (14x14) → (7x7)
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf * 2, ndf * 4, 3, 2, 1, bias=False),  # Output: (7x7) → (4x4)
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf * 4, 1, 4, 1, 0, bias=False)  # Output: (4x4) → (1x1)
        )

    def forward(self, input):
        return self.main(input).view(-1)

In [12]:
generator = Generator(nz, ngf, nChannels).cuda()
critic = Critic(ndf, nChannels).cuda()

optimizerG = optim.RMSprop(generator.parameters(), lr=lr)
optimizerD = optim.RMSprop(critic.parameters(), lr=lr)


In [13]:
os.makedirs("generated_images", exist_ok=True)

In [14]:
for epoch in range(num_epochs):
    for i, (data, _) in enumerate(dataloader):
        real_images = data.cuda()
        batch_size = real_images.size(0)

        # Train Critic
        for _ in range(n_critic):
            optimizerD.zero_grad()
            noise = torch.randn(batch_size, nz, 1, 1).cuda()
            fake_images = generator(noise).detach()

            loss_D_real = WGANLoss()(critic(real_images), True)
            loss_D_fake = WGANLoss()(critic(fake_images), False)
            loss_D = loss_D_real + loss_D_fake
            loss_D.backward()
            optimizerD.step()

            # Weight Clipping
            for p in critic.parameters():
                p.data.clamp_(-clip_value, clip_value)

        # Train Generator
        optimizerG.zero_grad()
        fake_images = generator(noise)
        loss_G = WGANLoss()(critic(fake_images), True)
        loss_G.backward()
        optimizerG.step()

        # Logging
        writer.add_scalar("Loss/Discriminator", loss_D.item(), epoch * len(dataloader) + i)
        writer.add_scalar("Loss/Generator", loss_G.item(), epoch * len(dataloader) + i)

    # Save Checkpoints
    os.makedirs("checkpoints", exist_ok=True)
    torch.save(generator.state_dict(), f'checkpoints/generator_epoch_{epoch}.pth')
    torch.save(critic.state_dict(), f'checkpoints/critic_epoch_{epoch}.pth')

    # Save Generated Images
    os.makedirs("generated_images", exist_ok=True)
    vutils.save_image(fake_images[:16], f"generated_images/epoch_{epoch}.png", normalize=True)

    # ✅ Add image to TensorBoard (Fixed for single slider view)
    writer.add_image('Generated Images', vutils.make_grid(fake_images[:16], normalize=True, scale_each=True), global_step=epoch)

    # Evaluate FID & IS
    if (epoch + 1) % 10 == 0:
        fake_images_dir = "generated_images/fake"
        real_images_dir = "generated_images/real"
        os.makedirs(fake_images_dir, exist_ok=True)
        os.makedirs(real_images_dir, exist_ok=True)

        # Clear old images
        shutil.rmtree(fake_images_dir)
        shutil.rmtree(real_images_dir)
        os.makedirs(fake_images_dir)
        os.makedirs(real_images_dir)

        for j, img in enumerate(fake_images[:batch_size]):
            vutils.save_image(img, os.path.join(fake_images_dir, f"fake_{epoch}_{j}.png"), normalize=True)

        for j, img in enumerate(real_images[:batch_size]):
            vutils.save_image(img, os.path.join(real_images_dir, f"real_{epoch}_{j}.png"), normalize=True)

        # Compute Metrics
        metrics = calculate_metrics(
            input1=fake_images_dir, 
            input2=real_images_dir, 
            cuda=True, 
            isc=True, 
            fid=True
        )

        print(f"\n=== Epoch {epoch+1} ===")
        print(f"Inception Score (IS)   : {metrics['inception_score_mean']:.4f} ± {metrics.get('inception_score_std', 0):.4f}")
        print(f"Fréchet Inception Distance (FID): {metrics['frechet_inception_distance']:.4f}")
        print("=========================")

        writer.add_scalar("Metrics/IS", metrics['inception_score_mean'], epoch)
        writer.add_scalar("Metrics/FID", metrics['frechet_inception_distance'], epoch)

writer.close()


Creating feature extractor "inception-v3-compat" with features ['2048', 'logits_unbiased']
Downloading: "https://github.com/toshas/torch-fidelity/releases/download/v0.2.0/weights-inception-2015-12-05-6726825d.pth" to /root/.cache/torch/hub/checkpoints/weights-inception-2015-12-05-6726825d.pth
100%|██████████| 91.2M/91.2M [00:00<00:00, 254MB/s]
Extracting features from input1
Looking for samples non-recursivelty in "generated_images/fake" with extensions png,jpg,jpeg
Found 100 samples
  img = torch.ByteTensor(torch.ByteStorage.from_buffer(img.tobytes())).view(height, width, 3)
Processing samples                                                         
Extracting features from input2
Looking for samples non-recursivelty in "generated_images/real" with extensions png,jpg,jpeg
Found 100 samples
Processing samples                                                         
Inception Score: 1.3727737345510767 ± 0.07945076057430353
Frechet Inception Distance: 368.793522043068



=== Epoch 10 ===
Inception Score (IS)   : 1.3728 ± 0.0795
Fréchet Inception Distance (FID): 368.7935


Creating feature extractor "inception-v3-compat" with features ['2048', 'logits_unbiased']
Extracting features from input1
Looking for samples non-recursivelty in "generated_images/fake" with extensions png,jpg,jpeg
Found 100 samples
Processing samples                                                         
Extracting features from input2
Looking for samples non-recursivelty in "generated_images/real" with extensions png,jpg,jpeg
Found 100 samples
Processing samples                                                         
Inception Score: 1.2579949406021969 ± 0.07615106716101969
Frechet Inception Distance: 493.12889186798293



=== Epoch 20 ===
Inception Score (IS)   : 1.2580 ± 0.0762
Fréchet Inception Distance (FID): 493.1289


Creating feature extractor "inception-v3-compat" with features ['2048', 'logits_unbiased']
Extracting features from input1
Looking for samples non-recursivelty in "generated_images/fake" with extensions png,jpg,jpeg
Found 100 samples
Processing samples                                                         
Extracting features from input2
Looking for samples non-recursivelty in "generated_images/real" with extensions png,jpg,jpeg
Found 100 samples
Processing samples                                                         
Inception Score: 1.2713628503282954 ± 0.1034480635000023
Frechet Inception Distance: 404.52882882312156



=== Epoch 30 ===
Inception Score (IS)   : 1.2714 ± 0.1034
Fréchet Inception Distance (FID): 404.5288


Creating feature extractor "inception-v3-compat" with features ['2048', 'logits_unbiased']
Extracting features from input1
Looking for samples non-recursivelty in "generated_images/fake" with extensions png,jpg,jpeg
Found 100 samples
Processing samples                                                         
Extracting features from input2
Looking for samples non-recursivelty in "generated_images/real" with extensions png,jpg,jpeg
Found 100 samples
Processing samples                                                         
Inception Score: 1.2843373805005955 ± 0.06183867830910506
Frechet Inception Distance: 393.09944614076903



=== Epoch 40 ===
Inception Score (IS)   : 1.2843 ± 0.0618
Fréchet Inception Distance (FID): 393.0994


Creating feature extractor "inception-v3-compat" with features ['2048', 'logits_unbiased']
Extracting features from input1
Looking for samples non-recursivelty in "generated_images/fake" with extensions png,jpg,jpeg
Found 100 samples
Processing samples                                                         
Extracting features from input2
Looking for samples non-recursivelty in "generated_images/real" with extensions png,jpg,jpeg
Found 100 samples
Processing samples                                                         
Inception Score: 1.2647722571411362 ± 0.060447090625046045



=== Epoch 50 ===
Inception Score (IS)   : 1.2648 ± 0.0604
Fréchet Inception Distance (FID): 368.7797


Frechet Inception Distance: 368.77974013272154
