In [None]:
from __future__ import absolute_import, division, print_function

import numpy as np
import tensorflow as tf
import utils
from scipy.signal import spectrogram, istft
from librispeech_mixer import LibriSpeechMixer
from keras.layers import Input, Dense, Conv1D, MaxPooling2D, Conv2DTranspose, UpSampling2D, Reshape, Flatten, Dropout, BatchNormalization
from tensorflow.contrib.layers import flatten
import IPython
from os import listdir
from keras import backend as K


def mask_to_outputs(mixed_real, mixed_imag, scaled_mask_real, scaled_mask_imag, C, K):
    unscaled_mask_real = -tf.log((K - scaled_mask_real)/(K + scaled_mask_real+1e-10)+1e-30)/C
    unscaled_mask_imag = -tf.log((K - scaled_mask_imag)/(K + scaled_mask_imag+1e-10)+1e-30)/C

    sep1 = tf.multiply(tf.complex(unscaled_mask_real, unscaled_mask_imag), tf.complex(mixed_real, mixed_imag))
    sep2 = tf.complex(mixed_real, mixed_imag) - sep1

    return sep1, sep2

tf.reset_default_graph()
K.set_learning_phase(1) #set learning phase

#Create the LibriSpeech mixer
mixer = LibriSpeechMixer(dataset_built=True)

#parse function to get data from the dataset correctly
def _parse_function(example_proto):
    keys_to_features = {'mixed_real':tf.FixedLenFeature((mixer.spec_length, mixer.nb_freq), tf.float32),
                        'mixed_imag':tf.FixedLenFeature((mixer.spec_length, mixer.nb_freq), tf.float32),
                        'mask_real': tf.FixedLenFeature((mixer.spec_length, mixer.nb_freq), tf.float32),
                        'mask_imag': tf.FixedLenFeature((mixer.spec_length, mixer.nb_freq), tf.float32),
                        }
    parsed_features = tf.parse_single_example(example_proto, keys_to_features)
    return tf.concat([parsed_features['mixed_real'], parsed_features['mixed_imag']], axis=1),\
            tf.concat([parsed_features['mask_real'], parsed_features['mask_imag']], axis=1)

#Create the dataset object
batch_size = 64

#Placeholder to be able to specify either the training or validation set
filenames = tf.placeholder(tf.string, shape=[None])
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(_parse_function)
dataset = dataset.shuffle(buffer_size=2500)
dataset = dataset.batch(batch_size)
dataset = dataset.repeat()
iterator = dataset.make_initializable_iterator()
x_pl, y_pl = iterator.get_next()

training_filenames = ["/mnt/train/" + filename for filename in listdir("/mnt/train/")]
validation_filenames = ["/mnt/dev/" + filename for filename in listdir("/mnt/dev/")]


height, width, nchannels = mixer.nb_freq, mixer.spec_length, 1
padding = 'same'

filters = mixer.nb_freq*2
kernel_size = 3

print('Trace of the tensors shape as it is propagated through the network.')
print('Layer name \t Output size')
print('----------------------------')

with tf.variable_scope('convLayer1'):

    conv1 = Conv1D(round(filters), kernel_size, padding=padding, activation='relu')
    print('x_pl \t\t', x_pl.get_shape())
    x = conv1(x_pl)
    print('conv1 \t\t', x.get_shape())

    conv2 = Conv1D(round(filters), kernel_size, padding=padding, activation='relu')
    x = conv2(x)
    print('conv2 \t\t', x.get_shape())
    
    conv3 = Conv1D(round(filters), kernel_size, padding=padding, activation='relu')
    x = conv3(x)
    print('conv3 \t\t', x.get_shape())

    conv4 = Conv1D(round(filters), kernel_size, padding=padding, activation='relu')
    x = conv4(x)
    print('conv4 \t\t', x.get_shape())
    enc_cell = tf.nn.rnn_cell.GRUCell(mixer.nb_freq*2, activation = tf.nn.relu)
    x, enc_state = tf.nn.dynamic_rnn(cell=enc_cell, inputs=x,
                                     dtype=tf.float32)
    
    convend = Conv1D(round(filters), 1, padding=padding, activation='tanh')
    y = convend(x)
    print('convend \t\t', x.get_shape())
    
    y = y * mixer.K
