### Data fitting analysis
The goal of this notebook is to perform the ability of the different models to fit the data.
The main way to do so is to generate data from the fitted models and compare the statistics of the fitted data with the statistics of the real data.

The analysis bellow is done only to one fold from the cross validation.

In [1]:
%matplotlib inline

import os
import sys
import pickle
sys.path.append('../')

import numpy as np
from scipy.stats import gaussian_kde
import matplotlib.pyplot as plt

from src.FullModel.model import Model as full_model
from src.FullModel.epsilon_param import EpsParam as epsilon_object
from src.FullModel.xi_param import XiParam as xi_object

from src.LocalGlobalAttentionModel.b_param import BParam as b_object
from src.LocalGlobalAttentionModel.s0_param import S0Param as s_0_object

from src.LocalChoiceModel.model import Model as local_choice_model

from src.FixedChoiceModel.model import Model as fixed_choice_model
from src.FixedChoiceModel.rho_param import RhoParam as rho_object

from src.LocalSaliencyModel.model import Model as local_saliency_model
from src.LocalSaliencyModel.xi_param import XiParam as xi_object_local_saliency

#### Tehse are functions that will help us to read and process the results

In [2]:
def read_samples_files(file_name, burnin):
    """
    This function reads the results of an inference for one subject assuming one chain.
    :param file_name: path to the file containing the inference results.
    :return:
    """

    with open(file_name, 'rb') as f:
        # Expected - a list where each item is a numpy array containing the sample for a parametr.
        res = pickle.load(f)

    samps_0 = np.array(res[0])
    if 0 in samps_0:
        # Could be that the end of the chains contains 0s. We don't want to return them.
        eff_len = np.where(samps_0 == 0)[0][0]
    else:
        eff_len = samps_0.shape[0]

    processed_result = []
    # go over each parameter
    for i, result in enumerate(res):
        tmp_res = np.array(result)[:eff_len]
        # if it is a parameter with multiple dimensions - separate the chain of each dimension.
        if len(tmp_res.shape) > 1:
            s = tmp_res.shape[1]
            for j in range(s):
                processed_result.append(tmp_res[burnin:, j])
        else:
            processed_result.append(tmp_res[burnin:])
    return np.array(processed_result)


def read_folder(folder_path, burnin):
    """
    This function iterates over the results of an experiment.
    It expects one output file per subject, where the subject index is the last characters of the file name.
    :param folder_path: path to the folder containing the results-
    :return: Dictionary with subject index as key and the results of read_samples_files as value.
    """
    res = {}
    for file_name in os.listdir(folder_path):
        file_path = os.path.join(folder_path, file_name)

        if 'sub' not in file_name:
            continue
        else:
            try:
                sub_ind = int(file_path[-4:-2])
            except ValueError:
                sub_ind = int(file_path[-3:-2])

            tmp_res = read_samples_files(file_path, burnin)
            res[sub_ind] = tmp_res
    return res

def pix_to_degs(fixations):
    """
    The models generate data in pixels and we want to convert it to visual degrees.
    """
    shape = [64, 64]
    range_x = [1.035, 32.1]
    range_y = [0.82, 25.68]
    shape_degs = [range_x[1] - range_x[0], range_y[1] - range_y[0]]
    deg_fixs = []
    conv = np.array([shape_degs[0] / shape[0], shape_degs[1] / shape[1]])
    num_pixs = 128

    for fixs_im in fixations:
        deg_fixs.append([])
        for fixs_sub in fixs_im:
            try:
                deg_fixs[-1].append((fixs_sub.T * conv + np.array([range_x[0], range_y[0]])).T)
            except ValueError:
                deg_fixs[-1].append((fixs_sub[:-1, :].T * conv).T)


    return deg_fixs

def get_kdes_estimates(data, positions):
    """
    This function takes dataset of saccades amplitudes and uses kde
    to get the frequency of saccade amplitudes.
    """
    shape = data.shape
    res = np.zeros((shape[0], shape[1], positions.shape[0]))
    for i in range(shape[0]):
        for j in range(shape[1]):
            kde = gaussian_kde(data[i, j])
            res[i, j] = kde(positions)
    return res

#### Read the test data

In [3]:
test_fixations_path = '../DATA/processed_data/test_fixs_for_cross_validation.p'
test_saliencies_path = '../DATA/processed_data/test_sals_for_cross_validation.p'

fold = 0

with open(test_fixations_path, 'rb') as f:
    test_fixations = pickle.load(f)
    
with open(test_saliencies_path, 'rb') as f:
    test_saliencies = pickle.load(f)
    
