In [1]:
import os
from os.path import join as ospj
import time
from time import gmtime, strftime
import datetime
from munch import Munch
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from network.model_p import build_model
from core.checkpoint import CheckpointIO
# from dataset.data_loader import InputFetcher
from dataset.frame_dataset import FramesDataset, MotionDataset, DatasetRepeater
import network.utils as utils
import yaml
import random
from utils import Bar, Logger, AverageMeter
from tqdm import tqdm
from torch.nn.parallel.data_parallel import DataParallel
import pytorch_ssim
from torchsummary import summary

In [2]:
with open('config/train_transformer2.yaml') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)
config = Munch(config)

In [3]:
# GPU Device
gpu_id = '0,1,2'
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
use_cuda = torch.cuda.is_available()
print("GPU device " , use_cuda)

GPU device  True


In [4]:
torch.manual_seed(config.seed)
torch.cuda.manual_seed_all(config.seed)
np.random.seed(config.seed)

In [5]:
resume = True

In [6]:
class Solver(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.args.lr = float(self.args.lr)
        self.args.weight_decay = float(self.args.weight_decay)
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.start = 0

        self.nets, self.nets_ema = build_model(args)
        # below setattrs are to make networks be children of Solver, e.g., for self.to(self.device)
        for name, module in self.nets.items():
            utils.print_network(module, name)
            setattr(self, name, module)
        for name, module in self.nets_ema.items():
            setattr(self, name + '_ema', module)

        if args.mode == 'train':
            self.optims = Munch()
            for net in self.nets.keys():
                if net == 'fan':
                    continue
                self.optims[net] = torch.optim.Adam(
                    params=self.nets[net].parameters(),
                    lr=float(args.lr),
                    betas=[args.beta1, args.beta2])

            self.ckptios = [
                CheckpointIO(ospj(args.checkpoint_dir, '{:06d}_nets.ckpt'), **self.nets),
                CheckpointIO(ospj(args.checkpoint_dir, '{:06d}_nets_ema.ckpt'), **self.nets_ema),
                CheckpointIO(ospj(args.checkpoint_dir, '{:06d}_optims.ckpt'), **self.optims)]
        else:
            self.ckptios = [CheckpointIO(ospj(args.checkpoint_dir, '{:06d}_nets_ema.ckpt'), **self.nets_ema)]
    

        self.to(self.device)
        for name, network in self.named_children():
            # Do not initialize the FAN parameters
            if ('ema' not in name) and ('fan' not in name):
                print('Initializing %s...' % name)
                network.apply(utils.he_init)

    def _save_checkpoint(self, step):
        for ckptio in self.ckptios:
            ckptio.save(step)

    def _load_checkpoint(self, step):
        for ckptio in self.ckptios:
            ckptio.load(step)

    def _reset_grad(self):
        for optim in self.optims.values():
            optim.zero_grad()

    def train(self, loaders):
        args = self.args
        nets = self.nets
        nets_ema = self.nets_ema
        
        for name in nets:
            nets[name] = DataParallel(nets[name])
        optims = self.optims

        # resume training if necessary
        if args.resume_iter > 0:
            self._load_checkpoint(args.resume_iter)


        # batch
#         for i in range(args.resume_iter, args.total_iters):
        for epoch in range(args.resume_iter, args.epochs):
            bar = tqdm(total=len(loaders.src), leave=False)
            wgan_loss, d_reg_loss = AverageMeter(), AverageMeter()
            g_latent_loss, g_cycle_loss, tf_l1_loss = AverageMeter(), AverageMeter(), AverageMeter()
            for i, inputs in enumerate(loaders.src):
                x_source, y_drive = inputs['source'].to(self.device), inputs['target']
                num_frame = len(y_drive)
                for f in range(num_frame):
                    # train the discriminator
                    d_loss, d_losses_latent = compute_d_loss(nets, args, x_source, y_drive[f].to(self.device), device=self.device)
                    self._reset_grad()
                    d_loss.backward()
                    optims.discriminator.step()

                    # train the generator
                    g_loss, g_losses_latent = compute_g_loss(nets, args, x_source, y_drive[f].to(self.device), device=self.device)
                    self._reset_grad()
                    g_loss.backward()
                    optims.generator.step()
                    optims.transformer.step()

                    moving_average(nets.generator, nets_ema.generator, beta=0.999)

                    wgan_loss.update(float(d_losses_latent.wgangp), x_source.size(0))
                    d_reg_loss.update(float(d_losses_latent.reg), x_source.size(0))
                    g_latent_loss.update(float(g_losses_latent.adv), x_source.size(0))
                    g_cycle_loss.update(float(g_losses_latent.cyc), x_source.size(0))
#                     tf_l1_loss.update(float(g_losses_latent.tf), x_source.size(0))
                    
                
                bar.set_description("Ep:{:d}, WGP: {:.6f}, R1: {:.2f}, G: {:.6f}, Cyc: {:.6f}".format(
                                    epoch+1, wgan_loss.avg, d_reg_loss.avg, 
                                    g_latent_loss.avg, g_cycle_loss.avg), refresh=True)
                bar.update()
            bar.close()

                # save model checkpoints
            logger.append([str(wgan_loss.avg)[:8], str(d_reg_loss.avg)[:8], 
                           str(g_latent_loss.avg)[:8], str(g_cycle_loss.avg)[:8]])
            if (epoch+1) % config.save_every == 0:
                self._save_checkpoint(step=epoch+1)

            # compute SSIM and FID in test_set
            if (epoch+1) % config.eval_every == 0:
                self.evaluate(epoch, nets, loaders.val)
        
        self.evaluate(epoch, nets, loaders.val)

    @torch.no_grad()
    def evaluate(self, epoch, nets, loader):
        if not os.path.isdir(config.result_dir):
            os.makedirs(config.result_dir)
        result_target = os.path.join(config.result_dir, 'tar')
        result_gen = os.path.join(config.result_dir,'gen')
        if not os.path.isdir(result_target):
            os.makedirs(result_target)
        if not os.path.isdir(result_gen):
            os.makedirs(result_gen)
        
        bar = tqdm(total=len(loader), leave=False)
        ssim_meter, fid_meter = AverageMeter(), AverageMeter()
        for iteration, x in enumerate(loader):
            test_video = torch.tensor(np.concatenate(x['video'])) # (frame, c, w, h)
            num_frame = test_video.shape[0]
            k_frame = np.random.choice(num_frame-config.K, size=2, replace=False)
            source = test_video[[k_frame[0]]]
            target = test_video[[k_frame[1]]]
            source_gen, _  = nets.generator(source.to(self.device), target.to(self.device), nets.transformer)
            
            ssim = float(pytorch_ssim.ssim(source_gen, target.to(self.device)))
            ssim_meter.update(ssim, iteration+1)
            
            # save for FID
            gen = source_gen.squeeze().cpu().detach().numpy()
            target = target.squeeze().cpu().detach().numpy()
            gen = gen.swapaxes(0, 1).swapaxes(1, 2)
            target = target.swapaxes(0, 1).swapaxes(1, 2)
            gen_img = Image.fromarray((gen*255).astype('uint8'))
            tar_img = Image.fromarray((target*255).astype('uint8'))
            gen_img.save(result_gen + '/{}.png'.format(iteration+1))
            tar_img.save(result_target + '/{}.png'.format(iteration+1))
            
            bar.set_description("Epoch:{:d}, SSIM: {:.8f}".format(epoch+1, ssim_meter.avg), refresh=True)
            bar.update()
        bar.close()
        val_logger.append([str(ssim_meter.avg)])
        return
    
    def make_animation(self, net, loader):
        if not os.path.isdir(config.result_dir):
            os.makedirs(config.result_dir)
        pass
            
def compute_d_loss(nets, args, x_real, y_org, device='cuda'):
#     assert (z_trg is None) != (x_ref is None)
    # with real images
    x_real.requires_grad = True
    y_org.requires_grad = True
    real_out = nets.discriminator(y_org)
     # with fake images
    with torch.no_grad():
        x_fake, _ = nets.generator(x_real, y_org, nets.transformer)
    fake_out = nets.discriminator(x_fake)
    loss_reg = r1_reg(real_out, y_org)
    
    loss = (torch.mean(fake_out) - torch.mean(real_out) + (args.drift * torch.mean(real_out ** 2)))
    x_fake.requires_grad = True
    gp = gradient_penalty(nets, y_org, x_fake, args.lambda_gp, device=device)
    loss += gp

    loss = loss + args.lambda_reg * loss_reg
    return loss, Munch(wgangp=loss.item(), reg=loss_reg.item())


def compute_g_loss(nets, args, x_real, y_org, device='cuda'):
    # adversarial loss
    x_fake, x_att = nets.generator(x_real, y_org, nets.transformer)
    out = nets.discriminator(x_fake)
    loss_adv = adv_loss(out, 1)
    
    # diversity sensitive loss: same real image --> different reference (y_trg, s_trg)
#     if z_trgs is not None:
#         s_trg2 = nets.mapping_network(z_trg2, y_trg)
#     else:
#         s_trg2 = nets.style_encoder(x_ref2, y_trg)
#     x_fake2 = nets.generator(x_real, s_trg2, masks=masks)
#     x_fake2 = x_fake2.detach()
#     loss_ds = torch.mean(torch.abs(x_fake - x_fake2))

    # cycle-consistency loss
    x_rec, x_tf = nets.generator(x_fake, x_real, nets.transformer)
    loss_cyc = torch.mean(torch.abs(x_rec - x_real)) + torch.mean(torch.abs(x_att - x_tf))

#     loss = loss_adv + args.lambda_sty * loss_sty - args.lambda_ds * loss_ds + args.lambda_cyc * loss_cyc
    loss = loss_adv + args.lambda_cyc * loss_cyc
    return loss, Munch(adv=loss_adv.item(), cyc=loss_cyc.item())

# def compute_tf_loss(nets, args, x_att, t_att, device='cuda'):
#     out_
    
#     loss_tf = torch.mean(torch.abs(x_att - t_att))
    
#     args.lambda_tf * tf_loss
#     return loss_tf

def gradient_penalty(nets, real, fake, reg_lambda=10, device='cuda'):
    from torch.autograd import grad
    batch_size = real.shape[0]
    epsilon = torch.rand(batch_size, 1, 1, 1).to(device)
    merged = (epsilon * real) + ((1 - epsilon) * fake)
    # forward
    op = nets.discriminator(merged)
    
    # merted gradient
    gradient = grad(outputs=op, inputs=merged, create_graph=True, grad_outputs=torch.ones_like(op), 
                    retain_graph=True, only_inputs=True)[0]
    
    # calc penalty
    penalty = reg_lambda * ((gradient.norm(p=2, dim=1) - 1) ** 2).mean()
    return penalty
    

def moving_average(model, model_test, beta=0.999):
    for param, param_test in zip(model.parameters(), model_test.parameters()):
        param_test.data = torch.lerp(param.data, param_test.data, beta)


def adv_loss(logits, target):
    assert target in [1, 0]
    targets = torch.full_like(logits, fill_value=target)
    loss = F.binary_cross_entropy_with_logits(logits, targets)
    return loss


def r1_reg(d_out, x_in):
    # zero-centered gradient penalty for real images
    batch_size = x_in.size(0)
    grad_dout = torch.autograd.grad(
        outputs=d_out.sum(), inputs=x_in,
        create_graph=True, retain_graph=True, only_inputs=True
    )[0]
    grad_dout2 = grad_dout.pow(2)
    assert(grad_dout2.size() == x_in.size())
    reg = 0.5 * grad_dout2.view(batch_size, -1).sum(1).mean(0)
    return reg

In [7]:
train_dataset = MotionDataset(config.root_dir, image_shape=config.frame_shape, id_sampling=True, is_train=True, random_seed=config.seed)
test_dataset = FramesDataset(config.root_dir, image_shape=config.frame_shape, id_sampling=True, is_train=False, random_seed=config.seed)

Use predefined train-test split.
Use predefined train-test split.


In [8]:
# train_dataset = DatasetRepeater(train_dataset, config.num_repeats)

In [9]:
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, 
                              num_workers=config.num_workers, pin_memory=True, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=1)

