In [1]:
import cv2
from pathlib import Path
from itertools import chain
import os
import random
from munch import Munch
from PIL import Image
import numpy as np
import torch
from torch.utils import data
from torch.utils.data import Dataset
from torch.utils.data.sampler import WeightedRandomSampler
from torchvision import transforms
from torchvision.datasets import ImageFolder
import argparse
from core.data_loader import get_train_loader
from core.data_loader import get_test_loader
from core.solver import Solver, compute_d_loss

In [2]:
parser = argparse.ArgumentParser()

# model arguments
parser.add_argument('--img_size', type=int, default=256,
                    help='Image resolution')
parser.add_argument('--num_domains', type=int, default=2,
                    help='Number of domains')
parser.add_argument('--latent_dim', type=int, default=16,
                    help='Latent vector dimension')
parser.add_argument('--latent_channels', type=int, default=64,
                    help='latent channels for coarse')
parser.add_argument('--norm', type=str, default='in',
                    help='normalization type for coarse')
parser.add_argument('--pad_type', type=str, default='zero',
                    help='the padding type for coarse')
parser.add_argument('--activation', type=str, default='elu',
                    help='the activation type for coarse')  # elu
parser.add_argument('--hidden_dim', type=int, default=512,
                    help='Hidden dimension of mapping network')
parser.add_argument('--style_dim', type=int, default=64,
                    help='Style code dimension')

# weight for objective functions
parser.add_argument('--lambda_reg', type=float, default=1,
                    help='Weight for R1 regularization')
parser.add_argument('--lambda_cyc', type=float, default=1,
                    help='Weight for cyclic consistency loss')
parser.add_argument('--lambda_sty', type=float, default=1,
                    help='Weight for style reconstruction loss')
parser.add_argument('--lambda_ds', type=float, default=1,
                    help='Weight for diversity sensitive loss')
parser.add_argument('--lambda_pl1', type=float, default=1, help='the parameter of parsing L1Loss')
parser.add_argument('--lambda_il1', type=float, default=1, help='the parameter of image L1Loss')
parser.add_argument('--lambda_perceptual', type=float, default=2,
                    help='the parameter of FML1Loss (perceptual loss)')
parser.add_argument('--lambda_gan', type=float, default=1,
                    help='the parameter of valid loss of AdaReconL1Loss; 0 is recommended')
parser.add_argument('--s_loss', type=float, default=2.5, help='STYLE_LOSS_WEIGHT')
parser.add_argument('--p_loss', type=float, default=0.1, help='CONTENT_LOSS_WEIGHT')
parser.add_argument('--ds_iter', type=int, default=100000,
                    help='Number of iterations to optimize diversity sensitive loss')
parser.add_argument('--w_hpf', type=float, default=1,
                    help='weight for high-pass filtering')

# training arguments
parser.add_argument('--randcrop_prob', type=float, default=0.5,
                    help='Probabilty of using random-resized cropping')
parser.add_argument('--total_iters', type=int, default=100000,
                    help='Number of total iterations')
parser.add_argument('--resume_iter', type=int, default=0,
                    help='Iterations to resume training/testing')
parser.add_argument('--batch_size', type=int, default=8,
                    help='Batch size for training')
parser.add_argument('--val_batch_size', type=int, default=32,
                    help='Batch size for validation')
parser.add_argument('--lr', type=float, default=1e-4,
                    help='Learning rate for D, E and G')
parser.add_argument('--f_lr', type=float, default=1e-6,
                    help='Learning rate for F')
parser.add_argument('--beta1', type=float, default=0.0,
                    help='Decay rate for 1st moment of Adam')
parser.add_argument('--beta2', type=float, default=0.99,
                    help='Decay rate for 2nd moment of Adam')
parser.add_argument('--weight_decay', type=float, default=1e-4,
                    help='Weight decay for optimizer')
parser.add_argument('--num_outs_per_domain', type=int, default=10,
                    help='Number of generated images per domain during sampling')

# misc
parser.add_argument('--mode', type=str, required=True,
                    choices=['train', 'eval', 'sample'],
                    help='This argument is used in solver')
parser.add_argument('--num_workers', type=int, default=4,
                    help='Number of workers used in DataLoader')
parser.add_argument('--seed', type=int, default=777,
                    help='Seed for random number generator')

# directory for training
parser.add_argument('--train_img_dir', type=str, default='data/celeba_hq/train',
                    help='Directory containing training images')
parser.add_argument('--val_img_dir', type=str, default='data/celeba_hq/val',
                    help='Directory containing validation images')
parser.add_argument('--sample_dir', type=str, default='expr/samples',
                    help='Directory for saving generated images')
parser.add_argument('--checkpoint_dir', type=str, default='expr/checkpoints',
                    help='Directory for saving network checkpoints')

# directory for calculating metrics
parser.add_argument('--eval_dir', type=str, default='expr/eval',
                    help='Directory for saving metrics, i.e., FID and LPIPS')

# directory for testing
parser.add_argument('--result_dir', type=str, default='expr/results',
                    help='Directory for saving generated images and videos')
