In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import torch
import torch.nn as nn
from torch.nn import LSTM
import torchaudio
import math

In [3]:
#network parameters set @ facebook
L = 8

In [4]:
#real audio example
audio_input, s_rate = torchaudio.load('/home/minds/Desktop/test_speech_mix/generated/speaker/id10009/7hpSiT9_gCE/0.wav', normalization=True)
audio_input = audio_input.unsqueeze(0)
audio_input.shape

torch.Size([1, 1, 64000])

In [5]:
audio_input.mean(), audio_input.std()

(tensor(-3.8114e-06), tensor(0.0430))

In [6]:
class BaseRNNBlock(nn.Module):
    '''
    O bloco impar passa a RNN ao longo da dimensão do tempo, ou seja,
    ao longo de R chunks de tamanho K
    '''
    
    def __init__(self, Parameter=128, hidden_size=128, **kwargs):
        super(BaseRNNBlock, self).__init__()
        self.lstm1 = LSTM(input_size=Parameter, hidden_size=hidden_size, batch_first=True, bidirectional=True, **kwargs)
        self.lstm2 = LSTM(input_size=Parameter, hidden_size=hidden_size, batch_first=True, bidirectional=True, **kwargs)
        #no paper P
        self.P = nn.Linear(hidden_size*2 + Parameter, Parameter)
        
    def forward(self, x):
        outp1, _ = self.lstm1(x)
        outp2, _ = self.lstm2(x)
        outp = torch.cat((outp1 * outp2, x), dim=-1)
        #O tensor volta a NxKxR
        return self.P(outp)

In [7]:
class RNNBlock(nn.Module):
    '''
    Contem um bloco par e um bloco impar
    '''
    
    def __init__(self, K, R, hidden_size=128,**kwargs):
        super(RNNBlock, self).__init__()
        self.oddBlock = BaseRNNBlock(Parameter=K, hidden_size=hidden_size, **kwargs)
        self.evenBlock = BaseRNNBlock(Parameter=R, hidden_size=hidden_size, **kwargs)
        
    def forward(self, x):
        outp = self.oddBlock(x)
        #skip connection
        x += outp
        #Tensor NxRxK -> NxKxR
        x = torch.transpose(x, 1,-1)
        outp = self.evenBlock(x)
        #skip connection
        x += outp
        #Tensor NxKxR -> NxRxK
        x = torch.transpose(x, 1,-1)
        return x

In [120]:
class Facebookmodel(nn.Module):
    
    def __init__(self, n, k, r, c=2, l=8, b=1, **kwargs):
        super(Facebookmodel, self).__init__()
        self.encoder = nn.Conv1d(1, n, l, int(l/2))
        self.rnnblocks = [RNNBlock(k, r, **kwargs) for _ in range(b)]
        self.d = nn.Conv1d(r, c*r, kernel_size=1)
        self.activation = torch.nn.PReLU(num_parameters=1, init=0.25)
        self.decoder = nn.ConvTranspose1d(64, 1, kernel_size=l, stride=int(l/2))
    
    def forward(self, x):
        encoded = self.encoder(x).squeeze(0)
        chunks = self.chunk(encoded)
        outps = list()
        for block in self.rnnblocks:
            res = block(chunks)
            res = self.d(self.activation(res))
            outps.append(res)
        
        s1, s2 = self.split_channels(outps)
        o1, o2 = self.apply_overlap_and_add(s1, s2)
        return self.decode(o1, o2)
    
    def split_channels(self, x):
        channel_1 = []
        channel_2 = []
        for o in x:
            divided = r.unfold(-2, 181, 181)
            t = torch.transpose(divided, -1, -2)
            channel_1.append(divided[:,0,...])
            channel_2.append(divided[:,1,...])
        return channel_1, channel_2
    
    def chunk(self, x):
        x = torch.cat((x, torch.zeros((64, 110))), dim=-1)
        x = torch.cat((torch.zeros((64, 89)), x), dim=-1)
        
        return x.unfold(-1, 178, 89)
    
    def decode(self, c1, c2):
        restored_1, restored_2 = list(), list()
        
        for a in c1:
            a = a[:, 89:-110].unsqueeze(0)
            a = self.decoder(a)
            restored_1.append(a)
        
        for a in c2:
            a = a[:,89:-110].unsqueeze(0)
            a = self.decoder(a)
            restored_2.append(a)
        
        return restored_1, restored_2
    
    def apply_overlap_and_add(self, channel_1, channel_2):
        overlapped_1 = list()
        overlapped_2 = list()
        
        for el in channel_1:
            r = self.overlap_and_add(el)
            overlapped_1.append(r)
        
        for el in channel_2:
            r = self.overlap_and_add(el)
            overlapped_2.append(r)
            
        return overlapped_1, overlapped_2
    
    def overlap_and_add(self, x):
        result = torch.nn.functional.fold(x, (1, 16198) ,kernel_size=(1,178), stride=(1,89))
        return result.squeeze(1).squeeze(1)

In [121]:
m = Facebookmodel(64, 178, 181)

In [122]:
r = m(audio_input)[0]

torch.Size([64, 16198])
torch.Size([1, 64, 15999])


In [30]:
r.unfold(-2, 181, 181)[:,0,...].shape