print('Model consits of ', utils.num_params(), 'trainable parameters.')
# restricting memory usage, TensorFlow is greedy and will use all memory otherwise
gpu_opts = tf.GPUOptions(per_process_gpu_memory_fraction=0.99)
"""## Launch TensorBoard, and visualize the TF graph
with tf.Session(config=tf.ConfigProto(gpu_options=gpu_opts)) as sess:
    tmp_def = utils.rename_nodes(sess.graph_def, lambda s:"/".join(s.split('_',1)))
    utils.show_graph(tmp_def)"""


with tf.variable_scope('loss'):
    # The loss takes the amplitude of the output into account, in order to avoid taking care of noise
    y_target1, y_target2 = mask_to_outputs(x_pl[:, :, :mixer.nb_freq], x_pl[:, :, mixer.nb_freq:],\
                                            y_pl[:, :, :mixer.nb_freq], y_pl[:, :, mixer.nb_freq:], mixer.C, mixer.K)

    y_pred1, y_pred2 = mask_to_outputs(x_pl[:, :, :mixer.nb_freq], x_pl[:, :, mixer.nb_freq:],\
                                            y[:, :, :mixer.nb_freq], y[:, :, mixer.nb_freq:], mixer.C, mixer.K)
    
    mean_square_error = tf.reduce_mean((tf.real(y_target1) - tf.real(y_pred1))**2 + \
                                        (tf.real(y_target2) - tf.real(y_pred2))**2 + \
                                        (tf.imag(y_target1) - tf.imag(y_pred1))**2 + \
                                        (tf.imag(y_target2) - tf.imag(y_pred2))**2)

    #L2 regularization
    """reg_scale = 0.00001
    regularize = tf.contrib.layers.l2_regularizer(reg_scale)
    params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
    reg_term = sum([regularize(param) for param in params])
    mean_square_error += reg_term"""


with tf.variable_scope('training'):
    # defining our optimizer
    optimizer = tf.train.AdamOptimizer(learning_rate=0.001)

    # applying the gradients
    train_op = optimizer.minimize(mean_square_error)

#Test the forward pass
with tf.Session(config=tf.ConfigProto(gpu_options=gpu_opts)) as sess:
    sess.run(iterator.initializer, feed_dict={filenames: training_filenames})
    sess.run(tf.global_variables_initializer())
    y_pred = sess.run(fetches=y)

assert y_pred.shape[1:] == y_pl.shape[1:], "ERROR the output shape is not as expected!"         + " Output shape should be " + str(y_pl.shape) + ' but was ' + str(y_pred.shape)

print('Forward pass successful!')


In [None]:
# ## Training

#Training Loop

max_epochs = 25


valid_loss = []
train_loss = []
test_loss = []


