In [1]:
DATA_CATEGORY = 'capsule'

import os

checkpoint_dir = 'checkpoint/'+DATA_CATEGORY+'/'
sample_dir     = 'sample/'+DATA_CATEGORY+'/'

# If folder doesn't exist, then create it.
if not os.path.isdir(checkpoint_dir):
    os.makedirs(checkpoint_dir)
if not os.path.isdir(sample_dir):
    os.makedirs(sample_dir)

In [2]:
#coding: future_fstrings     # should work even without -*-
#location = fromstr(f'POINT({longitude} {latitude})', srid=4326)

import os
os.environ["CUDA_VISIBLE_DEVICES"]="1"

import sys
sys.argv = [sys.argv[0], '--sched', 'mvtec_single_out/'+DATA_CATEGORY]# '--loss', 'r1', '--mixing'

import argparse
import random
import math

from tqdm import tqdm
import numpy as np
from PIL import Image

import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.autograd import Variable, grad
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils

from dataset import MultiResolutionDataset
from model import StyledGenerator, Discriminator, Encoder


def requires_grad(model, flag=True):
    for p in model.parameters():
        p.requires_grad = flag


def accumulate(model1, model2, decay=0.999):
    par1 = dict(model1.named_parameters())
    par2 = dict(model2.named_parameters())

    for k in par1.keys():
        par1[k].data.mul_(decay).add_(1 - decay, par2[k].data)


def sample_data(dataset, batch_size, image_size=4):
    dataset.resolution = image_size
    loader = DataLoader(dataset, shuffle=True, batch_size=batch_size, num_workers=1, drop_last=True)

    return loader


def adjust_lr(optimizer, lr):
    for group in optimizer.param_groups:
        mult = group.get('mult', 1)
        group['lr'] = lr * mult



In [3]:
code_size = 512
base_batch_size = 1 #16
n_critic = 1

parser = argparse.ArgumentParser(description='Progressive Growing of GANs')

parser.add_argument('path', type=str, help='path of specified dataset')
parser.add_argument(
    '--phase',
    type=int,
    default= 30000, #600_000 // 4,
    help='number of samples used for each training phases',
)
parser.add_argument('--lr', default=0.001, type=float, help='learning rate')
parser.add_argument('--sched', action='store_true', help='use lr scheduling')
parser.add_argument('--init_size', default=8, type=int, help='initial image size') #8
parser.add_argument('--max_size', default=512, type=int, help='max image size')
parser.add_argument(
    '--ckpt', default=None, type=str, help='load from previous checkpoints'
)
parser.add_argument(
    '--no_from_rgb_activate',
    action='store_true',
    help='use activate in from_rgb (original implementation)',
)
parser.add_argument(
    '--mixing', action='store_true', help='use mixing regularization'
)
parser.add_argument(
    '--loss',
    type=str,
    default='wgan-gp',
    choices=['wgan-gp', 'r1'],
    help='class of gan loss',
)

args = parser.parse_args()

encoder = nn.DataParallel(
    Encoder(from_rgb_activate=not args.no_from_rgb_activate)
).cuda()
generator = nn.DataParallel(StyledGenerator(code_size)).cuda()
discriminator = nn.DataParallel(
    Discriminator(from_rgb_activate=not args.no_from_rgb_activate)
).cuda()
g_running = StyledGenerator(code_size).cuda()
g_running.train(False)

e_running = Encoder(from_rgb_activate=not args.no_from_rgb_activate).cuda()
e_running.train(False)

g_optimizer = optim.Adam(
    generator.module.generator.parameters(), lr=args.lr, betas=(0.0, 0.99)
)
g_optimizer.add_param_group(
    {
        'params': generator.module.style.parameters(),
        'lr': args.lr * 0.01,
        'mult': 0.01,
    }
)
d_optimizer = optim.Adam(discriminator.parameters(), lr=args.lr, betas=(0.0, 0.99))
#beta1=0.9, beta2=0.99, epsilon=1e-8)
e_optimizer = optim.Adam(encoder.parameters(), lr=args.lr, betas=(0.0, 0.99))

accumulate(g_running, generator.module, 0)
accumulate(e_running, encoder.module, 0)

