In [1]:
import os
from os.path import join as ospj
import time
from time import gmtime, strftime
import datetime
from munch import Munch
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from collections import deque

from network.model_zt import build_model
from core.checkpoint import CheckpointIO
from dataset.frame_dataset import FramesDataset, MotionDataset, DatasetRepeater
import network.utils as utils
import yaml
import random
from utils import Logger, AverageMeter, center
from tqdm import tqdm
from torch.nn.parallel.data_parallel import DataParallel
import pytorch_ssim
from torchvision import transforms
from network.vgg import VGG
from torchvision.models import vgg19
import imp
from network.wing import FAN, HighPass
from network.interp import AntiAliasInterpolation2d
import copy
from torch.optim.lr_scheduler import MultiStepLR

import warnings
warnings.filterwarnings("ignore")
import imageio

In [2]:
with open('config/train_transformer7.yaml') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)
config = Munch(config)

In [3]:
# GPU Device
gpu_id = '0,1,2'
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
use_cuda = torch.cuda.is_available()
print("GPU device " , use_cuda)

GPU device  True


In [4]:
torch.manual_seed(config.seed)
torch.cuda.manual_seed_all(config.seed)
np.random.seed(config.seed)

In [5]:
resume = False

In [6]:
class ImagePyramide(torch.nn.Module):
    """
    Create image pyramide for computing pyramide perceptual loss. See Sec 3.3
    """
    def __init__(self, scales, num_channels):
        super(ImagePyramide, self).__init__()
        downs = {}
        for scale in scales:
            downs[str(scale).replace('.', '-')] = AntiAliasInterpolation2d(num_channels, scale)
        self.downs = nn.ModuleDict(downs)

    def forward(self, x):
        out_dict = {}
        for scale, down_module in self.downs.items():
            out_dict['prediction_' + str(scale).replace('-', '.')] = down_module(x)
        return out_dict

