In [1]:
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '1'

In [2]:
import sys

SOURCE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__name__)))
sys.path.insert(0, SOURCE_DIR)

In [3]:
import tensorflow as tf
import malaya_speech
import malaya_speech.augmentation.waveform as augmentation

In [4]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import IPython.display as ipd

In [5]:
tf.compat.v1.enable_eager_execution()

In [6]:
x = tf.random.normal(shape = (1, 178, 128))
x

<tf.Tensor: id=5, shape=(1, 178, 128), dtype=float32, numpy=
array([[[-1.4894083 , -1.08747   ,  0.23945677, ...,  0.62369645,
          0.09033953, -0.50241685],
        [-0.43885314, -0.5180802 , -2.0874403 , ...,  0.05223629,
          2.452981  ,  0.46216166],
        [ 1.6546882 ,  0.41157594, -0.8293056 , ...,  0.20968488,
          0.46494552, -0.85549325],
        ...,
        [-0.46001434, -0.57955754, -0.3742274 , ..., -0.42055252,
          0.55562997,  0.51984334],
        [-0.21256934, -0.7085946 ,  1.9723598 , ..., -1.3315401 ,
          0.55218446,  0.8064594 ],
        [-0.57879204,  0.71113425, -0.19437844, ...,  0.8976132 ,
          0.10807681, -0.23783444]]], dtype=float32)>

In [7]:
tf.math.count_nonzero(x, axis = 2)

