In [0]:
VERSION = '0.9.2'
GPU = 0

# target
TARGET = 'landscape'
TARGET_JP = '風景'

# hyper params
EPOCHS = (8, 8, 8, 16, 16, 32, 32)
BATCH_SIZES = (256, 256, 256, 256, 128, 128, 64)

# extension params
LOG_INTERVAL = 2000
DISPLAY_INTERVAL = None
TWEET_INTERVAL = 8000
SNAPSHOT_INTERVAL = 2000

# model params
SA_GAMMA = 1.
CH_SIZE = 32
LATENT_SIZE = 256
STAGE = 7

# learning params
LEARNING_RATE = 2e-4
GRAD_CLIP = None
GEN_WEIGHT_DECAY = 0
DIS_WEIGHT_DECAY = 0
RUNNING_MEAN_RATE = 1e-4
RECONST_RATE = 1.
START = 0

# initialize params
LOAD_GEN = None
LOAD_DIS = None
STRICT = False
INIT_LAYER = (4, -1)

# paths
OUT = './result/'
DATAROOT = './picture/train_pic/flickr/landscape'

# names
GEN_NAME = '{}_gen'.format(TARGET)
DIS_NAME = '{}_dis'.format(TARGET)

In [0]:
import numpy as np
import io
import random
import uuid
import pickle
import tweepy
from tqdm.notebook import tqdm

from PIL import Image, ImageFilter
import matplotlib.pyplot as plt
from IPython.display import clear_output

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import spectral_norm, clip_grad_norm_
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms

import twitter_api_key

In [0]:
DEVICE = torch.device("cuda:{}".format(GPU))

In [0]:
def zeropad(x, ch):
    return F.pad(x, (0, 0, 0, 0, 0, ch-x.size(1), 0, 0))

def gap(x):
    return F.avg_pool2d(x, x.size()[-2:])

In [0]:
class SNConv(nn.Module):
    def __init__(self, in_ch, out_ch, kernel=1, stride=1, padding=0, bias=True, groups=1, activation=nn.LeakyReLU(0.2)):
        super(SNConv, self).__init__()
        self.conv = spectral_norm(nn.Conv2d(in_ch, out_ch, kernel, stride, padding, bias=bias, groups=groups))

        self.activation = activation

        self._cache_weight = None
        self._cache_bias = None
    
    def forward(self, x):
        h = self.activation(x) if self.activation else x
        return self.conv(h)

    def weight_init(self):
        nn.init.kaiming_normal_(self.conv.weight, 0.2)
        if self.conv.bias is not None:
            nn.init.constant_(self.conv.bias.data, 0)

    def running_mean(self, gamma):
        if self.conv.weight.requires_grad:
            if self._cache_weight is None:
                self._cache_weight = self.conv.weight.data.detach()
            else:
                self._cache_weight = torch.lerp(self.conv.weight.data, self._cache_weight, gamma)
                self.conv.weight.data = self._cache_weight
        else:
            if self._cache_weight is not None:
                self._cache_weight = None

        if self.conv.bias is not None and self.conv.bias.requires_grad:
            if self._cache_bias is None:
                self._cache_bias = self.conv.bias.data.detach()
            else:
                self._cache_bias = torch.lerp(self.conv.bias.data, self._cache_bias, gamma)
                self.conv.bias.data = self._cache_bias
        else:
            if self._cache_bias is not None:
                self._cache_bias = None

In [0]:
class SNDense(nn.Module):
    def __init__(self, in_ch, out_ch, bias=True, activation=nn.LeakyReLU(0.2)):
        super(SNDense, self).__init__()
        self.linear = spectral_norm(nn.Linear(in_ch, out_ch, bias=bias))
        
        self.activation = activation

        self._cache_weight = None
        self._cache_bias = None

    def forward(self, x):
        h = self.activation(x) if self.activation else x
        return self.linear(h)

    def weight_init(self):
        nn.init.kaiming_normal_(self.linear.weight, 0.2)
        if self.linear.bias is not None:
            nn.init.constant_(self.linear.bias.data, 0)

    def running_mean(self, gamma):
        if self.linear.weight.requires_grad:
            if self._cache_weight is None:
                self._cache_weight = self.linear.weight.data.detach()
            else:
                self._cache_weight = torch.lerp(self.linear.weight.data, self._cache_weight, gamma)
                self.linear.weight.data = self._cache_weight
        else:
            if self._cache_weight is not None:
                self._cache_weight = None

        if self.linear.bias is not None and self.linear.bias.requires_grad:
            if self._cache_bias is None:
                self._cache_bias = self.linear.bias.data.detach()
            else:
                self._cache_bias = torch.lerp(self.linear.bias.data, self._cache_bias, gamma)
                self.linear.bias.data = self._cache_bias
        else:
            if self._cache_bias is not None:
                self._cache_bias = None