In [7]:
class Solver(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.args.lr = float(self.args.lr)
        self.args.weight_decay = float(self.args.weight_decay)
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.start = 10
        self.replay_memory = 10000
        self.replay_buffer = deque(maxlen=self.replay_memory)

        self.nets = build_model(args)
        # below setattrs are to make networks be children of Solver, e.g., for self.to(self.device)
        for name, module in self.nets.items():
            utils.print_network(module, name)
            setattr(self, name, module)

        if args.mode == 'train':
            self.optims = Munch()
            self.scheduler = []
            for net in self.nets.keys():
                self.optims[net] = torch.optim.Adam(params=self.nets[net].parameters(), lr=float(args.lr), betas=[args.beta1, args.beta2],
                                                   weight_decay=0)
                sd = MultiStepLR(self.optims[net], args.lr_decay, gamma=0.1, last_epoch=-1)
                self.scheduler.append(sd)

            self.ckptios = [
                CheckpointIO(ospj(args.checkpoint_dir, '{:06d}_nets.ckpt'), **self.nets),
                CheckpointIO(ospj(args.checkpoint_dir, '{:06d}_optims.ckpt'), **self.optims)]
    

        self.to(self.device)
        for name, network in self.named_children():
            # Do not initialize the FAN parameters
            if ('ema' not in name) and ('fan' not in name):
                print('Initializing %s...' % name)
                network.apply(utils.he_init)
        
        # landmark
        self.fan = FAN(fname_pretrained=config.fname_fan)
        self.fan.to(self.device)
        self.fan.eval()
        self.hpf = HighPass(config.w_hpf, self.device)
        self.masking = transforms.RandomErasing(p=1.0, scale=(0.25, 0.33), ratio=(0.3, 0.33))
        self.clip_num = config.clip_num
        
        # perceptual loss
        self.vgg = VGG()
        MainModel = imp.load_source("MainModel", args.fname_ir)
        weight = torch.load(args.fname_vgg, map_location='cpu')
        self.vgg.load_state_dict(weight.state_dict(), strict=False)
        self.vgg.eval()
        self.vgg.to(self.device)
        self.vgg19 = vgg19(pretrained=True)
        self.vgg19.eval()
        self.vgg19.to(self.device)
        
        # multiscale
        self.image_pyramid = ImagePyramide(args.scales, 3).to(self.device)
        

    def _save_checkpoint(self, step):
        for ckptio in self.ckptios:
            ckptio.save(step)

    def _load_checkpoint(self, step):
        for ckptio in self.ckptios:
            ckptio.load(step)

    def _reset_grad(self):
        for optim in self.optims.values():
            optim.zero_grad()
            
    def train(self, loaders):
        args = self.args
        nets = self.nets
 
        for name in nets:
            nets[name] = DataParallel(nets[name])
            nets[name] = nets[name].to(self.device)
        optims = self.optims

        # resume training if necessary
        if args.resume_iter > 0:
            self._load_checkpoint(args.resume_iter)

        # batch
        for epoch in range(args.resume_iter, args.epochs):
            bar = tqdm(total=len(loaders.src), leave=False)
            wgan_loss, d_reg_loss = AverageMeter(), AverageMeter()
            g_latent_loss, vgg_loss, fm_loss, cm_loss = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()
            for i, inputs in enumerate(loaders.src):
                x_source, y_drive = inputs['source'], inputs['target']
                x_source_land = self.fan.get_landmark(x_source.to(self.device))
                num_frame = len(y_drive)
                x_source_mb = x_source.unsqueeze(1).repeat(1, args.K, 1, 1, 1)
                x_source_land = F.interpolate(self.fan.get_landmark(x_source.cuda()), size=256, mode='bilinear')
                x_source_land_mb = x_source_land.unsqueeze(1).repeat(1, args.K, 1, 1, 1)
                y_drive_mb = y_drive[0].unsqueeze(1)
                y_drive_land = F.interpolate(self.fan.get_landmark(y_drive[0].cuda()), size=256, mode='bilinear')
                y_drive_land = y_drive_land.unsqueeze(1)

                for f in range(1, num_frame):
                    y_drive_ld = self.fan.get_landmark(y_drive[f].cuda())
                    y_drive_ld = F.interpolate(y_drive_ld, size=y_drive[f].size(2), mode='bilinear') # (bs, K, 1, w, h)
                    y_drive_mb = torch.cat([y_drive_mb, y_drive[f].unsqueeze(1)], dim=1) # (bs, K, 3, w, h)
                    y_drive_land = torch.cat([y_drive_land, y_drive_ld.unsqueeze(1)], dim=1) # (bs, K, 1, w,h)
                
#                     self.replay_buffer.append((, y_drive[f]))           
#                     if len(self.replay_buffer) < self.start:
#                         continue
#                     minibatch = random.sample(self.replay_buffer, 1)
#                     x_source_mb, y_drive_mb = minibatch[0][0], minibatch[0][1]
                    
                k = np.random.randint(args.K)
                # Transformer computation
                e_hat = compute_tf_loss(nets, args, x_source_mb, y_drive_mb, x_source_land_mb, y_drive_land, device=self.device)

                # Generator
                g_loss, g_losses_latent = compute_g_loss(nets, args, x_source, y_drive_mb[:,k,:,:,:], e_hat, 
                                                         x_source_land, y_drive_land[:,k,:,:,:], 
                                                         vgg=self.vgg, vgg19=self.vgg19,
                                                         image_pyramid=self.image_pyramid, scales=args.scales,
                                                         device=self.device)        
                self._reset_grad()
                g_loss.backward()
                optims.generator.step()
                optims.transformer.step()

                # Discriminator
                d_loss, d_losses_latent = compute_d_loss(nets, args, x_source, y_drive_mb[:,k,:,:], e_hat, 
                                                               x_source_land, y_drive_land[:,k,:,:,:], 
                                                               image_pyramid=self.image_pyramid, scales=args.scales,
                                                               device=self.device) 
                self._reset_grad()
                d_loss.backward()
                optims.discriminator.step()

                wgan_loss.update(float(d_losses_latent.wgangp), x_source.size(0))
                d_reg_loss.update(float(d_losses_latent.reg), x_source.size(0))
                g_latent_loss.update(float(g_losses_latent.adv), x_source.size(0))
    #                     g_cycle_loss.update(float(g_losses_latent.cyc), x_source.size(0))
                vgg_loss.update(float(g_losses_latent.vgg), x_source.size(0))
                fm_loss.update(float(g_losses_latent.fm), x_source.size(0))
                cm_loss.update(float(g_losses_latent.cm), x_source.size(0))

                bar.set_description("Ep:{:d}, D: {:.6f}, R1: {:.2f}, G: {:.6f}, Vgg: {:.6f}, FM: {:.6f}, CM: {:.6f}".format(
                    epoch+1, wgan_loss.avg, d_reg_loss.avg, 
                    g_latent_loss.avg, vgg_loss.avg, fm_loss.avg, cm_loss.avg), refresh=True)
                bar.update()
            bar.close()
            
            for sc in self.scheduler:
                sc.step()

                # save model checkpoints
            logger.append([str(wgan_loss.avg)[:8], str(d_reg_loss.avg)[:8], 
                           str(g_latent_loss.avg)[:8], str(vgg_loss.avg)[:8], str(fm_loss.avg)[:8], str(cm_loss.avg)[:8]])
            if (epoch+1) % config.save_every == 0:
                self._save_checkpoint(step=epoch+1)

            # compute SSIM and FID in test_set
            if (epoch+1) % config.eval_every == 0:
                self.evaluate(args, epoch, nets, loaders.val)
                
            self.make_animation(args, nets, loaders)
        
        self.evaluate(args, epoch, nets, loaders.val)
            
    @torch.no_grad()
    def evaluate(self, args, epoch, nets, loader):
        if not os.path.isdir(args.result_dir):
            os.makedirs(args.result_dir)
        result_target = os.path.join(args.result_dir, 'tar')
        result_gen = os.path.join(args.result_dir,'gen')
        if not os.path.isdir(result_target):
            os.makedirs(result_target)
        if not os.path.isdir(result_gen):
            os.makedirs(result_gen)
        
        bar = tqdm(total=len(loader), leave=False)
        ssim_meter, fid_meter = AverageMeter(), AverageMeter()
        for iteration, x in enumerate(loader):
            try:
                test_video = torch.tensor(np.concatenate(x['video'])) # (frame, c, w, h)
            except:
                continue
            num_frame = test_video.shape[0]
            k_frame = np.random.choice(num_frame-args.K, size=2, replace=False)
            source = test_video[[k_frame[0]]].cuda() # (1, 3, 256, 256)
            target = test_video[k_frame[1]:k_frame[1]+args.K].cuda() # (8, 3, 256, 256)
            x_source_mb = source.unsqueeze(1).repeat(1,args.K,1,1,1) # (1, 8, 3, 256, 256)
            x_source_land = F.interpolate(self.fan.get_landmark(source.cuda()), size=256, mode='bilinear') # (1, 1, 256, 256)
            y_drive_mb = target[[0]].unsqueeze(1) # (1, 1, 3, 256, 256) first frame
            y_drive_land = F.interpolate(self.fan.get_landmark(target[[0]].cuda()), size=256, mode='bilinear') # (1, 1, 256, 256)
            y_drive_land = y_drive_land.unsqueeze(1) # (1, 1, 1, 256, 256)
            for i in range(1, args.K):
                y_drive_ld = self.fan.get_landmark(target[[i]]) # (1, 1, 64, 64)
                y_drive_ld = F.interpolate(y_drive_ld, size=target[[i]].size(2), mode='bilinear') # (1, 1, 256, 256)
                y_drive_mb = torch.cat([y_drive_mb, target[[i]].unsqueeze(1)], dim=1) # (1, K, 3, 256, 256)
                y_drive_land = torch.cat([y_drive_land, y_drive_ld.unsqueeze(1)], dim=1) # (1, K, 1, 256, 256)

            x_source_land_mb = x_source_land.unsqueeze(1).repeat(1, args.K, 1, 1, 1) # (1, K, 1, 256, 256)
            x_source_land_mb = x_source_land_mb.view(-1, 1, args.img_size, args.img_size) # (k, 1, 256, 256)
            y_drive_land = y_drive_land.view(-1, 1, args.img_size, args.img_size) # (k, 1, 256, 256)
            
            out = nets.transformer(x_source_land_mb, y_drive_land)
            out = out.view(-1, args.K, args.max_conv_dim, 1, 1) # (1, K, 512, 1, 1)
            e_hat = out.mean(dim=1) # (bs, 512, 1, 1)
            idx = np.random.randint(0, args.K) # choose random frame
            source_gen = nets.generator(source, x_source_land, y_drive_mb[:,idx,:,:,:], y_drive_land[idx,:,:,:], e_hat)
            ssim = float(pytorch_ssim.ssim(source_gen, y_drive_mb[:,idx,:,:,:]))
            ssim_meter.update(ssim, iteration+1)
            
            # save for FID
            gen = source_gen.squeeze().cpu().detach().numpy()
            target = y_drive_mb[:,idx,:,:,:].squeeze().cpu().detach().numpy()
            gen = gen.swapaxes(0, 1).swapaxes(1, 2)
            target = target.swapaxes(0, 1).swapaxes(1, 2)
            gen_img = Image.fromarray((gen*255).astype('uint8'))
            tar_img = Image.fromarray((target*255).astype('uint8'))
            gen_img.save(result_gen + '/{}.png'.format(iteration+1))
            tar_img.save(result_target + '/{}.png'.format(iteration+1))
            
            bar.set_description("Epoch:{:d}, SSIM: {:.8f}".format(epoch+1, ssim_meter.avg), refresh=True)
            bar.update()
        bar.close()
        val_logger.append([str(ssim_meter.avg)])
        return
    
    @torch.no_grad()
    def make_animation(self, args, nets, loaders):
        if not os.path.isdir(args.sample_dir):
            os.makedirs(args.sample_dir)
        K = 100
        random_list = np.random.choice(len(loaders.val.dataset), replace=False, size=2)
        source_image_idx = int(random_list[0])
        test_video_idx = int(random_list[1])
        train_video_idx = int(np.random.choice(len(loaders.src.dataset), size=1))
        # test animation
        source_image = loaders.val.dataset[source_image_idx]['video'][0]
        source_image = source_image.unsqueeze(0).cuda() # (1, 3, 256, 256)
        test_video = loaders.val.dataset[test_video_idx]['video'] # list [video](3, 256, 256)
        test_frame = len(test_video) if len(test_video) < K else K
        predict_test, predict_train = [], []
        
        x_source_mb = source_image.unsqueeze(1).repeat(1,K,1,1,1) # (1, K, 3, 256, 256)
        x_source_land = F.interpolate(self.fan.get_landmark(source_image), size=256, mode='bilinear') # (1, 1, 256, 256)
        y_drive_mb = test_video[0].unsqueeze(0) # (1, 3, 256, 256)
        y_drive_land = F.interpolate(self.fan.get_landmark(y_drive_mb.cuda()), size=256, mode='bilinear') # (1, 1, 256, 256)
        y_drive_land = y_drive_land.unsqueeze(1) # (1, 1, 1, 256, 256)
        y_drive_mb = y_drive_mb.unsqueeze(1) # (1, 1, 3, 256, 256)
        for i in range(1, test_frame):
            y_drive_ld = self.fan.get_landmark(test_video[i].cuda().unsqueeze(0)) # (1, 1, 64, 64)
            y_drive_ld = F.interpolate(y_drive_ld, size=test_video[i].size(2), mode='bilinear') # (1, 1, 256, 256)
            y_drive_land = torch.cat([y_drive_land, y_drive_ld.unsqueeze(1)], dim=1) # (1, K, 1, 256, 256)
            y_drive_mb = torch.cat([y_drive_mb, test_video[i].unsqueeze(0).unsqueeze(1)], dim=1) # (1, K, 3, 256, 256)
            
        y_drive_mb = y_drive_mb.cuda()
        x_source_land_mb = x_source_land.unsqueeze(1).repeat(1, test_frame, 1, 1, 1) # (1, K, 1, 256, 256)
        x_source_land_mb = x_source_land_mb.view(-1, 1, args.img_size, args.img_size) # (bs*k, 1, 256, 256)
        y_drive_land = y_drive_land.view(-1, 1, args.img_size, args.img_size) # (bs*k, 1, 256, 256)

        out = nets.transformer(x_source_land_mb, y_drive_land)
        out = out.view(-1, test_frame, args.max_conv_dim, 1, 1)
        e_hat = out.mean(dim=1) # (bs, 512, 1, 1)
        for i in range(test_frame):
            x_fake = nets.generator(source_image, x_source_land, y_drive_mb[:,i,:,:,:], y_drive_land[i,:,:,:], e_hat)
            predict_test.append(x_fake.cpu().detach().numpy().squeeze().swapaxes(0, 1).swapaxes(1,2))
        source_image = (source_image*255).cpu().squeeze().numpy().swapaxes(0,1).swapaxes(1,2).astype('uint8')
        source_img = Image.fromarray(source_image)
        source_img.save(config.sample_dir+ '/source.png')
        imageio.mimsave(os.path.join(config.sample_dir, 'test_gen.mp4'), [(frame*255).astype('uint8') for frame in predict_test], fps=24)
        imageio.mimsave(os.path.join(config.sample_dir, 'test_raw.mp4'), [(frame*255).numpy().astype('uint8').swapaxes(0,1).swapaxes(1,2) for frame in test_video], fps=24)

In [8]:
def compute_tf_loss(nets, args, x_real, y_org, x_land, y_land, device='cuda'):
    x_real, y_org = x_real.to(device), y_org.to(device)
    x_land = x_land.view(-1, x_land.size(2), x_land.size(3), x_land.size(4))
    y_land = y_land.view(-1, y_land.size(2), y_land.size(3), y_land.size(4))
    
    out = nets.transformer(x_land, y_land) # (bs*K, 1, w, h) / (bs*K, 1, w, h)
    out = out.view(-1, args.K, args.max_conv_dim, 1, 1) # (bs, K, 512, 1, 1)
    e_hat = out.mean(dim=1) # (bs, 512, 1, 1)
    return e_hat

def compute_d_loss(nets, args, x_real, y_org, e_hat, x_landmark,  y_landmark, image_pyramid, scales, device='cuda'):
    # with real images
    x_real, y_org = x_real.to(device), y_org.to(device)
    x_real.requires_grad = True
    imp_real = image_pyramid(x_real)
    disc_real = nets.discriminator(imp_real)
    
    # R1-reg
    loss_reg = 0
    for scale in scales:
        key = 'output_%s' % scale
        real_key = 'prediction_%s' % scale
        value = r1_reg(disc_real[key], imp_real[real_key])
        loss_reg += value.mean()
    
    with torch.no_grad():
        x_gen = nets.generator(x_real, x_landmark, y_org, y_landmark, e_hat)
        img_gen = image_pyramid(x_gen)
    
    disc_gen = nets.discriminator(img_gen)
    loss_d = 0
    for scale in scales:
        key = 'output_%s' % scale
        value = (1 - disc_real[key])**2 + disc_gen[key]**2
        loss_d += value.mean()
    
    loss = loss_d + args.lambda_reg * loss_reg
    return loss, Munch(wgangp=loss_d.item(), reg=loss_reg.item())

def compute_g_loss(nets, args, x_real, y_org, e_hat, x_landmark, y_landmark, vgg, vgg19, image_pyramid, scales, device='cuda'):
    # adversarial loss
    # (bs, 3, w, h) / (bs, 3, w, h) / (bs, 1, w, h) / (bs, 1, w, h)
    x_real, y_org = x_real.to(device), y_org.to(device)
    x_gen = nets.generator(x_real, x_landmark, y_org, y_landmark, e_hat)
    x_gen = x_gen.to(device)
    # image pyramid for D
    imp_gen = image_pyramid(x_gen)
    imp_real = image_pyramid(x_real)
    disc_gen = nets.discriminator(imp_gen) # 'output_scale' / 'map_scale'
    with torch.no_grad():
        disc_real = nets.discriminator(imp_real) # 'output_scale' / 'map_scale'
    
    # adv loss: LSGAN
    loss_adv = 0
    for scale in scales:
        key = 'output_%s' % scale
        value = ((1 - disc_gen[key])**2).mean()
        loss_adv += value
        
    # feature-matching loss
    l1_loss = nn.L1Loss()
    loss_fm = 0
    for scale in scales:
        key = 'map_%s' % scale
        for i, (a, b) in enumerate(zip(disc_real[key], disc_gen[key])):
            value = l1_loss(a, b)
            loss_fm += value
    
    # embedding-matching loss
    key = 'map_%s' % scales[0]
    loss_cm = l1_loss(e_hat.squeeze(), disc_real[key].squeeze())
    
    # perceptual loss: vggface
    with torch.no_grad():
        vgg_x = vgg(y_org)
    with torch.autograd.enable_grad():
        vgg_xhat = vgg(x_gen)
        
    loss_vggface = 0
    for x_feat, xhat_feat in zip(vgg_x, vgg_xhat):
        loss_vggface += l1_loss(x_feat, xhat_feat)
        
    conv_idx_list = [2,7,12,21.30] # indexs of conv layers
    def vgg_x_hook(module, input, output):
            output.detach_() #no gradient compute
            vgg_x_features.append(output)
    def vgg_xhat_hook(module, input, output):
        vgg_xhat_features.append(output)
            
    vgg_x_features = []
    vgg_xhat_features = []
    vgg_x_handles = []
    conv_idx_iter = 0
        
    for i, m in enumerate(vgg19.features.modules()):
        if i == conv_idx_list[conv_idx_iter]:
            if conv_idx_iter < len(conv_idx_list)-1:
                conv_idx_iter += 1
            vgg_x_handles.append(m.register_forward_hook(vgg_x_hook))
    with torch.no_grad():
        vgg19(y_org)
    for h in vgg_x_handles:
        h.remove()
    
    vgg_xhat_handles = []
    conv_idx_iter = 0
    with torch.autograd.enable_grad():
        for i, m in enumerate(vgg19.features.modules()):
            if i == conv_idx_list[conv_idx_iter]:
                if conv_idx_iter < len(conv_idx_list)-1:
                    conv_idx_iter += 1
                vgg_xhat_handles.append(m.register_forward_hook(vgg_xhat_hook))
        vgg19(x_gen)
        
        for h in vgg_xhat_handles:
            h.remove()
    
    loss_vgg19 = 0
    for x_feat, xhat_feat in zip(vgg_x_features, vgg_xhat_features):
        loss_vgg19 += l1_loss(x_feat, xhat_feat)
    
    loss = loss_adv + args.lambda_vggface * loss_vggface + args.lambda_vgg19 * loss_vgg19 + args.lambda_fm * loss_fm + args.lambda_fm * loss_cm
    return loss, Munch(adv=loss_adv.item(), vgg=loss_vggface.item()+loss_vgg19.item(), fm=loss_fm.item(), cm=loss_cm.item())
    
def moving_average(model, model_test, beta=0.999):
    for param, param_test in zip(model.parameters(), model_test.parameters()):
        param_test.data = torch.lerp(param.data, param_test.data, beta)

def adv_loss(logits, target):
    assert target in [1, 0]
    targets = torch.full_like(logits, fill_value=target)
    loss = F.binary_cross_entropy_with_logits(logits, targets)
    return loss

def r1_reg(d_out, x_in):
    # zero-centered gradient penalty for real images
    batch_size = x_in.size(0)
    grad_dout = torch.autograd.grad(
        outputs=d_out.sum(), inputs=x_in,
        create_graph=True, retain_graph=True, only_inputs=True
    )[0]
    grad_dout2 = grad_dout.pow(2)
    assert(grad_dout2.size() == x_in.size())
    reg = 0.5 * grad_dout2.view(batch_size, -1).sum(1).mean(0)
    return reg

In [9]:
train_dataset = MotionDataset(config.root_dir, image_shape=config.frame_shape, id_sampling=True, is_train=True, random_seed=config.seed)
test_dataset = FramesDataset(config.root_dir, image_shape=config.frame_shape, id_sampling=True, is_train=False, random_seed=config.seed)

Use predefined train-test split.
Use predefined train-test split.


In [10]:
train_dataset = DatasetRepeater(train_dataset, config.num_repeats)

In [11]:
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, 
                              num_workers=config.num_workers, pin_memory=False, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=config.num_workers)