def trainingLoop():
    with tf.Session(config=tf.ConfigProto(gpu_options=gpu_opts)) as sess:
        saver = tf.train.Saver()
        sess.run(iterator.initializer, feed_dict={filenames: training_filenames})
        sess.run(tf.global_variables_initializer())
        print('Begin training loop')
        
        nb_batches_processed = 0
        nb_epochs = 0
        try:

            while nb_epochs < max_epochs:
                _train_loss = []

                ## Run train op
                fetches_train = [train_op, mean_square_error]
                _, _loss = sess.run(fetches_train)

                _train_loss.append(_loss)
                
                nb_batches_processed += 1
                
                ## Compute validation loss once per epoch
                if round(nb_batches_processed/mixer.nb_seg_train*batch_size-0.5) > nb_epochs:
                    nb_epochs += 1
                    
                    sess.run(iterator.initializer, feed_dict={filenames: validation_filenames})
                    _valid_loss = []
                    train_loss.append(np.mean(_train_loss))

                    fetches_valid = [mean_square_error]
                    
                    nb_test_batches_processed = 0
                    #Proceed to a whole testing epoch
                    while round(nb_test_batches_processed/mixer.nb_seg_test*batch_size-0.5) < 1:
                        
                        _loss = sess.run(fetches_valid)

                        _valid_loss.append(_loss)
                        nb_test_batches_processed += 1
                        
                    valid_loss.append(np.mean(_valid_loss))


                    print("Epoch {} : Train Loss {:6.3f}, Valid loss {:6.3f}".format(
                        nb_epochs, train_loss[-1], valid_loss[-1]))
                    sess.run(iterator.initializer, feed_dict={filenames: training_filenames})
        
        except KeyboardInterrupt:
            pass
        
        save_path = saver.save(sess, "./model.ckpt")
            
        #Display how the model perform, mask and sound
        x_batch, y_batch, phase = sess.run(iterator.get_next())
        phase = np.transpose(phase[0,:,:])
        x_batch = np.transpose(x_batch)
        y_batch = np.transpose(y_batch)
        y_pred = np.transpose(sess.run(fetches=y))
        
        sp_y1_targ = np.multiply(x_batch[0,:,:],y_batch[0,:,:])
        sp_y1_rec = np.multiply(x_batch[0,:,:],y_pred[0,:,:])
        sp_y2_targ = np.multiply((np.ones(x_batch[0,:,:].shape)-y_batch[0,:,:]), x_batch[0,:,:])
        sp_y2_rec = np.multiply((np.ones(x_batch[0,:,:].shape)-y_pred[0,:,:]), x_batch[0,:,:])

        sp_y1_rec_phase = []
        for i in range(len(sp_y1_rec)):
            rec_line = []
            for n in range(len(sp_y1_rec[0])) :
                rec_line.append( sp_y1_rec[i][n]*np.cos(phase[i][n]) + 1j*sp_y1_rec[i][n]*np.sin(phase[i][n]) )

            sp_y1_rec_phase.append(rec_line)
            
        sp_y2_rec_phase = []
        for i in range(len(sp_y2_rec)):
            rec_line = []
            for n in range(len(sp_y2_rec[0])) :
                rec_line.append( sp_y2_rec[i][n]*np.cos(phase[i][n]) + 1j*sp_y2_rec[i][n]*np.sin(phase[i][n]) )

            sp_y2_rec_phase.append(rec_line)
            
        sp_y1_targ_phase = []
        for i in range(len(sp_y1_targ)):
            rec_line = []
            for n in range(len(sp_y1_targ[0])) :
                rec_line.append( sp_y1_targ[i][n]*np.cos(phase[i][n]) + 1j*sp_y1_targ[i][n]*np.sin(phase[i][n]) )

            sp_y1_targ_phase.append(rec_line)
            
        sp_y2_targ_phase = []
        for i in range(len(sp_y2_targ)):
            rec_line = []
            for n in range(len(sp_y2_targ[0])) :
                rec_line.append(sp_y2_targ[i][n]*np.cos(phase[i][n]) + 1j*sp_y2_targ[i][n]*np.sin(phase[i][n]) )

            sp_y2_targ_phase.append(rec_line)
        
        framerate=16000
        print(sp_y1_targ_phase)
        t1, y1_targ = istft(sp_y1_targ_phase, fs=framerate)
        print('Speaker A target')
        IPython.display.display(IPython.display.Audio(y1_targ,rate=framerate))
        
        t1, y1_rec = istft(sp_y1_rec_phase, fs=framerate)
        print('Speaker A prediction')
        IPython.display.display(IPython.display.Audio(y1_rec,rate=framerate))
        
        t2, y2_targ = istft(sp_y2_targ_phase, fs=framerate)
        print('Speaker B target')
        IPython.display.display(IPython.display.Audio(y2_targ,rate=framerate))
        
        t2, y2_rec = istft(sp_y2_rec_phase, fs=framerate)
        print('Speaker B prediction')
        IPython.display.display(IPython.display.Audio(y2_rec,rate=framerate))

        
        plt.pcolormesh(10 * np.log10(x_batch[0,:,:]+1e-10))
        plt.axis('tight')
        plt.ylabel('Frequency [Hz]')
        plt.xlabel('Time [sec]')
        plt.title('Input')
        plt.colorbar()
        plt.show()
        
                       
        plt.pcolormesh(y_batch[0,:,:])
        plt.axis('tight')
        plt.ylabel('Frequency [Hz]')
        plt.xlabel('Time [sec]')
        plt.title('Real mask')
        plt.colorbar()
        plt.show()
        
        plt.pcolormesh(y_pred[0,:,:])
        plt.axis('tight')
        plt.ylabel('Frequency [Hz]')
        plt.xlabel('Time [sec]')
        plt.title('Predicted mask')
        plt.colorbar()
        plt.show()

    
