In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import torch
import torch.nn as nn
from torch.nn import LSTM
import torchaudio
import math

In [3]:
#network parameters set @ facebook
L = 8

In [4]:
#real audio example
audio_input, s_rate = torchaudio.load('/home/minds/Desktop/test_speech_mix/generated/speaker/id10009/7hpSiT9_gCE/0.wav', normalization=True)
audio_input = audio_input.unsqueeze(0)
audio_input.shape

torch.Size([1, 1, 64000])

In [5]:
class BaseRNNBlock(nn.Module):
    '''
    O bloco impar passa a RNN ao longo da dimensão do tempo, ou seja,
    ao longo de R chunks de tamanho K
    '''
    
    def __init__(self, Parameter=128, hidden_size=128, **kwargs):
        super(BaseRNNBlock, self).__init__()
        self.lstm1 = LSTM(input_size=Parameter, hidden_size=hidden_size, batch_first=True, bidirectional=True, **kwargs)
        self.lstm2 = LSTM(input_size=Parameter, hidden_size=hidden_size, batch_first=True, bidirectional=True, **kwargs)
        #no paper P
        self.P = nn.Linear(hidden_size*2 + Parameter, Parameter)
        
    def forward(self, x):
        outp1, _ = self.lstm1(x)
        outp2, _ = self.lstm2(x)
        outp = torch.cat((outp1 * outp2, x), dim=-1)
        #O tensor volta a NxKxR
        return self.P(outp)

In [6]:
class RNNBlock(nn.Module):
    '''
    Contem um bloco par e um bloco impar
    '''
    
    def __init__(self, K, R, hidden_size=128,**kwargs):
        super(RNNBlock, self).__init__()
        self.oddBlock = BaseRNNBlock(Parameter=K, hidden_size=hidden_size, **kwargs)
        self.evenBlock = BaseRNNBlock(Parameter=R, hidden_size=hidden_size, **kwargs)
        
    def forward(self, x):
        outp = self.oddBlock(x)
        #skip connection
        x += outp
        #Tensor NxRxK -> NxKxR
        x = torch.transpose(x, 1,-1)
        outp = self.evenBlock(x)
        #skip connection
        x += outp
        #Tensor NxKxR -> NxRxK
        x = torch.transpose(x, 1,-1)
        return x

In [7]:
class Facebookmodel(nn.Module):
    '''
    Modelo do facebook
    '''
    
    def __init__(self, n, k, r, c=2, l=8, b=1, **kwargs):
        super(Facebookmodel, self).__init__()
        assert l % 2 == 0, 'l must be even'
        self.c = c
        self.encoder = nn.Conv1d(1, n, l, int(l/2))
        self.rnnblocks = [RNNBlock(k, r, **kwargs) for _ in range(b)]
        self.d = nn.Conv1d(r, c*r, kernel_size=1)
        self.activation = torch.nn.PReLU(num_parameters=1, init=0.25)
        self.decoder = nn.ConvTranspose1d(n, 1, kernel_size=l, stride=int(l/2))
        
        #teste de decode com uma convolucao 2d
        self.decoder2d = nn.ConvTranspose2d(n, 1, kernel_size=(1,l), stride=(1,int(l/2)))
    
    def forward(self, x):
        encoded = self.encoder(x).squeeze(0)
        chunks = self.chunk(encoded)
        outps = list()
        for block in self.rnnblocks:
            chunks = block(chunks)
            res = self.d(self.activation(chunks))
            outps.append(res)
        
        outps = self.apply_overlap_and_add(outps)
        return self.decode2d(outps)
        #return self.decode(o1, o2)
    
    def chunk(self, x):
        x = torch.cat((x, torch.zeros((64, 110))), dim=-1)
        x = torch.cat((torch.zeros((64, 89)), x), dim=-1)
        
        return x.unfold(-1, 178, 89)
        
    def decode2d(self, x):
        '''
        Testar o decode com uma convolucao de 2 dimensoes, audios das c fontes
        juntos no mesmo tensor
        '''
        
        restored = []
        for a in x:
            a = a[...,89:-110].unsqueeze(0)
            d = self.decoder2d(a).squeeze(1)
            restored.append(d)
        
        return restored
    
    def decode(self, x):
        '''
        Decode de separacao com convolucao 1d, os audios das c fontes diferentes sao separados
        previamente
        '''
        restored = [[] for _ in range(self.c)]
        
        for a in x:
            for i in range(self.c):
                t = a[:,i,89:-110].unsqueeze(0)
                t = self.decoder(t)
                restored[i].append(t.squeeze(0))
        
        return restored
    
    def apply_overlap_and_add(self, x):
        overlapped_added = list()
        for el in x:
            result = self.overlap_and_add(el)
            overlapped_added.append(result)
            
        return overlapped_added
    
    def overlap_and_add(self, x):
        '''
        Faz overlap and add usando pytorch fold
        '''
        x = torch.transpose(x, -2, -1)
        result = torch.nn.functional.fold(x, (self.c, 16198) ,kernel_size=(1,178), stride=(1,89))
        return result.squeeze(1)

In [19]:
def si_snr_2speaker(y_hat, y):
    
    y_power = torch.pow(y, 2).sum(-1, keepdim=False)
    
    scale_factor = y_hat@y.t()/y_power

    s_target = (torch.pow(y.unsqueeze(1) * scale_factor.t().unsqueeze(-1), 2).sum(-1)).t()

    #e_noise
    residual = y_hat.unsqueeze(1) - y
    residual_norms = torch.pow(residual, 2).sum(-1, keepdim=False)
    
    temp = (10*(torch.log10(s_target) - torch.log10(residual_norms)))
    
    loss_one = temp[0,0] + temp[1,1]
    loss_two = temp[0,1] + temp[1,0]
    maximum_loss = loss_one if loss_one > loss_two else loss_two
    
    return -maximum_loss
    

In [36]:
def si_snr(y_hat, y):
    loss = 0
    for prediction in y_hat:
        loss += si_snr_2speaker(prediction.squeeze(0), y)
    
    return loss

In [27]:
m = Facebookmodel(64, 178, 181, b = 1)

In [28]:
r = m(audio_input)

In [29]:
audio_input1, s_rate = torchaudio.load('/home/minds/Desktop/test_speech_mix/generated/speaker/id10009/7hpSiT9_gCE/0.wav', normalization=True)
audio_input2, s_rate = torchaudio.load('/home/minds/Desktop/test_speech_mix/generated/speaker/id10009/7hpSiT9_gCE/0.wav', normalization=True)

In [30]:
target = torch.cat((audio_input1, audio_input2), 0)
target.shape

torch.Size([2, 64000])

In [31]:
r[0].shape

torch.Size([1, 2, 64000])

In [33]:
loss = apply_snr(r, target)

In [35]:
loss.backward()