# Training for Stylizer module
This notebook will handle the training the stylizer module.

## Load Pretrained model

In [1]:
import matplotlib

matplotlib.use('Agg')

import os, sys
import yaml
from argparse import ArgumentParser
from time import gmtime, strftime
from shutil import copy

from frames_dataset import FramesDataset

from modules.generator import OcclusionAwareGenerator # LCH: refer here for generator
from modules.discriminator import MultiScaleDiscriminator # LCH: refer here for discriminator
from modules.keypoint_detector import KPDetector # LCH: refer here for key point detector

import torch

from train import train # LCH: For training process, everything in this module
from reconstruction import reconstruction
from animate import animate

In [3]:
config_path = "config/anim-256.yaml"
with open(config_path) as f:
        # read in the config file
        config = yaml.load(f) # config file contains code directions, including training details

checkpoint_path = "pre_trains/vox-cpk.pth.tar"
log_dir = "MyLog/"
if not os.path.exists(log_dir):
    os.mkdir(log_dir)
# Copy the config file (*.yaml) into the logging path
if not os.path.exists(os.path.join(log_dir, os.path.basename(config_path))):
    copy(config_path, log_dir)

# initialize generator
generator = OcclusionAwareGenerator(**config['model_params']['generator_params'],
                                        **config['model_params']['common_params'])
# initialize discriminator
discriminator = MultiScaleDiscriminator(**config['model_params']['discriminator_params'],
                                            **config['model_params']['common_params'])
# initialize kp detector
kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
                             **config['model_params']['common_params'])

# If GPU Available, adapt to it
if torch.cuda.is_available():
    generator.to(0)
    discriminator.to(0)
    kp_detector.to(0)

  after removing the cwd from sys.path.


In [6]:
# load in the pretrained modules
from logger import Logger

train_params = config['train_params']

if not torch.cuda.is_available():
    # remember to adapt to cpu version
    checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
else:
    checkpoint = torch.load(checkpoint_path)

generator.load_state_dict(checkpoint['generator'])
discriminator.load_state_dict(checkpoint['discriminator'])
kp_detector.load_state_dict(checkpoint['kp_detector'])

## Dataset Preparation

In [15]:
from frames_dataset import DatasetRepeater
from torch.utils.data import DataLoader


# load original target data
frame_dataset = FramesDataset(is_train=True, **config['dataset_params'])
print("Dataset size: {}, repeat number: {}".format(len(dataset), config['train_params']['num_repeats']))

if 'num_repeats' in train_params or train_params['num_repeats'] != 1:
    # Augment the dataset according to "num_reapeat"
    frame_dataset = DatasetRepeater(dataset, train_params['num_repeats'])
    print("Repeated Dataset size: {}, repeat number: {}".format(len(dataset), config['train_params']['num_repeats']))

dataloader = DataLoader(frame_dataset, batch_size=train_params['batch_size'], shuffle=True, num_workers=2, drop_last=True)

Use predefined train-test split.
Dataset size: 1080, repeat number: 4
Repeated Dataset size: 1080, repeat number: 4
