In [None]:
# import general classes
%matplotlib inline
import matplotlib.pyplot as plt

# import model classes for training and plotting
# class that trains all the models
from train_model.train_model import train_model
# class that visualizes the learned movemes
from utilities.utilities import utilities

In [None]:
# path variables
## DATASET_PATH: path to the dataset with the annotations
DATASET_PATH = '../inputs/CO_LSP_train2016.json'
## IMAGES_PATH: path of the actual images of LSP
IMAGES_PATH = '/home/mronchi/Datasets/lsp/lsp_dataset/images/'
## SAVE_PATH: path of the folder to save the resulting model data
SAVE_PATH       = '../data'

In [None]:
# set parameters for model

# flags to save the data used during training and the partial models
save_dataset        = True
save_partial_models = False

# number of latent factors to learn
num_factors = 10

# activities to include in the analysis
#   - 'athletics'
#   - 'badminton'
#   - 'baseball'
#   - 'gymnastics'
#   - 'parkour'
#   - 'soccer'
#   - 'tennis'
#   - 'volleyball'
#   - 'other'
activities = ['athletics', 'badminton', 'soccer', 'tennis', 'volleyball', 'baseball']

# model that should be used to learn the basis pose factorization
#   - 'svd':
#   - 'bucketed_svd_2d':
#   - 'bucketed_svd_3d':
#   - 'lfa_2d':
#   - 'lfa_3d':
model_type = 'lfa_3d'

# initialization for the U and V matrices in the lfa3d model
#   - 'random':
#   - 'svd':
init_type = 'random'

# type of angle annotations for initializing the angle of view of each pose
#   - 'random':
#   - 'heuristic':
#   - 'coarse':
#   - 'gt':
bucketing_metric = 'gt'

# number of pose clusters to use for discretizing the angles of view
num_buckets = 8

# objective function for the stochastic gradient descent
#   - 'l1_reg':
#   - 'l2_reg':
#   - 'l2_l1_ista_reg':
objective_f_type = 'l2_l1_ista_reg'

In [None]:
# set hyper parameters for model

hyper_params = dict()
hyper_params['l_rate_U']           = 1e-4
hyper_params['l_rate_V']           = 1e-4
hyper_params['m_rate_U']           = 1e-5
hyper_params['m_rate_V']           = 1e-5
hyper_params['l_rate_theta']       = 1e-5
hyper_params['m_rate_theta']       = 1e-6
hyper_params['positive_V_flag']    = True
hyper_params['rmse_tolerance']     = 1e-5
hyper_params['lr_decay']           = 0.5
hyper_params['obj_func_tolerance'] = 1e4
hyper_params['UV_batch_step']      = 1e3
hyper_params['theta_batch_step']   = 1e3
hyper_params['lr_bound']           = 1e-6
hyper_params['max_iter']           = int(1e7)
hyper_params['error_window']       = 5

In [None]:
# Model training

train_model_obj = train_model(
                    # path to the dataset json file and images
                    dataset_path=DATASET_PATH, images_path=IMAGES_PATH,
                    # path at which models are saved and flags
                    save_path=SAVE_PATH, save_dataset=save_dataset,
                    save_partial_models=save_partial_models,
                    # number of latent factors
                    num_factors=num_factors,
                    # actions to exclude from the analysis
                    activity_list=activities,
                    # model type trained
                    model_type=model_type,
                    # initialization for lfa3d model
                    init_type=init_type,
                    # type of angle bucketing (heuristic, random or gt based)
                    bucketing_metric=bucketing_metric,
                    # number of buckets to cluster poses
                    num_buckets=num_buckets,
                    # objective function type for the SGD procedures
                    objective_f_type=objective_f_type,
                    # a dictionary containing all the optimization hyperparams
                    hyper_params=hyper_params,
                    # provide an input matrix for U
                    U_test=None
                )

# set to True to train from scratch or with different parameters
train = False
if train:
    train_model_obj.train()
else:
    pretrained_timestamp = ''
    train_model_obj.load(pretrained_timestamp)

In [None]:
# class used for plotting and saving the learned movemes
# inputs:
#  - the trained obj: train_model_obj
#  - color list: for coloring the skeleton
#  - basis_coeff: a multiplying factor for moveme strength
#                 (not used for lfa3d or lfa2d)
# output:
#  - None, saves the movemes at the path specified in the train_model_obj
colors = ['g',
           'g',
           'y',
           'y',
           'r',
           'b',
           'r',
           'b',
           'y',
           'y',
           'm',
           'c',
           'm',
           'c']
utility_obj = utilities(train_model_obj, colors, basis_coeff=1)
utility_obj.plot_movemes()