In [1]:
import torch
from torch import nn
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR
from torch.nn.functional import softplus
from torch.autograd import grad
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import make_grid, save_image
from torchvision import transforms

import numpy as np
import random
from math import log2, ceil
import pandas as pd

from tqdm import tqdm_notebook as tqdm
from PIL import Image
import matplotlib.pyplot as plt

from datetime import datetime
from time import time

import os
import warnings
warnings.filterwarnings('ignore')

In [2]:
from dataset import CelebA, make_dataloader
from utils import sample_noise, find_alpha, allow_gradient, adjust_lr, linear_scale_lr, save_batch, save_reconstructions
from losses import zero_centered_gradient_penalty, loss_discriminator, loss_generator, loss_autoencoder
from net import MapingNetwork, StyleGenerator, Discriminator, Encoder, GeneratorBlock, ToRGB

## In variable `PATH` write folder, where raw aligned images from CelebA are located.

In [None]:
# PATH = '/root/data/CelebA/img_align_celeba/'
celeba = CelebA(path = PATH)

## `Training loop`
* Please, specify `DEVICE`, `experiment_name`

In [None]:
DEVICE = 'cuda:0'

experiment_name = 'Final_Run1'

try:
    os.mkdir(experiment_name)
except:
    pass

SAVE_IMAGES_EACH = 700
EPOCHS_PER_SCALE = 12
TOTAL_IMAGES = len(celeba)

lc = 256
bs_per_scale = {4: 128, 8: 128, 16: 128, 32: 64, 64: 32, 128: 16}
lr_per_resolution = {4: 0.0015, 8: 0.0015, 16: 0.0015, 32: 0.0015, 64: 0.0015, 128: 0.0015}
BLUR_ACTIVATIONS = True

F = MapingNetwork(code=lc, d=8).to(DEVICE)

G = StyleGenerator(max_fm=256, code=lc, blur_upsample=BLUR_ACTIVATIONS).to(DEVICE)
G_average = StyleGenerator(max_fm=256, code=lc, blur_upsample=BLUR_ACTIVATIONS).to(DEVICE) # Clone Generator to track EMA
G_average.load_state_dict(G.state_dict())
G_average.eval()

E = Encoder(max_fm=256, code=lc, fc_intital=True, blur_downsample=BLUR_ACTIVATIONS).to(DEVICE)
D = Discriminator(code=lc, d=3).to(DEVICE)

optimizer_D =  Adam(list(E.parameters()) + list(D.parameters()), betas=(0, 0.99), lr=lr_per_resolution[4])
optimizer_G =  Adam(G.parameters(), betas=(0, 0.99), lr=lr_per_resolution[4])
optimizer_G.add_param_group({
                                'params': F.parameters(),
                                'lr': lr_per_resolution[4] * 0.01,
                                'mult': 0.01,
                            })
optimizer_AE = Adam(list(G.parameters()) + list(E.parameters()), betas=(0, 0.99), lr=lr_per_resolution[4])

USE_SCHEDULING = False
scheduler_D = StepLR(optimizer_D, step_size=1, gamma=0.1)
scheduler_G = StepLR(optimizer_G, step_size=1, gamma=0.1)
scheduler_AE = StepLR(optimizer_AE, step_size=1, gamma=0.1)

loss_stats = {'G':  [],
              'D':  [],
              'AE': []}

# Store 100 noisy guys which we will visualize to track Generator's progress
terracotta_army = sample_noise(100, code=lc, device=DEVICE)

