In [1]:
import os
import sys
import copy
import json
import glob
import numpy as np
%matplotlib inline
import matplotlib.pyplot as plt

sys.path.append('../assets_psychophysics')
import util_human_model_comparison
import util_figures_psychophysics

sys.path.append('/packages/msutil')
import util_misc


def load_results_dict(results_dict_fn, pop_key_list=['psychometric_function']):
    with open(results_dict_fn) as f: results_dict = json.load(f)
    for pop_key in pop_key_list:
        if pop_key in results_dict.keys():
            results_dict.pop(pop_key)
    return results_dict


def calc_best_metric(valid_metrics_fn, metric_key='f0_label:accuracy', maximize=True):
    if not os.path.exists(valid_metrics_fn):
        return None
    with open(valid_metrics_fn) as f:
        valid_metrics_dict = json.load(f)
    if metric_key not in valid_metrics_dict.keys():
        # If metric_key does not exist in validation_metrics_dict, look for a similarly named key
        for available_key in valid_metrics_dict.keys():
            if all([mkp in available_key for mkp in metric_key.split(':')]):
                metric_key = available_key
                break
    metric_values = valid_metrics_dict[metric_key]
    if maximize:
        best_metric_value = np.max(metric_values)
    else:
        best_metric_value = np.min(metric_values)
    return best_metric_value


def flatten_dict(d, parent_key=None, sep='/'):
    d_flat = {}
    for key in sorted(d.keys()):
        key_flat = key
        if parent_key is not None:
            key_flat = parent_key + sep + key_flat
        if isinstance(d[key], dict):
            d_flat.update(flatten_dict(d[key], parent_key=key_flat, sep=sep))
        else:
            d_flat.update({key_flat: d[key]})
    return d_flat


def all_equal(iterator):
    iterator = iter(iterator)
    first = next(iterator)
    return all(np.array_equal(first, rest) for rest in iterator)


def concatenate_dicts(list_d):
    list_d_flat = [flatten_dict(d) for d in list_d]
    d_concatenated = {}
    if len(list_d_flat) > 0:
        for key in sorted(list_d_flat[0].keys()):
            if isinstance(list_d_flat[0][key], (list, int, float, np.ndarray)):
                list_key_val = [d[key] for d in list_d_flat]
                if all_equal(list_key_val):
                    d_concatenated[key] = list_key_val[0]
                else:
                    d_concatenated[key] = np.stack(list_key_val, axis=0)
    return d_concatenated


In [5]:
experiment_to_basename_map = {
    'bernox2005': 'EVAL_SOFTMAX_lowharm_v01_bestckpt_results_dict.json',
    'transposedtones': 'EVAL_SOFTMAX_transposedtones_v01_bestckpt_results_dict.json',
    'freqshiftedcomplexes': 'EVAL_SOFTMAX_freqshifted_v01_bestckpt_results_dict.json',
    'mistunedharmonics': 'EVAL_SOFTMAX_mistunedharm_v01_bestckpt_results_dict.json',
    'altphasecomplexes': 'EVAL_SOFTMAX_altphase_v01_bestckpt_results_dict.json',
    'f0dlspl': 'EVAL_SOFTMAX_testspl_v03_bestckpt_results_dict.json',
}

experiment_to_human_results_map = {
    'bernox2005': util_human_model_comparison.get_human_results_dict_bernox2005(),
    'transposedtones': util_human_model_comparison.get_human_results_dict_transposedtones(),
    'freqshiftedcomplexes': util_human_model_comparison.get_human_results_dict_freqshiftedcomplexes(),
    'mistunedharmonics': util_human_model_comparison.get_human_results_dict_mistunedharmonics(),
    'altphasecomplexes': util_human_model_comparison.get_human_results_dict_altphasecomplexes(),
}

experiment_to_compfunc_map = {
    'bernox2005': util_human_model_comparison.compare_bernox2005,
    'transposedtones': util_human_model_comparison.compare_transposedtones,
    'freqshiftedcomplexes': util_human_model_comparison.compare_freqshiftedcomplexes,
    'mistunedharmonics': util_human_model_comparison.compare_mistunedharmonics,
    'altphasecomplexes': util_human_model_comparison.compare_altphasecomplexes_hist,
}