In [12]:
loaders = Munch(src=train_loader, val=test_loader)

In [13]:
solver = Solver(config)

Number of parameters of generator: 33541923
Number of parameters of discriminator: 29173571
Number of parameters of transformer: 10586755
Initializing generator...
Initializing discriminator...
Initializing transformer...


In [14]:
# Resume
if resume:
    print('==> Resuming from checkpoint..')
    logger = Logger(os.path.join(config.checkpoint_dir, 'log.txt'), resume=True)
    val_logger = Logger(os.path.join(config.checkpoint_dir, 'val_log.txt'), resume=True)
else:
    logger = Logger(os.path.join(config.checkpoint_dir, 'log.txt'))
    val_logger = Logger(os.path.join(config.checkpoint_dir, 'val_log.txt'))
    logger.set_names(['D Loss', 'R1reg Loss', 'G-latent-adv Loss', 'Perceptual Loss', 'Feature-matching Loss', 'Content-matching Loss'])
    val_logger.set_names(['SSIM measure'])

In [None]:
solver.train(loaders)

                                                                                                                                           

Saving checkpoint into logs/tf7/checkpoints/000001_nets.ckpt...
Saving checkpoint into logs/tf7/checkpoints/000001_optims.ckpt...


                                                                                                                                          

Saving checkpoint into logs/tf7/checkpoints/000002_nets.ckpt...
Saving checkpoint into logs/tf7/checkpoints/000002_optims.ckpt...


                                                                                                                                          