#assert args.ckpt is not None
if args.ckpt is not None:
    ckpt = torch.load(args.ckpt)

    generator.module.load_state_dict(ckpt['generator'])
    discriminator.module.load_state_dict(ckpt['discriminator'])
    g_running.load_state_dict(ckpt['g_running'])
    g_optimizer.load_state_dict(ckpt['g_optimizer'])
    d_optimizer.load_state_dict(ckpt['d_optimizer'])


transform = transforms.Compose(
    [
        #transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
    ]
)

dataset = MultiResolutionDataset(args.path, transform)

if args.sched:
    args.lr = {128: 0.0015, 256: 0.002, 512: 0.002, 1024: 0.002}
    #args.batch = {4: 512, 8: 256, 16: 128, 32: 64, 64: 32, 128: 32, 256: 32}
    args.batch = {4: base_batch_size*64, 8: base_batch_size*32, 16: base_batch_size*32, 32: base_batch_size*16,
                  64: base_batch_size*4, 128: base_batch_size*2, 256: base_batch_size, 512: base_batch_size}

else:
    args.lr = {}
    args.batch = {}

args.gen_sample = {512: (8, 4), 1024: (4, 2)}

args.batch_default = base_batch_size



	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha)


In [4]:
loss_g = []
loss_e = []
loss_d = []

In [5]:
#train(args, dataset, generator, discriminator, 0)


#def train(args, dataset, generator, discriminator, initial_num):
step = int(math.log2(args.init_size)) - 2
resolution = 4 * 2 ** step
loader = sample_data(
    dataset, args.batch.get(resolution, args.batch_default), resolution
)
data_loader = iter(loader)

adjust_lr(g_optimizer, args.lr.get(resolution, 0.001))
adjust_lr(d_optimizer, args.lr.get(resolution, 0.001))

pbar = tqdm(range(3_000_000)) #3_000_000

requires_grad(generator, False)
requires_grad(discriminator, True)

disc_loss_val = 0
gen_loss_val = 0
grad_loss_val = 0

alpha = 0
used_sample = 0

max_step = int(math.log2(args.max_size)) - 2
final_progress = False

