In [22]:
import os
import sys
import copy
import json
import glob
import numpy as np
%matplotlib inline
import matplotlib.pyplot as plt

sys.path.append('../assets_psychophysics')
import util_human_model_comparison
import util_figures_psychophysics


def load_results_dict(results_dict_fn, pop_key_list=['psychometric_function']):
    with open(results_dict_fn) as f: results_dict = json.load(f)
    for pop_key in pop_key_list:
        if pop_key in results_dict.keys():
            results_dict.pop(pop_key)
    return results_dict


def calc_best_metric(valid_metrics_fn, metric_key='f0_label:accuracy', maximize=True):
    if not os.path.exists(valid_metrics_fn):
        return None
    with open(valid_metrics_fn) as f:
        valid_metrics_dict = json.load(f)
    if metric_key not in valid_metrics_dict.keys():
        # If metric_key does not exist in validation_metrics_dict, look for a similarly named key
        for available_key in valid_metrics_dict.keys():
            if all([mkp in available_key for mkp in metric_key.split(':')]):
                metric_key = available_key
                break
    metric_values = valid_metrics_dict[metric_key]
    if maximize:
        best_metric_value = np.max(metric_values)
    else:
        best_metric_value = np.min(metric_values)
    return best_metric_value


def flatten_dict(d, parent_key=None, sep='/'):
    d_flat = {}
    for key in sorted(d.keys()):
        key_flat = key
        if parent_key is not None:
            key_flat = parent_key + sep + key_flat
        if isinstance(d[key], dict):
            d_flat.update(flatten_dict(d[key], parent_key=key_flat, sep=sep))
        else:
            d_flat.update({key_flat: d[key]})
    return d_flat


def all_equal(iterator):
    iterator = iter(iterator)
    first = next(iterator)
    return all(np.array_equal(first, rest) for rest in iterator)


def concatenate_dicts(list_d):
    list_d_flat = [flatten_dict(d) for d in list_d]
    d_concatenated = {}
    for key in sorted(list_d_flat[0].keys()):
        if isinstance(list_d_flat[0][key], (list, int, float, np.ndarray)):
            list_key_val = [d[key] for d in list_d_flat]
            if all_equal(list_key_val):
                d_concatenated[key] = list_key_val[0]
            else:
                d_concatenated[key] = np.stack(list_key_val, axis=0)
    return d_concatenated


In [23]:
experiment_to_basename_map = {
    'bernox2005': 'EVAL_SOFTMAX_lowharm_v01_bestckpt_results_dict.json',
    'transposedtones': 'EVAL_SOFTMAX_transposedtones_v01_bestckpt_results_dict.json',
    'freqshiftedcomplexes': 'EVAL_SOFTMAX_freqshifted_v01_bestckpt_results_dict.json',
    'mistunedharmonics': 'EVAL_SOFTMAX_mistunedharm_v01_bestckpt_results_dict.json',
    'altphasecomplexes': 'EVAL_SOFTMAX_altphase_v01_bestckpt_results_dict.json',
}

experiment_to_human_results_map = {
    'bernox2005': util_human_model_comparison.get_human_results_dict_bernox2005(),
    'transposedtones': util_human_model_comparison.get_human_results_dict_transposedtones(),
    'freqshiftedcomplexes': util_human_model_comparison.get_human_results_dict_freqshiftedcomplexes(),
    'mistunedharmonics': util_human_model_comparison.get_human_results_dict_mistunedharmonics(),
    'altphasecomplexes': util_human_model_comparison.get_human_results_dict_altphasecomplexes(),
}

experiment_to_compfunc_map = {
    'bernox2005': util_human_model_comparison.compare_bernox2005,
    'transposedtones': util_human_model_comparison.compare_transposedtones,
    'freqshiftedcomplexes': util_human_model_comparison.compare_freqshiftedcomplexes,
    'mistunedharmonics': util_human_model_comparison.compare_mistunedharmonics,
    'altphasecomplexes': util_human_model_comparison.compare_altphasecomplexes_hist,
}

experiment_to_compfunc_kwargs_map = {
    'bernox2005': {},
    'transposedtones': {},
    'freqshiftedcomplexes': {},
    'mistunedharmonics': {},
    'altphasecomplexes': {},
}

experiment_keys = [
    'bernox2005',
    'altphasecomplexes',
    'freqshiftedcomplexes',
    'mistunedharmonics',
    'transposedtones',
]


