In [None]:
import sys
import os
import json
import glob
import h5py

import numpy as np
import scipy.signal
import scipy.io.wavfile
%matplotlib inline
import matplotlib.pyplot as plt
import IPython.display as ipd

import importlib

sys.path.append('/packages/bez2018model')
import bez2018model
importlib.reload(bez2018model)

sys.path.append('/packages/msutil')
import util_figures
importlib.reload(util_figures)
import util_stimuli
importlib.reload(util_stimuli)
import util_misc
importlib.reload(util_misc)

save_dir = '/om2/user/msaddler/pitchnet/assets_psychophysics/figures/archive_2020_09_26_pitchnet_paper_figures_v03/'


In [None]:
# fn = '/om/scratch/Mon/msaddler/data_pitchnet/PND_v08/noise_TLAS_snr_neg10pos10/PND_sr32000_v08_2079000-2100000.hdf5'
# fn = '/om/scratch/Mon/msaddler/data_pitchnet/PND_v08/noise_TLAS_snr_neg10pos10/sr2000_cf1000_species002_spont070_BW10eN1_IHC0050Hz_IHC7order/bez2018meanrates_000_000000-003500.hdf5'
# key_list = util_misc.get_hdf5_dataset_key_list(fn)

# f = h5py.File(fn, 'r')
# for key in key_list:
#     print(key, f[key])
# f.close()


In [None]:
random_seed = 858
np.random.seed(random_seed)
sr = 32000
dur = 0.150
f0 = 200
phase_mode = 'sine'

signal = util_stimuli.complex_tone(f0,
                                   sr,
                                   dur,
                                   harmonic_numbers=np.arange(1, 1000),
                                   phase_mode=phase_mode,
                                   offset_start=False,
                                   strict_nyquist=False)
noise = util_stimuli.modified_uniform_masking_noise(sr,
                                                    dur,
                                                    dBHzSPL=15.0,
                                                    attenuation_start=600.0,
                                                    attenuation_slope=2.0)
signal = util_stimuli.set_dBSPL(signal, 60.0)
y = util_stimuli.combine_signal_and_noise(signal, noise, np.inf)


# fn = '/om/user/msaddler/data_pitchnet/oxenham2004/transposedtones_v01/stim.hdf5'
# with h5py.File(fn, 'r') as f:
#     f0_list = np.unique(f['f0'][:])
#     arg = np.argmin(np.abs(f0_list - 200))      
#     IDX = np.argwhere(np.logical_and(f['f0'][:] == f0_list[arg], f['f_carrier'][:] == 0.0))[0][0]
#     y = f['stimuli/signal'][IDX]
#     sr = f['sr'][0]

# fn = '/om/user/msaddler/data_pitchnet/mooremoore2003/freqshifted_v01/stim.hdf5'
# with h5py.File(fn, 'r') as f:
#     f0_list = np.unique(f['f0'][:])
#     arg = np.argmin(np.abs(f0_list - 200))  
#     print(np.unique(f['spectral_envelope_centered_harmonic'][:]))
#     print(np.unique(f['f0_shift'][:]))
    
#     IDX = np.argwhere(
#         np.logical_and(
#             f['f0'][:] == f0_list[arg],
#             np.logical_and(
#                 f['spectral_envelope_centered_harmonic'][:] == 16,
#                 f['f0_shift'][:] == 0.24
#             )
#         )
#     )[0][0]
#     print(IDX)
#     y = f['stimuli/signal'][IDX]
#     sr = f['sr'][0]

# fn = '/om/scratch/Sun/msaddler/data_pitchnet/PND_v08/noise_TLAS_snr_neg10pos10/PND_sr32000_v08_0000000-0021000.hdf5'
# f = h5py.File(fn, 'r')
# IDX = 14184#np.random.randint(f['stimuli/signal'].shape[0])
# signal = f['stimuli/signal'][IDX]
# noise = f['stimuli/noise'][IDX]
# f0 = f['nopad_f0_mean'][IDX]
# print(IDX, f0)
# sr = f['sr'][0]
# f.close()

# snr = np.inf
# dBSPL = 60
# noise = np.random.randn(noise.shape[0])

# y = util_stimuli.combine_signal_and_noise(signal, noise, snr)
# y = util_stimuli.set_dBSPL(y, dBSPL)

ipd.display(ipd.Audio(y, rate=sr))


In [None]:
np.random.seed(random_seed)

# IHC_cutoff = 3000.0
# BW_scale_factor = 1.0/4
# kwargs_nervegram = {
#     'nervegram_dur': 0.050,
#     'nervegram_fs': 20000,
#     'buffer_start_dur': 0.070,
#     'buffer_end_dur': 0.010,
#     'pin_fs': 100e3,
#     'pin_dBSPL_flag': 0,
#     'pin_dBSPL': None,
#     'species': 2,
#     'bandwidth_scale_factor': BW_scale_factor,
#     'cf_list': None,
#     'num_cf': 100,
#     'min_cf': 125.0,
#     'max_cf': 14e3,
#     'max_spikes_per_train': 500,
#     'num_spike_trains': 1,
#     'cohc': 1.0,
#     'cihc': 1.0,
#     'IhcLowPass_cutoff': IHC_cutoff,
#     'IhcLowPass_order': 7,
#     'spont': 70.0,
#     'noiseType': 0,
#     'implnt': 0,
#     'tabs': 6e-4,
#     'trel': 6e-4,
#     'random_seed': None,
#     'return_vihcs': False,
#     'return_meanrates': True,
#     'return_spike_times': False,
#     'return_spike_tensor': False,
# }

kwargs_nervegram = {
    'nervegram_dur': 0.050,
    'nervegram_fs': 20000,
    'buffer_start_dur': 0.070,
    'buffer_end_dur': 0.010,
    'pin_fs': 100e3,
    'pin_dBSPL_flag': 0,
    'pin_dBSPL': None,
    'species': 2, #4,
    'bandwidth_scale_factor': 1.0, #[80.0] * 100,
    'cf_list': None, #np.arange(125.0, 8125.0, 80),
    'num_cf': 100,
    'min_cf': 125.0,
    'max_cf': 14e3,
    'max_spikes_per_train': 500,
    'num_spike_trains': 1,
    'cohc': 1.0,
    'cihc': 1.0,
    'IhcLowPass_cutoff': 3000,
    'IhcLowPass_order': 7,
    'spont': 0.1, #70.0,
    'noiseType': 0,
    'implnt': 0,
    'tabs': 6e-4,
    'trel': 6e-4,
    'random_seed': None,
    'return_vihcs': False,
    'return_meanrates': True,
    'return_spike_times': False,
    'return_spike_tensor_sparse': False,
    'return_spike_tensor_dense': False,
}
out_dict = bez2018model.nervegram(y, sr, **kwargs_nervegram)
out_dict['nervegram_meanrates'].shape

# out_dict['cf_list'], out_dict['bandwidth_scale_factor']


In [None]:
import util_figures
importlib.reload(util_figures)

figsize=(6, 3.5)
nrows=2
ncols=3
gridspec_kw = {
    'wspace': 0.15,
    'hspace': 0.15,
    'width_ratios': [1, 6, 1],
    'height_ratios': [1, 4],
}
fig, ax_arr = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize, gridspec_kw=gridspec_kw)

nervegram = np.squeeze(out_dict['nervegram_meanrates']).copy()
# print(nervegram.max(), nervegram.min(), nervegram.mean())
# mean_exc = np.mean(nervegram, axis=1)
# mean_nervegram = np.mean(nervegram)
# NZIDX = mean_exc > 0
# nervegram[NZIDX] = nervegram[NZIDX] / np.expand_dims(mean_exc[NZIDX], axis=1)
# nervegram[NZIDX] = mean_nervegram * nervegram[NZIDX]

util_figures.make_stimulus_summary_plot(ax_arr,
                                        ax_idx_waveform=1,
                                        ax_idx_spectrum=3,
                                        ax_idx_nervegram=4,
                                        ax_idx_excitation=5,
                                        waveform=out_dict['signal'],
                                        nervegram=nervegram,
                                        sr_waveform=out_dict['signal_fs'],
                                        sr_nervegram=out_dict['nervegram_fs'],
                                        cfs=out_dict['cf_list'],
                                        tmin=None,#0,
                                        tmax=None,#0.015,
                                        treset=True,
                                        vmin=None,
                                        vmax=None,
                                        erb_freq_axis=True,
                                        spines_to_hide_waveform=[],
                                        spines_to_hide_spectrum=[],
                                        spines_to_hide_excitation=[],
                                        nxticks=6,
                                        nyticks=6,
                                        kwargs_plot={},
                                        limits_buffer=0.2,
                                        ax_arr_clear_leftover=True)

xticks = [0, 50]
ax_arr[1,2].set_xlim(xticks)
ax_arr[1,2].set_xticks(xticks)
ax_arr[1,2].set_xticklabels(xticks)

plt.show()


In [None]:
# import util_figures
# importlib.reload(util_figures)

# figsize=(6, 2.6) #(6, 3.5)
# nrows=1 #2
# ncols=3
# gridspec_kw = {
#     'wspace': 0.15,
#     'hspace': 0.15,
#     'width_ratios': [1, 6, 1],
# #         'height_ratios': [1, 4],
# }
# fig, ax_arr = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize, gridspec_kw=gridspec_kw)

# util_figures.make_stimulus_summary_plot(ax_arr,
#                                         ax_idx_waveform=None, #1,
#                                         ax_idx_spectrum=0, #3,
#                                         ax_idx_nervegram=1, #4,
#                                         ax_idx_excitation=2, #5,
#                                         waveform=out_dict['signal'],
#                                         nervegram=out_dict['nervegram_meanrates'],
#                                         sr_waveform=out_dict['signal_fs'],
#                                         sr_nervegram=out_dict['nervegram_fs'],
#                                         cfs=out_dict['cf_list'],
#                                         tmin=None,
#                                         tmax=None,
#                                         treset=True,
#                                         vmin=None,
#                                         vmax=None,
#                                         erb_freq_axis=False,
#                                         spines_to_hide_waveform=[],
#                                         spines_to_hide_spectrum=[],
#                                         spines_to_hide_excitation=[],
#                                         nxticks=6,
#                                         nyticks=6,
#                                         kwargs_plot={},
#                                         limits_buffer=0.2,
#                                         ax_arr_clear_leftover=True)

# xticks = [0, 200]
# ax_arr = ax_arr.reshape([-1])
# ax_arr[-1].set_xlim(xticks)
# ax_arr[-1].set_xticks(xticks)
# ax_arr[-1].set_xticklabels(xticks)

# plt.show()