Saving checkpoint into logs/tf7/checkpoints/000003_nets.ckpt...
Saving checkpoint into logs/tf7/checkpoints/000003_optims.ckpt...


                                                                                                                                          

Saving checkpoint into logs/tf7/checkpoints/000004_nets.ckpt...
Saving checkpoint into logs/tf7/checkpoints/000004_optims.ckpt...


                                                                                                                                          

Saving checkpoint into logs/tf7/checkpoints/000005_nets.ckpt...
Saving checkpoint into logs/tf7/checkpoints/000005_optims.ckpt...


                                                                                                                                          

Saving checkpoint into logs/tf7/checkpoints/000006_nets.ckpt...
Saving checkpoint into logs/tf7/checkpoints/000006_optims.ckpt...


                                                                                                                                          

Saving checkpoint into logs/tf7/checkpoints/000007_nets.ckpt...
Saving checkpoint into logs/tf7/checkpoints/000007_optims.ckpt...


                                                                                                                                          

Saving checkpoint into logs/tf7/checkpoints/000008_nets.ckpt...
Saving checkpoint into logs/tf7/checkpoints/000008_optims.ckpt...


                                                                                                                                          

Saving checkpoint into logs/tf7/checkpoints/000009_nets.ckpt...
Saving checkpoint into logs/tf7/checkpoints/000009_optims.ckpt...


                                                                                                                                           