trainingLoop();


    
epoch = np.arange(len(train_loss))
plt.figure()
plt.plot(epoch, train_loss,'r', epoch, valid_loss,'b')
plt.legend(['Train Loss','Val Loss'], loc=4)
plt.xlabel('Epochs'), plt.ylabel('Loss')

In [None]:
from scipy.io import wavfile
from separation import bss_eval_sources

def mask_to_outputs_np(mixed_real, mixed_imag, scaled_mask_real, scaled_mask_imag, C, K):
    unscaled_mask_real = -np.log((K - scaled_mask_real)/(K + scaled_mask_real+1e-10)+1e-30)/C
    unscaled_mask_imag = -np.log((K - scaled_mask_imag)/(K + scaled_mask_imag+1e-10)+1e-30)/C

    sep1 = np.multiply(unscaled_mask_real +  1j * unscaled_mask_imag, mixed_real + 1j * mixed_imag)
    sep2 = mixed_real + 1j * mixed_imag - sep1

    return sep1, sep2

#Load the saved model and test it
saver = tf.train.Saver()

# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
with tf.Session(config = tf.ConfigProto(device_count = {'GPU': 0})) as sess:
    sess.run(iterator.initializer, feed_dict={filenames: validation_filenames})
    saver.restore(sess, "complex3filters5convsamefilters.ckpt")
    print("Model restored.")
    x_batch, y_batch = sess.run(iterator.get_next())

    y_batch = np.transpose(y_batch, axes=[0,2,1])
    
    y_pred = np.transpose(sess.run(fetches=y, feed_dict={x_pl: x_batch}), axes = [0,2,1])
    
    x_batch = np.transpose(x_batch, axes=[0,2,1])
    sp_y1_targ, sp_y2_targ = mask_to_outputs_np(x_batch[0, :mixer.nb_freq, :], x_batch[0, mixer.nb_freq:, :],\
                                            y_batch[0, :mixer.nb_freq, :], y_batch[0, mixer.nb_freq:, :], mixer.C, mixer.K)

    sp_y1_rec, sp_y2_rec = mask_to_outputs_np(x_batch[0, :mixer.nb_freq, :], x_batch[0, mixer.nb_freq:, :],\
                                            y_pred[0, :mixer.nb_freq, :], y_pred[0, mixer.nb_freq:, :], mixer.C, mixer.K) 

    framerate=16000

    tm, mixed = istft(x_batch[0, :mixer.nb_freq, :] + 1j * x_batch[0, mixer.nb_freq:, :], fs=framerate)
    print('Mixed')
    IPython.display.display(IPython.display.Audio(mixed,rate=framerate))
    
    t1, y1_targ = istft(sp_y1_targ, fs=framerate)
    print('Speaker A target')
    IPython.display.display(IPython.display.Audio(y1_targ,rate=framerate))
    
    t1, y1_rec = istft(sp_y1_rec, fs=framerate)
    print('Speaker A prediction')
    IPython.display.display(IPython.display.Audio(y1_rec,rate=framerate))

    t2, y2_targ = istft(sp_y2_targ, fs=framerate)
    print('Speaker B target')
    IPython.display.display(IPython.display.Audio(y2_targ,rate=framerate))

    t2, y2_rec = istft(sp_y2_rec, fs=framerate)
    print('Speaker B prediction')
    IPython.display.display(IPython.display.Audio(y2_rec,rate=framerate))
            
    wavfile.write('mixed.wav', framerate, mixed / np.max(np.abs(mixed)))
    wavfile.write('spkA_target.wav', framerate, y1_targ / np.max(np.abs(y1_targ)))
    wavfile.write('spkA_prediction.wav', framerate, y1_rec / np.max(np.abs(y1_targ)))
    wavfile.write('spkB_target.wav', framerate, y2_targ / np.max(np.abs(y1_targ)))
    wavfile.write('spkB_prediction.wav', framerate, y2_rec / np.max(np.abs(y1_targ)))

    indexes = bss_eval_sources( np.array([y1_targ, y2_targ]), np.array([y1_rec, y2_rec]) )
    print(indexes)