# save_fn = os.path.join(save_dir, 'nervegram_full_pitch{:03.0f}Hz_spont1eN1_BW{:02.0f}eN1_IHC{:04.0f}Hz_flat_exc_mean.pdf'.format(f0, 10*BW_scale_factor, IHC_cutoff))
# save_fn = os.path.join(save_dir, 'nervegram_full_pitch{:03.0f}Hz_spont070_BW{:02.0f}eN1_IHC{:04.0f}Hz.pdf'.format(f0, 10*BW_scale_factor, IHC_cutoff))
# save_fn = os.path.join(save_dir, 'nervegram_partial_pitch{:03.0f}Hz_species004_spont070_BWlinear_IHC3000Hz.pdf'.format(f0))

# print(save_fn)
# fig.savefig(save_fn, bbox_inches='tight', pad_inches=0, transparent=False)

# fig.savefig('tmp.pdf', bbox_inches='tight', pad_inches=0, transparent=False)



In [None]:
# for BW_scale_factor in [0.5, 1.0, 2.0]:
for IHC_cutoff in [50, 250, 320, 1000, 3000, 6000, 9000]:
    np.random.seed(random_seed)

#     IHC_cutoff = 3000.0
    BW_scale_factor = 1.0
    kwargs_nervegram = {
        'nervegram_dur': 0.050,
        'nervegram_fs': 20000,
        'buffer_start_dur': 0.070,
        'buffer_end_dur': 0.010,
        'pin_fs': 100e3,
        'pin_dBSPL_flag': 0,
        'pin_dBSPL': None,
        'species': 2,
        'bandwidth_scale_factor': BW_scale_factor,
        'cf_list': None,
        'num_cf': 100,
        'min_cf': 125.0,
        'max_cf': 14e3,
        'max_spikes_per_train': 500,
        'num_spike_trains': 1,
        'cohc': 1.0,
        'cihc': 1.0,
        'IhcLowPass_cutoff': IHC_cutoff,
        'IhcLowPass_order': 7,
        'spont': 70.0,
        'noiseType': 0,
        'implnt': 0,
        'tabs': 6e-4,
        'trel': 6e-4,
        'random_seed': None,
        'return_vihcs': False,
        'return_meanrates': True,
        'return_spike_times': False,
        'return_spike_tensor': False,
    }
    out_dict = bez2018model.nervegram(y, sr, **kwargs_nervegram)
    
    import util_figures
    importlib.reload(util_figures)

    figsize=(6, 2.6) #(6, 3.5)
    nrows=1 #2
    ncols=3
    gridspec_kw = {
        'wspace': 0.15,
        'hspace': 0.15,
        'width_ratios': [1, 6, 1],
#         'height_ratios': [1, 4],
    }
    fig, ax_arr = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize, gridspec_kw=gridspec_kw)

    util_figures.make_stimulus_summary_plot(ax_arr,
                                            ax_idx_waveform=None, #1,
                                            ax_idx_spectrum=0, #3,
                                            ax_idx_nervegram=1, #4,
                                            ax_idx_excitation=2, #5,
                                            waveform=out_dict['signal'],
                                            nervegram=out_dict['nervegram_meanrates'],
                                            sr_waveform=out_dict['signal_fs'],
                                            sr_nervegram=out_dict['nervegram_fs'],
                                            cfs=out_dict['cf_list'],
                                            tmin=None,
                                            tmax=None,
                                            treset=True,
                                            vmin=None,
                                            vmax=None,
                                            spines_to_hide_waveform=[],
                                            spines_to_hide_spectrum=[],
                                            spines_to_hide_excitation=[],
                                            nxticks=6,
                                            nyticks=6,
                                            kwargs_plot={},
                                            limits_buffer=0.2,
                                            ax_arr_clear_leftover=True)

    xticks = [0, 200]
    ax_arr = ax_arr.reshape([-1])
    ax_arr[-1].set_xlim(xticks)
    ax_arr[-1].set_xticks(xticks)
    ax_arr[-1].set_xticklabels(xticks)

    plt.show()

#     save_fn = os.path.join(save_dir, 'nervegram_partial_pitch{:03.0f}Hz_spont070_BW{:02.0f}eN1_IHC{:04.0f}Hz.pdf'.format(f0, 10*BW_scale_factor, IHC_cutoff))
#     print(save_fn)
#     fig.savefig(save_fn, bbox_inches='tight', pad_inches=0, transparent=False)


In [None]:
fn = '/om/scratch/Sun/msaddler/data_pitchnet/PND_v08/noise_TLAS_snr_neg10pos10/PND_sr32000_v08_0000000-0021000.hdf5'
f = h5py.File(fn, 'r')
IDX = 14184#np.random.randint(f['stimuli/signal'].shape[0])
signal = f['stimuli/signal'][IDX]
noise = f['stimuli/noise'][IDX]
f0 = f['nopad_f0_mean'][IDX]
print(IDX, f0)
sr = f['sr'][0]
f.close()


figsize = (1.8, 0.6)
tmin = 0.0
tmax = 0.05
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=figsize)
ax = util_figures.make_stimulus_summary_plot(
    ax,
    ax_idx_waveform=0,
    ax_idx_spectrum=None,
    ax_idx_nervegram=None,
    ax_idx_excitation=None,
    waveform=signal,
    nervegram=None,
    sr_waveform=sr,
    sr_nervegram=None,
    cfs=out_dict['cf_list'],
    tmin=tmin,
    tmax=tmax,
    treset=True,
    vmin=None,
    vmax=None,
    spines_to_hide_waveform=['top', 'bottom', 'left', 'right'],
    spines_to_hide_spectrum=[],
    spines_to_hide_excitation=[],
    nxticks=6,
    nyticks=6,
    kwargs_plot={},
    limits_buffer=0.1,
    ax_arr_clear_leftover=True)

plt.show()
# save_fn = os.path.join(save_dir, 'tmp_schematic_training_stimulus_{}_waveform_signal.pdf'.format(IDX))
# fig.savefig(save_fn, bbox_inches='tight', pad_inches=0, transparent=True)


fig, ax = plt.subplots(nrows=1, ncols=1, figsize=figsize)
ax = util_figures.make_stimulus_summary_plot(
    ax,
    ax_idx_waveform=0,
    ax_idx_spectrum=None,
    ax_idx_nervegram=None,
    ax_idx_excitation=None,
    waveform=noise,
    nervegram=None,
    sr_waveform=sr,
    sr_nervegram=None,
    cfs=out_dict['cf_list'],
    tmin=tmin,
    tmax=tmax,
    treset=True,
    vmin=None,
    vmax=None,
    spines_to_hide_waveform=['top', 'bottom', 'left', 'right'],
    spines_to_hide_spectrum=[],
    spines_to_hide_excitation=[],
    nxticks=6,
    nyticks=6,
    kwargs_plot={},
    limits_buffer=0.1,
    ax_arr_clear_leftover=True)
plt.show()
# save_fn = os.path.join(save_dir, 'tmp_schematic_training_stimulus_{}_waveform_noise.pdf'.format(IDX))
# fig.savefig(save_fn, bbox_inches='tight', pad_inches=0, transparent=True)


figsize=(2.4, 1.6)
nrows=1
ncols=1
fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize)
ax = util_figures.make_stimulus_summary_plot(
    ax,
    ax_idx_waveform=None,
    ax_idx_spectrum=None,
    ax_idx_nervegram=0,
    ax_idx_excitation=None,
    waveform=None,
    nervegram=out_dict['nervegram_meanrates'],
    sr_waveform=sr,
    sr_nervegram=out_dict['nervegram_fs'],
    cfs=out_dict['cf_list'],
    tmin=tmin,
    tmax=tmax,
    treset=True,
    vmin=None,
    vmax=None,
    spines_to_hide_waveform=[],
    spines_to_hide_spectrum=[],
    spines_to_hide_excitation=[],
    nxticks=0,
    nyticks=0,
    kwargs_plot={},
    limits_buffer=0.1,
    ax_arr_clear_leftover=True)
ax[0].set_xlabel(None)
ax[0].set_ylabel(None)
plt.show()
# save_fn = os.path.join(save_dir, 'tmp_schematic_training_stimulus_{}_nervegram_signal_in_noise.pdf'.format(IDX))
# fig.savefig(save_fn, bbox_inches='tight', pad_inches=0, transparent=True)


In [None]:
sys.path.append('assets_datasets/')
import stimuli_generate_BernsteinOxenhamFixedFilter
import matplotlib.cm


def get_level_cmap(level,
                   level_min=None,
                   level_max=None,
                   cmap_name='Greys',
                   cmap_n=None):
    '''
    '''
    # Truncate level values before fitting cmap
    level = np.array(level)
    if level_min is not None:
        level[level < level_min] = level_min
    if level_max is not None:
        level[level > level_max] = level_max
    # Normalize level values and define cmap
    level_normalized = level - np.min(level)
    level_normalized = level_normalized / np.max(level_normalized)
    cmap = matplotlib.cm.get_cmap(cmap_name, cmap_n)
    # Return a function that applies the same normalization to
    # new level values and calls the defined cmap
    def level_cmap(x):
        x -= np.min(level)
        x /= np.max(level - np.min(level))
        return cmap(x)
    return level_cmap


def schematic_spectrogram(ax, times, freqs,
                          colors='k',
                          kwargs_plot_update={}):
    '''
    '''
    kwargs_plot = {
        'lw': 3,
        'ls': '-',
        'marker': '',
    }
    kwargs_plot.update(kwargs_plot_update)
    if not len(colors) == len(freqs):
        colors = [colors] * len(freqs)
    for f, c in zip(freqs, colors):
        kwargs_plot['color'] = c
        ax.plot(times, np.ones_like(times)*f, **kwargs_plot)
    return ax


