<a href="https://colab.research.google.com/github/borundev/pytorch_lightning_examples/blob/decouple_gan_generator_discriminator/Gan.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install -Uqq pytorch_lightning
!pip install -Uqq pytorch-lightning-bolts
!pip install -Uqq wandb

In [18]:
import pytorch_lightning as pl
from pl_bolts.datamodules.cifar10_datamodule import CIFAR10DataModule
from torch import nn
import torch
from pytorch_lightning.loggers import WandbLogger
import numpy as np
import torchvision.utils as vutils
import wandb
import torchvision
import psutil
import os
import torch.nn.functional as F
from sklearn.metrics import precision_recall_curve
import matplotlib.pyplot as plt

In [3]:
class Generator(nn.Module):
    def __init__(self, latent_dim, img_shape):
        super().__init__()
        self.latent_dim=latent_dim
        self.img_shape=img_shape

    def forward(self, z):
        return self.model(z)

class Discriminator(nn.Module):
    def __init__(self, img_shape):
        super().__init__()
        self.img_shape=img_shape

    def forward(self, img):
        return self.model(img)

class LambdaModule(nn.Module):
    def __init__(self, lambd):
        super().__init__()
        import types
        assert type(lambd) is types.LambdaType
        self.lambd = lambd

    def forward(self, x):
        return self.lambd(x)

In [4]:
class GeneratorDCGAN(Generator):
    def __init__(self, latent_dim, img_shape,ngf=64):
        super().__init__(latent_dim, img_shape)
        nc=img_shape[0]
        self.model = nn.Sequential(
            # input is Z, going into a convolution
            LambdaModule(lambda x: x.unsqueeze(-1).unsqueeze(-1)),
            nn.ConvTranspose2d( latent_dim, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            #nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.ConvTranspose2d(ngf*2, nc, 4, 2, 1, bias=False),
            #nn.BatchNorm2d(ngf),
            #nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            #nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )

    # custom weights initialization called on netG and netD
    @staticmethod
    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            nn.init.normal_(m.weight.data, 0.0, 0.02)
        elif classname.find('BatchNorm') != -1:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0)


In [5]:
class DiscriminatorDCGAN(Discriminator):
    def __init__(self, img_shape ,ndf = 64,):
        super().__init__(img_shape)
        nc=img_shape[0]
        self.model = nn.Sequential(
            # input is (nc) x 64 x 64
            #nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            #nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            #nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.Conv2d(nc, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid(),
            nn.Flatten(1)
        )
        self.apply(self.weights_init)


    # custom weights initialization called on netG and netD
    @staticmethod
    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            nn.init.normal_(m.weight.data, 0.0, 0.02)
        elif classname.find('BatchNorm') != -1:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0)

