In [1]:
import os
import argparse

from munch import Munch
from torch.backends import cudnn
import torch

from core.data_loader import get_train_loader
from core.data_loader import get_test_loader
from core.solver import Solver

In [2]:
cudnn.benchmark = True
torch.manual_seed(0)
torch.randn([3])

tensor([ 1.5410, -0.2934, -2.1788])

In [3]:
parser = argparse.ArgumentParser()

# model arguments
parser.add_argument('--img_size', type=int, default=256,
                    help='Image resolution')
parser.add_argument('--num_domains', type=int, default=2,
                    help='Number of domains')
parser.add_argument('--latent_dim', type=int, default=16,
                    help='Latent vector dimension')
parser.add_argument('--hidden_dim', type=int, default=512,
                    help='Hidden dimension of mapping network')
parser.add_argument('--style_dim', type=int, default=64,
                    help='Style code dimension')

# weight for objective functions
parser.add_argument('--lambda_reg', type=float, default=1,
                    help='Weight for R1 regularization')
parser.add_argument('--lambda_cyc', type=float, default=1,
                    help='Weight for cyclic consistency loss')
parser.add_argument('--lambda_sty', type=float, default=1,
                    help='Weight for style reconstruction loss')
parser.add_argument('--lambda_ds', type=float, default=2,
                    help='Weight for diversity sensitive loss')
parser.add_argument('--ds_iter', type=int, default=100000,
                    help='Number of iterations to optimize diversity sensitive loss')
parser.add_argument('--w_hpf', type=float, default=0,
                    help='weight for high-pass filtering')

# training arguments
parser.add_argument('--randcrop_prob', type=float, default=0.5,
                    help='Probabilty of using random-resized cropping')
parser.add_argument('--total_iters', type=int, default=100000,
                    help='Number of total iterations')
parser.add_argument('--resume_iter', type=int, default=0,
                    help='Iterations to resume training/testing')
parser.add_argument('--batch_size', type=int, default=8,
                    help='Batch size for training')
parser.add_argument('--val_batch_size', type=int, default=32,
                    help='Batch size for validation')
parser.add_argument('--lr', type=float, default=1e-4,
                    help='Learning rate for D, E and G')
parser.add_argument('--f_lr', type=float, default=1e-6,
                    help='Learning rate for F')
parser.add_argument('--beta1', type=float, default=0.0,
                    help='Decay rate for 1st moment of Adam')
parser.add_argument('--beta2', type=float, default=0.99,
                    help='Decay rate for 2nd moment of Adam')
parser.add_argument('--weight_decay', type=float, default=1e-4,
                    help='Weight decay for optimizer')
parser.add_argument('--num_outs_per_domain', type=int, default=10,
                    help='Number of generated images per domain during sampling')

# misc
parser.add_argument('--mode', type=str,default='train', required=False,
                    choices=['train', 'sample', 'eval', 'align'],
                    help='This argument is used in solver')
parser.add_argument('--num_workers', type=int, default=4,
                    help='Number of workers used in DataLoader')
parser.add_argument('--seed', type=int, default=777,
                    help='Seed for random number generator')

# directory for training
parser.add_argument('--train_img_dir', type=str, default='data/afhq/train',
                    help='Directory containing training images')
parser.add_argument('--val_img_dir', type=str, default='data/afhq/val',
                    help='Directory containing validation images')
parser.add_argument('--sample_dir', type=str, default='expr/samples',
                    help='Directory for saving generated images')
parser.add_argument('--checkpoint_dir', type=str, default='expr/checkpoints',
                    help='Directory for saving network checkpoints')

# directory for calculating metrics
parser.add_argument('--eval_dir', type=str, default='expr/eval',
                    help='Directory for saving metrics, i.e., FID and LPIPS')

# directory for testing
parser.add_argument('--result_dir', type=str, default='expr/results',
                    help='Directory for saving generated images and videos')
parser.add_argument('--src_dir', type=str, default='assets/representative/afhq/src',
                    help='Directory containing input source images')
parser.add_argument('--ref_dir', type=str, default='assets/representative/afhq/ref',
                    help='Directory containing input reference images')
parser.add_argument('--inp_dir', type=str, default='assets/representative/custom/female',
                    help='input directory when aligning faces')
parser.add_argument('--out_dir', type=str, default='assets/representative/afhq/src/female',
                    help='output directory when aligning faces')

# face alignment
parser.add_argument('--wing_path', type=str, default='expr/checkpoints/wing.ckpt')
parser.add_argument('--lm_path', type=str, default='expr/checkpoints/celeba_lm_mean.npz')

# step size
parser.add_argument('--print_every', type=int, default=10)
parser.add_argument('--sample_every', type=int, default=5000)
parser.add_argument('--save_every', type=int, default=10000)
parser.add_argument('--eval_every', type=int, default=50000)

args = parser.parse_args([])

In [4]:
print(args)