experiment_to_compfunc_kwargs_map = {
    'bernox2005': {},
    'transposedtones': {},
    'freqshiftedcomplexes': {},
    'mistunedharmonics': {},
    'altphasecomplexes': {},
}

experiment_keys = [
    'bernox2005',
    'altphasecomplexes',
    'freqshiftedcomplexes',
    'mistunedharmonics',
    'transposedtones',
    'f0dlspl',
]


In [6]:
list_model_tag_and_name = [    
    ('REDOsr2000_cf1000_species002_spont070_BW10eN1_IHC0050Hz_IHC7order', 'IHC0050Hz'),
    ('REDOsr20000_cf100_species002_spont070_BW10eN1_IHC0320Hz_IHC7order', 'IHC0320Hz'),
    ('REDOsr20000_cf100_species002_spont070_BW10eN1_IHC1000Hz_IHC7order', 'IHC1000Hz'),
    ('sr20000_cf100_species002_spont070_BW10eN1_IHC3000Hz_IHC7order', 'IHC3000Hz'),
    ('REDOsr20000_cf100_species002_spont070_BW10eN1_IHC6000Hz_IHC7order', 'IHC6000Hz'),
    ('REDOsr20000_cf100_species002_spont070_BW10eN1_IHC9000Hz_IHC7order', 'IHC9000Hz'),
    
    ('sr20000_cf100_species002_spont070_BW05eN1_IHC3000Hz_IHC7order', 'BW05eN1'),
    ('sr20000_cf100_species002_spont070_BW10eN1_IHC3000Hz_IHC7order', 'BW10eN1'),
    ('sr20000_cf100_species002_spont070_BW20eN1_IHC3000Hz_IHC7order', 'BW20eN1'),
    ('sr20000_cf100_species004_spont070_BWlinear_IHC3000Hz_IHC7order', 'BWlinear'),
    
    ('cochlearn', 'cochlearn'),
    ('cochlearn_PND_v08inst_noise_TLAS_snr_neg10pos10', 'cochlearn_inst_only'),
    ('cochlearn_PND_v08spch_noise_TLAS_snr_neg10pos10', 'cochlearn_spch_only'),

    ('sr20000_cf100_species002_spont070_BW10eN1_IHC3000Hz_IHC7order', 'natural'),
    ('PND_mfcc_PNDv08PYSmatched12_TLASmatched12_snr_neg10pos10_phase3', 'natural_matched'),
    ('PND_mfcc_PNDv08PYSnegated12_TLASmatched12_snr_neg10pos10_phase3', 'natural_antimatched'),
    ('PND_v08inst_noise_TLAS_snr_neg10pos10', 'inst_only'),
    ('PND_v08spch_noise_TLAS_snr_neg10pos10', 'spch_only'),
    ('PND_v08_noise_TLAS_snr_neg10pos10_filter_signalHPv00', 'natural_hp'),
    ('PND_v08_noise_TLAS_snr_neg10pos10_filter_signalLPv01', 'natural_lp'),
    ('PND_v08_noise_TLAS_snr_pos10pos30', 'noise_low'),
    ('PND_v08_noise_TLAS_snr_posInf', 'noise_none'),
    
    ('f0_label_024', 'f0_label_024'),
    ('f0_label_048', 'f0_label_048'),
    ('f0_label_096', 'f0_label_096'),
    ('f0_label_192', 'f0_label_192'),
    ('f0_label_384', 'f0_label_384'),
    
    ('REDOsr2000_cfI100_species002_spont070_BW10eN1_IHC0050Hz_IHC7order', 'cf0100_IHC0050Hz'),
    ('REDOsr2000_cfI250_species002_spont070_BW10eN1_IHC0050Hz_IHC7order', 'cf0250_IHC0050Hz'),
    ('REDOsr2000_cfI500_species002_spont070_BW10eN1_IHC0050Hz_IHC7order', 'cf0500_IHC0050Hz'),
    ('REDOsr2000_cf1000_species002_spont070_BW10eN1_IHC0050Hz_IHC7order', 'cf1000_IHC0050Hz'),
    
    ('sr20000_cf100_species002_spont1eN1_BW10eN1_IHC3000Hz_IHC7order', 'low_spont_rate'),
]

