In [None]:
#default_exp gan

In [None]:
#hide
from IPython.display import clear_output
from nbdev.export import notebook2script
from dotenv import load_dotenv
%reload_ext autoreload
%autoreload 2

_ = load_dotenv()

In [None]:
#export
import torch
import logging
import functools
from torch import nn
from more_itertools import pairwise
from practical_ai.layers import Identity


logger = logging.getLogger()
logger.setLevel("INFO")

# gan

> 對抗生成網路（Generative Adversarial Network）的相關模組。

In [None]:
#export 
DIM_CHANNEL_MULTIPLIER = 8
KERNEL_SIZE = 4

## Utilities

In [None]:
#export
def get_n_samplings(dim):
    return int(torch.log2(torch.tensor(dim, dtype=torch.float32)).item()) - 2

In [None]:
dim = 32
assert get_n_samplings(32) == 3

In [None]:
#export
def get_norm2d(name):
    if name == "identity":
        return Identity
    elif name == "batch":
        return nn.BatchNorm2d
    elif name == "instance":
        return functools.partial(nn.InstanceNorm2d, affine=True)
    elif name == "layer":
        return lambda num_features: nn.GroupNorm(1, num_features)
    else:
        raise NotImplementedError

In [None]:
norm = get_norm2d("identity")
norm

practical_ai.layers.Identity

In [None]:
#export
def get_activation(name):
    if name == "relu":
        return nn.ReLU()
    elif name == "leaky_relu":
        return nn.LeakyReLU(0.2)
    elif name == "tanh":
        return nn.Tanh()
    else:
        raise NotImplementedError

In [None]:
act = get_activation("relu")
act

ReLU()

## Modules

In [None]:
#export
class UpsampleConv2d(nn.Sequential):
    """ConvTransponse2d -> Normalization -> Activation"""
    
    def __init__(self, 
                 in_channels,
                 out_channels,
                 kernel_size=KERNEL_SIZE,
                 stride=2,
                 padding=1,
                 norm="batch",
                 act="relu",
                 bias=True):
        
        layers = [
            nn.ConvTranspose2d(in_channels, 
                               out_channels, 
                               kernel_size, 
                               stride, 
                               padding,
                               bias=bias)]
        
        if norm != "none":
            layers.append(get_norm2d(norm)(out_channels))
        
        if act not in ["none", "linear"]:
            layers.append(get_activation(act))
        
        super().__init__(*layers)

一般使用： Transposed Conv -> Norm -> Act

In [None]:
in_channels = 3
out_channels = 128
h = w = 32
batch_size = 8

x = torch.randn(batch_size, in_channels, h, w)
upconv = UpsampleConv2d(in_channels, out_channels)
out = upconv(x)

assert len(upconv) == 3
assert out.shape == (batch_size, out_channels, h * 2, w * 2)
upconv