test_fixations = test_fixations[fold]
test_saliencies = test_saliencies[fold]

#### For each subject we create a model with the parameters values, as they were inferred for this subject.
Since the inference results in a distribution, we sample multiple configurations for each parameter.
We will generate later data from these models.

In [4]:
burnin = 5000
num_samples = 50

#### Full Model

In [5]:
full_model_results = '../Results/cross_validation/full_model_fold_0'
full_model_samples = read_folder(full_model_results, burnin)

full_model_objects_dict = {}

for sub in full_model_samples.keys():
    try:
        samples_s_0, samples_b, samples_epsilon_x, samples_epsilon_y, samples_xi_x, samples_xi_y = full_model_samples[sub]
    except ValueError:
        print(sub)
        continue

    try:
        full_model_objects_dict[sub] = []
        
        for i in range(num_samples):
            chosen_ind = int(np.random.choice(np.linspace(0, len(samples_s_0) - 1,len(samples_s_0))))
            s_0 = samples_s_0[chosen_ind]
            b = samples_b[chosen_ind]
            eps_x = samples_epsilon_x[chosen_ind]
            eps_y = samples_epsilon_y[chosen_ind]
            xi_x = samples_xi_x[chosen_ind]
            xi_y = samples_xi_y[chosen_ind]
            
            s_0_ob = s_0_object()
            s_0_ob.set_value(s_0)

            b_ob = b_object()
            b_ob.set_value(b)

            eps_ob = epsilon_object()
            eps_ob.set_value(np.array([eps_x, eps_y]))

            xi_ob = xi_object()
            xi_ob.set_value(np.array([xi_x, xi_y]))

            cov_ratio = 4
            full_model_objects_dict[sub].append(full_model(test_saliencies, 
                                                           s_0_ob, b_ob, eps_ob, xi_ob, cov_ratio))

    except RuntimeWarning:
        print(sub)
        continue

#### Local Saliency Model

In [6]:
local_saliency_model_result_folder = '../Results/cross_validation/local_saliency_model_fold_0'
local_saliency_model_samples = read_folder(local_saliency_model_result_folder, burnin)
local_saliency_model_objects_dict = {}

for sub in local_saliency_model_samples.keys():
    try:
        samples_xi_x, samples_xi_y = local_saliency_model_samples[sub]
    except KeyError:
        print(sub)
        continue
    
    local_saliency_model_objects_dict[sub] = []
    for i in range(num_samples):
        chosen_ind = int(np.random.choice(np.linspace(0, len(samples_xi_x) - 1,len(samples_xi_x))))
        xi_x = samples_xi_x[chosen_ind]
        xi_y = samples_xi_y[chosen_ind]
        
        xi_ob = xi_object_local_saliency()
        xi_ob.set_value(np.array([xi_x, xi_y]))
        local_saliency_model_objects_dict[sub].append(local_saliency_model(test_saliencies, xi_ob))

#### Fixed Choice Model

In [7]:
fixed_choice_model_results_folder = '../Results/cross_validation/fixed_choice_model_fold_0'
fixed_choice_model_samples = read_folder(fixed_choice_model_results_folder, burnin)

fixed_choice_model_objects_dict = {}

for sub in fixed_choice_model_samples.keys():
    try:
        samples_rho, samples_epsilon_x, samples_epsilon_y, samples_xi_x, samples_xi_y = fixed_choice_model_samples[sub]
    except ValueError:
        print(sub)
        continue

    fixed_choice_model_objects_dict[sub] = []
    for i in range(num_samples):
        chosen_ind = int(np.random.choice(np.linspace(0, len(samples_rho) - 1,len(samples_rho))))
        rho = samples_rho[chosen_ind]
        eps_x = samples_epsilon_x[chosen_ind]
        eps_y = samples_epsilon_y[chosen_ind]
        xi_x = samples_xi_x[chosen_ind]
        xi_y = samples_xi_y[chosen_ind]
        
        
        rho_ob = rho_object()
        rho_ob.set_value(rho)
        eps_ob = epsilon_object()
        eps_ob.set_value(np.array([eps_x, eps_y]))
        xi_ob = xi_object()
        xi_ob.set_value(np.array([xi_x, xi_y]))   

        cov_ratio = 4
        fixed_choice_model_objects_dict[sub].append(fixed_choice_model(test_saliencies, 
                                                                       rho_ob, eps_ob, 
                                                                       xi_ob, cov_ratio))

#### Local choice model