for i in pbar:
    discriminator.zero_grad()

    alpha = min(1, 1 / args.phase * (used_sample + 1))

    if (resolution == args.init_size and args.ckpt is None) or final_progress:
        alpha = 1

    if used_sample > args.phase * 2:
        used_sample = 0
        step += 1

        if step > max_step:
            step = max_step
            final_progress = True
            ckpt_step = step + 1

        else:
            alpha = 0
            ckpt_step = step

        resolution = 4 * 2 ** step

        del loader
        del data_loader 

        loader = sample_data(
            dataset, args.batch.get(resolution, args.batch_default), resolution
        )
        data_loader = iter(loader)

        torch.save(
            {
                'encoder': encoder.module.state_dict(),
                'generator': generator.module.state_dict(),
                'discriminator': discriminator.module.state_dict(),
                'g_optimizer': g_optimizer.state_dict(),
                'd_optimizer': d_optimizer.state_dict(),
                'g_running': g_running.state_dict(),
                'e_running': e_running.state_dict(),
            },
            ('checkpoint/capsule/train_step-' + str(ckpt_step) + '.model'),
        )

        adjust_lr(g_optimizer, args.lr.get(resolution, 0.001))
        adjust_lr(d_optimizer, args.lr.get(resolution, 0.001))

    try:
        real_image = next(data_loader)

    except (OSError, StopIteration):
        data_loader = iter(loader)
        real_image = next(data_loader)

    used_sample += real_image.shape[0]

    b_size = real_image.size(0)
    real_image = real_image.cuda()

    if args.loss == 'wgan-gp':
        real_predict = discriminator(real_image, step=step, alpha=alpha)
        real_predict = real_predict.mean() - 0.001 * (real_predict ** 2).mean()
        (-real_predict).backward()

    elif args.loss == 'r1':
        real_image.requires_grad = True
        real_scores = discriminator(real_image, step=step, alpha=alpha)
        real_predict = F.softplus(-real_scores).mean()
        real_predict.backward(retain_graph=True)

        grad_real = grad(
            outputs=real_scores.sum(), inputs=real_image, create_graph=True
        )[0]
        grad_penalty = (
            grad_real.view(grad_real.size(0), -1).norm(2, dim=1) ** 2
        ).mean()
        grad_penalty = 10 / 2 * grad_penalty
        grad_penalty.backward()
        if i%10 == 0:
            grad_loss_val = grad_penalty.item()

    if args.mixing and random.random() < 0.9:
        gen_in11, gen_in12, gen_in21, gen_in22 = torch.randn(
            4, b_size, code_size, device='cuda'
        ).chunk(4, 0)
        gen_in1 = [gen_in11.squeeze(0), gen_in12.squeeze(0)]
        gen_in2 = [gen_in21.squeeze(0), gen_in22.squeeze(0)]

    else:
        gen_in1, gen_in2 = torch.randn(2, b_size, code_size, device='cuda').chunk(
            2, 0
        )
        gen_in1 = gen_in1.squeeze(0)
        gen_in2 = gen_in2.squeeze(0)

    #gen_in1 = encoder(real_image, step=step, alpha=alpha)

    _, fake_image = generator([gen_in1], step=step, alpha=alpha)
    fake_predict = discriminator(fake_image, step=step, alpha=alpha)

    if args.loss == 'wgan-gp':
        fake_predict = fake_predict.mean()
        fake_predict.backward()

        eps = torch.rand(b_size, 1, 1, 1).cuda()
        x_hat = eps * real_image.data + (1 - eps) * fake_image.data
        x_hat.requires_grad = True
        hat_predict = discriminator(x_hat, step=step, alpha=alpha)
        grad_x_hat = grad(
            outputs=hat_predict.sum(), inputs=x_hat, create_graph=True
        )[0]
        grad_penalty = (
            (grad_x_hat.view(grad_x_hat.size(0), -1).norm(2, dim=1) - 1) ** 2
        ).mean()
        grad_penalty = 10 * grad_penalty
        grad_penalty.backward()
        if i%10 == 0:
            grad_loss_val = grad_penalty.item()
            disc_loss_val = (-real_predict + fake_predict).item()
            loss_d.append(disc_loss_val)

    elif args.loss == 'r1':
        fake_predict = F.softplus(fake_predict).mean()
        fake_predict.backward()
        if i%10 == 0:
            disc_loss_val = (real_predict + fake_predict).item()
            loss_d.append(disc_loss_val)

    d_optimizer.step()

    if (i + 1) % n_critic == 0:
        generator.zero_grad()

        requires_grad(generator, True)
        requires_grad(discriminator, False)

        #gen_in2 = encoder(real_image, step=step, alpha=alpha)
        _, fake_image = generator([gen_in2], step=step, alpha=alpha)
        predict = discriminator(fake_image, step=step, alpha=alpha)

        if args.loss == 'wgan-gp':
            loss = -predict.mean()
        elif args.loss == 'r1':
            loss = F.softplus(-predict).mean()
        if i%10 == 0:
            gen_loss_val = loss.item()
            loss_g.append(gen_loss_val)
        loss.backward()
        g_optimizer.step()
        accumulate(g_running, generator.module)

        requires_grad(generator, False)
        requires_grad(discriminator, True)

    if (i + 1) % 100 == 0:
        images = []

        gen_i, gen_j = args.gen_sample.get(resolution, (10, 5))

        images.append(real_image.data.detach().cpu())
        images.append(g_running(
                        [gen_in2], step=step, alpha=alpha
                    )[1].data.detach().cpu())
                      
        with torch.no_grad():
            for _ in range(gen_i):
                images.append(
                    g_running(
                        [torch.randn(gen_j, code_size).cuda()], step=step, alpha=alpha
                    )[1].data.detach().cpu()
                )

        utils.save_image(
            torch.cat(images, 0),
            'sample/capsule/'+str(i + 1).zfill(6)+ '.png',
            nrow=gen_i,
            normalize=True,
            range=(-1, 1),
        )

    if (i + 1) % 10000 == 0:
        torch.save(
            g_running.state_dict(), 'checkpoint/capsule/'+str(i + 1).zfill(6)+'.model'
        )

    state_msg = (
        f'Size: {4 * 2 ** step}; G: {gen_loss_val:.3f}; D: {disc_loss_val:.3f};'
        f' Grad: {grad_loss_val:.3f}; Alpha: {alpha:.5f}'
    )

    pbar.set_description(state_msg)



Size: 16; G: 11.020; D: -8.452; Grad: 1.316; Alpha: 1.00000:   0%|          | 3416/3000000 [09:26<137:57:16,  6.03it/s] 


KeyboardInterrupt: 