In [18]:
import os
import sys
import copy
import h5py
import json
import numpy as np
import scipy.signal
import tensorflow as tf

import util_cochlea
import util_network
import util_stimuli


In [25]:
"""
Specify model directory and load config / architecture
"""

dir_model = "models/spkr_word_recognition/simplified_IHC3000/arch0_0000"
# dir_model = "models/sound_localization/simplified_IHC3000_delayed_integration/arch01"

fn_config = os.path.join(dir_model, "config.json")
fn_arch = os.path.join(dir_model, "arch.json")
fn_ckpt = os.path.join(dir_model, "ckpt_BEST")

with open(fn_config, "r") as f_config:
    CONFIG = json.load(f_config)
with open(fn_arch, "r") as f_arch:
    list_layer_dict = json.load(f_arch)
n_classes_dict = CONFIG["n_classes_dict"]

if CONFIG.get("kwargs_cochlea", {}):
    sr = CONFIG["kwargs_cochlea"]["sr_input"]
    if "localization" in dir_model:
        input_shape = [int(1.3 * sr), 2] # 1.3-second binaural input for localization model
    else:
        input_shape = [int(2 * sr)] # 2-second monaural input for word + voice recognition model
else:
    if "localization" in dir_model:
        input_shape = [50, 10000, 3, 2] # Pre-generated nervegram (50 freq channels, 1-second at 10 kHz, 3 spont rates, binaural)
    else:
        input_shape = [50, 20000, 3] # Pre-generated nervegram (50 freq channels, 2-seconds at 10 kHz, 3 spont rates, monaural)

print(f"Model input shape: {input_shape}")
print(f"Model output shape(s): {n_classes_dict}")


Model input shape: [40000]
Model output shape(s): {'label_speaker_int': 433, 'label_word_int': 794}


In [26]:
"""
Build TensorFlow model object and load pre-trained weights
"""

def model_io_function(x):
    y = x
    if CONFIG.get("kwargs_cochlea", {}):
        if "label_loc_int" in n_classes_dict:
            msg = "expected [batch, freq, time, spont, channel=2] or [batch, time, channel=2]"
            assert (len(y.shape) in [3, 5]) and (y.shape[-1] == 2), msg
            y0, _ = util_cochlea.cochlea(y[..., 0], **copy.deepcopy(CONFIG["kwargs_cochlea"]))
            y1, _ = util_cochlea.cochlea(y[..., 1], **copy.deepcopy(CONFIG["kwargs_cochlea"]))
            y = tf.concat([y0, y1], axis=-1)
            if y.shape[2] > 10000:
                y = util_cochlea.random_slice(
                    y,
                    slice_length=10000,
                    axis=2, # Time axis
                    buffer=500)
        else:
            y, _ = util_cochlea.cochlea(y, **copy.deepcopy(CONFIG["kwargs_cochlea"]))
    y, _ = util_network.build_network(y, list_layer_dict, n_classes_dict=n_classes_dict)
    return y

inputs = tf.keras.Input(shape=input_shape, batch_size=None, dtype=tf.float32)
model = tf.keras.Model(inputs=inputs, outputs=model_io_function(inputs))
model.load_weights(fn_ckpt)
model.summary()


[cochlea] converting audio to subbands using fir_gammatone_filterbank
[cochlea] half-wave rectified subbands
[tf_fir_resample] interpreted `tensor_input.shape` as [batch, freq=50, time=40000]
[tf_fir_resample] `kwargs_fir_lowpass_filter`: {'cutoff': 3000, 'fir_dur': 0.05, 'ihc_filter': True, 'order': 7}
[fir_lowpass_filter] sr_filt = 20000.0 Hz
[fir_lowpass_filter] numtaps = 1001 samples
[fir_lowpass_filter] fir_dur = 0.05 seconds
[fir_lowpass_filter] cutoff = 3000 Hz
[fir_lowpass_filter] order = 7 (bez2018model IHC filter)
[cochlea] resampled subbands from 20000 Hz to 10000 Hz with filter: {'cutoff': 3000, 'fir_dur': 0.05, 'ihc_filter': True, 'order': 7}
[cochlea] half-wave rectified resampled subbands
[cochlea] incorporated sigmoid_rate_level_function: {'dynamic_range': [20.0, 40.0, 80.0], 'dynamic_range_interval': 0.95, 'envelope_mode': True, 'rate_max': [250.0, 250.0, 250.0], 'rate_spont': [0.0, 0.0, 0.0], 'threshold': [0.0, 12.0, 28.0]}
[cochlea] inferring `sr=10000.0` for spike_g

In [27]:
"""
Load stimuli on which to evaluate model
"""
if "sound_localization" in dir_model:
    fn_stim = "stimuli/sound_localization/evaluation/speech_in_noise_in_reverb_v04/stim.hdf5"
else:
    fn_stim = "stimuli/spkr_word_recognition/evaluation/speech_in_synthetic_textures/stim.hdf5"
with h5py.File(fn_stim, "r") as f:
    batch_size = 8
    idx0 = 0
    indexes = slice(idx0, idx0 + batch_size)
    y = f["signal"][indexes]
    sr_src = f["sr"][0]
    y = scipy.signal.resample_poly(y, up=sr, down=sr_src, axis=1)
    print(f"Batch of example inputs: {y.shape=} {y.dtype=}")
    for k in f.keys():
        if f[k].ndim == 1:
            print("|__", k, f[k].dtype, f[k][indexes])
    print("|__ signal: {} (resampled: {} --> {} Hz)".format(y.shape, sr_src, sr))


Batch of example inputs: y.shape=(8, 40000) y.dtype=dtype('float32')
|__ dur float64 [2. 2. 2. 2. 2. 2. 2. 2.]
|__ index_example int64 [0 1 2 3 4 5 6 7]
|__ index_foreground int64 [0 1 2 3 4 5 6 7]
|__ index_texture int64 [0 0 0 0 0 0 0 0]
|__ label_speaker_int int64 [204 279 225 287 364 364 263 265]
|__ label_word_int int64 [ 3  5  7  8 10 11 13 14]
|__ snr int64 [-3 -3 -3 -3 -3 -3 -3 -3]
|__ sr int64 [20000 20000 20000 20000 20000 20000 20000 20000]
|__ signal: (8, 40000) (resampled: 20000 --> 20000.0 Hz)


In [28]:
"""
Evaluate model on example stimuli
"""

out = model(y)

print("Model outputs (softmax cross entropy logits):")
for k, v in out.items():
    print("|__", k, v.dtype, v.shape)

print("Model outputs (class label predictions):")
for k, v in out.items():
    print("|__", k, np.argmax(v, axis=1))


Model outputs (softmax cross entropy logits):
|__ label_speaker_int <dtype: 'float32'> (8, 433)
|__ label_word_int <dtype: 'float32'> (8, 794)
Model outputs (class label predictions):
|__ label_speaker_int [204 279 225 214 364 364 263 264]
|__ label_word_int [302  95   7   8  10  10  18  14]
