Dataset Source:
https://www.kaggle.com/datasets/andrewmvd/animal-faces

In [34]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import lightning as L
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.trainer import Trainer
import matplotlib.pyplot as plt
import wandb
import torchmetrics
import numpy as np


In [37]:
dataset = torchvision.datasets.ImageFolder(
    root='data/afhq/train',
    transform=transforms.Compose([
        transforms.Resize(64),
        transforms.ToTensor(),
    ]),
)

# Dataset of only dogs
dataset = torch.utils.data.Subset(dataset, indices=[i for i in range(len(dataset)) if dataset.targets[i] == 1])
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True, num_workers=2, pin_memory=True, persistent_workers=True)


In [38]:
generator = nn.Sequential(
    nn.Unflatten(1, (100, 1, 1)),
    nn.BatchNorm2d(100),
    nn.ConvTranspose2d(100, 256, 4, 1, 0),
    nn.ReLU(),
    nn.BatchNorm2d(256),
    nn.ConvTranspose2d(256, 128, 4, 2, 1),
    nn.ReLU(),
    nn.BatchNorm2d(128),
    nn.ConvTranspose2d(128, 64, 4, 2, 1),
    nn.ReLU(),
    nn.BatchNorm2d(64),
    nn.ConvTranspose2d(64, 32, 4, 2, 1),
    nn.ReLU(),
    nn.BatchNorm2d(32),
    nn.ConvTranspose2d(32, 3, 4, 2, 1),
    nn.Tanh()
    
)

discriminator = nn.Sequential(
    nn.BatchNorm2d(3),
    nn.Conv2d(3, 32, 4, 2, 1),
    nn.LeakyReLU(0.2),
    nn.BatchNorm2d(32),
    nn.Conv2d(32, 64, 4, 2, 1),
    nn.LeakyReLU(0.2),
    nn.BatchNorm2d(64),
    nn.Conv2d(64, 128, 4, 2, 1),
    nn.LeakyReLU(0.2),
    nn.BatchNorm2d(128),
    nn.Conv2d(128, 256, 4, 2, 1),
    nn.LeakyReLU(0.2),
    nn.Flatten(),
    nn.Linear(256*4*4, 1),
    nn.Sigmoid()
)

class GAN(L.LightningModule):
    def __init__(self, generator, discriminator):
        super().__init__()
        self.generator = generator
        self.discriminator = discriminator
        self.automatic_optimization = False
        self.metric = torchmetrics.image.fid.FrechetInceptionDistance(feature=64, normalize=True)

        
    def training_step(self, batch, batch_idx):
        g_opt, d_opt = self.optimizers()
        real_imgs, real_labels = batch
        
        # Train discriminator
        for idx in range(1):
            d_opt.zero_grad()
            z = torch.randn(real_imgs.shape[0], 100, device=self.device)
            fake_imgs = self.generator(z)
            real_preds = self.discriminator(real_imgs)
            fake_preds = self.discriminator(fake_imgs)
            real_loss = F.binary_cross_entropy_with_logits(real_preds, torch.ones_like(real_preds))
            fake_loss = F.binary_cross_entropy_with_logits(fake_preds, torch.zeros_like(fake_preds))
            d_loss = real_loss + fake_loss
            d_loss.backward()
            d_opt.step()
        
        # Train generator
        for idx in range(2):
            g_opt.zero_grad()
            z = torch.randn(real_imgs.shape[0], 100, device=self.device)
            fake_imgs = self.generator(z)
            fake_preds = self.discriminator(fake_imgs)
            g_loss = F.binary_cross_entropy_with_logits(fake_preds, torch.ones_like(fake_preds))
            g_loss.backward()
            g_opt.step()

        
        if batch_idx % 100 == 0:
            z = torch.randn(real_imgs.shape[0], 100, device=self.device)
            fake_imgs = self.generator(z)
            real_preds = self.discriminator(real_imgs)
            fake_preds = self.discriminator(fake_imgs)
            real_loss = F.binary_cross_entropy_with_logits(real_preds, torch.ones_like(real_preds))
            fake_loss = F.binary_cross_entropy_with_logits(fake_preds, torch.zeros_like(fake_preds))
            d_loss = real_loss + fake_loss
            g_loss = F.binary_cross_entropy_with_logits(fake_preds, torch.ones_like(fake_preds))
            
            
            # Calculate FID
            fake_imgs = F.interpolate(fake_imgs, size=(299, 299), mode='nearest').detach()
            real_imgs = F.interpolate(real_imgs, size=(299, 299), mode='nearest').detach()

            self.metric.update(fake_imgs, real=False)
            self.metric.update(real_imgs, real=True)
            
            fid_score = self.metric.compute()
            
            self.metric.reset()
            self.logger.experiment.log({
                "Generator Loss": g_loss.item(),
                "Discriminator Loss": d_loss.item(),
                "Discriminator Accuracy": (real_preds > 0.5).float().mean().item(),
                "Discriminator Loss Real": real_loss.item(),
                "Discriminator Loss Fake": fake_loss.item(),
                "Generated Images": wandb.Image(torchvision.utils.make_grid(fake_imgs, nrow=32), caption=f"Generated Images Epoch {self.current_epoch}, Batch {batch_idx}"),
                "FID Score": fid_score
                })
        
    def configure_optimizers(self):
        g_opt = optim.Adam(self.generator.parameters(), lr=0.0002)
        d_opt = optim.Adam(self.discriminator.parameters(), lr=0.0002)
        return [g_opt, d_opt], []

gan = GAN(generator, discriminator)


In [41]:
logger = WandbLogger(project="gan", tags=["conv", "dcgan"])
trainer = Trainer(max_epochs=10, logger=logger, limit_val_batches=1, profiler="simple", limit_train_batches=10)
trainer.fit(gan, dataloader)
wandb.finish()


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.



  | Name          | Type                     | Params
-----------------------------------------------------------
0 | generator     | Sequential               | 1.1 M 
1 | discriminator | Sequential               | 694 K 
2 | metric        | FrechetInceptionDistance | 23.9 M
-----------------------------------------------------------
1.8 M     Trainable params
23.9 M    Non-trainable params
25.6 M    Total params
102.586   Total estimated model params size (MB)
/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py:298: The number of training batches (10) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.
FIT Profiler Report

--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Action                                                                                                                                                        	|  Mean duration (s)	|  Num calls      	|  Total time (s) 	|  Percentage %   	|
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Total                                                                                                                                                   

VBox(children=(Label(value='5.861 MB of 5.861 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
Discriminator Accuracy,▁▅▅▅▇▇████
Discriminator Loss,▄▂█▄▅▃▁▁▁▁
Discriminator Loss Fake,▃▁█▄▅▃▁▁▁▁
Discriminator Loss Real,█▄▄▅▃▂▁▁▁▁
FID Score,▁▁▂▃▅▅▇███
Generator Loss,▆▇▁▅▄▆████

0,1
Discriminator Accuracy,1.0
Discriminator Loss,1.00701
Discriminator Loss Fake,0.69319
Discriminator Loss Real,0.31382
FID Score,717.49011
Generator Loss,0.6931


In [10]:
wandb.finish()


VBox(children=(Label(value='0.002 MB of 0.002 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))