Saving checkpoint into logs/tf7/checkpoints/000010_nets.ckpt...
Saving checkpoint into logs/tf7/checkpoints/000010_optims.ckpt...


                                                                                                                                           

Saving checkpoint into logs/tf7/checkpoints/000011_nets.ckpt...
Saving checkpoint into logs/tf7/checkpoints/000011_optims.ckpt...


                                                                                                                                           

Saving checkpoint into logs/tf7/checkpoints/000012_nets.ckpt...
Saving checkpoint into logs/tf7/checkpoints/000012_optims.ckpt...


                                                                                                                                           

Saving checkpoint into logs/tf7/checkpoints/000013_nets.ckpt...
Saving checkpoint into logs/tf7/checkpoints/000013_optims.ckpt...


                                                                                                                                           

Saving checkpoint into logs/tf7/checkpoints/000014_nets.ckpt...
Saving checkpoint into logs/tf7/checkpoints/000014_optims.ckpt...


                                                                                                                                           

Saving checkpoint into logs/tf7/checkpoints/000015_nets.ckpt...
Saving checkpoint into logs/tf7/checkpoints/000015_optims.ckpt...


                                                                                                                                           

Saving checkpoint into logs/tf7/checkpoints/000016_nets.ckpt...
Saving checkpoint into logs/tf7/checkpoints/000016_optims.ckpt...


                                                                                                                                           

