# AAM Demo Script


## Initialization and loading data

In [None]:
import numpy as np
import scipy.io as sio
import scipy
import matplotlib.pyplot as plt
import math
%reload_ext autoreload
%autoreload 2

IMG_SET_ID = 4 # 7
path = "data/"
dist_name = path + "distances_" + str(IMG_SET_ID) + ".mat"
PSF_name  = path + "GaussStd2Color_" + str(IMG_SET_ID) + ".mat"
PSF_NIR_name = path + "GaussStd2Nir_" + str(IMG_SET_ID) + ".mat"

mat_dist = sio.loadmat(dist_name)
mat_PSF = sio.loadmat(PSF_name)
mat_PSF_NIR = sio.loadmat(PSF_NIR_name)

distances = mat_dist['distancesCol']
PSF       = mat_PSF['GaussStd2Color']
PSF_NIR   = np.squeeze(mat_PSF_NIR['GaussStd2Nir'])
PSF[:,3] = PSF_NIR

assert distances.shape[0] == PSF.shape[0]
assert distances.shape[0] == PSF_NIR.shape[0]
print('loaded experimental data from {} distances and {} channels.'.format(*PSF.shape))

#distances2 = np.squeeze(distances)
#diff_dist = [x - distances2[i - 1] for i, x in enumerate(distances2) if i > 0]
#b = np.zeros((50, 2))
#b[:,0] = diff_dist
#b[:,1] = range(len(diff_dist))
#print(b)

## Resample data uniformly. 

In [None]:
from aam import make_uniform
    
plt.plot(distances, PSF, '*') 
plt.ylabel('Raw Experimental PSF')
plt.show()

distances_uniform, PSF_uniform = make_uniform(distances, PSF, 'uniform')

plt.figure()
plt.plot(distances_uniform, PSF_uniform, '*')
plt.ylabel('Interpolated PSF')
plt.show()

# do this for IMG_SET_ID == 4
distances_uniform, PSF_uniform = make_uniform(distances, PSF, 'manual')
num_samples = 50
num_channels = PSF.shape[1]

plt.figure()
plt.plot(distances_uniform, PSF_uniform, '*')
plt.ylabel('Manual resampled PSF')
plt.show()

In [None]:
from aam import polynomial_fitting, get_focus_distances
from aam import compute_aam

x = np.squeeze(0.001 * distances_uniform)
polyParams = polynomial_fitting(x, PSF_uniform)
x0 = get_focus_distances(polyParams, bounds=(x[0], x[-1]))
print(x0)
num_colors = 4
num_alphas = 51
alphaList = np.linspace(0.2, 0.5, num_alphas)

AAM = compute_aam(polyParams[:num_colors, :], x0[:num_colors], alphaList)

for i, (alpha, aam) in enumerate(zip(alphaList, AAM)):
    if i in [0, np.floor(num_alphas/2), num_alphas-1]:
        print('{:2.2f} \t {:10.2f} \t'.format(alpha, aam))