list_regex_model_dir = [
    '/saved_models/arch_search_v02_topN/{}/arch_0???/'.format(tag)
    for (tag, name) in list_model_tag_and_name
]
basename_valid_metrics = 'validation_metrics.json'

# Compile list of lists of model psychophysical data
list_list_model_dir = []
list_dict_super = []

# For each entry in list_regex_model_dir, grab all of the models that are globbed by the regex
for regex_model_dir in list_regex_model_dir:
    list_model_dir = []
    list_valid_metric = []
    dict_results_dicts = {ek: [] for ek in experiment_keys}
    dict_human_model_comparison = {
        ek: {
            'human_model_similarity_pval': [],
            'human_model_similarity_coef': [],
        }
        for ek in experiment_keys
    }
    for idx, model_dir in enumerate(sorted(glob.glob(regex_model_dir))):
        fn_valid_metric = os.path.join(model_dir, basename_valid_metrics)
        fn_result_dict = {
            ek: os.path.join(model_dir, experiment_to_basename_map[ek]) for ek in experiment_keys
        }
        if 'snr_pos' in model_dir:
            high_snr_basename = 'EVAL_SOFTMAX_lowharm_v04_bestckpt_results_dict.json'
            fn_result_dict['bernox2005'] = os.path.join(model_dir, high_snr_basename)
            high_snr_basename = 'EVAL_SOFTMAX_transposedtones_v02_bestckpt_results_dict.json'
            fn_result_dict['transposedtones'] = os.path.join(model_dir, high_snr_basename)
        include_model_flag = True
        for ek in experiment_keys:
            if not os.path.exists(fn_result_dict[ek]):
                if 'spl' not in ek:
                    include_model_flag = False
        if include_model_flag:
            list_valid_metric.append(calc_best_metric(fn_valid_metric))
            list_model_dir.append(model_dir)
            # Load results_dict for each model and experiment
            for ek, results_dict_fn in fn_result_dict.items():
                if os.path.exists(results_dict_fn):
                    results_dict = load_results_dict(results_dict_fn)
                    dict_results_dicts[ek].append(results_dict)
                    
                    # Measure human-model similarity for each model and experiment
                    if ek in experiment_to_compfunc_map.keys():
                        compfunc = experiment_to_compfunc_map[ek]
                        compfunc_kwargs = experiment_to_compfunc_kwargs_map[ek]
                        r, p = compfunc(
                            experiment_to_human_results_map[ek],
                            results_dict,
                            **compfunc_kwargs)
                        dict_human_model_comparison[ek]['human_model_similarity_coef'].append(r)
                        dict_human_model_comparison[ek]['human_model_similarity_pval'].append(p)

    dict_super = {}
    for ek in experiment_keys:
        dict_super[ek] = concatenate_dicts(dict_results_dicts[ek])
        dict_super[ek]['human_model_similarity_coef'] = dict_human_model_comparison[ek]['human_model_similarity_coef']
        dict_super[ek]['human_model_similarity_pval'] = dict_human_model_comparison[ek]['human_model_similarity_pval']
        dict_super[ek]['validation_accuracy'] = list_valid_metric

    # Add lists of model results to the master list
    list_list_model_dir.append(list_model_dir)
    list_dict_super.append(dict_super)

    print(regex_model_dir, len(list_model_dir))


