In [61]:
import sys
sys.path.append("../")

import torch
import torch.nn as nn
import torch.nn.functional as F

from UNet_m import ResBlock,Encoder,Decoder

In [68]:
class DepthwiseSeparableConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
        super(DepthwiseSeparableConv2d, self).__init__()
        
        self.depthwise = nn.Conv2d(
            in_channels=in_channels,  # because it passed already in the previous conv
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            groups=in_channels,
        )
        self.bn_depth = nn.BatchNorm2d(out_channels)
        self.pointwise = nn.Conv2d(out_channels, out_channels, kernel_size=1)
        self.bn_point = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        x = self.depthwise(x)
        x = self.bn_depth(x)
        x = F.relu(x)

        x = self.pointwise(x)
        x = self.bn_point(x)
        x = F.relu(x)
        return x


"""
From https://github.com/JusperLee/AFRCNN-For-Speech-Separation
"""
class GlobalChannelLayerNorm(nn.Module):
    '''
        Global Layer Normalization
    '''
    def __init__(self, channel_size):
        super(GlobalChannelLayerNorm, self).__init__()
        self.channel_size = channel_size
        self.gamma = nn.Parameter(torch.ones(channel_size),
                                  requires_grad=True)
        self.beta = nn.Parameter(torch.zeros(channel_size),
                                 requires_grad=True)
    
    def apply_gain_and_bias(self, normed_x):
        """ Assumes input of size `[batch, chanel, *]`. """
        return (self.gamma * normed_x.transpose(1, -1) +
                self.beta).transpose(1, -1)

    def forward(self, x):
        """
        x: N x C x T
        """
        dims = list(range(1, len(x.shape)))
        mean = x.mean(dim=dims, keepdim=True)
        var = torch.pow(x - mean, 2).mean(dim=dims, keepdim=True)
        return self.apply_gain_and_bias((x - mean) / (var + 1e-8).sqrt())
"""
Modified FFN Block of  
Li, Kai, Runxuan Yang, and Xiaolin Hu. 
"An efficient encoder-decoder architecture with top-down attention for speech separation."
arXiv preprint arXiv:2209.15200 (2022).

"""
class MultiScaleConvBlock(nn.Module):
    def __init__(self, in_channels) : 
        super(MultiScaleConvBlock, self).__init__()
        
        self.net = nn.Sequential(
                    nn.Conv2d(in_channels, in_channels*2, kernel_size=1),
                    GlobalChannelLayerNorm(in_channels*2),
                    DepthwiseSeparableConv2d(in_channels*2, in_channels*2, kernel_size=1),
                    GlobalChannelLayerNorm(in_channels*2),
                    nn.Conv2d(in_channels*2, in_channels, kernel_size=1),
                    GlobalChannelLayerNorm(in_channels)
        
        )

    def forward(self,x):
        # x : [B, C, F, T]
        x = self.net(x)
        return x

class LSTMBlock(nn.Module):
    def __init__(self,n_dim,n_hidden,n_layer=3,proj_size=None,dropout=0.2) : 
        super(LSTMBlock, self).__init__()
        
        if proj_size == None :
            proj_size = n_dim
        self.rnn = nn.LSTM(n_dim,n_hidden,n_layer,batch_first=True,proj_size=proj_size,dropout=dropout)
    
    def forward(self,x):
        # [B,C,F',T] -> [B,C*F',T]
        d0,d1,d2,d3 = x.shape
        x = torch.reshape(x,(x.shape[0],x.shape[1]*x.shape[2],x.shape[3]))
        # [B,C*F',T] -> [B,T,C*F']
        x = torch.permute(x,(0,2,1))
        #print("bottle in : {}".format(x.shape))

        x,h = self.rnn(x)

        # [B,T,C*F'] -> [B,C*F',T]
        x = torch.permute(x,(0,2,1))
        # [B,C*F',T] -> [B,C,F',T]
        x = torch.reshape(x,(d0,d1,d2,d3))
        
        return x,h
    