In [24]:
basename_valid_metrics = 'validation_metrics.json'
list_regex_model_dir = [
    '/saved_models/arch_search_v02_topN/sr20000_cf100_species002_spont070_BW10eN1_IHC3000Hz_IHC7order/arch_0???/',

    '/saved_models/arch_search_v02_topN/REDOsr2000_cf1000_species002_spont070_BW10eN1_IHC0050Hz_IHC7order/arch_0???/',
    '/saved_models/arch_search_v02_topN/REDOsr20000_cf100_species002_spont070_BW10eN1_IHC0050Hz_IHC7order/arch_0???/',
    '/saved_models/arch_search_v02_topN/REDOsr20000_cf100_species002_spont070_BW10eN1_IHC0320Hz_IHC7order/arch_0???/',
    '/saved_models/arch_search_v02_topN/REDOsr20000_cf100_species002_spont070_BW10eN1_IHC1000Hz_IHC7order/arch_0???/',
    '/saved_models/arch_search_v02_topN/sr20000_cf100_species002_spont070_BW10eN1_IHC3000Hz_IHC7order/arch_0???/',
    '/saved_models/arch_search_v02_topN/REDOsr20000_cf100_species002_spont070_BW10eN1_IHC6000Hz_IHC7order/arch_0???/',
    '/saved_models/arch_search_v02_topN/REDOsr20000_cf100_species002_spont070_BW10eN1_IHC9000Hz_IHC7order/arch_0???/',
]

# Compile list of lists of model psychophysical data
list_list_model_dir = []
list_dict_super = []

# For each entry in list_regex_model_dir, grab all of the models that are globbed by the regex
for regex_model_dir in list_regex_model_dir:
    prefix = None
    if isinstance(regex_model_dir, tuple):
        (regex_model_dir, prefix) = regex_model_dir

    list_model_dir = []
    list_valid_metric = []
    dict_results_dicts = {ek: [] for ek in experiment_keys}
    dict_human_model_comparison = {
        ek: {
            'human_model_similarity_pval': [],
            'human_model_similarity_coef': [],
        }
        for ek in experiment_keys
    }
    for idx, model_dir in enumerate(sorted(glob.glob(regex_model_dir))):
        fn_valid_metric = os.path.join(model_dir, basename_valid_metrics)
        fn_result_dict = {
            ek: os.path.join(model_dir, experiment_to_basename_map[ek]) for ek in experiment_keys
        }
        if 'snr_pos' in model_dir:
            high_snr_basename = 'EVAL_SOFTMAX_lowharm_v04_bestckpt_results_dict.json'
            fn_result_dict['bernox2005'] = os.path.join(model_dir, high_snr_basename)
            high_snr_basename = 'EVAL_SOFTMAX_transposedtones_v02_bestckpt_results_dict.json'
            fn_result_dict['transposedtones'] = os.path.join(model_dir, high_snr_basename)
            print(model_dir)
        if prefix is not None:
            for k in fn_result_dict.keys():
                fn_result_dict[k] = fn_result_dict[k].replace('EVAL_SOFTMAX', prefix)
        include_model_flag = True
        for ek in experiment_keys:
            if not os.path.exists(fn_result_dict[ek]):
                include_model_flag = False
        if include_model_flag:
            list_valid_metric.append(calc_best_metric(fn_valid_metric))
            list_model_dir.append(model_dir)
            # Load results_dict for each model and experiment
            for ek, results_dict_fn in fn_result_dict.items():
                results_dict = load_results_dict(results_dict_fn)
                dict_results_dicts[ek].append(results_dict)
                # Measure human-model similarity for each model and experiment
                compfunc = experiment_to_compfunc_map[ek]
                compfunc_kwargs = experiment_to_compfunc_kwargs_map[ek]
                r, p = compfunc(
                    experiment_to_human_results_map[ek],
                    results_dict,
                    **compfunc_kwargs)
                dict_human_model_comparison[ek]['human_model_similarity_coef'].append(r)
                dict_human_model_comparison[ek]['human_model_similarity_pval'].append(p)

    dict_super = {}
    for ek in experiment_keys:
        dict_super[ek] = concatenate_dicts(dict_results_dicts[ek])
        dict_super[ek]['human_model_similarity_coef'] = dict_human_model_comparison[ek]['human_model_similarity_coef']
        dict_super[ek]['human_model_similarity_pval'] = dict_human_model_comparison[ek]['human_model_similarity_pval']
        dict_super[ek]['validation_accuracy'] = list_valid_metric

    # Add lists of model results to the master list
    list_list_model_dir.append(list_model_dir)
    list_dict_super.append(dict_super)

    print(regex_model_dir, len(list_model_dir))


