#### This notebook provides a simple walk-through of:
1. Loading and evaluating a trained audio transform on example waveforms
2. Computing auditory model loss values on example waveforms

In [1]:
import os
import sys
import glob
import numpy as np
import tensorflow as tf
import scipy.io.wavfile
import IPython.display as ipd

import util_auditory_model_loss
import util_audio_transform


#### Load example audio waveforms (2-second speech-in-noise excerpts sampled at 20 kHz).

In [2]:
def get_rms(y):
    return np.sqrt(np.mean(np.square(y - np.mean(y))))

def set_rms(y, rms):
    y = y - np.mean(y)
    return rms * y / get_rms(y)

list_fn_wav = glob.glob('audio/ex*_unprocessed_input.wav')
list_wav = []
for fn_wav in list_fn_wav:
    print('Loading audio example: {}'.format(fn_wav))
    sr, y = scipy.io.wavfile.read(fn_wav)
    y = set_rms(y, 0.02) # Our trained audio transforms expect waveforms re-scaled to 60 dB SPL
    assert sr == 20e3
    list_wav.append(y)
list_wav = np.array(list_wav) # Array of example waveforms with shape [batch, timesteps]


Loading audio example: audio/ex01_unprocessed_input.wav
Loading audio example: audio/ex02_unprocessed_input.wav
Loading audio example: audio/ex03_unprocessed_input.wav
Loading audio example: audio/ex04_unprocessed_input.wav
Loading audio example: audio/ex05_unprocessed_input.wav
Loading audio example: audio/ex06_unprocessed_input.wav
Loading audio example: audio/ex07_unprocessed_input.wav
Loading audio example: audio/ex08_unprocessed_input.wav
Loading audio example: audio/ex09_unprocessed_input.wav
Loading audio example: audio/ex10_unprocessed_input.wav


#### 1. Evaluate trained audio transforms on example audio waveforms.

In [3]:
# List of audio transforms featured in our paper (see Table 1 and Table 2)
list_audio_transform_tag = [
    'unet_A1',
    'unet_A123',
    'unet_A123W123',
    'unet_A1W1',
    'unet_R1',
    'unet_R123',
    'unet_W1',
    'unet_W123',
    'unet_cochlear_human',
    'unet_cochlear_reverse',
    'unet_germain_deep_features',
    'unet_waveform',
]
audio_transform_tag = list_audio_transform_tag[0] # Select one audio transform to run
checkpoint_filename = 'models/audio_transforms/{}/model.ckpt-600000'.format(audio_transform_tag)

# Build tensorflow graph for the audio transform (Wave-U-Net architecture) 
tf.reset_default_graph()
tf.get_logger().setLevel('ERROR')

placeholder_wav = tf.placeholder(tf.float32, [None, 40000])
tensor_wav_unet = util_audio_transform.build_unet(
    placeholder_wav,
    signal_rate=20000,
    UNET_PARAMS={})
var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=None)
saver = tf.train.Saver(var_list=var_list, max_to_keep=0)

# Load variables from model checkpoint and run audio transform on example audio waveforms
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print('Loading audio transform variables from: {}'.format(checkpoint_filename))
    saver.restore(sess, checkpoint_filename)
    print('Running audio transform: {}'.format(audio_transform_tag))
    list_wav_unet = sess.run(tensor_wav_unet, feed_dict={placeholder_wav: list_wav})


Loading audio transform variables from: models/audio_transforms/unet_A1/model.ckpt-600000
Running audio transform: unet_A1


#### Create IPython display objects for listening to input and processed audio.

In [4]:
for itr_wav in range(len(list_fn_wav)):
    print('======================== Example {:02d} ========================'.format(itr_wav + 1))
    print('Unprocessed input audio ({}):'.format(list_fn_wav[itr_wav]))
    ipd.display(ipd.Audio(list_wav[itr_wav], rate=20e3))
    print('Processed audio ({}):'.format(audio_transform_tag))
    ipd.display(ipd.Audio(list_wav_unet[itr_wav], rate=20e3))


Unprocessed input audio (audio/ex01_unprocessed_input.wav):


Processed audio (unet_A1):


Unprocessed input audio (audio/ex02_unprocessed_input.wav):


Processed audio (unet_A1):


Unprocessed input audio (audio/ex03_unprocessed_input.wav):


Processed audio (unet_A1):