In [8]:
local_choice_model_results_folder = '../Results/cross_validation/local_choice_model_fold_0'
local_choice_model_samples = read_folder(local_choice_model_results_folder, burnin)

local_choice_model_objects_dict = {}

for sub in local_choice_model_samples.keys():
    try:
        samples_s_0, samples_b, samples_epsilon_x, samples_epsilon_y, samples_xi_x, samples_xi_y = local_choice_model_samples[sub]
    except ValueError:
        print(sub)
        continue
    
    local_choice_model_objects_dict[sub] = []
    for i in range(num_samples):
        chosen_ind = int(np.random.choice(np.linspace(0, len(samples_s_0) - 1,len(samples_s_0))))
        s_0 = samples_s_0[chosen_ind]
        b = samples_b[chosen_ind]
        eps_x = samples_epsilon_x[chosen_ind]
        eps_y = samples_epsilon_y[chosen_ind]
        xi_x = samples_xi_x[chosen_ind]
        xi_y = samples_xi_y[chosen_ind]
        
        s_0_ob = s_0_object()
        s_0_ob.set_value(s_0)

        b_ob = b_object()
        b_ob.set_value(b)

        eps_ob = epsilon_object()
        eps_ob.set_value(np.array([eps_x, eps_y]))

        xi_ob = xi_object()
        xi_ob.set_value(np.array([xi_x, xi_y]))   

        cov_ratio = 4
        local_choice_model_objects_dict[sub].append(local_choice_model(test_saliencies, 
                                                                       s_0_ob, b_ob,
                                                                       eps_ob, xi_ob, cov_ratio))

In [9]:
all_models = [full_model_objects_dict, local_saliency_model_objects_dict,
              fixed_choice_model_objects_dict, local_choice_model_objects_dict]
models_names = ['Full \n model', 'Local \n saliency \n model', 'Fixed \n choice \n model', 'Local \n choice \n model']

subjects = full_model_objects_dict.keys()

#### To process the data we put it in a dummy model.

In [10]:
# we use the local saliency model as it is the simplest
xi_dummy = xi_object_local_saliency()
dummy_data_model = local_saliency_model(test_saliencies, xi_dummy)
dummy_data_model.fixations = test_fixations
dummy_data_model.set_fix_dist_2()
dummy_data_model.set_saliencies_ts()
dummy_data_model.fixs_degs = pix_to_degs(dummy_data_model.fixations)
dummy_data_model.set_fix_dist_2_degs()
dummy_data_model.set_angles_between_saccades_ts()
dummy_data_model.set_angles_ts()

data_fixs_dists_2_deg = dummy_data_model.fix_dists_2_degs
data_fixs_dists_2 = dummy_data_model.fix_dists_2
data_sals_ts = dummy_data_model.saliencies_ts
data_dir_x = dummy_data_model.angles_x_ts
data_dir_change = dummy_data_model.angles_between_ts

#### Calculate NSS of the data for each model

In [None]:
# This will take a while if num_samples is big

nss = np.empty((len(all_models), len(subjects), num_samples))

for k, model in enumerate(all_models):
    for s, sub in enumerate(subjects):
        fixs_sub = [[test_fixations[i][s]] for i in range(len(test_fixations))]
        sal_ts_sub = [[data_sals_ts[i][s]] for i in range(len(data_sals_ts))]
        fix_dists_2_sub = [[data_fixs_dists_2[i][s]] for i in range(len(data_fixs_dists_2))]
        for ind in range(num_samples):
            res = model[sub][ind].calculate_likelihood_per_subject(fix_dists_2_sub, sal_ts_sub, fixs_sub, per_fixation=False, for_nss=True, saliencies=test_saliencies)
            nss[k, s, ind] = np.array([res[im].mean() for im in range(len(res))]).mean()

  return 1. / (1. + np.exp(-arg))


In [None]:
print(models_names)
print(nss.mean(axis=(1,2)))

#### Generate data for all the models

In [None]:
time_steps = np.zeros((len(subjects), len(test_fixations)))
for i, model in enumerate(all_models):
    
    for s, sub in enumerate(subjects):
        if i == 0:
            for k in range(len(test_fixations)):
                time_steps[s, k] = test_fixations[k][s].shape[1]
            time_steps = time_steps.astype(int)
        
        for j in range(num_samples):
            gammas, fixs = model[sub][j].generate_dataset(time_steps[s], 1)
            
            model[sub][j].set_fixations(fixs)
            model[sub][j].set_fix_dist_2()
            model[sub][j].set_angles_ts()
            model[sub][j].set_angles_between_saccades_ts()
            model[sub][j].set_saliencies_ts()
            model[sub][j].fixs_degs = pix_to_degs(model[sub][j].fixations)
            model[sub][j].set_fix_dist_2_degs()
            model[sub][j].set_angles_between_saccades_ts()
            model[sub][j].set_angles_ts()