/saved_models/arch_search_v02_topN/sr20000_cf100_species002_spont070_BW10eN1_IHC3000Hz_IHC7order/arch_0???/ 10
/saved_models/arch_search_v02_topN/REDOsr2000_cf1000_species002_spont070_BW10eN1_IHC0050Hz_IHC7order/arch_0???/ 10
/saved_models/arch_search_v02_topN/REDOsr20000_cf100_species002_spont070_BW10eN1_IHC0050Hz_IHC7order/arch_0???/ 10
/saved_models/arch_search_v02_topN/REDOsr20000_cf100_species002_spont070_BW10eN1_IHC0320Hz_IHC7order/arch_0???/ 10
/saved_models/arch_search_v02_topN/REDOsr20000_cf100_species002_spont070_BW10eN1_IHC1000Hz_IHC7order/arch_0???/ 10
/saved_models/arch_search_v02_topN/sr20000_cf100_species002_spont070_BW10eN1_IHC3000Hz_IHC7order/arch_0???/ 10
/saved_models/arch_search_v02_topN/REDOsr20000_cf100_species002_spont070_BW10eN1_IHC6000Hz_IHC7order/arch_0???/ 10
/saved_models/arch_search_v02_topN/REDOsr20000_cf100_species002_spont070_BW10eN1_IHC9000Hz_IHC7order/arch_0???/ 10


In [27]:
for ek in experiment_keys:
    print('_'*24, ek, '_'*24)
    if 'mistuned' not in ek:
        for k in sorted(list_dict_super[0][ek].keys()):
            print(k, np.array(list_dict_super[0][ek][k]).shape)


________________________ bernox2005 ________________________
f0dl (10, 60)
human_model_similarity_coef (10,)
human_model_similarity_pval (10,)
kwargs_f0_prior/octave_range (2,)
low_harm (60,)
phase_mode (60,)
validation_accuracy (10,)
________________________ altphasecomplexes ________________________
f0_bin_centers (12,)
f0_pred_ratio_results/f0_condition_list (9,)
f0_pred_ratio_results/f0_pred_ratio_list (10, 9)
f0_pred_ratio_results/filter_condition_list (9,)
f0_pred_ratio_results/kwargs_f0_pred_ratio/f0_bin_centers (3,)
f0_pred_ratio_results/kwargs_f0_pred_ratio/f0_bin_width ()
filter_fl_bin_means/125.0 (10, 12)
filter_fl_bin_means/1375.0 (10, 12)
filter_fl_bin_means/3900.0 (10, 12)
human_model_similarity_coef (10,)
human_model_similarity_pval (10,)
kwargs_f0_prior/octave_range (2,)
validation_accuracy (10,)
________________________ freqshiftedcomplexes ________________________
f0_max ()
f0_min ()
human_model_similarity_coef (10,)
human_model_similarity_pval (10,)
kwargs_f0_prior/o

In [19]:
with open('pitchnet_paper_stats_data_psychophysics_2020AUG09.json', 'r') as f:
    DATA_DICT = json.load(f)


In [26]:
# for ek in experiment_keys:
#     print('____________________________', ek, '____________________________')
#     X = DATA_DICT['IHC3000Hz-{}'.format(ek)]
#     print('### OLD')
#     for k in sorted(X.keys()):
#         print(k, np.array(X[k]).shape)
#     print('### NEW')
#     for k in sorted(list_dict_super[0][ek].keys()):
#         print(k, np.array(list_dict_super[0][ek][k]).shape)
    
#     if 'f0dl' in X:
#         f0dl_OLD = np.array(X['f0dl'])
#         f0dl_NEW = np.array(list_dict_super[0][ek]['f0dl'])
        
#         f0dl_OLD[f0dl_OLD > 100.0] = 100.0
#         f0dl_NEW[f0dl_NEW > 100.0] = 100.0
#         print('COMPARING f0dl_NEW and f0dl_OLD')
#         print(np.max(np.abs(f0dl_OLD - f0dl_NEW)))
#         print(np.max(f0dl_OLD), np.max(f0dl_NEW), np.min(f0dl_NEW))