In [6]:
class GAN(pl.LightningModule):

    def __init__(
            self,
            channels,
            width,
            height,
            generator,
            discriminator,
            latent_dim: int = 100,
            lr: float = 0.0002,
            b1: float = 0.5,
            b2: float = 0.999,
            batch_size: int = 64,
            **kwargs
    ):
        super().__init__()
        self.save_hyperparameters()

        # networks
        self.data_shape = (channels, width, height)
        self.generator = generator
        self.discriminator = discriminator

        self.fixed_random_sample = None

    def print_summary(self):
        print(torchsummary.summary(self.generator, (self.hparams.latent_dim,), 1))
        print(torchsummary.summary(self.discriminator, self.data_shape, 1))

    def forward(self, z):
        return self.generator(z)

    def adversarial_loss(self, y_hat, y):
        return F.binary_cross_entropy(y_hat, y)

    def training_step(self, batch, batch_idx, optimizer_idx):

        # if this is the first run make the fixed random vector
        if self.fixed_random_sample is None:
            imgs, _ = batch
            self.fixed_random_sample = torch.randn(6, self.hparams.latent_dim).type_as(imgs)

        # log images generatd from fixed random noise status of the fixed random noise generated images
        sample_imgs = self(self.fixed_random_sample)
        grid = torchvision.utils.make_grid(sample_imgs,padding=2, normalize=True).detach().cpu().numpy().transpose(1, 2, 0)
        self.logger.experiment.log(
            {'gen_images': [wandb.Image(grid, caption='{}:{}'.format(self.current_epoch, batch_idx))]}, commit=False)

        process = psutil.Process(os.getpid())
        self.log('memory',process.memory_info().rss/(1024**3))

        if optimizer_idx == 0:
            return self.training_step_generator(batch, batch_idx)
        elif optimizer_idx == 1:
            return self.training_step_discriminator(batch, batch_idx)

    def training_step_generator(self, batch, batch_idx):

        imgs, _ = batch
        batch_size = imgs.shape[0]

        # note z.type_as(imgs) not only type_casts but also puts on the same device
        z = torch.randn(batch_size, self.hparams.latent_dim).type_as(imgs)
        generated_imgs = self(z)
        generated_y_score = self.discriminator(generated_imgs)
        generated_y = torch.ones(imgs.size(0), 1).type_as(imgs)
        g_loss = self.adversarial_loss(generated_y_score, generated_y)

        fooling_fraction = (generated_y_score > 0.5).type(torch.float).flatten().mean()

        self.log('generator/g_loss', g_loss, prog_bar=True)
        self.log('generator/g_fooling_fraction', fooling_fraction, prog_bar=True)

        return {'loss': g_loss,
                'y_score': generated_y_score,
                'y': generated_y
                }

    def log(self,k,v,*args,**kwargs):
        try:
            v=v.detach().item()
        except:
            pass
        super().log(k,v,*args,**kwargs)

    def training_step_discriminator(self, batch, batch_idx):
        imgs, _ = batch
        batch_size = imgs.shape[0]

        # note z.type_as(imgs) not only type_casts but also puts on the same device
        z = torch.randn(batch_size, self.hparams.latent_dim).type_as(imgs)
        generated_imgs = self(z)
        generated_y_score = self.discriminator(generated_imgs)
        generated_y = torch.zeros(imgs.size(0), 1).type_as(imgs)
        generated_loss = self.adversarial_loss(generated_y_score, generated_y)

        real_y_score = self.discriminator(imgs)
        real_y = torch.ones(imgs.size(0), 1).type_as(imgs)
        real_loss = self.adversarial_loss(real_y_score, real_y)

        y_score = torch.cat([real_y_score, generated_y_score], 0)
        y = torch.cat([real_y, generated_y], 0)
        pred = (y_score > 0.5).type(torch.int).view(-1, 1)

        accuracy = (pred == y).type(torch.float).mean()
        loss = (real_loss + generated_loss) / 2.0

        self.log('discriminator/d_loss', loss, prog_bar=True)
        self.log('discriminator/d_accuracy', accuracy, prog_bar=True)

        return {'loss': loss,
                'y_score': y_score,
                'y': y
                }

    def training_epoch_end(self, outputs):
        discriminator_score = []
        discriminator_y = []

        for output in outputs[1]:
            discriminator_y.append(output['y'])
            discriminator_score.append(output['y_score'])
        discriminator_score = torch.cat(discriminator_score)
        discriminator_y = torch.cat(discriminator_y)
        discriminator_score = torch.cat([1 - discriminator_score, discriminator_score], 1)

        y_true = discriminator_y.cpu().numpy().flatten()
        y_score = discriminator_score.detach().cpu().numpy()

        #self.log("discriminator/discriminator_pr", wandb.plot.pr_curve(y_true, y_score, labels=['Fake', 'Real']))
        #self.log("discriminator/discriminator_roc", wandb.plot.roc_curve(y_true, y_score, labels=['Fake', 'Real']))
        self.log('discriminator/discriminator_confusion_matrix', wandb.plot.confusion_matrix(y_score,
                                                                                             y_true,
                                                                                             class_names=['Fake',
                                                                                                          'Real']))

        p, r, t = precision_recall_curve(1-y_true, y_score[:, 0])
        plt.plot(r,p,label='Fake')
        p, r, t = precision_recall_curve(y_true, y_score[:, 1])
        plt.plot(r,p, label='Real')
        plt.xlim(0,1)
        plt.ylim(0,1)
        plt.xlabel('Recall')
        plt.ylabel('Precision')
        plt.grid(True)
        plt.legend()
        self.log('discriminator/pr_curve',wandb.Image(plt,caption='Epoch: {}'.format(self.current_epoch)))
        plt.close()


    def configure_optimizers(self):
        lr = self.hparams.lr
        b1 = self.hparams.b1
        b2 = self.hparams.b2

        opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))
        opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))
        return [opt_g, opt_d], []