def bernox_schematic_spectrogram(fs=32000,
                                 highpass_filter_cutoff=2.5e3,
                                 lowpass_filter_cutoff=3.5e3,
                                 filter_order=4,
                                 threshold_dBSPL=33.3,
                                 component_dBSL=15.0,
                                 base_f0=300,
                                 delta_f0_list=[1.12, 1.0],
                                 label_list=['Tone 1', 'Tone 2'],
                                 lh=5,
                                 figsize=(2.25, 2.25),
                                 dur=1.0,
                                 gap=0.6,
                                 limits_buffer=0.25,
                                 ylimits=[0, 7e3],
                                 max_harm_labels_freq=None,
                                 include_harm_labels=True,
                                 fontsize_harm=6,
                                 fontsize_label=10,
                                 kwargs_plot_update={'lw':2}):
    '''
    '''
    baseline_freq_response = stimuli_generate_BernsteinOxenhamFixedFilter.get_bandpass_filter_frequency_response(
        highpass_filter_cutoff,
        lowpass_filter_cutoff,
        fs=fs,
        order=filter_order)
    desired_fl = base_f0 * lh
    desired_fl_gain_in_dB = -1 * component_dBSL
    fixed_freq_response = stimuli_generate_BernsteinOxenhamFixedFilter.shift_bandpass_filter_frequency_response(
        desired_fl,
        desired_fl_gain_in_dB,
        fs=fs,
        unshifted_passband=None,
        frequency_response_in_dB=baseline_freq_response)

    fig, ax = plt.subplots(nrows=1, ncols=1, figsize=figsize)
    times = np.array([0.0, dur])
    
    # Define level_cmap using base_f0
    freqs = np.arange(base_f0, fs/2, base_f0)
    level = fixed_freq_response(freqs)
    IDX = freqs >= (lh - 1) * base_f0
    IDX = np.logical_or(IDX, level >= level[IDX][0])        
    level = level[IDX]
    freqs = freqs[IDX]
    level_cmap = get_level_cmap(level, level_min=level[0])
    
    for itr_f0, f0_delta in enumerate(delta_f0_list):
        f0 = base_f0 * f0_delta
        freqs = np.arange(f0, fs/2, f0)

        level = fixed_freq_response(freqs)
        IDX = freqs >= (lh - 1) * f0
        IDX = np.logical_or(IDX, level >= level[IDX][0])        
        level = level[IDX]
        freqs = freqs[IDX]
        colors = level_cmap(level)
        ax = schematic_spectrogram(ax, times, freqs, colors=colors, kwargs_plot_update=kwargs_plot_update)

        if include_harm_labels:
            if max_harm_labels_freq is None:
                max_harm_labels_freq = ylimits[-1]*0.875
            for f_idx, f in enumerate(np.arange(f0, max_harm_labels_freq, f0)):
                ax.text(times[0]-gap/3, f, '{}'.format(f_idx + 1),
                        ha='center',  va='center', fontsize=fontsize_harm)

        label_yval = ylimits[-1]*0.95
        ax.text(np.mean(times), label_yval, label_list[itr_f0],
                ha='center',  va='center', fontsize=fontsize_label)

        times = times + gap + dur

    [xb, yb, dxb, dyb] = ax.dataLim.bounds
    xlimits = [xb - limits_buffer * dxb, xb + (1 + limits_buffer) * dxb]

    ax = util_figures.format_axes(ax,
        str_xlabel='Time',
        str_ylabel='Frequency',
        xlimits=xlimits,
        ylimits=ylimits,
        xticklabels=[],
        yticklabels=[],
        spines_to_hide=['right', 'top'],
        minor_tick_params_kwargs_update={},
        major_tick_params_kwargs_update={'length':0})

    return fig, ax


def f0dl_schematic_spectrogram(f0_list=[400.0, 400*1.12],
                               label_list=['Tone 1', 'Tone 2'],
                               harmonics=[1],#[1,2,3,4,5,6,7,8,9,10],
                               colors='k',
                               figsize=(2.5, 2.5),
                               dur=1.0,
                               gap=0.6,
                               limits_buffer=0.25,
                               ylimits=[0, 1e3],
                               include_harm_labels=False,
                               include_noise=True,
                               noise_level_cmap=None,
                               noise_level=None,
                               noise_band=None,
                               fontsize_harm=6,
                               fontsize_label=10,
                               kwargs_plot_update={'lw':2},
                               level_cmap=None):
    '''
    '''
    fig, ax = plt.subplots(nrows=1, ncols=1, figsize=figsize)
    times = np.array([0.0, dur])
    level_cmap = None
    for itr_f0, f0 in enumerate(f0_list):
        freqs = f0 * np.array(harmonics)
        ax = schematic_spectrogram(ax, times, freqs, colors=colors, kwargs_plot_update=kwargs_plot_update)

        if include_harm_labels:
            for h, f in zip(harmonics, freqs):
                ax.text(times[0]-gap/3, f, '{}'.format(h),
                        ha='center',  va='center', fontsize=fontsize_harm)

        label_yval = ylimits[-1]*0.95
        ax.text(np.mean(times), label_yval, label_list[itr_f0],
                ha='center',  va='center', fontsize=fontsize_label)

        times = times + gap + dur

    [xb, yb, dxb, dyb] = ax.dataLim.bounds
    xlimits = [xb - limits_buffer * dxb, xb + (1 + limits_buffer) * dxb]

    if include_noise:
        if noise_band is None:
            noise_band = [0, ylimits[-1]*0.875]
        noise_time = [-gap/2, times[-1] - dur - gap/2]
        if noise_level is None:
            noise_level = np.arange(0, 1, 0.01)
        if noise_level_cmap is None:
            noise_level_cmap = get_level_cmap(noise_level)
        noise_image = noise_level_cmap(noise_level)
        noise_image = np.expand_dims(noise_image, 1)
        ax.imshow(noise_image,
                  aspect='auto',
                  extent=noise_time+noise_band,
                  zorder=-1)

    ax = util_figures.format_axes(ax,
        str_xlabel='Time',
        str_ylabel='Frequency',
        xlimits=xlimits,
        ylimits=ylimits,
        xticklabels=[],
        yticklabels=[],
        spines_to_hide=['right', 'top'],
        minor_tick_params_kwargs_update={},
        major_tick_params_kwargs_update={'length':0})

    return fig, ax


In [None]:
# fig, ax = bernox_schematic_spectrogram(delta_f0_list=[1.12, 1.0])
# plt.show()
# save_fn = os.path.join(save_dir, 'schematic_stimulus_bernox_spectrogram.pdf')
# fig.savefig(save_fn, bbox_inches='tight', transparent=False)
# print(save_fn)

# for lh in [2, 12]:
#     fig, ax = bernox_schematic_spectrogram(lh=lh,
#                                            ylimits=[0, 1e4],
#                                            figsize=(2.0, 3.0),
#                                            include_harm_labels=True,
#                                            fontsize_harm=4)
#     plt.show()
    
#     save_fn = os.path.join(save_dir, 'schematic_stimulus_bernox_spectrogram_example_lh{:02d}.pdf'.format(lh))
#     fig.savefig(save_fn, bbox_inches='tight', transparent=False)
#     print(save_fn)



# # Define noise band, frequencies, level, and colormap
# noise_band = [0, 6e3]
# noise_freqs = np.arange(*noise_band, 1e1)
# noise_level = np.zeros_like(noise_freqs)
# nzidx = noise_freqs > 0
# noise_level[nzidx] = 2.0 * np.log2(np.abs(noise_freqs[nzidx]) / 600.0)
# noise_level_cmap = get_level_cmap(noise_level)

# for snr_idx, snr in enumerate([3, 11]):
#     fig, ax = f0dl_schematic_spectrogram(figsize=(2.0, 3.0),
#                                          colors='k',
#                                          noise_band=noise_band,
#                                          noise_level=noise_level-snr,
#                                          noise_level_cmap=noise_level_cmap)
#     plt.show()
    
#     save_fn = os.path.join(save_dir, 'schematic_stimulus_bernox_spectrogram_example_snr{:02d}.pdf'.format(snr_idx))
#     fig.savefig(save_fn, bbox_inches='tight', transparent=False)
#     print(save_fn)



list_colors = [
    [0.75] * 3,
    [0.25] * 3,
    [0] * 3,
]
list_lw = [
    1.0,
    3.0,
    2.0,
]

for spl_idx in range(len(list_lw)):
    
    fig, ax = f0dl_schematic_spectrogram(figsize=(1.5, 3.0),
                                         limits_buffer=0.16,
                                         colors=list_colors[spl_idx],
                                         include_noise=False,
                                         kwargs_plot_update={'lw':list_lw[spl_idx]})
    plt.show()
    
#     save_fn = os.path.join(save_dir, 'schematic_stimulus_bernox_spectrogram_example_spl{:02d}.pdf'.format(spl_idx))
#     fig.savefig(save_fn, bbox_inches='tight', transparent=False)
#     print(save_fn)


In [None]:
f0 = 100
fs = 32000
dur = 1.0

freqs = np.arange(1, 12+1) * f0
IDX = 5
IDX_base_f = freqs[IDX]
IDX_mistuned_f = IDX_base_f * 1.08
freqs[IDX] = IDX_mistuned_f

colors = ['k'] * len(freqs)

fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(1.25, 2.25))
times = np.array([0, dur])
ax = schematic_spectrogram(ax, times, freqs, colors=colors, kwargs_plot_update={'lw':2})


kwargs_arrow = {
    'x': dur * 1.25,
    'dx': 0,
    'dy': IDX_mistuned_f - IDX_base_f,
    'y': IDX_base_f,
    'width': 0.01,
    'length_includes_head': False,
    'head_width': 0.1,
    'head_length': (IDX_mistuned_f - IDX_base_f)/2,
    'overhang': 0.0,
    'head_starts_at_zero': False,
    'fc': 'k',
    'ec': 'k',
}
ax.arrow(**kwargs_arrow)

limits_buffer = 0.2
ax.update_datalim([[kwargs_arrow['x'], kwargs_arrow['y']]])
[xb, yb, dxb, dyb] = ax.dataLim.bounds
limits_x = [xb - limits_buffer * dxb, xb + (1 + limits_buffer) * dxb]

ax = util_figures.format_axes(ax,
    str_xlabel='Time',
    str_ylabel='Frequency',
    fontsize_labels=12,
    fontsize_ticks=12,
    fontweight_labels=None,
    xscale='linear',
    yscale='linear',
    xlimits=limits_x,
    ylimits=[0, freqs[-1] + f0/2],
    xticks=[],
    yticks=np.arange(f0, freqs[-1] + f0/2, f0),
    xticks_minor=None,
    yticks_minor=None,
    xticklabels=None,
    yticklabels=[],
    spines_to_hide=['right', 'top'],
    major_tick_params_kwargs_update={'length': 0},
    minor_tick_params_kwargs_update={})

ax.grid(linewidth=0.2, color='k')

plt.show()

# save_fn = os.path.join(save_dir, 'schematic_stimulus_mistuned_spectrogram.pdf')
# fig.savefig(save_fn, bbox_inches='tight', transparent=False)


In [None]:
import stimuli_generate_FrequencyShiftedComplexes

list_spectral_envelope_params = [
    {'spectral_envelope_centered_harmonic': 5, 'spectral_envelope_bandwidth_in_harmonics': 3}, # "RES"
#     {'spectral_envelope_centered_harmonic': 11, 'spectral_envelope_bandwidth_in_harmonics': 5}, # "INT"
    {'spectral_envelope_centered_harmonic': 16, 'spectral_envelope_bandwidth_in_harmonics': 5}, # "UNRES"
]

list_colors = [
    'k',
#     'b',
    'r',
]
list_colors = util_figures.get_color_list(1000, cmap_name='coolwarm')
list_colors = [list_colors[0], list_colors[-1]]
print(list_colors)

f0_shift = 0.24
f0 = 100
fs = 32000
dur = 1.0