parser.add_argument('--src_dir', type=str, default='assets/representative/celeba_hq/src',
                    help='Directory containing input source images')
parser.add_argument('--ref_dir', type=str, default='assets/representative/celeba_hq/ref',
                    help='Directory containing input reference images')
parser.add_argument('--inp_dir', type=str, default='assets/representative/custom/female',
                    help='input directory when aligning faces')
parser.add_argument('--out_dir', type=str, default='assets/representative/celeba_hq/src/female',
                    help='output directory when aligning faces')

# face alignment
parser.add_argument('--wing_path', type=str, default='expr/checkpoints/wing.ckpt')
parser.add_argument('--lm_path', type=str, default='expr/checkpoints/celeba_lm_mean.npz')

# step size
parser.add_argument('--print_every', type=int, default=10)
parser.add_argument('--sample_every', type=int, default=5000)
parser.add_argument('--save_every', type=int, default=10000)
parser.add_argument('--eval_every', type=int, default=50000)

args = parser.parse_args()

args = parser.parse_args(['--mode','train'])

In [3]:
pre_loaders = Munch(src=get_train_loader(root=args.train_img_dir,
                                             which='source',
                                             img_size=args.img_size,
                                             batch_size=args.batch_size,
                                             prob=args.randcrop_prob,
                                             num_workers=args.num_workers,
                                             pre=True),
                        val=get_test_loader(root=args.val_img_dir,
                                            img_size=args.img_size,
                                            batch_size=args.val_batch_size,
                                            shuffle=True,
                                            num_workers=args.num_workers))

Preparing DataLoader to fetch source images during the training phase...
Preparing DataLoader for the generation phase...


In [4]:
from core.data_loader import InputFetcher

In [5]:
fetcher = InputFetcher(pre_loaders.src, latent_dim=args.latent_dim, mode='test')

In [6]:
fetcher_val = InputFetcher(pre_loaders.val, None, latent_dim=args.latent_dim, mode='test')

In [7]:
inputs_val = next(fetcher_val)

In [8]:
from core.model import build_model

In [9]:
nets = build_model(args,pre=True)

In [10]:
x_real, x_p, y_org, mask = inputs_val.x, inputs_val.p, inputs_val.y, inputs_val.mask

In [11]:
x_real.shape

torch.Size([32, 3, 256, 256])

In [12]:
def compute_d_loss(nets, args, x_real, x_p, y_org, y_trg=None, z_trg=None, x_ref=None, masks=None, pre=False):
    if not pre:
        assert (z_trg is None) != (x_ref is None)
    # with real images
    x_real.requires_grad_()
    out = nets.discriminator(x_real, masks, y_org)
    loss_real = adv_loss(out, 1)
    loss_reg = r1_reg(out, x_real)

    # with fake images
    with torch.no_grad():
        x_mask = x_real * (1 - masks)
        x_p_mask = x_p * (1-masks)
        if pre:
            x_fake = nets.generator(x_mask, None, x_p_mask, masks)
        else:
            if z_trg is not None:
                s_trg = nets.mapping_network(z_trg, y_trg)
            else:  # x_ref is not None
                s_trg = nets.style_encoder(x_ref, y_trg)
            x_fake = nets.generator(x_real, s_trg, masks=masks)
    out = nets.discriminator(x_fake, masks, y_trg)
    loss_fake = adv_loss(out, 0)

    loss = loss_real + loss_fake + args.lambda_reg * loss_reg
    return loss, Munch(real=loss_real.item(),
                       fake=loss_fake.item(),
                       reg=loss_reg.item())

In [19]:
x_real.requires_grad_()
out = nets.discriminator(x_real, mask, y_org)

In [20]:
targets = torch.full_like(out, fill_value=0)

In [23]:
import torch.nn.functional as F

In [24]:
mse_loss

tensor(0.4137, device='cuda:0', grad_fn=<MseLossBackward>)

In [31]:
F.binary_cross_entropy_with_logits(out, targets)

tensor(0.8252, device='cuda:0', grad_fn=<BinaryCrossEntropyWithLogitsBackward>)

In [None]:
def compute_g_loss_pre(nets, args, x_real, x_p, y_org, x_ref=None, masks=None):
    x_mask = x_real * (1-masks)
    x_p_mask = x_p * (1 - masks)
    x_fake = nets.generator(x_mask, None, x_p_mask, masks=masks)
    out = nets.discriminator(x_fake, masks, y_org)
    loss_adv = adv_loss(out, 1)

    # style reconstruction loss
    if x_ref is not None:
        s_trg = nets.style_encoder(x_ref, y_org)
        s_pred = nets.style_encoder(x_fake, y_org)
        loss_sty = torch.mean(torch.abs(s_pred - s_trg))
    else:
        loss_sty = 0
    loss = loss_adv + args.lambda_sty * loss_sty
    return loss, Munch(adv=loss_adv.item(),
                       sty=loss_sty.item(),
                       ds=loss_ds.item(),
                       cyc=loss_cyc.item())