UpsampleConv2d(
  (0): ConvTranspose2d(3, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU()
)

單純上採樣，不使用 Norm 以及非線性 activation

In [None]:
linear_acts = ["none", "linear"]
for act in linear_acts:
    upconv_linear = UpsampleConv2d(in_channels, out_channels, norm="none", act=act)
    print(upconv_linear)
    assert len(upconv_linear) == 1

UpsampleConv2d(
  (0): ConvTranspose2d(3, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
)
UpsampleConv2d(
  (0): ConvTranspose2d(3, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
)


In [None]:
#export
class UnsqueezeLatent(nn.Module):
    def forward(self, x):
        return x[..., None, None]

In [None]:
batch_size = 32
latent_dim = 100

x = torch.randn(batch_size, latent_dim)
out = UnsqueezeLatent()(x)
assert out.shape == (batch_size, latent_dim, 1, 1)

In [None]:
#export
class SqueezeLogit(nn.Module):
    def forward(self, x):
        return x.squeeze(-1).squeeze(-1)

In [None]:
batch_size = 16
dim = 50

x = torch.randn(batch_size, dim, 1, 1)
out = SqueezeLogit()(x)
assert out.shape == (batch_size, dim)

In [None]:
#export
class DownsampleConv2d(nn.Sequential):
    """Conv2D -> Normalization -> Activation"""
    
    def __init__(self, 
                 in_channels,
                 out_channels,
                 kernel_size=KERNEL_SIZE,
                 stride=2,
                 padding=1,
                 norm="batch",
                 act="leaky_relu",
                 bias=True):
        
        layers = [nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias)]
        
        if norm != "none":
            layers.append(get_norm2d(norm)(out_channels))
            
        layers.append(get_activation(act))
        super().__init__(*layers)

基本下採樣

In [None]:
batch_size = 8
in_channels = 3
out_channels = 64
h = w = 32

x = torch.randn(batch_size, in_channels, h, w)
downconv = DownsampleConv2d(in_channels, out_channels)
out = downconv(x)
assert out.shape == (batch_size, out_channels, h / 2, w / 2)

## Generator

In [None]:
#export
class ConvGenerator(nn.Sequential):
    """將輸入維度的隨機向量上採樣到指定大小圖片的生成器"""

    def __init__(self,
                 latent_dim=128,
                 out_dim=32,
                 out_channels=1,
                 kernel_size=4,
                 max_channels=None,
                 norm='batch',
                 act="relu",
                 dim_channel_multiplier=DIM_CHANNEL_MULTIPLIER):
        self.latent_dim = latent_dim
        self.out_dim = out_dim
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.dim_channel_multiplier = dim_channel_multiplier
        self.norm = norm
        self.act = act
        self.max_channels = max_channels if max_channels else self.out_dim * self.dim_channel_multiplier
        
        # decide appropriate number of upsampling process based on expected output image shape
        self.n_upsamples = get_n_samplings(self.out_dim)
        
        # projected to spatial extent convolutional repr. with feature maps
        # x.shape == (batch_size, latent_dim)
        layers = [
            UnsqueezeLatent(),
            UpsampleConv2d(in_channels=self.latent_dim,
                           out_channels=self.max_channels,
                           kernel_size=self.kernel_size,
                           stride=1,  # no need to stride in first layer
                           padding=0,  # no padding in first layer
                           norm=self.norm,
                           act=self.act)]
        
        # upsamples
        # x.shape == (batch_size, max_channels, kernel_size, kernel_size)
        chs = [self.max_channels // (2 ** i) for i in range(self.n_upsamples)]
        chs.append(self.out_channels)
        
        layers.extend([
            UpsampleConv2d(in_channels=in_ch,
                           out_channels=out_ch,
                           kernel_size=self.kernel_size,
                           stride=2,
                           norm=self.norm if i != self.n_upsamples else "none",
                           act=self.act if i != self.n_upsamples else "tanh",
                           bias=False if i != self.n_upsamples else True)
         for i, (in_ch, out_ch) in enumerate(pairwise(chs), 1)])
        # out.shape == (batch_size, out_channels, out_dim, out_dim)
        
        # final act: tanh
        # using a bounded activation allowed the model to learn more quickly to 
        # saturate and cover the color space of the training distribution. 
        
        super().__init__(*layers)

In [None]:
batch_size = 8

for latent_dim, out_dim, out_ch in zip([128, 50, 100], [128, 64, 32], [3, 3, 1]):
    x = torch.randn(batch_size, latent_dim)
    g = ConvGenerator(latent_dim=latent_dim, out_dim=out_dim, out_channels=out_ch)
    out = g(x)
    assert out.shape == (batch_size, out_ch, out_dim, out_dim)
g

ConvGenerator(
  (0): UnsqueezeLatent()
  (1): UpsampleConv2d(
    (0): ConvTranspose2d(100, 256, kernel_size=(4, 4), stride=(1, 1))
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (2): UpsampleConv2d(
    (0): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (3): UpsampleConv2d(
    (0): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (4): UpsampleConv2d(
    (0): ConvTranspose2d(64, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): Tanh()
  )
)

## Discriminator

In [None]:
class ConvDiscriminator(nn.Sequential):
    """將特定大小圖片下採樣的辨識器"""
    
    def __init__(self, 
                 in_channels=1, 
                 in_dim=32, 
                 norm="batch",
                 kernel_size=KERNEL_SIZE,
                 max_channels=None,
                 dim_channel_multiplier=DIM_CHANNEL_MULTIPLIER):
        self.in_channels = in_channels
        self.in_dim = in_dim
        self.norm = norm
        self.kernel_size = kernel_size
        self.n_downsamples = get_n_samplings(self.in_dim)
        self.dim_channel_multiplier = dim_channel_multiplier
        self.max_channels = max_channels if max_channels else self.in_dim * self.dim_channel_multiplier
        
        # downsample
        chs = [self.in_channels]
        chs += sorted([self.max_channels // (2 ** i) for i in range(self.n_downsamples)])
        
        # x.shape == (batch_size, in_channels, in_dim, in_dim)
        layers = [
            DownsampleConv2d(in_ch, 
                             out_ch, 
                             self.kernel_size, 
                             stride=2, 
                             norm=self.norm if i != 1 else "none",
                             bias=False if i != 1 else True)
            for i, (in_ch, out_ch) in enumerate(pairwise(chs), 1)]
        
        # compute logits
        # x.shape == (batch_size, max_channels, kernel_size, kernel_size)
        layers.extend([
            nn.Conv2d(chs[-1], 1, kernel_size=self.kernel_size),
            SqueezeLogit()
        ])
        # out.shape == (batch_size, 1)
        
        super().__init__(*layers)

In [None]:
batch_size = 8

for in_ch, in_dim in zip([3, 3, 1], [128, 64, 32]):
    x = torch.randn(batch_size, in_ch, in_dim, in_dim)
    d = ConvDiscriminator(in_ch, in_dim)
    out = d(x)
    assert out.shape == (batch_size, 1)
d

ConvDiscriminator(
  (0): DownsampleConv2d(
    (0): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2)
  )
  (1): DownsampleConv2d(
    (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2)
  )
  (2): DownsampleConv2d(
    (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2)
  )
  (3): Conv2d(256, 1, kernel_size=(4, 4), stride=(1, 1))
  (4): SqueezeLogit()
)

In [None]:
# from practical_ai.data import get_data_loader, get_dataset
# dataset = get_dataset("mnist", split="train")
# data_loader = get_data_loader(dataset, batch_size=4)



# from torch.utils.tensorboard import SummaryWriter
# import torchvision

# # default `log_dir` is "runs" - we'll be more specific here
# writer = SummaryWriter('runs/test')

# dataiter = iter(data_loader)
# images, labels = dataiter.next()

# img_grid = torchvision.utils.make_grid(images)
# writer.add_image('four_mnist_images', img_grid)

# images.shape

# z = torch.randn(1, latent_dim, 1, 1)

# writer.add_graph(g, z)
# writer.close()

## Generative Adversarial Network

In [None]:
#hide
notebook2script()
clear_output()