for spectral_envelope_params, c in zip(list_spectral_envelope_params, list_colors):
    harmonic_centered = spectral_envelope_params['spectral_envelope_centered_harmonic']
    harmonic_bandwidth = spectral_envelope_params['spectral_envelope_bandwidth_in_harmonics']
    f_center = f0 * harmonic_centered
    f_bandwidth = f0 * harmonic_bandwidth
    spectral_envelope = stimuli_generate_FrequencyShiftedComplexes.get_MooreMoore2003_spectral_envelope(
        f0, f_center, f_bandwidth)
    
    frequencies = np.arange(f0, fs/2, f0, dtype=np.float32)
    frequencies = frequencies + f0*f0_shift
    amplitudes = spectral_envelope(frequencies)
    frequencies = frequencies[amplitudes > 0]
    
#     amplitude_level_cmap = get_level_cmap(amplitudes)
#     colors = amplitude_level_cmap(amplitudes[amplitudes > 0])
    colors = c
    
    fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(1.25, 2.25))
    times = np.array([0, dur])
    ax = schematic_spectrogram(ax, times, frequencies, colors=colors, kwargs_plot_update={'lw':2})
    
#     for f in frequencies:
#         kwargs_arrow = {
#             'x': dur * 1.15,
#             'dx': 0,
#             'dy': f0*f0_shift,
#             'y': f - f0*f0_shift,
#             'width': 0.005,
#             'length_includes_head': False,
#             'head_width': 0.025,
#             'head_length': (f0*f0_shift)/2,
#             'overhang': 0.0,
#             'head_starts_at_zero': False,
#             'fc': 'k',
#             'ec': 'k',
#         }
#         ax.arrow(**kwargs_arrow)
    
    kwargs_arrow = {
        'x': dur * 1.5,
        'dx': 0,
        'dy': f0 * 3,
        'y': np.mean(frequencies) - f0*1.5,
        'width': 0.01,
        'length_includes_head': True,
        'head_width': 0.15,
        'head_length': (f0*f0_shift)*3,
        'overhang': 0.0,
        'head_starts_at_zero': False,
        'fc': 'k',
        'ec': 'k',
    }
    ax.arrow(**kwargs_arrow)
    ax.update_datalim([[kwargs_arrow['x'], kwargs_arrow['y']]])
    
    limits_buffer = 0.2
    [xb, yb, dxb, dyb] = ax.dataLim.bounds
    limits_x = [xb - limits_buffer * dxb, xb + (1 + limits_buffer/2) * dxb]

    ax = util_figures.format_axes(ax,
        str_xlabel='Time',
        str_ylabel='Frequency',
        fontsize_labels=12,
        fontsize_ticks=12,
        fontweight_labels=None,
        xscale='linear',
        yscale='linear',
        xlimits=limits_x,
        ylimits=[0, 20.5*f0],
        xticks=[],
        yticks=np.arange(f0, 20.5*f0, f0),
        xticks_minor=None,
        yticks_minor=None,
        xticklabels=None,
        yticklabels=[],
        spines_to_hide=['right', 'top'],
        major_tick_params_kwargs_update={'length': 0},
        minor_tick_params_kwargs_update={})

    ax.grid(linewidth=0.2, color='k')

    plt.show()

#     ch = spectral_envelope_params['spectral_envelope_centered_harmonic']
#     save_fn = os.path.join(save_dir, 'schematic_stimulus_freqshift_spectrogram_{}.pdf'.format(ch))
#     fig.savefig(save_fn, bbox_inches='tight', transparent=False)


In [None]:
sys.path.append('assets_datasets/')
import stimuli_generate_TransposedTones

kwargs_nervegram = {
    'nervegram_dur': 0.050,
    'nervegram_fs': 20000,
    'buffer_start_dur': 0.070,
    'buffer_end_dur': 0.010,
    'pin_fs': 100e3,
    'pin_dBSPL_flag': 0,
    'pin_dBSPL': None,
    'species': 2,
    'bandwidth_scale_factor': 1.0,
    'cf_list': None,
    'num_cf': 100,
    'min_cf': 125.0,
    'max_cf': 14e3,
    'max_spikes_per_train': 500,
    'num_spike_trains': 1,
    'cohc': 1.0,
    'cihc': 1.0,
    'IhcLowPass_cutoff': 3000,
    'IhcLowPass_order': 7,
    'spont': 70.0,
    'noiseType': 0,
    'implnt': 0,
    'tabs': 6e-4,
    'trel': 6e-4,
    'random_seed': None,
    'return_vihcs': False,
    'return_meanrates': True,
    'return_spike_times': False,
    'return_spike_tensor': False,
}

dict_stimuli = {
    'PT': {'f_carrier': 0.0, 'f_envelope': 200},
    'TT': {'f_carrier': 6350.0, 'f_envelope': 200},
}

for key in dict_stimuli.keys():
    np.random.seed(858)
    sr = 32000
    f_carrier = dict_stimuli[key]['f_carrier']
    f_envelope = dict_stimuli[key]['f_envelope']
    signal = stimuli_generate_TransposedTones.get_Oxenham2004_transposed_tone(
        f_carrier, f_envelope, fs=sr, dur=0.150, buffer_dur=1.0,
        dBSPL=70.0, offset_start=False, lowpass_filter_envelope=True)
    noise = util_stimuli.modified_uniform_masking_noise(sr,
                                                        dur=0.150,
                                                        dBHzSPL=15.0,
                                                        attenuation_start=600.0,
                                                        attenuation_slope=2.0)
    y = util_stimuli.combine_signal_and_noise(signal, noise, 30.0)
    dict_stimuli[key]['out_dict'] = bez2018model.nervegram(y, sr, **kwargs_nervegram)


In [None]:
for key in dict_stimuli.keys():
    out_dict = dict_stimuli[key]['out_dict']

    figsize=(1.8, 1.8)
    nrows=2
    ncols=1
    gridspec_kw = {
        'hspace': 0.1,
        'height_ratios': [1, 4],
    }
    fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize, gridspec_kw=gridspec_kw)
    ax = util_figures.make_stimulus_summary_plot(ax,
                                                ax_idx_waveform=0,
                                                ax_idx_spectrum=None,
                                                ax_idx_nervegram=1,
                                                ax_idx_excitation=None,
                                                waveform=out_dict['signal'],
                                                nervegram=out_dict['nervegram_meanrates'],
                                                sr_waveform=out_dict['signal_fs'],
                                                sr_nervegram=out_dict['nervegram_fs'],
                                                cfs=out_dict['cf_list'],
                                                tmin=None,
                                                tmax=None,
                                                treset=True,
                                                vmin=None,
                                                vmax=None,
                                                spines_to_hide_waveform=['top', 'bottom', 'left', 'right'],
                                                spines_to_hide_spectrum=[],
                                                spines_to_hide_excitation=[],
                                                nxticks=0,
                                                nyticks=0,
                                                kwargs_plot={},
                                                limits_buffer=0.1,
                                                ax_arr_clear_leftover=True)
    ax[1].set_xlabel('Time')
    ax[1].set_ylabel('Frequency')
    
    if dict_stimuli[key]['f_carrier'] == 0:
        ax[0].set_title('Pure tone')
    else:
        ax[0].set_title('Transposed tone')
    plt.show()
    
#     save_fn = os.path.join(save_dir, 'schematic_stimulus_transposedtones_nervegram_{}.pdf'.format(key))
#     fig.savefig(save_fn, bbox_inches='tight', pad_inches=0, transparent=False)


In [None]:
sys.path.append('assets_datasets/')
import stimuli_generate_AltPhaseHarmonics

np.random.seed(4078)

fs = 32000
dur = 0.150
f0 = 125
passband_component_dBSPL=50.0

list_filter_params = [
    {'fl': 125.0, 'fh': 625.0, 'fs': fs, 'order':8},
#     {'fl': 1375.0, 'fh': 1875.0, 'fs': fs, 'order':8},
#     {'fl': 3900.0, 'fh': 5400.0, 'fs': fs, 'order':8},
]
filter_params = list_filter_params[0]
frequency_response_in_dB = stimuli_generate_AltPhaseHarmonics.get_bandpass_filter_frequency_response(
    **filter_params)

base_f0 = 125
list_stimuli_dict = [
    {'phase_mode': 'sine', 'f0': base_f0},
#     {'phase_mode': 'rand', 'f0': base_f0},
    {'phase_mode': 'alt', 'f0': base_f0},
    {'phase_mode': 'sine', 'f0': 2*base_f0},
]

for stimulis_dict in list_stimuli_dict:
    f0 = stimulis_dict['f0']
    phase_mode = stimulis_dict['phase_mode']
    harmonic_freqs = np.arange(f0, fs/2, f0)
    harmonic_numbers = harmonic_freqs / f0
    harmonic_dBSPL = passband_component_dBSPL + frequency_response_in_dB(harmonic_freqs)
    amplitudes = 20e-6 * np.power(10, (harmonic_dBSPL/20))
    y = util_stimuli.complex_tone(f0, fs, dur,
                                  harmonic_numbers=harmonic_numbers,
                                  amplitudes=amplitudes,
                                  phase_mode=phase_mode,
                                  offset_start=False)

    figsize=(1.75, 0.75)
    nrows=1
    ncols=1
    tmin = 0.0125
    tmax = tmin + 3/base_f0
    fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize)
    ax = util_figures.make_stimulus_summary_plot(ax,
                                        ax_idx_waveform=0,
                                        ax_idx_spectrum=None,
                                        ax_idx_nervegram=None,
                                        ax_idx_excitation=None,
                                        waveform=y,
                                        nervegram=None,
                                        sr_waveform=fs,
                                        sr_nervegram=None,
                                        cfs=None,
                                        tmin=tmin,
                                        tmax=tmax,
                                        treset=True,
                                        vmin=None,
                                        vmax=None,
                                        spines_to_hide_waveform=[],
                                        spines_to_hide_spectrum=[],
                                        spines_to_hide_excitation=[],
                                        nxticks=0,
                                        nyticks=0,
                                        kwargs_plot={},
                                        limits_buffer=0.1,
                                        ax_arr_clear_leftover=True)

    ax[0].set_title('{} phase\n({:.0f}Hz F0)'.format(phase_mode.upper(), f0))
    ax[0].set_xticks(np.arange(0, tmax-tmin, 1/(2*base_f0)))
    major_tick_params_kwargs = {
        'axis': 'both',
        'which': 'major',
        'length': 3,
        'direction': 'in',
        'top': True,
        'bottom': True,
    }
    ax[0].tick_params(**major_tick_params_kwargs)
    plt.show()
    
#     save_fn = os.path.join(save_dir, 'schematic_stimulus_waveform_{}phase_{:.0f}f0.pdf'.format(phase_mode.upper(), f0))
#     fig.savefig(save_fn, bbox_inches='tight', pad_inches=0, transparent=False)


In [None]:
list_stim = [
    (125, 'SINE'),
    (125, 'ALT'),
    (250, 'SINE'),
]