Namespace(batch_size=8, beta1=0.0, beta2=0.99, checkpoint_dir='expr/checkpoints', ds_iter=100000, eval_dir='expr/eval', eval_every=50000, f_lr=1e-06, hidden_dim=512, img_size=256, inp_dir='assets/representative/custom/female', lambda_cyc=1, lambda_ds=2, lambda_reg=1, lambda_sty=1, latent_dim=16, lm_path='expr/checkpoints/celeba_lm_mean.npz', lr=0.0001, mode='train', num_domains=2, num_outs_per_domain=10, num_workers=4, out_dir='assets/representative/afhq/src/female', print_every=10, randcrop_prob=0.5, ref_dir='assets/representative/afhq/ref', result_dir='expr/results', resume_iter=0, sample_dir='expr/samples', sample_every=5000, save_every=10000, seed=777, src_dir='assets/representative/afhq/src', style_dim=64, total_iters=100000, train_img_dir='data/afhq/train', val_batch_size=32, val_img_dir='data/afhq/val', w_hpf=0, weight_decay=0.0001, wing_path='expr/checkpoints/wing.ckpt')


In [5]:
loaders = Munch(src=get_train_loader(root=args.train_img_dir,
                                     which='source',
                                     img_size=args.img_size,
                                     batch_size=args.batch_size,
                                     prob=args.randcrop_prob,
                                     num_workers=args.num_workers),
                ref=get_train_loader(root=args.train_img_dir,
                                     which='reference',
                                     img_size=args.img_size,
                                     batch_size=args.batch_size,
                                     prob=args.randcrop_prob,
                                     num_workers=args.num_workers),
                val=get_test_loader(root=args.val_img_dir,
                                    img_size=args.img_size,
                                    batch_size=args.val_batch_size,
                                    shuffle=True,
                                    num_workers=args.num_workers))

Preparing DataLoader to fetch source images during the training phase...
Preparing DataLoader to fetch reference images during the training phase...
Preparing DataLoader for the generation phase...


In [6]:
for src in loaders.src:
    src_demo = src[0]
    print(src_demo.shape, src_demo.shape)
    break

torch.Size([8, 3, 256, 256]) torch.Size([8, 3, 256, 256])


In [7]:
for ref in loaders.ref:
    print(ref[0].shape, ref[1].shape, ref[2].shape)
    break

torch.Size([8, 3, 256, 256]) torch.Size([8, 3, 256, 256]) torch.Size([8])


In [15]:
from torchvision.utils import save_image
batch_img = ((src_demo + 1)/2).clip(0, 1)
save_image(batch_img.cpu(), './test.jpg', padding=0)

In [8]:
solver = Solver(args)

Number of parameters of generator: 33892995
Number of parameters of mapping_network: 2438272
Number of parameters of style_encoder: 20916928
Number of parameters of discriminator: 20852290
Initializing generator...
Initializing mapping_network...
Initializing style_encoder...
Initializing discriminator...


In [39]:
args.sample_every, args.sample_dir

(5000, 'expr/samples')

In [9]:
solver.nets.keys()

dict_keys(['generator', 'mapping_network', 'style_encoder', 'discriminator'])

In [10]:
from torchviz import make_dot
x = torch.rand([1, 3, 256, 256]).to('cuda:0')
y = torch.randint(0, 3, [1]).to('cuda:0')
dis = solver.nets.discriminator.to('cuda:0')
dis_dot = make_dot(dis(x, y), params=dict(dis.named_parameters()), show_attrs=True, show_saved=True)

/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [0,0,0], thread: [0,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.


In [11]:
dis_dot.view(filename='dis')

'dis.pdf'

Unescaped left brace in regex is deprecated, passed through in regex; marked by <-- HERE in m/%{ <-- HERE (.*?)}/ at /usr/bin/run-mailcap line 528.
Error: no "view" rule for type "application/pdf" passed its test case
       (for more information, add "--debug=1" on the command line)
/usr/bin/xdg-open: line 778: www-browser: command not found
/usr/bin/xdg-open: line 778: links2: command not found
/usr/bin/xdg-open: line 778: elinks: command not found
/usr/bin/xdg-open: line 778: links: command not found
/usr/bin/xdg-open: line 778: lynx: command not found
/usr/bin/xdg-open: line 778: w3m: command not found
xdg-open: no method available for opening 'dis.pdf'


In [12]:
from torchviz import make_dot
map_net = solver.nets.mapping_network
z = torch.rand([1, 16]).to('cuda:0') # latent code
y = torch.randint(0, 3, [1]).to('cuda:0')
map_net_dot = make_dot(map_net(z, y), params=dict(map_net.named_parameters()), show_attrs=True, show_saved=True)

In [13]:
map_net_dot.view('map_net')

'map_net.pdf'

Unescaped left brace in regex is deprecated, passed through in regex; marked by <-- HERE in m/%{ <-- HERE (.*?)}/ at /usr/bin/run-mailcap line 528.
Error: no "view" rule for type "application/pdf" passed its test case
       (for more information, add "--debug=1" on the command line)
/usr/bin/xdg-open: line 778: www-browser: command not found
/usr/bin/xdg-open: line 778: links2: command not found
/usr/bin/xdg-open: line 778: elinks: command not found
/usr/bin/xdg-open: line 778: links: command not found
/usr/bin/xdg-open: line 778: lynx: command not found
/usr/bin/xdg-open: line 778: w3m: command not found
xdg-open: no method available for opening 'map_net.pdf'


In [24]:
a = torch.randn([1, 1024, 1,1])
gamma, beta = torch.chunk(a, chunks=2, dim=1)

In [25]:
gamma.shape, beta.shape

(torch.Size([1, 512, 1, 1]), torch.Size([1, 512, 1, 1]))