In [1]:
import os
import sys
import importlib
import json
import glob
import numpy as np
import tensorflow as tf
import scipy.io.wavfile
import IPython.display as ipd

import util_recognition_network
import util_cochlear_model
import util_audio_transform


In [2]:
ARCH = 1
TASK = 'A'

fn_ckpt = 'models/recognition_networks/arch{}_task{}.ckpt*'.format(ARCH, TASK)
fn_ckpt = glob.glob(fn_ckpt)[-1].replace('.index', '')
fn_arch = 'models/recognition_networks/arch{}.json'.format(ARCH)
with open(fn_arch, 'r') as f_arch:
    list_layer_dict = json.load(f_arch)

fn_ckpt, fn_arch


('models/recognition_networks/arch1_taskA.ckpt-550000',
 'models/recognition_networks/arch1.json')

In [3]:
importlib.reload(util_recognition_network)

tf.reset_default_graph()

if 'taskA' in fn_ckpt:
    n_classes_dict = {"/stimuli/labels_binary_via_int": 517}
else:
    n_classes_dict = {"/stimuli/word_int": 794}


tensor_input = tf.placeholder(tf.float32, [None, 40, 20000, 1])
tensor_output, tensors = util_recognition_network.build_network(
    tensor_input,
    list_layer_dict,
    n_classes_dict=n_classes_dict)

var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=None)
var_dict = {v.name: v for v in var_list}

saver = tf.train.Saver(var_list=var_list, max_to_keep=0)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    var_dict_init = sess.run(var_dict)
    saver.restore(sess, fn_ckpt)
    var_dict_load = sess.run(var_dict)

for v in var_list:
    print(v.name, v.shape, np.mean(np.abs(var_dict_init[v.name] - var_dict_load[v.name])))


Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.
Instructions for updating:
Use standard file APIs to check for files with this prefix.
INFO:tensorflow:Restoring parameters from models/recognition_networks/arch1_taskA.ckpt-550000
batch_norm_data_layer/gamma:0 (1,) 0.3478464
batch_norm_data_layer/beta:0 (1,) 0.28489918
batch_norm_data_layer/moving_mean:0 (1,) 0.07441781
batch_norm_data_layer/moving_variance:0 (1,) 0.9963481
conv_0/kernel:0 (2, 42, 1, 32) 0.055686645
conv_0/bias:0 (32,) 0.077409275
batch_norm_0/gamma:0 (32,) 0.07441161
batch_norm_0/beta:0 (32,) 0.17043708
batch_norm_0/moving_mean:0 (32,) 8.825504
batch_norm_0/moving_variance:0 (32,) 87.00698
conv_1/kernel:0 (2, 18, 32, 64) 0.032826066
conv_1/bias:0 (64,) 0.07326804
batch_norm_1/gamma:0 (64,) 0.1428659
batch_norm_1/beta:0 (64,) 0.11341882
batch_norm_1/moving_mean:0 (64,) 28.610592
batch

In [53]:
def get_rms(y):
    return np.sqrt(np.mean(np.square(y - np.mean(y))))

def set_rms(y, rms):
    y = y - np.mean(y)
    return rms * y / get_rms(y)


regex_fn_wav = 'audio/ex*_unprocessed_input.wav'
list_fn_wav = glob.glob(regex_fn_wav)

list_y = []
for fn_wav in list_fn_wav:
    sr, y = scipy.io.wavfile.read(fn_wav)
    y = set_rms(y, 0.02)
    assert sr == 20e3
    list_y.append(y)
list_y = np.array(list_y)
list_y.shape


(10, 40000)

In [7]:
importlib.reload(util_cochlear_model)
importlib.reload(util_audio_transform)

tf.reset_default_graph()

fn_ckpt = 'models/audio_transforms/unet_cochlear_reverse/model.ckpt-600000'

tensor_waveform = tf.placeholder(tf.float32, [None, 40000])

