In [13]:
import os
import sys
import glob
import numpy as np
import tensorflow as tf
import importlib
import IPython.display as ipd

import util_audio_transform
import util_auditory_model_loss
import util_cochlear_model
import util_recognition_network

importlib.reload(util_auditory_model_loss)


<module 'util_auditory_model_loss' from '/rdma/vast-rdma/vast/mcdermott/msaddler/auditory-model-denoising/util_auditory_model_loss.py'>

In [14]:
filenames = glob.glob('data/toy_dataset*.tfrecords')
batch_size = 8
feature_description = {
    'background/signal': tf.io.FixedLenFeature([], tf.string, default_value=None),
    'foreground/signal': tf.io.FixedLenFeature([], tf.string, default_value=None),
}
bytes_description = {
    'background/signal': {'dtype': tf.float32, 'shape': [40000]}, 
    'foreground/signal': {'dtype': tf.float32, 'shape': [40000]},
}

def parse_tfrecord(tfrecord):
    tfrecord = tf.parse_single_example(tfrecord, features=feature_description)
    for key in bytes_description.keys():
        tfrecord[key] = tf.decode_raw(tfrecord[key], bytes_description[key]['dtype'])
        tfrecord[key] = tf.reshape(tfrecord[key], bytes_description[key]['shape'])
    return tfrecord


tf.reset_default_graph()
tf.random.set_random_seed(0)

dataset = tf.data.TFRecordDataset(filenames=filenames, compression_type='GZIP')
dataset = dataset.map(parse_tfrecord)
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(buffer_size=4)
dataset = dataset.shuffle(buffer_size=32)
dataset = dataset.repeat(count=None)

iterator = dataset.make_one_shot_iterator()
input_tensor_dict = iterator.get_next()


In [15]:
tensor_waveform_noisy = input_tensor_dict['foreground/signal']
tensor_waveform_clean = input_tensor_dict['foreground/signal']
with tf.variable_scope('audio_transform'):
    tensor_waveform_denoised = util_audio_transform.build_unet(tensor_waveform_noisy)

list_recognition_networks = [
#     'arch1_taskA',
#     'arch2_taskA',
#     'arch3_taskA',
]
auditory_model = util_auditory_model_loss.AuditoryModelLoss(
    list_recognition_networks=list_recognition_networks,
    tensor_wave0=tensor_waveform_clean,
    tensor_wave1=tensor_waveform_denoised)

transform_var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='audio_transform')
# transform_saver = tf.train.Saver(var_list=transform_var_list, max_to_keep=0)

loss = auditory_model.loss_waveform
optimizer = tf.train.AdamOptimizer(learning_rate=1e-4)
train_op = optimizer.minimize(
    loss=loss,
    global_step=None,
    var_list=transform_var_list)


0 recognition networks included for deep feature loss:
Building waveform loss
Building cochlear model loss
[make_cos_filters_nx] using filter_spacing=`erb`
[make_cos_filters_nx] using filter_spacing=`erb`


In [None]:
with tf.Session() as sess:
#     example = sess.run(iterator.get_next())
#     for k in example.keys():
#         print(k, example[k].dtype, example[k].shape)
    sess.run(tf.global_variables_initializer())
    auditory_model.load_auditory_model_vars(sess)
    
    for itr0 in range(1000):
        _, batch_loss = sess.run([train_op, loss])
        if itr0 % 50 == 0:
            print(itr0, batch_loss.sum())


0 4563.756
50 1410.4463
100 817.84216
150 814.09973
200 564.2228
250 618.8761
