In [0]:
# target
target = 'landscape'
target_JP = '風景'

# hyper params
EPOCHS = (24, 32, 48, 64, 64, 64, 64)
BATCH_SIZES = (768, 512, 384, 256, 192, 128, 96)
initial_scale = 1.
GPU = 0

# extension params
log_interval = 500
display_interval = None
tweet_interval = 1000
snapshot_interval = 1000

# model params
sa_gamma = 1.
SCALEUP_ALPHA = 1
START = 0
IMG_SIZE = 256
LATENT_SIZE = 512
CH_SIZE = 64

# learning controller
learning_rate = 2e-4
grad_clip = None
gen_weight_decay = 0
dis_weight_decay = 0
sa_endpoint = 1000
sg_endpoint = 1

# file names
load_gen = None
load_dis = None

OUT = './result/'
dataroot = './picture/train_pic/flickr/landscape'

gen_name = '{}_gen'.format(target)
dis_name = '{}_dis'.format(target)

In [0]:
import numpy as np
import io
import uuid
import pickle
import tweepy
from tqdm.notebook import tqdm

from PIL import Image
import matplotlib.pyplot as plt
from IPython.display import clear_output

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import spectral_norm, clip_grad_norm_
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms

import twitter_api_key

In [0]:
def gaussian(size):
    return torch.normal(torch.zeros(size), torch.ones(size))

def pixel_noise(x, k):
    return torch.normal(x, torch.abs(x)*k)

def zeropad(x, ch):
    return F.pad(x, (0, 0, 0, 0, 0, ch-x.size(1), 0, 0))

def gap(x):
    return F.avg_pool2d(x, x.size()[-2:])

def noise_injection(x, k):
    return torch.normal(x, instance_var(x)*k)

def instance_var(x):
    _shape = x.size()
    _x = x.view(_shape[0], _shape[1], -1)
    _x = torch.var(_x, dim=2)
    _x = _x.view(*_x.size(), 1, 1)
    _x = _x.expand(*_shape)
    return _x

def pixel_norm(x):
    return x * torch.rsqrt((x**2).mean(1, keepdim=True) + 1e-8)

In [0]:
class SNConv(nn.Module):
    def __init__(self, in_ch, out_ch, kernel=1, stride=1, padding=0):
        super(SNConv, self).__init__()
        self.main = nn.Sequential(
            nn.LeakyReLU(0.2),
            spectral_norm(nn.Conv2d(in_ch, out_ch, kernel, stride, padding))
        )
    
    def forward(self, x):
        return self.main(x)

In [0]:
class SNRes(nn.Module):
    def __init__(self, in_ch, out_ch, kernel=3, stride=1, padding=1):
        super(SNRes, self).__init__()
        self.main = SNConv(in_ch, out_ch, kernel, stride, padding)
    
    def forward(self, x):
        return x + self.main(x)