class FGRUBlock(nn.Module):
    def __init__(self, in_channels, hidden_size, out_channels):
        super(FGRUBlock, self).__init__()
        self.GRU = nn.GRU(
            in_channels, hidden_size, batch_first=True, bidirectional=True
        )
        # the GRU is bidirectional -> multiply hidden_size by 2
        self.conv = nn.Conv2d(hidden_size * 2, out_channels, kernel_size=1)
        self.bn = nn.BatchNorm2d(out_channels)
        self.hidden_size = hidden_size
        self.relu = nn.ReLU()

    def forward(self, x):
        # X : [B, C, F', T]
        # goal : [BT, F, C ]
        B, C, T, F_ = x.shape
        x_ = x.permute(0, 2, 3, 1)  # x_.shape == (B,T,F,C)
        x_ = x_.reshape(B * T, F_, C)
        y, h = self.GRU(x_)  # x_.shape == (BT,F,C)
        y = y.reshape(B, T, F_, self.hidden_size * 2)
        output = y.permute(0, 3, 1, 2)  # output.shape == (B,C,T,F)
        output = self.conv(output)
        output = self.bn(output)
        return self.relu(output)


class TGRUBlock(nn.Module):
    def __init__(self, in_channels, hidden_size, out_channels):
        super(TGRUBlock, self).__init__()
        self.GRU = nn.GRU(in_channels, hidden_size, batch_first=True)
        self.conv = nn.Conv2d(hidden_size, out_channels, kernel_size=1)
        self.bn = nn.BatchNorm2d(out_channels)
        self.hidden_size = hidden_size
        self.relu = nn.ReLU()

    def forward(self, x, rnn_state=None):
        """
        X :[B, C, F', T]
        
        X' : [B*F', T, C']
        """
        B, C, F_, T = x.shape
 
        # -> [B, F', T, C]
        x = x.permute(0, 2, 3, 1)
        # -> [B*F', T, C]
        x = x.reshape(B * F_, T, C)
 
            
        x, rnn_state = self.GRU(x, rnn_state)  # y_.shape == (BF,T,C)
        #  X' : [B*F', T, hidden_size]
        # -> X' : [B, F', T, hidden_size]
        print("TGRU::{}".format(x.shape))
        x = x.reshape(B, F_, T, self.hidden_size)
        # -> X' : [B, hidden_size, F', T]
        x = x.permute(0, 3, 1, 2)     
        #  X' : [B, C, F', T]
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x, rnn_state