for scale in [4, 8, 16, 32, 64, 128]:
    print('*' * 100, '\n', str(scale) * 50, '\n', '*' * 100)
    
    BS = bs_per_scale[scale]
    
    tracked_images = 0
    limit = int(0.5 * TOTAL_IMAGES * EPOCHS_PER_SCALE)
    alpha = find_alpha(tracked_images, limit)
        
    warmup = True
    epoch_start = time()
    
    # Set necessary learning rate
    for opt in [optimizer_AE, optimizer_G, optimizer_D]:
        adjust_lr(opt, lr_per_resolution[scale])
    
        
    for epoch in range(EPOCHS_PER_SCALE):
        total_batches = len(celeba)//BS
        for batch_idx, real_samples in tqdm(enumerate(make_dataloader(celeba, BS, image_size=scale)), total=total_batches,
                                            desc=f'Scale: {scale}, Epoch: [{epoch}/{EPOCHS_PER_SCALE}]'):
            
            # In the paper 500k with blending & 500k with alpha=1 for each scale
            alpha = find_alpha(tracked_images, limit)

            # Discriminator loss
            z1, z2 = sample_noise(BS, code=lc, device=DEVICE), sample_noise(BS, code=lc, device=DEVICE)
            w = F(z1, scale, z2, p_mix=0.9)
            
            real_samples = real_samples.to(DEVICE).requires_grad_()
            fake_samples = G(w, scale, alpha)

            lossD = loss_discriminator(E, D, alpha, real_samples, fake_samples, gamma=10)
            optimizer_D.zero_grad()
            lossD.backward()
            optimizer_D.step()

            # Generator loss
            z1, z2 = sample_noise(BS, code=lc, device=DEVICE), sample_noise(BS, code=lc, device=DEVICE)
            w = F(z1, scale, z2, p_mix=0.9)
            
            fake_samples = G(w, scale, alpha)

            lossG = loss_generator(E, D, alpha, fake_samples)
            optimizer_G.zero_grad()
            lossG.backward()
            optimizer_G.step()

            # Autoencoder loss
            z = sample_noise(BS, code=lc, device=DEVICE)

            lossAE = loss_autoencoder(F, G, E, scale, alpha, z)
            optimizer_AE.zero_grad()
            lossAE.backward()
            optimizer_AE.step()

            loss_stats['D'].append(lossD.item())
            loss_stats['G'].append(lossG.item())
            loss_stats['AE'].append(lossAE.item())

            tracked_images += real_samples.shape[0]
            
            # Keep average version of Generator
            G_average.ema(G, beta=0.999) 
            
            if (batch_idx % SAVE_IMAGES_EACH) == 0:
                print('Batch idx: ', batch_idx, 'Alpha: ', round(alpha, 3), 'Minutes spent: ',round((time()-epoch_start)/60, 3))

                name = f'Generation_Scale{scale}_Img{int(tracked_images//1000)}.png'
                
                with torch.no_grad():
                    codes_real = E(real_samples, alpha)[:, None, :].repeat(1, int(log2(scale)-1), 1)
                    codes_fake = E(fake_samples, alpha)[:, None, :].repeat(1, int(log2(scale)-1), 1)
                    
                    reconstructions_real = G(codes_real, scale, alpha).cpu().detach()
                    reconstructions_fake = G(codes_fake, scale, alpha).cpu().detach()
                    
                real, fake = real_samples.cpu().detach(), fake_samples.cpu().detach()

                # Save generations
                save_batch(os.path.join(experiment_name, name), fake, real, nrows=6)
                Image.open(os.path.join(experiment_name, name)).resize((1024, 1024)).save(os.path.join(experiment_name, name))
                                
                # Save reconstructions (of reals and fakes)
                name = f'Reconstruction_Scale{scale}_Img{int(tracked_images//1000)}.png'
                save_reconstructions(os.path.join(experiment_name, name),
                                     [real, fake],
                                     [reconstructions_real, reconstructions_fake], nrows=6)
                Image.open(os.path.join(experiment_name, name)).resize((1536, 512)).save(os.path.join(experiment_name, name))
                
                # Save terracotta_army                
                name = f'Terracotta_army_Scale{scale}_Img{int(tracked_images//1000)}.png'
                
                with torch.no_grad():
                    code = F(terracotta_army, scale, z2=None, p_mix=0).chunk(10)
                    generated_army = torch.cat([G(c, scale, alpha) for c in code], dim=0)
                
                save_image(generated_army.cpu().detach(),
                           os.path.join(experiment_name, name), nrow=10, padding=0, normalize=True, range=(-1, 1))
                Image.open(os.path.join(experiment_name, name)).resize((1024, 1024)).save(os.path.join(experiment_name, name))
                
                # Save losses
                pd.DataFrame(loss_stats).to_csv(os.path.join(experiment_name, 'stats.csv'), index=False)
        
        # Save plot of loss
        fig, ax = plt.subplots(figsize=(20, 10))
        ax.set_ylim([0, 5])
        ax.plot(loss_stats['D'], label='Disc', alpha=0.5, c='b')
        ax.plot(loss_stats['G'], label='Gen', alpha=0.5, c='r')
        ax.plot(loss_stats['AE'], label='AE', alpha=0.5, c='g')
        ax.set(xlabel='Batches tracked', ylabel='Loss')
        ax.legend()
        plt.savefig(os.path.join(experiment_name, 'stats.png'))
        plt.clf() 

        if USE_SCHEDULING & (alpha > 0.99) & (epoch in [8, 10]):
            scheduler_D.step()
            scheduler_G.step()
            scheduler_AE.step()
    
    # At the end of the epoch save the model at given scale
    # TODO: Change to more smart saving
    torch.save({'F': F.state_dict(),
                'G': G.state_dict(),
                'E': E.state_dict(),
                'D': D.state_dict(),
                'G_average': G_average.state_dict(),
                
                'optD': optimizer_D.state_dict(),
                'optG': optimizer_G.state_dict(),
                'optAE': optimizer_AE.state_dict()},
                os.path.join(experiment_name, f'Scale{scale}.pt'))