In [1]:
import tensorflow as tf
import os
import sys
import data_generation
import networks
import scipy.io as sio
import param
import util
import truncated_vgg
from keras.backend.tensorflow_backend import set_session
from keras.optimizers import Adam

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [2]:
def train(model_name, gpu_id):
    params = param.get_general_params()

    network_dir = params['model_save_dir'] + '/' + model_name

    if not os.path.isdir(network_dir):
        os.mkdir(network_dir)

    train_feed = data_generation.create_feed(params, params['data_dir'], 'train')

    os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    set_session(tf.Session(config=config))

    vgg_model = truncated_vgg.vgg_norm()
    networks.make_trainable(vgg_model, False)
    response_weights = sio.loadmat('../data/vgg_activation_distribution_train.mat')
    model = networks.network_posewarp(params)
    model.compile(optimizer=Adam(lr=1e-4), loss=[networks.vgg_loss(vgg_model, response_weights, 12)])

    #model.summary()
    n_iters = params['n_training_iter']

    for step in range(0, n_iters):
        x, y = next(train_feed)

        train_loss = model.train_on_batch(x, y)

        util.printProgress(step, 0, train_loss)

        if step > 0 and step % params['model_save_interval'] == 0:
            model.save(network_dir + '/' + str(step) + '.h5')

In [3]:
params = param.get_general_params()

In [4]:
params

{'IMG_HEIGHT': 256,
 'IMG_WIDTH': 256,
 'obj_scale_factor': 1.14,
 'scale_max': 1.05,
 'scale_min': 0.9,
 'max_rotate_degree': 5,
 'max_sat_factor': 0.05,
 'max_px_shift': 10,
 'posemap_downsample': 2,
 'sigma_joint': 1.75,
 'n_joints': 14,
 'n_limbs': 10,
 'limbs': [[0, 1],
  [2, 3],
  [3, 4],
  [5, 6],
  [6, 7],
  [8, 9],
  [9, 10],
  [11, 12],
  [12, 13],
  [2, 5, 8, 11]],
 'n_training_iter': 200000,
 'test_interval': 500,
 'model_save_interval': 1000,
 'project_dir': '/home/jarvislam1999/posewarp-cvpr2018',
 'model_save_dir': '/home/jarvislam1999/posewarp-cvpr2018/models',
 'data_dir': '/path/to/dataset',
 'batch_size': 4}

/home/jarvislam1999/posewarp-cvpr2018/code


In [24]:
model_name = 'vgg_100000.h5'

In [25]:
network_dir = params['model_save_dir'] + '/' + model_name

In [26]:
network_dir

'/home/jarvislam1999/posewarp-cvpr2018/models/vgg_100000.h5'

False

In [35]:
train_feed = data_generation.create_feed(params, params['data_dir'], 'train')

In [30]:
train_feed

<generator object warp_example_generator at 0x7f3d2602dca8>