In [None]:
#default_exp gan

In [None]:
#hide
from IPython.display import clear_output
from nbdev.export import notebook2script
from dotenv import load_dotenv
%reload_ext autoreload
%autoreload 2

_ = load_dotenv()

In [None]:
#export
import torch
import logging
import functools
import torchvision
import pytorch_lightning as pl
from torch import nn
from more_itertools import pairwise
from collections import OrderedDict
from practical_ai.layers import Identity
from practical_ai.data import get_dataset, get_data_loader
from practical_ai.gan.loss import get_adversarial_losses_fns


logger = logging.getLogger()
logger.setLevel("INFO")

# gan

> 對抗生成網路（Generative Adversarial Network）的相關模組。

In [None]:
#export 
DIM_CHANNEL_MULTIPLIER = 8
KERNEL_SIZE = 4
LATENT_DIM = 128
DIM = 32
CHANNELS = 1
NORM = "batch"
BATCH_SIZE = 32

## Utilities

In [None]:
#export
def get_n_samplings(dim):
    return int(torch.log2(torch.tensor(dim, dtype=torch.float32)).item()) - 2

In [None]:
dim = 32
assert get_n_samplings(32) == 3

In [None]:
#export
def get_norm2d(name):
    if name == "identity":
        return Identity
    elif name == "batch":
        return nn.BatchNorm2d
    elif name == "instance":
        return functools.partial(nn.InstanceNorm2d, affine=True)
    elif name == "layer":
        return lambda num_features: nn.GroupNorm(1, num_features)
    else:
        raise NotImplementedError

In [None]:
norm = get_norm2d("identity")
norm

practical_ai.layers.Identity

In [None]:
#export
def get_activation(name):
    if name == "relu":
        return nn.ReLU()
    elif name == "leaky_relu":
        return nn.LeakyReLU(0.2)
    elif name == "tanh":
        return nn.Tanh()
    else:
        raise NotImplementedError

In [None]:
act = get_activation("relu")
act

ReLU()

## Modules

In [None]:
#export
class UpsampleConv2d(nn.Sequential):
    """基本上採樣： ConvTransponse2d -> Norm -> Activation"""
    
    def __init__(self, 
                 in_channels,
                 out_channels,
                 kernel_size=KERNEL_SIZE,
                 stride=2,
                 padding=1,
                 norm="batch",
                 act="relu",
                 bias=True):
        
        layers = [
            nn.ConvTranspose2d(in_channels, 
                               out_channels, 
                               kernel_size, 
                               stride, 
                               padding,
                               bias=bias)]
        
        if norm != "none":
            layers.append(get_norm2d(norm)(out_channels))
        
        if act not in ["none", "linear"]:
            layers.append(get_activation(act))
        
        super().__init__(*layers)

一般使用： Transposed Conv -> Norm -> Act

In [None]:
in_channels = 3
out_channels = 128
h = w = 32
batch_size = 8

x = torch.randn(batch_size, in_channels, h, w)
upconv = UpsampleConv2d(in_channels, out_channels)
out = upconv(x)

assert len(upconv) == 3
assert out.shape == (batch_size, out_channels, h * 2, w * 2)
upconv

