In [1]:
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '1'

In [2]:
import sys

SOURCE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__name__)))
sys.path.insert(0, SOURCE_DIR)

In [3]:
import malaya_speech
import tensorflow as tf

In [5]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

In [6]:
tf.compat.v1.enable_eager_execution()

In [7]:
x = torch.randn(20, 6, 10, 10)

In [8]:
nn.GroupNorm(3, 6)(x).shape

torch.Size([20, 6, 10, 10])

In [9]:
x_tf = tf.random.normal(shape = (20, 10, 10, 6))

In [10]:
tf.contrib.layers.group_norm(x_tf, groups = 3).shape

The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.



TensorShape([Dimension(20), Dimension(10), Dimension(10), Dimension(6)])

In [11]:
sr = 8000
y, _ = malaya_speech.load('../speech/example-speaker/husein-zolkepli.wav')
y = np.expand_dims(y, 0).astype(np.float32)
y.shape

(1, 90090)

In [12]:
y_pt = torch.from_numpy(y)
y_tf = tf.convert_to_tensor(y)

In [13]:
y_pt

tensor([[ 9.6471e-05, -2.3510e-05, -4.9451e-05,  ..., -1.4349e-04,
         -9.6471e-05,  0.0000e+00]])

In [14]:
N = 128
L = 8
H = 128
R = 6
C = 2
input_normalize = False
sample_rate = 8000
segment = 4
context_len = 2 * sr / 1000
context = int(sr * context_len / 1000)
layer = R
filter_dim = context * 2 + 1
num_spk = C
segment_size = int(np.sqrt(2 * sr * segment / (L/2)))
segment_size, filter_dim

(126, 257)