Unprocessed input audio (audio/ex04_unprocessed_input.wav):


Processed audio (unet_A1):


Unprocessed input audio (audio/ex05_unprocessed_input.wav):


Processed audio (unet_A1):


Unprocessed input audio (audio/ex06_unprocessed_input.wav):


Processed audio (unet_A1):


Unprocessed input audio (audio/ex07_unprocessed_input.wav):


Processed audio (unet_A1):


Unprocessed input audio (audio/ex08_unprocessed_input.wav):


Processed audio (unet_A1):


Unprocessed input audio (audio/ex09_unprocessed_input.wav):


Processed audio (unet_A1):


Unprocessed input audio (audio/ex10_unprocessed_input.wav):


Processed audio (unet_A1):


#### 2. Build auditory model loss graph and compute waveform, cochlear model, deep feature losses.

`list_recognition_networks` specifies which recognition networks to include in the deep feature loss.

Architectures:
- `arch1` : defined in "models/recognition_networks/arch1.json"
- `arch2` : defined in "models/recognition_networks/arch2.json"
- `arch3` : defined in "models/recognition_networks/arch3.json"

Tasks:
- `taskA` : recognition network optimized for AudioSet environmental sound recognition task
- `taskW` : recognition network optimized for word recognition task
- `taskR` : randomly-initialized recognition network features

In [5]:
list_recognition_networks = [
    'arch1_taskA',
#     'arch1_taskR',
#     'arch1_taskW',
    'arch2_taskA',
#     'arch2_taskR',
#     'arch2_taskW',
    'arch3_taskA',
#     'arch3_taskR',
#     'arch3_taskW',
]

tf.reset_default_graph()
tf.get_logger().setLevel('ERROR')

# The AuditoryModelLoss class builds the full auditory model loss graph
auditory_model_loss = util_auditory_model_loss.AuditoryModelLoss(
    list_recognition_networks=list_recognition_networks)

# Compute auditory model losses for pairs of example audio waveforms
with tf.Session() as sess:
    auditory_model_loss.load_auditory_model_vars(sess)

    print("\nComputing auditory model losses between input and UNET-transformed audio")
    waveform_loss = auditory_model_loss.waveform_loss(list_wav, list_wav_unet)
    cochlear_model_loss = auditory_model_loss.cochlear_model_loss(list_wav, list_wav_unet)
    deep_feature_loss = auditory_model_loss.deep_feature_loss(list_wav, list_wav_unet)
    print("|__ Waveform losses: {}".format(waveform_loss))
    print("|__ Cochlear model losses: {}".format(cochlear_model_loss))
    print("|__ Deep feature losses: {}".format(deep_feature_loss))

    print("\nComputing auditory model losses between `input audio` and `input audio`")
    waveform_loss = auditory_model_loss.waveform_loss(list_wav, list_wav)
    cochlear_model_loss = auditory_model_loss.cochlear_model_loss(list_wav, list_wav)
    deep_feature_loss = auditory_model_loss.deep_feature_loss(list_wav, list_wav)
    print("|__ Waveform losses: {}".format(waveform_loss))
    print("|__ Cochlear model losses: {}".format(cochlear_model_loss))
    print("|__ Deep feature losses: {}".format(deep_feature_loss))


3 recognition networks included for deep feature loss:
|__ arch1_taskA: models/recognition_networks/arch1_taskA.ckpt-550000
|__ arch2_taskA: models/recognition_networks/arch2_taskA.ckpt-950000
|__ arch3_taskA: models/recognition_networks/arch3_taskA.ckpt-980000
Building waveform loss
Building cochlear model loss
[make_cos_filters_nx] using filter_spacing=`erb`
[make_cos_filters_nx] using filter_spacing=`erb`
Building deep feature loss (recognition network: arch1_taskA)
Building deep feature loss (recognition network: arch2_taskA)
Building deep feature loss (recognition network: arch3_taskA)
Loading `arch1_taskA` variables from models/recognition_networks/arch1_taskA.ckpt-550000
Loading `arch2_taskA` variables from models/recognition_networks/arch2_taskA.ckpt-950000
Loading `arch3_taskA` variables from models/recognition_networks/arch3_taskA.ckpt-980000

Computing auditory model losses between input and UNET-transformed audio
|__ Waveform losses: [3437.252  3433.5608 3420.795  3435.198 