gridspec_kw = {
    'hspace': 0.1,
    'height_ratios': [4, 1],
}
fig, ax_arr = plt.subplots(nrows=2, ncols=len(list_stim), figsize=(5.25, 2.25), gridspec_kw=gridspec_kw)

kwargs_grid = {
    'color': 'k',
    'ls': '--',
    'lw': 0.2
}
kwargs_plot = {
    'color': 'k',
    'lw': 0.6
}
kwargs_plot_phase_dot = util_misc.recursive_dict_merge(kwargs_plot, {'marker': '.', 'ms':2, 'ls':''})
base_f0 = 125
base_period = 1/base_f0
fs = 32000
dur = 2*base_period
t = np.arange(0, dur, 1/fs) - 0.25 * base_period
IDX_xticks = t % (base_period/2) == 0
xticks = t[IDX_xticks]
harmonics = np.array([1,2,3,4,5,6])
harmonics_labels = ['$1^{st}$', '$2^{nd}$', '$3^{rd}$', '$4^{th}$', '$5^{th}$', '$6^{th}$']
offset_list = (harmonics-1)*2.5

for col, (f0, phase_mode) in enumerate(list_stim):
    Y = np.zeros([len(harmonics), len(t)])
    for itr_h, h in enumerate(harmonics):
        phase = 0
        if (phase_mode.lower() == 'alt') and (h % 2 == 0):
            phase = np.pi/2
        ytmp = np.sin(2*np.pi*(f0*h)*t + phase)
        Y[itr_h] = ytmp
        
        ax_arr[0, col].plot(t, ytmp+offset_list[itr_h], **kwargs_plot)
        ax_arr[0, col].plot(t[IDX_xticks], ytmp[IDX_xticks]+offset_list[itr_h], **kwargs_plot_phase_dot)
        
    yticks = np.array(offset_list)
    yticklabels = harmonics[:len(yticks)]

    kwargs_format_axes = {
        'spines_to_hide': [],
        'xticks': xticks,
        'yticks': offset_list,
        'xticklabels': [],
        'yticklabels': [],
        'xlimits': [t[0], t[-1]],
        'major_tick_params_kwargs_update': {'axis': 'x', 'length': 0},
        'str_title': '{} phase\n({}Hz F0)'.format(phase_mode, f0)
    }
    
    if col == 0:
        kwargs_format_axes['yticklabels'] = harmonics_labels
        kwargs_format_axes['str_ylabel'] = 'Harmonic'

    ax_arr[0, col] = util_figures.format_axes(ax_arr[0, col], **kwargs_format_axes)
    ax_arr[0, col].xaxis.grid(**kwargs_grid)
    
    if col == 0:
        for tick in ax_arr[0, col].yaxis.get_major_ticks():
            tick.label1.set_horizontalalignment('left')
        ax_arr[0, col].tick_params(axis='y', pad=18)


    y = np.sum(Y, axis=0)
    ax_arr[1, col].plot(t, y, **kwargs_plot)
    ax_arr[1, col].plot(t[IDX_xticks], y[IDX_xticks], **kwargs_plot_phase_dot)

    kwargs_format_axes.update({'yticks': [0], 'str_ylabel': None, 'ylimits': [-5,5], 'str_title': None})
    if col == 0:
        kwargs_format_axes.update({'yticklabels': ['Sum']})
    ax_arr[1, col] = util_figures.format_axes(ax_arr[1, col], **kwargs_format_axes)
    ax_arr[1, col].xaxis.grid(**kwargs_grid)

plt.show()

# save_fn = os.path.join(save_dir, 'schematic_stimulus_alt_vs_sine_phase.pdf')
# fig.savefig(save_fn, bbox_inches='tight', pad_inches=0.05, transparent=False)
# print(save_fn)


In [None]:
import os
import sys
import json
import glob
import numpy as np
%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.cm
import matplotlib.ticker

sys.path.append('/om2/user/msaddler/python-packages/msutil')
import util_figures
import util_stimuli

sys.path.append('assets_datasets/')
import stimuli_f0_labels

sys.path.append('assets_psychophysics/')
import f0dl_bernox


regex_model_dir = '/nobackup/scratch/*/msaddler/pitchnet/saved_models/arch_search_v02/arch_0191/'
model_dir = glob.glob(regex_model_dir)[0]
fn_valid = os.path.join(model_dir, 'EVAL_validation_bestckpt.json')

f0_label_true_key = 'f0_label:labels_true'
f0_label_pred_key = 'f0_label:labels_pred'
kwargs_f0_bins = {}

expt_dict = f0dl_bernox.load_f0_expt_dict_from_json(fn_valid,
                                                    metadata_key_list=[],
                                                    f0_label_true_key=f0_label_true_key,
                                                    f0_label_pred_key=f0_label_pred_key)
expt_dict = f0dl_bernox.add_f0_estimates_to_expt_dict(expt_dict,
                                                      f0_label_true_key=f0_label_true_key,
                                                      f0_label_pred_key=f0_label_pred_key)
f0_bins = stimuli_f0_labels.get_f0_bins(**kwargs_f0_bins)
f0_label_true = expt_dict[f0_label_true_key]
f0_label_pred = expt_dict[f0_label_pred_key]
    
confusion_matrix = np.zeros([f0_bins.shape[0], f0_bins.shape[0]])
confusion_matrix_counts = np.zeros_like(confusion_matrix)

for t, p in zip(f0_label_true, f0_label_pred):
    confusion_matrix[p, t] += 1
    confusion_matrix_counts[:, t] += 1

confusion_matrix = confusion_matrix / confusion_matrix_counts


In [None]:
fontsize_title = 12
fontsize_labels = 12
fontsize_legend = 12
fontsize_ticks = 12
figsize = (4, 3)
nticks = 5

tmp_confusion_matrix = confusion_matrix.copy()
floor_val = -4
tmp_confusion_matrix = np.log10(tmp_confusion_matrix)
tmp_confusion_matrix[tmp_confusion_matrix < floor_val] = floor_val

vmin = floor_val #np.min(np.min(tmp_confusion_matrix))
vmax = 0 #np.max(np.max(tmp_confusion_matrix))
vticks = np.linspace(vmin, vmax, nticks)
vticklabels = [r'$10^{{{:.0f}}}$'.format(v) for v in vticks]

fig, ax = plt.subplots(nrows=1, ncols=1, figsize=figsize)
cmap = ax.imshow(tmp_confusion_matrix, origin='lower', aspect='auto',
                 extent=[0, tmp_confusion_matrix.shape[1], 0, tmp_confusion_matrix.shape[0]],
                 cmap='Greens_r', vmin=vmin, vmax=vmax)
cbar = plt.colorbar(cmap, ax=ax, pad=0.05)
cbar.ax.set_ylabel('Proportion of stimuli', fontsize=fontsize_labels)
cbar.set_ticks(vticks)
cbar.set_ticklabels(vticklabels)
# cbar.ax.yaxis.set_major_formatter(formatter)
cbar.ax.tick_params(direction='out',
                    axis='both',
                    which='both',
                    labelsize=fontsize_ticks,
                    length=fontsize_ticks/2)

xtick_values = np.array([80.0, 160.0, 320.0, 640.0])
xticks = [idx for idx in range(f0_bins.shape[0]) if f0_bins[idx] in xtick_values]
xticklabels = ['{:.0f}'.format(t) for t in f0_bins[xticks]]

ax = util_figures.format_axes(ax,
                              str_xlabel='True F0 (Hz)',
                              str_ylabel='Reported F0 (Hz)',
                              fontsize_labels=fontsize_labels,
                              fontsize_ticks=fontsize_ticks,
                              fontweight_labels=None,
                              xscale='linear',
                              yscale='linear',
                              xlimits=None,
                              ylimits=None,
                              xticks=xticks,
                              yticks=xticks,
                              xticks_minor=None,
                              yticks_minor=None,
                              xticklabels=xticklabels,
                              yticklabels=xticklabels,
                              spines_to_hide=[],
                              major_tick_params_kwargs_update={},
                              minor_tick_params_kwargs_update={})

plt.tight_layout()
plt.show()

# save_fn = os.path.join(save_dir, 'schematic_confusion_matrix_arch0191.pdf')
# fig.savefig(save_fn, bbox_inches='tight', pad_inches=0, transparent=False)


In [None]:
import os
import sys
import json
import glob
import copy
import numpy as np
%matplotlib inline
import matplotlib.pyplot as plt

sys.path.append('/om2/user/msaddler/python-packages/msutil')
import util_figures
import util_stimuli
import util_misc

valid_regex_format = '/om/scratch/*/msaddler/pitchnet/saved_models/arch_search_v02/arch_{:04d}/validation_metrics.json'


def get_valid_trace(valid_metrics_fn, metric_key='f0_label:accuracy', checkpoint_number_key='step'):
    '''
    '''
    with open(valid_metrics_fn) as f:
        valid_metrics_dict = json.load(f)
    metric_values = np.array(valid_metrics_dict[metric_key])
    checkpoint_numbers = np.array(valid_metrics_dict[checkpoint_number_key])
    return checkpoint_numbers, metric_values


def get_valid_metric(model_dir,
                     results_dict_basename='EVAL_validation_bestckpt_results_dict.json',
                     best_metric_key='f0_pct_error_median'):
    '''
    '''
    with open(os.path.join(model_dir, results_dict_basename)) as f:
        results_dict = json.load(f)
    return results_dict[best_metric_key]


def calc_num_layers(brain_arch_fn):
    with open(brain_arch_fn) as f: brain_arch = json.load(f)
    num_conv_layers = 0
    for layer_dict in brain_arch:
        if layer_dict['layer_type'] == 'tf.layers.conv2d':
            num_conv_layers = num_conv_layers + 1
    return num_conv_layers


list_arch_nums = []
list_traces = []
list_best_metric = []
list_valid_metric = []
list_arch_metric = []

for idx in range(750):
    valid_metrics_fn_list = glob.glob(valid_regex_format.format(idx))
    if len(valid_metrics_fn_list) > 0:
        valid_metrics_fn = valid_metrics_fn_list[0]
        checkpoint_numbers, metric_values = get_valid_trace(valid_metrics_fn)
        list_arch_nums.append(idx)
        list_traces.append((checkpoint_numbers, metric_values))
        list_best_metric.append(np.max(metric_values))
        list_valid_metric.append(get_valid_metric(os.path.dirname(valid_metrics_fn)))
        list_arch_metric.append(calc_num_layers(valid_metrics_fn.replace('validation_metrics', 'brain_arch')))

print(len(list_arch_nums))


In [None]:
SORT_IDX = np.argsort(list_valid_metric)[::-1]
sorted_list_arch_nums = np.array(list_arch_nums)[SORT_IDX]
sorted_list_valid_metric = np.array(list_valid_metric)[SORT_IDX]
sorted_list_arch_metric = np.array(list_arch_metric)[SORT_IDX]

