In [1]:
import os
import sys
import importlib
import json
import glob
import numpy as np
import tensorflow as tf
import scipy.io.wavfile
import IPython.display as ipd

import util_recognition_network
import util_cochlear_model
import util_audio_transform


In [None]:
ARCH = 1
TASK = 'A'

fn_ckpt = 'models/recognition_networks/arch{}_task{}.ckpt*'.format(ARCH, TASK)
fn_ckpt = glob.glob(fn_ckpt)[-1].replace('.index', '')
fn_arch = 'models/recognition_networks/arch{}.json'.format(ARCH)
with open(fn_arch, 'r') as f_arch:
    list_layer_dict = json.load(f_arch)

fn_ckpt, fn_arch


In [None]:
importlib.reload(util_recognition_network)

tf.reset_default_graph()

if 'taskA' in fn_ckpt:
    n_classes_dict = {"/stimuli/labels_binary_via_int": 517}
else:
    n_classes_dict = {"/stimuli/word_int": 794}


tensor_input = tf.placeholder(tf.float32, [None, 40, 20000, 1])
tensor_output, tensors = util_recognition_network.build_network(
    tensor_input,
    list_layer_dict,
    n_classes_dict=n_classes_dict)

var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=None)
var_dict = {v.name: v for v in var_list}

saver = tf.train.Saver(var_list=var_list, max_to_keep=0)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    var_dict_init = sess.run(var_dict)
    saver.restore(sess, fn_ckpt)
    var_dict_load = sess.run(var_dict)

for v in var_list:
    print(v.name, v.shape, np.mean(np.abs(var_dict_init[v.name] - var_dict_load[v.name])))


In [None]:
def get_rms(y):
    return np.sqrt(np.mean(np.square(y - np.mean(y))))

def set_rms(y, rms):
    y = y - np.mean(y)
    return rms * y / get_rms(y)


regex_fn_wav = 'audio/ex*_unprocessed_input.wav'
list_fn_wav = glob.glob(regex_fn_wav)

list_y = []
for fn_wav in list_fn_wav:
    sr, y = scipy.io.wavfile.read(fn_wav)
    y = set_rms(y, 0.02)
    assert sr == 20e3
    list_y.append(y)
list_y = np.array(list_y)
list_y.shape


In [None]:
importlib.reload(util_cochlear_model)
importlib.reload(util_audio_transform)

tf.reset_default_graph()

fn_ckpt = 'models/audio_transforms/unet_cochlear_reverse/model.ckpt-600000'

tensor_waveform = tf.placeholder(tf.float32, [None, 40000])

tensor_cochlear_representation, coch_container = util_cochlear_model.build_cochlear_model(
    tensor_waveform,
    signal_rate=20000,
    filter_type='half-cosine',
    filter_spacing='erb',
    HIGH_LIM=8000,
    LOW_LIM=20,
    N=40,
    SAMPLE_FACTOR=1,
    bandwidth_scale_factor=1.0,
    compression='stable_point3',
    include_highpass=False,
    include_lowpass=False,
    linear_max=1.0,
    rFFT=True,
    rectify_and_lowpass_subbands=True,
    return_subbands_only=True)

tensor_waveform_unet = util_audio_transform.build_unet(
    tensor_waveform,
    signal_rate=20000,
    UNET_PARAMS={})

var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=None)
saver = tf.train.Saver(var_list=var_list, max_to_keep=0)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.restore(sess, fn_ckpt)
    list_y_unet = sess.run(tensor_waveform_unet, feed_dict={tensor_waveform: list_y})
    list_coch_y = sess.run(tensor_cochlear_representation, feed_dict={tensor_waveform: list_y})
    list_coch_y_unet = sess.run(tensor_cochlear_representation, feed_dict={tensor_waveform: list_y_unet})


In [None]:
for itr0 in range(list_y_unet.shape[0]):
    y = list_y[itr0]
    y_unet = list_y_unet[itr0]
    print(list_fn_wav[itr0], np.mean(list_coch_y[itr0]), np.mean(list_coch_y_unet[itr0]))
    ipd.display(ipd.Audio(y, rate=sr))
    ipd.display(ipd.Audio(y_unet, rate=sr))


In [23]:
regex_recognition_network_ckpt = 'models/recognition_networks/arch*taskA*ckpt*index'
fn_deep_feature_loss_weights= 'models/recognition_networks/deep_feature_loss_weights.json'

list_recognition_network_ckpt = [
    fn_ckpt.replace('.index', '') for fn_ckpt in glob.glob(regex_recognition_network_ckpt)
]

with open(fn_deep_feature_loss_weights, 'r') as f:
    deep_feature_loss_weights = json.load(f)

dict_recognition_network = {}
for fn_ckpt in list_recognition_network_ckpt:
    key = os.path.basename(fn_ckpt).split('.')[0]
    if 'taskA' in key:
        n_classes_dict = {"task_audioset": 517}
    else:
        n_classes_dict = {"task_word": 794}
    dict_recognition_network[key] = {
        'fn_ckpt': fn_ckpt,
        'fn_arch': fn_ckpt[:fn_ckpt.rfind('_task')] + '.json',
        'n_classes_dict': n_classes_dict,
    }