tensor_cochlear_representation, coch_container = util_cochlear_model.build_cochlear_model(
    tensor_waveform,
    signal_rate=20000,
    filter_type='half-cosine',
    filter_spacing='erb',
    HIGH_LIM=8000,
    LOW_LIM=20,
    N=40,
    SAMPLE_FACTOR=1,
    bandwidth_scale_factor=1.0,
    compression='stable_point3',
    include_highpass=False,
    include_lowpass=False,
    linear_max=1.0,
    rFFT=True,
    rectify_and_lowpass_subbands=True,
    return_subbands_only=True)

tensor_waveform_unet = util_audio_transform.build_unet(
    tensor_waveform,
    signal_rate=20000,
    UNET_PARAMS={})

var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=None)
saver = tf.train.Saver(var_list=var_list, max_to_keep=0)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.restore(sess, fn_ckpt)
    list_y_unet = sess.run(tensor_waveform_unet, feed_dict={tensor_waveform: list_y})
    list_coch_y = sess.run(tensor_cochlear_representation, feed_dict={tensor_waveform: list_y})
    list_coch_y_unet = sess.run(tensor_cochlear_representation, feed_dict={tensor_waveform: list_y_unet})


[make_cos_filters_nx] using filter_spacing=`erb`
INFO:tensorflow:Restoring parameters from models/audio_transforms/unet_cochlear_reverse/model.ckpt-600000


In [8]:
for itr0 in range(list_y_unet.shape[0]):
    y = list_y[itr0]
    y_unet = list_y_unet[itr0]
    print(list_fn_wav[itr0], np.mean(list_coch_y[itr0]), np.mean(list_coch_y_unet[itr0]))
    ipd.display(ipd.Audio(y, rate=sr))
    ipd.display(ipd.Audio(y_unet, rate=sr))


audio/ex01_unprocessed_input.wav 0.00070896186 0.0003957955


audio/ex02_unprocessed_input.wav 0.0011009874 0.00024696643


audio/ex03_unprocessed_input.wav 0.0009850197 0.00021446396


audio/ex04_unprocessed_input.wav 0.00089565513 0.0003951589


audio/ex05_unprocessed_input.wav 0.00080664817 0.00017993133


audio/ex06_unprocessed_input.wav 0.0007807299 0.00012441046


audio/ex07_unprocessed_input.wav 0.0009090955 0.00051825977


audio/ex08_unprocessed_input.wav 0.0010563683 0.00021237253


audio/ex09_unprocessed_input.wav 0.0010225364 0.00021431383


audio/ex10_unprocessed_input.wav 0.00086743536 0.00015796826


In [63]:
regex_recognition_network_ckpt = 'models/recognition_networks/*task*ckpt*index'
fn_deep_feature_loss_weights= 'models/recognition_networks/deep_feature_loss_weights.json'

list_recognition_network_ckpt = [
    fn_ckpt.replace('.index', '') for fn_ckpt in glob.glob(regex_recognition_network_ckpt)
]

with open(fn_deep_feature_loss_weights, 'r') as f:
    deep_feature_loss_weights = json.load(f)

dict_recognition_network = {}
for fn_ckpt in list_recognition_network_ckpt:
    key = os.path.basename(fn_ckpt).split('.')[0]
    if 'taskA' in key:
        n_classes_dict = {"task_audioset": 517}
    else:
        n_classes_dict = {"task_word": 794}
    dict_recognition_network[key] = {
        'fn_ckpt': fn_ckpt,
        'fn_arch': fn_ckpt[:fn_ckpt.rfind('_task')] + '.json',
        'n_classes_dict': n_classes_dict,
    }


tf.reset_default_graph()

tensor_waveform = tf.placeholder(tf.float32, [None, 40000])
tensor_coch_rep, coch_container = util_cochlear_model.build_cochlear_model(tensor_waveform)
deep_feature_tensors = {}
for key in sorted(dict_recognition_network.keys()):
    with open(dict_recognition_network[key]['fn_arch'], 'r') as f:
        list_layer_dict = json.load(f)
    with tf.variable_scope(key):
        _, tensors = util_recognition_network.build_network(
            tensor_coch_rep,
            list_layer_dict,
            n_classes_dict=dict_recognition_network[key]['n_classes_dict'])
        deep_feature_tensors[key] = {}
        for tensor_name in sorted(tensors.keys()):
            if tensor_name in deep_feature_loss_weights[key]:
                deep_feature_tensors[key][tensor_name] = tf.math.multiply(
                    tensors[tensor_name],
                    deep_feature_loss_weights[key][tensor_name])
    var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=key)
    dict_recognition_network[key]['saver'] = tf.train.Saver(
        var_list={v.name.replace(key + '/', '').replace(':0', ''): v for v in var_list},
        max_to_keep=0)
    
    print(key, len(var_list))