Saving checkpoint into logs/tf7/checkpoints/000017_nets.ckpt...
Saving checkpoint into logs/tf7/checkpoints/000017_optims.ckpt...


                                                                                                                                           

Saving checkpoint into logs/tf7/checkpoints/000018_nets.ckpt...
Saving checkpoint into logs/tf7/checkpoints/000018_optims.ckpt...


                                                                                                                                           

Saving checkpoint into logs/tf7/checkpoints/000019_nets.ckpt...
Saving checkpoint into logs/tf7/checkpoints/000019_optims.ckpt...


                                                                                                                                           

Saving checkpoint into logs/tf7/checkpoints/000020_nets.ckpt...
Saving checkpoint into logs/tf7/checkpoints/000020_optims.ckpt...


                                                                                                                                           

Saving checkpoint into logs/tf7/checkpoints/000021_nets.ckpt...
Saving checkpoint into logs/tf7/checkpoints/000021_optims.ckpt...


                                                                                                                                           

Saving checkpoint into logs/tf7/checkpoints/000022_nets.ckpt...
Saving checkpoint into logs/tf7/checkpoints/000022_optims.ckpt...


                                                                                                                                           

Saving checkpoint into logs/tf7/checkpoints/000023_nets.ckpt...
Saving checkpoint into logs/tf7/checkpoints/000023_optims.ckpt...


                                                                                                                                           