dict_recognition_network


{'arch1_taskA': {'fn_arch': 'models/recognition_networks/arch1.json',
  'fn_ckpt': 'models/recognition_networks/arch1_taskA.ckpt-550000',
  'n_classes_dict': {'task_audioset': 517}},
 'arch2_taskA': {'fn_arch': 'models/recognition_networks/arch2.json',
  'fn_ckpt': 'models/recognition_networks/arch2_taskA.ckpt-950000',
  'n_classes_dict': {'task_audioset': 517}},
 'arch3_taskA': {'fn_arch': 'models/recognition_networks/arch3.json',
  'fn_ckpt': 'models/recognition_networks/arch3_taskA.ckpt-980000',
  'n_classes_dict': {'task_audioset': 517}}}

In [24]:
importlib.reload(util_audio_transform)
importlib.reload(util_cochlear_model)
importlib.reload(util_recognition_network)


kwargs_build_cochlear_model = {}


tf.reset_default_graph()
tensor_waveform_0 = tf.placeholder(tf.float32, [None, 40000])
tensor_waveform_1 = tf.placeholder(tf.float32, [None, 40000])
tensor_deep_feature_loss = tf.zeros([], dtype=tf.float32)

tensor_coch_0, _ = util_cochlear_model.build_cochlear_model(
    tensor_waveform_0,
    **kwargs_build_cochlear_model)
tensor_coch_1, _ = util_cochlear_model.build_cochlear_model(
    tensor_waveform_1,
    **kwargs_build_cochlear_model)


for recognition_network_key in sorted(dict_recognition_network.keys()):
    with open(dict_recognition_network[recognition_network_key]['fn_arch'], 'r') as f:
        list_layer_dict = json.load(f)
    with tf.variable_scope(recognition_network_key + '_0') as scope:
        _, recognition_network_tensors_0 = util_recognition_network.build_network(
            tensor_coch_0,
            list_layer_dict,
            n_classes_dict=dict_recognition_network[key]['n_classes_dict'])
        var_list = {
            v.name.replace(scope.name + '/', '').replace(':0', ''): v
            for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=scope.name)
        }
        dict_recognition_network[recognition_network_key]['saver_0'] = tf.train.Saver(
            var_list=var_list,
            max_to_keep=0)
    with tf.variable_scope(recognition_network_key + '_1') as scope:
        _, recognition_network_tensors_1 = util_recognition_network.build_network(
            tensor_coch_1,
            list_layer_dict,
            n_classes_dict=dict_recognition_network[key]['n_classes_dict'])
        var_list = {
            v.name.replace(scope.name + '/', '').replace(':0', ''): v
            for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=scope.name)
        }
        dict_recognition_network[recognition_network_key]['saver_1'] = tf.train.Saver(
            var_list=var_list,
            max_to_keep=0)
    
    for feature_key in sorted(deep_feature_loss_weights[recognition_network_key].keys()):
        feature_weight = deep_feature_loss_weights[recognition_network_key][feature_key]
        feature_0 = recognition_network_tensors_0[feature_key]
        feature_1 = recognition_network_tensors_1[feature_key]

        feature_l1_distance = tf.reduce_sum(
            tf.math.abs(feature_0 - feature_1),
            axis=np.arange(1, len(feature_0.get_shape().as_list())))

        tensor_deep_feature_loss += feature_weight * feature_l1_distance


[make_cos_filters_nx] using filter_spacing=`erb`
[make_cos_filters_nx] using filter_spacing=`erb`


In [28]:
def get_rms(y):
    return np.sqrt(np.mean(np.square(y - np.mean(y))))

def set_rms(y, rms):
    y = y - np.mean(y)
    return rms * y / get_rms(y)


regex_fn_wav = 'audio/ex*_unprocessed_input.wav'
list_fn_wav = glob.glob(regex_fn_wav)

example_waveforms = []
for fn_wav in list_fn_wav:
    sr, waveform = scipy.io.wavfile.read(fn_wav)
    waveform = set_rms(waveform, 0.02)
    assert sr == 20e3
    example_waveforms.append(waveform)
example_waveforms = np.array(example_waveforms)


with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for recognition_network_key in sorted(dict_recognition_network.keys()):
        saver_0 = dict_recognition_network[recognition_network_key]['saver_0']
        saver_0.restore(sess, dict_recognition_network[recognition_network_key]['fn_ckpt'])
        saver_1 = dict_recognition_network[recognition_network_key]['saver_1']
        saver_1.restore(sess, dict_recognition_network[recognition_network_key]['fn_ckpt'])

    example_deep_feature_loss = sess.run(
        tensor_deep_feature_loss,
        feed_dict={
            tensor_waveform_0: example_waveforms,
            tensor_waveform_1: example_waveforms + np.random.randn(*example_waveforms.shape),
        })