deep_feature_tensors


[make_cos_filters_nx] using filter_spacing=`erb`
arch1_taskA 54
arch1_taskR 54
arch1_taskW 54
arch2_taskA 56
arch2_taskR 56
arch2_taskW 56
arch3_taskA 50
arch3_taskR 50
arch3_taskW 50


{'arch1_taskA': {'batch_norm_0': <tf.Tensor 'arch1_taskA/Mul_5:0' shape=(?, 20, 4986, 32) dtype=float32>,
  'batch_norm_1': <tf.Tensor 'arch1_taskA/Mul_6:0' shape=(?, 10, 1239, 64) dtype=float32>,
  'batch_norm_2': <tf.Tensor 'arch1_taskA/Mul_7:0' shape=(?, 10, 305, 128) dtype=float32>,
  'batch_norm_3': <tf.Tensor 'arch1_taskA/Mul_8:0' shape=(?, 10, 72, 256) dtype=float32>,
  'batch_norm_4': <tf.Tensor 'arch1_taskA/Mul_9:0' shape=(?, 10, 65, 512) dtype=float32>,
  'batch_norm_5': <tf.Tensor 'arch1_taskA/Mul_10:0' shape=(?, 10, 60, 512) dtype=float32>,
  'batch_norm_6': <tf.Tensor 'arch1_taskA/Mul_11:0' shape=(?, 5, 10, 512) dtype=float32>},
 'arch1_taskR': {'batch_norm_0': <tf.Tensor 'arch1_taskR/Mul_5:0' shape=(?, 20, 4986, 32) dtype=float32>,
  'batch_norm_1': <tf.Tensor 'arch1_taskR/Mul_6:0' shape=(?, 10, 1239, 64) dtype=float32>,
  'batch_norm_2': <tf.Tensor 'arch1_taskR/Mul_7:0' shape=(?, 10, 305, 128) dtype=float32>,
  'batch_norm_3': <tf.Tensor 'arch1_taskR/Mul_8:0' shape=(?, 1

In [65]:
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for key in sorted(dict_recognition_network.keys()):
        dict_recognition_network[key]['saver'].restore(sess, dict_recognition_network[key]['fn_ckpt'])

    list_y_deep_feature_tensors = sess.run(deep_feature_tensors, feed_dict={tensor_waveform: list_y[0:1]})

for k0 in sorted(list_y_deep_feature_tensors.keys()):
    for k1 in sorted(list_y_deep_feature_tensors[k0].keys()):
        print(k0, k1, list_y_deep_feature_tensors[k0][k1].shape)


INFO:tensorflow:Restoring parameters from models/recognition_networks/arch1_taskA.ckpt-550000
INFO:tensorflow:Restoring parameters from models/recognition_networks/arch1_taskR.ckpt-0
INFO:tensorflow:Restoring parameters from models/recognition_networks/arch1_taskW.ckpt-360000
INFO:tensorflow:Restoring parameters from models/recognition_networks/arch2_taskA.ckpt-950000
INFO:tensorflow:Restoring parameters from models/recognition_networks/arch2_taskR.ckpt-0
INFO:tensorflow:Restoring parameters from models/recognition_networks/arch2_taskW.ckpt-610000
INFO:tensorflow:Restoring parameters from models/recognition_networks/arch3_taskA.ckpt-980000
INFO:tensorflow:Restoring parameters from models/recognition_networks/arch3_taskR.ckpt-0
INFO:tensorflow:Restoring parameters from models/recognition_networks/arch3_taskW.ckpt-580000
arch1_taskA batch_norm_0 (1, 20, 4986, 32)
arch1_taskA batch_norm_1 (1, 10, 1239, 64)
arch1_taskA batch_norm_2 (1, 10, 305, 128)
arch1_taskA batch_norm_3 (1, 10, 72, 256