/saved_models/arch_search_v02_topN/REDOsr2000_cf1000_species002_spont070_BW10eN1_IHC0050Hz_IHC7order/arch_0???/ 10
/saved_models/arch_search_v02_topN/REDOsr20000_cf100_species002_spont070_BW10eN1_IHC0320Hz_IHC7order/arch_0???/ 10
/saved_models/arch_search_v02_topN/REDOsr20000_cf100_species002_spont070_BW10eN1_IHC1000Hz_IHC7order/arch_0???/ 10
/saved_models/arch_search_v02_topN/sr20000_cf100_species002_spont070_BW10eN1_IHC3000Hz_IHC7order/arch_0???/ 10
/saved_models/arch_search_v02_topN/REDOsr20000_cf100_species002_spont070_BW10eN1_IHC6000Hz_IHC7order/arch_0???/ 10
/saved_models/arch_search_v02_topN/REDOsr20000_cf100_species002_spont070_BW10eN1_IHC9000Hz_IHC7order/arch_0???/ 10
/saved_models/arch_search_v02_topN/sr20000_cf100_species002_spont070_BW05eN1_IHC3000Hz_IHC7order/arch_0???/ 10
/saved_models/arch_search_v02_topN/sr20000_cf100_species002_spont070_BW10eN1_IHC3000Hz_IHC7order/arch_0???/ 10
/saved_models/arch_search_v02_topN/sr20000_cf100_species002_spont070_BW20eN1_IHC3000Hz_IHC7o

In [8]:
DATA_DICT_TO_STORE = {}

for (tag, name), dict_super in zip(list_model_tag_and_name, list_dict_super):
    for ek in experiment_keys:
        DATA_DICT_TO_STORE[name + '-' + ek] = dict_super[ek]

# for k in sorted(DATA_DICT_TO_STORE.keys()):
#     print(k)

fn_stats = 'pitchnet_paper_stats_data_psychophysics_2021AUG05.json'
print('[START] {}'.format(fn_stats))
with open(fn_stats, 'w') as f:
    json.dump(DATA_DICT_TO_STORE, f, cls=util_misc.NumpyEncoder, sort_keys=True)
print('[END] {}'.format(fn_stats))


[START] pitchnet_paper_stats_data_psychophysics_2021AUG05.json
[END] pitchnet_paper_stats_data_psychophysics_2021AUG05.json


In [None]:
# for ek in experiment_keys:
#     print('_'*24, ek, '_'*24)
#     if 'mistuned' not in ek:
#         for k in sorted(list_dict_super[0][ek].keys()):
#             print(k, np.array(list_dict_super[0][ek][k]).shape)


In [None]:
# for k in sorted(DATA_DICT_TO_STORE.keys()):
#     print(k)
# len(DATA_DICT_TO_STORE)


In [None]:
# with open('pitchnet_paper_stats_data_psychophysics_2020AUG09.json', 'r') as f:
#     DATA_DICT = json.load(f)

# for k in sorted(DATA_DICT.keys()):
#     print(k)
    

In [None]:
# for ek in experiment_keys:
#     print('____________________________', ek, '____________________________')
#     X = DATA_DICT['IHC3000Hz-{}'.format(ek)]
#     print('### OLD')
#     for k in sorted(X.keys()):
#         print(k, np.array(X[k]).shape)
#     print('### NEW')
#     for k in sorted(list_dict_super[0][ek].keys()):
#         print(k, np.array(list_dict_super[0][ek][k]).shape)
    
#     if 'f0dl' in X:
#         f0dl_OLD = np.array(X['f0dl'])
#         f0dl_NEW = np.array(list_dict_super[0][ek]['f0dl'])
        
#         f0dl_OLD[f0dl_OLD > 100.0] = 100.0
#         f0dl_NEW[f0dl_NEW > 100.0] = 100.0
#         print('COMPARING f0dl_NEW and f0dl_OLD')
#         print(np.max(np.abs(f0dl_OLD - f0dl_NEW)))
#         print(np.max(f0dl_OLD), np.max(f0dl_NEW), np.min(f0dl_NEW))


In [3]:
# Quick script for copying model checkpoint and results files to new destination for public release

import os
import sys
import glob
import json
import shutil
import numpy as np

sys.path.append('../')
import pitchnet_evaluate_best


list_tag_src_dst = [
    ('sr20000_cf100_species002_spont070_BW10eN1_IHC3000Hz_IHC7order', 'default'),
]

