In [3]:
import sys
sys.path.append("../")

import torch
import torch.nn as nn

from UNet_m import ResBlock,Encoder,Decoder


In [111]:
"""
matched channels
multi-scale Global feature 
"""
class ResUNetOnFreq3(nn.Module) :
    def __init__(self, 
                 c_in = 1,
                 c_out = 1,
                 n_fft=512,
                 device="cuda:0",
                 print_shape=False,
                 n_block = 5,
                 activation = "Softplus" , 
                 Softplus_thr = 20,
                 norm = "BatchNorm2d",
                 dropout = 0.0
                 ):
        super().__init__()

        n_hfft = int(n_fft/2+1)

        self.print_shape=print_shape

        self.F = n_hfft
        f_dim = 30
        
        upscale = [1,2.2, 4.5, 9, 18.4, 36.8]

        if n_block < 2 :
            raise Exception("ERROR::ResUnetOnFreq : n_block({}) < 2".fomrat(n_block))

        ## Model Implementation

        # input layer
        self.layer_input = nn.Sequential(
            Encoder(c_in,f_dim,(1,3),1,(0,1),1),
            nn.LayerNorm(n_hfft) 
        )

        # Encoder
        encoders=[]
        encoders.append(ResBlock(f_dim))
        for i in range(n_block) :
            encoders.append(nn.Sequential(
                    Encoder(f_dim,f_dim,(3,1),
                    (2,1),(0,0),activation="PReLU",norm=norm),
                    ResBlock(30)))

        self.encoders=encoders
        for i,enc in enumerate(self.encoders) : 
            self.add_module("enc_{}".format(i),enc)
        
        # multi-scale encoder
        ms = []
        for i in range(n_block):
            ms.append(nn.Sequential(
                nn.Conv2d(30,30,(2**(n_block+1-i),1),stride=(2**(n_block-i),1),padding=(1,0))
            ))
        self.ms = ms
        for i,i_ms in enumerate(self.ms) : 
            self.add_module("ms_{}".format(i),i_ms)
            
        self.upsample = []
        for scale in upscale : 
            self.upsample.append(nn.Sequential(
                            nn.Upsample(scale_factor=(scale,1), mode='nearest'),
                            nn.Sigmoid()
            ))
        for i,i_up in enumerate(self.upsample) : 
            self.add_module("up_{}".format(i),i_up)
            
        # Decoder
        decoders=[]
        for i in range(n_block-2) :
            decoders.append(nn.Sequential(
                ResBlock(f_dim),
                Decoder(f_dim,f_dim,(4,1),
                (2,1),(1,0),output_padding=(1,0),activation="PReLU",norm=norm)))
        decoders.append(nn.Sequential(
                ResBlock(f_dim),
                Decoder(f_dim,f_dim,(5,1),
                (2,1),(1,0),output_padding=(1,0),activation="PReLU",norm=norm)))
        decoders.append(nn.Sequential(
                ResBlock(f_dim),
                Decoder(f_dim,f_dim,(4,1),
                (2,1),(1,0),output_padding=(1,0),activation="PReLU",norm=norm)))
        decoders.append(ResBlock(f_dim))

        self.decoders=decoders
        for i,dec in enumerate(self.decoders) : 
            self.add_module("dec_{}".format(i),dec)

        self.len_model = len(encoders)

        # Residual Path
        res_paths = []
        res_paths.append(Encoder(f_dim,f_dim,1,1,0,1,activation="PReLU"))
        for i in range(n_block) : 
            res_paths.append(Encoder(f_dim,f_dim,1,1,0,1,activation="PReLU"))

        self.res_paths = res_paths
        for i,res_path in enumerate(self.res_paths) : 
            self.add_module("res_path_{}".format(i),res_path)

        # Bottlenect
        self.bottleneck = nn.LSTM(210,300,3,batch_first=True,proj_size=210,dropout=dropout)

        # output layer
        self.out_layer = nn.ConvTranspose2d(f_dim,c_out,(3,1),stride=1,padding=(1,0),dilation=1,output_padding=(0,0))

        if activation == "Softplus" : 
            self.activation_mask = nn.Softplus(threshold=Softplus_thr)
        elif activation == "Sigmoid" : 
            self.activation_mask = nn.Sigmoid()
        else : 
            self.activation_mask = nn.Softplus()

    def forward(self,input):
        ## ipnut : [ Batch Channel Freq Time]
        # reshape
        # [ B C T F]
        feature = torch.permute(input[:,:,:,:],(0,1,3,2))
        feature = self.layer_input(feature)

        # reshape
        x = torch.permute(feature,(0,1,3,2))

        ## Encoder
        res=[]
        
        ## Multi-Scale Feature
        
        ms = None
        for i,enc in enumerate(self.encoders):
            x = enc(x)
            if self.print_shape : 
                print("x_{} : {}".format(i,x.shape))
            res.append(x)
            
            if i < len(self.ms) :
                if ms is None : 
                    ms = self.ms[i](x)
                else : 
                    ms += self.ms[i](x)
                if self.print_shape : 
                    print("ms_{} : {}".format(i,ms.shape))
        
        ## bottleneck
        # [B,C,F',T] -> [B,C*F',T]
        d0,d1,d2,d3 = x.shape
        x = torch.reshape(x,(x.shape[0],x.shape[1]*x.shape[2],x.shape[3]))
        # [B,C,T] -> [B,T,C]
        x = torch.permute(x,(0,2,1))
        #print("bottle in : {}".format(x.shape))

        x = self.bottleneck(x)[0]

        # [B,T,C] -> [B,C,T]
        x = torch.permute(x,(0,2,1))
        # [B,C,T] -> [B,C,1,T]
        x = torch.reshape(x,(d0,d1,d2,d3))
        
        ## ResPath
        for i,res_path in enumerate(self.res_paths) : 
            res[i] = res_path(res[i])
        

        ## Decoder

        y = x
        for i,dec in enumerate(self.decoders) : 
            if self.print_shape : 
                print("y : {} += r_{}*att_{} : {}".format(y.shape,i,i,res[-1-i].shape))
            
            up = self.upsample[i]
            att = up(ms)
            
            y  = torch.add(y,res[-1-i]*att)
            y = dec(y)
            if self.print_shape : 
                print("-> y_{} : {}".format(i,y.shape))

        ## output
        
        output = self.out_layer(y)
        return self.activation_mask(output)
    