example_deep_feature_loss


INFO:tensorflow:Restoring parameters from models/recognition_networks/arch1_taskA.ckpt-550000
INFO:tensorflow:Restoring parameters from models/recognition_networks/arch1_taskA.ckpt-550000
INFO:tensorflow:Restoring parameters from models/recognition_networks/arch2_taskA.ckpt-950000
INFO:tensorflow:Restoring parameters from models/recognition_networks/arch2_taskA.ckpt-950000
INFO:tensorflow:Restoring parameters from models/recognition_networks/arch3_taskA.ckpt-980000
INFO:tensorflow:Restoring parameters from models/recognition_networks/arch3_taskA.ckpt-980000


array([49.191277, 48.231163, 48.53233 , 48.74688 , 48.85193 , 48.27629 ,
       48.18045 , 48.591633, 47.91723 , 47.967995], dtype=float32)

In [29]:
dict_recognition_network

{'arch1_taskA': {'fn_arch': 'models/recognition_networks/arch1.json',
  'fn_ckpt': 'models/recognition_networks/arch1_taskA.ckpt-550000',
  'n_classes_dict': {'task_audioset': 517},
  'saver_0': <tensorflow.python.training.saver.Saver at 0x2b7792ce2978>,
  'saver_1': <tensorflow.python.training.saver.Saver at 0x2b723bea9a20>},
 'arch2_taskA': {'fn_arch': 'models/recognition_networks/arch2.json',
  'fn_ckpt': 'models/recognition_networks/arch2_taskA.ckpt-950000',
  'n_classes_dict': {'task_audioset': 517},
  'saver_0': <tensorflow.python.training.saver.Saver at 0x2b7793e8e898>,
  'saver_1': <tensorflow.python.training.saver.Saver at 0x2b7752e6e6d8>},
 'arch3_taskA': {'fn_arch': 'models/recognition_networks/arch3.json',
  'fn_ckpt': 'models/recognition_networks/arch3_taskA.ckpt-980000',
  'n_classes_dict': {'task_audioset': 517},
  'saver_0': <tensorflow.python.training.saver.Saver at 0x2b77c4ec1710>,
  'saver_1': <tensorflow.python.training.saver.Saver at 0x2b7793780be0>}}

In [10]:
class AuditoryModelLoss():
    def __init__(self,
                 dir_recognition_network='models/recognition_networks'):
        """
        """
        fn_weights = os.path.join(dir_recognition_network, 'deep_feature_loss_weights.json')
        with open(fn_weights, 'r') as f_weights:
            deep_feature_loss_weights = json.load(f_weights)
        list_fn_ckpt = glob.glob(os.path.join(dir_recognition_network, '*ckpt*index'))
        list_fn_ckpt = [fn_ckpt.replace('.index', '') for fn_ckpt in list_fn_ckpt]
        dict_recognition_network = {}
        for fn_ckpt in list_fn_ckpt:
            recognition_network_key = os.path.basename(fn_ckpt).split('.')[0]
            if 'taskA' in recognition_network_key:
                n_classes_dict = {"task_audioset": 517}
            else:
                n_classes_dict = {"task_word": 794}
            dict_recognition_network[recognition_network_key] = {
                'fn_ckpt': fn_ckpt,
                'fn_arch': fn_ckpt[:fn_ckpt.rfind('_task')] + '.json',
                'n_classes_dict': n_classes_dict,
                'weights': deep_feature_loss_weights[recognition_network_key],
            }
        print(json.dumps(dict_recognition_network, indent=4, sort_keys=True))
    
    
    def build_auditory_model():
        pass
    
    
    def load_auditory_model():
        pass
    
    
    def waveform_loss(y0, y1):
        pass
    
    
    def cochlear_model_loss(y0, y1):
        pass
    
    
    def deep_feature_loss(y0, y1):
        pass


auditory_model_loss = AuditoryModelLoss()


{
    "arch1_taskA": {
        "fn_arch": "models/recognition_networks/arch1.json",
        "fn_ckpt": "models/recognition_networks/arch1_taskA.ckpt-550000",
        "n_classes_dict": {
            "task_audioset": 517
        },
        "weights": {
            "batch_norm_0": 7.721951988869115e-07,
            "batch_norm_1": 2.180155442434395e-06,
            "batch_norm_2": 4.35400902509891e-06,
            "batch_norm_3": 9.663577649529144e-06,
            "batch_norm_4": 7.037957926197431e-06,
            "batch_norm_5": 8.04823966794683e-06,
            "batch_norm_6": 9.601534802363368e-05
        }
    },
    "arch1_taskR": {
        "fn_arch": "models/recognition_networks/arch1.json",
        "fn_ckpt": "models/recognition_networks/arch1_taskR.ckpt-0",
        "n_classes_dict": {
            "task_word": 794
        },
        "weights": {
            "batch_norm_0": 4.5313309082274834e-06,
            "batch_norm_1": 2.2921854969402775e-06,
            "batch_norm_2": 9.4444