In [None]:
%matplotlib inline

from __future__ import division
import numpy as np
import menpo.io as mio
from menpo.image import Image
from menpo.feature import greyscale, centralize, normalize_norm, normalize_std
from menpo.landmark import labeller, ibug_face_66
from menpo.visualize import visualize_images, print_dynamic, progress_bar_str
from menpofit.visualize import visualize_fitting_results
from alaborticcv2015.alignment import PartsAAMBuilder, PartsAAMFitter
from alaborticcv2015.alignment.result import SerializableResult

# ICA-Network AAMs Parameter Optimization

Load training data:

In [None]:
training_images = []
for i in mio.import_images('/data/PhD/DataBases/faces/lfpw/trainset/', verbose=True, 
                           max_images=None):
    i.crop_to_landmarks_proportion_inplace(0.5)
    i = i.rescale_landmarks_to_diagonal_range(200)
    labeller(i, 'PTS', ibug_face_66)
    training_images.append(i)

In [None]:
visualize_images(training_images)

Load test data:

In [None]:
test_images = []
for i in mio.import_images('/data/PhD/DataBases/faces/lfpw/testset/', verbose=True, 
                           max_images=None):
    i.crop_to_landmarks_proportion_inplace(0.5)
    i = i.rescale_landmarks_to_diagonal_range(200)
    labeller(i, 'PTS', ibug_face_66)
    test_images.append(i)

In [None]:
visualize_images(test_images)

Load pre-trained PCA LDCN:

In [None]:
n_filters = range(4, 24, 2)
shapes = [s for s in range(5, 21, 2)]

sampling_mask = np.require(np.zeros((17, 17)), dtype=np.bool)
sampling_mask[2::4, 2::4] = True
Image(sampling_mask).view()

In [None]:
count = 0
for nf in n_filters:
    for s in shapes:
        
        string = '- Network {}, {}: '.format(nf, s)
        
        string2 = string + 'Building AAM '
        print_dynamic('{}{}'.format(
            string2, progress_bar_str(0, show_bar=True)))
        
        file_name = ('/data/PhD/Models/alaborticcv2015/ica_net_' 
                     + str(nf) 
                     + '_' 
                     + str(s)
                     + '.pkl.gz') 
        
        net = mio.import_pickle(file_name)
        
        def network_features(pixels):
            pixels = greyscale(pixels)
            pixels = net.network_response(pixels)
            return pixels
        
        builder = PartsAAMBuilder(features=network_features,
                                  diagonal=100, 
                                  norm_func=None,
                                  max_appearance_components=100)

        aam = builder.build(training_images, 
                            group='ibug_face_66')
        
        fitter = PartsAAMFitter(aam, 
                                n_shape=[3, 12], 
                                n_appearance=100, 
                                sampling_mask=sampling_mask)
        
        string2 = string + 'Building AAM '
        print_dynamic('{}{}'.format(
            string2, progress_bar_str(1, show_bar=True)))
        
        np.random.seed(seed=1)
        serializable_results = []
        for j, i in enumerate(test_images):
            string2 = string + 'Fitting images '
            print_dynamic('{}{}'.format(
                string2, progress_bar_str(j/len(test_images), show_bar=True)))
            gt_s = i.landmarks['ibug_face_66'].lms
            s = fitter.perturb_shape(gt_s, noise_std=0.05)
            fr = fitter.fit(i, s, gt_shape=gt_s, max_iters=20, map_inference=False)
            
            sr = SerializableResult(fr.image, fr.shapes(), fr.n_iters, fr.gt_shape)
            serializable_results.append(sr)
            
        file_name = ('/data/PhD/Models/alaborticcv2015/results_ica_net_' 
                     + str(nf) 
                     + '_' 
                     + str(s)
                     + '.pkl.gz') 
        mio.export_pickle(serializable_results, file_name, overwrite=True)
        
        errors = [sr.final_error() for sr in serializable_results]
        
        mean = np.mean(errors)
        median = np.median(errors)
        std = np.std(errors)
        
        print_dynamic(string + 
                      '\tMean={0:.4f}  Median={1:.4f}  Std={2:.4f}\n'
                      .format(mean, median, std))