In [0]:
# target
target = 'landscape'
target_JP = '風景'

# hyper params
EPOCHS = (8, 16, 32, 32, 32, 32, 32)
BATCH_SIZES = (256, 128, 64, 32, 16, 8, 4)
initial_scale = 1.
GPU = 0

# extension params
log_interval = 1000
display_interval = None
tweet_interval = 4000
snapshot_interval = 4000

# model params
sa_gamma = 1.
gen_noise = 1e-2
SCALEUP_ALPHA = 1
START = 0
IMG_SIZE = 256
LATENT_SIZE = 512

# learning controller
learning_rate = 2e-4
grad_clip = None
gen_weight_decay = 0
dis_weight_decay = 0
sa_endpoint = 1000
sg_endpoint = 1

# file names
load_weight = None

OUT = './result/'
dataroot = './picture/train_pic/flickr/landscape'

gen_name = '{}_gen'.format(target)
dis_name = '{}_dis'.format(target)

In [0]:
import numpy as np
import io
import uuid
import pickle
import tweepy
from tqdm.notebook import tqdm

from PIL import Image
import matplotlib.pyplot as plt
from IPython.display import clear_output

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import spectral_norm, clip_grad_norm_
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms

import twitter_api_key

In [0]:
def gaussian(size):
    return torch.normal(torch.zeros(size), torch.ones(size))

def pixel_noise(x, k):
    return torch.normal(x, torch.abs(x)*k)

def zeropad(x, ch):
    return F.pad(x, (0, 0, 0, 0, 0, ch-x.size(1), 0, 0))

def gap(x):
    return F.avg_pool2d(x, x.size()[-2:])

def noise_injection(x, k):
    return torch.normal(x, instance_var(x)*k)

def instance_var(x):
    _shape = x.size()
    _x = x.view(_shape[0], _shape[1], -1)
    _x = torch.var(_x, dim=2)
    _x = _x.view(*_x.size(), 1, 1)
    _x = _x.expand(*_shape)
    return _x

def pixel_norm(x):
    return x * torch.rsqrt((x**2).mean(1, keepdim=True) + 1e-8)

In [0]:
class SNConv(nn.Module):
    def __init__(self, in_ch, out_ch, kernel=1, stride=1, padding=0):
        super(SNConv, self).__init__()
        self.main = nn.Sequential(
            nn.LeakyReLU(0.2),
            spectral_norm(nn.Conv2d(in_ch, out_ch, kernel, stride, padding))
        )
    
    def forward(self, x):
        return self.main(x)

In [0]:
class SNRes(nn.Module):
    def __init__(self, in_ch, out_ch, kernel=3, stride=1, padding=1):
        super(SNRes, self).__init__()
        self.main = SNConv(in_ch, out_ch, kernel, stride, padding)
    
    def forward(self, x):
        return x + self.main(x)

