In [0]:
# target
target = 'landscape'
target_JP = '風景'

# hyper params
EPOCHS = [32, 16, 8, 4]
BATCH_SIZE = 64
initial_scale = 1.
GPU = 0

# extension params
log_interval = 1000
display_interval = None
tweet_interval = 1000
snapshot_interval = 1000

# model params
sa_gamma = 1.
IMG_SIZE = 256
CH_SIZE = 8
LATENT_SIZE = 64
HIDDEN_SIZE = 16

# learning controller
learning_rate = 2e-4
grad_clip = None
gen_weight_decay = 0
dis_weight_decay = 0
sa_endpoint = 15000

# file names
load_gen = None
load_dis = None

OUT = './result/'
dataroot = './picture/train_pic/flickr/landscape'

gen_name = '{}_gen'.format(target)
dis_name = '{}_dis'.format(target)

In [0]:
import numpy as np
import io
import uuid
import pickle
import tweepy
from tqdm.notebook import tqdm

from PIL import Image, ImageFilter
import matplotlib.pyplot as plt
from IPython.display import clear_output

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import spectral_norm, clip_grad_norm_
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms

import twitter_api_key

In [0]:
def gaussian(size):
    return torch.normal(torch.zeros(size), torch.ones(size))

def zeropad(x, ch):
    return F.pad(x, (0, 0, 0, 0, 0, ch-x.size(1), 0, 0))

def split_zeropad(x, ch):
    h = x.view(x.size(0), 4, -1, *x.size()[2:])
    h_size_half = h.size(2) // 2
    h_pad_half = ch // 8 - h_size_half
    h = torch.cat(
        (
            F.pad(h[:,:,:h_size_half,:,:], (0,0,0,0,0,h_pad_half,0,0,0,0)).view(h.size(0), -1, *h.size()[3:]),
            F.pad(h[:,:,h_size_half:,:,:], (0,0,0,0,h_pad_half,0,0,0,0,0)).view(h.size(0), -1, *h.size()[3:])
         ),
        dim=1
    )
    return h

def gap(x):
    return F.avg_pool2d(x, x.size()[-2:])

def instance_var(x):
    _shape = x.size()
    _x = x.view(_shape[0], _shape[1], -1)
    _x = torch.var(_x, dim=2)
    _x = _x.view(*_x.size(), 1, 1)
    _x = _x.expand(*_shape)
    return _x

def pixel_norm(x):
    return x * torch.rsqrt((x**2).mean(1, keepdim=True) + 1e-8)

In [0]:
class Mish(nn.Module):
    def __init__(self):
        super(Mish, self).__init__()
        self.main = nn.Sequential(
            nn.Softplus(),
            nn.Tanh()
        )
    
    def forward(self, x):
        return x*self.main(x)

In [0]:
class SNConv(nn.Module):
    def __init__(self, in_ch, out_ch, kernel=1, stride=1, padding=0, bias=True):
        super(SNConv, self).__init__()
        self.main = nn.Sequential(
            Mish(),
            spectral_norm(nn.Conv2d(in_ch, out_ch, kernel, stride, padding, bias=bias))
        )
    
    def forward(self, x):
        return self.main(x)

In [0]:
class SNRes(nn.Module):
    def __init__(self, in_ch, out_ch, kernel=3, stride=1, padding=1):
        super(SNRes, self).__init__()
        self.main = SNConv(in_ch, out_ch, kernel, stride, padding)
    
    def forward(self, x):
        return x + self.main(x)