class ResUNetOnFreq2(nn.Module) :
    def __init__(self, 
                 c_in = 1,
                 c_out = 1,
                 n_fft=512,
                 device="cuda:0",
                 print_shape=False,
                 n_block = 5,
                 activation = "Softplus" , 
                 bottleneck = "LSTM",
                 Softplus_thr = 20,
                 norm = "BatchNorm2d",
                 dropout = 0.0,
                 activation_layer = "PReLU",
                 multi_scale = False
                 ):
        super().__init__()

        n_hfft = int(n_fft/2+1)

        self.print_shape=print_shape
        self.activation = activation
        self.multi_scale = multi_scale
        self.F = n_hfft
        upscale = [1,2.2, 4.5, 9, 18.4, 36.8]
        f_dim = 30

        if n_block < 2 :
            raise Exception("ERROR::ResUnetOnFreq : n_block({}) < 2".fomrat(n_block))

        ## Model Implementation

        # input layer
        self.layer_input = nn.Sequential(
            Encoder(c_in,f_dim,(1,3),1,(0,1),1),
            nn.LayerNorm(n_hfft) 
        )

        # Encoder
        encoders=[]
        encoders.append(ResBlock(f_dim))
        for i in range(n_block) :
            encoders.append(nn.Sequential(
                    Encoder(f_dim,f_dim,(3,1),
                    (2,1),(0,0),activation=activation_layer,norm=norm),
                    ResBlock(30)))

        self.encoders=encoders
        for i,enc in enumerate(self.encoders) : 
            self.add_module("enc_{}".format(i),enc)

        # Decoder
        decoders=[]
        for i in range(n_block-2) :
            decoders.append(nn.Sequential(
                ResBlock(f_dim),
                Decoder(f_dim,f_dim,(4,1),
                (2,1),(1,0),output_padding=(1,0),activation=activation_layer,norm=norm)))
        decoders.append(nn.Sequential(
                ResBlock(f_dim),
                Decoder(f_dim,f_dim,(5,1),
                (2,1),(1,0),output_padding=(1,0),activation=activation_layer,norm=norm)))
        decoders.append(nn.Sequential(
                ResBlock(f_dim),
                Decoder(f_dim,f_dim,(4,1),
                (2,1),(1,0),output_padding=(1,0),activation=activation_layer,norm=norm)))
        decoders.append(ResBlock(f_dim))

        self.decoders=decoders
        for i,dec in enumerate(self.decoders) : 
            self.add_module("dec_{}".format(i),dec)

        self.len_model = len(encoders)

        # Residual Path
        res_paths = []
        res_paths.append(Encoder(f_dim,f_dim,1,1,0,1,activation=activation_layer))
        for i in range(n_block) : 
            res_paths.append(Encoder(f_dim,f_dim,1,1,0,1,activation=activation_layer))

        self.res_paths = res_paths
        for i,res_path in enumerate(self.res_paths) : 
            self.add_module("res_path_{}".format(i),res_path)

        # Bottleneck
        if bottleneck == "LSTM" : 
            self.bottleneck = LSTMBlock(210,300,n_layer=3,dropout=dropout)
        elif bottleneck == "FTGRU" : 
            bottleneck_hidden = 256
            bottleneck_channel = f_dim*2
            self.bottleneck = nn.Sequential(
                FGRUBlock(f_dim, bottleneck_hidden, bottleneck_channel),
                TGRUBlock(bottleneck_channel, bottleneck_hidden, f_dim)
            )            
        else :
            self.bottleneck = nn.LSTM(210,300,3,batch_first=True,proj_size=210,dropout=dropout)
            
        # multi-scale encoder
        if multi_scale : 
            ms = []
            for i in range(n_block):
                ms.append(nn.Sequential(
                    #nn.Conv2d(30,30,(2**(n_block+1-i),1),stride=(2**(n_block-i),1),padding=(1,0))
                    nn.AvgPool2d((2**(n_block+1-i),1),stride=(2**(n_block-i),1),padding=(1,0))
                ))
            self.ms = ms
            for i,i_ms in enumerate(self.ms) : 
                self.add_module("ms_{}".format(i),i_ms)
                
            self.ms_module = MultiScaleConvBlock(f_dim)

            self.upsample = []
            for scale in upscale : 
                self.upsample.append(nn.Sequential(
                                nn.Upsample(scale_factor=(scale,1), mode='nearest'),
                                nn.Sigmoid()
                ))
            for i,i_up in enumerate(self.upsample) : 
                self.add_module("up_{}".format(i),i_up)

        # output layer
        self.out_layer = nn.ConvTranspose2d(f_dim,c_out,(3,1),stride=1,padding=(1,0),dilation=1,output_padding=(0,0))

        if activation == "Softplus" : 
            self.activation_mask = nn.Softplus(threshold=Softplus_thr)
        elif activation == "Sigmoid" : 
            self.activation_mask = nn.Sigmoid()
        elif activation == "Tanh" : 
            self.activation_mask = nn.Tanh()
        elif activation == "Identity" : 
            self.activation_mask = nn.Identity()
        elif activation == "MEA" : 
            if c_in == 1 :
                raise Exception("ERROR::ResUnetOnFreq::feature must be complex")
            self.activation_mask = MEA(in_channels = c_out)
            self.add_module("MEA",self.activation_mask)
        else : 
            self.activation_mask = nn.Softplus()

    def forward(self,input):
        ## ipnut : [ Batch Channel Freq Time]
        # reshape
        # [ B C T F]
        feature = torch.permute(input[:,:,:,:],(0,1,3,2))
        feature = self.layer_input(feature)

        # reshape
        x = torch.permute(feature,(0,1,3,2))

        ## Encoder
        res = []
        ms = None
        for i,enc in enumerate(self.encoders):
            x = enc(x)
            if self.print_shape : 
                print("x_{} : {}".format(i,x.shape))
            res.append(x)
            
            # multi-scale
            if self.multi_scale and i < len(self.ms) :
                if ms is None : 
                    ms = self.ms[i](x)
                else : 
                    ms += self.ms[i](x)
                if self.print_shape : 
                    print("ms_{} : {}".format(i,ms.shape))
                ms = self.ms_module(ms)

        ## bottleneck
        x = self.bottleneck(x)[0]

        ## ResPath
        for i,res_path in enumerate(self.res_paths) : 
            res[i] = res_path(res[i])

        ## Decoder
        y = x

        for i,dec in enumerate(self.decoders) : 
            if self.print_shape : 
                print("y {} += r_{} : {}".format(y.shape,i,res[-1-i].shape))
                
            if self.multi_scale : 
                up = self.upsample[i]
                att = up(ms)
                y  = torch.add(y,res[-1-i]*att)
            else : 
                y  = torch.add(y,res[-1-i])
                
            y = dec(y)
            if self.print_shape : 
                print("y_{} : {}".format(i,y.shape))     

        ## output
        output = self.out_layer(y)
        return self.activation_mask(output)

    def output(self,mask,feature):

        if self.activation == "MEA" : 
            return self.last_activation.output(mask,feature)
        else :
            return mask * feature[:,:2]

