In [1]:
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '1'

In [2]:
import sys

SOURCE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__name__)))
sys.path.insert(0, SOURCE_DIR)

In [3]:
import tensorflow as tf
import malaya_speech
import malaya_speech.train
from malaya_speech.train.model import unet
from malaya_speech.utils import tf_featurization
from tensorflow.keras.layers import Multiply
import IPython.display as ipd
import numpy as np






The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.



In [74]:
from itertools import permutations
import numpy as np

reduce_time = 0.02307655849224936
EPS = 1e-8

def get_stft(X):
    batch_size = tf.shape(X)[0]
    stft_X = tf.TensorArray(
        dtype = tf.complex64,
        size = batch_size,
        dynamic_size = False,
        infer_shape = False,
    )
    D_X = tf.TensorArray(
        dtype = tf.float32,
        size = batch_size,
        dynamic_size = False,
        infer_shape = False,
    )

    init_state = (0, stft_X, D_X)

    def condition(i, stft, D):
        return i < batch_size

    def body(i, stft, D):
        stft_x, D_x = tf_featurization.get_stft(X[i])
        return i + 1, stft.write(i, stft_x), D.write(i, D_x)

    _, stft_X, D_X = tf.while_loop(condition, body, init_state)
    stft_X = stft_X.stack()
    stft_X.set_shape((None, None, 2049, 1))
    D_X = D_X.stack()
    D_X.set_shape((None, None, 512, 1024, 1))
    return stft_X, D_X

def log10(x):
    numerator = tf.log(x)
    denominator = tf.log(tf.constant(10, dtype = numerator.dtype))
    return numerator / denominator

class Model:
    def __init__(self, size = 4):
        self.X = tf.placeholder(tf.float32, (None, None))
        self.Y = tf.placeholder(tf.float32, (None, size, None))
        self.length = tf.placeholder(tf.float32, (None,))
        self.lengths = tf.cast(self.length / reduce_time, tf.int32)
        
        stft_X, D_X = get_stft(self.X)
        
        self.stft = []
        for i in range(size):
            self.stft.append(get_stft(self.Y[:, i]))
            
        self.outputs = []
        for i in range(size):
            with tf.variable_scope(f'model_{i}'):
                output = unet.Model3D(
                    D_X, dropout = 0.0, training = True
                ).logits
                self.outputs.append(output)
        
        batch_size = tf.shape(self.outputs[0])[0]
        fft_size = self.outputs[0].shape[3]
        
        labels = [i[1] for i in self.stft]
        labels = tf.concat(labels, axis = 4)
        labels = tf.reshape(labels, [batch_size, -1, fft_size, size])
        labels = tf.transpose(labels, perm = [0, 3, 1, 2])
        
        concatenated = tf.concat(self.outputs, axis = 4)
        concatenated = tf.reshape(concatenated, [batch_size, -1, fft_size, size])
        concatenated = tf.transpose(concatenated, perm = [0, 3, 1, 2])
        
        mask = tf.cast(
            tf.sequence_mask(self.lengths, tf.shape(concatenated)[2]),
            concatenated.dtype,
        )
        mask = tf.expand_dims(mask, 1)
        mask = tf.expand_dims(mask, -1)
        
        labels = labels * mask
        concatenated = concatenated * mask
        
        # https://github.com/asteroid-team/asteroid/blob/master/asteroid/losses/mse.py
        targets = tf.expand_dims(labels, 1)
        est_targets = tf.expand_dims(concatenated, 2)
        pw_loss = tf.abs(targets - est_targets)
        pair_wise_abs = tf.reduce_mean(pw_loss, axis = [3, 4])
        
        perms = tf.convert_to_tensor(np.array(list(permutations(range(size)))))
        perms = tf.cast(perms, tf.int32)
        index = tf.expand_dims(perms, 2)
        ones = tf.ones(tf.reduce_prod(tf.shape(index)))
        perms_one_hot = tf.zeros((tf.shape(perms)[0], tf.shape(perms)[1], size))

        indices = index
        tensor = perms_one_hot
        original_tensor = tensor
        indices = tf.reshape(indices, shape = [-1, tf.shape(indices)[-1]])
        indices_add = tf.expand_dims(
            tf.range(0, tf.shape(indices)[0], 1) * (tf.shape(tensor)[-1]), axis = -1
        )
        indices += indices_add
        tensor = tf.reshape(perms_one_hot, shape = [-1])
        indices = tf.reshape(indices, shape = [-1, 1])
        updates = tf.reshape(ones, shape = [-1])
        scatter = tf.tensor_scatter_nd_update(tensor, indices, updates)
        perms_one_hot = tf.reshape(
            scatter,
            shape = [
                tf.shape(original_tensor)[0],
                tf.shape(original_tensor)[1],
                -1,
            ],
        )

        abs_set = tf.einsum('bij,pij->bp', pair_wise_abs, perms_one_hot)
        min_abs = tf.reduce_min(abs_set, axis = 1, keepdims = True)
        min_abs /= size
        self.loss = tf.reduce_mean(min_abs)

In [75]:
tf.compat.v1.reset_default_graph()
model = Model()
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())

In [76]:
y, _ = malaya_speech.load('../speech/example-speaker/husein-zolkepli.wav', sr = 44100)
y1, _ = malaya_speech.load('../speech/example-speaker/shafiqah-idayu.wav', sr = 44100)
y = y[:sr * 3]
y1 = y1[:sr * 3]
len(y) / sr, len(y1) / sr

(3.0, 3.0)

In [78]:
sess.run(model.loss, feed_dict = {model.X: [y, y],
                                  model.Y: [[y1, y, y1, y], [y1] * 4],
                                  model.length: [len(y) / sr, len(y) / sr]})

0.394099

In [None]:
# stft[0][1].shape

In [None]:
# outputs = sess.run(model.outputs, feed_dict = {model.X: y_})
# [o.shape for o in outputs]

In [None]:
# sess.run(model.loss, feed_dict = {model.X: y_, model.Y: [y, noise]})

In [None]:
# istft = sess.run(model.istft, feed_dict = {model.X: y_})
# [s.shape for s in istft]

In [None]:
# ipd.Audio(istft[0], rate = sr)

In [None]:
# ipd.Audio(istft[1], rate = sr)

In [None]:
# ipd.Audio(y_, rate = sr)

In [None]:
saver = tf.train.Saver()

In [None]:
saver.save(sess, 'test/model.ckpt')

In [None]:
!ls -lh test

In [None]:
tf.trainable_variables()

In [None]:
!rm -rf test