<tf.Tensor: id=10, shape=(1, 178), dtype=int64, numpy=
array([[128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128,
        128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128,
        128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128,
        128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128,
        128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128,
        128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128,
        128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128,
        128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128,
        128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128,
        128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128,
        128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128,
        128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128,
        128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 1

In [8]:
import random

sr = 8000
speakers_size = 4

def read_wav(f):
    return malaya_speech.load(f, sr = sr)

def random_sampling(s, length):
    return augmentation.random_sampling(s, sr = sr, length = length)

def combine_speakers(files, n = 5, limit = 4):
    w_samples = random.sample(files, n)
    w_samples = [
        random_sampling(
            read_wav(f)[0],
            length = min(
                random.randint(10000 // n, 20000 // n), 10000
            ),
        )
        for f in w_samples
    ]
    y = [w_samples[0]]
    left = w_samples[0].copy() * random.uniform(0.5, 1.0)
    start, end = [], []
    start.append(0)
    end.append(len(left))

    combined = None

    for i in range(1, n):
        right = w_samples[i].copy() * random.uniform(0.5, 1.0)
        overlap = random.uniform(0.1, 0.9)
        print(i, overlap, len(right))
        len_overlap = int(overlap * len(right))
        minus = len(left) - len_overlap
        padded_right = np.pad(right, (minus, 0))
        start.append(minus)
        end.append(len(padded_right))
        left = np.pad(left, (0, len(padded_right) - len(left)))

        left = left + padded_right

        if i >= (limit - 1):
            if combined is None:
                combined = padded_right
            else:
                combined = np.pad(
                    combined, (0, len(padded_right) - len(combined))
                )
                combined += padded_right

        else:
            y.append(padded_right)

    if combined is not None:
        y.append(combined)

    for i in range(len(y)):
        if len(y[i]) != len(left):
            y[i] = np.pad(y[i], (0, len(left) - len(y[i])))
            y[i] = y[i] / np.max(np.abs(y[i]))

    left = left / np.max(np.abs(left))
    
    while len(y) < limit:
        y.append(np.zeros((len(left))))
        start.append(0)
        end.append(0)
        
    return left, y, start[:limit], end[:limit]

# y, _ = malaya_speech.load('../speech/example-speaker/husein-zolkepli.wav')
# y = np.expand_dims(y, 0).astype(np.float32)
# y.shape

In [9]:
from glob import glob

wavs = glob('../speech/example-speaker/*.wav')
len(wavs)

8

In [10]:
left, y, start, end = combine_speakers(wavs, random.randint(1, len(wavs)))
len(left) / sr, len(y), start, end

1 0.3157270280609322 41288


(6.414125, 4, [0, 10025, 0, 0], [23060, 51313, 0, 0])

In [11]:
ipd.Audio(left[start[0]: end[0]], rate = sr)

In [12]:
left = np.array([left.astype(np.float32)])

In [13]:
y_pt = torch.from_numpy(left)
y_tf = tf.convert_to_tensor(left)

In [14]:
N = 128
L = 8
H = 128
R = 6
C = speakers_size
input_normalize = False
sample_rate = 8000
segment = 4
context_len = 2 * sr / 1000
context = int(sr * context_len / 1000)
layer = R
filter_dim = context * 2 + 1
num_spk = C
segment_size = int(np.sqrt(2 * sr * segment / (L/2)))
segment_size, filter_dim

(126, 257)

In [15]:
class Encoder_PT(nn.Module):
    def __init__(self, L, N):
        super(Encoder_PT, self).__init__()
        self.L, self.N = L, N
        # setting 50% overlap
        self.conv = nn.Conv1d(
            1, N, kernel_size=L, stride=L // 2, bias=False)

    def forward(self, mixture):
        mixture = torch.unsqueeze(mixture, 1)
        mixture_w = F.relu(self.conv(mixture))
        return mixture_w
    
class MulCatBlock_PT(nn.Module):

    def __init__(self, input_size, hidden_size, dropout=0, bidirectional=False):
        super(MulCatBlock_PT, self).__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_direction = int(bidirectional) + 1

        self.rnn = nn.LSTM(input_size, hidden_size, 1, dropout=dropout,
                           batch_first=True, bidirectional=bidirectional)
        self.rnn_proj = nn.Linear(hidden_size * self.num_direction, input_size)

        self.gate_rnn = nn.LSTM(input_size, hidden_size, num_layers=1,
                                batch_first=True, dropout=dropout, bidirectional=bidirectional)
        self.gate_rnn_proj = nn.Linear(
            hidden_size * self.num_direction, input_size)

        self.block_projection = nn.Linear(input_size * 2, input_size)
    
    def forward(self, input):
        output = input
        # run rnn module
        rnn_output, _ = self.rnn(output)
        rnn_output = self.rnn_proj(rnn_output.contiguous(
        ).view(-1, rnn_output.shape[2])).view(output.shape).contiguous()
        # run gate rnn module
        gate_rnn_output, _ = self.gate_rnn(output)
        gate_rnn_output = self.gate_rnn_proj(gate_rnn_output.contiguous(
        ).view(-1, gate_rnn_output.shape[2])).view(output.shape).contiguous()
        # apply gated rnn
        gated_output = torch.mul(rnn_output, gate_rnn_output)
        gated_output = torch.cat([gated_output, output], 2)
        gated_output = self.block_projection(
            gated_output.contiguous().view(-1, gated_output.shape[2])).view(output.shape)
        return gated_output
        
class ByPass_PT(nn.Module):
    def __init__(self):
        super(ByPass_PT, self).__init__()

    def forward(self, input):
        return input


class DPMulCat_PT(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_spk,
                 dropout=0, num_layers=1, bidirectional=True, input_normalize=False):
        super(DPMulCat_PT, self).__init__()

        self.input_size = input_size
        self.output_size = output_size
        self.hidden_size = hidden_size
        self.in_norm = input_normalize
        self.num_layers = num_layers

        self.rows_grnn = nn.ModuleList([])
        self.cols_grnn = nn.ModuleList([])
        self.rows_normalization = nn.ModuleList([])
        self.cols_normalization = nn.ModuleList([])

        # create the dual path pipeline
        for i in range(num_layers):
            self.rows_grnn.append(MulCatBlock_PT(
                input_size, hidden_size, dropout, bidirectional=bidirectional))
            self.cols_grnn.append(MulCatBlock_PT(
                input_size, hidden_size, dropout, bidirectional=bidirectional))
            if self.in_norm:
                self.rows_normalization.append(
                    nn.GroupNorm(1, input_size, eps=1e-8))
                self.cols_normalization.append(
                    nn.GroupNorm(1, input_size, eps=1e-8))
            else:
                # used to disable normalization
                self.rows_normalization.append(ByPass_PT())
                self.cols_normalization.append(ByPass_PT())

        self.output = nn.Sequential(
            nn.PReLU(), nn.Conv2d(input_size, output_size * num_spk, 1))

    def forward(self, input):
        batch_size, _, d1, d2 = input.shape
        output = input
        output_all = []
        for i in range(self.num_layers):
            row_input = output.permute(0, 3, 2, 1).contiguous().view(batch_size * d2, d1, -1)
            
            row_output = self.rows_grnn[i](row_input)
            row_output = row_output.view(
                batch_size, d2, d1, -1).permute(0, 3, 2, 1).contiguous()
            row_output = self.rows_normalization[i](row_output)
            # apply a skip connection
            output = output + row_output
            
            print(i, row_input.shape, row_output.shape, output.shape)

            col_input = output.permute(0, 2, 3, 1).contiguous().view(
                batch_size * d1, d2, -1)
            col_output = self.cols_grnn[i](col_input)
            col_output = col_output.view(
                batch_size, d1, d2, -1).permute(0, 3, 1, 2).contiguous()
            col_output = self.cols_normalization[i](col_output).contiguous()
            # apply a skip connection
            output = output + col_output
            
            print(i, col_input.shape, col_output.shape, output.shape)

            output_i = self.output(output)
            output_all.append(output_i)
        return output_all

        
class Separator_PT(nn.Module):
    def __init__(self, input_dim, feature_dim, hidden_dim, output_dim, num_spk=2,
                 layer=4, segment_size=100, input_normalize=False, bidirectional=True):
        super(Separator_PT, self).__init__()

        self.input_dim = input_dim
        self.feature_dim = feature_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim

        self.layer = layer
        self.segment_size = segment_size
        self.num_spk = num_spk
        self.input_normalize = input_normalize

        self.rnn_model = DPMulCat_PT(self.feature_dim, self.hidden_dim,
                                  self.feature_dim, self.num_spk, num_layers=layer, bidirectional=bidirectional, input_normalize=input_normalize)

    # ======================================= #
    # The following code block was borrowed and modified from https://github.com/yluo42/TAC
    # ================ BEGIN ================ #
    def pad_segment(self, input, segment_size):
        # input is the features: (B, N, T)
        batch_size, dim, seq_len = input.shape
        segment_stride = segment_size // 2
        rest = segment_size - (segment_stride + seq_len %
                               segment_size) % segment_size
        if rest > 0:
            pad = Variable(torch.zeros(batch_size, dim, rest)
                           ).type(input.type())
            input = torch.cat([input, pad], 2)

        pad_aux = Variable(torch.zeros(
            batch_size, dim, segment_stride)).type(input.type())
        input = torch.cat([pad_aux, input, pad_aux], 2)
        return input, rest

    def create_chuncks(self, input, segment_size):
        # split the feature into chunks of segment size
        # input is the features: (B, N, T)

        input, rest = self.pad_segment(input, segment_size)
        batch_size, dim, seq_len = input.shape
        segment_stride = segment_size // 2

        segments1 = input[:, :, :-segment_stride].contiguous().view(batch_size,
                                                                    dim, -1, segment_size)
        segments2 = input[:, :, segment_stride:].contiguous().view(
            batch_size, dim, -1, segment_size)
        segments = torch.cat([segments1, segments2], 3).view(
            batch_size, dim, -1, segment_size).transpose(2, 3)
        return segments.contiguous(), rest

    def merge_chuncks(self, input, rest):

        batch_size, dim, segment_size, _ = input.shape
        segment_stride = segment_size // 2
        input = input.transpose(2, 3).contiguous().view(
            batch_size, dim, -1, segment_size*2)  # B, N, K, L

        input1 = input[:, :, :, :segment_size].contiguous().view(
            batch_size, dim, -1)[:, :, segment_stride:]
        input2 = input[:, :, :, segment_size:].contiguous().view(
            batch_size, dim, -1)[:, :, :-segment_stride]

        output = input1 + input2
        if rest > 0:
            output = output[:, :, :-rest]
        return output.contiguous()
    
    def forward(self, input):
        # create chunks
        enc_segments, enc_rest = self.create_chuncks(
            input, self.segment_size)
        output_all = self.rnn_model(enc_segments)
        
        output_all_wav = []
        for ii in range(len(output_all)):
            output_ii = self.merge_chuncks(
                output_all[ii], enc_rest)
            print(ii, output_all[ii].shape)
            output_all_wav.append(output_ii)
        return output_all_wav

In [16]:
# encoder_pt = Encoder_PT(L, N)
# separator_pt = Separator_PT(filter_dim + N, N, H,
#                       filter_dim, num_spk, layer, segment_size, input_normalize)
# e_pt = encoder_pt(y_pt)
# o_pt = separator_pt(e_pt)
# # l.shape, r

In [17]:
# [i.shape for i in o_pt]

In [18]:
def shape_list(x):
    """Deal with dynamic shape in tensorflow cleanly."""
    static = x.shape.as_list()
    dynamic = tf.shape(x)
    return [dynamic[i] if s is None else s for i, s in enumerate(static)]

class Encoder(tf.keras.layers.Layer):
    def __init__(self, L, N, **kwargs):
        super(Encoder, self).__init__(name = 'Encoder', **kwargs)
        self.conv = tf.keras.layers.Conv1D(N, kernel_size=L, strides=L // 2, use_bias=False)
    
    def call(self, mixture):
        mixture = tf.expand_dims(mixture, -1)
        mixture_w = tf.nn.relu(self.conv(mixture))
        return mixture_w
    
class Decoder(tf.keras.layers.Layer):
    def __init__(self, L, **kwargs):
        super(Decoder, self).__init__(name = 'Decoder', **kwargs)
        self.L = L

    def call(self, est_source):
        # torch.Size([1, 256, 22521])
        # pt (1, 2, 128, 22521), tf (1, 22521, 2, 128)
        est_source = tf.transpose(est_source, (0, 1, 3, 2))
        est_source = tf.compat.v1.layers.average_pooling2d(est_source, 1, (1,8),
                                     padding = 'SAME')
        est_source = tf.signal.overlap_and_add(tf.transpose(est_source, (0, 3, 1, 2)), self.L // 2)

        return est_source
    
class MulCatBlock(tf.keras.layers.Layer):

    def __init__(self, input_size, hidden_size, dropout=0, bidirectional=False, **kwargs):
        super(MulCatBlock, self).__init__(name = 'MulCatBlock', **kwargs)

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_direction = int(bidirectional) + 1
        
        if bidirectional:
            self.rnn = tf.keras.layers.Bidirectional(
                tf.keras.layers.LSTM(hidden_size, return_sequences = True)
            )
            self.gate_rnn = tf.keras.layers.Bidirectional(
                tf.keras.layers.LSTM(hidden_size, return_sequences = True)
            )
        else:
            self.rnn = tf.keras.layers.LSTM(hidden_size, return_sequences = True)
            self.gate_rnn = tf.keras.layers.LSTM(hidden_size, return_sequences = True)
        
        self.rnn_proj = tf.keras.layers.Dense(input_size)
        self.gate_rnn_proj = tf.keras.layers.Dense(input_size)
        self.block_projection = tf.keras.layers.Dense(input_size)
    
    def call(self, input):
        output = input
        rnn_output = self.rnn(output)
        rnn_output = self.rnn_proj(rnn_output)
        gate_rnn_output = self.gate_rnn(output)
        gate_rnn_output = self.gate_rnn_proj(gate_rnn_output)
        gated_output = tf.multiply(rnn_output, gate_rnn_output)
        gated_output = tf.concat([gated_output, output], 2)
        gated_output = self.block_projection(gated_output)
        return gated_output
    
class GroupNorm(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(DPMulCat, self).__init__(name = 'GroupNorm', **kwargs)
    
    def call(self, input):
        return tf.contrib.layers.group_norm(x_tf, groups = 1, epsilon = 1e-8)
    
class ByPass(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(ByPass, self).__init__(name = 'ByPass', **kwargs)

    def call(self, input):
        return input
        
class DPMulCat(tf.keras.layers.Layer):
    def __init__(self, input_size, hidden_size, output_size, num_spk,
                 dropout=0, num_layers=1, bidirectional=True, input_normalize=False, **kwargs):
        super(DPMulCat, self).__init__(name = 'DPMulCat', **kwargs)

        self.input_size = input_size
        self.output_size = output_size
        self.hidden_size = hidden_size
        self.in_norm = input_normalize
        self.num_layers = num_layers
        
        self.rows_grnn = []
        self.cols_grnn = []
        self.rows_normalization = []
        self.cols_normalization = []
        
        for i in range(num_layers):
            self.rows_grnn.append(MulCatBlock(
                input_size, hidden_size, dropout, bidirectional=bidirectional))
            self.cols_grnn.append(MulCatBlock(
                input_size, hidden_size, dropout, bidirectional=bidirectional))
            if self.in_norm:
                self.rows_normalization.append(GroupNorm())
                self.cols_normalization.append(GroupNorm())
            else:
                # used to disable normalization
                self.rows_normalization.append(ByPass())
                self.cols_normalization.append(ByPass())
                
        self.outputs = tf.keras.Sequential()     
        self.outputs.add(tf.keras.layers.PReLU())
        self.outputs.add(tf.keras.layers.Conv2D(output_size * num_spk, 1, padding = 'SAME'))
    
    def call(self, input):
        # original, [b, d3, d1, d2]
        print('input', input.shape)
        input = tf.transpose(input, (0, 2, 1, 3))
        batch_size, d3, d1, d2 = shape_list(input)
        output = input
        output_all = []
        for i in range(self.num_layers):
            row_input = tf.transpose(output, [0, 3, 2, 1])
            print('row_input', row_input.shape)
            row_input = tf.reshape(row_input, (batch_size * d2, d1, d3))
            row_output = self.rows_grnn[i](row_input)
            row_output = tf.reshape(row_output, (batch_size, d2, d1, d3))
            row_output = tf.transpose(row_output, (0, 3, 2, 1))
            row_output = self.rows_normalization[i](row_output)
            output = output + row_output
            
            print(i, row_input.shape, row_output.shape, output.shape)
            
            col_input = tf.transpose(output, [0, 2, 3, 1])
            col_input = tf.reshape(col_input, (batch_size * d1, d2, d3))
            col_output = self.cols_grnn[i](col_input)
            col_output = tf.reshape(col_output, (batch_size, d1, d2, d3))
            col_output = tf.transpose(col_output, (0, 3, 1, 2))
            col_output = self.cols_normalization[i](col_output)
            
            output = output + col_output
            
            print(i, col_input.shape, col_output.shape, output.shape)
            
            # torch.Size([1, 128, 126, 360]
            output_i = self.outputs(tf.transpose(output, [0, 2, 3, 1]))
            output_all.append(output_i)
        return output_all
    
class Separator(tf.keras.layers.Layer):
    def __init__(self, input_dim, feature_dim, hidden_dim, output_dim, num_spk=2,
                 layer=4, segment_size=100, input_normalize=False, bidirectional=True, **kwargs):
        super(Separator, self).__init__(name = 'Separator', **kwargs)
        self.input_dim = input_dim
        self.feature_dim = feature_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim

        self.layer = layer
        self.segment_size = segment_size
        self.num_spk = num_spk
        self.input_normalize = input_normalize
        
        self.rnn_model = DPMulCat(self.feature_dim, self.hidden_dim,
                                  self.feature_dim, self.num_spk, num_layers=layer, 
                                  bidirectional=bidirectional, input_normalize=input_normalize)
    
    def pad_segment(self, input, segment_size):
        # input is the features: (B, N, T)
        
        batch_size, seq_len, dim = shape_list(input)
        segment_stride = segment_size // 2
        rest = segment_size - (segment_stride + seq_len %
                               segment_size) % segment_size
        if rest > 0:
            pad = tf.Variable(tf.zeros(shape=(batch_size, rest, dim)))
            input = tf.concat([input, pad], 1)

        pad_aux = tf.Variable(tf.zeros(shape=(batch_size, segment_stride, dim)))
        input = tf.concat([pad_aux, input, pad_aux], 1)
        return input, rest
    
    def create_chuncks(self, input, segment_size):

        input, rest = self.pad_segment(input, segment_size)
        batch_size, seq_len, dim = shape_list(input)
        segment_stride = segment_size // 2
        segments1 = tf.reshape(input[:, :-segment_stride], (batch_size, -1, dim, segment_size))
        segments2 = tf.reshape(input[:, segment_stride:], (batch_size, -1, dim, segment_size))
        segments = tf.concat([segments1, segments2], axis = 3)
        segments = tf.reshape(segments, (batch_size, -1, dim, segment_size))
        segments = tf.transpose(segments, perm = [0, 3, 2, 1])
        return segments, rest
    
    def merge_chuncks(self, input, rest):
        # original, [b, dim, segment_size, _]
        # torch.Size([1, 256, 126, 360])
        # (1, 126, 360, 256)
        input = tf.transpose(input, perm = [0, 3, 1, 2])
        batch_size, dim, segment_size, _ = shape_list(input)
        segment_stride = segment_size // 2
        # original, [b, dim, _, segment_size]
        input = tf.transpose(input, perm = [0, 1, 3, 2])
        input = tf.reshape(input, (batch_size, dim, -1, segment_size * 2))
        
        input1 = tf.reshape(input[:, :, :, :segment_size], (batch_size, dim, -1))[:, :, segment_stride:]
        input2 = tf.reshape(input[:, :, :, segment_size:], (batch_size, dim, -1))[:, :, :-segment_stride]
        
        output = input1 + input2
        if rest > 0:
            output = output[:, :, :-rest]
            
        return tf.transpose(output, perm = [0, 2, 1])
        
    def call(self, input):
        # create chunks
        enc_segments, enc_rest = self.create_chuncks(
            input, self.segment_size)
        output_all = self.rnn_model(enc_segments)
        output_all_wav = []
        for ii in range(len(output_all)):
            print(ii, output_all[ii].shape)
            output_ii = self.merge_chuncks(
                output_all[ii], enc_rest)
            output_all_wav.append(output_ii)
        return output_all_wav
    
class Model(tf.keras.Model):
    def __init__(
        self,
        N = 128,
        L = 8,
        H = 128,
        R = 6,
        C = 2,
        input_normalize = False,
        sample_rate = 8000,
        segment = 4,
        context_len = 2 * sr / 1000,
        context = int(sr * context_len / 1000),
        layer = R,
        filter_dim = context * 2 + 1,
        num_spk = C,
        segment_size = int(np.sqrt(2 * sr * segment / (L / 2))),
        **kwargs
    ):
        super(Model, self).__init__(name = 'swave', **kwargs)
        self.C = C
        self.N = N
        self.encoder = Encoder(L, N)
        self.separator = Separator(
            filter_dim + N,
            N,
            H,
            filter_dim,
            num_spk,
            layer,
            segment_size,
            input_normalize,
        )
        self.decoder = Decoder(L)

    def call(self, mixture):
        mixture_w = self.encoder(mixture)
        output_all = self.separator(mixture_w)
        T_mix = tf.shape(mixture)[1]
        batch_size = tf.shape(mixture)[0]
        T_mix_w = tf.shape(mixture_w)[1]
        # generate wav after each RNN block and optimize the loss
        outputs = []
        for ii in range(len(output_all)):
            output_ii = tf.reshape(
                output_all[ii], (batch_size, T_mix_w, self.C, self.N)
            )
            output_ii = self.decoder(output_ii)
            output_ii = tf.cond(
                tf.shape(output_ii)[2] >= T_mix,
                lambda: output_ii[:, :, :T_mix],
                lambda: tf.pad(
                    output_ii,
                    [[0, 0], [0, 0], [0, T_mix - tf.shape(output_ii)[2]]],
                ),
            )
            outputs.append(output_ii)
        return outputs

In [19]:
model = Model(C = speakers_size)

In [20]:
y_tf.shape

TensorShape([Dimension(1), Dimension(51313)])

In [21]:
segment_size

126

In [22]:
outputs = model(y_tf)

input (1, 126, 128, 206)
row_input (1, 206, 126, 128)
0 (206, 126, 128) (1, 128, 126, 206) (1, 128, 126, 206)
0 (126, 206, 128) (1, 128, 126, 206) (1, 128, 126, 206)
row_input (1, 206, 126, 128)
1 (206, 126, 128) (1, 128, 126, 206) (1, 128, 126, 206)
1 (126, 206, 128) (1, 128, 126, 206) (1, 128, 126, 206)
row_input (1, 206, 126, 128)
2 (206, 126, 128) (1, 128, 126, 206) (1, 128, 126, 206)
2 (126, 206, 128) (1, 128, 126, 206) (1, 128, 126, 206)
row_input (1, 206, 126, 128)
3 (206, 126, 128) (1, 128, 126, 206) (1, 128, 126, 206)
3 (126, 206, 128) (1, 128, 126, 206) (1, 128, 126, 206)
row_input (1, 206, 126, 128)
4 (206, 126, 128) (1, 128, 126, 206) (1, 128, 126, 206)
4 (126, 206, 128) (1, 128, 126, 206) (1, 128, 126, 206)
row_input (1, 206, 126, 128)
5 (206, 126, 128) (1, 128, 126, 206) (1, 128, 126, 206)
5 (126, 206, 128) (1, 128, 126, 206) (1, 128, 126, 206)
0 (1, 126, 206, 512)
1 (1, 126, 206, 512)
2 (1, 126, 206, 512)
3 (1, 126, 206, 512)
4 (1, 126, 206, 512)
5 (1, 126, 206, 512)
Ins

In [23]:
outputs[0].shape

TensorShape([Dimension(1), Dimension(4), Dimension(51313)])

In [104]:
EPS = 1e-8

def log10(x):
    numerator = tf.log(x)
    denominator = tf.log(tf.constant(10, dtype=numerator.dtype))
    return numerator / denominator

def cal_si_snr_with_pit(source, estimate_source, source_lengths, C):
    B, _, T = shape_list(source)
    mask = tf.cast(tf.sequence_mask(source_lengths, tf.reduce_max(source_lengths)), source.dtype)
    print(mask)
    estimate_source *= mask
    
    num_samples = tf.cast(tf.reshape(source_lengths, (-1, 1, 1)), tf.float32)
    mean_target = tf.reduce_sum(source, axis = 2, keepdims = True) / num_samples
    mean_estimate = tf.reduce_sum(estimate_source, axis = 2, keepdims = True) / num_samples
    zero_mean_target = source - mean_target
    zero_mean_estimate = estimate_source - mean_estimate
    zero_mean_target *= mask
    zero_mean_estimate *= mask
    
    s_target = tf.expand_dims(zero_mean_target, 1)
    s_estimate = tf.expand_dims(zero_mean_estimate, 2)
    
    pair_wise_dot = tf.reduce_sum(s_estimate * s_target,
                              axis=3, keepdims=True)
    s_target_energy = tf.reduce_sum(
        s_target ** 2, axis=3, keepdims=True) + EPS
    pair_wise_proj = pair_wise_dot * s_target / s_target_energy
    e_noise = s_estimate - pair_wise_proj
    pair_wise_si_snr = tf.reduce_sum(
        pair_wise_proj ** 2, axis=3) / (tf.reduce_sum(e_noise ** 2, axis=3) + EPS)
    pair_wise_si_snr = 10.0 * log10(pair_wise_si_snr + EPS)
    pair_wise_si_snr = tf.transpose(pair_wise_si_snr, perm = [0, 2, 1])
    
    perms = tf.constant(list(permutations(range(C))))
    perms_one_hot = tf.one_hot(perms, C)
    
    snr_set = tf.einsum('bij,pij->bp', pair_wise_si_snr, perms_one_hot)
    max_snr_idx = tf.argmax(snr_set, axis=1)
    max_snr = tf.reduce_max(snr_set, axis=1, keepdims=True)
    max_snr /= C
    
    return max_snr, max_snr_idx

In [105]:
from itertools import permutations

def cal_si_snr_with_pit_pt(source, estimate_source, source_lengths):
    """Calculate SI-SNR with PIT training.
    Args:
        source: [B, C, T], B is batch size
        estimate_source: [B, C, T]
        source_lengths: [B], each item is between [0, T]
    """
    assert source.size() == estimate_source.size()
    B, C, T = source.size()
    # mask padding position along T
    mask = get_mask(source, source_lengths)
    estimate_source *= mask

    # Step 1. Zero-mean norm
    num_samples = source_lengths.view(-1, 1, 1).float()  # [B, 1, 1]
    mean_target = torch.sum(source, dim=2, keepdim=True) / num_samples
    mean_estimate = torch.sum(estimate_source, dim=2,
                              keepdim=True) / num_samples
    zero_mean_target = source - mean_target
    zero_mean_estimate = estimate_source - mean_estimate
    # mask padding position along T
    zero_mean_target *= mask
    zero_mean_estimate *= mask

    # Step 2. SI-SNR with PIT
    # reshape to use broadcast
    s_target = torch.unsqueeze(zero_mean_target, dim=1)  # [B, 1, C, T]
    s_estimate = torch.unsqueeze(zero_mean_estimate, dim=2)  # [B, C, 1, T]
    # s_target = <s', s>s / ||s||^2
    pair_wise_dot = torch.sum(s_estimate * s_target,
                              dim=3, keepdim=True)  # [B, C, C, 1]
    s_target_energy = torch.sum(
        s_target ** 2, dim=3, keepdim=True) + EPS  # [B, 1, C, 1]
    pair_wise_proj = pair_wise_dot * s_target / s_target_energy  # [B, C, C, T]
    # e_noise = s' - s_target
    e_noise = s_estimate - pair_wise_proj  # [B, C, C, T]
    # SI-SNR = 10 * log_10(||s_target||^2 / ||e_noise||^2)
    pair_wise_si_snr = torch.sum(
        pair_wise_proj ** 2, dim=3) / (torch.sum(e_noise ** 2, dim=3) + EPS)
    pair_wise_si_snr = 10 * torch.log10(pair_wise_si_snr + EPS)  # [B, C, C]
    pair_wise_si_snr = torch.transpose(pair_wise_si_snr, 1, 2)

    # Get max_snr of each utterance
    # permutations, [C!, C]
    perms = source.new_tensor(list(permutations(range(C))), dtype=torch.long)
    # one-hot, [C!, C, C]
    index = torch.unsqueeze(perms, 2)
    perms_one_hot = source.new_zeros((*perms.size(), C)).scatter_(2, index, 1)
    # [B, C!] <- [B, C, C] einsum [C!, C, C], SI-SNR sum of each permutation
    snr_set = torch.einsum('bij,pij->bp', [pair_wise_si_snr, perms_one_hot])
    max_snr_idx = torch.argmax(snr_set, dim=1)  # [B]
    #  max_snr = torch.gather(snr_set, 1, max_snr_idx.view(-1, 1))  # [B, 1]
    max_snr, _ = torch.max(snr_set, dim=1, keepdim=True)
    max_snr /= C
    return max_snr, max_snr_idx

def get_mask(source, source_lengths):
    """
    Args:
        source: [B, C, T]
        source_lengths: [B]
    Returns:
        mask: [B, 1, T]
    """
    B, _, T = source.size()
    mask = source.new_ones((B, 1, T))
    for i in range(B):
        mask[i, :, source_lengths[i]:] = 0
    return mask

In [106]:
source = torch.from_numpy(np.expand_dims(y, 0))
estimated_source = torch.from_numpy(outputs[0].numpy())
source_lengths = torch.from_numpy(np.array([len(left[0])]))

In [107]:
source.shape

torch.Size([1, 4, 51313])

In [108]:
source_tf.shape

TensorShape([Dimension(1), Dimension(4), Dimension(51313)])

In [109]:
cal_si_snr_with_pit_pt(source, estimated_source, source_lengths)

(tensor([[-61.3037]], dtype=torch.float64), tensor([17]))

In [110]:
source_tf = tf.convert_to_tensor(np.expand_dims(y, 0).astype(np.float32))
source_lengths_tf = tf.convert_to_tensor(np.array([len(left[0])]))

In [111]:
cal_si_snr_with_pit(source_tf, outputs[0], source_lengths_tf, C)

tf.Tensor([[1. 1. 1. ... 1. 1. 1.]], shape=(1, 51313), dtype=float32)


(<tf.Tensor: id=680836, shape=(1, 1), dtype=float32, numpy=array([[-61.303654]], dtype=float32)>,
 <tf.Tensor: id=680832, shape=(1,), dtype=int64, numpy=array([16])>)

In [112]:
num_samples = source_lengths.view(-1, 1, 1).float()
num_samples.shape

torch.Size([1, 1, 1])

In [113]:
num_samples = source_lengths.view(-1, 1, 1).float()
num_samples

tensor([[[51313.]]])

In [None]:
s = tf.math.logical_not(tf.sequence_mask(start, 105299))
e = tf.sequence_mask(end, 105299)
tf.math.logical_and(s, e)

In [None]:
mask = tf.cast(tf.sequence_mask([80695], 80695), outputs[0].dtype)
outputs[0] * mask

In [None]:
np.array([[1, 3],[2,0]])

In [None]:
tf.sequence_mask([1,2,3,2])

In [None]:
r = tf.sequence_mask(np.concatenate([s, e]).T)
r = tf.cast(r, tf.int32).numpy()
# r = tf.reduce_sum(r, axis = 1).numpy()

In [None]:
r.shape

In [None]:
r[0].sum()

In [None]:
r[1]

In [None]:
r[0]

In [None]:
r[:,1].sum()

In [None]:
tf.sequence_mask([[     0], [44368]])

In [None]:
# 0 torch.Size([360, 126, 128]) torch.Size([1, 128, 126, 360]) torch.Size([1, 128, 126, 360])
# 0 torch.Size([126, 360, 128]) torch.Size([1, 128, 126, 360]) torch.Size([1, 128, 126, 360])
# 1 torch.Size([360, 126, 128]) torch.Size([1, 128, 126, 360]) torch.Size([1, 128, 126, 360])
# 1 torch.Size([126, 360, 128]) torch.Size([1, 128, 126, 360]) torch.Size([1, 128, 126, 360])
# 2 torch.Size([360, 126, 128]) torch.Size([1, 128, 126, 360]) torch.Size([1, 128, 126, 360])
# 2 torch.Size([126, 360, 128]) torch.Size([1, 128, 126, 360]) torch.Size([1, 128, 126, 360])
# 3 torch.Size([360, 126, 128]) torch.Size([1, 128, 126, 360]) torch.Size([1, 128, 126, 360])
# 3 torch.Size([126, 360, 128]) torch.Size([1, 128, 126, 360]) torch.Size([1, 128, 126, 360])
# 4 torch.Size([360, 126, 128]) torch.Size([1, 128, 126, 360]) torch.Size([1, 128, 126, 360])
# 4 torch.Size([126, 360, 128]) torch.Size([1, 128, 126, 360]) torch.Size([1, 128, 126, 360])
# 5 torch.Size([360, 126, 128]) torch.Size([1, 128, 126, 360]) torch.Size([1, 128, 126, 360])
# 5 torch.Size([126, 360, 128]) torch.Size([1, 128, 126, 360]) torch.Size([1, 128, 126, 360])

In [None]:
# 0 torch.Size([1, 256, 126, 360])
# 1 torch.Size([1, 256, 126, 360])
# 2 torch.Size([1, 256, 126, 360])
# 3 torch.Size([1, 256, 126, 360])
# 4 torch.Size([1, 256, 126, 360])
# 5 torch.Size([1, 256, 126, 360])

# [torch.Size([1, 256, 22521]),
#  torch.Size([1, 256, 22521]),
#  torch.Size([1, 256, 22521]),
#  torch.Size([1, 256, 22521]),
#  torch.Size([1, 256, 22521]),
#  torch.Size([1, 256, 22521])]

In [None]:
est_source.shape

In [None]:
F.pad(output_ii, (0, 90199 - T_est)).shape

In [None]:
(0, 90099 - T_est)

In [None]:
p = [F.pad(output_ii, (0, 90199 - T_est)), F.pad(output_ii, (0, 90199 - T_est))]

In [None]:
torch.stack(p).shape

In [None]:
outputs[0]

In [None]:
tf.pad(outputs[0], [[0,0], [0,0], [0,3]])

In [None]:
tf.shape(outputs[0])

In [None]:
targets = torch.randn(10, 2, 32000)

In [None]:
est_targets = torch.randn(10, 2, 32000)

In [None]:
targets = targets.unsqueeze(1)
est_targets = est_targets.unsqueeze(2)

In [None]:
targets.shape

In [None]:
est_targets.shape

In [None]:
pw_loss = (targets - est_targets) ** 2

In [None]:
mean_over = list(range(3, pw_loss.ndim))

In [None]:
pw_loss.mean(dim=mean_over).shape