In [1]:
import torch
import torch.nn as nn
from torch.nn.modules.activation import MultiheadAttention
from torch.nn.modules.module import Module
from torch.autograd import Variable
import math
import torch.nn.functional as F

  from .autonotebook import tqdm as notebook_tqdm


## Encoder

In [2]:
class Encoder(nn.Module):
    def __init__(self, C, L, N):
        super(Encoder, self).__init__()
        self.C = C # in_channels
        self.L = L  # length_kernel
        self.N = N  # n_output
        self.conv = nn.Conv2d(in_channels=C,
                                out_channels=N,
                                kernel_size=(L,1),
                                stride=(L,1),
                                padding=0,
                                bias=False)
        self.activation = nn.PReLU()
    def forward(self, x):
        x = self.conv(x)
        x = self.activation(x)
        return x
    
x = torch.rand(2,5,257,32)
print(x.shape)
m = []
m.append(Encoder(C=5,N=8,L=2))
m.append(Encoder(C=8,N=16,L=2))
m.append(Encoder(C=16,N=32,L=2))
m.append(Encoder(C=32,N=64,L=2))
m.append(Encoder(C=64,N=128,L=2))
m.append(Encoder(C=128,N=256,L=2))
m.append(Encoder(C=256,N=512,L=2))
m.append(Encoder(C=512,N=1024,L=2))

for i in range(len(m)) : 
    x = m[i](x)
    print("enc_{} : {} | {}".format(i+1,x.shape, x.shape[1]*x.shape[2]))

torch.Size([2, 5, 257, 32])
enc_1 : torch.Size([2, 8, 128, 32]) | 1024
enc_2 : torch.Size([2, 16, 64, 32]) | 1024
enc_3 : torch.Size([2, 32, 32, 32]) | 1024
enc_4 : torch.Size([2, 64, 16, 32]) | 1024
enc_5 : torch.Size([2, 128, 8, 32]) | 1024
enc_6 : torch.Size([2, 256, 4, 32]) | 1024
enc_7 : torch.Size([2, 512, 2, 32]) | 1024
enc_8 : torch.Size([2, 1024, 1, 32]) | 1024


In [3]:
class TransformerEncoderLayer(Module):
    """
        https://pytorch.org/docs/stable/_modules/torch/nn/modules/transformer.html#TransformerEncoderLayer
    
        TransformerEncoderLayer is made up of self-attn and feedforward network.
        This standard encoder layer is based on the paper "Attention Is All You Need".
        Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
        Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
        Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
        in a different way during application.
        Args:
            d_model: the number of expected features in the input (required).
            nhead: the number of heads in the multiheadattention models (required).
            dim_feedforward: the dimension of the feedforward network model (default=2048).
            dropout: the dropout value (default=0.1).
            activation: the activation function of intermediate layer, relu or gelu (default=relu).
        Examples:
            >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
            >>> src = torch.rand(10, 32, 512)
            >>> out = encoder_layer(src)
    """

    def __init__(self, d_model, n_head, dropout=0):
        super(TransformerEncoderLayer, self).__init__()
        self.LayerNorm1 = nn.LayerNorm(normalized_shape=d_model)
        self.self_attn = MultiheadAttention(d_model, n_head, dropout=dropout)
        self.Dropout1 = nn.Dropout(p=dropout)
        self.LayerNorm2 = nn.LayerNorm(normalized_shape=d_model)
        self.FeedForward = nn.Sequential(nn.Linear(d_model, d_model*2*2),
                                         nn.ReLU(),
                                         nn.Dropout(p=dropout),
                                         nn.Linear(d_model*2*2, d_model))
        self.Dropout2 = nn.Dropout(p=dropout)

    def forward(self, z):
        z1 = self.LayerNorm1(z)
        z2 = self.self_attn(z1, z1, z1, attn_mask=None, key_padding_mask=None)[0]
        z3 = self.Dropout1(z2) + z
        z4 = self.LayerNorm2(z3)
        z5 = self.Dropout2(self.FeedForward(z4)) + z3
        return z5
    
x = torch.rand(10, 32, 512)
m = TransformerEncoderLayer(d_model=512, n_head=8)

print(x.shape)
y = m(x)    
print(y.shape)

torch.Size([10, 32, 512])
torch.Size([10, 32, 512])


In [19]:
class RNN(Module):
    def __init__(self, C, hidden_size= 1024,  num_layers=2):
        super(RNN, self).__init__()
        
        self.rnn = nn.GRU(input_size = C, hidden_size = hidden_size, num_layers = num_layers, batch_first  = True, bidirectional  = False)
        self.activation = nn.PReLU()
        self.FC = nn.Linear(hidden_size,C)
    def forward(self, x, h = None):
        
        x,h_out = self.rnn(x,h)
        x = self.activation(x)
        x = self.FC(x)
        return x,h_out
    
x = torch.rand(2, 32, 1024)
m = RNN(1024)
y,h = m(x)
print(y.shape)
print(h.shape)

y,h = m(x,h)
print(y.shape)
print(h.shape)

torch.Size([2, 32, 1024])
torch.Size([2, 2, 1024])
torch.Size([2, 32, 1024])
torch.Size([2, 2, 1024])


In [27]:
class Model_v1(nn.Module):
    def __init__(self):
        super(Model_v1, self).__init__()
        n_fft = 512
        n_hfft = 257

        # Convolution Encoders
        self.encoders = []
        self.encoders.append(Encoder(C=5,N=8,L=2))
        for i in range(1,8):
            self.encoders.append(Encoder(C=2**(i+2),N=2**(i+3),L=2))
    
        # Transformer Encoder
        self.formers = []
        self.formers.append(TransformerEncoderLayer(d_model=1024,n_head=8))
        self.formers.append(TransformerEncoderLayer(d_model=1024,n_head=8))
        
        self.recurrents = []
        self.recurrents.append(RNN(1024))
        self.recurrents.append(RNN(1024))
        
        self.output = nn.Linear(1024,257*4)

    def forward(self,x):
        for enc in self.encoders : 
            x = enc(x)
        
        # x : [B, 1024,1,T]
        print("x : {}".format(x.shape))
        x = torch.reshape(x,(x.shape[0],x.shape[1],x.shape[3]))
        print("x : {}".format(x.shape))
        x = torch.permute(x,(0,2,1))
        print("x : {}".format(x.shape))

        for former in self.formers : 
            x = former(x)
        
        for recurrent in self.recurrents : 
            x,h = recurrent(x)
        
        w = self.output(x)
        print("w : {}".format(w.shape))
        w = torch.permute(w,(0,2,1))
        print("w : {}".format(w.shape))
        w = torch.reshape(w,(w.shape[0],257,4,w.shape[-1]))
        return w
            
x = torch.rand(2,5,257,32)
print(x.shape)

m = Model_v1()

y = m(x)    
print(y.shape)

torch.Size([2, 5, 257, 32])
x : torch.Size([2, 1024, 1, 32])
x : torch.Size([2, 1024, 32])
x : torch.Size([2, 32, 1024])
w : torch.Size([2, 32, 1028])
w : torch.Size([2, 1028, 32])
torch.Size([2, 257, 4, 32])