In [0]:
class SNDense(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(SNDense, self).__init__()
        self.main = nn.Sequential(
            nn.LeakyReLU(0.2),
            spectral_norm(nn.Linear(in_ch, out_ch))
        )
        
    def forward(self, x):
        return self.main(x)

In [0]:
class SNSelfAttentionBlock(nn.Module):
    def __init__(self, in_ch, out_ch, gamma=1.):
        super(SNSelfAttentionBlock, self).__init__()

        self.gamma = gamma
        self.cf = SNConv(in_ch, out_ch//8)
        self.cg = SNConv(in_ch, out_ch//8)
        self.ch = SNConv(in_ch, out_ch)
        self.softmax = nn.Softmax(2)
            
    def forward(self, x):
        f = self.cf(x)
        g = self.cg(x)
        h = self.ch(x)
        f = f.view(f.size(0), f.size(1), -1)
        g = g.view(g.size(0), g.size(1), -1)
        h = h.view(h.size(0), h.size(1), -1)
        
        attention_map = torch.bmm(torch.transpose(f, 1, 2), g)
        attention_map = self.softmax(attention_map)
        feature_map = torch.bmm(h, torch.transpose(attention_map, 1, 2))
        feature_map = feature_map.view(*x.size())

        return x + feature_map*self.gamma

    def set_gamma(self, gamma):
        self.gamma = gamma

In [0]:
class MinibatchDiscrimination(nn.Module):
    def __init__(self, in_ch, kernel, kernel_dims, device):
        super(MinibatchDiscrimination, self).__init__()
        self.device = device
        self.kernel = kernel
        self.dim = kernel_dims
        self.t = nn.Linear(in_ch, self.kernel*self.dim, bias=False)
        for param in self.t.parameters():
            param.requires_grad = False

    def __call__(self, x):
        batchsize = x.size(0)
        m = self.t(x).view(batchsize, self.kernel, self.dim, 1)
        m_T = torch.transpose(m, 0, 3)
        m, m_T = torch.broadcast_tensors(m, m_T)
        norm = torch.sum(F.l1_loss(m, m_T, reduction='none'), dim=2)

        eraser = torch.eye(batchsize, device=self.device).view(batchsize, 1, batchsize).expand(norm.size())
        c_b = torch.exp(-(norm + 1e6 * eraser))
        o_b = torch.sum(c_b, dim=2)
        h = torch.cat((x, o_b), dim=1)
        return h

In [0]:
class InceptionBlock(nn.Module):
    def __init__(self, in_ch, mid_ch, out_ch):
        super(InceptionBlock, self).__init__()

        self.layer_11 = SNConv(in_ch, mid_ch)
        self.layer_33_1 = nn.Sequential(
            SNConv(in_ch, mid_ch),
            SNConv(mid_ch, mid_ch, 3, 1, 1)
        )
        self.layer_33_2 = nn.Sequential(
            SNConv(in_ch, mid_ch),
            SNConv(mid_ch, mid_ch, 3, 1, 1),
            SNRes(mid_ch, mid_ch)
        )

        self.c = SNConv(mid_ch*3, out_ch)
            
    def forward(self, x):
        h_11 = self.layer_11(x)
        h_33_1 = self.layer_33_1(x)
        h_33_2 = self.layer_33_2(x)
        
        h = torch.cat((h_11, h_33_1, h_33_2), dim=1)
        h = self.c(h)
        return h

In [0]:
class InceptionResBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(InceptionResBlock, self).__init__()
        self.ch = out_ch
        self.inception = InceptionBlock(in_ch, self.ch//4, self.ch)
            
    def forward(self, x):
        _h = self.inception(x)
        h = zeropad(x, self.ch)
        return h + _h

In [0]:
class DiscriminatorBlock(nn.Module):
    def __init__(self, in_ch, out_ch, sa_gamma=None):
        super(DiscriminatorBlock, self).__init__()
        self.main = nn.Sequential(
            InceptionResBlock(in_ch, in_ch*3//2),
            InceptionResBlock(in_ch*3//2, in_ch*2)
        )
        self.sa = SNSelfAttentionBlock(in_ch*2, in_ch*2, gamma=sa_gamma) if sa_gamma else None

        self.c = SNConv(in_ch*2, out_ch)
            
    def forward(self, x):
        h = self.main(x)

        if self.sa:
            h = self.sa(h)

        h = self.c(h)
        return h

    def set_gamma(self, gamma):
        if self.sa:
            self.sa.set_gamma(gamma)

In [0]:
class Discriminator(nn.Module):

    def __init__(self, out_ch=2, alpha=0.5, device=None, sa_gamma=1.):
        super(Discriminator, self).__init__()
        self.alpha = alpha

        self.in_256 = spectral_norm(nn.Conv2d(3, 32, 1, 1, 0))
        self.layer_256 = nn.Sequential(
            DiscriminatorBlock(32, 64),
            SNConv(64, 64, 4, 2, 1)
        )

        self.in_128 = spectral_norm(nn.Conv2d(3, 64, 1, 1, 0))
        self.layer_128 = nn.Sequential(
            DiscriminatorBlock(64, 128),
            SNConv(128, 128, 4, 2, 1)
        )

        self.in_64 = spectral_norm(nn.Conv2d(3, 128, 1, 1, 0))
        self.layer_64 = nn.Sequential(
            DiscriminatorBlock(128, 256),
            SNConv(256, 256, 4, 2, 1)
        )

        self.in_32 = spectral_norm(nn.Conv2d(3, 256, 1, 1, 0))
        self.layer_32 = nn.Sequential(
            DiscriminatorBlock(256, 512),
            SNConv(512, 512, 4, 2, 1)
        )

        self.in_16 = spectral_norm(nn.Conv2d(3, 512, 1, 1, 0))
        self.layer_16 = nn.Sequential(
            DiscriminatorBlock(512, 512, sa_gamma=sa_gamma),
            SNConv(512, 512, 4, 2, 1)
        )

        self.in_8 = spectral_norm(nn.Conv2d(3, 512, 1, 1, 0))
        self.layer_8 = nn.Sequential(
            DiscriminatorBlock(512, 512, sa_gamma=sa_gamma),
            SNConv(512, 512, 4, 2, 1)
        )

        self.in_4 =  spectral_norm(nn.Conv2d(3, 512, 1, 1, 0))
        self.layer_4 = nn.Sequential(
            DiscriminatorBlock(512, 512, sa_gamma=sa_gamma),
            nn.AvgPool2d(4),
            nn.Flatten(),
            MinibatchDiscrimination(512, 64, 16, device),
            SNDense(512+64, out_ch)
        )
    
    def forward(self, x, img_size, delta=None):

        if img_size >= 256:
            h = self.in_256(x)
            h = self.layer_256(h)
        else:
            h = 0
        
        if img_size >= 128:
            if img_size == 128:
                _x = F.avg_pool2d(x, img_size//128)
                h = h + self.in_128(_x)
            elif delta and img_size == 256:
                _x = F.avg_pool2d(x, img_size//128)
                h = torch.lerp(self.in_128(_x), h, delta)
            h = self.layer_128(h)
        
        if img_size >= 64:
            if img_size == 64:
                _x = F.avg_pool2d(x, img_size//64)
                h = h + self.in_64(_x)
            elif delta and img_size == 128:
                _x = F.avg_pool2d(x, img_size//64)
                h = torch.lerp(self.in_64(_x), h, delta)
            h = self.layer_64(h)
        
        if img_size >= 32:
            if img_size == 32:
                _x = F.avg_pool2d(x, img_size//32)
                h = h + self.in_32(_x)
            elif delta and img_size == 64:
                _x = F.avg_pool2d(x, img_size//32)
                h = torch.lerp(self.in_32(_x), h, delta)
            h = self.layer_32(h)
        
        if img_size >= 16:
            if img_size == 16:
                _x = F.avg_pool2d(x, img_size//16)
                h = h + self.in_16(_x)
            elif delta and img_size == 32:
                _x = F.avg_pool2d(x, img_size//16)
                h = torch.lerp(self.in_16(_x), h, delta)
            h = self.layer_16(h)

        if img_size >= 8:
            if img_size == 8:
                _x = F.avg_pool2d(x, img_size//8)
                h = h + self.in_8(_x)
            elif delta and img_size == 16:
                _x = F.avg_pool2d(x, img_size//8)
                h = torch.lerp(self.in_8(_x), h, delta)
            h = self.layer_8(h)

        if img_size == 4:
            _x = F.avg_pool2d(x, img_size//4)
            h = h + self.in_4(_x)
        elif delta and img_size == 8:
            _x = F.avg_pool2d(x, img_size//4)
            h = torch.lerp(self.in_4(_x), h, delta)
        h = self.layer_4(h)

        return h

    def set_gamma(self, gamma, img_size):
        if img_size == 8:
            self.layer_4[0].set_gamma(gamma)
        if img_size == 16:
            self.layer_8[0].set_gamma(gamma)

In [0]:
class BNConv(nn.Module):
    def __init__(self, in_ch, out_ch, kernel=1, stride=1, padding=0):
        super(BNConv, self).__init__()
        self.main = nn.Sequential(
            nn.LeakyReLU(0.2),
            nn.Conv2d(in_ch, out_ch, kernel, stride, padding),
            nn.BatchNorm2d(out_ch)
        )
    
    def forward(self, x):
        return self.main(x)

In [0]:
class BNDense(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(BNDense, self).__init__()
        self.main = nn.Sequential(
            nn.LeakyReLU(0.2),
            nn.Linear(in_ch, out_ch),
            nn.BatchNorm1d(out_ch)
        )
        
    def forward(self, x):
        return self.main(x)

In [0]:
class Affine(nn.Module):
    def __init__(self, in_ch, mid_ch=64, w_ch=128):
        super(Affine, self).__init__()
        
        self.main = nn.Sequential(
            nn.Linear(in_ch, mid_ch),
            nn.BatchNorm1d(mid_ch),
            BNDense(mid_ch, mid_ch),
            BNDense(mid_ch, mid_ch),
            BNDense(mid_ch, mid_ch),
            BNDense(mid_ch, mid_ch),
            BNDense(mid_ch, mid_ch),
            BNDense(mid_ch, mid_ch),
            BNDense(mid_ch, w_ch)
        )

    def forward(self, x):
        return self.main(pixel_norm(x))

In [0]:
class BNSelfAttentionBlock(nn.Module):
    def __init__(self, in_ch, out_ch, gamma=1.):
        super(BNSelfAttentionBlock, self).__init__()

        self.gamma = gamma
        self.cf = BNConv(in_ch, out_ch//8)
        self.cg = BNConv(in_ch, out_ch//8)
        self.ch = BNConv(in_ch, out_ch)
        self.softmax = nn.Softmax(2)
            
    def forward(self, x):
        f = self.cf(x)
        g = self.cg(x)
        h = self.ch(x)
        f = f.view(f.size(0), f.size(1), -1)
        g = g.view(g.size(0), g.size(1), -1)
        h = h.view(h.size(0), h.size(1), -1)
        
        attention_map = torch.bmm(torch.transpose(f, 1, 2), g)
        attention_map = self.softmax(attention_map)
        feature_map = torch.bmm(h, torch.transpose(attention_map, 1, 2))
        feature_map = feature_map.view(*x.size())

        return x + feature_map*self.gamma

    def set_gamma(self, gamma):
        self.gamma = gamma

In [0]:
class SEBlock(nn.Module):
    def __init__(self, in_ch, mid_ch):
        super(SEBlock, self).__init__()

        self.main = nn.Sequential(
            BNDense(in_ch, mid_ch),
            BNDense(mid_ch, in_ch),
            nn.Sigmoid()
        )

    def forward(self, x):
        h = self.main(x)
        return x*h

In [0]:
class NoiseInjection(nn.Module):
    def __init__(self, in_ch, out_ch, device=None):
        super(NoiseInjection, self).__init__()
        self.device = device

        self.var = nn.Sequential(
            SEBlock(in_ch, in_ch//4),
            SNDense(in_ch, out_ch)
        )

    def forward(self, x, w):
        var = self.var(w)
        var = var.view(*var.size(), 1, 1).expand(*x.size())
        noise = gaussian(x.size()).to(self.device)
        return x + noise * var

In [0]:
class AdaIN(nn.Module):
    def __init__(self, in_ch, out_ch, device=None):
        super(AdaIN, self).__init__()

        self.noise_injection = NoiseInjection(in_ch, out_ch, device)
        self.se = SEBlock(in_ch, in_ch//4)
        self.average_convert = SNDense(in_ch, out_ch)
        self.bias_convert = SNDense(in_ch, out_ch)

    def forward(self, x, w):
        h = F.instance_norm(x)
        h = self.noise_injection(h, w)
        
        _w = self.se(w)
        a = self.average_convert(_w)
        a = a.view(*a.size(), 1, 1).expand(*h.size())
        b = self.bias_convert(_w)
        b = b.view(*b.size(), 1, 1).expand(*h.size())
        h = h * a + b
        return h

In [0]:
class DepthwiseCondConv(nn.Module):
    def __init__(self, in_ch, n_kernels, kernel=3, stride=1, padding=1, device=None, k=1):
        super(DepthwiseCondConv, self).__init__()
        self.device = device

        self.kernel = kernel
        self.stride = stride
        self.padding = padding
        self.ch = in_ch
        self.kernels_size = (n_kernels, kernel, kernel)

        self.weight = nn.Parameter(torch.empty(n_kernels*kernel*kernel, k, 1))
        nn.init.xavier_normal_(self.weight)
    
    def forward(self, x, sigma, bias=None):
        b_size = x.size(0)

        h = F.leaky_relu(x, negative_slope=0.2)

        kernels = self.weight.expand(-1, -1, self.ch)
        i = torch.eye(self.ch).to(self.device).view(1, self.ch, self.ch).expand(kernels.size(0), -1, -1)
        kernels = i*kernels
        kernels = kernels.view(*self.kernels_size, self.ch, -1)
        kernels = kernels.transpose(1, 3).transpose(2, 4)
        kernels = kernels.expand(b_size, -1, -1, -1, -1, -1)

        _s = sigma.view(*kernels.size()[:3], 1, 1, 1)
        _s = _s.expand(-1, -1, -1, self.ch, -1, -1)
        _s = torch.sum(kernels*_s, dim=1)
        _s = _s.reshape(b_size, self.ch, -1)

        h = F.unfold(h, self.kernel, padding=self.padding, stride=self.stride)
        h = torch.bmm(_s, h)
        h = h.view(b_size, self.ch, *x.size()[2:])

        if bias is not None:
            _b = bias.view(-1, self.ch, 1, 1)
            h = h + _b
        
        return h

In [0]:
class SynthBlock(nn.Module):
    def __init__(self, in_ch, out_ch, latent_w, kernels, device=None):
        super(SynthBlock, self).__init__()

        self.se = SEBlock(latent_w, latent_w//4)
        self.sigma = nn.Sequential(
            SNDense(latent_w, in_ch*kernels),
            nn.Tanh()
        )
        self.bias = SNDense(latent_w, in_ch)

        self.dcc = spectral_norm(DepthwiseCondConv(in_ch, kernels, device=device))
        self.c = SNConv(in_ch, out_ch)
        self.adain = AdaIN(latent_w, out_ch, device)
            
    def forward(self, x, w):
        _w = self.se(w)
        sigma = self.sigma(_w)
        bias = self.bias(_w)

        h = self.dcc(x, sigma, bias)
        h = self.c(h)
        h = self.adain(h, w)
        return h

In [0]:
class GeneratorBlock(nn.Module):
    def __init__(self, in_ch, out_ch, latent_w, kernels, device=None, sa_gamma=None):
        super(GeneratorBlock, self).__init__()
        self.w_size = latent_w

        self.synth_0 = SynthBlock(in_ch, out_ch, latent_w, kernels, device)
        self.synth_1 = SynthBlock(out_ch, out_ch, latent_w, kernels, device)
        self.synth_2 = SynthBlock(out_ch, out_ch, latent_w, kernels, device)
        self.sa = SNSelfAttentionBlock(out_ch, out_ch, gamma=sa_gamma) if sa_gamma else None
        
        self.c_out = SNConv(out_ch, 3)
            
    def forward(self, x, w):
        _w = w[:, :self.w_size]

        h = self.synth_0(x, _w)
        h = h + self.synth_1(h, _w)
        h = h + self.synth_2(h, _w)
        if self.sa:
            h = self.sa(h)
            
        out = self.c_out(h)
        return h, out

    def set_gamma(self, gamma):
        if self.sa:
            self.sa.set_gamma(gamma)

In [0]:
class Generator(nn.Module):

    def __init__(self, latent_size, ch_size, alpha=0.5, latent_blur=0.1, device=None, sa_gamma=1., kernels=4):
        super(Generator, self).__init__()
        self.alpha = alpha
        self.latent_blur = latent_blur
        incremental_size = (latent_size-224)//32

        self.btm = nn.Parameter(torch.empty(1, ch_size, 4, 4))
        nn.init.normal_(self.btm)
        
        self.affine = Affine(latent_size, latent_size, latent_size)
        
        self.res6 = GeneratorBlock(ch_size, ch_size, 32, kernels, device, sa_gamma)
        self.res5 = GeneratorBlock(ch_size, ch_size, 64+incremental_size, kernels,  device, sa_gamma)
        self.res4 = GeneratorBlock(ch_size, ch_size, 96+incremental_size*2, kernels, device)
        self.res3 = GeneratorBlock(ch_size, ch_size, 128+incremental_size*4, kernels, device)
        self.res2 = GeneratorBlock(ch_size, ch_size, 160+incremental_size*8, kernels, device)
        self.res1 = GeneratorBlock(ch_size, ch_size, 192+incremental_size*16, kernels, device)
        self.res0 = GeneratorBlock(ch_size, ch_size, latent_size, kernels, device)

        self.upbi = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
    
    def forward(self, x, img_size, delta=None, mixing=False):
        w = self.affine(x)
        h = self.btm.expand(x.size(0), -1, -1, -1)
        h, out = self.res6(h, w)
        
        if img_size >= 8:
            h = self.upbi(h)
            _x = pixel_noise(x, self.latent_blur)
            w = self.affine(_x) if mixing else w
            h, _out = self.res5(h, w)
            if delta and img_size == 8:
                out = torch.lerp(self.upbi(out), _out, delta)
            else:
                out = _out
        
        if img_size >= 16:
            h = self.upbi(h)
            _x = pixel_noise(x, self.latent_blur)
            w = self.affine(_x) if mixing else w
            h, _out = self.res4(h, w)
            if delta and img_size == 16:
                out = torch.lerp(self.upbi(out), _out, delta)
            else:
                out = _out

        if img_size >= 32:
            h = self.upbi(h)
            _x = pixel_noise(x, self.latent_blur)
            w = self.affine(_x) if mixing else w
            h, _out = self.res3(h, w)
            if delta and img_size == 32:
                out = torch.lerp(self.upbi(out), _out, delta)
            else:
                out = _out

        if img_size >= 64:
            h = self.upbi(h)
            _x = pixel_noise(x, self.latent_blur)
            w = self.affine(_x) if mixing else w
            h, _out = self.res2(h, w)
            if delta and img_size == 64:
                out = torch.lerp(self.upbi(out), _out, delta)
            else:
                out = _out
        
        if img_size >= 128:
            h = self.upbi(h)
            _x = pixel_noise(x, self.latent_blur)
            w = self.affine(_x) if mixing else w
            h, _out = self.res1(h, w)
            if delta and img_size == 128:
                out = torch.lerp(self.upbi(out), _out, delta)
            else:
                out = _out
        
        if img_size >= 256:
            h = self.upbi(h)
            _x = pixel_noise(x, self.latent_blur)
            w = self.affine(_x) if mixing else w
            h, _out = self.res0(h, w)
            if delta and img_size == 256:
                out = torch.lerp(self.upbi(out), _out, delta)
            else:
                out = _out

        return torch.tanh(out)

    def set_gamma(self, gamma, img_size):
        if img_size == 8:
            self.res6.set_gamma(gamma)
        if img_size == 16:
            self.res5.set_gamma(gamma)

In [0]:
device = torch.device("cuda:{}".format(GPU))

In [0]:
def weights_init(m):
    if type(m) in (nn.Linear, nn.Conv2d):
        nn.init.xavier_normal_(m.weight, gain=initial_scale)
        if m.bias is not None:
            m.bias.data.fill_(0)
    elif type(m) in (nn.BatchNorm1d, nn.BatchNorm2d):
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [0]:
gen = Generator(LATENT_SIZE, CH_SIZE, alpha=SCALEUP_ALPHA, device=device)
if load_gen:
    gen.load_state_dict(torch.load(OUT+'{}_{}px_{}epoch.pkl'.format(gen_name, *load_weight)))
    gen.eval()
else:
    gen.apply(weights_init)

dis = Discriminator(out_ch=1, alpha=SCALEUP_ALPHA, device=device)
if load_dis:
    dis.load_state_dict(torch.load(OUT+'{}_{}px_{}epoch.pkl'.format(dis_name, *load_weight)))
    dis.eval()
else:
    dis.apply(weights_init)

gen.to(device)
dis.to(device)

Discriminator(
  (in_256): Conv2d(3, 32, kernel_size=(1, 1), stride=(1, 1))
  (layer_256): Sequential(
    (0): DiscriminatorBlock(
      (main): Sequential(
        (0): InceptionResBlock(
          (inception): InceptionBlock(
            (layer_11): SNConv(
              (main): Sequential(
                (0): LeakyReLU(negative_slope=0.2)
                (1): Conv2d(32, 12, kernel_size=(1, 1), stride=(1, 1))
              )
            )
            (layer_33_1): Sequential(
              (0): SNConv(
                (main): Sequential(
                  (0): LeakyReLU(negative_slope=0.2)
                  (1): Conv2d(32, 12, kernel_size=(1, 1), stride=(1, 1))
                )
              )
              (1): SNConv(
                (main): Sequential(
                  (0): LeakyReLU(negative_slope=0.2)
                  (1): Conv2d(12, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                )
              )
            )
            (layer_33_2): Sequential(
 

In [0]:
opt_gen = optim.Adam(
    gen.parameters(),
    lr=learning_rate,
    betas=(0.5, 0.999),
    weight_decay=gen_weight_decay
    )
opt_dis = optim.Adam(
    dis.parameters(),
    lr=learning_rate,
    betas=(0.5, 0.999),
    weight_decay=dis_weight_decay
    )

In [0]:
gen_losses = list()
dis_losses = list()

In [0]:
if tweet_interval:
    auth = tweepy.OAuthHandler(twitter_api_key.CONSUMER_KEY, twitter_api_key.CONSUMER_SECRET)
    auth.set_access_token(twitter_api_key.ACCESS_TOKEN_KEY, twitter_api_key.ACCESS_TOKEN_SECRET)
    api = tweepy.API(auth)

In [0]:
def report_log(i, epoch, g_loss, d_loss, g_mean, d_real_mean, d_fake_mean):
    print('[{}/{}]\tLoss_D: {:.4f}\tLoss_G: {:.4f}\tD(x): {:.4f}\tD(G(z)): {:.4f}/{:.4f}'.format(
        i, epoch, d_loss, g_loss, d_real_mean, d_fake_mean, g_mean
    ))
    gen_losses.append(g_loss)
    dis_losses.append(d_loss)

In [0]:
def make_image(gen, img_size, device, delta=None, mixing=False):
    clear_output()
    with torch.no_grad():
        generated = gen(gaussian((8, LATENT_SIZE)).to(device), img_size, delta, mixing).detach().cpu()
    generated = np.transpose(np.reshape(generated, (-1, 3, img_size, img_size)), (0, 2, 3, 1))

    plt.figure(figsize=(16, 8))
        
    for i, img in enumerate(generated):
        plt.subplot(2, 4, i+1).axis('off')
        plt.subplot(2, 4, i+1).imshow(Image.fromarray(np.uint8((img+1.)/2. *255.)))

    plt.show()

In [0]:
def image_upload(image_array, api):
    bin_io = io.BytesIO()
    img = Image.fromarray(np.uint8((image_array+1.)/2. *255.))
    img = img.resize((512, 512), resample=0)
    img.save(bin_io, format='JPEG')
    result = api.media_upload(filename='{}_generated_{}.jpg'.format(target, uuid.uuid4()), file=bin_io)
    return result.media_id

In [0]:
def post_image(gen, api, img_size, iteration, epoch, device, delta=None, mixing=False):
    with torch.no_grad():
        generated = gen(gaussian((16, LATENT_SIZE)).to(device), img_size, delta, mixing).detach().cpu()
    generated = np.reshape(generated, (4, 2, 2, 3, img_size, img_size))
    generated = np.transpose(generated, (0, 3, 1, 4, 2, 5))
    generated = np.reshape(generated, (4, 3, img_size*2, img_size*2))
    generated = np.transpose(generated, (0, 2, 3, 1))

    try:
        img_ids = [image_upload(img, api) for img in generated]
        hash_tags = ['AIで{}を作る'.format(target_JP),
                    'iteration/epoch: {}/{}'.format(iteration+1, epoch+1),
                    '#makeing{}'.format(target),
                    '#nowlearning...',
                    '#AI',
                    '#人工知能',
                    '#DeepLearning',
                    '#GAN']

        api.update_status(
            status='\n'.join(hash_tags),
            media_ids=img_ids
            )
    except Exception:
        pass

In [0]:
def save_model(gen, dis, img_size, epoch):
    torch.save(gen.state_dict(), OUT+'{}_{}px_{}epoch.pkl'.format(gen_name, img_size, epoch))
    torch.save(dis.state_dict(), OUT+'{}_{}px_{}epoch.pkl'.format(dis_name, img_size, epoch))

In [0]:
def report_result():
    plt.figure(figsize=(16, 8))
    plt.plot(gen_losses,label="G")
    plt.plot(dis_losses,label="D")
    plt.xlabel("iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.show()

In [0]:
def round_dataset(gen, dis, opt_gen, opt_dis, dataloader, epoch, img_size, delta, device, mixing=False):
    loss_fun = nn.MSELoss()
    data_len = len(dataloader)
    for i, data in tqdm(enumerate(dataloader, 0)):
        _delta = None
        dis.zero_grad()
        x_real = data[0].to(device)
        b_size = x_real.size(0)
        y_real = dis(x_real, img_size, _delta).view(-1)
        real_loss = loss_fun(y_real, torch.ones(*y_real.size(), device=device))
        real_loss.backward()

        x_fake = gen(gaussian((b_size, LATENT_SIZE)).to(device), img_size, _delta, mixing)
        y_fake = dis(x_fake.detach(), img_size, _delta).view(-1)
        fake_loss = loss_fun(y_fake, torch.zeros(*y_fake.size(), device=device))
        fake_loss.backward()

        if grad_clip:
            clip_grad_norm_(dis.parameters(), grad_clip)
        dis_loss = real_loss + fake_loss
        opt_dis.step()
        
        gen.zero_grad()
        y_gen = dis(x_fake, img_size, _delta).view(-1)
        gen_loss = loss_fun(y_gen, torch.ones(*y_gen.size(), device=device))
        gen_loss.backward()

        if grad_clip:
            clip_grad_norm_(gen.parameters(), grad_clip)
        opt_gen.step()

        if display_interval and (i+1) % display_interval == 0:
            make_image(gen, img_size, device, _delta, mixing)

        if (i+1) % log_interval == 0:
            report_log(
                i,
                data_len,
                gen_loss.item(),
                dis_loss.item(),
                y_gen.mean().item(),
                y_real.mean().item(),
                y_fake.mean().item()
                )

        if tweet_interval and (i+1) % tweet_interval == 0:
            post_image(gen, api, img_size, i, epoch, device, _delta, mixing)
        
        if snapshot_interval and (i+1) % snapshot_interval == 0:
            save_model(gen, dis, img_size, epoch)

In [0]:
def train_loop(gen, dis, opt_gen, opt_dis, dataloader, epoch, img_size, device, mixing=False):
    for i in range(epoch):
        delta = None
        
        gen.train()
        dis.train()
        round_dataset(gen, dis, opt_gen, opt_dis, dataloader, i, img_size, delta, device, mixing)

In [0]:
def upscaling(gen, dis, opt_gen, opt_dis, device, mixing=False):
    img_size = 4*2**START
    for epoch, batch_size in zip(EPOCHS[START:], BATCH_SIZES[START:]):
        dataset = dset.ImageFolder(root=dataroot,
                                   transform=transforms.Compose([
                                                                 transforms.Resize(img_size),
                                                                 transforms.RandomCrop(img_size),
                                                                 transforms.RandomHorizontalFlip(),
                                                                 transforms.ToTensor(),
                                                                 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                                                 ]))
        dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
        train_loop(gen, dis, opt_gen, opt_dis, dataloader, epoch, img_size, device, mixing)
        img_size = img_size*2
    report_result()

In [0]:
upscaling(gen, dis, opt_gen, opt_dis, device, mixing=True)

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

[499/1358]	Loss_D: 4.3158	Loss_G: 1.9698	D(x): -0.1056	D(G(z)): -0.3846/0.7393
[999/1358]	Loss_D: 2.7611	Loss_G: 0.1319	D(x): -0.5354	D(G(z)): -0.5047/0.6992



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

[499/1358]	Loss_D: 0.0699	Loss_G: 0.8400	D(x): 0.9196	D(G(z)): 0.0606/0.0935
[999/1358]	Loss_D: 0.1587	Loss_G: 0.6076	D(x): 0.6786	D(G(z)): -0.0948/0.2353



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

[499/1358]	Loss_D: 0.2845	Loss_G: 0.3235	D(x): 0.5616	D(G(z)): 0.1038/0.4519
[999/1358]	Loss_D: 0.7720	Loss_G: 0.0929	D(x): 0.1513	D(G(z)): -0.1107/0.7280



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

[499/1358]	Loss_D: 0.4256	Loss_G: 0.1993	D(x): 0.4200	D(G(z)): 0.2362/0.5616
[999/1358]	Loss_D: 0.4590	Loss_G: 0.1894	D(x): 0.4228	D(G(z)): 0.3203/0.5713



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

[499/1358]	Loss_D: 0.6316	Loss_G: 0.6943	D(x): 0.8347	D(G(z)): 0.7689/0.1689
[999/1358]	Loss_D: 0.6596	Loss_G: 0.7539	D(x): 0.8739	D(G(z)): 0.7979/0.1332



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

[499/1358]	Loss_D: 0.8327	Loss_G: 0.8920	D(x): 0.9467	D(G(z)): 0.9076/0.0564
[999/1358]	Loss_D: 1.3884	Loss_G: 1.6413	D(x): 1.2091	D(G(z)): 1.1552/-0.2798



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

[499/1358]	Loss_D: 1.4222	Loss_G: 0.0324	D(x): -0.1730	D(G(z)): -0.2058/1.1728
[999/1358]	Loss_D: 0.7593	Loss_G: 0.0270	D(x): 0.1364	D(G(z)): 0.0913/0.8420



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

[499/1358]	Loss_D: 0.7711	Loss_G: 0.0219	D(x): 0.1302	D(G(z)): 0.0933/0.8628
[999/1358]	Loss_D: 0.6698	Loss_G: 0.0624	D(x): 0.2003	D(G(z)): 0.1546/0.7560



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

[499/1358]	Loss_D: 0.5127	Loss_G: 0.4894	D(x): 0.6746	D(G(z)): 0.6330/0.3022
[999/1358]	Loss_D: 0.4562	Loss_G: 0.2939	D(x): 0.5019	D(G(z)): 0.4400/0.4676



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

[499/1358]	Loss_D: 0.4764	Loss_G: 0.2512	D(x): 0.4867	D(G(z)): 0.4553/0.5022
[999/1358]	Loss_D: 0.6692	Loss_G: 0.8304	D(x): 0.8601	D(G(z)): 0.8000/0.0905



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

[499/1358]	Loss_D: 0.8580	Loss_G: 1.0570	D(x): 0.9727	D(G(z)): 0.9215/-0.0267
[999/1358]	Loss_D: 0.4674	Loss_G: 0.2421	D(x): 0.4664	D(G(z)): 0.4136/0.5115



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

[499/1358]	Loss_D: 0.4842	Loss_G: 0.2468	D(x): 0.4834	D(G(z)): 0.4607/0.5053
[999/1358]	Loss_D: 0.5699	Loss_G: 0.1137	D(x): 0.2869	D(G(z)): 0.2176/0.6730



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

[499/1358]	Loss_D: 0.4390	Loss_G: 0.2955	D(x): 0.5253	D(G(z)): 0.4510/0.4601
[999/1358]	Loss_D: 0.4365	Loss_G: 0.3180	D(x): 0.4497	D(G(z)): 0.3422/0.4400



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

[499/1358]	Loss_D: 0.4213	Loss_G: 0.3781	D(x): 0.5029	D(G(z)): 0.4051/0.3881
[999/1358]	Loss_D: 0.7864	Loss_G: 1.0583	D(x): 1.0199	D(G(z)): 0.8696/-0.0250



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

[499/1358]	Loss_D: 0.5932	Loss_G: 0.1516	D(x): 0.2532	D(G(z)): 0.0314/0.6279
[999/1358]	Loss_D: 0.4281	Loss_G: 0.7429	D(x): 0.7928	D(G(z)): 0.5821/0.1471



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

[499/1358]	Loss_D: 0.3334	Loss_G: 0.5240	D(x): 0.7147	D(G(z)): 0.4462/0.2887
[999/1358]	Loss_D: 0.2715	Loss_G: 0.5234	D(x): 0.6531	D(G(z)): 0.3087/0.2900



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

[499/1358]	Loss_D: 0.3773	Loss_G: 0.9534	D(x): 0.8768	D(G(z)): 0.5316/0.0352
[999/1358]	Loss_D: 0.3163	Loss_G: 0.3456	D(x): 0.5675	D(G(z)): 0.2335/0.4350



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

[499/1358]	Loss_D: 0.3480	Loss_G: 0.3886	D(x): 0.5075	D(G(z)): 0.2242/0.3928
[999/1358]	Loss_D: 0.3184	Loss_G: 0.4193	D(x): 0.6056	D(G(z)): 0.3088/0.3726



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

In [0]:
make_image(gen, 256, device)