#### Pile together everything so we can plot it

In [None]:
num_fixs = np.sum([test_fixations[i][s].shape[1] for i in range(len(test_fixations)) for s in range(len(test_fixations[i]))])
num_fixs_per_subject = np.array([np.sum([test_fixations[i][s].shape[1] for i in range(len(test_fixations))]) for s in range(len(subjects))])

num_diffs = num_fixs - len(subjects) * len(test_fixations)
flattend_fix_dist_2_deg = np.zeros((len(all_models), num_samples, num_diffs))
gen_data_means = np.zeros((len(all_models), len(subjects), num_samples))
gen_data_stds = np.zeros((len(all_models), len(subjects), num_samples))

sacc_dir_means = np.zeros((len(all_models), len(subjects), num_samples))
gen_data_stds = np.zeros((len(all_models), len(subjects), num_samples))

for i, model in enumerate(all_models):
    flat_dists_deg = []
    flat_dists_pix = []
    for l in range(num_samples):
        all_subs = []
        for s, sub in enumerate(subjects):
            saccs_dir_sub = model[sub][l].angles_between_ts
            dists_deg = [np.sqrt(model[sub][l].fix_dists_2_degs[im][0][-1, :]) for im in range(len(model[sub][l].fix_dists_2_degs))]            
            sub_dat = np.concatenate(dists_deg)
            all_subs.append(sub_dat)
            gen_data_means[i][sub][l] = np.mean(sub_dat)
            gen_data_stds[i][sub][l] = np.std(sub_dat)
        flattend_fix_dist_2_deg[i][l] = np.concatenate(all_subs)

In [None]:
saccade_lengths_data_deg = np.hstack([np.sqrt(data_fixs_dists_2_deg[i][s][-1,:]) for i in range(len(data_fixs_dists_2_deg)) for s in range(len(data_fixs_dists_2_deg[i]))]).flatten()

#### Get densities of saccades amplitudes

In [None]:
positions = np.arange(-1, 25, 0.01)
kdes_res = get_kdes_estimates(flattend_fix_dist_2_deg, positions)
kde_data = gaussian_kde(saccade_lengths_data_deg)
kde_res_data = kde_data(positions)

In [None]:
kdes_mean = kdes_res.mean(axis=1)
kdes_percentiles = np.percentile(kdes_res, [2.5, 97.5], axis=1)

In [None]:
f, axarr = plt.subplots(2,1, figsize=(12, 10), sharex=True)
axarr[0].plot(positions, kde_res_data, label='Experimental Data', color='black')
axarr[0].plot(positions, kdes_mean[0], label='Full Model', color='C1')
axarr[0].plot(positions, kdes_mean[1], label='Local Saliency Model', color='C2')
axarr[0].plot(positions, kdes_mean[2], label='Local Choice Model', color='C3')
axarr[0].plot(positions, kdes_mean[3], label='Fixed Choice Model', color='C4')

axarr[0].legend(fontsize=15)
axarr[0].set_ylabel('Density', fontsize=30)
axarr[0].tick_params(labelsize=20)

axarr[1].plot(positions, kde_res_data, label='Experimental Data', color='black')
axarr[1].plot(positions, kdes_mean[0], label='Full Model', color='C1')
axarr[1].fill_between(positions, kdes_percentiles[0, 0], kdes_percentiles[1,0], color='peachpuff')
axarr[1].plot(positions, kdes_mean[1], label='Local Saliency Model', color='C2')
axarr[1].fill_between(positions, kdes_percentiles[0, 1], kdes_percentiles[1,1], color='#BFE2BF')
axarr[1].legend(fontsize=15)
axarr[1].set_xlabel('Saccade length [deg]', fontsize=30)
axarr[1].set_ylabel('Density', fontsize=30)
axarr[1].tick_params(labelsize=20)
axarr[1].set_xlim((-1,20))

plt.tight_layout()
plt.show()
plt.show()

#### compare mean and std of saccade amplitude per subject

In [None]:
gen_data_std_means = gen_data_stds.mean(axis=2)
gen_data_means_means = gen_data_means.mean(axis=2)
gen_data_means_errors = gen_data_std_means / np.sqrt(num_fixs_per_subject - 1)

In [None]:
f, axarr = plt.subplots(1,2, figsize=(10, 5))