In [None]:
def compute_indexes(x_batch, y_batch, y_pred) :   

    y_batch = np.transpose(y_batch)
    
    y_pred = np.transpose(y_pred)
    
    x_batch = np.transpose(x_batch)
    sp_y1_targ, sp_y2_targ = mask_to_outputs_np(x_batch[:mixer.nb_freq, :], x_batch[mixer.nb_freq:, :],\
                                            y_batch[:mixer.nb_freq, :], y_batch[mixer.nb_freq:, :], mixer.C, mixer.K)

    sp_y1_rec, sp_y2_rec = mask_to_outputs_np(x_batch[:mixer.nb_freq, :], x_batch[mixer.nb_freq:, :],\
                                            y_pred[:mixer.nb_freq, :], y_pred[mixer.nb_freq:, :], mixer.C, mixer.K) 

    framerate=16000

    t1, y1_targ = istft(sp_y1_targ, fs=framerate)
    
    t1, y1_rec = istft(sp_y1_rec, fs=framerate)

    t2, y2_targ = istft(sp_y2_targ, fs=framerate)

    t2, y2_rec = istft(sp_y2_rec, fs=framerate)
    indexes = bss_eval_sources( np.array([y1_targ, y2_targ]), np.array([y1_rec, y2_rec]) )
    
    return indexes

max_epoch = 1
with tf.Session(config=tf.ConfigProto(gpu_options=gpu_opts)) as sess:
    saver = tf.train.Saver()    
    sess.run(iterator.initializer, feed_dict={filenames: training_filenames})
    saver.restore(sess, "complex3filters5convsamefilters.ckpt")

    nb_batches_processed = 0
    nb_epochs = 0
    max_epochs = 1
    _train_indexes = []
    train_indexes = []
    valid_indexes = []
    try:

        while nb_epochs < max_epochs:
            
            ## Run train op
            x_batch, y_batch = sess.run(iterator.get_next())

            y_pred = sess.run(fetches=y, feed_dict={x_pl: x_batch})
            
            _indexes = compute_indexes(x_batch[0], y_batch[0], y_pred[0])
            _train_indexes.append([_indexes[0], _indexes[1], _indexes[2]])

            nb_batches_processed += 1
            print(nb_batches_processed, np.mean(np.mean(np.array(_train_indexes), axis = 0), axis=1))

            ## Compute validation loss once per epoch
            if True or round(nb_batches_processed/mixer.nb_seg_train*batch_size-0.5) > nb_epochs:
                nb_epochs += 1

                sess.run(iterator.initializer, feed_dict={filenames: validation_filenames})
                _valid_indexes = []
                train_indexes.append(np.mean(np.mean(np.array(_train_indexes), axis = 0), axis = 1))
                _train_indexes = []

                nb_test_batches_processed = 0
                #Proceed to a whole testing epoch
                while round(nb_test_batches_processed/mixer.nb_seg_test*batch_size-0.5) < 1:

                    x_batch, y_batch = sess.run(iterator.get_next())

                    y_pred = sess.run(fetches=y, feed_dict={x_pl: x_batch})
                    
                    for j in range(0,x_batch.shape[0]):
                        _indexes = compute_indexes(x_batch[j], y_batch[j], y_pred[j])

                        _valid_indexes.append([_indexes[0], _indexes[1], _indexes[2]])
                        print(nb_test_batches_processed, np.mean(np.mean(np.array(_valid_indexes), axis = 0), axis=1))
                    nb_test_batches_processed += 1
                    


                valid_indexes.append(np.mean(np.mean(np.array(_valid_indexes), axis = 0), axis = 1))


                print("train indexes:", train_indexes[-1], 
                      "valid indexes", valid_indexes[-1])
                sess.run(iterator.initializer, feed_dict={filenames: training_filenames})

    except KeyboardInterrupt:
        pass