In [1]:
import os
from collections import OrderedDict

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST

PATH_DATASETS = '.'
AVAIL_GPUS = 1
BATCH_SIZE = 64
NUM_WORKERS = int(os.cpu_count() / 2)
N_CRITIC = 5
CLIP_VALUE = .01

# dataset

In [2]:
class MNISTDataModule(LightningDataModule):
    def __init__(
            self,
            data_dir: str = PATH_DATASETS,
            batch_size: int = BATCH_SIZE,
            num_workers: int = NUM_WORKERS,
    ):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers

        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
                # transforms.Normalize((0.1307,), (0.3081,)),
            ]
        )

        # self.dims is returned when you call dm.size()
        # Setting default dims here because we know them.
        # Could optionally be assigned dynamically in dm.setup()
        self.dims = (1, 28, 28)
        self.num_classes = 10

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(
            self.mnist_train,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
        )

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=self.num_workers)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=self.num_workers)

# models

In [3]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.main = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 2, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(.2, inplace=True),
        )
    
    def forward(self, x):
        return self.main(x)


class DCGAN_D(nn.Module):
    def __init__(self, in_channels, init_channels):
        super().__init__()
        self.main = nn.Sequential(
            ConvBlock(in_channels, init_channels),
            ConvBlock(init_channels * 1, init_channels * 2),
            ConvBlock(init_channels * 2, init_channels * 4),
            nn.Conv2d(init_channels * 4, 1, 4),
        )
        
    def forward(self, x):
        return self.main(x).mean()

In [4]:
class ConvTransposedBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(.2, inplace=True),
        )
    
    def forward(self, x):
        return self.main(x)


class DCGAN_G(nn.Module):
    def __init__(self, z_channels, init_channels, out_channels):
        super().__init__()
        
        self.main = nn.Sequential(
            ConvTransposedBlock(z_channels, init_channels),            
            ConvTransposedBlock(init_channels // 1, init_channels // 2),
            ConvTransposedBlock(init_channels // 2, init_channels // 4),
            ConvTransposedBlock(init_channels // 4, init_channels // 8),
            
            nn.ConvTranspose2d(init_channels // 8, out_channels, 4, 2, 1, bias=False),
            nn.Sigmoid(),            
        )
    
    def forward(self, x):
        x = x.unsqueeze(2).unsqueeze(2)
        x = self.main(x)
        x = x[..., 2:-2, 2:-2]
        return x

# trainer

In [5]:
class GAN(LightningModule):
    def __init__(
            self,
            channels,
            width,
            height,
            latent_dim: int = 100,
            lr: float = 5e-5,
            b1: float = 0.5,
            b2: float = 0.999,
            batch_size: int = BATCH_SIZE,
            **kwargs
    ):
        super().__init__()
        self.save_hyperparameters()

        # networks
        data_shape = (channels, width, height)
        # self.generator = Generator(latent_dim=self.hparams.latent_dim, img_shape=data_shape)
        # self.discriminator = Discriminator(img_shape=data_shape)
        
        self.generator = DCGAN_G(self.hparams.latent_dim, 64, 1)
        self.discriminator = DCGAN_D(1, 64)

        self.validation_z = torch.randn(8, self.hparams.latent_dim)

        self.example_input_array = torch.zeros(2, self.hparams.latent_dim)

    def forward(self, z):
        return self.generator(z)

    def training_step(self, batch, batch_idx, optimizer_idx):
        imgs, _ = batch

        # sample noise
        z = torch.randn(imgs.shape[0], self.hparams.latent_dim)
        z = z.type_as(imgs)

        # train generator
        if optimizer_idx == 0:

            # generate images
            self.generated_imgs = self(z)

            # log sampled images
            # sample_imgs = self.generated_imgs[:6]
            # grid = torchvision.utils.make_grid(sample_imgs)
            # self.logger.experiment.add_image("generated_images", grid, self.current_epoch)

            # adversarial loss is binary cross-entropy
            g_loss = -self.discriminator(self(z)).mean()
            tqdm_dict = {"g_loss": g_loss}
            output = OrderedDict({"loss": g_loss, "progress_bar": tqdm_dict, "log": tqdm_dict})
            return output

        # train discriminator
        if optimizer_idx == 1:
            for p in self.discriminator.parameters():
                p.data.clamp_(-CLIP_VALUE, CLIP_VALUE)
            # Measure discriminator's ability to classify real from generated samples
            
            real_loss = self.discriminator(imgs).mean()

            fake_loss = self.discriminator(self(z).detach()).mean()

            # discriminator loss is the average of these
            d_loss = fake_loss - real_loss
            tqdm_dict = {"d_loss": d_loss}
            output = OrderedDict({"loss": d_loss, "progress_bar": tqdm_dict, "log": tqdm_dict})
            return output

    def configure_optimizers(self):
        lr = self.hparams.lr
        b1 = self.hparams.b1
        b2 = self.hparams.b2

        opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))
        opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))
        
        return (
            {'optimizer': opt_g, 'frequency': 1},
            {'optimizer': opt_d, 'frequency': N_CRITIC}
        )

    def on_epoch_end(self):
        z = self.validation_z.type_as(model.generator.main[0].main[0].weight)

        # log sampled images
        sample_imgs = self(z)
        grid = torchvision.utils.make_grid(sample_imgs)
        self.logger.experiment.add_image("generated_images", grid, self.current_epoch)

# train model

In [6]:
dm = MNISTDataModule()
model = GAN(*dm.size())
trainer = Trainer(gpus=1, max_epochs=100, progress_bar_refresh_rate=1)
trainer.fit(model, dm)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name          | Type    | Params | In sizes | Out sizes     
----------------------------------------------------------------------
0 | generator     | DCGAN_G | 145 K  | [2, 100] | [2, 1, 28, 28]
1 | discriminator | DCGAN_D | 374 K  | ?        | ?             
----------------------------------------------------------------------
519 K     Trainable params
0         Non-trainable params
519 K     Total params
2.080     Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]