In [15]:
class Encoder_PT(nn.Module):
    def __init__(self, L, N):
        super(Encoder_PT, self).__init__()
        self.L, self.N = L, N
        # setting 50% overlap
        self.conv = nn.Conv1d(
            1, N, kernel_size=L, stride=L // 2, bias=False)

    def forward(self, mixture):
        mixture = torch.unsqueeze(mixture, 1)
        mixture_w = F.relu(self.conv(mixture))
        return mixture_w
    
class MulCatBlock_PT(nn.Module):

    def __init__(self, input_size, hidden_size, dropout=0, bidirectional=False):
        super(MulCatBlock_PT, self).__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_direction = int(bidirectional) + 1

        self.rnn = nn.LSTM(input_size, hidden_size, 1, dropout=dropout,
                           batch_first=True, bidirectional=bidirectional)
        self.rnn_proj = nn.Linear(hidden_size * self.num_direction, input_size)

        self.gate_rnn = nn.LSTM(input_size, hidden_size, num_layers=1,
                                batch_first=True, dropout=dropout, bidirectional=bidirectional)
        self.gate_rnn_proj = nn.Linear(
            hidden_size * self.num_direction, input_size)

        self.block_projection = nn.Linear(input_size * 2, input_size)
    
    def forward(self, input):
        output = input
        # run rnn module
        rnn_output, _ = self.rnn(output)
        rnn_output = self.rnn_proj(rnn_output.contiguous(
        ).view(-1, rnn_output.shape[2])).view(output.shape).contiguous()
        # run gate rnn module
        gate_rnn_output, _ = self.gate_rnn(output)
        gate_rnn_output = self.gate_rnn_proj(gate_rnn_output.contiguous(
        ).view(-1, gate_rnn_output.shape[2])).view(output.shape).contiguous()
        # apply gated rnn
        gated_output = torch.mul(rnn_output, gate_rnn_output)
        gated_output = torch.cat([gated_output, output], 2)
        gated_output = self.block_projection(
            gated_output.contiguous().view(-1, gated_output.shape[2])).view(output.shape)
        return gated_output
        
class ByPass_PT(nn.Module):
    def __init__(self):
        super(ByPass_PT, self).__init__()

    def forward(self, input):
        return input


class DPMulCat_PT(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_spk,
                 dropout=0, num_layers=1, bidirectional=True, input_normalize=False):
        super(DPMulCat_PT, self).__init__()

        self.input_size = input_size
        self.output_size = output_size
        self.hidden_size = hidden_size
        self.in_norm = input_normalize
        self.num_layers = num_layers

        self.rows_grnn = nn.ModuleList([])
        self.cols_grnn = nn.ModuleList([])
        self.rows_normalization = nn.ModuleList([])
        self.cols_normalization = nn.ModuleList([])

        # create the dual path pipeline
        for i in range(num_layers):
            self.rows_grnn.append(MulCatBlock_PT(
                input_size, hidden_size, dropout, bidirectional=bidirectional))
            self.cols_grnn.append(MulCatBlock_PT(
                input_size, hidden_size, dropout, bidirectional=bidirectional))
            if self.in_norm:
                self.rows_normalization.append(
                    nn.GroupNorm(1, input_size, eps=1e-8))
                self.cols_normalization.append(
                    nn.GroupNorm(1, input_size, eps=1e-8))
            else:
                # used to disable normalization
                self.rows_normalization.append(ByPass_PT())
                self.cols_normalization.append(ByPass_PT())

        self.output = nn.Sequential(
            nn.PReLU(), nn.Conv2d(input_size, output_size * num_spk, 1))

    def forward(self, input):
        batch_size, _, d1, d2 = input.shape
        output = input
        output_all = []
        for i in range(self.num_layers):
            row_input = output.permute(0, 3, 2, 1).contiguous().view(batch_size * d2, d1, -1)
            
            row_output = self.rows_grnn[i](row_input)
            row_output = row_output.view(
                batch_size, d2, d1, -1).permute(0, 3, 2, 1).contiguous()
            row_output = self.rows_normalization[i](row_output)
            # apply a skip connection
            output = output + row_output
            
            print(i, row_input.shape, row_output.shape, output.shape)

            col_input = output.permute(0, 2, 3, 1).contiguous().view(
                batch_size * d1, d2, -1)
            col_output = self.cols_grnn[i](col_input)
            col_output = col_output.view(
                batch_size, d1, d2, -1).permute(0, 3, 1, 2).contiguous()
            col_output = self.cols_normalization[i](col_output).contiguous()
            # apply a skip connection
            output = output + col_output
            
            print(i, col_input.shape, col_output.shape, output.shape)

            output_i = self.output(output)
            output_all.append(output_i)
        return output_all

        
class Separator_PT(nn.Module):
    def __init__(self, input_dim, feature_dim, hidden_dim, output_dim, num_spk=2,
                 layer=4, segment_size=100, input_normalize=False, bidirectional=True):
        super(Separator_PT, self).__init__()

        self.input_dim = input_dim
        self.feature_dim = feature_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim

        self.layer = layer
        self.segment_size = segment_size
        self.num_spk = num_spk
        self.input_normalize = input_normalize

        self.rnn_model = DPMulCat_PT(self.feature_dim, self.hidden_dim,
                                  self.feature_dim, self.num_spk, num_layers=layer, bidirectional=bidirectional, input_normalize=input_normalize)

    # ======================================= #
    # The following code block was borrowed and modified from https://github.com/yluo42/TAC
    # ================ BEGIN ================ #
    def pad_segment(self, input, segment_size):
        # input is the features: (B, N, T)
        batch_size, dim, seq_len = input.shape
        segment_stride = segment_size // 2
        rest = segment_size - (segment_stride + seq_len %
                               segment_size) % segment_size
        if rest > 0:
            pad = Variable(torch.zeros(batch_size, dim, rest)
                           ).type(input.type())
            input = torch.cat([input, pad], 2)

        pad_aux = Variable(torch.zeros(
            batch_size, dim, segment_stride)).type(input.type())
        input = torch.cat([pad_aux, input, pad_aux], 2)
        return input, rest

    def create_chuncks(self, input, segment_size):
        # split the feature into chunks of segment size
        # input is the features: (B, N, T)

        input, rest = self.pad_segment(input, segment_size)
        batch_size, dim, seq_len = input.shape
        segment_stride = segment_size // 2

        segments1 = input[:, :, :-segment_stride].contiguous().view(batch_size,
                                                                    dim, -1, segment_size)
        segments2 = input[:, :, segment_stride:].contiguous().view(
            batch_size, dim, -1, segment_size)
        segments = torch.cat([segments1, segments2], 3).view(
            batch_size, dim, -1, segment_size).transpose(2, 3)
        return segments.contiguous(), rest

    def merge_chuncks(self, input, rest):

        batch_size, dim, segment_size, _ = input.shape
        segment_stride = segment_size // 2
        input = input.transpose(2, 3).contiguous().view(
            batch_size, dim, -1, segment_size*2)  # B, N, K, L

        input1 = input[:, :, :, :segment_size].contiguous().view(
            batch_size, dim, -1)[:, :, segment_stride:]
        input2 = input[:, :, :, segment_size:].contiguous().view(
            batch_size, dim, -1)[:, :, :-segment_stride]

        output = input1 + input2
        if rest > 0:
            output = output[:, :, :-rest]
        return output.contiguous()
    
    def forward(self, input):
        # create chunks
        enc_segments, enc_rest = self.create_chuncks(
            input, self.segment_size)
        output_all = self.rnn_model(enc_segments)
        
        output_all_wav = []
        for ii in range(len(output_all)):
            output_ii = self.merge_chuncks(
                output_all[ii], enc_rest)
            print(ii, output_all[ii].shape)
            output_all_wav.append(output_ii)
        return output_all_wav

In [16]:
# encoder_pt = Encoder_PT(L, N)
# separator_pt = Separator_PT(filter_dim + N, N, H,
#                       filter_dim, num_spk, layer, segment_size, input_normalize)
# e_pt = encoder_pt(y_pt)
# o_pt = separator_pt(e_pt)
# # l.shape, r

In [17]:
# [i.shape for i in o_pt]

In [18]:
def shape_list(x):
    """Deal with dynamic shape in tensorflow cleanly."""
    static = x.shape.as_list()
    dynamic = tf.shape(x)
    return [dynamic[i] if s is None else s for i, s in enumerate(static)]

class Encoder(tf.keras.layers.Layer):
    def __init__(self, L, N, **kwargs):
        super(Encoder, self).__init__(name = 'Encoder', **kwargs)
        self.conv = tf.keras.layers.Conv1D(N, kernel_size=L, strides=L // 2, use_bias=False)
    
    def call(self, mixture):
        mixture = tf.expand_dims(mixture, -1)
        mixture_w = tf.nn.relu(self.conv(mixture))
        return mixture_w
    
class Decoder(tf.keras.layers.Layer):
    def __init__(self, L, **kwargs):
        super(Decoder, self).__init__(name = 'Decoder', **kwargs)
        self.L = L

    def call(self, est_source):
        # torch.Size([1, 256, 22521])
        # pt (1, 2, 128, 22521), tf (1, 22521, 2, 128)
        est_source = tf.transpose(est_source, (0, 1, 3, 2))
        est_source = tf.compat.v1.layers.average_pooling2d(est_source, 1, (1,8),
                                     padding = 'SAME')
        est_source = tf.signal.overlap_and_add(tf.transpose(est_source, (0, 3, 1, 2)), self.L // 2)

        return est_source
    
class MulCatBlock(tf.keras.layers.Layer):

    def __init__(self, input_size, hidden_size, dropout=0, bidirectional=False, **kwargs):
        super(MulCatBlock, self).__init__(name = 'MulCatBlock', **kwargs)

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_direction = int(bidirectional) + 1
        
        if bidirectional:
            self.rnn = tf.keras.layers.Bidirectional(
                tf.keras.layers.LSTM(hidden_size, return_sequences = True)
            )
            self.gate_rnn = tf.keras.layers.Bidirectional(
                tf.keras.layers.LSTM(hidden_size, return_sequences = True)
            )
        else:
            self.rnn = tf.keras.layers.LSTM(hidden_size, return_sequences = True)
            self.gate_rnn = tf.keras.layers.LSTM(hidden_size, return_sequences = True)
        
        self.rnn_proj = tf.keras.layers.Dense(input_size)
        self.gate_rnn_proj = tf.keras.layers.Dense(input_size)
        self.block_projection = tf.keras.layers.Dense(input_size)
    
    def call(self, input):
        output = input
        rnn_output = self.rnn(output)
        rnn_output = self.rnn_proj(rnn_output)
        gate_rnn_output = self.gate_rnn(output)
        gate_rnn_output = self.gate_rnn_proj(gate_rnn_output)
        gated_output = tf.multiply(rnn_output, gate_rnn_output)
        gated_output = tf.concat([gated_output, output], 2)
        gated_output = self.block_projection(gated_output)
        return gated_output
    
class GroupNorm(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(DPMulCat, self).__init__(name = 'GroupNorm', **kwargs)
    
    def call(self, input):
        return tf.contrib.layers.group_norm(x_tf, groups = 1, epsilon = 1e-8)
    
class ByPass(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(ByPass, self).__init__(name = 'ByPass', **kwargs)

    def call(self, input):
        return input
        
class DPMulCat(tf.keras.layers.Layer):
    def __init__(self, input_size, hidden_size, output_size, num_spk,
                 dropout=0, num_layers=1, bidirectional=True, input_normalize=False, **kwargs):
        super(DPMulCat, self).__init__(name = 'DPMulCat', **kwargs)

        self.input_size = input_size
        self.output_size = output_size
        self.hidden_size = hidden_size
        self.in_norm = input_normalize
        self.num_layers = num_layers
        
        self.rows_grnn = []
        self.cols_grnn = []
        self.rows_normalization = []
        self.cols_normalization = []
        
        for i in range(num_layers):
            self.rows_grnn.append(MulCatBlock(
                input_size, hidden_size, dropout, bidirectional=bidirectional))
            self.cols_grnn.append(MulCatBlock(
                input_size, hidden_size, dropout, bidirectional=bidirectional))
            if self.in_norm:
                self.rows_normalization.append(GroupNorm())
                self.cols_normalization.append(GroupNorm())
            else:
                # used to disable normalization
                self.rows_normalization.append(ByPass())
                self.cols_normalization.append(ByPass())
                
        self.outputs = tf.keras.Sequential()     
        self.outputs.add(tf.keras.layers.PReLU())
        self.outputs.add(tf.keras.layers.Conv2D(output_size * num_spk, 1, padding = 'SAME'))
    
    def call(self, input):
        # original, [b, d3, d1, d2]
        input = tf.transpose(input, (0, 2, 1, 3))
        batch_size, d3, d1, d2 = shape_list(input)
        output = input
        output_all = []
        for i in range(self.num_layers):
            row_input = tf.transpose(output, [0, 3, 2, 1])
            row_input = tf.reshape(row_input, (batch_size * d2, d1, d3))
            row_output = self.rows_grnn[i](row_input)
            row_output = tf.reshape(row_output, (batch_size, d2, d1, d3))
            row_output = tf.transpose(row_output, (0, 3, 2, 1))
            row_output = self.rows_normalization[i](row_output)
            output = output + row_output
            
            print(i, row_input.shape, row_output.shape, output.shape)
            
            col_input = tf.transpose(output, [0, 2, 3, 1])
            col_input = tf.reshape(col_input, (batch_size * d1, d2, d3))
            col_output = self.cols_grnn[i](col_input)
            col_output = tf.reshape(col_output, (batch_size, d1, d2, d3))
            col_output = tf.transpose(col_output, (0, 3, 1, 2))
            col_output = self.cols_normalization[i](col_output)
            
            output = output + col_output
            
            print(i, col_input.shape, col_output.shape, output.shape)
            
            # torch.Size([1, 128, 126, 360]
            output_i = self.outputs(tf.transpose(output, [0, 2, 3, 1]))
            output_all.append(output_i)
        return output_all
    
class Separator(tf.keras.layers.Layer):
    def __init__(self, input_dim, feature_dim, hidden_dim, output_dim, num_spk=2,
                 layer=4, segment_size=100, input_normalize=False, bidirectional=True, **kwargs):
        super(Separator, self).__init__(name = 'Separator', **kwargs)
        self.input_dim = input_dim
        self.feature_dim = feature_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim

        self.layer = layer
        self.segment_size = segment_size
        self.num_spk = num_spk
        self.input_normalize = input_normalize
        
        self.rnn_model = DPMulCat(self.feature_dim, self.hidden_dim,
                                  self.feature_dim, self.num_spk, num_layers=layer, 
                                  bidirectional=bidirectional, input_normalize=input_normalize)
    
    def pad_segment(self, input, segment_size):
        # input is the features: (B, N, T)
        
        batch_size, seq_len, dim = shape_list(input)
        segment_stride = segment_size // 2
        rest = segment_size - (segment_stride + seq_len %
                               segment_size) % segment_size
        if rest > 0:
            pad = tf.Variable(tf.zeros(shape=(batch_size, rest, dim)))
            input = tf.concat([input, pad], 1)

        pad_aux = tf.Variable(tf.zeros(shape=(batch_size, segment_stride, dim)))
        input = tf.concat([pad_aux, input, pad_aux], 1)
        return input, rest
    
    def create_chuncks(self, input, segment_size):

        input, rest = self.pad_segment(input, segment_size)
        batch_size, seq_len, dim = shape_list(input)
        segment_stride = segment_size // 2
        segments1 = tf.reshape(input[:, :-segment_stride], (batch_size, -1, dim, segment_size))
        segments2 = tf.reshape(input[:, segment_stride:], (batch_size, -1, dim, segment_size))
        segments = tf.concat([segments1, segments2], axis = 3)
        segments = tf.reshape(segments, (batch_size, -1, dim, segment_size))
        segments = tf.transpose(segments, perm = [0, 3, 2, 1])
        return segments, rest
    
    def merge_chuncks(self, input, rest):
        # original, [b, dim, segment_size, _]
        # torch.Size([1, 256, 126, 360])
        # (1, 126, 360, 256)
        input = tf.transpose(input, perm = [0, 3, 1, 2])
        batch_size, dim, segment_size, _ = shape_list(input)
        segment_stride = segment_size // 2
        # original, [b, dim, _, segment_size]
        input = tf.transpose(input, perm = [0, 1, 3, 2])
        input = tf.reshape(input, (batch_size, dim, -1, segment_size * 2))
        
        input1 = tf.reshape(input[:, :, :, :segment_size], (batch_size, dim, -1))[:, :, segment_stride:]
        input2 = tf.reshape(input[:, :, :, segment_size:], (batch_size, dim, -1))[:, :, :-segment_stride]
        
        output = input1 + input2
        if rest > 0:
            output = output[:, :, :-rest]
            
        return tf.transpose(output, perm = [0, 2, 1])
        
    def call(self, input):
        # create chunks
        enc_segments, enc_rest = self.create_chuncks(
            input, self.segment_size)
        output_all = self.rnn_model(enc_segments)
        output_all_wav = []
        for ii in range(len(output_all)):
            print(ii, output_all[ii].shape)
            output_ii = self.merge_chuncks(
                output_all[ii], enc_rest)
            output_all_wav.append(output_ii)
        return output_all_wav
    
class Model(tf.keras.Model):
    def __init__(
        self,
        N = 128,
        L = 8,
        H = 128,
        R = 6,
        C = 2,
        input_normalize = False,
        sample_rate = 8000,
        segment = 4,
        context_len = 2 * sr / 1000,
        context = int(sr * context_len / 1000),
        layer = R,
        filter_dim = context * 2 + 1,
        num_spk = C,
        segment_size = int(np.sqrt(2 * sr * segment / (L / 2))),
        **kwargs
    ):
        super(Model, self).__init__(name = 'swave', **kwargs)
        self.C = C
        self.N = N
        self.encoder = Encoder(L, N)
        self.separator = Separator(
            filter_dim + N,
            N,
            H,
            filter_dim,
            num_spk,
            layer,
            segment_size,
            input_normalize,
        )
        self.decoder = Decoder(L)

    def call(self, mixture):
        mixture_w = self.encoder(mixture)
        output_all = self.separator(mixture_w)
        T_mix = tf.shape(mixture)[1]
        batch_size = tf.shape(mixture)[0]
        T_mix_w = tf.shape(mixture_w)[1]
        # generate wav after each RNN block and optimize the loss
        outputs = []
        for ii in range(len(output_all)):
            output_ii = tf.reshape(
                output_all[ii], (batch_size, T_mix_w, self.C, self.N)
            )
            output_ii = self.decoder(output_ii)
            output_ii = tf.cond(
                tf.shape(output_ii)[2] >= T_mix,
                lambda: output_ii[:, :, :T_mix],
                lambda: tf.pad(
                    output_ii,
                    [[0, 0], [0, 0], [0, T_mix - tf.shape(output_ii)[2]]],
                ),
            )
            outputs.append(output_ii)
        return outputs

In [19]:
model = Model()

# class Separator(nn.Module):
#     def __init__(self, input_dim, feature_dim, hidden_dim, output_dim, num_spk=2,
#                  layer=4, segment_size=100, input_normalize=False, bidirectional=True):

# class MulCatBlock(nn.Module):

#     def __init__(self, input_size, hidden_size, dropout=0, bidirectional=False):
#         super(MulCatBlock, self).__init__()
        
# self.rnn_model = DPMulCat(self.feature_dim, self.hidden_dim,
#                           self.feature_dim, self.num_spk, num_layers=layer, bidirectional=bidirectional, input_normalize=input_normalize)

In [20]:
y_tf.shape

TensorShape([Dimension(1), Dimension(90090)])

In [21]:
outputs = model(y_tf)

0 (360, 126, 128) (1, 128, 126, 360) (1, 128, 126, 360)
0 (126, 360, 128) (1, 128, 126, 360) (1, 128, 126, 360)
1 (360, 126, 128) (1, 128, 126, 360) (1, 128, 126, 360)
1 (126, 360, 128) (1, 128, 126, 360) (1, 128, 126, 360)
2 (360, 126, 128) (1, 128, 126, 360) (1, 128, 126, 360)
2 (126, 360, 128) (1, 128, 126, 360) (1, 128, 126, 360)
3 (360, 126, 128) (1, 128, 126, 360) (1, 128, 126, 360)
3 (126, 360, 128) (1, 128, 126, 360) (1, 128, 126, 360)
4 (360, 126, 128) (1, 128, 126, 360) (1, 128, 126, 360)
4 (126, 360, 128) (1, 128, 126, 360) (1, 128, 126, 360)
5 (360, 126, 128) (1, 128, 126, 360) (1, 128, 126, 360)
5 (126, 360, 128) (1, 128, 126, 360) (1, 128, 126, 360)
0 (1, 126, 360, 256)
1 (1, 126, 360, 256)
2 (1, 126, 360, 256)
3 (1, 126, 360, 256)
4 (1, 126, 360, 256)
5 (1, 126, 360, 256)
Instructions for updating:
Use keras.layers.AveragePooling2D instead.
Instructions for updating:
Please use `layer.__call__` method instead.


In [22]:
outputs

[<tf.Tensor: id=989986, shape=(1, 2, 90090), dtype=float32, numpy=
 array([[[-1.1719704e-03, -4.6524021e-04, -9.3938681e-05, ...,
           1.4410544e-03, -3.5640300e-04, -4.6988868e-04],
         [-4.6890823e-04,  3.0910695e-04, -6.9982489e-05, ...,
          -3.9416447e-04,  3.4138164e-04, -6.2367937e-05]]], dtype=float32)>,
 <tf.Tensor: id=990095, shape=(1, 2, 90090), dtype=float32, numpy=
 array([[[-2.5863254e-03,  4.0950871e-04, -1.8142318e-04, ...,
           3.4108611e-03, -5.1781628e-04, -2.0579109e-03],
         [ 8.2968705e-05,  7.3135237e-04,  9.3379873e-05, ...,
           1.1433249e-03,  1.6798971e-03, -1.6701058e-03]]], dtype=float32)>,
 <tf.Tensor: id=990204, shape=(1, 2, 90090), dtype=float32, numpy=
 array([[[-2.7442961e-03,  2.0390499e-04, -9.5309957e-04, ...,
           1.3520857e-03,  2.5793747e-04, -4.6325871e-04],
         [ 3.0651595e-04,  3.5258654e-05, -1.7121775e-04, ...,
          -1.0525491e-03,  2.0909510e-03, -2.2979626e-03]]], dtype=float32)>,
 <tf.Tenso

In [None]:
# 0 torch.Size([360, 126, 128]) torch.Size([1, 128, 126, 360]) torch.Size([1, 128, 126, 360])
# 0 torch.Size([126, 360, 128]) torch.Size([1, 128, 126, 360]) torch.Size([1, 128, 126, 360])
# 1 torch.Size([360, 126, 128]) torch.Size([1, 128, 126, 360]) torch.Size([1, 128, 126, 360])
# 1 torch.Size([126, 360, 128]) torch.Size([1, 128, 126, 360]) torch.Size([1, 128, 126, 360])
# 2 torch.Size([360, 126, 128]) torch.Size([1, 128, 126, 360]) torch.Size([1, 128, 126, 360])
# 2 torch.Size([126, 360, 128]) torch.Size([1, 128, 126, 360]) torch.Size([1, 128, 126, 360])
# 3 torch.Size([360, 126, 128]) torch.Size([1, 128, 126, 360]) torch.Size([1, 128, 126, 360])
# 3 torch.Size([126, 360, 128]) torch.Size([1, 128, 126, 360]) torch.Size([1, 128, 126, 360])
# 4 torch.Size([360, 126, 128]) torch.Size([1, 128, 126, 360]) torch.Size([1, 128, 126, 360])
# 4 torch.Size([126, 360, 128]) torch.Size([1, 128, 126, 360]) torch.Size([1, 128, 126, 360])
# 5 torch.Size([360, 126, 128]) torch.Size([1, 128, 126, 360]) torch.Size([1, 128, 126, 360])
# 5 torch.Size([126, 360, 128]) torch.Size([1, 128, 126, 360]) torch.Size([1, 128, 126, 360])

In [None]:
# 0 torch.Size([1, 256, 126, 360])
# 1 torch.Size([1, 256, 126, 360])
# 2 torch.Size([1, 256, 126, 360])
# 3 torch.Size([1, 256, 126, 360])
# 4 torch.Size([1, 256, 126, 360])
# 5 torch.Size([1, 256, 126, 360])

# [torch.Size([1, 256, 22521]),
#  torch.Size([1, 256, 22521]),
#  torch.Size([1, 256, 22521]),
#  torch.Size([1, 256, 22521]),
#  torch.Size([1, 256, 22521]),
#  torch.Size([1, 256, 22521])]

In [None]:
est_source.shape

In [None]:
F.pad(output_ii, (0, 90199 - T_est)).shape

In [None]:
(0, 90099 - T_est)

In [None]:
p = [F.pad(output_ii, (0, 90199 - T_est)), F.pad(output_ii, (0, 90199 - T_est))]

In [None]:
torch.stack(p).shape

In [None]:
outputs[0]

In [None]:
tf.pad(outputs[0], [[0,0], [0,0], [0,3]])

In [None]:
tf.shape(outputs[0])