## Import

In [None]:
%matplotlib inline
import numpy as np
import os

import menpo.io as mio
from menpo.feature import fast_dsift, double_igo, no_op, dsift
from menpo.visualize import print_dynamic, print_progress, visualize_images, visualize_pointclouds

from menpofit.fittingresult import compute_error
from menpofit.visualize import visualize_shape_model, visualize_fitting_result, plot_ced

from alabortijcv2015.aam import PartsAAMBuilder, PartsAAMFitter
from alabortijcv2015.aam.algorithm import SIC, BSC
from alabortijcv2015.utils import pickle_load, pickle_dump
from alabortijcv2015.result import SerializableResult

## Visualize data

In [None]:
max_images = 150

all_paths = []
for path in mio.image_paths('/vol/atlas/databases/body/FashionPose/Trainset/rescaled_img_train/'):
    all_paths.append(path)
    
path_initial = '/vol/atlas/databases/body/FashionPose/Trainset/InitialShapes/'
images = []
for path in print_progress(all_paths[:max_images]):
    im = mio.import_image(path)
    if im.n_channels == 3:
        im = im.as_greyscale(mode='luminosity')
    
    sh = mio.import_landmark_file(path_initial + os.path.basename(str(path))[:-3] + 'ljson')
    im.landmarks['CNN'] = sh    
    images.append(im)

In [None]:
visualize_images(images)

## Plot Initial Shapes Curve

In [None]:
cnn_shapes = []
gt_shapes = []
for sh in mio.import_landmark_files('/vol/atlas/databases/body/FashionPose/Trainset/InitialShapes/', verbose=True):
    cnn_shapes.append(sh)
    
    gt_path = os.path.basename(str(sh.path))
    
    gt_sh = mio.import_landmark_file('/vol/atlas/databases/body/FashionPose/Trainset/rescaled_img_train/' + gt_path)
    gt_shapes.append(gt_sh)

In [None]:
error_type = 'me_norm' #'me_norm', or 'me' or 'rmse'
errors = [compute_error(sh.lms, gt.lms, error_type=error_type) for sh, gt in zip(cnn_shapes, gt_shapes)]

In [None]:
plot_ced([errors])

In [None]:
print("CNN: mean: {:1.4f}, median: {:1.4f}, std: {:1.4f}".format(np.mean(errors), 
                                                                 np.median(errors),
                                                                 np.std(errors)))

## Load Data

In [None]:
n_training_images = 500
n_testing_images = 20

In [None]:
images = []
for im in mio.import_images('/vol/atlas/databases/body/FashionPose/Trainset/rescaled_img_train/', verbose=True, 
                            max_images=n_training_images + n_testing_images):
    if im.n_channels == 3:
            im = im.as_greyscale(mode='luminosity')
    images.append(im)

In [None]:
visualize_images(images)

In [None]:
pointclouds = [im.landmarks['LJSON'].lms for im in images]
visualize_pointclouds(pointclouds)

## Train Model

In [None]:
patch_shape = (24, 24)
features = fast_dsift
diagonal = 150
normalize_parts = no_op
scales = (1, .5)
max_shape_components = 50
max_appearance_components = 200

save_path = '/vol/atlas/homes/mej114/'

In [None]:
aam = PartsAAMBuilder(parts_shape=patch_shape,
                      features=features,
                      diagonal=diagonal,
                      normalize_parts=normalize_parts,
                      scales=scales,
                      max_shape_components=max_shape_components,
                      max_appearance_components=max_appearance_components).build(images[:n_training_images], 
                                                                                 verbose=True)

In [None]:
aam_type = aam.__class__.__name__
pickle_dump(aam, save_path + aam_type + '_' + features.__name__ + '.pickle')

In [None]:
visualize_shape_model(aam.shape_models)

In [None]:
aam.appearance_models[1].plot_eigenvalues_ratio()

In [None]:
aam.appearance_models[1].plot_eigenvalues_cumulative_ratio()

## Fit Model

In [None]:
aam = pickle_load(save_path + 'PartsAAM_fast_dsift.pickle')

In [None]:
algorithm_cls = SIC  #BSC
n_shape = [10, 20]; 
n_appearance = [30, 50]
sampling_step = 1

max_iters = 50
prior = False

In [None]:
sampling_mask = np.require(np.zeros(patch_shape), dtype=np.bool)
sampling_mask[::sampling_step, ::sampling_step] = True

fitter = PartsAAMFitter(aam, algorithm_cls=algorithm_cls, n_shape=n_shape,
                        n_appearance=n_appearance, sampling_mask=sampling_mask)

In [None]:
path_initial = '/vol/atlas/databases/body/FashionPose/Trainset/InitialShapes/'

fitter_results = []
for j, im in enumerate(images[n_training_images+1:n_training_images+n_testing_images]):
    # Get groundtruth shape
    groundtruth_shape = im.landmarks['LJSON'].lms
    # Get initial shape
    initial_shape = mio.import_landmark_file(path_initial + os.path.basename(str(im.path))[:-3] + 'ljson')
    # Fit
    fr = fitter.fit(im, initial_shape.lms, gt_shape=groundtruth_shape, max_iters=max_iters, prior=prior)
    # Append fitting result
    fr.downscale = 0.5
    fitter_results.append(fr)
    # Print progress
    print_dynamic("Image: {}/{}, Error: {:1.4f} -> {:1.4f}".format(j, n_testing_images-1, 
                                                                   fr.initial_error(), fr.final_error()))

In [None]:
results = [SerializableResult('none', fr.shapes(), fr.n_iters, 'FastSIC', fr.gt_shape) 
           for fr in fitter_results]
pickle_dump(results, save_path + aam_type + '_' + features.__name__ + '_noise' + str(noise_std) + '.pickle')

## Visualize results

In [None]:
visualize_images(images)

In [None]:
visualize_fitting_result(fitter_results)

In [None]:
error_type = 'me_norm' #'me_norm', or 'me' or 'rmse'

initial_errors = [fr.initial_error(error_type=error_type) for fr in fitter_results]
final_errors = [fr.final_error(error_type=error_type) for fr in fitter_results]

plot_ced([initial_errors, final_errors], legend_entries=['CNN', 'AAM'])

In [None]:
print("               |  mean  | median |  std  ")
print("Initialization | {:1.4f} | {:1.4f} | {:1.4f}".format(np.mean(initial_errors), 
                                                            np.median(initial_errors),
                                                            np.std(initial_errors)))
print("Fitting result | {:1.4f} | {:1.4f} | {:1.4f}".format(np.mean(final_errors), 
                                                            np.median(final_errors),
                                                            np.std(final_errors)))