In [70]:
m = ResUNetOnFreq2(c_in = 1, print_shape=True, activation = "Sigmoid",bottleneck = "FTGRU",multi_scale=True)
m.eval()

# to check low latency
x = torch.rand(1,1,257,1)
m(x)
print("--------------------")

x = torch.rand(1,1,257,10)
print(x.shape)
y = m(x)
print(y.shape)

x_0 : torch.Size([1, 30, 257, 1])
ms_0 : torch.Size([1, 30, 7, 1])
x_1 : torch.Size([1, 30, 128, 1])
ms_1 : torch.Size([1, 30, 7, 1])
x_2 : torch.Size([1, 30, 63, 1])
ms_2 : torch.Size([1, 30, 7, 1])
x_3 : torch.Size([1, 30, 31, 1])
ms_3 : torch.Size([1, 30, 7, 1])
x_4 : torch.Size([1, 30, 15, 1])
ms_4 : torch.Size([1, 30, 7, 1])
x_5 : torch.Size([1, 30, 7, 1])
TGRU::torch.Size([7, 1, 256])
y torch.Size([1, 30, 7, 1]) += r_0 : torch.Size([1, 30, 7, 1])
y_0 : torch.Size([1, 30, 15, 1])
y torch.Size([1, 30, 15, 1]) += r_1 : torch.Size([1, 30, 15, 1])
y_1 : torch.Size([1, 30, 31, 1])
y torch.Size([1, 30, 31, 1]) += r_2 : torch.Size([1, 30, 31, 1])
y_2 : torch.Size([1, 30, 63, 1])
y torch.Size([1, 30, 63, 1]) += r_3 : torch.Size([1, 30, 63, 1])
y_3 : torch.Size([1, 30, 128, 1])
y torch.Size([1, 30, 128, 1]) += r_4 : torch.Size([1, 30, 128, 1])
y_4 : torch.Size([1, 30, 257, 1])
y torch.Size([1, 30, 257, 1]) += r_5 : torch.Size([1, 30, 257, 1])
y_5 : torch.Size([1, 30, 257, 1])
-------------

In [37]:
m = ResUNetOnFreq2(c_in = 1, print_shape=True, activation = "Sigmoid",bottleneck = "FTGRU")
m.eval()