In [10]:
loaders = Munch(src=train_loader, val=test_loader)

In [11]:
solver = Solver(config)

Number of parameters of generator: 37852815
Number of parameters of discriminator: 20851777
Number of parameters of transformer: 22069248
Initializing generator...
Initializing discriminator...
Initializing transformer...


In [12]:
# Resume
if resume:
    print('==> Resuming from checkpoint..')
    logger = Logger(os.path.join(config.checkpoint_dir, 'log.txt'), resume=True)
    val_logger = Logger(os.path.join(config.checkpoint_dir, 'val_log.txt'), resume=True)
else:
    logger = Logger(os.path.join(config.checkpoint_dir, 'log.txt'))
    val_logger = Logger(os.path.join(config.checkpoint_dir, 'val_log.txt'))
    logger.set_names(['WGAN-GP Loss', 'R1reg Loss', 'G-latent-adv Loss', 'G-cyclc Loss'])
    val_logger.set_names(['SSIM measure'])

==> Resuming from checkpoint..


In [13]:
solver.train(loaders)

Loading checkpoint from logs/tf2/checkpoints/000025_nets.ckpt...
Loading checkpoint from logs/tf2/checkpoints/000025_nets_ema.ckpt...
Loading checkpoint from logs/tf2/checkpoints/000025_optims.ckpt...


  0%|          | 0/26 [00:00<?, ?it/s]

RuntimeError: Caught RuntimeError in replica 1 on device 1.
Original Traceback (most recent call last):
  File "/home/cutz/anaconda3/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 60, in _worker
    output = module(*input, **kwargs)
  File "/home/cutz/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/nips/experiment/network/model_p.py", line 175, in forward
    x = block(x, tf_x)
  File "/home/cutz/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/nips/experiment/network/model_p.py", line 115, in forward
    out = self._residual(x, att)
  File "/home/nips/experiment/network/model_p.py", line 104, in _residual
    x = self.norm1(x, att)
  File "/home/cutz/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/nips/experiment/network/model_p.py", line 73, in forward
    h = self.conv(s) # (bs, 1, 512)
  File "/home/cutz/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/cutz/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py", line 208, in forward
    self.padding, self.dilation, self.groups)
RuntimeError: Expected tensor for argument #1 'input' to have the same device as tensor for argument #2 'weight'; but device 0 does not equal 1 (while checking arguments for cudnn_convolution)
