# Progressive growing of GANs prototype

- smooth transitions
- minibatch STD
- pixel-wise feature vector normalisation (not in the networks)
- spectral normalisation
- He initialisation
- chrome tire rim
- 7.1 Dolby Surround

In [1]:
import numpy as np
import torch
import torch.nn as nn
from torch.nn.modules import conv
from torch.nn.modules.utils import _pair
import torch.optim as optim
from torch.autograd import Variable, grad
import torch.nn.functional as F
from torchvision  import transforms, datasets

In [2]:
class MinibatchSDLayer(nn.Module):
    def __init__(self):
        super(MinibatchSDLayer, self).__init__()
        
    def forward(self, x):
        mean_batch_std = x.std(0).mean()
        mean_batch_std = mean_batch_std.expand(x.size(0), 1, x.size(-1), x.size(-1))
        return torch.cat([x, mean_batch_std], 1)
    
    
class PixelWiseFeatureNormLayer(nn.Module):
    def __init__(self):
        super(PixelWiseFeatureNormLayer, self).__init__()
        
    def forward(self, x):
        return x / x.norm(2,1)
    
    
def l2normalize(v, eps=1e-12):
    return v / (v.norm() + eps)

def max_singular_value(W, u=None, Ip=1):
    """
    power iteration for weight parameter
    """
    if u is None:
        u = torch.FloatTensor(1, W.size(0)).normal_()
        
    _u = u
    for _ in range(Ip):
        _v = l2normalize(torch.matmul(_u, W), eps=1e-12)
        _u = l2normalize(torch.matmul(_v, torch.transpose(W, 0, 1)), eps=1e-12)
        
    sigma = torch.sum(F.linear(_u, torch.transpose(W, 0, 1)) * _v)
    return sigma, _u
    
    
class SNConv2d(conv._ConvNd):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
        kernel_size = _pair(kernel_size)
        stride = _pair(stride)
        padding = _pair(padding)
        dilation = _pair(dilation)
        super(SNConv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, False, _pair(0), groups, bias)
        self.u = nn.Parameter(torch.Tensor(1, out_channels).normal_(), requires_grad=False)
        #self.u = torch.Tensor(1, out_channels).normal_()

    @property
    def W_(self):
        w_mat = self.weight.view(self.weight.size(0), -1).data
        sigma, _u = max_singular_value(w_mat, self.u.data)
        self.u.data = _u
        return self.weight / sigma

    def forward(self, input):
        return F.conv2d(input, self.W_, self.bias, self.stride, self.padding, self.dilation, self.groups)

