In [2]:
import os
import sys
import json
import glob
import numpy as np
import tensorflow as tf
import scipy.io.wavfile
import IPython.display as ipd

import util_recognition_network
import util_cochlear_model
import util_auditory_model_loss
import util_audio_transform


In [3]:
def get_rms(y):
    return np.sqrt(np.mean(np.square(y - np.mean(y))))

def set_rms(y, rms):
    y = y - np.mean(y)
    return rms * y / get_rms(y)

regex_fn_wav = 'audio/ex*_unprocessed_input.wav'
list_fn_wav = glob.glob(regex_fn_wav)

list_wav = []
for fn_wav in list_fn_wav:
    print('Loading audio example: {}'.format(fn_wav))
    sr, y = scipy.io.wavfile.read(fn_wav)
    y = set_rms(y, 0.02)
    assert sr == 20e3
    list_wav.append(y)
list_wav = np.array(list_wav)


Loading audio example: audio/ex01_unprocessed_input.wav
Loading audio example: audio/ex02_unprocessed_input.wav
Loading audio example: audio/ex03_unprocessed_input.wav
Loading audio example: audio/ex04_unprocessed_input.wav
Loading audio example: audio/ex05_unprocessed_input.wav
Loading audio example: audio/ex06_unprocessed_input.wav
Loading audio example: audio/ex07_unprocessed_input.wav
Loading audio example: audio/ex08_unprocessed_input.wav
Loading audio example: audio/ex09_unprocessed_input.wav
Loading audio example: audio/ex10_unprocessed_input.wav


In [None]:
list_audio_transform_tag = [
    'unet_A1',
    'unet_A123',
    'unet_A123W123',
    'unet_A1W1',
    'unet_R1',
    'unet_R123',
    'unet_W1',
    'unet_W123',
    'unet_cochlear_human',
    'unet_cochlear_reverse',
    'unet_germain_deep_features',
    'unet_waveform',
]
augio_transform_tag = list_audio_transform_tag[0]
checkpoint_filename = 'models/audio_transforms/{}/model.ckpt-600000'.format(augio_transform_tag)

tf.reset_default_graph()
tf.get_logger().setLevel('ERROR')

placeholder_wav = tf.placeholder(tf.float32, [None, 40000])
tensor_wav_unet = util_audio_transform.build_unet(
    placeholder_wav,
    signal_rate=20000,
    UNET_PARAMS={})
var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=None)
saver = tf.train.Saver(var_list=var_list, max_to_keep=0)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print('Loading audio transform variables from: {}'.format(checkpoint_filename))
    saver.restore(sess, checkpoint_filename)
    print('Running audio transform: {}'.format(augio_transform_tag))
    list_wav_unet = sess.run(tensor_wav_unet, feed_dict={placeholder_wav: list_wav})


In [None]:
for itr_wav in range(len(list_fn_wav)):
    print('======================== Example {:02d} ========================'.format(itr_wav + 1))
    print('Unprocessed input audio ({}):'.format(list_fn_wav[itr_wav]))
    ipd.display(ipd.Audio(list_wav[itr_wav], rate=20e3))
    print('Processed audio ({}):'.format(augio_transform_tag))
    ipd.display(ipd.Audio(list_wav_unet[itr_wav], rate=20e3))


In [19]:
tf.reset_default_graph()
tf.get_logger().setLevel('ERROR')

import importlib
importlib.reload(util_auditory_model_loss)


list_recognition_networks = [
#     'arch1_taskA',
    'arch1_taskR',
#     'arch1_taskW',
#     'arch2_taskA',
    'arch2_taskR',
#     'arch2_taskW',
#     'arch3_taskA',
    'arch3_taskR',
#     'arch3_taskW',
]

auditory_model_loss = util_auditory_model_loss.AuditoryModelLoss(
    list_recognition_networks=list_recognition_networks)


with tf.Session() as sess:
    auditory_model_loss.load_auditory_model_vars(sess)

    print(auditory_model_loss.waveform_loss(list_wav[0:1], list_wav[1:2]))
    print(auditory_model_loss.waveform_loss(list_wav, list_wav))
    print(auditory_model_loss.cochlear_model_loss(list_wav[0:1], list_wav[1:2]))
    print(auditory_model_loss.cochlear_model_loss(list_wav, list_wav))
    print(auditory_model_loss.deep_feature_loss(list_wav[0:1], list_wav[1:2]))
    print(auditory_model_loss.deep_feature_loss(list_wav[0:1], list_wav[0:1]))


3 recognition networks included for deep feature loss:
|__ arch1_taskR: models/recognition_networks/arch1_taskR.ckpt-0
|__ arch2_taskR: models/recognition_networks/arch2_taskR.ckpt-0
|__ arch3_taskR: models/recognition_networks/arch3_taskR.ckpt-0
Building waveform loss
Building cochlear model loss
[make_cos_filters_nx] using filter_spacing=`erb`
[make_cos_filters_nx] using filter_spacing=`erb`
Building deep feature loss (recognition network: arch1_taskR)
Building deep feature loss (recognition network: arch2_taskR)
Building deep feature loss (recognition network: arch3_taskR)
Loading `arch1_taskR` variables from models/recognition_networks/arch1_taskR.ckpt-0
Loading `arch2_taskR` variables from models/recognition_networks/arch2_taskR.ckpt-0
Loading `arch3_taskR` variables from models/recognition_networks/arch3_taskR.ckpt-0
[872.1858]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[1078.2817]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[1.1398318]
[0.]