for arch_metric in [8, 2, 4, 6, 3, 7]:#np.unique(sorted_list_arch_metric):
    IDX = sorted_list_arch_metric == arch_metric
    print('=== {} layer networks ==='.format(arch_metric))
    for _, (arch_num, valid_metric) in enumerate(zip(sorted_list_arch_nums[IDX], sorted_list_valid_metric[IDX])):
        if _ < 8:
            print(arch_num, valid_metric, len(list_traces[arch_num][0]))

# print(sorted_list_arch_nums)


In [None]:
kwargs_plot = {
    'marker': '',
    'lw': 0.15,
    'color': [0.0] * 3
}

kwargs_plot_main = copy.deepcopy(kwargs_plot)
kwargs_plot_main['color'] = np.array([26, 66, 32])/256
kwargs_plot_main['lw'] = 3.0
kwargs_plot_main['zorder'] = 900
arch_num_main = 191

kwargs_plot_accent = copy.deepcopy(kwargs_plot)
kwargs_plot_accent['lw'] = 2.0
kwargs_plot_accent['zorder'] = 909
# ARCH_N_list = [191] + sorted([349, 244, 277, 307, 289, 76])
ARCH_N_list = [191] + sorted([188, 41, 254, 144, 361, 338])

ARCH_color_list = util_figures.get_color_list(9, cmap_name='Set1')
ARCH_color_list = [[0, 1, 0]] + ARCH_color_list
ARCH_color_list.pop(3)
ARCH_color_list.pop(1)
batch_size = 64


fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(3.5, 3))

max_steps = 0
for arch_num, (checkpoint_numbers, metric_values) in zip(list_arch_nums, list_traces):
    max_steps = max(max_steps, max(checkpoint_numbers))
    if arch_num == arch_num_main:
        ax.plot(checkpoint_numbers, 100*metric_values, **kwargs_plot_main)
    if arch_num in ARCH_N_list:
        kwargs_plot_accent['color'] = ARCH_color_list[ARCH_N_list.index(arch_num)]
#         kwargs_plot_accent['label'] = '{}:{}'.format(arch_num, list_arch_metric[arch_num])
        kwargs_plot_accent['label'] = None
        ax.plot(checkpoint_numbers, 100*metric_values, **kwargs_plot_accent)
        ax.plot(checkpoint_numbers, 100*metric_values, **kwargs_plot, zorder=999)
    else:
        kwargs_plot_accent['label'] = None
        ax.plot(checkpoint_numbers, 100*metric_values, **kwargs_plot)

xlimits = [0, max_steps]
xticks = np.arange(xlimits[0], xlimits[-1], 50000)
xticks_minor = np.arange(xlimits[0], xlimits[-1], 10000)
xticklabels = ['{:.0f}'.format(t) for t in xticks]
ylimits = [0, 25]
yticks = np.arange(ylimits[0], ylimits[-1]+1, 5)
yticklabels = ['{:.0f}%'.format(t) for t in yticks]

# ax.legend(fontsize=8)
ax = util_figures.format_axes(ax,
                              str_xlabel='Training steps (batch size = {})'.format(batch_size),
                              str_ylabel='Validation set accuracy',
                              xscale='linear',
                              yscale='linear',
                              xlimits=xlimits,
                              ylimits=ylimits,
                              xticks=xticks,
                              yticks=yticks,
                              xticks_minor=xticks_minor,
                              yticks_minor=None,
                              xticklabels=xticklabels,
                              yticklabels=yticklabels,
                              spines_to_hide=[],
                              major_tick_params_kwargs_update={},
                              minor_tick_params_kwargs_update={})
plt.tight_layout()
plt.show()

save_fn = os.path.join(save_dir, 'schematic_valid_traces_arch_search_v02.pdf')
fig.savefig(save_fn, bbox_inches='tight', pad_inches=0, transparent=False)



fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(3, 3))

bin_limits = [5e-1, 1e1]
bin_step = 0.05
bins = [bin_limits[0]]
while bins[-1] < bin_limits[1]:
    bins.append(bins[-1] * (1.0+bin_step))
bins = np.array(bins)
bin_centers = (bins[:-1] + bins[1:]) / 2
bin_widths = bins[1:] - bins[:-1]
bin_counts, bin_edges = np.histogram(list_valid_metric, bins=bins)

ax.bar(bin_centers, bin_counts, width=bin_widths, align='center', color='k', fc='k', ec=[0.2]*3, lw=0.5)
ax = util_figures.format_axes(ax,
                              str_xlabel='Median F0 error (%)',
                              str_ylabel='Number of architectures',
                              xscale='log',
                              yscale='linear',
                              spines_to_hide=[],
                              xlimits=[0.4, 10],
                              xticks=[0.5, 1.0, 2.0, 4.0, 8.0],
                              xticklabels=[0.5, 1.0, 2.0, 4.0, 8.0],
                              xticks_minor=[],
                              yticks=np.linspace(0, 40, 5),
                              yticks_minor=[],
                              major_tick_params_kwargs_update={},
                              minor_tick_params_kwargs_update={})

plt.tight_layout()
plt.show()

save_fn = os.path.join(save_dir, 'schematic_valid_hist_medianf0err_arch_search_v02.pdf')
fig.savefig(save_fn, bbox_inches='tight', pad_inches=0, transparent=False)


In [None]:
import sys
import os
import json
import numpy as np
import glob

%matplotlib inline
import matplotlib.pyplot as plt

sys.path.append('/om2/user/msaddler/python-packages/msutil')
import util_figures

color_list = util_figures.get_color_list(6, cmap_name='gist_heat')
color_list = [color_list[idx] for idx in [0, 2, 4]]
# color_list = util_figures.get_color_list(8, cmap_name='Accent')
# color_list = [color_list[idx] for idx in [4,5]]


master_list = [
    ('/om/scratch/*/msaddler/data_pitchnet/PND_v08/noise_TLAS_snr_neg10pos10/SPECTRAL_STATISTICS_v00/results_dict.json', 'Speech + music (natural)\nwith no background noise', color_list[0]),
    
#     ('/om/scratch/*/msaddler/data_pitchnet/PND_v08/noise_TLAS_snr_neg10pos10/SPECTRAL_STATISTICS_v00/results_dict.json', 'Speech + music (natural)', color_list[0]),
#     ('/om/scratch/*/msaddler/data_pitchnet/PND_v08/noise_TLAS_snr_neg10pos10_filter_signalLPv01/SPECTRAL_STATISTICS_v00/results_dict.json', 'Speech + music (lowpass)', color_list[1]),
#     ('/om/scratch/*/msaddler/data_pitchnet/PND_v08/noise_TLAS_snr_neg10pos10_filter_signalHPv00/SPECTRAL_STATISTICS_v00/results_dict.json', 'Speech + music (highpass)', color_list[2]),
#     ('/om/scratch/*/msaddler/data_pitchnet/PND_v08/noise_TLAS_snr_neg10pos10/SPECTRAL_STATISTICS_v00/results_dict.json', 'Natural background noise', [0.75]*3),

#     ('/om/scratch/*/msaddler/data_pitchnet/PND_mfcc/PNDv08PYSmatched12_TLASmatched12_snr_neg10pos10_phase3/SPECTRAL_STATISTICS_v00/results_dict.json', 'Synthetic tones (matched)', color_list[0]),
#     ('/om/scratch/*/msaddler/data_pitchnet/PND_mfcc/PNDv08PYSnegated12_TLASmatched12_snr_neg10pos10_phase3/SPECTRAL_STATISTICS_v00/results_dict.json', 'Synthetic tones (anti-matched)', color_list[2]),
#     ('/om/scratch/*/msaddler/data_pitchnet/PND_mfcc/PNDv08PYSmatched12_TLASmatched12_snr_neg10pos10_phase3/SPECTRAL_STATISTICS_v00/results_dict.json', 'Synthetic noise (matched)', [0.75]*3),

#     ('/om/scratch/*/msaddler/data_pitchnet/PND_v08spch/noise_TLAS_snr_neg10pos10/SPECTRAL_STATISTICS_v00/results_dict.json', 'Speech only', util_figures.get_color_list(8, cmap_name='Accent')[4]),
#     ('/om/scratch/*/msaddler/data_pitchnet/PND_v08inst/noise_TLAS_snr_neg10pos10/SPECTRAL_STATISTICS_v00/results_dict.json', 'Music only', util_figures.get_color_list(8, cmap_name='Accent')[5]),
#     ('/om/scratch/*/msaddler/data_pitchnet/PND_v08/noise_TLAS_snr_neg10pos10/SPECTRAL_STATISTICS_v00/results_dict.json', 'Natural background noise', [0.75]*3),

#     ('/om/scratch/*/msaddler/data_pitchnet/PND_v08/noise_TLAS_snr_neg10pos10/SPECTRAL_STATISTICS_v00/results_dict.json', 'Natural speech + music', color_list[0]),
#     ('/om/scratch/*/msaddler/data_pitchnet/PND_v08/noise_TLAS_snr_neg10pos10/SPECTRAL_STATISTICS_v00/results_dict.json', 'Natural background noise', [0.75]*3),
]

# master_list = [
# #     ('/om/scratch/*/msaddler/data_pitchnet/PND_v08/noise_TLAS_snr_neg10pos10/TMP_mean_spectrum_KEY_stimuli_signal.json', 'Speech + Music (natural)', color_list[0]),
# #     ('/om/scratch/*/msaddler/data_pitchnet/PND_v08/noise_TLAS_snr_neg10pos10_filter_signalBPv00/TMP_mean_spectrum_KEY_stimuli_signal.json', 'Speech + Music (bandpass)', color_list[1]),
# #     ('/om/scratch/*/msaddler/data_pitchnet/PND_v08/noise_TLAS_snr_neg10pos10_filter_signalHPv00/TMP_mean_spectrum_KEY_stimuli_signal.json', 'Speech + Music (highpass)', color_list[2]),
# #     ('/om/scratch/*/msaddler/data_pitchnet/PND_v08/noise_TLAS_snr_neg10pos10/TMP_mean_spectrum_KEY_stimuli_noise.json', 'Natural background noise', [0.75]*3),
    
# #     ('/om/scratch/*/msaddler/data_pitchnet/PND_v08spch/noise_TLAS_snr_neg10pos10/TMP_mean_spectrum_KEY_stimuli_signal.json', 'Speech only', color_list[0]),
# #     ('/om/scratch/*/msaddler/data_pitchnet/PND_v08inst/noise_TLAS_snr_neg10pos10/TMP_mean_spectrum_KEY_stimuli_signal.json', 'Music only', color_list[1]),
# #     ('/om/scratch/*/msaddler/data_pitchnet/PND_v08/noise_TLAS_snr_neg10pos10/TMP_mean_spectrum_KEY_stimuli_noise.json', 'Natural background noise', [0.75]*3),