In [3]:
class GrowingGenerator(nn.Module):
    def __init__(self, zdim=100, init_size=4, final_size=128, n_feature_maps=128):
        super(GrowingGenerator, self).__init__()
       
        self.init_size = init_size
        self.final_size = final_size
        init_nfm = 8*n_feature_maps
        
        self.layers = [
            #1x1
            nn.ConvTranspose2d(zdim, init_nfm, 4, 1, 0, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            #4x4
            nn.Conv2d(init_nfm, init_nfm, 3, 1, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True)
            #4x4
        ]
        self.main = nn.Sequential(*self.layers)
        self.old = self.main
        
        self.to_rgb = nn.Conv2d(init_nfm, 3, 1, 1, 0, bias=False)
        self.old_rgb = self.to_rgb
        self.current_size = init_size
        self.current_nfm = init_nfm
        
        self.transitioning = False
        
    @property
    def alpha(self):
        return self._alpha
        
    @alpha.setter
    def alpha(self, v):
        if v > 1:
            self.transitioning = False
        self._alpha = v
                
    def forward(self, x):
        if self.transitioning:
            new = self.main(x)
            new = self.to_rgb(new)
            old = F.upsample(self.old(x), scale_factor=2)
            old = self.old_rgb(old)
            x = self.alpha*new + (1-self.alpha)*old
        
        else:   
            x = self.main(x)
            x = self.to_rgb(x)
            
        return F.tanh(x)
    
    def grow(self):
        if self.current_size == self.final_size:
            print("Network can't grow more")
            return
        
        self.transitioning = True
        self._alpha = 0
        
        if self.current_size in [8,32]: # don't decrease everytime because otherwise it's too fast
            future_nfm = self.current_nfm
        else:
            future_nfm = int(self.current_nfm / 2)
            
        self.old = self.main
        self.old_rgb = self.to_rgb
            
        block = [
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(self.current_nfm, future_nfm, 3, 1, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(future_nfm, future_nfm, 3, 1, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True)
        ]
        self.layers += block
        self.main = nn.Sequential(*self.layers)
        
        self.current_size *= 2
        self.current_nfm = future_nfm
        self.to_rgb = nn.Conv2d(self.current_nfm, 3, 1, 1, 0, bias=False)
        
        self.new_parameters = nn.Sequential(*block).parameters()
        
        
class GrowingDiscriminator(nn.Module):
    def __init__(self, init_size=4, final_size=128, n_feature_maps=128):
        super(GrowingDiscriminator, self).__init__()
        self.init_size = init_size
        self.final_size = final_size
        init_nfm = 8 * n_feature_maps
        
        self.from_rgb = nn.Sequential(
            SNConv2d(3, init_nfm, 1, 1, 0, bias=False),
            nn.LeakyReLU(0.2, inplace=True)
        )
        
        self.layers = [
            MinibatchSDLayer(),
            #4x4
            SNConv2d(init_nfm+1, init_nfm, 3, 1, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            #4x4
            #nn.Conv2d(init_nfm, init_nfm, 4, 1, 0, bias=False),
            SNConv2d(init_nfm, 1, 4, 1, 0, bias=False),
            #nn.LeakyReLU(0.2, inplace=True),
            #1x1
            #nn.Conv2d(init_nfm, 1, 1, 1, 0, bias=False) # equivalent to fully connected
            #nn.Sigmoid()
        ]
        self.main = nn.Sequential(*self.layers)
        
        self.current_size = init_size
        self.current_nfm = init_nfm
        
        self.transitioning = False
        
    @property
    def alpha(self):
        return self._alpha
        
    @alpha.setter
    def alpha(self, v):
        if v > 1:
            self.transitioning = False
        self._alpha = v
        
    def forward(self, x):
        if x.size(3) != self.current_size:
            print("input is of the wrong size (should be {})".format(self.current_size))
            return
        
        if self.transitioning:
            new = self.from_rgb(x)
            new = self.new(new)
            old = F.avg_pool2d(x, 2)
            old = self.old_rgb(old)
            x = self.alpha*new + (1-self.alpha)*old
            output = self.old(x)
            
        else:
            x = self.from_rgb(x)
            output = self.main(x)
            
        return output.view(-1,1).squeeze()
    
    def grow(self):
        if self.current_size == self.final_size:
            print("Network can't grow more")
            return
        
        self.transitioning = True
        self.alpha = 0
        
        if self.current_size in [8,32]:
            future_nfm = self.current_nfm
        else:
            future_nfm = int(self.current_nfm / 2)
            
        self.old = self.main
        self.old_rgb = self.from_rgb
        
        block = [
            SNConv2d(future_nfm, future_nfm, 3, 1, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            SNConv2d(future_nfm, self.current_nfm, 3, 1, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.AvgPool2d(2)
        ]
        self.new = nn.Sequential(*block)
        self.layers = block + self.layers
        self.main = nn.Sequential(*self.layers)
        
        self.current_size *= 2
        self.current_nfm = future_nfm
        self.from_rgb = SNConv2d(3, self.current_nfm, 1, 1, 0, bias=False)
        
        self.new_parameters = nn.Sequential(*block).parameters()

In [4]:
batch_size = 64

transform = transforms.Compose(
	[
	    transforms.ToTensor(),
	    transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
	])
dataset = datasets.ImageFolder('paintings64/', transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)

In [5]:
zdim = 100
n_feature_maps = 128
init_size = 4
final_size = 64

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.kaiming_uniform(m.weight.data)
        
G = GrowingGenerator(zdim, init_size, final_size, n_feature_maps)
G.apply(weights_init)
D = GrowingDiscriminator(init_size, final_size, n_feature_maps)
D.apply(weights_init)

GrowingDiscriminator(
  (from_rgb): Sequential(
    (0): SNConv2d (3, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (1): LeakyReLU(0.2, inplace)
  )
  (main): Sequential(
    (0): MinibatchSDLayer(
    )
    (1): SNConv2d (1025, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (2): LeakyReLU(0.2, inplace)
    (3): SNConv2d (1024, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
  )
)

In [6]:
lr = 1e-3
beta1 = 0
beta2 = 0.99
G_optimiser = optim.Adam(G.parameters(), lr=lr, betas=(beta1, beta2))
D_optimiser = optim.Adam(filter(lambda p: p.requires_grad, D.parameters()), lr=lr, betas=(beta1, beta2))

In [7]:
def get_gradient_penalty(real, fake, D, gamma=1, gpu=True):
    batch_size = real.size(0)
    alpha = torch.rand(batch_size,1,1,1)
    alpha = Variable(alpha.expand_as(real))
    if gpu:
        alpha = alpha.cuda()

    interpolation = alpha * real + (1-alpha) * fake # everything is a Variable so interpolation should be one too
    D_itp = D(interpolation)
    if gpu:
        gradients = grad(outputs=D_itp, inputs=interpolation, grad_outputs=torch.ones(D_itp.size()).cuda(), create_graph=True, retain_graph=True, only_inputs=True)[0]
    else:
        gradients = grad(outputs=D_itp, inputs=interpolation, grad_outputs=torch.ones(D_itp.size()), create_graph=True, retain_graph=True, only_inputs=True)[0]

    GP = ((gradients.norm(2, dim=1) - gamma)**2 / gamma**2).mean()
    return GP

In [8]:
n_epochs = 10
lambda_ = 10
gamma = 750
epsilon_drift = 1e-3
examples_seen = 0
current_size = 4
for epoch in range(n_epochs):
    for img, label in dataloader:
        print(examples_seen)
        x = Variable(img)
        if x.size(-1) > current_size:
            ratio = int(x.size(0)/current_size)
            x = F.avg_pool2d(x, ratio)
        
        # D training, n_critic=1
        for p in D.parameters():
            p.requires_grad = True
            
        D.zero_grad
        D_real = D(x)
        
        z = torch.FloatTensor(batch_size, zdim, 1, 1).normal_()
        z = Variable(z)
        fake = G(z)
        D_fake = D(fake.detach())
                
        D_err = torch.mean(D_real) - torch.mean(D_fake)
        D_optimiser.step()
        
        # G training
        for p in D.parameters():
            p.requires_grad = False # saves computation
            
        z = torch.FloatTensor(batch_size, zdim, 1, 1).normal_()
        z = Variable(z)
        fake = G(z)
        G_err = torch.mean(D(fake))
        G_optimiser.step()
        
        examples_seen += img.size(0)
        
        if G.transitioning:
            G.alpha += 1e-3
            D.alpha += 1e-3
            print(G.alpha, D.alpha)
        
        
    # we grow every 100K images. 600Kin the paper, plus transitions, we'll see
        if examples_seen > 200:
            examples_seen = 0
            current_size *= 2
            G.grow()
            G_optimiser.add_param_group({'params': G.new_parameters})
            D.grow()
            D_optimiser.add_param_group({'params': filter(lambda p: p.requires_grad, D.new_parameters)})

0
64
128
192
0
0.001 0.001
64
0.002 0.002
128
0.003 0.003
192
0.004 0.004
0
0.001 0.001
64


Process Process-2:
Process Process-1:
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/francois/anaconda3/envs/pytorch/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/francois/anaconda3/envs/pytorch/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/home/francois/anaconda3/envs/pytorch/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/francois/anaconda3/envs/pytorch/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/home/francois/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 36, in _worker_loop
    r = index_queue.get()
  File "/home/francois/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 36, in _worker_loop
    r = index_queue.get()
  File "/home/fran

KeyboardInterrupt: 

In [12]:
a = np.linspace(0,1,int(5e5/64)+1)
len(a)

7813

In [13]:
30000/64

468.75