In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import os
import json
import scipy
import h5py
import skimage
import glob
from skimage import io,transform 
from skimage.io import imread,imsave
from collections import OrderedDict

project_dir = os.getcwd()
art_dir = project_dir + '/images/styles/'
photo_dir = project_dir + '/images/content/'
guide_dir = project_dir + '/images/guides/'
out_dir = project_dir + '/images/outputs/'
model_dir = project_dir + '/models/trained/'
if not os.path.isdir(model_dir):
    print('Model dir missing, to get the pretrained models execute the download_leon_models.sh script')
data_dir = project_dir + '/data/'
if not os.path.isdir(data_dir):
    print('Data dir missing, for training networks you first need to create a dataset using the make_style_dataset.py script')
tmp_dir = project_dir + '/tmp/'
if not os.path.isdir(tmp_dir):
    os.makedirs(tmp_dir)

Model dir missing, to get the pretrained models execute the download_leon_models.sh script
Data dir missing, for training networks you first need to create a dataset using the make_style_dataset.py script


## Train luminance network

For luminance training, one first needs to create a luminance dataset using the '--lum' flag in the make_style_dataset.py script.

The code below assumes that the dataset is saved under "fast-neural-style/data/" and named "ms-coco-{data_size}-lum.h5"

In [None]:
#define training parameters
# arch = 'c9s1-32,d64,d128,R128,R128,R128,R128,R128,u64,u32,c9s1-3'
arch = 'c9s1-16,d32,d64,R64,R64,R64,R64,R64,u32,u16,c9s1-3'
data_size = 256
data_name = str(data_size)+'-lum'
loss_network = 'models/vgg16.t7'
style_image_size = 256
style_weights = '5.0'
gpu = 0
h5_file = data_dir+'ms-coco-'+data_name+'.h5'
style_image = 'candy'
num_iterations = 40000
checkpoint_name = model_dir + style_image + '_' + data_name+'_guidance_sw_' + style_weights
checkpoint_every = 100

In [None]:
context = {
    'arch': arch,
    'h5_file': h5_file,
    'loss_network': loss_network,
    'style_image': art_dir + style_image + '.jpg',
    'style_image_size': style_image_size,
    'style_weights': style_weights,
    'num_iterations': num_iterations,
    'checkpoint_name': checkpoint_name,
    'checkpoint_every': checkpoint_every,
    'gpu': gpu
}

template = (
            '#!/bin/bash\n' +
            'time /usr/local/torch/install/bin/th train.lua ' + 
            '-arch {arch} ' +
            '-h5_file {h5_file} ' + 
            '-loss_network {loss_network} ' + 
            '-style_image {style_image} ' + 
            '-style_image_size {style_image_size} ' + 
            '-style_weights {style_weights} ' + 
            '-checkpoint_name {checkpoint_name} ' + 
            '-checkpoint_every {checkpoint_every} ' + 
            '-style_target_type gram ' + 
            '-gpu {gpu} '
           )

script_name = project_dir + '/train_fast.sh'
with open(script_name, 'w') as script:
    script.write(template.format(**context))
os.chmod(script_name, 0o755)
!{script_name}

## Train spatially guided network

For training the network, one first needs to create a dataset using the make_style_dataset.py script.

The code below assumes that the dataset is saved under "fast-neural-style/data/" and named "ms-coco-{data_size}.h5"

In [None]:
#define training parameters
# arch = 'c9s1-32,d64,d128,R128,R128,R128,R128,R128,u64,u32,c9s1-3'
arch = 'c9s1-16,d32,d64,R64,R64,R64,R64,R64,u32,u16,c9s1-3'
data_size = 256
data_name = str(data_size)
loss_network = 'models/vgg16.t7'
style_image_size = 512
style_weights = '5.0'
gpu = 0
h5_file = data_dir+'ms-coco-'+data_name+'.h5'
style_image = 'candy_over_feathers'
num_iterations = 40000
checkpoint_name = model_dir + style_image + '_' + data_name+'_guidance_sw_' + style_weights
checkpoint_every = 100

In [None]:
#define guidance channels for the style image
guide_names = [style_image.replace('.jpg','')+'_candy.jpg',style_image.replace('.jpg','')+'_feathers.jpg']
guides = []
for name in guide_names:
    guides.append(imread(guide_dir + name)[:,:,0])
guides = np.dstack(guides).transpose(2,0,1)
# save guides 
guides_file_name = tmp_dir + 'trainguides.hdf5'
f = h5py.File(guides_file_name, 'w')
f.create_dataset('guides', data=guides)
f.close()

In [None]:
context = {
    'arch': arch,
    'h5_file': h5_file,
    'loss_network': loss_network,
    'style_image': art_dir + style_image + '.jpg',
    'style_image_guides': guides_file_name,
    'style_image_size': style_image_size,
    'style_weights': style_weights,
    'num_iterations': num_iterations,
    'checkpoint_name': checkpoint_name,
    'checkpoint_every': checkpoint_every,
    'gpu': gpu
}

template = (
            '#!/bin/bash\n' +
            'time /usr/local/torch/install/bin/th train.lua ' + 
            '-arch {arch} ' +
            '-h5_file {h5_file} ' + 
            '-loss_network {loss_network} ' + 
            '-style_image {style_image} ' + 
            '-style_image_size {style_image_size} ' + 
            '-style_weights {style_weights} ' + 
            '-checkpoint_name {checkpoint_name} ' + 
            '-checkpoint_every {checkpoint_every} ' + 
            '-style_target_type guided_gram ' + 
            '-style_image_guides {style_image_guides} ' + 
            '-gpu {gpu} '
           )

script_name = project_dir + '/train_fast.sh'
with open(script_name, 'w') as script:
    script.write(template.format(**context))
os.chmod(script_name, 0o755)
!{script_name}

In [None]:
#show training log
train_log = checkpoint_name + '.json'
with open(train_log) as json_data:
    log = json.load(json_data)
plt.plot(log['val_loss_history'])