In [10]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.utils as vutils
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
import medmnist
from torch_fidelity import calculate_metrics

In [11]:
writer = SummaryWriter("runs/LS_GAN_PneumoniaMNIST")

In [12]:
class LSLoss(nn.Module):
    def __init__(self):
        super(LSLoss, self).__init__()
    
    def forward(self, pred, target):
        return torch.mean((pred - target) ** 2)

In [13]:
dataset_name = "pneumoniamnist"
info = medmnist.INFO[dataset_name]
DataClass = getattr(medmnist, info["python_class"])

image_size = 28  # PneumoniaMNIST images are 28x28
nChannels = 1  # Grayscale images

data_transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

dataset = DataClass(split="train", transform=data_transform, download=True)
batch_size = min(128, len(dataset)) if len(dataset) >= 128 else len(dataset)  # Adjust batch size dynamically
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)


Downloading https://zenodo.org/records/10519652/files/pneumoniamnist.npz?download=1 to /root/.medmnist/pneumoniamnist.npz


100%|██████████| 4.17M/4.17M [00:00<00:00, 4.71MB/s]


In [15]:
nz = 100
ngf = 64
ndf = 64
lr = 0.0002
num_epochs = 50

In [16]:
class Generator(nn.Module):
    def __init__(self, nz, ngf, nChannels):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(nz, ngf * 4, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(ngf, nChannels, 4, 2, 1, bias=False),
            nn.Tanh()
        )
    
    def forward(self, input):
        return self.main(input)

