# Progressive growing of GANs prototype

To add:
- smooth transitions -> hard
- pixel-wise feature normalisation in generator -> easy

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable, grad
import torch.nn.functional as F
from torchvision  import transforms, datasets

In [2]:
# TO IMPLEMENT
# you can add these to a Sequential

class MinibatchSDLayer(nn.Module):
    def __init__(self):
        super(MinibatchSDLayer, self).__init__()
        
    def forward(self, x):
        mean_batch_std = x.std(0).mean()
        mean_batch_std = mean_batch_std.expand(x.size(0), 1, x.size(-1), x.size(-1))
        return torch.cat([x, mean_batch_std], 1)
    
    
class PixelWiseFeatureNormLayer(nn.Module):
    def __init__(self):
        super(PixelWiseFeatureNormLayer, self).__init__()
        
    def forward(self, x):
        return x
    
    
class SpectralNormLayer(nn.Module):
    def __init__(self):
        super(SpectralNormLayer, self).__init__()
        
    def forward(self, x)

In [3]:
class GrowingGenerator(nn.Module):
    def __init__(self, zdim=100, init_size=4, final_size=128, n_feature_maps=128):
        super(GrowingGenerator, self).__init__()
       
        self.init_size = init_size
        self.final_size = final_size
        init_nfm = 8*n_feature_maps
        
        self.layers = [
            #1x1
            nn.ConvTranspose2d(zdim, init_nfm, 4, 1, 0, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            #4x4
            nn.Conv2d(init_nfm, init_nfm, 3, 1, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True)
            #4x4
        ]
        self.main = nn.Sequential(*self.layers)
        
        self.to_rgb = nn.Conv2d(init_nfm, 3, 1, 1, 0, bias=False)
        self.current_size = init_size
        self.current_nfm = init_nfm
                
    def forward(self, x):
        x = self.main(x)
        x = self.to_rgb(x)
        return F.tanh(x)

    def grow(self):
        if self.current_size == self.final_size:
            print("Network can't grow more")
            return
        
        if self.current_size in [8,32]: # don't decrease everytime because otherwise it's too fast
            future_nfm = self.current_nfm
        else:
            future_nfm = int(self.current_nfm / 2)
            
        block = [
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(self.current_nfm, future_nfm, 3, 1, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(future_nfm, future_nfm, 3, 1, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True)
        ]
        self.layers += block
        self.main = nn.Sequential(*self.layers)
        
        self.current_size *= 2
        self.current_nfm = future_nfm
        self.to_rgb = nn.Conv2d(self.current_nfm, 3, 1, 1, 0, bias=False)
        
        self.new_parameters = nn.Sequential(*block).parameters()
        
        
class GrowingDiscriminator(nn.Module):
    def __init__(self, init_size=4, final_size=128, n_feature_maps=128):
        super(GrowingDiscriminator, self).__init__()
        self.init_size = init_size
        self.final_size = final_size
        init_nfm = 8 * n_feature_maps
        
        self.from_rgb = nn.Sequential(
            nn.Conv2d(3, init_nfm, 1, 1, 0, bias=False),
            nn.LeakyReLU(0.2, inplace=True)
        )
        
        self.layers = [
            #4x4
            nn.Conv2d(init_nfm, init_nfm, 3, 1, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            #4x4
            #nn.Conv2d(init_nfm, init_nfm, 4, 1, 0, bias=False),
            nn.Conv2d(init_nfm, 1, 4, 1, 0, bias=False),
            #nn.LeakyReLU(0.2, inplace=True),
            #1x1
            #nn.Conv2d(init_nfm, 1, 1, 1, 0, bias=False) # equivalent to fully connected
            #nn.Sigmoid()
        ]
        self.main = nn.Sequential(*self.layers)
        
        self.current_size = init_size
        self.current_nfm = init_nfm
        
    def forward(self, x):
        if x.size(3) != self.current_size:
            print("input is of the wrong size (should be {})".format(self.current_size))
            return
        
        x = self.from_rgb(x)
        output = self.main(x)
        return output.view(-1,1).squeeze()
    
    def grow(self):
        if self.current_size == self.final_size:
            print("Network can't grow more")
            return
        
        if self.current_size in [8,32]:
            future_nfm = self.current_nfm
        else:
            future_nfm = int(self.current_nfm / 2)
        
        # if first growing, we had minibatch std
        if self.current_size == self.init_size:
            self.layers = [MinibatchSDLayer()] + self.layers
        
        block = [
            nn.Conv2d(future_nfm, future_nfm, 3, 1, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(future_nfm, self.current_nfm, 3, 1, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.AvgPool2d(2)
        ]
        self.layers = block + self.layers
        self.main = nn.Sequential(*self.layers)
        
        self.current_size *= 2
        self.current_nfm = future_nfm
        self.from_rgb = nn.Conv2d(3, self.current_nfm, 1, 1, 0, bias=False)
        
        self.new_parameters = nn.Sequential(*block).parameters()

In [62]:
batch_size = 64

transform = transforms.Compose(
	[
	    transforms.ToTensor(),
	    transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
	])
dataset = datasets.ImageFolder('paintings64/', transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)

In [125]:
zdim = 100
n_feature_maps = 128
init_size = 4
final_size = 64

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.kaiming_uniform(m.weight.data)
        
G = GrowingGenerator(zdim, init_size, final_size, n_feature_maps)
G.apply(weights_init)
D = GrowingDiscriminator(init_size, final_size, n_feature_maps)
D.apply(weights_init)

GrowingDiscriminator(
  (from_rgb): Sequential(
    (0): Conv2d (3, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (1): LeakyReLU(0.2, inplace)
  )
  (main): Sequential(
    (0): Conv2d (1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): LeakyReLU(0.2, inplace)
    (2): Conv2d (1024, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
  )
)

In [37]:
lr = 1e-3
beta1 = 0
beta2 = 0.99
G_optimiser = optim.Adam(G.parameters(), lr=lr, betas=(beta1, beta2))
D_optimiser = optim.Adam(D.parameters(), lr=lr, betas=(beta1, beta2))

In [114]:
def get_gradient_penalty(real, fake, D, gamma=1):
    batch_size = real.size(0)
    alpha = torch.rand(batch_size,1,1,1)
    alpha = Variable(alpha.expand_as(real))
    interpolation = alpha*real + (1-alpha)*fake # everything is a Variable so interpolation should be one too
    D_itp = D(interpolation)
    gradients = grad(outputs=D_itp, inputs=interpolation, grad_outputs=torch.ones(D_itp.size()),
                                 create_graph=True, retain_graph=True, only_inputs=True)[0]
    GP = ((gradients.norm(2, dim=1) - gamma)**2 / gamma**2).mean()
    return GP

In [53]:
n_epochs = 10
lambda_ = 10
gamma = 750
epsilon_drift = 1e-3
examples_seen = 0
current_size = 4
for epoch in range(n_epochs):
    for img, label in dataloader:
        x = Variable(img)
        if x.size(-1) > current_size:
            ratio = int(x.size(0)/current_size)
            x = F.avg_pool2d(x, ratio)
        
        # D training, n_critic=1
        for p in D.parameters:
            p.requires_grad = True
            
        D.zero_grad
        D_real = D(x)
        
        z = torch.FloatTensor(batch_size, zdim, 1, 1).normal_()
        z = Variable(z)
        fake = G(z)
        D_fake = D(fake.detach())
        
        GP = get_gradient_penalty(x, fake, D, gamma)
        
        D_err = torch.mean(D_real) - torch.mean(D_fake) + lambda_*GP + epsilon_drift*torch.mean(D_real**2)
        D_optimiser.step()
        
        # G training
        for p in D.parameters:
            p.requires_grad = False # saves computation
            
        z = torch.FloatTensor(batch_size, zdim, 1, 1).normal_()
        z = Variable(z)
        fake = G(z)
        G_err = torch.mean(D(fake))
        G_optimiser.step()
        
        examples_seen += img.size(0)
    
    # we grow every 100K images. 600Kin the paper, plus transitions, we'll see
    if examples_seen % 1e5 == 0:
        examples_seen = 0
        current_size *= 2
        G.grow()
        G_optimiser.add_param_group({'params': G.new_parameters})
        D.grow()
        D_optimiser.add_param_group({'params': D.new_parameters})


 0.5407  0.4825  0.5390
 0.4180  0.5692  0.4473
 0.6007  0.4951  0.4712
 0.5359  0.6343  0.4671
 0.5979  0.5003  0.4840
 0.4008  0.5005  0.5363
 0.4923  0.4290  0.3370
 0.3789  0.4962  0.4651
 0.4435  0.6083  0.4501
 0.6282  0.5788  0.4825
 0.4221  0.4190  0.4181
 0.5917  0.4317  0.5368
 0.4120  0.3518  0.5131
 0.4476  0.4498  0.6718
 0.3705  0.5237  0.5341
 0.4924  0.4210  0.5256
 0.4675  0.4763  0.5332
 0.5430  0.4773  0.4204
 0.5775  0.4399  0.5060
 0.5643  0.5750  0.5323
 0.5221  0.4944  0.5235
 0.5078  0.4369  0.4284
 0.4526  0.4259  0.5561
 0.4584  0.5642  0.4732
 0.5480  0.4675  0.3675
 0.5367  0.6425  0.3746
 0.4588  0.4288  0.4503
 0.4756  0.4561  0.6355
 0.5365  0.4252  0.4482
[torch.FloatTensor of size 29x3]

In [4]:
G = GrowingGenerator(zdim=100, init_size=4, final_size=64, n_feature_maps=128)
G.load_state_dict(torch.load('results/saved_data/paintings64_PG_GAN_generator', map_location=lambda storage, loc: storage))

In [5]:
G.current_size

4