# to check low latency
x = torch.rand(1,1,257,1)
m(x)
print("--------------------")

x = torch.rand(1,1,257,10)
print(x.shape)
y = m(x)
print(y.shape)

x_0 : torch.Size([1, 30, 257, 1])
x_1 : torch.Size([1, 30, 128, 1])
x_2 : torch.Size([1, 30, 63, 1])
x_3 : torch.Size([1, 30, 31, 1])
x_4 : torch.Size([1, 30, 15, 1])
x_5 : torch.Size([1, 30, 7, 1])
TGRU::torch.Size([7, 1, 256])
y torch.Size([1, 30, 7, 1]) += r_0 : torch.Size([1, 30, 7, 1])
y_0 : torch.Size([1, 30, 15, 1])
y torch.Size([1, 30, 15, 1]) += r_1 : torch.Size([1, 30, 15, 1])
y_1 : torch.Size([1, 30, 31, 1])
y torch.Size([1, 30, 31, 1]) += r_2 : torch.Size([1, 30, 31, 1])
y_2 : torch.Size([1, 30, 63, 1])
y torch.Size([1, 30, 63, 1]) += r_3 : torch.Size([1, 30, 63, 1])
y_3 : torch.Size([1, 30, 128, 1])
y torch.Size([1, 30, 128, 1]) += r_4 : torch.Size([1, 30, 128, 1])
y_4 : torch.Size([1, 30, 257, 1])
y torch.Size([1, 30, 257, 1]) += r_5 : torch.Size([1, 30, 257, 1])
y_5 : torch.Size([1, 30, 257, 1])
--------------------
torch.Size([1, 1, 257, 10])
x_0 : torch.Size([1, 30, 257, 10])
x_1 : torch.Size([1, 30, 128, 10])
x_2 : torch.Size([1, 30, 63, 10])
x_3 : torch.Size([1, 30, 

In [32]:
m = ResUNetOnFreq2(c_in = 1, print_shape=True, activation = "Sigmoid",bottleneck = "LSTM")
m.eval()

# to check low latency
x = torch.rand(1,1,257,1)
m(x)
print("--------------------")

x = torch.rand(1,1,257,10)
print(x.shape)
y = m(x)
print(y.shape)

x_0 : torch.Size([1, 30, 257, 1])
x_1 : torch.Size([1, 30, 128, 1])
x_2 : torch.Size([1, 30, 63, 1])
x_3 : torch.Size([1, 30, 31, 1])
x_4 : torch.Size([1, 30, 15, 1])
x_5 : torch.Size([1, 30, 7, 1])
y torch.Size([1, 30, 7, 1]) += r_0 : torch.Size([1, 30, 7, 1])
y_0 : torch.Size([1, 30, 15, 1])
y torch.Size([1, 30, 15, 1]) += r_1 : torch.Size([1, 30, 15, 1])
y_1 : torch.Size([1, 30, 31, 1])
y torch.Size([1, 30, 31, 1]) += r_2 : torch.Size([1, 30, 31, 1])
y_2 : torch.Size([1, 30, 63, 1])
y torch.Size([1, 30, 63, 1]) += r_3 : torch.Size([1, 30, 63, 1])
y_3 : torch.Size([1, 30, 128, 1])
y torch.Size([1, 30, 128, 1]) += r_4 : torch.Size([1, 30, 128, 1])
y_4 : torch.Size([1, 30, 257, 1])
y torch.Size([1, 30, 257, 1]) += r_5 : torch.Size([1, 30, 257, 1])
y_5 : torch.Size([1, 30, 257, 1])
--------------------
torch.Size([1, 1, 257, 10])
x_0 : torch.Size([1, 30, 257, 10])
x_1 : torch.Size([1, 30, 128, 10])
x_2 : torch.Size([1, 30, 63, 10])
x_3 : torch.Size([1, 30, 31, 10])
x_4 : torch.Size([1, 