# Training for Stylizer module
This notebook will handle the training the stylizer module.

## Load Pretrained model

In [1]:
import matplotlib

matplotlib.use('Agg')

import os, sys
import yaml
from argparse import ArgumentParser
from time import gmtime, strftime
from shutil import copy

from frames_dataset import FramesDataset

from modules.generator import OcclusionAwareGenerator # LCH: refer here for generator
from modules.discriminator import MultiScaleDiscriminator # LCH: refer here for discriminator
from modules.keypoint_detector import KPDetector # LCH: refer here for key point detector

import torch

from train import train # LCH: For training process, everything in this module
from reconstruction import reconstruction
from animate import animate

In [2]:
config_path = "config/anim-256.yaml"
with open(config_path) as f:
        # read in the config file
        config = yaml.load(f) # config file contains code directions, including training details

checkpoint_path = "pre_trains/vox-cpk.pth.tar"
log_dir = "MyLog/"
if not os.path.exists(log_dir):
    os.mkdir(log_dir)
# Copy the config file (*.yaml) into the logging path
if not os.path.exists(os.path.join(log_dir, os.path.basename(config_path))):
    copy(config_path, log_dir)

# initialize generator
generator = OcclusionAwareGenerator(**config['model_params']['generator_params'],
                                        **config['model_params']['common_params'])
# initialize discriminator
discriminator = MultiScaleDiscriminator(**config['model_params']['discriminator_params'],
                                            **config['model_params']['common_params'])
# initialize kp detector
kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
                             **config['model_params']['common_params'])

# If GPU Available, adapt to it
# if torch.cuda.is_available():
#     print("using GPU")
#     generator.to(0)
#     discriminator.to(0)
#     kp_detector.to(0)

  after removing the cwd from sys.path.


In [3]:
# load in the pretrained modules
from logger import Logger

train_params = config['train_params']

if not torch.cuda.is_available():
    # remember to adapt to cpu version
    checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
else:
    checkpoint = torch.load(checkpoint_path)

generator.load_state_dict(checkpoint['generator'])
discriminator.load_state_dict(checkpoint['discriminator'])
kp_detector.load_state_dict(checkpoint['kp_detector'])

    Found GPU0 GeForce GTX 770 which is of cuda capability 3.0.
    PyTorch no longer supports this GPU because it is too old.
    


## Dataset Preparation

In [4]:
from frames_dataset import DatasetRepeater
from torch.utils.data import DataLoader


# load original target data
frame_dataset = FramesDataset(is_train=True, **config['dataset_params'])
print("Dataset size: {}, repeat number: {}".format(len(frame_dataset), config['train_params']['num_repeats']))

if 'num_repeats' in train_params or train_params['num_repeats'] != 1:
    # Augment the dataset according to "num_reapeat"
    frame_dataset = DatasetRepeater(frame_dataset, train_params['num_repeats'])
    print("Repeated Dataset size: {}, repeat number: {}".format(len(frame_dataset), config['train_params']['num_repeats']))

dataloader = DataLoader(frame_dataset, batch_size=train_params['batch_size'], shuffle=True, num_workers=2, drop_last=True)

Use predefined train-test split.
Dataset size: 270, repeat number: 4
Repeated Dataset size: 1080, repeat number: 4


In [5]:
# get sparse motion field by the original models
# check the dense motion module
if generator.dense_motion_network is None:
    print("Error: dense motion network doesn't exist!")
dm_network = generator.dense_motion_network

# data fetching schema
for x in dataloader:
    print(x['source'].shape)
    
    # first get the key points for both source and driving
    kp_source = kp_detector(x['source'])
    kp_driving = kp_detector(x['driving'])
    
    # second pass through the motion predictor
    # plan A: get sparse motion as training data
    if dm_network.scale_factor != 1:
        src_image = dm_network.down(x['source'])
    
    bs, _, h, w = src_image.shape
    print("source image shape: {}".format(src_image.shape))
    sparse_motion = dm_network.create_sparse_motions(src_image, kp_driving, kp_source)
    # here we don't need the last key point, which is a identity grid layer added by users
    sparse_motion = sparse_motion[:, :-1, :, :, :]
    print("sparse motion shape: {}".format(sparse_motion.shape)) 
    break

torch.Size([10, 3, 256, 256])
source image shape: torch.Size([10, 3, 64, 64])
sparse motion shape: torch.Size([10, 10, 64, 64, 2])


In [12]:
sparse_motion = sparse_motion.permute(0, 1, 3, 4, 2)
s = sparse_motion.shape
# sparse_motion.view()

In [16]:
print(s)
sparse_motion.view((s[0], -1, s[2], s[3])).view(s)

torch.Size([10, 10, 64, 64, 2])


torch.Size([10, 10, 64, 64, 2])

## Code snippets part

In [28]:
h, w = 4, 4
x = torch.arange(w).type(torch.float)
y = torch.arange(h).type(torch.float)

x = (2 * (x / (w - 1)) - 1)
y = (2 * (y / (h - 1)) - 1)

yy = y.view(-1, 1).repeat(1, w)
xx = x.view(1, -1).repeat(h, 1)

meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2)

print(meshed.shape)

meshed = meshed.view(1, 1, h, w, 2)
meshed.shape

torch.Size([4, 4, 2])


torch.Size([1, 1, 4, 4, 2])