m = ResUNetOnFreq3(c_in = 1, print_shape=True, activation = "Sigmoid")
m.eval()

# to check low latency
x = torch.rand(1,1,257,1)
m(x)
print("--------------------")

x = torch.rand(1,1,257,10)
print(x.shape)
y = m(x)
print(y.shape)

--------------------
torch.Size([1, 1, 257, 10])
torch.Size([1, 1, 257, 10])


## Upsample

```
x_0 : torch.Size([1, 30, 257, 1])
x_1 : torch.Size([1, 30, 128, 1])
x_2 : torch.Size([1, 30, 63, 1])
x_3 : torch.Size([1, 30, 31, 1])
x_4 : torch.Size([1, 30, 15, 1])
x_5 : torch.Size([1, 30, 7, 1])
```

In [91]:
x = torch.rand(1,30,7,10)
print(x.shape)

m = nn.Upsample(scale_factor=(2.2,1), mode='nearest')
y = m(x)
print(y.shape)

m = nn.Upsample(scale_factor=(4.5,1), mode='nearest')
y = m(x)
print(y.shape)

m = nn.Upsample(scale_factor=(9,1), mode='nearest')
y = m(x)
print(y.shape)

m = nn.Upsample(scale_factor=(18.4,1), mode='nearest')
y = m(x)
print(y.shape)

m = nn.Upsample(scale_factor=(36.8,1), mode='nearest')
y = m(x)
print(y.shape)

torch.Size([1, 30, 7, 10])
torch.Size([1, 30, 15, 10])
torch.Size([1, 30, 31, 10])
torch.Size([1, 30, 63, 10])
torch.Size([1, 30, 128, 10])
torch.Size([1, 30, 257, 10])


## multi-scale encoder

In [52]:
x = torch.rand(1,30,257,1)
m = torch.nn.Conv2d(30,30,(64,1),stride=(32,1))
y = m(x)
print(x.shape)
print(y.shape)

x = torch.rand(1,30,15,1)
m = torch.nn.Conv2d(30,30,(3,1),stride=(2,1))
y = m(x)
print(x.shape)
print(y.shape)

torch.Size([1, 30, 257, 1])
torch.Size([1, 30, 7, 1])
torch.Size([1, 30, 15, 1])
torch.Size([1, 30, 7, 1])
