In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import torch
import torch.nn as nn
from torch.nn import LSTM
import torchaudio
import math

In [3]:
audio_input, s_rate = torchaudio.load('/home/minds/Desktop/test_speech_mix/generated/speaker/id10009/7hpSiT9_gCE/0.wav')
audio_input = audio_input.unsqueeze(0)
audio_input.shape

torch.Size([1, 1, 64000])

In [4]:
class BaseRNNBlock(nn.Module):
    '''
    O bloco impar passa a RNN ao longo da dimensão do tempo, ou seja,
    ao longo de R chunks de tamanho K
    '''
    
    def __init__(self, Parameter=128, hidden_size=128, **kwargs):
        super(BaseRNNBlock, self).__init__()
        self.lstm1 = LSTM(input_size=Parameter, hidden_size=hidden_size, batch_first=True, bidirectional=True, **kwargs)
        self.lstm2 = LSTM(input_size=Parameter, hidden_size=hidden_size, batch_first=True, bidirectional=True, **kwargs)
        #no paper P
        self.P = nn.Linear(hidden_size*2 + Parameter, Parameter)
        
    def forward(self, x):
        outp1, _ = self.lstm1(x)
        outp2, _ = self.lstm2(x)
        outp = torch.cat((outp1 * outp2, x), dim=-1)
        #O tensor volta a NxKxR
        return self.P(outp)

In [5]:
class RNNBlock(nn.Module):
    '''
    Contem um bloco par e um bloco impar
    '''
    
    def __init__(self, K, R, hidden_size=128,**kwargs):
        super(RNNBlock, self).__init__()
        self.oddBlock = BaseRNNBlock(Parameter=K, hidden_size=hidden_size, **kwargs)
        self.evenBlock = BaseRNNBlock(Parameter=R, hidden_size=hidden_size, **kwargs)
        
    def forward(self, x):
        outp = self.oddBlock(x)
        #skip connection
        x += outp
        #Tensor NxRxK -> NxKxR
        x = torch.transpose(x, 1,-1)
        outp = self.evenBlock(x)
        #skip connection
        x += outp
        #Tensor NxKxR -> NxRxK
        x = torch.transpose(x, 1,-1)
        return x

In [None]:
class RNNNetwork(nn.Module):
    '''
    RNN completas
    '''
    def __init__(self, K, R, C, N=64, depth=6, h_size=128, **kwargs):
        super(RNNNetwork, self).__init__()
        self.layers = [RNNBlock(K, R, h_size, **kwargs) for _ in range(depth)]
        self.activation = torch.nn.PReLU(num_parameters=1, init=0.25)
        self.D = nn.Conv2d(R, C*R, kernel_size=1)
        self.y = nn.Parameter(torch.zeros((depth, N, C*R, K)))
    
    def forward(self, x):
        for i, layer in enumerate(self.layers):
            x = self.activation(layer(x))
            self.y[i] = self.D(x)
            
        return self.y
            
        

In [None]:
class FacebookModel(nn.Module):
    '''
    Modelo completo
    '''
    def __init__(self, k=179, r=181, p=89, padding=110, l=8, n=64, c=2, **kwargs):
        super(FacebookModel, self).__init__()
        self.l, self.k, self.p = l, k, p
        #padding na esquerda e na direita do tensor
        self.padding_left, self.padding_right = p, padding
        #E
        self.encoder = nn.Conv1d(1, n, kernel_size=l, stride=int(l/2))
        #A rede em si
        self.RNNNetwork = RNNNetwork(n, k, r, c, **kwargs)
        self.decoder = None #falta reconstruir o audio
        
    def forward(self, x):
        processed_audio = self.encoder(x).squeeze(0) # NxT'
        processed_audio = sef.chunk(processed_audio #NxRxK
        y = self.RNNNetwork(processed_audio)
        
        
        return y
    def chunk(self, x):

        x = torch.cat((x, torch.zeros(1, self.N, self.paddin_right)), -1)
        x = torch.cat((torch.zeros(1, self.N, self.padding_left), x), -1)
        return x.unfold(-1, self.K, self.P).squeeze(0)
                                    
    def overlapandadd(self, x):
            res = x.unfold(-2, 181, 181)
            a, b, c, d, e = res.shape
            ad = torch.zeros((a, b, c, 16308))
            for i in range(e):
                start, end = i*89, (i*89 + d)
                piece = res[...,:,i]
                ad[...,start:end] = piece