In [0]:
class SNSelfAttentionBlock(nn.Module):
    def __init__(self, in_ch, out_ch, gamma=1.):
        super(SNSelfAttentionBlock, self).__init__()
        self.gamma = gamma

        self.cf = SNConv(in_ch, out_ch//8)
        self.cg = SNConv(in_ch, out_ch//8)
        self.ch = SNConv(in_ch, out_ch)

        self.softmax = nn.Softmax(2)
            
    def forward(self, x):
        f = self.cf(x)
        g = self.cg(x)
        h = self.ch(x)
        f = f.view(f.size(0), f.size(1), -1)
        g = g.view(g.size(0), g.size(1), -1)
        h = h.view(h.size(0), h.size(1), -1)
        
        attention_map = torch.bmm(torch.transpose(f, 1, 2), g)
        attention_map = self.softmax(attention_map)
        feature_map = torch.bmm(h, torch.transpose(attention_map, 1, 2))
        feature_map = feature_map.view(*x.size())

        return x+feature_map*self.gamma

    def set_gamma(self, gamma):
        self.gamma = gamma

    def weight_init(self):
        self.cf.weight_init()
        self.cg.weight_init()
        self.ch.weight_init()

    def running_mean(self, gamma):
        self.cf.running_mean(gamma)
        self.cg.running_mean(gamma)
        self.ch.running_mean(gamma)

    def freeze_param(self):
        self.cf.requires_grad_(False)
        self.cg.requires_grad_(False)
        self.ch.requires_grad_(False)

In [0]:
class MinibatchDiscrimination(nn.Module):
    def __init__(self, in_ch, kernel, kernel_dims, device):
        super(MinibatchDiscrimination, self).__init__()
        self.device = device
        self.kernel = kernel
        self.dim = kernel_dims
        
        self.t = SNDense(in_ch, self.kernel*self.dim, bias=False, activation=False)
        for param in self.t.parameters():
            param.requires_grad = False

    def __call__(self, x):
        batchsize = x.size(0)
        m = self.t(x).view(batchsize, self.kernel, self.dim, 1)
        m_T = torch.transpose(m, 0, 3)
        m, m_T = torch.broadcast_tensors(m, m_T)
        norm = torch.sum(F.l1_loss(m, m_T, reduction='none'), dim=2)

        eraser = torch.eye(batchsize, device=self.device).view(batchsize, 1, batchsize).expand(norm.size())
        c_b = torch.exp(-(norm + 1e6 * eraser))
        o_b = torch.sum(c_b, dim=2)
        h = torch.cat((x, o_b), dim=1)
        return h

    def weight_init(self):
        self.t.weight_init()

In [0]:
class DiscriminatorBlock(nn.Module):
    def __init__(self, in_ch, out_ch, sa_gamma=None, downconv=True):
        super(DiscriminatorBlock, self).__init__()
        self.out_ch = out_ch

        self.sa = SNSelfAttentionBlock(out_ch, out_ch, gamma=sa_gamma) if sa_gamma else None
        self.extractor = SNConv(in_ch, out_ch, 4, 2, 1) if downconv else SNConv(in_ch, out_ch, 3, 1, 1)
        
        self.downsample = nn.AvgPool2d(2) if downconv else None
            
    def forward(self, x):
        h = x if self.sa is None else self.sa(x)
        h = self.extractor(h)

        _h = x if self.downsample is None else self.downsample(x)
        _h = zeropad(_h, self.out_ch)
        return h + _h

    def set_gamma(self, gamma):
        if self.sa:
            self.sa.set_gamma(gamma)

    def weight_init(self):
        if self.sa is not None:
            self.sa.weight_init()
        self.extractor.weight_init()

In [0]:
class Discriminator(nn.Module):
    def __init__(self, ch_size, out_ch=1, device=None, sa_gamma=1.):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            SNConv(3, ch_size, 1, 1, 0, activation=False),
            DiscriminatorBlock(ch_size, ch_size),
            DiscriminatorBlock(ch_size, ch_size),
            DiscriminatorBlock(ch_size, ch_size, downconv=False),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            MinibatchDiscrimination(ch_size, 8, 8, device),
            SNDense(ch_size+8, out_ch)
        )
    
    def forward(self, x):
        h = self.main(x)
        return h

    def set_gamma(self, gamma):
        return
        
    def weight_init(self):
        for m in self.main[:4]:
            m.weight_init()
        for m in self.main[6:]:
            m.weight_init()

In [0]:
class CondConv(nn.Module):
    def __init__(self, in_ch, out_ch, hidden_ch, kernel=1, padding=0, variations=3,
                 activation=nn.LeakyReLU(0.2), bias=True):
        super(CondConv, self).__init__()
        self.in_ch = in_ch
        self.out_ch = out_ch
        self.variations = variations

        self.weight = nn.Parameter(torch.empty(out_ch, in_ch, variations, kernel**2))
        self.bias = nn.Parameter(torch.empty(out_ch)) if bias else None

        self.condition = SNDense(hidden_ch, out_ch*in_ch*variations, activation=None)

        self.activation = activation
        self.unfold = nn.Unfold(kernel, padding=padding)
        self.softmax = nn.Softmax(3)

        self._cache_weight = None
        self._cache_bias = None
    
    def forward(self, x, z):
        b_size, _, height, width = x.size()
        h = self.activation(x) if self.activation else x

        w = self.weight.expand(b_size, -1, -1, -1, -1)

        f = self.condition(z)
        f = f.view(b_size, self.out_ch, self.in_ch, self.variations)
        f = self.softmax(f)
        f = f.view(*f.size(), 1)

        w = w * f
        w = w.sum(dim=3)
        w = w.view(b_size, self.out_ch, -1)

        h = self.unfold(h)
        h = torch.bmm(w, h)
        h = h.view(b_size, self.out_ch, height, width)
        if self.bias is not None:
            _b = self.bias.view(1, -1, 1, 1)
            _b = _b.expand(b_size, -1, height, width)
            h = h + _b
        return h

    def weight_init(self):
        nn.init.kaiming_normal_(self.weight)
        if self.bias is not None:
            nn.init.constant_(self.bias.data, 0)
        self.condition.weight_init()

    def running_mean(self, gamma):
        if self.weight.requires_grad:
            if self._cache_weight is None:
                self._cache_weight = self.weight.data.detach()
            else:
                self._cache_weight = torch.lerp(self.weight.data, self._cache_weight, gamma)
                self.weight.data = self._cache_weight
        else:
            if self._cache_weight is not None:
                self._cache_weight = None

        if self.bias is not None and self.bias.requires_grad:
            if self._cache_bias is None:
                self._cache_bias = self.bias.data.detach()
            else:
                self._cache_bias = torch.lerp(self.bias.data, self._cache_bias, gamma)
                self.bias.data = self._cache_bias
        else:
            if self._cache_bias is not None:
                self._cache_bias = None

        self.condition.running_mean(gamma)

In [0]:
class GeneratorBlock(nn.Module):
    def __init__(self, in_ch, out_ch, latent_ch, activation=nn.LeakyReLU(0.2), device=None):
        super(GeneratorBlock, self).__init__()
        self.out_ch = out_ch
        self.latent_ch = latent_ch
        self.device = device

        self.extractor_0 = spectral_norm(CondConv(in_ch, out_ch, latent_ch, kernel=3, padding=1,
                                                  activation=activation))
        self.extractor_1 = spectral_norm(CondConv(out_ch, out_ch, latent_ch, kernel=3, padding=1))
        self.tail = spectral_norm(CondConv(out_ch, out_ch, latent_ch, kernel=3, padding=1))

        self._cache_tail = None

    def forward(self, x, z):
        h = self.extractor_0(x, z)
        h = self.extractor_1(h, z)

        _h = None if self._cache_tail is None else self._cache_tail(h, z)
        h = self.tail(h, z)
        return h, _h

    def weight_init(self):
        self.extractor_0.weight_init()
        self.extractor_1.weight_init()
        self.tail.weight_init()

    def running_mean(self, gamma):
        self.extractor_0.running_mean(gamma)
        self.extractor_1.running_mean(gamma)
        self.tail.running_mean(gamma)

    def freeze_param(self, freeze_tail=False):
        self.extractor_0.requires_grad_(False)
        self.extractor_1.requires_grad_(False)
        if freeze_tail:
            self.tail.requires_grad_(False)

    def cache_tail(self):
        self._cache_tail = spectral_norm(
            CondConv(self.out_ch, self.out_ch, self.latent_ch, kernel=3, padding=1)
            )
        self._cache_tail.load_state_dict(self.tail.state_dict())
        self._cache_tail.to(self.device)
        self._cache_tail.requires_grad_(False)

    def clear_cache(self):
        self._cache_tail = None

In [0]:
class Generator(nn.Module):
    def __init__(self, ch_size, latent_size, device=None):
        super(Generator, self).__init__()
        self.device = device
        self.ch_size = ch_size
        self.latent_size = latent_size

        self.btm = SNDense(latent_size, ch_size*4, activation=None)
        self.blocks = nn.ModuleList(
            [GeneratorBlock(ch_size//4+1, ch_size, latent_size, device=device),
             GeneratorBlock(ch_size//4+1, ch_size, latent_size, device=device),
             GeneratorBlock(ch_size//4+1, ch_size, latent_size, device=device),
             GeneratorBlock(ch_size//4+1, ch_size, latent_size, device=device),
             GeneratorBlock(ch_size//4+1, ch_size, latent_size, device=device),
             GeneratorBlock(ch_size//4+1, ch_size, latent_size, device=device),
             GeneratorBlock(ch_size//4+1, ch_size, latent_size, device=device)]
             )
        self.out = SNConv(ch_size, 3)

        self.upbi = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.upps = nn.PixelShuffle(2)

        self._cache_out = None

    def forward(self, batch, manipulation=None, layer_num=6):
        if manipulation is not None:
            z = manipulation
        else:
            z = self.latents_generate(batch)

        _h = None
        h = self.btm(z)
        h = h.view(batch, self.ch_size//4, 4, 4)
        _n = torch.randn(batch, 1, *h.size()[2:], device=self.device)
        h = torch.cat((h, _n), dim=1)
        
        for block in self.blocks[:layer_num]:
            h, _h = self.block_up(h, z, block)
            _n = torch.randn(batch, 1, *h.size()[2:], device=self.device)
            h = torch.cat((h, _n), dim=1)

        if self._cache_out is not None:
            _h = self._cache_out(_h)
            _h = torch.tanh(_h)

        h, _ = self.blocks[layer_num](h, z)
        h = self.out(h)
        return torch.tanh(h), _h

    def block_up(self, x, z, block):
        h, _h = block(x, z)
        h = self.upps(h)
        return h, _h

    def latents_generate(self, batch):
        return self.pixel_norm(torch.randn(batch, self.latent_size, device=self.device))

    def pixel_norm(self, x):
        return x * torch.rsqrt((x**2).mean(1, keepdim=True) + 1e-8)

    def weight_init(self, init_layer):
        if init_layer[0] == 0:
            self.btm.weight_init()
        for b in self.blocks[init_layer[0]:init_layer[1]]:
            b.weight_init()

        nn.init.xavier_normal_(self.out.conv.weight)
        if self.out.conv.bias is not None:
            nn.init.constant_(self.out.conv.bias.data, 0)

    def running_mean(self, gamma):
        self.btm.running_mean(gamma)
        for b in self.blocks:
            b.running_mean(gamma)
        self.out.running_mean(gamma)

    def up_scale(self, layer_num):
        if layer_num == 0:
            self.btm.requires_grad_(False)

        self.blocks[layer_num-1].freeze_param(freeze_tail=True)
        self.blocks[layer_num-1].clear_cache()

        self.blocks[layer_num].freeze_param()
        self.blocks[layer_num].cache_tail()
        
        self._cache_out = SNConv(self.ch_size, 3)
        self._cache_out.load_state_dict(self.out.state_dict())
        self._cache_out.to(self.device)
        self._cache_out.requires_grad_(False)

    def clear_cache(self):
        self._cache_out = None

In [0]:
def pixel_norm(x):
    return x * torch.rsqrt((x**2).mean(1, keepdim=True) + 1e-8)

In [0]:
CONST_LATENT = pixel_norm(torch.randn(16, LATENT_SIZE, device=DEVICE))

In [0]:
def create_discriminator(ch_size=CH_SIZE, device=DEVICE, sa_gamma=SA_GAMMA):
    dis = Discriminator(ch_size, out_ch=1, device=device, sa_gamma=sa_gamma)
    if LOAD_DIS:
        dis.load_state_dict(torch.load(OUT+'{}_{}_stage{}.pkl'.format(DIS_NAME, VERSION, LOAD_DIS)))
        dis.eval()
    else:
        dis.weight_init()
    dis.to(device)

    opt_dis = optim.Adam(
        dis.parameters(),
        lr=LEARNING_RATE,
        betas=(0.5, 0.999),
        weight_decay=DIS_WEIGHT_DECAY
        )
    return dis, opt_dis

In [0]:
gen = Generator(CH_SIZE, LATENT_SIZE, device=DEVICE)
if LOAD_GEN:
    gen.load_state_dict(torch.load(OUT+'{}_{}_stage{}.pkl'.format(GEN_NAME, VERSION, LOAD_GEN)), strict=STRICT)
    gen.weight_init(INIT_LAYER)
    gen.eval()
else:
    gen.weight_init()
gen.to(DEVICE)

opt_gen = optim.Adam(
    gen.parameters(),
    lr=LEARNING_RATE,
    betas=(0.5, 0.999),
    weight_decay=GEN_WEIGHT_DECAY
    )

In [0]:
gen_losses = list()
dis_losses = list()

In [0]:
if TWEET_INTERVAL:
    auth = tweepy.OAuthHandler(twitter_api_key.CONSUMER_KEY, twitter_api_key.CONSUMER_SECRET)
    auth.set_access_token(twitter_api_key.ACCESS_TOKEN_KEY, twitter_api_key.ACCESS_TOKEN_SECRET)
    api = tweepy.API(auth)

In [0]:
def report_log(iteration, epoch, g_loss, d_loss, rec_loss, g_mean, d_real_mean, d_fake_mean):
    print('[{}/{}]\tLoss_D: {:.4f}\tLoss_G: {:.4f}\tLoss_Grec: {:.4f}\tD(x): {:.4f}\tD(G(z)): {:.4f}/{:.4f}'.format(
        iteration, epoch, d_loss, g_loss, rec_loss, d_real_mean, d_fake_mean, g_mean
    ))
    gen_losses.append(g_loss)
    dis_losses.append(d_loss)

In [0]:
def make_image(gen, stage):
    img_size = 2**(stage+2)

    gen.eval()
    with torch.no_grad():
        generated, _ = gen(16, manipulation=CONST_LATENT, layer_num=stage)
        generated = generated.detach().cpu()
    gen.train()

    generated = np.transpose(np.reshape(generated, (-1, 3, img_size, img_size)), (0, 2, 3, 1))

    plt.figure(figsize=(16, 16))  
    for i, img in enumerate(generated):
        plt.subplot(4, 4, i+1).axis('off')
        plt.subplot(4, 4, i+1).imshow(Image.fromarray(np.uint8((img+1.)/2. *255.)))

    plt.show()

In [0]:
def image_upload(image_array, api):
    bin_io = io.BytesIO()
    img = Image.fromarray(np.uint8((image_array+1.)/2. *255.))
    img = img.resize((512, 512), resample=0)
    img.save(bin_io, format='JPEG')
    result = api.media_upload(filename='{}_generated_{}.jpg'.format(TARGET, uuid.uuid4()), file=bin_io)
    return result.media_id

In [0]:
def post_image(gen, api, stage, iteration, epoch):
    img_size = 2**(stage+2)

    gen.eval()
    with torch.no_grad():
        generated, _ = gen(16, manipulation=CONST_LATENT, layer_num=stage)
        generated = generated.detach().cpu()
    gen.train()

    generated = np.reshape(generated, (4, 2, 2, 3, img_size, img_size))
    generated = np.transpose(generated, (0, 3, 1, 4, 2, 5))
    generated = np.reshape(generated, (4, 3, img_size*2, img_size*2))
    generated = np.transpose(generated, (0, 2, 3, 1))

    try:
        img_ids = [image_upload(img, api) for img in generated]
        hash_tags = ['AIで{}を作る'.format(TARGET_JP),
                    'iteration/epoch: {}/{}'.format(iteration+1, epoch+1),
                    '#makeing{}'.format(TARGET),
                    '#nowlearning...',
                    '#AI',
                    '#人工知能',
                    '#DeepLearning',
                    '#GAN']

        api.update_status(
            status='\n'.join(hash_tags),
            media_ids=img_ids
            )
    except Exception:
        pass

In [0]:
def save_model(gen, dis, epoch):
    torch.save(gen.state_dict(), OUT+'{}_{}_stage{}.pkl'.format(GEN_NAME, VERSION, epoch))
    torch.save(dis.state_dict(), OUT+'{}_{}_stage{}.pkl'.format(DIS_NAME, VERSION, epoch))

In [0]:
def report_result():
    plt.figure(figsize=(16, 8))
    plt.plot(gen_losses,label="G")
    plt.plot(dis_losses,label="D")
    plt.xlabel("iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.show()

In [0]:
def update_discriminator(dis, opt_dis, real, fake, device=DEVICE):
    loss_fun = nn.MSELoss()
    
    dis.zero_grad()
    y_real = dis(real).view(-1)
    real_loss = loss_fun(y_real, torch.ones(*y_real.size(), device=device))
    real_loss.backward()

    y_fake = dis(fake).view(-1)
    fake_loss = loss_fun(y_fake, torch.zeros(*y_fake.size(), device=device))
    fake_loss.backward()

    if GRAD_CLIP:
        clip_grad_norm_(dis.parameters(), GRAD_CLIP)
    dis_loss = real_loss + fake_loss
    opt_dis.step()

    return dis_loss, y_real, y_fake

In [0]:
def update_generator(gen, opt_gen, dis, fake, prestage, epoch, device=DEVICE):
    loss_fun = nn.MSELoss()
    reconst_rate = RECONST_RATE

    gen.zero_grad()
    y_gen = dis(fake).view(-1)
    gen_loss = loss_fun(y_gen, torch.ones(*y_gen.size(), device=device))

    if prestage is None:
        gen_loss.backward()
        reconst_loss = None
    else:
        gen_loss.backward(retain_graph=True)

        reconst_loss = loss_fun(
            F.avg_pool2d(fake, 2),
            prestage.detach())
        reconst_loss = reconst_loss * reconst_rate
        reconst_loss.backward()

    if GRAD_CLIP:
        clip_grad_norm_(gen.parameters(), GRAD_CLIP)
    opt_gen.step()

    gen.running_mean(RUNNING_MEAN_RATE*epoch)

    return gen_loss, y_gen, reconst_loss

In [0]:
def round_dataset(gen, dis, opt_gen, opt_dis, dataloader, epoch, stage, device=DEVICE):
    data_len = len(dataloader)
    img_size = 2**(stage+2)

    if epoch > 0:
        dis.set_gamma(SA_GAMMA)

    for iteration, data in tqdm(enumerate(dataloader, 0)):
        if epoch == 0:
            dis.set_gamma(SA_GAMMA*iteration/data_len)

        x_real = data[0].to(device)
        b_size = x_real.size(0)
        x_fake, x_prestage = gen(b_size, layer_num=stage)

        dis_loss, y_real, y_fake = update_discriminator(dis, opt_dis, x_real, x_fake.detach(), device=device)
        gen_loss, y_gen, reconst_loss = update_generator(gen, opt_gen, dis, x_fake, x_prestage, epoch, device=device)

        if DISPLAY_INTERVAL and (iteration+1) % DISPLAY_INTERVAL == 0:
            make_image(gen, stage)

        if (iteration+1) % LOG_INTERVAL == 0:
            report_log(
                iteration,
                epoch,
                gen_loss.item(),
                dis_loss.item(),
                reconst_loss.item() if reconst_loss is not None else 0,
                y_gen.mean().item(),
                y_real.mean().item(),
                y_fake.mean().item()
                )

        if TWEET_INTERVAL and (iteration+1) % TWEET_INTERVAL == 0:
            post_image(gen, api, stage, iteration, epoch)
        
        if SNAPSHOT_INTERVAL and (iteration+1) % SNAPSHOT_INTERVAL == 0:
            save_model(gen, dis, stage)

In [0]:
def grow_layer(gen, opt_gen, dataloader, stage, device=DEVICE):
    dis, opt_dis = create_discriminator()
    dis.train()
    for epoch in range(EPOCHS[stage]):
        round_dataset(gen, dis, opt_gen, opt_dis, dataloader, epoch, stage, device=device)
    gen.up_scale(stage)

In [0]:
def create_dataloader(stage):
    img_size = 2**(stage+2)
    dataset = dset.ImageFolder(root=DATAROOT,
                                transform=transforms.Compose([transforms.Resize(img_size),
                                                              transforms.RandomCrop(img_size),
                                                              transforms.RandomHorizontalFlip(),
                                                              transforms.ToTensor(),
                                                              transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                                              ]))
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZES[stage], shuffle=True, num_workers=6, drop_last=True)
    return dataloader

In [0]:
def train_loop(gen, opt_gen, device=DEVICE):
    gen.train()
    for stage in range(START, STAGE):
        dataloader = create_dataloader(stage)
        grow_layer(gen, opt_gen, dataloader, stage, device=device)
    report_result()

In [0]:
train_loop(gen, opt_gen, device=DEVICE)