In [25]:
import os
import sys
import glob
import numpy as np
import tensorflow as tf
import importlib
import IPython.display as ipd

import util_audio_transform
import util_auditory_model_loss
import util_cochlear_model
import util_recognition_network

importlib.reload(util_auditory_model_loss)


<module 'util_auditory_model_loss' from '/rdma/vast-rdma/vast/mcdermott/msaddler/auditory-model-denoising/util_auditory_model_loss.py'>

In [26]:
filenames = glob.glob('data/toy_dataset*.tfrecords')
batch_size = 8
feature_description = {
    'background/signal': tf.io.FixedLenFeature([], tf.string, default_value=None),
    'foreground/signal': tf.io.FixedLenFeature([], tf.string, default_value=None),
}
bytes_description = {
    'background/signal': {'dtype': tf.float32, 'shape': [40000]}, 
    'foreground/signal': {'dtype': tf.float32, 'shape': [40000]},
}

def parse_tfrecord(tfrecord):
    tfrecord = tf.parse_single_example(tfrecord, features=feature_description)
    for key in bytes_description.keys():
        tfrecord[key] = tf.decode_raw(tfrecord[key], bytes_description[key]['dtype'])
        tfrecord[key] = tf.reshape(tfrecord[key], bytes_description[key]['shape'])
    return tfrecord


tf.reset_default_graph()
tf.random.set_random_seed(0)

dataset = tf.data.TFRecordDataset(filenames=filenames, compression_type='GZIP')
dataset = dataset.map(parse_tfrecord)
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(buffer_size=4)
dataset = dataset.shuffle(buffer_size=32)
dataset = dataset.repeat(count=None)

iterator = dataset.make_one_shot_iterator()
input_tensor_dict = iterator.get_next()


In [30]:
tensor_waveform_noisy = input_tensor_dict['foreground/signal']
tensor_waveform_clean = input_tensor_dict['foreground/signal']

with tf.variable_scope('audio_transform'):
    tensor_waveform_denoised = util_audio_transform.build_unet(
        tensor_waveform_noisy,
        signal_rate=20000,
        UNET_PARAMS={})

list_recognition_networks = [
    'arch1_taskA',
#     'arch1_taskR',
#     'arch1_taskW',
    'arch2_taskA',
#     'arch2_taskR',
#     'arch2_taskW',
    'arch3_taskA',
#     'arch3_taskR',
#     'arch3_taskW',
]
auditory_model_loss = util_auditory_model_loss.AuditoryModelLoss(
    list_recognition_networks=list_recognition_networks,
    tensor_wave0=tensor_waveform_clean,
    tensor_wave1=tensor_waveform_denoised)

auditory_model_loss.loss_waveform, auditory_model_loss.loss_cochlear_model, auditory_model_loss.loss_deep_features

# var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='audio_transform')
# saver = tf.train.Saver(var_list=var_list, max_to_keep=0)
# var_list


3 recognition networks included for deep feature loss:
|__ arch1_taskA: models/recognition_networks/arch1_taskA.ckpt-550000
|__ arch2_taskA: models/recognition_networks/arch2_taskA.ckpt-950000
|__ arch3_taskA: models/recognition_networks/arch3_taskA.ckpt-980000
Building waveform loss
Building cochlear model loss
[make_cos_filters_nx] using filter_spacing=`erb`
[make_cos_filters_nx] using filter_spacing=`erb`
Building deep feature loss (recognition network: arch1_taskA)
Building deep feature loss (recognition network: arch2_taskA)
Building deep feature loss (recognition network: arch3_taskA)


(<tf.Tensor 'Sum_25:0' shape=(?,) dtype=float32>,
 <tf.Tensor 'Sum_26:0' shape=(?,) dtype=float32>,
 <tf.Tensor 'add_51:0' shape=(?,) dtype=float32>)

In [32]:
with tf.Session() as sess:
#     example = sess.run(iterator.get_next())
#     for k in example.keys():
#         print(k, example[k].dtype, example[k].shape)
    sess.run(tf.global_variables_initializer())
    print(sess.run(auditory_model_loss.loss_waveform))


[0. 0. 0. 0. 0. 0. 0. 0.]