In [7]:
Generator=GeneratorDCGAN
Discriminator=DiscriminatorDCGAN
DataModule=CIFAR10DataModule

In [8]:
torch.cuda.is_available()

True

In [9]:
dm = DataModule('.')
latent_dim=100
img_shape=dm.size()

generator=Generator(latent_dim=latent_dim, img_shape=img_shape)
discriminator=Discriminator(img_shape=img_shape)

model = GAN(*dm.size(),latent_dim=latent_dim, generator=generator, discriminator=discriminator)

In [27]:
logger = WandbLogger(project='gan_memory_profiling',name='colab-gpu')

dm.prepare_data()
dm.setup()
dataloader =dm.train_dataloader()
real_batch = next(iter(dataloader))

gpus=1
device = 'cuda:0'


real_images=np.transpose(vutils.make_grid(real_batch[0][:6], padding=2, normalize=True).detach().numpy(),(1,2,0))
logger.experiment.log({'real_sample':[wandb.Image(real_images, caption='Real Images')]})

Files already downloaded and verified
Files already downloaded and verified


In [28]:
trainer = pl.Trainer(gpus=gpus,
                     max_epochs=3,
                     logger=logger,
                     )

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


In [29]:
trainer.fit(model, dm)


  | Name          | Type               | Params
-----------------------------------------------------
0 | generator     | GeneratorDCGAN     | 3.4 M 
1 | discriminator | DiscriminatorDCGAN | 2.6 M 
-----------------------------------------------------
2.6 M     Trainable params
3.4 M     Non-trainable params
6.1 M     Total params


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…




1

In [30]:
wandb.finish()

VBox(children=(Label(value=' 1.20MB of 1.20MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
_step,4217.0
_runtime,304.0
_timestamp,1610901240.0
memory,4.86593
generator/g_loss,3.23946
generator/g_fooling_fraction,0.0
discriminator/d_loss,0.11444
discriminator/d_accuracy,0.98438
epoch,2.0


0,1
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
_runtime,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇█
_timestamp,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇█
memory,▁▁▁▁▁▁▁▁▁▁▁▁▁▂▅▅▅▅▅▅▅▅▅▅▅▆▆▆▆▆▆▆▆▆▆▆▆▆▇█
generator/g_loss,▄▄▂▅▃▅▅▄▅▃▄▂▅▇▅▄▄▃▄█▆▄▃▃▄▃▃▆▆▆▁▄▅▆▄▄▅▁▅▃
generator/g_fooling_fraction,▁▁▄▁▁▁▁▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▅▁▁▁▁▁▁█▁▁
discriminator/d_loss,▂▂▆▃▂▂▂▄▃▂▂▁█▆▁▂▂▃▁▁▂▁▂▁▂▁▁█▁▅▄▁▂▅▁▂▂▆▂▂
discriminator/d_accuracy,▇█▃▇███▆▆▇▇█▁▃█▇█▆██▇██████▂█▄▅██▂█▇█▂▇█
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▅▅▅▅▅▅▅▅▅▅▅▅▅█████████████


In [31]:
logger = WandbLogger(project='gan_memory_profiling',name='colab-cpu')

dm.prepare_data()
dm.setup()
dataloader =dm.train_dataloader()
real_batch = next(iter(dataloader))

gpus=0
device = 'cpu'


real_images=np.transpose(vutils.make_grid(real_batch[0][:6], padding=2, normalize=True).detach().numpy(),(1,2,0))
logger.experiment.log({'real_sample':[wandb.Image(real_images, caption='Real Images')]})

Files already downloaded and verified
Files already downloaded and verified


In [32]:
trainer = pl.Trainer(gpus=gpus,
                     max_epochs=3,
                     logger=logger,
                     )

GPU available: True, used: False
TPU available: None, using: 0 TPU cores


In [33]:
generator=Generator(latent_dim=latent_dim, img_shape=img_shape)
discriminator=Discriminator(img_shape=img_shape)

model = GAN(*dm.size(),latent_dim=latent_dim, generator=generator, discriminator=discriminator)

In [None]:
trainer.fit(model, dm)


  | Name          | Type               | Params
-----------------------------------------------------
0 | generator     | GeneratorDCGAN     | 3.4 M 
1 | discriminator | DiscriminatorDCGAN | 2.6 M 
-----------------------------------------------------
6.1 M     Trainable params
0         Non-trainable params
6.1 M     Total params


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

In [None]:
wandb.finish()