# Generative Adversarial Networks

```
pytorch
pytorch-lightning
```

In [1]:
from pathlib import Path
from collections import OrderedDict

import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
import torchvision
from torchvision.datasets import MNIST
from torchvision import transforms

import pytorch_lightning as pl



In [2]:
class Discriminator(nn.Module):
    """Discriminator"""
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(in_features=28*28, out_features=1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(in_features=1024, out_features=512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(in_features=512, out_features=256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(in_features=256, out_features=1),
            nn.Sigmoid()
        )

    def forward(self, inputs):
        r"""
        inputs: sample images
        """
        inputs = inputs.view(-1, 28*28)
        return self.layers(inputs)

class Generator(nn.Module):
    """Generator"""
    def __init__(self, latent_dim):
        super(Generator, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(in_features=latent_dim, out_features=256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(in_features=256, out_features=512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(in_features=512, out_features=1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(in_features=1024, out_features=28*28),
            nn.Tanh()
        )
    
    def forward(self, inputs):
        r"""
        inputs: random values from prior, z
        """
        return self.layers(inputs).view(-1, 1, 28, 28)

class GAN(pl.LightningModule):
    def __init__(self, hparams):
        r"""
        Code Reference: PyTorch-Lightining https://colab.research.google.com/drive/1F_RNcHzTfFuQf-LeKvSlud6x7jXYkG31#scrollTo=ArrPXFM371jR
        """
        super().__init__()
        self.hparams = hparams
        self.D = Discriminator()
        self.G = Generator(self.hparams.latent_dim)
        # cache
        self.generated_imgs = None
        self.last_imgs = None
    def forward(self, z):
        return self.G(z)

    def criterion(self, y_hat, y):
        return F.binary_cross_entropy(y_hat, y)

    def training_step(self, batch, batch_idx, optimizer_idx):
        x, _ = batch
        bs = x.size(0)
        self.last_imgs = x
        # training G
        if optimizer_idx == 0:  
            z = torch.randn(bs, self.hparams.latent_dim)
            if self.on_gpu:
                z = z.to(x.device)
            self.generated_imgs = self(z)

            valid_label = torch.ones(bs, 1)
            if self.on_gpu:
                valid_label = valid_label.to(x.device)

            # Pass the generated inputs to Discriminator
            # Let the Generator learn how to make a good faked images
            g_loss = self.criterion(self.D(self.generated_imgs), valid_label)
            tqdm_dict = {'g_loss': g_loss}
            output = OrderedDict({
                'loss': g_loss,
                'progress_bar': tqdm_dict,
                'log': tqdm_dict
            })
            return output

        # training D
        if optimizer_idx == 1:
            valid_label = torch.ones(bs, 1)
            if self.on_gpu:
                valid_label = valid_label.to(x.device)

            # Pass the real inputs to Discriminator
            # Let the Discriminator learn to determine whether inputs are real 
            real_loss = self.criterion(self.D(x), valid_label)

            fake_label = torch.zeros(bs, 1)
            if self.on_gpu:
                fake_label = fake_label.to(x.device)

            # Pass the cached generated image
            fake_loss = self.criterion(self.D(self.generated_imgs.detach()), fake_label)

            # Discriminator loss is the average of these
            d_loss = (real_loss + fake_loss) / 2
            tqdm_dict = {'d_loss': d_loss}
            output = OrderedDict({
                'loss': d_loss,
                'progress_bar': tqdm_dict,
                'log': tqdm_dict
            })
            return output

    def configure_optimizers(self):
        lr = self.hparams.lr
        b1 = self.hparams.b1
        b2 = self.hparams.b2

        G_optimizer = optim.Adam(self.G.parameters(), lr=lr, betas=(b1, b2))
        D_optimizer = optim.Adam(self.D.parameters(), lr=lr, betas=(b1, b2))
        return [G_optimizer, D_optimizer], []

    def train_dataloader(self):
        root_dir = Path(".").absolute().parent.parent
        transform = transforms.Compose([transforms.ToTensor(),
                                        transforms.Normalize([0.5], [0.5])])
        dataset = MNIST(
            root_dir / "data", 
            train=True, 
            download=True, 
            transform=transform
        )
        loader = DataLoader(dataset, batch_size=hparams.batch_size, shuffle=True)
        return loader


In [3]:
from argparse import Namespace

args = {
    'batch_size': 64,
    'lr': 0.0002,
    'b1': 0.5,
    'b2': 0.999,
    'latent_dim': 5
}
hparams = Namespace(**args)

In [4]:
from pytorch_lightning.callbacks import ModelCheckpoint

ckpts_path = str(Path(".").absolute() / "ckpts")
# DEFAULTS used by the Trainer
checkpoint_callback = ModelCheckpoint(
    filepath=ckpts_path,
    save_last=True,
    verbose=True,
    save_weights_only=True
)



In [5]:
gan_model = GAN(hparams)

# most basic trainer, uses good defaults (1 gpu)
trainer = pl.Trainer(max_epochs=15, gpus=1, checkpoint_callback=checkpoint_callback)    
trainer.fit(gan_model)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]

  | Name | Type          | Params
---------------------------------------
0 | D    | Discriminator | 1 M   
1 | G    | Generator     | 1 M   


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…






1

In [33]:
ckpt = torch.load(ckpts_path + "/last.ckpt")
G_weight = OrderedDict({k.lstrip("G."): v for k, v in ckpt["state_dict"].items() if "G." in k})

In [43]:
generator = Generator(latent_dim=hparams.latent_dim)
generator.load_state_dict(G_weight)

<All keys matched successfully>

In [44]:
import ipywidgets as widgets
from ipywidgets import interact, interact_manual
import matplotlib.pyplot as plt

In [45]:
latent_dict = {
    f"L{i}": widgets.IntSlider(
        min=0, max=99, step=1, value=0, description=f"L{i}"
    ) for i in range(hparams.latent_dim)
}
ui = widgets.VBox(
    list(latent_dict.values()), 
    layout=widgets.Layout(display='inline-flex', flex_flow='column', border='solid 2px', justify_content='space-between')
)
z = torch.zeros(hparams.latent_dim)
zz = torch.linspace(-1, 1)

In [46]:
def generate_img(L0, L1, L2, L3, L4):
    for i in range(hparams.latent_dim):
        idx = latent_dict[f"L{i}"].value
        z[i] = zz[idx]
    output = generator(z).detach().squeeze().numpy()
    plt.imshow(output)

In [47]:
out = widgets.interactive_output(generate_img, latent_dict)

In [48]:
display(ui, out)

VBox(children=(IntSlider(value=0, description='L0', max=99), IntSlider(value=0, description='L1', max=99), Int…

Output(outputs=({'output_type': 'display_data', 'data': {'text/plain': '<Figure size 432x288 with 1 Axes>', 'i…