# #     ('/om/scratch/*/msaddler/data_pitchnet/PND_synthetic/noise_UMNm_snr_neg10pos10_phase01_filter_signalLPv00/TMP_mean_spectrum_KEY_stimuli_signal.json', 'Synthetic tones (lowpass)', color_list[0]),
# #     ('/om/scratch/*/msaddler/data_pitchnet/PND_synthetic/noise_UMNm_snr_neg10pos10_phase01_filter_signalBPv00/TMP_mean_spectrum_KEY_stimuli_signal.json', 'Synthetic tones (bandpass)', color_list[1]),
# #     ('/om/scratch/*/msaddler/data_pitchnet/PND_synthetic/noise_UMNm_snr_neg10pos10_phase01_filter_signalHPv00/TMP_mean_spectrum_KEY_stimuli_signal.json', 'Synthetic tones (highpass)', color_list[2]),
# #     ('/om/scratch/*/msaddler/data_pitchnet/PND_synthetic/noise_UMNm_snr_neg10pos10_phase01_filter_signalHPv00/TMP_mean_spectrum_KEY_stimuli_noise.json', 'Synthetic background noise', [0.75]*3),

#     ('/om/scratch/*/msaddler/data_pitchnet/PND_v08/noise_TLAS_snr_neg10pos10/TMP_mean_spectrum_KEY_stimuli_signal.json', 'Speech + Music', color_list[0]),
# #     ('/om/scratch/*/msaddler/data_pitchnet/PND_v08/noise_TLAS_snr_neg10pos10/TMP_mean_spectrum_KEY_stimuli_noise.json', 'Natural background noise', [0.75]*3),
# ]

list_data_dict = []
for (fn_regex, label, color) in master_list:
    fn = glob.glob(fn_regex)[0]
    with open(fn, 'r') as f:
        results_dict = json.load(f)
    key_sig = 'stimuli/signal'
    key_fxx = 'mean_power_spectrum_freqs'
    key_pxx = 'mean_power_spectrum'
    key_n_fft = 'mean_power_spectrum_n_fft'
    if ('noise' in label) and (len(master_list) > 1):
        key_sig = 'stimuli/noise'
    data_dict = {
        'freqs': np.array(results_dict[key_sig][key_fxx]),
        'mean_spectrum': np.array(results_dict[key_sig][key_pxx]),
        'kwargs_plot': {
            'label': label,
            'color': color,
        },
        'is_noise': 'noise' in key_sig,
    }
    list_data_dict.append(data_dict)

print(len(list_data_dict))


In [None]:
fontsize_title = 12
fontsize_labels = 12
fontsize_legend = 10
fontsize_ticks = 12
figsize = (4, 3)
nticks = 5

kwargs_plot = {
    'lw': 3,
    'marker': '',
}

snr_change = 0

fig, ax = plt.subplots(nrows=1, ncols=1, figsize=figsize)

zorder=0
for data_dict in list_data_dict:
    xvals = data_dict['freqs']
    yvals = data_dict['mean_spectrum']
    kwargs_plot.update(data_dict['kwargs_plot'])
    if data_dict['is_noise']:
        yvals = yvals - yvals.max()
        yvals -= snr_change
        ax.plot(xvals, yvals, **kwargs_plot, zorder=zorder)
        ax.fill_between(xvals, -100*np.ones_like(xvals), yvals,
                        color=kwargs_plot['color'],
                        alpha=1.0,
                        zorder=zorder)
    else:
        yvals = yvals - yvals.max()
        ax.plot(xvals, yvals, **kwargs_plot, zorder=zorder)
    zorder -= 1
        
xlimits = [40, 16000]
ylimits = [-85, 5]
yticks = np.arange(ylimits[0]+5, ylimits[-1], 20)
yticks_minor = np.arange(ylimits[0], ylimits[-1], 5)

ax = util_figures.format_axes(ax,
                              str_xlabel='Frequency (Hz)',
                              str_ylabel='Mean power (dB)',
                              fontsize_labels=fontsize_labels,
                              fontsize_ticks=fontsize_ticks,
                              fontweight_labels=None,
                              xscale='log',
                              yscale='linear',
                              xlimits=xlimits,
                              ylimits=ylimits,
                              xticks=None,
                              yticks=yticks,
                              xticks_minor=None,
                              yticks_minor=yticks_minor,
                              xticklabels=None,
                              yticklabels=None,
                              spines_to_hide=[],
                              major_tick_params_kwargs_update={},
                              minor_tick_params_kwargs_update={})

kwargs_legend = {
    'loc': 'lower left',
    'frameon': True,
    'framealpha': 1.0,
    'facecolor': 'w',
    'edgecolor': 'k',
    'handlelength': 0.5,
    'markerscale': 0.0,
    'fontsize': fontsize_legend,
    'borderpad': 0.65,
    'borderaxespad': 0.25,
}
leg = ax.legend(**kwargs_legend)
for legobj in leg.legendHandles:
    legobj.set_linewidth(8.0)

plt.tight_layout()
plt.show()

# save_fn = os.path.join(save_dir, 'schematic_mean_spectra_natural.pdf')
# save_fn = os.path.join(save_dir, 'schematic_mean_spectra_synthetic.pdf')
# save_fn = os.path.join(save_dir, 'schematic_mean_spectra_speech_vs_music.pdf')
# save_fn = os.path.join(save_dir, 'schematic_mean_spectra_SNR_meanInfdB.pdf')
# fig.savefig(save_fn, bbox_inches='tight', pad_inches=0, transparent=False)
# fig.savefig('tmp.pdf', bbox_inches='tight', pad_inches=0, transparent=False)


In [None]:
fontsize_title = 12
fontsize_labels = 12
fontsize_legend = 10
fontsize_ticks = 12
figsize = (4, 3)
nticks = 5

kwargs_plot = {
    'lw': 4,
    'marker': '',
}

assert len(list_data_dict) == 2, "ONLY SIGNAL + NOISE SUPPORTED"
data_dict_signal, data_dict_noise = list_data_dict
assert data_dict_noise['is_noise'], "NOISE DATA DICT MUST BE NOISE"

list_mean_snr = [0, 20, np.inf]
list_color = [
    [0.6]*3,
    [0.7]*3,
    [0.8]*3,
]
list_label = [
    '$-10$ to $+10$ dB SNR',
    '$+10$ to $+30$ dB SNR',
    'Noiseless'
]

fig, ax = plt.subplots(nrows=1, ncols=1, figsize=figsize)
xlimits = [40, 16000]
ylimits = [-65, 5]
minusinf_val = ylimits[0] + 2
yticks = np.arange(ylimits[0]+5, ylimits[-1], 20)
yticks_minor = np.arange(ylimits[0], ylimits[-1], 5)

xvals = data_dict_signal['freqs']
yvals = data_dict_signal['mean_spectrum']
kwargs_plot.update(data_dict_signal['kwargs_plot'])
yvals = yvals - yvals.max()
ax.plot(xvals, yvals, **kwargs_plot, zorder=10)

for idx in range(len(list_mean_snr)):
    mean_snr = list_mean_snr[idx]
    color = list_color[idx]
    label = list_label[idx]
    xvals = data_dict_noise['freqs']
    yvals = data_dict_noise['mean_spectrum']
    kwargs_plot.update({'color': color, 'label': label})
    yvals = yvals - yvals.max() - mean_snr
     
    yvals[yvals < minusinf_val] = minusinf_val
    ax.plot(xvals, yvals, **kwargs_plot, zorder=-1)
    
    if idx < len(list_mean_snr) - 1:
        next_yvals = yvals - (list_mean_snr[idx+1] - list_mean_snr[idx])
        next_yvals[next_yvals < minusinf_val] = minusinf_val
        IDX = np.argmin(np.abs(xvals-4000))
        dy = 0.95 * (next_yvals[IDX] - yvals[IDX])
        y = yvals[IDX]
        dx = 0
        x = xvals[IDX]
        draw_arrow = True
    else:
        next_yvals = ylimits[0] * np.ones_like(yvals)
        draw_arrow = False
    if draw_arrow:
        ax.annotate("",
                    xy=(x+dx, y+dy),
                    xytext=(x, y),
                    arrowprops=dict(arrowstyle="->", color='w', lw=2))
    
    ax.fill_between(xvals, yvals, next_yvals,
                    color=kwargs_plot['color'],
                    alpha=1.0,
                    zorder=-10)


ax = util_figures.format_axes(ax,
                              str_xlabel='Frequency (Hz)',
                              str_ylabel='Mean PSD (dB/Hz)',
                              fontsize_labels=fontsize_labels,
                              fontsize_ticks=fontsize_ticks,
                              fontweight_labels=None,
                              xscale='log',
                              yscale='linear',
                              xlimits=xlimits,
                              ylimits=ylimits,
                              xticks=None,
                              yticks=yticks,
                              xticks_minor=None,
                              yticks_minor=yticks_minor,
                              xticklabels=None,
                              yticklabels=None,
                              spines_to_hide=[],
                              major_tick_params_kwargs_update={},
                              minor_tick_params_kwargs_update={})

kwargs_legend = {
    'loc': 'lower left',
    'frameon': True,
    'framealpha': 1.0,
    'facecolor': 'w',
    'edgecolor': 'k',
    'handlelength': 0.5,
    'markerscale': 0.0,
    'fontsize': fontsize_legend,
    'borderpad': 0.65,
    'borderaxespad': 0.25,
}
leg = ax.legend(**kwargs_legend)
for legobj in leg.legendHandles:
    legobj.set_linewidth(8.0)

plt.tight_layout()
plt.show()

# save_fn = os.path.join(save_dir, 'tmp_schematic_mean_spectra_snr_manipulation.pdf')
# fig.savefig(save_fn, bbox_inches='tight', pad_inches=0, transparent=False)


In [None]:
import sys
import h5py
import numpy as np
import importlib

sys.path.append('/om2/user/msaddler/python-packages/bez2018model')
import bez2018model
importlib.reload(bez2018model)

sys.path.append('/om2/user/msaddler/python-packages/msutil')
import util_stimuli
importlib.reload(util_stimuli)


# Make nervegram image for CNN drawing
kwargs_nervegram = {
    'nervegram_dur': 0.050,
    'nervegram_fs': 20000,
    'buffer_start_dur': 0.070,
    'buffer_end_dur': 0.010,
    'pin_fs': 100e3,
    'pin_dBSPL_flag': 0,
    'pin_dBSPL': None,
    'species': 2,
    'bandwidth_scale_factor': 1.0,
    'cf_list': None,
    'num_cf': 100,
    'min_cf': 125.0,
    'max_cf': 14e3,
    'max_spikes_per_train': 500,
    'num_spike_trains': 1,
    'cohc': 1.0,
    'cihc': 1.0,
    'IhcLowPass_cutoff': 3000,
    'IhcLowPass_order': 7,
    'spont': 70.0,
    'noiseType': 0,
    'implnt': 0,
    'tabs': 6e-4,
    'trel': 6e-4,
    'random_seed': None,
    'return_vihcs': False,
    'return_meanrates': True,
    'return_spike_times': False,
    'return_spike_tensor_sparse': False,
    'return_spike_tensor_dense': False,

}

