In [12]:
import os
import sys
import glob
import numpy as np
import tensorflow as tf
import importlib
import IPython.display as ipd

import util_audio_preprocess
import util_audio_transform
import util_auditory_model_loss
import util_cochlear_model
import util_recognition_network


In [13]:
filenames = glob.glob('data/toy_dataset*.tfrecords')
batch_size = 8
feature_description = {
    'background/signal': tf.io.FixedLenFeature([], tf.string, default_value=None),
    'foreground/signal': tf.io.FixedLenFeature([], tf.string, default_value=None),
}
bytes_description = {
    'background/signal': {'dtype': tf.float32, 'shape': [40000]}, 
    'foreground/signal': {'dtype': tf.float32, 'shape': [40000]},
}


def parse_tfrecord(tfrecord):
    tfrecord = tf.parse_single_example(tfrecord, features=feature_description)
    for key in bytes_description.keys():
        tfrecord[key] = tf.decode_raw(tfrecord[key], bytes_description[key]['dtype'])
        tfrecord[key] = tf.reshape(tfrecord[key], bytes_description[key]['shape'])
    return tfrecord


def preprocess_audio_batch(batch):
    """
    Function combines foreground (speech) and background (noise) audio
    signals with signal-to-noise ratios drawn uniformly between -20 and
    +10 dB. The returned dictionary contains the noisy speech signal,
    the clean speech signal, and the SNR.
    """
    foreground_signal = batch['foreground/signal']
    background_signal = batch['background/signal']
    snr = tf.random.uniform(
        [tf.shape(foreground_signal)[0], 1],
        minval=-20.0,
        maxval=10.0,
        dtype=foreground_signal.dtype)
    signal_in_noise, signal, noise_scaled = util_audio_preprocess.tf_set_snr(
        foreground_signal,
        background_signal,
        snr)
    batch = {
        'snr': snr,
        'waveform_noisy': signal_in_noise,
        'waveform_clean': signal,
    }
    return batch


tf.reset_default_graph()
tf.random.set_random_seed(0)

dataset = tf.data.TFRecordDataset(filenames=filenames, compression_type='GZIP')
dataset = dataset.map(parse_tfrecord)
dataset = dataset.batch(batch_size)
dataset = dataset.map(preprocess_audio_batch)
dataset = dataset.prefetch(buffer_size=4)
dataset = dataset.shuffle(buffer_size=32)
dataset = dataset.repeat(count=None)

iterator = dataset.make_one_shot_iterator()
input_tensor_dict = iterator.get_next()


In [14]:
tensor_waveform_noisy = input_tensor_dict['waveform_noisy']
tensor_waveform_clean = input_tensor_dict['waveform_clean']
tensor_waveform_denoised = util_audio_transform.build_unet(tensor_waveform_noisy)

list_recognition_networks = [
    'arch1_taskA',
#     'arch2_taskA',
#     'arch3_taskA',
]
auditory_model = util_auditory_model_loss.AuditoryModelLoss(
    list_recognition_networks=list_recognition_networks,
    tensor_wave0=tensor_waveform_clean,
    tensor_wave1=tensor_waveform_denoised)

transform_var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='separator')
transform_saver = tf.train.Saver(var_list=transform_var_list, max_to_keep=0)

# loss = auditory_model.loss_waveform
loss = auditory_model.loss_cochlear_model
# loss = auditory_model.loss_deep_features
optimizer = tf.train.AdamOptimizer(learning_rate=1e-4)
train_op = optimizer.minimize(
    loss=loss,
    global_step=None,
    var_list=transform_var_list)

transform_var_list


1 recognition networks included for deep feature loss:
|__ arch1_taskA: models/recognition_networks/arch1_taskA.ckpt-550000
Building waveform loss
Building cochlear model loss
[make_cos_filters_nx] using filter_spacing=`erb`
[make_cos_filters_nx] using filter_spacing=`erb`
Building deep feature loss (recognition network: arch1_taskA)


[<tf.Variable 'separator/conv1d/kernel:0' shape=(15, 1, 24) dtype=float32_ref>,
 <tf.Variable 'separator/conv1d/bias:0' shape=(24,) dtype=float32_ref>,
 <tf.Variable 'separator/conv1d_1/kernel:0' shape=(15, 24, 48) dtype=float32_ref>,
 <tf.Variable 'separator/conv1d_1/bias:0' shape=(48,) dtype=float32_ref>,
 <tf.Variable 'separator/conv1d_2/kernel:0' shape=(15, 48, 72) dtype=float32_ref>,
 <tf.Variable 'separator/conv1d_2/bias:0' shape=(72,) dtype=float32_ref>,
 <tf.Variable 'separator/conv1d_3/kernel:0' shape=(15, 72, 96) dtype=float32_ref>,
 <tf.Variable 'separator/conv1d_3/bias:0' shape=(96,) dtype=float32_ref>,
 <tf.Variable 'separator/conv1d_4/kernel:0' shape=(15, 96, 120) dtype=float32_ref>,
 <tf.Variable 'separator/conv1d_4/bias:0' shape=(120,) dtype=float32_ref>,
 <tf.Variable 'separator/conv1d_5/kernel:0' shape=(15, 120, 144) dtype=float32_ref>,
 <tf.Variable 'separator/conv1d_5/bias:0' shape=(144,) dtype=float32_ref>,
 <tf.Variable 'separator/conv1d_6/kernel:0' shape=(15, 144

In [15]:
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    auditory_model.load_auditory_model_vars(sess)
    
    for step in range(101):
        _, step_loss = sess.run([train_op, loss])
        if step % 5 == 0:
            print("Loss after training step {:06d} = {:.02f}".format(step, step_loss.mean()))
    
    transform_saver.save(
        sess,
        save_path='new_model.ckpt',
        global_step=step,
        write_meta_graph=False)


Loading `arch1_taskA` variables from models/recognition_networks/arch1_taskA.ckpt-550000
Instructions for updating:
Use standard file APIs to check for files with this prefix.
INFO:tensorflow:Restoring parameters from models/recognition_networks/arch1_taskA.ckpt-550000
INFO:tensorflow:Restoring parameters from models/recognition_networks/arch1_taskA.ckpt-550000
Loss after training step 000000 = 552.91
Loss after training step 000005 = 652.04
Loss after training step 000010 = 456.93
Loss after training step 000015 = 491.50
Loss after training step 000020 = 488.89
Loss after training step 000025 = 457.64
Loss after training step 000030 = 455.87
Loss after training step 000035 = 448.33
Loss after training step 000040 = 447.42
Loss after training step 000045 = 460.21
Loss after training step 000050 = 436.42
Loss after training step 000055 = 493.47
Loss after training step 000060 = 457.38
Loss after training step 000065 = 441.01
Loss after training step 000070 = 420.04
Loss after training s