Saving checkpoint into logs/tf7/checkpoints/000024_nets.ckpt...
Saving checkpoint into logs/tf7/checkpoints/000024_optims.ckpt...


                                                                                                                                           

Saving checkpoint into logs/tf7/checkpoints/000025_nets.ckpt...
Saving checkpoint into logs/tf7/checkpoints/000025_optims.ckpt...


                                                                                                                                           

Saving checkpoint into logs/tf7/checkpoints/000026_nets.ckpt...
Saving checkpoint into logs/tf7/checkpoints/000026_optims.ckpt...


                                                                                                                                           

Saving checkpoint into logs/tf7/checkpoints/000027_nets.ckpt...
Saving checkpoint into logs/tf7/checkpoints/000027_optims.ckpt...


                                                                                                                                           

Saving checkpoint into logs/tf7/checkpoints/000028_nets.ckpt...
Saving checkpoint into logs/tf7/checkpoints/000028_optims.ckpt...


                                                                                                                                           

Saving checkpoint into logs/tf7/checkpoints/000029_nets.ckpt...
Saving checkpoint into logs/tf7/checkpoints/000029_optims.ckpt...


                                                                                                                                           

Saving checkpoint into logs/tf7/checkpoints/000030_nets.ckpt...
Saving checkpoint into logs/tf7/checkpoints/000030_optims.ckpt...


                                                                                                                                           