UpsampleConv2d(
  (0): ConvTranspose2d(3, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU()
)

單純上採樣，不使用 Norm 以及非線性 activation

In [None]:
linear_acts = ["none", "linear"]
for act in linear_acts:
    upconv_linear = UpsampleConv2d(in_channels, out_channels, norm="none", act=act)
    print(upconv_linear)
    assert len(upconv_linear) == 1

UpsampleConv2d(
  (0): ConvTranspose2d(3, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
)
UpsampleConv2d(
  (0): ConvTranspose2d(3, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
)


In [None]:
#export
class UnsqueezeLatent(nn.Module):
    """將 latent vector unsqueeze"""
    def forward(self, x):
        return x[..., None, None]

In [None]:
batch_size = 32
latent_dim = 100

x = torch.randn(batch_size, latent_dim)
out = UnsqueezeLatent()(x)
assert out.shape == (batch_size, latent_dim, 1, 1)

In [None]:
#export
class SqueezeLogit(nn.Module):
    """Squeeze Discriminator logit"""
    def forward(self, x):
        return x.squeeze(-1).squeeze(-1)

In [None]:
batch_size = 16
dim = 50

x = torch.randn(batch_size, dim, 1, 1)
out = SqueezeLogit()(x)
assert out.shape == (batch_size, dim)

In [None]:
#export
class DownsampleConv2d(nn.Sequential):
    """基本下採樣： Conv2d -> Norm -> Activation"""
    
    def __init__(self, 
                 in_channels,
                 out_channels,
                 kernel_size=KERNEL_SIZE,
                 stride=2,
                 padding=1,
                 norm="batch",
                 act="leaky_relu",
                 bias=True):
        
        layers = [nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias)]
        
        if norm != "none":
            layers.append(get_norm2d(norm)(out_channels))
            
        layers.append(get_activation(act))
        super().__init__(*layers)

In [None]:
batch_size = 8
in_channels = 3
out_channels = 64
h = w = 32

x = torch.randn(batch_size, in_channels, h, w)
downconv = DownsampleConv2d(in_channels, out_channels)
out = downconv(x)
assert out.shape == (batch_size, out_channels, h / 2, w / 2)

## Generator

In [None]:
#export
class ConvGenerator(nn.Sequential):
    """將特定維度的潛在向量上採樣到指定圖片大小的生成器"""

    def __init__(self,
                 latent_dim=LATENT_DIM,
                 out_dim=DIM,
                 out_channels=CHANNELS,
                 kernel_size=KERNEL_SIZE,
                 max_channels=None,
                 norm=NORM,
                 act="relu",
                 dim_channel_multiplier=DIM_CHANNEL_MULTIPLIER):
        self.latent_dim = latent_dim
        self.out_dim = out_dim
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.dim_channel_multiplier = dim_channel_multiplier
        self.norm = norm
        self.act = act
        self.max_channels = max_channels if max_channels else self.out_dim * self.dim_channel_multiplier
        
        # decide appropriate number of upsampling process based on expected output image shape
        self.n_upsamples = get_n_samplings(self.out_dim)
        
        # projected to spatial extent convolutional repr. with feature maps
        # x.shape == (batch_size, latent_dim)
        layers = [
            UnsqueezeLatent(),
            UpsampleConv2d(in_channels=self.latent_dim,
                           out_channels=self.max_channels,
                           kernel_size=self.kernel_size,
                           stride=1,  # no need to stride in first layer
                           padding=0,  # no padding in first layer
                           norm=self.norm,
                           act=self.act)]
        
        # upsamples
        # x.shape == (batch_size, max_channels, kernel_size, kernel_size)
        chs = [self.max_channels // (2 ** i) for i in range(self.n_upsamples)]
        chs.append(self.out_channels)
        
        layers.extend([
            UpsampleConv2d(in_channels=in_ch,
                           out_channels=out_ch,
                           kernel_size=self.kernel_size,
                           stride=2,
                           norm=self.norm if i != self.n_upsamples else "none",
                           act=self.act if i != self.n_upsamples else "tanh",
                           bias=False if i != self.n_upsamples else True)
         for i, (in_ch, out_ch) in enumerate(pairwise(chs), 1)])
        # out.shape == (batch_size, out_channels, out_dim, out_dim)
        
        # final act: tanh
        # using a bounded activation allowed the model to learn more quickly to 
        # saturate and cover the color space of the training distribution. 
        
        super().__init__(*layers)

In [None]:
batch_size = 8

for latent_dim, out_dim, out_ch in zip([128, 50, 100], [128, 64, 32], [3, 3, 1]):
    x = torch.randn(batch_size, latent_dim)
    g = ConvGenerator(latent_dim=latent_dim, out_dim=out_dim, out_channels=out_ch)
    out = g(x)
    assert out.shape == (batch_size, out_ch, out_dim, out_dim)
g

ConvGenerator(
  (0): UnsqueezeLatent()
  (1): UpsampleConv2d(
    (0): ConvTranspose2d(100, 256, kernel_size=(4, 4), stride=(1, 1))
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (2): UpsampleConv2d(
    (0): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (3): UpsampleConv2d(
    (0): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (4): UpsampleConv2d(
    (0): ConvTranspose2d(64, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): Tanh()
  )
)

## Discriminator

In [None]:
class ConvDiscriminator(nn.Sequential):
    """將特定大小圖片下採樣的辨識器"""
    
    def __init__(self, 
                 in_channels=CHANNELS, 
                 in_dim=DIM, 
                 norm=NORM,
                 kernel_size=KERNEL_SIZE,
                 max_channels=None,
                 dim_channel_multiplier=DIM_CHANNEL_MULTIPLIER):
        self.in_channels = in_channels
        self.in_dim = in_dim
        self.norm = norm
        self.kernel_size = kernel_size
        self.n_downsamples = get_n_samplings(self.in_dim)
        self.dim_channel_multiplier = dim_channel_multiplier
        self.max_channels = max_channels if max_channels else self.in_dim * self.dim_channel_multiplier
        
        # downsample
        chs = [self.in_channels]
        chs += sorted([self.max_channels // (2 ** i) for i in range(self.n_downsamples)])
        
        # x.shape == (batch_size, in_channels, in_dim, in_dim)
        layers = [
            DownsampleConv2d(in_ch, 
                             out_ch, 
                             self.kernel_size, 
                             stride=2, 
                             norm=self.norm if i != 1 else "none",
                             bias=False if i != 1 else True)
            for i, (in_ch, out_ch) in enumerate(pairwise(chs), 1)]
        
        # compute logits
        # x.shape == (batch_size, max_channels, kernel_size, kernel_size)
        layers.extend([
            nn.Conv2d(chs[-1], 1, kernel_size=self.kernel_size),
            SqueezeLogit()
        ])
        # out.shape == (batch_size, 1)
        
        super().__init__(*layers)

In [None]:
batch_size = 8

for in_ch, in_dim in zip([3, 3, 1], [128, 64, 32]):
    x = torch.randn(batch_size, in_ch, in_dim, in_dim)
    d = ConvDiscriminator(in_ch, in_dim)
    out = d(x)
    assert out.shape == (batch_size, 1)
d

ConvDiscriminator(
  (0): DownsampleConv2d(
    (0): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2)
  )
  (1): DownsampleConv2d(
    (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2)
  )
  (2): DownsampleConv2d(
    (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2)
  )
  (3): Conv2d(256, 1, kernel_size=(4, 4), stride=(1, 1))
  (4): SqueezeLogit()
)

## Generative Adversarial Network

In [None]:
#export
def get_generater(_type):
    if _type == "conv":
        return ConvGenerator
    else:
        raise NotImplementedError

def get_discriminator(_type):
    if _type == "conv":
        return ConvDiscriminator
    else:
        raise NotImplementedError

In [None]:
#export
class GAN(pl.LightningModule):
    """基本對抗生成網路"""
    
    def __init__(self, 
                 generator_type="conv", 
                 discriminator_type="conv",
                 dataset_name="mnist",
                 adversarial_loss_type="gan",
                 batch_size=BATCH_SIZE,
                 lr=2e-4,
                 beta1=0.5,
                 beta2=0.999,
                 latent_dim=LATENT_DIM,
                 image_shape=(CHANNELS, DIM, DIM),
                 kernel_size=KERNEL_SIZE,
                 norm=NORM,
                 **kwargs):
        super().__init__()
        self.generator_type = generator_type
        self.discriminator_type = discriminator_type
        self.dataset_name = dataset_name
        self.adversarial_loss_type = adversarial_loss_type
        self.batch_size = batch_size
        self.latent_dim = latent_dim
        self.lr = lr
        self.beta1 = beta1
        self.beta2 = beta2
        self.image_shape = image_shape
        self.kernel_size = kernel_size
        self.norm = norm
        
        # adversarial losses
        self.g_loss_fn, self.d_loss_fn = \
            get_adversarial_losses_fns(self.adversarial_loss_type)
        
        
        assert self.image_shape[-1] == self.image_shape[-2]
        self.channels, self.dim = self.image_shape[0], self.image_shape[-1]
        
        # initialize networks
        g = get_generater(self.generator_type)
        self.generator = g(latent_dim=self.latent_dim, 
                           out_dim=self.dim, 
                           out_channels=self.channels,
                           kernel_size=self.kernel_size,
                           norm=self.norm)
        
        d = get_discriminator(self.discriminator_type)
        self.discriminator = d(in_channels=self.channels,
                               in_dim=self.dim,
                               norm=self.norm,
                               kernel_size=self.kernel_size)
        
        # cache for generated images
        self.generated_images = None
        self.last_real_images = None
        
    def prepare_data(self):
        self.train_dataset = get_dataset(dataset=self.dataset_name,
                                         split="train",
                                         size=(self.dim, self.dim), 
                                         return_label=False)
        
        self.valid_dataset = get_dataset(dataset=self.dataset_name,
                                         split="valid",
                                         size=(self.dim, self.dim), 
                                         return_label=False)
        
    def train_dataloader(self):
        return get_data_loader(self.train_dataset, batch_size=self.batch_size)
    
    def configure_optimizers(self):
        self.g_optim = torch.optim.Adam(self.generator.parameters(), 
                                        lr=self.lr, 
                                        betas=(self.beta1, self.beta2))
        self.d_optim = torch.optim.Adam(self.discriminator.parameters(), 
                                        lr=self.lr, 
                                        betas=(self.beta1, self.beta2))
        return [self.d_optim, self.g_optim], []
    
    def get_latent_vectors(self, n=8, on_gpu=True):
        z = torch.randn(n, self.latent_dim)
        if on_gpu:
            z = z.cuda(self.last_real_images.device.index)
        return z
    
    def training_step(self, batch, batch_idx, optimizer_idx):
        self.last_real_images = real_images = batch
        z = self.get_latent_vectors(on_gpu=self.on_gpu)
        
        # discriminator
        if optimizer_idx == 0:
            fake_images = self.generator(z).detach()
            real_logits = self.discriminator(real_images)
            fake_logits = self.discriminator(fake_images)
            
            d_real_loss, d_fake_loss = self.d_loss_fn(real_logits, fake_logits, 
                                                      on_gpu=self.on_gpu)
            d_loss = d_real_loss + d_fake_loss
            
            # TODO: gradient penality
            
            tqdm_dict = {'d_loss': d_loss}
            output = OrderedDict({
                'loss': d_loss,
                'progress_bar': tqdm_dict,
                'log': tqdm_dict
            })
            return output
            
        # generator
        if optimizer_idx == 1:
            fake_images = self.generateed_images = self.generator(z)
            fake_logits = self.discriminator(fake_images)
            g_loss = self.g_loss_fn(fake_logits)
            
            tqdm_dict = {'g_loss': g_loss}
            output = OrderedDict({
                'loss': g_loss,
                'progress_bar': tqdm_dict,
                'log': tqdm_dict
            })
            return output
    
    def forward(self, z):
        return self.generator(z)
    
    def on_epoch_end(self):
        z = self.get_latent_vectors(on_gpu=self.on_gpu)
        sample_images = self.generator(z)
        grid = torchvision.utils.make_grid(sample_images)
        self.logger.experiment.add_image('sample_images', grid, self.current_epoch)

In [None]:
#server
from argparse import Namespace


args = {
#     'batch_size': 32,
#     'lr': 0.0002,
#     'b1': 0.5,
#     'b2': 0.999,
#     'latent_dim': 100
}
hparams = Namespace(**args)

gan = GAN()

# most basic trainer, uses good defaults (1 gpu)
trainer = pl.Trainer(gpus=1)    
trainer.fit(gan)   

INFO:root:GPU available: True, used: True
INFO:root:VISIBLE GPUS: 0
INFO:root:MNIST will be resized to (32, 32).
INFO:root:MNIST will be resized to (32, 32).
INFO:root:
   | Name              | Type              | Params
----------------------------------------------------
0  | generator         | ConvGenerator     | 1 M   
1  | generator.0       | UnsqueezeLatent   | 0     
2  | generator.1       | UpsampleConv2d    | 525 K 
3  | generator.1.0     | ConvTranspose2d   | 524 K 
4  | generator.1.1     | BatchNorm2d       | 512   
5  | generator.1.2     | ReLU              | 0     
6  | generator.2       | UpsampleConv2d    | 524 K 
7  | generator.2.0     | ConvTranspose2d   | 524 K 
8  | generator.2.1     | BatchNorm2d       | 256   
9  | generator.2.2     | ReLU              | 0     
10 | generator.3       | UpsampleConv2d    | 131 K 
11 | generator.3.0     | ConvTranspose2d   | 131 K 
12 | generator.3.1     | BatchNorm2d       | 128   
13 | generator.3.2     | ReLU              | 0    

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=1.0), HTML(value='')), …

INFO:root:Detected KeyboardInterrupt, attempting graceful shutdown...





1

In [None]:
#hide
notebook2script()
clear_output()