axarr[0].plot([3.8,8.5], [3.8, 8.5], linewidth=1)
axarr[0].plot(data_means, gen_data_means_means[0], '+', markersize=10)
axarr[0].set_xlabel('Subjects\' mean \n saccade length [deg]', fontsize=20)
axarr[0].set_ylabel('Models\' data mean \n saccade length [deg]',  fontsize=20)
axarr[0].set_xlim((3.8, 8.5))
axarr[0].set_ylim((3.8, 8.5))

axarr[1].plot([3., 5.2], [3., 5.2], linewidth=1)
axarr[1].plot(data_stds, gen_data_std_means[0], '+', markersize=10)
axarr[1].set_xlabel('Subjects\'  \n saccade length std [deg]', fontsize=20)
axarr[1].set_ylabel('Models\' data \n saccade length std [deg]',  fontsize=20)
axarr[1].set_xlim((3., 5.2))
axarr[1].set_ylim((3., 5.2))

plt.tight_layout()
plt.show()

## Saccade Direction

In [None]:
all_angs_x = np.zeros((len(all_models), num_samples, num_diffs))
all_angs_change = np.zeros((len(all_models), num_samples, num_diffs - (len(test_saliencies) * len(subjects))))
for i, model in enumerate(all_models):
    flat_angs_x = []
    flat_angs_change = []
    for l in range(num_samples):
        all_subs_x = []
        all_subs_change = []
        for s, sub in enumerate(subjects):
            saccs_dir_sub = model[sub][l].angles_x_ts
            saccs_dir_change_sub = model[sub][l].angles_between_ts
            sub_dir_x = np.concatenate([dat[0] for dat in saccs_dir_sub])
            sub_dir_change = np.concatenate([dat[0] for dat in saccs_dir_change_sub])
            all_subs_x.append(sub_dir_x)
            all_subs_change.append(sub_dir_change)
        all_angs_x[i][l] = np.concatenate(all_subs_x)
        all_angs_change[i][l] = np.concatenate(all_subs_change)

In [None]:
data_dir_x_flat = np.hstack([data_dir_x[i][s] for i in range(len(data_dir_x)) for s in range(len(data_dir_x[i]))]).flatten()
data_dir_change_flat = np.hstack([data_dir_change[i][s] for i in range(len(data_dir_change)) for s in range(len(data_dir_change[i]))]).flatten()

In [None]:
positions = np.arange(-3.2, 3.2, 0.005)
kdes_angs_x = get_kdes_estimates(all_angs_x, positions)
kdes_angs_change = get_kdes_estimates(all_angs_change, positions)

kde_data_angs_x = gaussian_kde(data_dir_x_flat)
kde_data_angs_x = kde_data_angs_x(positions)

kde_data_angs_change = gaussian_kde(data_dir_change_flat)
kde_data_angs_change = kde_data_angs_change(positions)

In [None]:
kdes_mean_angs_x = kdes_angs_x.mean(axis=1)
kdes_mean_angs_change = kdes_angs_change.mean(axis=1)

In [None]:
f, axarr = plt.subplots(1,2, figsize=(12, 4))
axarr[0].plot(positions, kde_data_angs_x, label='Experimental Data', color='black')
axarr[0].plot(positions, kdes_mean_angs_x[0], label='Full Model', color='C1')
axarr[0].plot(positions, kdes_mean_angs_x[1], label='Local Saliency Model', color='C2')
axarr[0].plot(positions, kdes_mean_angs_x[2], label='Local Choice Model', color='C3')
axarr[0].plot(positions, kdes_mean_angs_x[3], label='Fixed Choice Model', color='C4')

axarr[0].legend(fontsize=10)
axarr[0].set_ylabel('Density', fontsize=20)
axarr[0].set_xlabel('saccade direction', fontsize=20)
axarr[0].tick_params(labelsize=10)

axarr[1].plot(positions, kde_data_angs_change, label='Experimental Data', color='black')
axarr[1].plot(positions, kdes_mean_angs_change[0], label='Full Model', color='C1')
axarr[1].plot(positions, kdes_mean_angs_change[1], label='Local Saliency Model', color='C2')
axarr[1].plot(positions, kdes_mean_angs_change[2], label='Local Choice Model', color='C3')
axarr[1].plot(positions, kdes_mean_angs_change[3], label='Fixed Choice Model', color='C4')

axarr[1].set_ylabel('Density', fontsize=20)
axarr[1].set_xlabel('saccade change', fontsize=20)
axarr[1].tick_params(labelsize=10)

plt.tight_layout()
plt.show()
plt.show()