In [17]:
class Discriminator(nn.Module):
    def __init__(self, ndf, nChannels):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(nChannels, ndf, 3, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(ndf, ndf * 2, 3, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(ndf * 2, 1, 7, 1, 0, bias=False),
            nn.Sigmoid()
        )
    
    def forward(self, input):
        return self.main(input)

In [18]:
generator = Generator(nz, ngf, nChannels).cuda()
discriminator = Discriminator(ndf, nChannels).cuda()

# Optimizers
optimizerG = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizerD = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

criterion = LSLoss()


In [20]:
import os
os.makedirs("generated_images", exist_ok=True)


In [22]:
import datetime

for epoch in range(num_epochs):
    
    # Train the model
    for i, (data, _) in enumerate(dataloader):
        real_images = data.cuda()
        batch_size = real_images.size(0)
        real_labels = torch.ones(batch_size, 1, 1, 1).cuda()
        fake_labels = torch.zeros(batch_size, 1, 1, 1).cuda()

        # Train Discriminator
        optimizerD.zero_grad()
        output_real = discriminator(real_images)
        loss_real = criterion(output_real, real_labels)

        noise = torch.randn(batch_size, nz, 1, 1).cuda()
        fake_images = generator(noise)
        output_fake = discriminator(fake_images.detach())
        loss_fake = criterion(output_fake, fake_labels)

        loss_D = (loss_real + loss_fake) / 2
        loss_D.backward()
        optimizerD.step()

        # Train Generator
        optimizerG.zero_grad()
        output_fake = discriminator(fake_images)
        loss_G = criterion(output_fake, real_labels)
        loss_G.backward()
        optimizerG.step()

        writer.add_scalar("Loss/Discriminator", loss_D.item(), epoch * len(dataloader) + i)
        writer.add_scalar("Loss/Generator", loss_G.item(), epoch * len(dataloader) + i)


        # Save generated images every 100 batches
        if i % 100 == 0:
            writer.add_images("Generated Images", fake_images[:16], global_step=epoch)
            vutils.save_image(fake_images[:16], f"generated_images/epoch_{epoch}_batch_{i}.png", normalize=True)

    # Save Model Checkpoints
    torch.save(generator.state_dict(), f'generator_epoch_{epoch}.pth')
    torch.save(discriminator.state_dict(), f'discriminator_epoch_{epoch}.pth')

    # Save generated images for FID/IS evaluation
    fake_images_dir = "generated_images/fake"
    real_images_dir = "generated_images/real"
    os.makedirs(fake_images_dir, exist_ok=True)
    os.makedirs(real_images_dir, exist_ok=True)

    for j, img in enumerate(fake_images[:batch_size]):
        vutils.save_image(img, os.path.join(fake_images_dir, f"fake_{epoch}_{j}.png"), normalize=True)

    for j, img in enumerate(real_images[:batch_size]):
        vutils.save_image(img, os.path.join(real_images_dir, f"real_{epoch}_{j}.png"), normalize=True)

    # Print only after every 10 epochs
    if (epoch + 1) % 10 == 0:
        

        timestamp = datetime.datetime.now().strftime("%H:%M:%S")
        print(f"\n{'='*50}")
        print(f" Epoch {epoch+1}/{num_epochs} | Training Progress ")
        print(f"{'='*50}")
        print(f"[{timestamp}] Loss D: {loss_D.item():.4f} | Loss G: {loss_G.item():.4f}")

        # Evaluate IS and FID
        metrics = calculate_metrics(
            input1=fake_images_dir, 
            input2=real_images_dir, 
            cuda=True, isc=True, fid=True
        )

        print(f"\n{'-'*50}")
        print(f" Evaluation Metrics (Epoch {epoch+1})")
        print(f"{'-'*50}")
        print(f" Inception Score (IS) : {metrics['inception_score_mean']:.4f} ± {metrics['inception_score_std']:.4f}")
        print(f" Fréchet Inception Distance (FID): {metrics['frechet_inception_distance']:.4f}")
        print(f"{'='*50}\n")

        writer.add_scalar("Metrics/Inception Score", metrics['inception_score_mean'], epoch)
        writer.add_scalar("Metrics/FID", metrics['frechet_inception_distance'], epoch)



 Epoch 10/50 | Training Progress 
[05:57:50] Loss D: 0.0003 | Loss G: 0.9758


Creating feature extractor "inception-v3-compat" with features ['2048', 'logits_unbiased']
Downloading: "https://github.com/toshas/torch-fidelity/releases/download/v0.2.0/weights-inception-2015-12-05-6726825d.pth" to /root/.cache/torch/hub/checkpoints/weights-inception-2015-12-05-6726825d.pth
100%|██████████| 91.2M/91.2M [00:00<00:00, 259MB/s]
Extracting features from input1
Looking for samples non-recursivelty in "generated_images/fake" with extensions png,jpg,jpeg
Found 1000 samples
  img = torch.ByteTensor(torch.ByteStorage.from_buffer(img.tobytes())).view(height, width, 3)
Processing samples                                                           
Extracting features from input2
Looking for samples non-recursivelty in "generated_images/real" with extensions png,jpg,jpeg
Found 1000 samples
Processing samples                                                           
Inception Score: 2.122189435235305 ± 0.08656892939482286
Frechet Inception Distance: 287.01383491032703



--------------------------------------------------
 Evaluation Metrics (Epoch 10)
--------------------------------------------------
 Inception Score (IS) : 2.1222 ± 0.0866
 Fréchet Inception Distance (FID): 287.0138


 Epoch 20/50 | Training Progress 
[05:58:23] Loss D: 0.1012 | Loss G: 0.8669


Creating feature extractor "inception-v3-compat" with features ['2048', 'logits_unbiased']
Extracting features from input1
Looking for samples non-recursivelty in "generated_images/fake" with extensions png,jpg,jpeg
Found 2000 samples
Processing samples                                                           
Extracting features from input2
Looking for samples non-recursivelty in "generated_images/real" with extensions png,jpg,jpeg
Found 2000 samples
Processing samples                                                           
Inception Score: 2.4046504730089615 ± 0.06989617184050663
Frechet Inception Distance: 299.2838076640951



--------------------------------------------------
 Evaluation Metrics (Epoch 20)
--------------------------------------------------
 Inception Score (IS) : 2.4047 ± 0.0699
 Fréchet Inception Distance (FID): 299.2838


 Epoch 30/50 | Training Progress 
[05:59:04] Loss D: 0.0848 | Loss G: 0.8000


Creating feature extractor "inception-v3-compat" with features ['2048', 'logits_unbiased']
Extracting features from input1
Looking for samples non-recursivelty in "generated_images/fake" with extensions png,jpg,jpeg
Found 3000 samples
Processing samples                                                           
Extracting features from input2
Looking for samples non-recursivelty in "generated_images/real" with extensions png,jpg,jpeg
Found 3000 samples
Processing samples                                                           
Inception Score: 3.082019988497799 ± 0.07072949314313365
Frechet Inception Distance: 215.4396624128606



--------------------------------------------------
 Evaluation Metrics (Epoch 30)
--------------------------------------------------
 Inception Score (IS) : 3.0820 ± 0.0707
 Fréchet Inception Distance (FID): 215.4397


 Epoch 40/50 | Training Progress 
[05:59:49] Loss D: 0.0734 | Loss G: 0.7932


Creating feature extractor "inception-v3-compat" with features ['2048', 'logits_unbiased']
Extracting features from input1
Looking for samples non-recursivelty in "generated_images/fake" with extensions png,jpg,jpeg
Found 4000 samples
Processing samples                                                           
Extracting features from input2
Looking for samples non-recursivelty in "generated_images/real" with extensions png,jpg,jpeg
Found 4000 samples
Processing samples                                                           
Inception Score: 3.1534999774749624 ± 0.052210753024167104
Frechet Inception Distance: 180.27684714523156



--------------------------------------------------
 Evaluation Metrics (Epoch 40)
--------------------------------------------------
 Inception Score (IS) : 3.1535 ± 0.0522
 Fréchet Inception Distance (FID): 180.2768


 Epoch 50/50 | Training Progress 
[06:00:44] Loss D: 0.0719 | Loss G: 0.7946


Creating feature extractor "inception-v3-compat" with features ['2048', 'logits_unbiased']
Extracting features from input1
Looking for samples non-recursivelty in "generated_images/fake" with extensions png,jpg,jpeg
Found 5000 samples
Processing samples                                                           
Extracting features from input2
Looking for samples non-recursivelty in "generated_images/real" with extensions png,jpg,jpeg
Found 5000 samples
Processing samples                                                           
Inception Score: 3.0573660537646736 ± 0.07029686617252294



--------------------------------------------------
 Evaluation Metrics (Epoch 50)
--------------------------------------------------
 Inception Score (IS) : 3.0574 ± 0.0703
 Fréchet Inception Distance (FID): 160.0485



Frechet Inception Distance: 160.04852101666927