Saving checkpoint into logs/tf7/checkpoints/000031_nets.ckpt...
Saving checkpoint into logs/tf7/checkpoints/000031_optims.ckpt...


                                                                                                                                           

Saving checkpoint into logs/tf7/checkpoints/000032_nets.ckpt...
Saving checkpoint into logs/tf7/checkpoints/000032_optims.ckpt...


                                                                                                                                           

Saving checkpoint into logs/tf7/checkpoints/000033_nets.ckpt...
Saving checkpoint into logs/tf7/checkpoints/000033_optims.ckpt...


                                                                                                                                           

Saving checkpoint into logs/tf7/checkpoints/000034_nets.ckpt...
Saving checkpoint into logs/tf7/checkpoints/000034_optims.ckpt...


                                                                                                                                           

Saving checkpoint into logs/tf7/checkpoints/000035_nets.ckpt...
Saving checkpoint into logs/tf7/checkpoints/000035_optims.ckpt...


                                                                                                                                           

Saving checkpoint into logs/tf7/checkpoints/000036_nets.ckpt...
Saving checkpoint into logs/tf7/checkpoints/000036_optims.ckpt...


                                                                                                                                           

Saving checkpoint into logs/tf7/checkpoints/000037_nets.ckpt...
Saving checkpoint into logs/tf7/checkpoints/000037_optims.ckpt...


                                                                                                                                           

Saving checkpoint into logs/tf7/checkpoints/000038_nets.ckpt...
Saving checkpoint into logs/tf7/checkpoints/000038_optims.ckpt...


                                                                                                                                           

Saving checkpoint into logs/tf7/checkpoints/000039_nets.ckpt...
Saving checkpoint into logs/tf7/checkpoints/000039_optims.ckpt...


                                                                                                                                           

Saving checkpoint into logs/tf7/checkpoints/000040_nets.ckpt...
Saving checkpoint into logs/tf7/checkpoints/000040_optims.ckpt...


                                                                                                                                           

Saving checkpoint into logs/tf7/checkpoints/000041_nets.ckpt...
Saving checkpoint into logs/tf7/checkpoints/000041_optims.ckpt...


                                                                                                                                           

Saving checkpoint into logs/tf7/checkpoints/000042_nets.ckpt...
Saving checkpoint into logs/tf7/checkpoints/000042_optims.ckpt...


                                                                                                                                           

Saving checkpoint into logs/tf7/checkpoints/000043_nets.ckpt...
Saving checkpoint into logs/tf7/checkpoints/000043_optims.ckpt...


                                                                                                                                           

Saving checkpoint into logs/tf7/checkpoints/000044_nets.ckpt...
Saving checkpoint into logs/tf7/checkpoints/000044_optims.ckpt...


                                                                                                                                           

Saving checkpoint into logs/tf7/checkpoints/000045_nets.ckpt...
Saving checkpoint into logs/tf7/checkpoints/000045_optims.ckpt...


                                                                                                                                           

Saving checkpoint into logs/tf7/checkpoints/000046_nets.ckpt...
Saving checkpoint into logs/tf7/checkpoints/000046_optims.ckpt...


                                                                                                                                           

Saving checkpoint into logs/tf7/checkpoints/000047_nets.ckpt...
Saving checkpoint into logs/tf7/checkpoints/000047_optims.ckpt...


                                                                                                                                           

Saving checkpoint into logs/tf7/checkpoints/000048_nets.ckpt...
Saving checkpoint into logs/tf7/checkpoints/000048_optims.ckpt...


Ep:49, D: 0.113630, R1: 0.05, G: 2.845651, Vgg: 3.298123, FM: 4.534843, CM: 0.018213:  61%|██████    | 478/783 [1:52:04<1:12:13, 14.21s/it]