In [2]:
import os
import sys
import pdb
import json
import copy
import glob
import h5py
import numpy as np
import pandas as pd
import scipy.io.wavfile
import time
import importlib
import tensorflow as tf

%matplotlib inline
import matplotlib.pyplot as plt
import IPython.display as ipd

import util_tfrecords
import util_signal
import util_cochlea
import util_network
import util_optimize
import util_evaluate
import util_figures
import util_stimuli
import util_misc


def azim_elev_to_label(azim, elev):
    """
    """
    label = (elev / 10) * 72 + (azim / 5)
    return np.array(label).astype(int)


def label_to_azim_elev(label):
    """
    """
    elev = np.array((label // 72) * 10)
    azim = np.array((label % 72) * 5)
    return np.array(azim).astype(float), np.array(elev).astype(float)


In [17]:
dir_model = 'saved_models/tf2_model/archFrancl01'
fn_arch = os.path.join(dir_model, 'arch.json')
fn_config = os.path.join(dir_model, 'config.json')
fn_ckpt = os.path.join(dir_model, 'ckpt_BEST')

with open(fn_arch, 'r') as f:
    list_layer_dict = json.load(f)
with open(fn_config, 'r') as f:
    CONFIG = json.load(f)


def cochlea_model_io_function(x):
    """
    Wrapper function around cochlear model tensorflow graph
    """
    y = x
    if CONFIG.get('kwargs_cochlea', {}):
        msg = "expected input with shape [batch, time, channel=2]"
        assert (len(y.shape) == 3) and (y.shape[-1] == 2), msg
        # Cochlear model for ear index 0
        y0, _ = util_cochlea.cochlea(y[..., 0], **copy.deepcopy(CONFIG['kwargs_cochlea']))
        # Cochlear model for ear index 1
        y1, _ = util_cochlea.cochlea(y[..., 1], **copy.deepcopy(CONFIG['kwargs_cochlea']))
        # Binaural cochlear model representation with shape [batch, freq, time, channel=2]
        y = tf.concat([y0[..., tf.newaxis], y1[..., tf.newaxis]], axis=-1)
        msg = "expected cochlear model output with shape [batch, freq, time, channel=2]"
        assert (len(y.shape) == 4) and (y.shape[-1] == 2), msg
    return y


def network_model_io_function(x):
    """
    Wrapper function around network tensorflow graph
    """
    y = x
    y, _ = util_network.build_network(y, list_layer_dict, n_classes_dict=CONFIG['n_classes_dict'])
    return y


# Build tensorflow Keras model objects (cochlea, network, combined)
tf.keras.backend.clear_session()
inputs_sound = tf.keras.Input(shape=(48000, 2), batch_size=None, dtype=tf.float32)
inputs_coch = tf.keras.Input(shape=(39, 8000, 2), batch_size=None, dtype=tf.float32)
cochlea_model = tf.keras.Model(
    inputs=inputs_sound,
    outputs=cochlea_model_io_function(inputs_sound),
    name='cochlea_model')
network_model = tf.keras.Model(
    inputs=inputs_coch,
    outputs=network_model_io_function(inputs_coch),
    name='network_model')
model = tf.keras.Model(
    inputs=inputs_sound,
    outputs=network_model(cochlea_model(inputs_sound)))

tf.get_logger().setLevel('ERROR')
network_model.load_weights(fn_ckpt) # <-- Loading `network_model` weights will effect `model` as well
print('Loaded: {}'.format(fn_ckpt))
tf.get_logger().setLevel('INFO')


[cochlea] converting audio to subbands using half_cosine_filterbank
[cochlea] half-wave rectified subbands
[cochlea] resampled subbands from 48000 Hz to 8000 Hz with filter: {'cutoff': 4000, 'numtaps': 4097, 'window': ['kaiser', 5.0]}
[cochlea] half-wave rectified resampled subbands
[cochlea] applied 0.3 power compression to subbands
[cochlea] converting audio to subbands using half_cosine_filterbank
[cochlea] half-wave rectified subbands
[cochlea] resampled subbands from 48000 Hz to 8000 Hz with filter: {'cutoff': 4000, 'numtaps': 4097, 'window': ['kaiser', 5.0]}
[cochlea] half-wave rectified resampled subbands
[cochlea] applied 0.3 power compression to subbands
Loaded: saved_models/tf2_model/archFrancl01/ckpt_BEST


In [19]:
model.summary()

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 48000, 2)]        0         
                                                                 
 cochlea_model (Functional)  (None, 39, 8000, 2)       0         
                                                                 
 network_model (Functional)  {'label_loc_int': (None,  53956632  
                              504)}                              
                                                                 
Total params: 53,956,632
Trainable params: 53,953,752
Non-trainable params: 2,880
_________________________________________________________________


In [20]:
fn = '/om2/user/msaddler/data_localize/FLDv01/valid/stim_000000-004763.hdf5'
with h5py.File(fn, 'r') as f:
    sr_hdf5 = f['sr'][0]
    sr = 48e3
    IDX = slice(100, 116)
    y = f['signal'][IDX]    
    y = scipy.signal.resample_poly(y, up=sr, down=sr_hdf5, axis=1)
    azim = f['foreground_azimuth'][IDX]
    elev = f['foreground_elevation'][IDX]
    label = azim_elev_to_label(azim, elev)

y_out = model(y[:, 22800:22800 + 48000, :])['label_loc_int'].numpy()
y_pred = scipy.special.softmax(y_out, axis=-1)
label_pred = np.argmax(y_out, axis=-1)
azim_pred, elev_pred = label_to_azim_elev(label_pred)

print('True labels: {}'.format(label))
print('Pred labels: {}'.format(label_pred))
print('True azim: {}'.format(azim.astype(int)))
print('Pred azim: {}'.format(azim_pred.astype(int)))
print('True elev: {}'.format(elev.astype(int)))
print('Pred elev: {}'.format(elev_pred.astype(int)))
print('Correct: {}'.format((label == label_pred).astype(int)))


True labels: [296 170 170 224 443  51 306  84 152 260 332 478 498 347 440 429]
Pred labels: [296 171 170 244 443  51 306  96 225 260 332 478 499 347 441 428]
True azim: [ 40 130 130  40  55 255  90  60  40 220 220 230 330 295  40 345]
Pred azim: [ 40 135 130 140  55 255  90 120  45 220 220 230 335 295  45 340]
True elev: [40 20 20 30 60  0 40 10 20 30 40 60 60 40 60 50]
Pred elev: [40 20 20 30 60  0 40 10 30 30 40 60 60 40 60 50]
Correct: [1 0 1 0 1 1 1 0 0 1 1 1 0 1 0 0]