torch.Size([64, 178, 181])

In [32]:
e = nn.Conv1d(1, 64, L, int(L/2))

In [33]:
t = e(audio_input).squeeze(0)

In [None]:
t

In [36]:
dec = nn.ConvTranspose1d(64, 1, kernel_size=L, stride=int(L/2))

In [39]:
dec(t.unsqueeze(0)).shape

torch.Size([1, 1, 64000])

In [34]:
t.shape

torch.Size([64, 15999])

In [42]:
t

tensor([[ 0.2963,  0.2988,  0.2976,  ...,  0.3097,  0.2949,  0.2941],
        [ 0.0497,  0.0493,  0.0543,  ...,  0.0472,  0.0316,  0.0453],
        [-0.2260, -0.2260, -0.2255,  ..., -0.2244, -0.2102, -0.2111],
        ...,
        [-0.3466, -0.3421, -0.3494,  ..., -0.3353, -0.3397, -0.3497],
        [-0.1494, -0.1512, -0.1506,  ..., -0.1456, -0.1427, -0.1534],
        [-0.2383, -0.2415, -0.2306,  ..., -0.2481, -0.2472, -0.2326]],
       grad_fn=<SqueezeBackward1>)

In [40]:
t0 = torch.cat((t, torch.zeros((64, 110))), dim=-1)
t0 = torch.cat((torch.zeros((64, 89)), t0), -1)
t0.shape

torch.Size([64, 16198])

In [45]:
t0[:,89:-110]

tensor([[ 0.2963,  0.2988,  0.2976,  ...,  0.3097,  0.2949,  0.2941],
        [ 0.0497,  0.0493,  0.0543,  ...,  0.0472,  0.0316,  0.0453],
        [-0.2260, -0.2260, -0.2255,  ..., -0.2244, -0.2102, -0.2111],
        ...,
        [-0.3466, -0.3421, -0.3494,  ..., -0.3353, -0.3397, -0.3497],
        [-0.1494, -0.1512, -0.1506,  ..., -0.1456, -0.1427, -0.1534],
        [-0.2383, -0.2415, -0.2306,  ..., -0.2481, -0.2472, -0.2326]],
       grad_fn=<SliceBackward>)

In [58]:
t = t0.unfold(-1,178,89)
t.shape

torch.Size([64, 181, 178])

In [59]:
fold = nn.Fold(output_size=(1, 16198), kernel_size=(1, 178), stride=(1, 178))

In [60]:
res = fold(torch.transpose(t, -1,-2)).squeeze(1).squeeze(1)

RuntimeError: Given output_size=(1, 16198), kernel_size=(1, 178), dilation=(1, 1), padding=(0, 0), stride=(1, 178), expected size of input's dimension 2 to match the calculated number of sliding blocks 1 * 91 = 91, but got input.size(2)=181.

In [61]:
(t0 == res).all()

NameError: name 'res' is not defined

In [66]:
from torch.nn import functional as nnf
result = torch.tensor([[1,2,4,6.,7], [1,2,4,6,7]]).unsqueeze(0).unsqueeze(0)
recovered = nnf.unfold(result, kernel_size=(2,3), stride=2)

In [None]:
tmp = torch.rand((1,4,2))

In [None]:
tmp

In [None]:
f = nn.Fold(output_size=(1,4), kernel_size=(1,4), stride=(1,4))

In [None]:
f(tmp)

In [62]:
t = torch.rand((1,16198))
t

tensor([[0.6204, 0.2541, 0.5079,  ..., 0.5675, 0.3362, 0.8327]])

In [63]:
f = t.unfold(-1, 178, 89)
f, f.shape

(tensor([[[0.6204, 0.2541, 0.5079,  ..., 0.4269, 0.8965, 0.3063],
          [0.8041, 0.0994, 0.6862,  ..., 0.4787, 0.2633, 0.9002],
          [0.5556, 0.5323, 0.0520,  ..., 0.6909, 0.8630, 0.1000],
          ...,
          [0.5980, 0.1780, 0.8383,  ..., 0.0549, 0.7329, 0.6425],
          [0.2111, 0.6527, 0.5950,  ..., 0.7827, 0.4977, 0.7434],
          [0.2652, 0.1109, 0.0736,  ..., 0.5675, 0.3362, 0.8327]]]),
 torch.Size([1, 181, 178]))

In [68]:
torch.transpose(f, -2, -1).shape

torch.Size([1, 178, 181])

In [67]:
res = nnf.fold(torch.transpose(f, -2, -1), (1, 16198), kernel_size=(1, 178), stride=(1, 89))
res

tensor([[[[0.6204, 0.2541, 0.5079,  ..., 0.5675, 0.3362, 0.8327]]]])

In [None]:
(res.squeeze(1).squeeze(1) == t).all()

In [None]:
recovered

In [None]:
a = torch.tensor([[1, 2, 3], [1, 2, 3.]])    
b = torch.tensor([[5, 6, 7], [5, 6, 7.]])

In [None]:
uf = torch.cat((a.view(1, 6, 1), b.view(1, 6, 1)), dim=2)
uf, uf.shape

In [None]:
raw = nnf.fold(uf, (2,5), kernel_size=(2,3), stride=2)

In [None]:
raw