In [0]:
class SNDense(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(SNDense, self).__init__()
        self.main = nn.Sequential(
            nn.LeakyReLU(0.2),
            spectral_norm(nn.Linear(in_ch, out_ch))
        )
        
    def forward(self, x):
        return self.main(x)

In [0]:
class SelfAttentionBlock(nn.Module):
    def __init__(self, in_ch, out_ch, gamma=1.):
        super(SelfAttentionBlock, self).__init__()

        self.gamma = gamma
        self.cf = SNConv(in_ch, out_ch//8)
        self.cg = SNConv(in_ch, out_ch//8)
        self.ch = SNConv(in_ch, out_ch)
        self.softmax = nn.Softmax(2)
            
    def forward(self, x):
        f = self.cf(x)
        g = self.cg(x)
        h = self.ch(x)
        f = f.view(f.size(0), f.size(1), -1)
        g = g.view(g.size(0), g.size(1), -1)
        h = h.view(h.size(0), h.size(1), -1)
        
        attention_map = torch.bmm(torch.transpose(f, 1, 2), g)
        attention_map = self.softmax(attention_map)
        feature_map = torch.bmm(h, torch.transpose(attention_map, 1, 2))
        feature_map = feature_map.view(*x.size())

        return x + feature_map*self.gamma

    def set_gamma(self, gamma):
        self.gamma = gamma

In [0]:
class MinibatchDiscrimination(nn.Module):
    def __init__(self, in_ch, kernel, kernel_dims, device):
        super(MinibatchDiscrimination, self).__init__()
        self.device = device
        self.kernel = kernel
        self.dim = kernel_dims
        self.t = nn.Linear(in_ch, self.kernel*self.dim, bias=False)
        for param in self.t.parameters():
            param.requires_grad = False

    def __call__(self, x):
        batchsize = x.size(0)
        m = self.t(x).view(batchsize, self.kernel, self.dim, 1)
        m_T = torch.transpose(m, 0, 3)
        m, m_T = torch.broadcast_tensors(m, m_T)
        norm = torch.sum(F.l1_loss(m, m_T, reduction='none'), dim=2)

        eraser = torch.eye(batchsize, device=self.device).view(batchsize, 1, batchsize).expand(norm.size())
        c_b = torch.exp(-(norm + 1e6 * eraser))
        o_b = torch.sum(c_b, dim=2)
        h = torch.cat((x, o_b), dim=1)
        return h

In [0]:
class InceptionBlock(nn.Module):
    def __init__(self, in_ch, mid_ch, out_ch):
        super(InceptionBlock, self).__init__()

        self.layer_11 = SNConv(in_ch, mid_ch)
        self.layer_33_1 = nn.Sequential(
            SNConv(in_ch, mid_ch),
            SNConv(mid_ch, mid_ch, 3, 1, 1)
        )
        self.layer_33_2 = nn.Sequential(
            SNConv(in_ch, mid_ch),
            SNConv(mid_ch, mid_ch, 3, 1, 1),
            SNRes(mid_ch, mid_ch)
        )

        self.c = SNConv(mid_ch*3, out_ch)
            
    def forward(self, x):
        h_11 = self.layer_11(x)
        h_33_1 = self.layer_33_1(x)
        h_33_2 = self.layer_33_2(x)
        
        h = torch.cat((h_11, h_33_1, h_33_2), dim=1)
        h = self.c(h)
        return h

In [0]:
class InceptionResBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(InceptionResBlock, self).__init__()
        self.ch = out_ch
        self.inception = InceptionBlock(in_ch, self.ch//4, self.ch)
            
    def forward(self, x):
        _h = self.inception(x)
        h = zeropad(x, self.ch)
        return h + _h

In [0]:
class DiscriminatorBlock(nn.Module):
    def __init__(self, in_ch, out_ch, sa_gamma=None):
        super(DiscriminatorBlock, self).__init__()
        self.main = nn.Sequential(
            InceptionResBlock(in_ch, in_ch*3//2),
            InceptionResBlock(in_ch*3//2, in_ch*2)
        )
        self.sa = SelfAttentionBlock(in_ch*2, in_ch*2, gamma=sa_gamma) if sa_gamma else None

        self.c = SNConv(in_ch*2, out_ch)
            
    def forward(self, x):
        h = self.main(x)

        if self.sa:
            h = self.sa(h)

        h = self.c(h)
        return h

    def set_gamma(self, gamma):
        if self.sa:
            self.sa.set_gamma(gamma)

In [0]:
class Discriminator(nn.Module):

    def __init__(self, out_ch=2, alpha=0.5, device=None, sa_gamma=1.):
        super(Discriminator, self).__init__()
        self.alpha = alpha

        self.in_256 = spectral_norm(nn.Conv2d(3, 32, 1, 1, 0))
        self.layer_256 = nn.Sequential(
            DiscriminatorBlock(32, 64),
            SNConv(64, 64, 4, 2, 1)
        )

        self.in_128 = spectral_norm(nn.Conv2d(3, 64, 1, 1, 0))
        self.layer_128 = nn.Sequential(
            DiscriminatorBlock(64, 128),
            SNConv(128, 128, 4, 2, 1)
        )

        self.in_64 = spectral_norm(nn.Conv2d(3, 128, 1, 1, 0))
        self.layer_64 = nn.Sequential(
            DiscriminatorBlock(128, 256),
            SNConv(256, 256, 4, 2, 1)
        )

        self.in_32 = spectral_norm(nn.Conv2d(3, 256, 1, 1, 0))
        self.layer_32 = nn.Sequential(
            DiscriminatorBlock(256, 512),
            SNConv(512, 512, 4, 2, 1)
        )

        self.in_16 = spectral_norm(nn.Conv2d(3, 512, 1, 1, 0))
        self.layer_16 = nn.Sequential(
            DiscriminatorBlock(512, 512, sa_gamma=sa_gamma),
            SNConv(512, 512, 4, 2, 1)
        )

        self.in_8 = spectral_norm(nn.Conv2d(3, 512, 1, 1, 0))
        self.layer_8 = nn.Sequential(
            DiscriminatorBlock(512, 512, sa_gamma=sa_gamma),
            SNConv(512, 512, 4, 2, 1)
        )

        self.in_4 =  spectral_norm(nn.Conv2d(3, 512, 1, 1, 0))
        self.layer_4 = nn.Sequential(
            DiscriminatorBlock(512, 512, sa_gamma=sa_gamma),
            nn.AvgPool2d(4),
            nn.Flatten(),
            MinibatchDiscrimination(512, 64, 16, device),
            SNDense(512+64, out_ch)
        )
    
    def forward(self, x, img_size, delta=None):

        if img_size >= 256:
            h = self.in_256(x)
            h = self.layer_256(h)
        else:
            h = 0
        
        if img_size >= 128:
            if img_size == 128:
                _x = F.avg_pool2d(x, img_size//128)
                h = h + self.in_128(_x)
            elif delta and img_size == 256:
                _x = F.avg_pool2d(x, img_size//128)
                h = torch.lerp(self.in_128(_x), h, delta)
            h = self.layer_128(h)
        
        if img_size >= 64:
            if img_size == 64:
                _x = F.avg_pool2d(x, img_size//64)
                h = h + self.in_64(_x)
            elif delta and img_size == 128:
                _x = F.avg_pool2d(x, img_size//64)
                h = torch.lerp(self.in_64(_x), h, delta)
            h = self.layer_64(h)
        
        if img_size >= 32:
            if img_size == 32:
                _x = F.avg_pool2d(x, img_size//32)
                h = h + self.in_32(_x)
            elif delta and img_size == 64:
                _x = F.avg_pool2d(x, img_size//32)
                h = torch.lerp(self.in_32(_x), h, delta)
            h = self.layer_32(h)
        
        if img_size >= 16:
            if img_size == 16:
                _x = F.avg_pool2d(x, img_size//16)
                h = h + self.in_16(_x)
            elif delta and img_size == 32:
                _x = F.avg_pool2d(x, img_size//16)
                h = torch.lerp(self.in_16(_x), h, delta)
            h = self.layer_16(h)

        if img_size >= 8:
            if img_size == 8:
                _x = F.avg_pool2d(x, img_size//8)
                h = h + self.in_8(_x)
            elif delta and img_size == 16:
                _x = F.avg_pool2d(x, img_size//8)
                h = torch.lerp(self.in_8(_x), h, delta)
            h = self.layer_8(h)

        if img_size == 4:
            _x = F.avg_pool2d(x, img_size//4)
            h = h + self.in_4(_x)
        elif delta and img_size == 8:
            _x = F.avg_pool2d(x, img_size//4)
            h = torch.lerp(self.in_4(_x), h, delta)
        h = self.layer_4(h)

        return h

    def set_gamma(self, gamma, img_size):
        if img_size == 8:
            self.layer_4[0].set_gamma(gamma)
        if img_size == 16:
            self.layer_8[0].set_gamma(gamma)

In [0]:
class BNDense(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(BNDense, self).__init__()
        self.main = nn.Sequential(
            nn.LeakyReLU(0.2),
            nn.Linear(in_ch, out_ch),
            nn.BatchNorm1d(out_ch)
        )
        
    def forward(self, x):
        return self.main(x)

In [0]:
class Affine(nn.Module):
    def __init__(self, in_ch, mid_ch=64, w_ch=128):
        super(Affine, self).__init__()
        
        self.main = nn.Sequential(
            nn.Linear(in_ch, mid_ch),
            nn.BatchNorm1d(mid_ch),
            BNDense(mid_ch, mid_ch),
            BNDense(mid_ch, mid_ch),
            BNDense(mid_ch, mid_ch),
            BNDense(mid_ch, mid_ch),
            BNDense(mid_ch, mid_ch),
            BNDense(mid_ch, mid_ch),
            BNDense(mid_ch, w_ch)
        )

    def forward(self, x):
        return self.main(pixel_norm(x))

In [0]:
class SEBlock(nn.Module):
    def __init__(self, in_ch, mid_ch):
        super(SEBlock, self).__init__()

        self.main = nn.Sequential(
            BNDense(in_ch, mid_ch),
            BNDense(mid_ch, in_ch),
            nn.Sigmoid()
        )

    def forward(self, x):
        h = self.main(x)
        return x*h

In [0]:
class AdaIN(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(AdaIN, self).__init__()

        self.se = SEBlock(in_ch, in_ch//4)
        self.average_convert = SNDense(in_ch, out_ch)
        self.bias_convert = SNDense(in_ch, out_ch)

    def forward(self, x, w):
        h = F.instance_norm(x)
        _w = self.se(w)
        a = self.average_convert(_w)
        a = a.view(*a.size(), 1, 1).expand(*h.size())
        b = self.bias_convert(_w)
        b = b.view(*b.size(), 1, 1).expand(*h.size())
        h = h * a + b
        return h

In [0]:
class DepthwiseCondConv(nn.Module):
    def __init__(self, latent_w, in_ch, n_kernels, kernel=3, stride=1, padding=1, device=None, k=1):
        super(DepthwiseCondConv, self).__init__()
        self.device = device

        self.kernel = kernel
        self.stride = stride
        self.padding = padding
        self.ch = in_ch
        self.kernels_size = (n_kernels, kernel, kernel)

        self.kernels = nn.Parameter(torch.empty(n_kernels*kernel*kernel, k, 1))
        nn.init.xavier_normal_(self.kernels)

        self.feature_convert = nn.Sequential(
            SEBlock(latent_w, latent_w//4),
            SNDense(latent_w, in_ch*n_kernels),
            nn.Tanh()
        )
    
    def forward(self, x, w):
        b_size = x.size(0)

        h = F.leaky_relu(x, negative_slope=0.2)
        kernels = self.kernels.expand(-1, -1, self.ch)
        i = torch.eye(self.ch).to(self.device).view(1, self.ch, self.ch).expand(kernels.size(0), -1, -1)
        kernels = i*kernels
        kernels = kernels.view(*self.kernels_size, self.ch, -1)
        kernels = kernels.transpose(1, 3).transpose(2, 4)
        kernels = kernels.expand(b_size, -1, -1, -1, -1, -1)

        f = self.feature_convert(w)
        f = f.view(*kernels.size()[:3], 1, 1, 1)
        f = f.expand(-1, -1, -1, self.ch, -1, -1)
        f = torch.sum(kernels*f, dim=1)

        f = f.reshape(b_size, self.ch, -1)

        h = F.unfold(h, self.kernel, padding=self.padding, stride=self.stride)
        h = torch.bmm(f, h)
        h = h.view(b_size, self.ch, *x.size()[2:])
        return h

In [0]:
class SynthBlock(nn.Module):
    def __init__(self, in_ch, out_ch, latent_w, kernels, noise=None, device=None):
        super(SynthBlock, self).__init__()
        self.noise = noise

        self.dcc = DepthwiseCondConv(latent_w, in_ch, kernels, device=device)
        self.c = SNConv(in_ch, out_ch)
        self.adain = AdaIN(latent_w, out_ch)
            
    def forward(self, x, w):
        h = noise_injection(x, self.noise) if self.noise else x

        h = self.dcc(h, w)
        h = self.c(h)
        h = self.adain(h, w)
        _x = zeropad(x, h.size(1))
        return h+_x

In [0]:
class GeneratorBlock(nn.Module):
    def __init__(self, in_ch, out_ch, latent_w, kernels, noise=None, device=None, sa_gamma=None):
        super(GeneratorBlock, self).__init__()

        self.c_in = SNConv(in_ch, out_ch)
        self.adain_in = AdaIN(latent_w, out_ch)

        self.synth_0 = SynthBlock(out_ch, out_ch, latent_w, kernels, noise, device)
        self.synth_1 = SynthBlock(out_ch, out_ch, latent_w, kernels, noise, device)
        self.synth_2 = SynthBlock(out_ch, out_ch, latent_w, kernels, noise, device)
        self.synth_3 = SynthBlock(out_ch, out_ch, latent_w, kernels, noise, device)
        self.sa = SelfAttentionBlock(out_ch, out_ch, gamma=sa_gamma) if sa_gamma else None
        
        self.c_out = SNConv(out_ch, 3)
            
    def forward(self, x, w):
        h = self.c_in(x)
        h = self.adain_in(h, w)

        h = self.synth_0(h, w)
        h = self.synth_1(h, w)
        h = self.synth_2(h, w)
        h = self.synth_3(h, w)
        if self.sa:
            h = self.sa(h)
            
        out = self.c_out(h)
        return h, out

    def set_gamma(self, gamma):
        if self.sa:
            self.sa.set_gamma(gamma)

In [0]:
class Generator(nn.Module):

    def __init__(self, latent_size, alpha=0.5, noise=None, latent_blur=0.1, device=None, sa_gamma=1., kernels=4):
        super(Generator, self).__init__()
        self.alpha = alpha
        self.latent_blur = latent_blur

        self.btm = nn.Parameter(torch.empty(1, 512, 4, 4))
        nn.init.normal_(self.btm)
        
        self.affine = Affine(latent_size, latent_size, latent_size)
        
        self.res6 = GeneratorBlock(512, 256, latent_size, kernels, noise, device, sa_gamma)
        self.res5 = GeneratorBlock(256, 128, latent_size, kernels, noise, device, sa_gamma)
        self.res4 = GeneratorBlock(128, 64, latent_size, kernels, noise, device)
        self.res3 = GeneratorBlock(64, 32, latent_size, kernels, noise, device)
        self.res2 = GeneratorBlock(32, 16, latent_size, kernels, noise, device)
        self.res1 = GeneratorBlock(16, 8, latent_size, kernels, noise, device)
        self.res0 = GeneratorBlock(8, 4, latent_size, kernels, noise, device)

        self.upbi = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
    
    def forward(self, x, img_size, delta=None, mixing=False):
        w = self.affine(x)
        h = self.btm.expand(x.size(0), -1, -1, -1)
        h, out = self.res6(h, w)
        
        if img_size >= 8:
            h = self.upbi(h)
            _x = pixel_noise(x, self.latent_blur)
            w = self.affine(_x) if mixing else w
            h, _out = self.res5(h, w)
            if delta and img_size == 8:
                out = torch.lerp(self.upbi(out), _out, delta)
            else:
                out = _out
        
        if img_size >= 16:
            h = self.upbi(h)
            _x = pixel_noise(x, self.latent_blur)
            w = self.affine(_x) if mixing else w
            h, _out = self.res4(h, w, self.dcc)
            if delta and img_size == 16:
                out = torch.lerp(self.upbi(out), _out, delta)
            else:
                out = _out

        if img_size >= 32:
            h = self.upbi(h)
            _x = pixel_noise(x, self.latent_blur)
            w = self.affine(_x) if mixing else w
            h, _out = self.res3(h, w)
            if delta and img_size == 32:
                out = torch.lerp(self.upbi(out), _out, delta)
            else:
                out = _out

        if img_size >= 64:
            h = self.upbi(h)
            _x = pixel_noise(x, self.latent_blur)
            w = self.affine(_x) if mixing else w
            h, _out = self.res2(h, w)
            if delta and img_size == 64:
                out = torch.lerp(self.upbi(out), _out, delta)
            else:
                out = _out
        
        if img_size >= 128:
            h = self.upbi(h)
            _x = pixel_noise(x, self.latent_blur)
            w = self.affine(_x) if mixing else w
            h, _out = self.res1(h, w)
            if delta and img_size == 128:
                out = torch.lerp(self.upbi(out), _out, delta)
            else:
                out = _out
        
        if img_size >= 256:
            h = self.upbi(h)
            _x = pixel_noise(x, self.latent_blur)
            w = self.affine(_x) if mixing else w
            h, _out = self.res0(h, w)
            if delta and img_size == 256:
                out = torch.lerp(self.upbi(out), _out, delta)
            else:
                out = _out

        return torch.tanh(out)

    def set_gamma(self, gamma, img_size):
        if img_size == 8:
            self.res6.set_gamma(gamma)
        if img_size == 16:
            self.res5.set_gamma(gamma)

In [0]:
device = torch.device("cuda:{}".format(GPU))

In [0]:
def weights_init(m):
    if type(m) in (nn.Linear, nn.Conv2d):
        nn.init.xavier_normal_(m.weight, gain=initial_scale)
        if m.bias is not None:
            m.bias.data.fill_(0)
    elif type(m) in (nn.BatchNorm1d, nn.BatchNorm2d):
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [0]:
gen = Generator(LATENT_SIZE, alpha=SCALEUP_ALPHA, noise=gen_noise, device=device, sa_gamma=0)
dis = Discriminator(out_ch=1, alpha=SCALEUP_ALPHA, device=device, sa_gamma=0)

if load_weight:
    gen.load_state_dict(torch.load(OUT+'{}_{}px_{}epoch.pkl'.format(gen_name, *load_weight)))
    gen.eval()
    dis.load_state_dict(torch.load(OUT+'{}_{}px_{}epoch.pkl'.format(dis_name, *load_weight)))
    dis.eval()
else:
    gen.apply(weights_init)
    dis.apply(weights_init)

gen.to(device)
dis.to(device)

Discriminator(
  (in_256): Conv2d(3, 32, kernel_size=(1, 1), stride=(1, 1))
  (layer_256): Sequential(
    (0): DiscriminatorBlock(
      (main): Sequential(
        (0): InceptionResBlock(
          (inception): InceptionBlock(
            (layer_11): SNConv(
              (main): Sequential(
                (0): LeakyReLU(negative_slope=0.2)
                (1): Conv2d(32, 12, kernel_size=(1, 1), stride=(1, 1))
              )
            )
            (layer_33_1): Sequential(
              (0): SNConv(
                (main): Sequential(
                  (0): LeakyReLU(negative_slope=0.2)
                  (1): Conv2d(32, 12, kernel_size=(1, 1), stride=(1, 1))
                )
              )
              (1): SNConv(
                (main): Sequential(
                  (0): LeakyReLU(negative_slope=0.2)
                  (1): Conv2d(12, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                )
              )
            )
            (layer_33_2): Sequential(
 

In [0]:
opt_gen = optim.Adam(
    gen.parameters(),
    lr=learning_rate,
    betas=(0.5, 0.999),
    weight_decay=gen_weight_decay
    )
opt_dis = optim.Adam(
    dis.parameters(),
    lr=learning_rate*0.5,
    betas=(0.5, 0.999),
    weight_decay=dis_weight_decay
    )

In [0]:
gen_losses = list()
dis_losses = list()

In [0]:
if tweet_interval:
    auth = tweepy.OAuthHandler(twitter_api_key.CONSUMER_KEY, twitter_api_key.CONSUMER_SECRET)
    auth.set_access_token(twitter_api_key.ACCESS_TOKEN_KEY, twitter_api_key.ACCESS_TOKEN_SECRET)
    api = tweepy.API(auth)

In [0]:
def report_log(i, epoch, g_loss, d_loss, g_mean, d_real_mean, d_fake_mean):
    print('[{}/{}]\tLoss_D: {:.4f}\tLoss_G: {:.4f}\tD(x): {:.4f}\tD(G(z)): {:.4f}/{:.4f}'.format(
        i, epoch, d_loss, g_loss, d_real_mean, d_fake_mean, g_mean
    ))
    gen_losses.append(g_loss)
    dis_losses.append(d_loss)

In [0]:
def make_image(gen, img_size, device, delta=None, mixing=False):
    clear_output()
    with torch.no_grad():
        generated = gen(gaussian((8, LATENT_SIZE)).to(device), img_size, delta, mixing).detach().cpu()
    generated = np.transpose(np.reshape(generated, (-1, 3, img_size, img_size)), (0, 2, 3, 1))

    plt.figure(figsize=(16, 8))
        
    for i, img in enumerate(generated):
        plt.subplot(2, 4, i+1).axis('off')
        plt.subplot(2, 4, i+1).imshow(Image.fromarray(np.uint8((img+1.)/2. *255.)))

    plt.show()

In [0]:
def image_upload(image_array, api):
    bin_io = io.BytesIO()
    img = Image.fromarray(np.uint8((image_array+1.)/2. *255.))
    img = img.resize((512, 512), resample=0)
    img.save(bin_io, format='JPEG')
    result = api.media_upload(filename='{}_generated_{}.jpg'.format(target, uuid.uuid4()), file=bin_io)
    return result.media_id

In [0]:
def post_image(gen, api, img_size, iteration, epoch, device, delta=None, mixing=False):
    with torch.no_grad():
        generated = gen(gaussian((16, LATENT_SIZE)).to(device), img_size, delta, mixing).detach().cpu()
    generated = np.reshape(generated, (4, 2, 2, 3, img_size, img_size))
    generated = np.transpose(generated, (0, 3, 1, 4, 2, 5))
    generated = np.reshape(generated, (4, 3, img_size*2, img_size*2))
    generated = np.transpose(generated, (0, 2, 3, 1))

    try:
        img_ids = [image_upload(img, api) for img in generated]
        hash_tags = ['AIで{}を作る'.format(target_JP),
                    'iteration/epoch: {}/{}'.format(iteration, epoch),
                    '#makeing{}'.format(target),
                    '#nowlearning...',
                    '#AI',
                    '#人工知能',
                    '#DeepLearning',
                    '#GAN']

        api.update_status(
            status='\n'.join(hash_tags),
            media_ids=img_ids
            )
    except Exception:
        pass

In [0]:
def save_model(gen, dis, img_size, epoch):
    torch.save(gen.state_dict(), OUT+'{}_{}px_{}epoch.pkl'.format(gen_name, img_size, epoch))
    torch.save(dis.state_dict(), OUT+'{}_{}px_{}epoch.pkl'.format(dis_name, img_size, epoch))

In [0]:
def report_result():
    plt.figure(figsize=(16, 8))
    plt.plot(gen_losses,label="G")
    plt.plot(dis_losses,label="D")
    plt.xlabel("iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.show()

In [0]:
def round_dataset(gen, dis, opt_gen, opt_dis, dataloader, epoch, img_size, delta, device, mixing=False):
    loss_fun = nn.MSELoss()
    data_len = len(dataloader)
    for i, data in tqdm(enumerate(dataloader, 0)):
        if 8 <= img_size <= 16 and epoch == sg_endpoint+1 and i < sa_endpoint:
            gamma = sa_gamma*(i+1)/sa_endpoint
            gen.set_gamma(gamma, img_size)
            dis.set_gamma(gamma, img_size)
        _delta = delta*(i+1)/data_len if delta else delta
        dis.zero_grad()
        x_real = data[0].to(device)
        b_size = x_real.size(0)
        y_real = dis(x_real, img_size, _delta).view(-1)
        real_loss = loss_fun(y_real, torch.ones(*y_real.size(), device=device))
        real_loss.backward()

        x_fake = gen(gaussian((b_size, LATENT_SIZE)).to(device), img_size, _delta, mixing)
        y_fake = dis(x_fake.detach(), img_size, _delta).view(-1)
        fake_loss = loss_fun(y_fake, torch.zeros(*y_fake.size(), device=device))
        fake_loss.backward()

        if grad_clip:
            clip_grad_norm_(dis.parameters(), grad_clip)
        dis_loss = real_loss + fake_loss
        opt_dis.step()
        
        gen.zero_grad()
        y_gen = dis(x_fake, img_size, _delta).view(-1)
        gen_loss = loss_fun(y_gen, torch.ones(*y_gen.size(), device=device))
        gen_loss.backward()

        if grad_clip:
            clip_grad_norm_(gen.parameters(), grad_clip)
        opt_gen.step()

        if display_interval and (i+1) % display_interval == 0:
            make_image(gen, img_size, device, _delta, mixing)

        if (i+1) % log_interval == 0:
            report_log(
                i,
                data_len,
                gen_loss.item(),
                dis_loss.item(),
                y_gen.mean().item(),
                y_real.mean().item(),
                y_fake.mean().item()
                )

        if tweet_interval and (i+1) % tweet_interval == 0:
            post_image(gen, api, img_size, i, epoch, device, _delta, mixing)
        
        if snapshot_interval and (i+1) % snapshot_interval == 0:
            save_model(gen, dis, img_size, epoch)

In [0]:
def train_loop(gen, dis, opt_gen, opt_dis, dataloader, epoch, img_size, device, mixing=False):
    for i in range(epoch):
        if i < sg_endpoint:
            delta = SCALEUP_ALPHA*(i+1)/sg_endpoint
        else:
            delta = None
        
        gen.train()
        dis.train()
        round_dataset(gen, dis, opt_gen, opt_dis, dataloader, i, img_size, delta, device, mixing)

In [0]:
def upscaling(gen, dis, opt_gen, opt_dis, device, mixing=False):
    img_size = 4*2**START
    for epoch, batch_size in zip(EPOCHS[START:], BATCH_SIZES[START:]):
        dataset = dset.ImageFolder(root=dataroot,
                                   transform=transforms.Compose([
                                                                 transforms.Resize(img_size),
                                                                 transforms.RandomCrop(img_size),
                                                                 transforms.RandomHorizontalFlip(),
                                                                 transforms.ToTensor(),
                                                                 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                                                 ]))
        dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
        train_loop(gen, dis, opt_gen, opt_dis, dataloader, epoch, img_size, device, mixing)
        img_size = img_size*2
    report_result()

In [0]:
upscaling(gen, dis, opt_gen, opt_dis, device, mixing=True)

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

[999/4072]	Loss_D: 0.5144	Loss_G: 0.2828	D(x): 0.4925	D(G(z)): 0.5013/0.4705
[1999/4072]	Loss_D: 0.4983	Loss_G: 0.3006	D(x): 0.5284	D(G(z)): 0.5224/0.4534
[2999/4072]	Loss_D: 0.5106	Loss_G: 0.2694	D(x): 0.4981	D(G(z)): 0.4962/0.4872
[3999/4072]	Loss_D: 0.5108	Loss_G: 0.2710	D(x): 0.5023	D(G(z)): 0.5078/0.4815



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

[999/4072]	Loss_D: 0.4990	Loss_G: 0.2551	D(x): 0.4840	D(G(z)): 0.4808/0.4959
[1999/4072]	Loss_D: 0.4990	Loss_G: 0.2431	D(x): 0.4927	D(G(z)): 0.4767/0.5152
[2999/4072]	Loss_D: 0.5047	Loss_G: 0.2568	D(x): 0.5034	D(G(z)): 0.5015/0.4952
[3999/4072]	Loss_D: 0.4899	Loss_G: 0.2518	D(x): 0.5055	D(G(z)): 0.4935/0.4993



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

[999/4072]	Loss_D: 0.4999	Loss_G: 0.2625	D(x): 0.5035	D(G(z)): 0.5014/0.4887
[1999/4072]	Loss_D: 0.5011	Loss_G: 0.2548	D(x): 0.5072	D(G(z)): 0.5064/0.4960
[2999/4072]	Loss_D: 0.4873	Loss_G: 0.2453	D(x): 0.5034	D(G(z)): 0.4871/0.5073
[3999/4072]	Loss_D: 0.5186	Loss_G: 0.3269	D(x): 0.5768	D(G(z)): 0.5777/0.4297



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

[999/4072]	Loss_D: 0.5038	Loss_G: 0.2600	D(x): 0.5010	D(G(z)): 0.5028/0.4911
[1999/4072]	Loss_D: 0.5042	Loss_G: 0.2612	D(x): 0.5007	D(G(z)): 0.5043/0.4893
[2999/4072]	Loss_D: 0.4913	Loss_G: 0.3111	D(x): 0.5321	D(G(z)): 0.5192/0.4447
[3999/4072]	Loss_D: 0.5044	Loss_G: 0.2497	D(x): 0.4816	D(G(z)): 0.4841/0.5010



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

[999/4072]	Loss_D: 0.5039	Loss_G: 0.2483	D(x): 0.5208	D(G(z)): 0.5217/0.5025
[1999/4072]	Loss_D: 0.5082	Loss_G: 0.2224	D(x): 0.5379	D(G(z)): 0.5412/0.5290
[2999/4072]	Loss_D: 0.5009	Loss_G: 0.2683	D(x): 0.4982	D(G(z)): 0.4983/0.4827
[3999/4072]	Loss_D: 0.5082	Loss_G: 0.2827	D(x): 0.5061	D(G(z)): 0.5082/0.4701



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

[999/4072]	Loss_D: 0.5044	Loss_G: 0.2854	D(x): 0.5476	D(G(z)): 0.5441/0.4673
[1999/4072]	Loss_D: 0.5022	Loss_G: 0.2590	D(x): 0.4959	D(G(z)): 0.4964/0.4917
[2999/4072]	Loss_D: 0.4960	Loss_G: 0.2511	D(x): 0.5198	D(G(z)): 0.5132/0.5002
[3999/4072]	Loss_D: 0.4934	Loss_G: 0.2620	D(x): 0.4917	D(G(z)): 0.4831/0.4888



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

[999/4072]	Loss_D: 0.5006	Loss_G: 0.2573	D(x): 0.4953	D(G(z)): 0.4952/0.4931
[1999/4072]	Loss_D: 0.5249	Loss_G: 0.1788	D(x): 0.4116	D(G(z)): 0.4195/0.5791
[2999/4072]	Loss_D: 0.5069	Loss_G: 0.2554	D(x): 0.4981	D(G(z)): 0.5034/0.4953
[3999/4072]	Loss_D: 0.5024	Loss_G: 0.2497	D(x): 0.5264	D(G(z)): 0.5263/0.5009



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

[999/4072]	Loss_D: 0.5026	Loss_G: 0.2520	D(x): 0.5297	D(G(z)): 0.5292/0.4985
[1999/4072]	Loss_D: 0.5001	Loss_G: 0.2518	D(x): 0.5022	D(G(z)): 0.5012/0.4989
[2999/4072]	Loss_D: 0.4988	Loss_G: 0.2585	D(x): 0.4935	D(G(z)): 0.4903/0.4925
[3999/4072]	Loss_D: 0.4928	Loss_G: 0.2523	D(x): 0.4955	D(G(z)): 0.4866/0.4989



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

[999/8143]	Loss_D: 0.5042	Loss_G: 0.2723	D(x): 0.4947	D(G(z)): 0.4969/0.4793
[1999/8143]	Loss_D: 0.5091	Loss_G: 0.2501	D(x): 0.5006	D(G(z)): 0.5078/0.5006
[2999/8143]	Loss_D: 0.5055	Loss_G: 0.2583	D(x): 0.5069	D(G(z)): 0.5098/0.4930
[3999/8143]	Loss_D: 0.4819	Loss_G: 0.2771	D(x): 0.4797	D(G(z)): 0.4493/0.4748
[4999/8143]	Loss_D: 0.4987	Loss_G: 0.2784	D(x): 0.4743	D(G(z)): 0.4690/0.4738
[5999/8143]	Loss_D: 0.5081	Loss_G: 0.2447	D(x): 0.4951	D(G(z)): 0.5008/0.5066
[6999/8143]	Loss_D: 0.4878	Loss_G: 0.2719	D(x): 0.5009	D(G(z)): 0.4828/0.4806
[7999/8143]	Loss_D: 0.5003	Loss_G: 0.2433	D(x): 0.5422	D(G(z)): 0.5344/0.5100



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

[999/8143]	Loss_D: 0.4893	Loss_G: 0.2596	D(x): 0.5129	D(G(z)): 0.4968/0.4918
[1999/8143]	Loss_D: 0.4781	Loss_G: 0.2859	D(x): 0.5038	D(G(z)): 0.4741/0.4679
[2999/8143]	Loss_D: 0.4772	Loss_G: 0.2562	D(x): 0.4819	D(G(z)): 0.4530/0.4953
[3999/8143]	Loss_D: 0.4983	Loss_G: 0.2479	D(x): 0.5060	D(G(z)): 0.4988/0.5048
[4999/8143]	Loss_D: 0.4799	Loss_G: 0.2786	D(x): 0.5304	D(G(z)): 0.4945/0.4753
[5999/8143]	Loss_D: 0.5047	Loss_G: 0.2625	D(x): 0.4888	D(G(z)): 0.4833/0.4926
[6999/8143]	Loss_D: 0.4883	Loss_G: 0.2761	D(x): 0.5039	D(G(z)): 0.4882/0.4766
[7999/8143]	Loss_D: 0.5283	Loss_G: 0.2929	D(x): 0.4755	D(G(z)): 0.4991/0.4601



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

[999/8143]	Loss_D: 0.4883	Loss_G: 0.2449	D(x): 0.4842	D(G(z)): 0.4688/0.5068
[1999/8143]	Loss_D: 0.5117	Loss_G: 0.2559	D(x): 0.4761	D(G(z)): 0.4838/0.4950
[2999/8143]	Loss_D: 0.4812	Loss_G: 0.4051	D(x): 0.5994	D(G(z)): 0.5532/0.3669
[3999/8143]	Loss_D: 0.5094	Loss_G: 0.2450	D(x): 0.5211	D(G(z)): 0.5274/0.5058
[4999/8143]	Loss_D: 0.4872	Loss_G: 0.2946	D(x): 0.5100	D(G(z)): 0.4915/0.4632
[5999/8143]	Loss_D: 0.5036	Loss_G: 0.2743	D(x): 0.5006	D(G(z)): 0.5014/0.4772
[6999/8143]	Loss_D: 0.4985	Loss_G: 0.2591	D(x): 0.4955	D(G(z)): 0.4924/0.4919
[7999/8143]	Loss_D: 0.4973	Loss_G: 0.2438	D(x): 0.5173	D(G(z)): 0.5102/0.5076



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

[999/8143]	Loss_D: 0.4913	Loss_G: 0.2693	D(x): 0.4596	D(G(z)): 0.4406/0.4831
[1999/8143]	Loss_D: 0.4669	Loss_G: 0.2576	D(x): 0.5528	D(G(z)): 0.5137/0.4937
[2999/8143]	Loss_D: 0.4969	Loss_G: 0.2477	D(x): 0.5343	D(G(z)): 0.5262/0.5036
[3999/8143]	Loss_D: 0.5036	Loss_G: 0.2531	D(x): 0.4855	D(G(z)): 0.4869/0.4978
[4999/8143]	Loss_D: 0.4891	Loss_G: 0.2624	D(x): 0.5171	D(G(z)): 0.5049/0.4883
[5999/8143]	Loss_D: 0.4974	Loss_G: 0.2504	D(x): 0.5005	D(G(z)): 0.4938/0.5015
[6999/8143]	Loss_D: 0.5138	Loss_G: 0.2798	D(x): 0.4715	D(G(z)): 0.4802/0.4724
[7999/8143]	Loss_D: 0.4843	Loss_G: 0.2587	D(x): 0.5112	D(G(z)): 0.4916/0.4932



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

In [0]:
make_image(gen, 256, device)