pre_list_basename_to_copy = [
    'config.json',
    'EVAL_SOFTMAX_lowharm_v01_bestckpt_results_dict.json',
    'EVAL_SOFTMAX_transposedtones_v01_bestckpt_results_dict.json',
    'EVAL_SOFTMAX_freqshifted_v01_bestckpt_results_dict.json',
    'EVAL_SOFTMAX_mistunedharm_v01_bestckpt_results_dict.json',
    'EVAL_SOFTMAX_altphase_v01_bestckpt_results_dict.json',
    'EVAL_SOFTMAX_testspl_v03_bestckpt_results_dict.json',
    'EVAL_validation_bestckpt.json',
    'EVAL_validation_bestckpt_results_dict.json',
    'validation_metrics.json',
    'brain_arch.json',
    'brain_model.ckpt-{}.index',
    'brain_model.ckpt-{}.data-00000-of-00001',
]
list_basename_to_copy = []
for bnc in pre_list_basename_to_copy:
    list_basename_to_copy.append(bnc)
    if '_results_dict' in bnc:
        list_basename_to_copy.append(bnc.replace('_results_dict', ''))
        list_basename_to_copy.append(bnc.replace('_results_dict.json', '_f0_label_probs_out.npy'))


pattern_regex_dir_src = '/saved_models/arch_search_v02_topN/{}/arch_0???/'

for tag_src, tag_dst in list_tag_src_dst:
    list_dir_src = glob.glob(pattern_regex_dir_src.format(tag_src))
    
    for dir_src in list_dir_src:
        dir_dst = dir_src.replace(tag_src, tag_dst)
        dir_dst = dir_dst.replace('/saved_models/arch_search_v02_topN/', '')
        dir_dst = os.path.join('/om2/user/msaddler/pitchnet/models', dir_dst)
        
        if not os.path.exists(dir_dst):
            os.makedirs(dir_dst)
        
        ckpt_num = pitchnet_evaluate_best.get_best_checkpoint_number(
            os.path.join(dir_src, 'validation_metrics.json'),
            metric_key='f0_label:accuracy',
            maximize=True,
            checkpoint_number_key='step')
        
        for basename in list_basename_to_copy:
            fn_src = os.path.join(dir_src, basename)
            if '{}' in fn_src:
                fn_src = fn_src.format(ckpt_num)
            
            fn_dst = fn_src.replace(dir_src, dir_dst)
            
            if os.path.exists(fn_src):
                shutil.copyfile(fn_src, fn_dst)
            else:
                print('[WARNING]: no file `{}` found'.format(fn_src))
        
        print(dir_src)
        print(dir_dst)
        print('\n')


Selecting checkpoint 70000 (f0_label:accuracy=0.24030879139900208)
/saved_models/arch_search_v02_topN/sr20000_cf100_species002_spont070_BW10eN1_IHC3000Hz_IHC7order/arch_0083/
/om2/user/msaddler/pitchnet/models/default/arch_0083/


Selecting checkpoint 55000 (f0_label:accuracy=0.20737770199775696)
/saved_models/arch_search_v02_topN/sr20000_cf100_species002_spont070_BW10eN1_IHC3000Hz_IHC7order/arch_0154/
/om2/user/msaddler/pitchnet/models/default/arch_0154/


Selecting checkpoint 45000 (f0_label:accuracy=0.21219712495803833)
/saved_models/arch_search_v02_topN/sr20000_cf100_species002_spont070_BW10eN1_IHC3000Hz_IHC7order/arch_0190/
/om2/user/msaddler/pitchnet/models/default/arch_0190/


Selecting checkpoint 60000 (f0_label:accuracy=0.24883323907852173)
/saved_models/arch_search_v02_topN/sr20000_cf100_species002_spont070_BW10eN1_IHC3000Hz_IHC7order/arch_0191/
/om2/user/msaddler/pitchnet/models/default/arch_0191/


Selecting checkpoint 60000 (f0_label:accuracy=0.23146286606788635)
/saved_mo