fn = '/om/scratch/Tue/msaddler/data_pitchnet/PND_v08/noise_TLAS_snr_neg10pos10/PND_sr32000_v08_0000000-0021000.hdf5'
f = h5py.File(fn, 'r')
IDX = 14184#np.random.randint(f['stimuli/signal'].shape[0]) #14184
signal = f['stimuli/signal'][IDX]
noise = f['stimuli/noise'][IDX]
f0 = f['nopad_f0_mean'][IDX]
sr = f['sr'][0]
f.close()

snr = 10
dBSPL = 60

y = util_stimuli.combine_signal_and_noise(signal, noise, snr=snr)
y = util_stimuli.set_dBSPL(y, dBSPL)

out_dict = bez2018model.nervegram(y, sr, **kwargs_nervegram)
print(IDX, f0)


In [None]:
# CNN drawing

import os
import sys
import json
import copy
import numpy as np

%matplotlib inline
import matplotlib.pyplot as plt

sys.path.append('/om2/user/msaddler/python-packages/msutil')
import util_figures
import util_figures_cnn
importlib.reload(util_figures_cnn)



# brain_arch_fn = '/om/scratch/Mon/msaddler/pitchnet/saved_models/arch_search_v01_arch_0302_manipulations_v03/arch_0302_{:04.0f}/brain_arch.json'
brain_arch_fn = '/om/scratch/Tue/msaddler/pitchnet/saved_models/arch_search_v02/arch_{:04.0f}/brain_arch.json'
# ARCH_N_list = [0, 3, 6, 9, 19, 29]
# ARCH_N_list = [302] + np.random.randint(0, 302, 20).tolist()
# ARCH_N_list = [280, 373, 286, 208, 302, 291, 346, 259]
# ARCH_N_list = [302] + [185, 349, 76, 244, 340, 277, 307, 289]
# ARCH_N_list = [302] + [349, 244, 277, 307, 289, 76]
# ARCH_N_list = [302, 208, 373, 270, 97, 245, 287, 87, 325, 285]
# ARCH_N_list = [302] + sorted([349, 244, 277, 307, 289, 76])
ARCH_N_list = [191]
# ARCH_N_list = [191] + sorted([188, 41, 254, 144, 361, 338]) # For figure 1
ARCH_color_list = util_figures.get_color_list(9, cmap_name='Set1')
ARCH_color_list = [[0, 1, 0]] + ARCH_color_list
ARCH_color_list.pop(3)
ARCH_color_list.pop(1)

# ARCH_N_list = [191, 302, 288, 335, 346, 286, 83, 154, 190, 338] # Top 10 archs in order (arch_search_v02)
ARCH_color_list = [[0, 1, 0]] * len(ARCH_N_list)

for ARCH_N in ARCH_N_list:
    ARCH_INDEX_IN_LIST = ARCH_N_list.index(ARCH_N)
    layer_list = util_figures_cnn.process_cnn_layer_list(brain_arch_fn.format(ARCH_N))
    ARCH_NAME = os.path.basename(os.path.dirname(brain_arch_fn.format(ARCH_N)))
    ARCH_NAME = 'rank{:02}_'.format(ARCH_INDEX_IN_LIST) + ARCH_NAME

    print(ARCH_N, ARCH_NAME)
    
    
#     kwargs_polygon_kernel_update = {}
    kwargs_polygon_kernel_update = {
        'fc': ARCH_color_list[ARCH_INDEX_IN_LIST],
    }
    
    fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(1, 1))
    ax = util_figures_cnn.draw_cnn_from_layer_list(ax, layer_list,
                                                   deg_scale_x=60,
                                                   deg_skew_y=30,
                                                   scaling_w='log2',
                                                   scaling_h='log2',
                                                   scaling_n='log2',
                                                   input_image=np.squeeze(out_dict['nervegram_meanrates']),
                                                   kwargs_polygon_kernel_update=kwargs_polygon_kernel_update)
    
    [xb, yb, dxb, dyb] = ax.dataLim.bounds
    fig_factor = 6
    fig.set_size_inches(dxb/fig_factor*1.2, dyb/fig_factor/1.2)
    
    plt.show()
    
#     save_fn = os.path.join(save_dir, 'schematic_cnn_{}.pdf'.format(ARCH_NAME))
#     save_fn = '/om2/user/msaddler/pitchnet/assets_psychophysics/figures/archive_2021_05_07_pitchnet_paper_figures_v04/schematic_cnn_arch_{:04d}_neurophysiology.pdf'.format(ARCH_N)
#     fig.savefig(save_fn, bbox_inches='tight', pad_inches=0, transparent=True)
#     print(save_fn)


In [None]:
import sys
import os
import json
import numpy as np
import glob
import importlib

sys.path.append('/om2/user/msaddler/python-packages/msutil')
import util_figures_cnn

importlib.reload(util_figures_cnn)


def calc_best_metric(valid_metrics_fn, metric_key='f0_label:accuracy', maximize=True):
    with open(valid_metrics_fn) as f: valid_metrics_dict = json.load(f)
    metric_values = valid_metrics_dict[metric_key]
    if maximize: best_metric_value = np.max(metric_values)
    else: best_metric_value = np.min(metric_values)
    return best_metric_value

# Specify scope of all models to compare (regex must grab all model output directories)
regex_model_dir = '/om/scratch/*/msaddler/pitchnet/saved_models/arch_search_v02/arch_*/'
tmp_list_model_dir = sorted(glob.glob(regex_model_dir))

basename_valid_metrics = 'validation_metrics.json'
basename_arch_config = 'brain_arch.json'

list_valid_metric = []
list_model_dir = []
disp_step = 50
for idx, model_dir in enumerate(tmp_list_model_dir):
    fn_valid_metric = os.path.join(model_dir, basename_valid_metrics)
    fn_arch_config = os.path.join(model_dir, basename_arch_config)

    include_model_flag = True
    if not os.path.exists(fn_arch_config): include_model_flag = False
    if not os.path.exists(fn_valid_metric): include_model_flag = False
    
    if include_model_flag:
        list_model_dir.append(model_dir)
        list_valid_metric.append(calc_best_metric(fn_valid_metric))
    if idx % disp_step == 0:
        print(model_dir, include_model_flag)

print('Number of included networks:', len(list_valid_metric))


In [None]:
sort_idx = np.flip(np.argsort(list_valid_metric))
sorted_list_valid_metric = list(np.array(list_valid_metric)[sort_idx])
sorted_list_model_dir = list(np.array(list_model_dir)[sort_idx])

model_dir_pattern = '/om/scratch/Fri/msaddler/pitchnet/saved_models/arch_search_v02/arch_{:04.0f}/'

# ARCH_N_list = [302, 208, 373, 270, 97, 245, 287, 87, 325, 285]
ARCH_N_list = [191, 302, 288, 335, 346, 286, 83, 154, 190, 338] # arch_search_v02
for ARCH_N in ARCH_N_list:
    model_dir = model_dir_pattern.format(ARCH_N)
    model_idx = sorted_list_model_dir.index(model_dir)
    model_vm = sorted_list_valid_metric[model_idx]
    model_rank = model_idx + 1
    
    print('ID: {}\trank: {} \tvalid_acc: {:.1f}'.format(
        os.path.basename(model_dir[:-1]), model_rank, 100.0*model_vm))
    
    brain_arch_fn = os.path.join(model_dir, 'brain_arch.json')
    layer_list = util_figures_cnn.process_cnn_layer_list(brain_arch_fn)

    copy_str = ''
    for layer in layer_list:
        layer_name = layer['layer_name'].replace('top', 'out').replace('batch_norm', 'norm').replace('intermediate', '0')
        if layer_name[-1].isnumeric():
            layer_name = layer_name[:-1] + '{}'.format(int(layer_name[-1]) + 1)
        if not layer_name == 'flatten_end_conv':
            if 'conv' in layer_name:
                copy_str += '{} {}\n'.format(layer_name, layer['shape_kernel'])
            elif 'pool' in layer_name:
                copy_str += '{} {}\n'.format(layer_name, layer['strides'])
            else:
                copy_str += '{} {}\n'.format(layer_name, layer['shape_activations'])
    print(copy_str)


In [None]:
import sys
import os
import numpy as np

%matplotlib inline
import matplotlib.pyplot as plt

import importlib

sys.path.append('/om2/user/msaddler/python-packages/msutil')
import util_figures
importlib.reload(util_figures)
import util_stimuli
importlib.reload(util_stimuli)
import util_misc
importlib.reload(util_misc)

sys.path.append('/python-packages/tfcochlearn')
import util_cochlear_filters
importlib.reload(util_cochlear_filters)


def get_roex_filterbank(signal_length, sr_signal,
                        kwargs_cochlear_filters={},
                        kwargs_cochlear_filters_DIFF={}):
    '''
    '''
    kwargs_cochlear_filters = util_misc.recursive_dict_merge(kwargs_cochlear_filters,
                                                             kwargs_cochlear_filters_DIFF)
    filts, cfs, bws, freqs = util_cochlear_filters.make_roex_filters(signal_length, sr_signal,
                                                                     **kwargs_cochlear_filters)
    filts = 20*np.log10(filts)
    return freqs, filts



freqs, filts = get_roex_filterbank(4*4800, 32000,
                                   kwargs_cochlear_filters={'n':20, 'min_cf':125, 'max_cf':8e3},
                                   kwargs_cochlear_filters_DIFF={})

lw = 1.5
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(2.5, 1.25))


for fidx in range(filts.shape[0]):
    ax.plot(freqs, filts[fidx, :], color='k', lw=lw)

ax = util_figures.format_axes(ax,
                              xlimits=[1e2, 10e3],
                              ylimits=[-40, 4],
                              xticks=[],
                              xticks_minor=[],
                              xticklabels=None,
                              yticks=[],
                              yticks_minor=[],
                              yticklabels=None,
                              xscale='log',
                              yscale='linear')
# [ax.spines[key].set_linewidth(1.5*lw) for key in ax.spines]
ax.axis('off')
ax.margins(0,0)
ax.xaxis.set_major_locator(plt.NullLocator())
ax.yaxis.set_major_locator(plt.NullLocator())
plt.show()

# save_fn = os.path.join(save_dir, 'schematic_AN_cochFilterBank.pdf')
# fig.savefig(save_fn, bbox_inches='tight', pad_inches=0, transparent=False)
# print(save_fn)