In [0]:
class SNDense(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(SNDense, self).__init__()
        self.main = nn.Sequential(
            Mish(),
            spectral_norm(nn.Linear(in_ch, out_ch))
        )
        
    def forward(self, x):
        return self.main(x)

In [0]:
class SNSelfAttentionBlock(nn.Module):
    def __init__(self, in_ch, out_ch, gamma=1.):
        super(SNSelfAttentionBlock, self).__init__()

        self.gamma = gamma
        self.cf = SNConv(in_ch, out_ch//8)
        self.cg = SNConv(in_ch, out_ch//8)
        self.ch = SNConv(in_ch, out_ch)
        self.softmax = nn.Softmax(2)
            
    def forward(self, x):
        f = self.cf(x)
        g = self.cg(x)
        h = self.ch(x)
        f = f.view(f.size(0), f.size(1), -1)
        g = g.view(g.size(0), g.size(1), -1)
        h = h.view(h.size(0), h.size(1), -1)
        
        attention_map = torch.bmm(torch.transpose(f, 1, 2), g)
        attention_map = self.softmax(attention_map)
        feature_map = torch.bmm(h, torch.transpose(attention_map, 1, 2))
        feature_map = feature_map.view(*x.size())

        return x + feature_map*self.gamma

    def set_gamma(self, gamma):
        self.gamma = gamma

In [0]:
class MinibatchDiscrimination(nn.Module):
    def __init__(self, in_ch, kernel, kernel_dims, device):
        super(MinibatchDiscrimination, self).__init__()
        self.device = device
        self.kernel = kernel
        self.dim = kernel_dims
        self.t = nn.Linear(in_ch, self.kernel*self.dim, bias=False)
        for param in self.t.parameters():
            param.requires_grad = False

    def __call__(self, x):
        batchsize = x.size(0)
        m = self.t(x).view(batchsize, self.kernel, self.dim, 1)
        m_T = torch.transpose(m, 0, 3)
        m, m_T = torch.broadcast_tensors(m, m_T)
        norm = torch.sum(F.l1_loss(m, m_T, reduction='none'), dim=2)

        eraser = torch.eye(batchsize, device=self.device).view(batchsize, 1, batchsize).expand(norm.size())
        c_b = torch.exp(-(norm + 1e6 * eraser))
        o_b = torch.sum(c_b, dim=2)
        h = torch.cat((x, o_b), dim=1)
        return h

In [0]:
class InceptionBlock(nn.Module):
    def __init__(self, in_ch, mid_ch, out_ch):
        super(InceptionBlock, self).__init__()

        self.layer_11 = SNConv(in_ch, mid_ch)
        self.layer_33_1 = nn.Sequential(
            SNConv(in_ch, mid_ch),
            SNConv(mid_ch, mid_ch, 3, 1, 1)
        )
        self.layer_33_2 = nn.Sequential(
            SNConv(in_ch, mid_ch),
            SNConv(mid_ch, mid_ch, 3, 1, 1),
            SNRes(mid_ch, mid_ch)
        )

        self.c = SNConv(mid_ch*3, out_ch)
            
    def forward(self, x):
        h_11 = self.layer_11(x)
        h_33_1 = self.layer_33_1(x)
        h_33_2 = self.layer_33_2(x)
        
        h = torch.cat((h_11, h_33_1, h_33_2), dim=1)
        h = self.c(h)
        return h

In [0]:
class InceptionResBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(InceptionResBlock, self).__init__()
        self.ch = out_ch
        self.inception = InceptionBlock(in_ch, self.ch//4, self.ch)
            
    def forward(self, x):
        _h = self.inception(x)
        h = zeropad(x, self.ch)
        return h + _h

In [0]:
class DiscriminatorBlock(nn.Module):
    def __init__(self, in_ch, out_ch, sa_gamma=None, downconv=True):
        super(DiscriminatorBlock, self).__init__()
        self.out_ch = out_ch

        self.extractor = SNConv(in_ch, out_ch, 3, 1, 1) if not downconv else None
        self.sa = SNSelfAttentionBlock(out_ch, out_ch, gamma=sa_gamma) if sa_gamma else None

        self.downsample = nn.AvgPool2d(2)
        self.downconv = SNConv(in_ch  , out_ch, 4, 2, 1) if downconv else None
            
    def forward(self, x):
        h = self.extractor(x) if self.extractor is not None else x

        if self.sa:
            h = self.sa(h)

        if self.downconv:
            _h = zeropad(self.downsample(x), self.out_ch)
            h = self.downconv(h) + _h

        return h

    def set_gamma(self, gamma):
        if self.sa:
            self.sa.set_gamma(gamma)

In [0]:
class Discriminator(nn.Module):

    def __init__(self, out_ch=2, device=None, sa_gamma=1.):
        super(Discriminator, self).__init__()

        self.main = nn.Sequential(
            spectral_norm(nn.Conv2d(3, 16, 1, 1, 0)),
            DiscriminatorBlock(16, 32),
            DiscriminatorBlock(32, 64),
            DiscriminatorBlock(64, 128),
            DiscriminatorBlock(128, 256),
            DiscriminatorBlock(256, 512),
            DiscriminatorBlock(512, 512),
            DiscriminatorBlock(512, 512, sa_gamma=sa_gamma, downconv=False),
            nn.AvgPool2d(4),
            nn.Flatten(),
            MinibatchDiscrimination(512, 64, 16, device),
            SNDense(512+64, out_ch)
        )
    
    def forward(self, x):
        return self.main(x)

    def set_gamma(self, gamma):
        self.main[6].set_gamma(gamma)
        self.main[7].set_gamma(gamma)

In [0]:
class LatentGenerator(nn.Module):
    def __init__(self, in_ch, out_ch, device=None):
        super(LatentGenerator, self).__init__()
        self.device = device

        self.mu = spectral_norm(nn.Linear(in_ch, out_ch))
        self.var = spectral_norm(nn.Linear(in_ch, out_ch))

    def forward(self, w):
        mu = self.mu(w)
        var = self.var(w)
        h = mu + var * gaussian(var.size()).to(self.device)
        return pixel_norm(h)

In [0]:
class NoiseInjection(nn.Module):
    def __init__(self, in_ch, out_ch, device=None):
        super(NoiseInjection, self).__init__()
        self.device = device

        self.var = spectral_norm(nn.Linear(in_ch, out_ch))

    def forward(self, x, w):
        var = self.var(w)
        var = var.view(*var.size(), 1, 1).expand(*x.size())
        noise = gaussian(x.size()).to(self.device)
        return x + noise * var

In [0]:
class DepthwiseCondConv(nn.Module):
    def __init__(self, in_ch, n_kernels, kernel=3, stride=1, padding=1, device=None, k=1):
        super(DepthwiseCondConv, self).__init__()
        self.device = device

        self.kernel = kernel
        self.stride = stride
        self.padding = padding
        self.ch = in_ch
        self.kernels_size = (n_kernels, kernel, kernel)

        self.weight = nn.Parameter(torch.empty(n_kernels*kernel*kernel, k, 1))
        nn.init.normal_(self.weight)
        self.activation = Mish()
    
    def forward(self, x, sigma, bias=None, activation=True):
        b_size = x.size(0)

        h = self.activation(x) if activation else x
        
        kernels = self.weight.expand(-1, -1, self.ch)

        i = torch.eye(self.ch).to(self.device)
        i = i.view(1, self.ch, self.ch)
        i = i.expand(kernels.size(0), -1, -1)
        
        kernels = i*kernels
        kernels = kernels.view(*self.kernels_size, self.ch, -1)
        kernels = kernels.transpose(1, 3).transpose(2, 4)
        kernels = kernels.expand(b_size, -1, -1, -1, -1, -1)

        _s = sigma.view(*kernels.size()[:3], 1, 1, 1)
        _s = _s.expand(-1, -1, -1, self.ch, -1, -1)
        _s = torch.sum(kernels*_s, dim=1)
        _s = _s.reshape(b_size, self.ch, -1)
        _s = F.normalize(_s, dim=2)

        h = F.unfold(h, self.kernel, padding=self.padding, stride=self.stride)
        h = torch.bmm(_s, h)
        h = h.view(b_size, self.ch, *x.size()[2:])

        if bias is not None:
            _b = bias.view(-1, self.ch, 1, 1)
            h = h + _b
        
        return h

In [0]:
class SynthBlock(nn.Module):
    def __init__(self, in_ch, out_ch, latent_ch, mid_ch, kernels, device=None):
        super(SynthBlock, self).__init__()

        self.lg = LatentGenerator(latent_ch, mid_ch, device)
        self.sigma = spectral_norm(nn.Linear(mid_ch, in_ch*kernels))

        self.noise_injection = NoiseInjection(latent_ch, in_ch, device)
        self.dcc = spectral_norm(DepthwiseCondConv(in_ch, kernels, device=device))
        self.c = SNConv(in_ch, out_ch, bias=False)
            
    def forward(self, x, w, dcc_activation=True):
        _w = self.lg(w)
        sigma = self.sigma(_w)

        h = self.noise_injection(x, w)
        h = self.dcc(h, sigma, activation=dcc_activation)
        h = self.c(h)
        return h

In [0]:
class OutputGenerator(nn.Module):
    def __init__(self, ch_size, bias=False):
        super(OutputGenerator, self).__init__()
        self.ch_size = ch_size

        self.c_out = SNConv(ch_size, 3, bias=bias)

    def forward(self, x):
        if self.ch_size != x.size(1):
            h = x.view(x.size(0), self.ch_size//2, -1, *x.size()[2:])
            h = torch.cat((h[:,:,0,:,:], h[:,:,-1,:,:]), dim=1)
        else:
            h = x

        h = self.c_out(h)
        return h

In [0]:
class GeneratorBlock(nn.Module):
    def __init__(self, in_ch, out_ch, latent_ch, mid_ch, kernels, ch_size,
                 device=None, sa_gamma=None, bias=False):
        super(GeneratorBlock, self).__init__()
        self.synth_0 = SynthBlock(in_ch, out_ch*4, latent_ch, mid_ch, kernels, device)
        self.synth_1 = SynthBlock(out_ch*4, out_ch*4, latent_ch, mid_ch*4, kernels, device)
        self.sa = SNSelfAttentionBlock(out_ch*4, out_ch*4, gamma=sa_gamma) if sa_gamma else None
        self.out = OutputGenerator(ch_size, bias)
            
    def forward(self, x, w, input_activation=False):
        h = self.synth_0(x, w, input_activation)
        h = h + self.synth_1(h, w)

        if self.sa:
            h = self.sa(h)

        return self.out(h), h

    def set_gamma(self, gamma):
        if self.sa:
            self.sa.set_gamma(gamma)

In [0]:
class LastGeneratorBlock(nn.Module):
    def __init__(self, in_ch, out_ch, latent_ch, mid_ch, kernels, ch_size, device=None):
        super(LastGeneratorBlock, self).__init__()
        self.synth_0 = SynthBlock(in_ch, out_ch, latent_ch, mid_ch, kernels, device)
        self.synth_1 = SynthBlock(out_ch, out_ch, latent_ch, mid_ch, kernels, device)
            
    def forward(self, x, w, input_activation=False):
        h = self.synth_0(x, w, input_activation)
        h = h + self.synth_1(h, w)
        return h

In [0]:
class Generator(nn.Module):

    def __init__(self, ch_size, latent_size, hidden_size, device=None, sa_gamma=1., kernels=6):
        super(Generator, self).__init__()
        self.device = device
        self.latent_size = latent_size

        self.upbi = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.upps = nn.PixelShuffle(2)

        self.btm = nn.Parameter(torch.empty(1, ch_size*32, 4, 4))
        nn.init.normal_(self.btm)
        
        self.res6 = GeneratorBlock(ch_size*32, ch_size*16, latent_size, hidden_size*16, kernels, ch_size,
                                   device, sa_gamma, bias=True)
        self.res5 = GeneratorBlock(ch_size*16, ch_size*16, latent_size, hidden_size*16, kernels, ch_size,
                                   device)
        self.res4 = GeneratorBlock(ch_size*16, ch_size*8, latent_size, hidden_size*8, kernels, ch_size, device)
        self.res3 = GeneratorBlock(ch_size*8, ch_size*4, latent_size, hidden_size*4, kernels, ch_size, device)
        self.res2 = GeneratorBlock(ch_size*4, ch_size*2, latent_size, hidden_size*2, kernels, ch_size, device)
        self.res1 = GeneratorBlock(ch_size*2, ch_size, latent_size, hidden_size, kernels, ch_size, device)
        self.res0 = LastGeneratorBlock(ch_size, ch_size, latent_size, hidden_size, kernels, ch_size, device)
        
        self.out = SNConv(ch_size, 3)
    
    def forward(self, batch):
        w = pixel_norm(gaussian((batch, self.latent_size)).to(self.device))
        h = self.btm.expand(batch, -1, -1, -1)
        o, h = self.res6(h, w, input_activation=False)

        _o, h = self.res5(self.upps(h), w)
        o = _o + self.upbi(o)

        _o, h = self.res4(self.upps(h), w)
        o = _o + self.upbi(o)

        _o, h = self.res3(self.upps(h), w)
        o = _o + self.upbi(o)

        _o, h = self.res2(self.upps(h), w)
        o = _o + self.upbi(o)

        _o, h = self.res1(self.upps(h), w)
        o = _o + self.upbi(o)

        h = self.res0(self.upps(h), w)
        h = self.out(h)
        o = h + self.upbi(o)

        return torch.tanh(o)

    def set_gamma(self, gamma):
        self.res6.set_gamma(gamma)
        self.res5.set_gamma(gamma)

In [0]:
device = torch.device("cuda:{}".format(GPU))

In [0]:
def weights_init(m):
    if type(m) in (nn.Linear, nn.Conv2d):
        nn.init.xavier_normal_(m.weight, gain=initial_scale)
        if m.bias is not None:
            m.bias.data.fill_(0)
    elif type(m) in (nn.BatchNorm1d, nn.BatchNorm2d):
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [0]:
gen = Generator(CH_SIZE, LATENT_SIZE, HIDDEN_SIZE, device=device)
if load_gen:
    gen.load_state_dict(torch.load(OUT+'{}_0.4.0_{}epoch.pkl'.format(gen_name, load_gen)))
    gen.eval()
else:
    gen.apply(weights_init)

dis = Discriminator(out_ch=1, device=device)
if load_dis:
    dis.load_state_dict(torch.load(OUT+'{}_0.4.0_{}epoch.pkl'.format(dis_name, load_dis)))
    dis.eval()
else:
    dis.apply(weights_init)

gen.to(device)
dis.to(device)

In [0]:
opt_gen = optim.Adam(
    gen.parameters(),
    lr=learning_rate,
    betas=(0.5, 0.999),
    weight_decay=gen_weight_decay
    )
opt_dis = optim.Adam(
    dis.parameters(),
    lr=learning_rate,
    betas=(0.5, 0.999),
    weight_decay=dis_weight_decay
    )

In [0]:
gen_losses = list()
dis_losses = list()

In [0]:
if tweet_interval:
    auth = tweepy.OAuthHandler(twitter_api_key.CONSUMER_KEY, twitter_api_key.CONSUMER_SECRET)
    auth.set_access_token(twitter_api_key.ACCESS_TOKEN_KEY, twitter_api_key.ACCESS_TOKEN_SECRET)
    api = tweepy.API(auth)

In [0]:
def report_log(i, epoch, g_loss, d_loss, g_mean, d_real_mean, d_fake_mean):
    print('[{}/{}]\tLoss_D: {:.4f}\tLoss_G: {:.4f}\tD(x): {:.4f}\tD(G(z)): {:.4f}/{:.4f}'.format(
        i, epoch, d_loss, g_loss, d_real_mean, d_fake_mean, g_mean
    ))
    gen_losses.append(g_loss)
    dis_losses.append(d_loss)

In [0]:
def make_image(gen, img_size):
    clear_output()
    with torch.no_grad():
        generated = gen(8).detach().cpu()
    generated = np.transpose(np.reshape(generated, (-1, 3, img_size, img_size)), (0, 2, 3, 1))

    plt.figure(figsize=(16, 8))
        
    for i, img in enumerate(generated):
        plt.subplot(2, 4, i+1).axis('off')
        plt.subplot(2, 4, i+1).imshow(Image.fromarray(np.uint8((img+1.)/2. *255.)))

    plt.show()

In [0]:
def image_upload(image_array, api):
    bin_io = io.BytesIO()
    img = Image.fromarray(np.uint8((image_array+1.)/2. *255.))
    img = img.resize((512, 512), resample=0)
    img.save(bin_io, format='JPEG')
    result = api.media_upload(filename='{}_generated_{}.jpg'.format(target, uuid.uuid4()), file=bin_io)
    return result.media_id

In [0]:
def post_image(gen, api, img_size, iteration, epoch):
    with torch.no_grad():
        generated = gen(16).detach().cpu()
    generated = np.reshape(generated, (4, 2, 2, 3, img_size, img_size))
    generated = np.transpose(generated, (0, 3, 1, 4, 2, 5))
    generated = np.reshape(generated, (4, 3, img_size*2, img_size*2))
    generated = np.transpose(generated, (0, 2, 3, 1))

    try:
        img_ids = [image_upload(img, api) for img in generated]
        hash_tags = ['AIで{}を作る'.format(target_JP),
                    'iteration/epoch: {}/{}'.format(iteration+1, epoch+1),
                    '#makeing{}'.format(target),
                    '#nowlearning...',
                    '#AI',
                    '#人工知能',
                    '#DeepLearning',
                    '#GAN']

        api.update_status(
            status='\n'.join(hash_tags),
            media_ids=img_ids
            )
    except Exception:
        pass

In [0]:
def save_model(gen, dis, epoch):
    torch.save(gen.state_dict(), OUT+'{}_0.4.0_{}epoch.pkl'.format(gen_name, epoch))
    torch.save(dis.state_dict(), OUT+'{}_0.4.0_{}epoch.pkl'.format(dis_name, epoch))

In [0]:
def report_result():
    plt.figure(figsize=(16, 8))
    plt.plot(gen_losses,label="G")
    plt.plot(dis_losses,label="D")
    plt.xlabel("iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.show()

In [0]:
def round_dataset(gen, dis, opt_gen, opt_dis, dataloader, epoch, img_size, device):
    loss_fun = nn.MSELoss()
    data_len = len(dataloader)
    for i, data in tqdm(enumerate(dataloader, 0)):
        dis.zero_grad()
        x_real = data[0].to(device)
        b_size = x_real.size(0)
        y_real = dis(x_real).view(-1)
        real_loss = loss_fun(y_real, torch.ones(*y_real.size(), device=device))
        real_loss.backward()

        x_fake = gen(b_size)
        y_fake = dis(x_fake.detach()).view(-1)
        fake_loss = loss_fun(y_fake, torch.zeros(*y_fake.size(), device=device))
        fake_loss.backward()

        if grad_clip:
            clip_grad_norm_(dis.parameters(), grad_clip)
        dis_loss = real_loss + fake_loss
        opt_dis.step()
        
        gen.zero_grad()
        y_gen = dis(x_fake).view(-1)
        gen_loss = loss_fun(y_gen, torch.ones(*y_gen.size(), device=device))
        gen_loss.backward()

        if grad_clip:
            clip_grad_norm_(gen.parameters(), grad_clip)
        opt_gen.step()

        if display_interval and (i+1) % display_interval == 0:
            make_image(gen, img_size)

        if (i+1) % log_interval == 0:
            report_log(
                i,
                data_len,
                gen_loss.item(),
                dis_loss.item(),
                y_gen.mean().item(),
                y_real.mean().item(),
                y_fake.mean().item()
                )

        if tweet_interval and (i+1) % tweet_interval == 0:
            post_image(gen, api, img_size, i, epoch)
        
        if snapshot_interval and (i+1) % snapshot_interval == 0:
            save_model(gen, dis, epoch)

In [0]:
def train_loop(gen, dis, opt_gen, opt_dis, dataloader, epoch, img_size, device):
    for i in range(epoch):
        gen.train()
        dis.train()
        round_dataset(gen, dis, opt_gen, opt_dis, dataloader, i, img_size, device)
    
    report_result()

In [0]:
class GaussianBlur():
    def __init__(self, k):
        self.k = k

    def __call__(self, img):
        if self.k != 0:
            return img.filter(ImageFilter.GaussianBlur(self.k))
        else:
            return img

In [0]:
for k in range(3, -1, -1):
    dataset = dset.ImageFolder(root=dataroot,
                                transform=transforms.Compose([
                                                                transforms.Resize(IMG_SIZE),
                                                                transforms.RandomCrop(IMG_SIZE),
                                                                transforms.RandomHorizontalFlip(),
                                                                GaussianBlur(k),
                                                                transforms.ToTensor(),
                                                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                                                ]))
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
    train_loop(gen, dis, opt_gen, opt_dis, dataloader, EPOCHS[k], IMG_SIZE, device)

In [0]:
make_image(gen, IMG_SIZE)