<a href="https://colab.research.google.com/github/mehdihosseinimoghadam/Signal-Processing/blob/main/Deep_Complex_U_net.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
from torch import nn
import torch

In [None]:
class CConv2d(nn.Module):
  def __init__(self, in_channels, out_channels, **kwargs):
    super(CConv2d, self).__init__()
    self.in_channels = in_channels
    self.out_channels = out_channels


    self.re_conv = nn.Conv2d(self.in_channels, self.out_channels, **kwargs)
    self.im_conv = nn.Conv2d(self.in_channels, self.out_channels, **kwargs)

    nn.init.xavier_uniform_(self.re_conv.weight)
    nn.init.xavier_uniform_(self.im_conv.weight)

  def forward(self, x):  
    x_re = x[..., 0]
    x_im = x[..., 1]

    out_re = self.re_conv[x_re] - self.im_conv(x_im)
    out_im = self.re_conv[x_im] + self.im_conv(x_re)

    out = torch.cat([out_re, out_im], -1) 

    return out


In [4]:
class CConvTrans2d(nn.Module):
  def __init__(self, in_channels, out_channels, **kwargs):
    super(CConvTrans2d, self).__init__()
    self.in_channels = in_channels
    self.out_channels = out_channels


    self.re_Tconv = nn.ConvTranspose2d(self.in_channels, self.out_channels, **kwargs)
    self.im_Tconv = nn.ConvTranspose2d(self.in_channels, self.out_channels, **kwargs)

    nn.init.xavier_uniform_(self.re_Tconv.weight)
    nn.init.xavier_uniform_(self.im_Tconv.weight)


  def forward(self, x):  
    x_re = x[..., 0]
    x_im = x[..., 1]

    out_re = self.re_Tconv[x_re] - self.im_Tconv(x_im)
    out_im = self.re_Tconv[x_im] + self.im_Tconv(x_re)

    out = torch.cat([out_re, out_im], -1) 

    return out

In [6]:
class CBatchnorm(nn.Module):
    def __init__(self, in_channels, out_channels, **kwargs):
        super(CBatchnorm, self).__init__()
        self.in_channels = in_channels

        self.re_batch = nn.BatchNorm2d(in_channels)
        self.im_batch = nn.BatchNorm2d(in_channels)


    def forward(self, x):
        x_re = x[..., 0]
        x_im = x[..., 1]

        out_re =  self.re_batch(x_re)
        out_im =  self.re_batch(x_im)


        out = torch.cat([out_re, out_im], -1) 

        return out



In [7]:
class CconvBlock(nn.Module):
  def __init__(self, in_channels, out_channels, kernel, stride, padding):
    super(CconvBlock, self).__init()
    self.in_channels = in_channels
    self.out_channels = out_channels
    self.kernel = kernel
    self.stride = stride
    self.padding = padding

    self.CConv2d = CConv2d(self.in_channels, self.out_channels, self.kernel, self.stride , self.padding)
    self.CBatchnorm = CBatchnorm(self.out_channels)
    self.leaky_relu = nn.LeakyReLU()


  def forward(self, x):
    conved = self.CConv2d(x)
    normed = self.CBatchnorm(conved)
    activated =  self.leaky_relu(normed)

    return activated



In [8]:
class CConvTransBlock(nn.Module):
  def __init__(self, in_channels, out_channels, kernel, stride, padding, last_layer=False):
    super(CConvTransBlock, self).__init()
    self.in_channels = in_channels
    self.out_channels = out_channels
    self.kernel = kernel
    self.stride = stride
    self.padding = padding
    self.last_layer = last_layer

    self.CConvTrans2d = CConvTrans2d(self.in_channels, self.out_channels, self.kernel, self.stride , self.padding)
    self.CBatchnorm = CBatchnorm(self.out_channels)
    self.leaky_relu = nn.LeakyReLU()


  def forward(self, x):
    conved =  self.CConvTrans2d(x)

    if not self.last_layer:
        normed = self.CBatchnorm(conved)
        activated =  self.leaky_relu(normed)
        return activated
    else:
        m_phase = conved/(torch.abs(conved)+1e-8)  
        m_mag = torch.tanh(torch.abs(conved))
        out = m_phase * m_mag
        return out  


In [11]:
class Encoder(nn.Module):
  def __init__(self):
      super(Encoder, self).__init__()

      self.CconvBlock0 = CconvBlock(filter_size=(7,5), stride_size=(2,2), in_channels=1, out_channels=45, padding=(0,0))
      self.CconvBlock1 = CconvBlock(filter_size=(7,5), stride_size=(2,2), in_channels=45, out_channels=90, padding=(0,0))
      self.CconvBlock2 = CconvBlock(filter_size=(5,3), stride_size=(2,2), in_channels=90, out_channels=90, padding=(0,0))
      self.CconvBlock3 = CconvBlock(filter_size=(5,3), stride_size=(2,2), in_channels=90, out_channels=90, padding=(0,0))
      self.CconvBlock4 = CconvBlock(filter_size=(5,3), stride_size=(2,1), in_channels=90, out_channels=90, padding=(0,0))


  def forward(self, x):
      ccb0 = self.CconvBlock0(x)
      ccb1 = self.CconvBlock1(ccb0) 
      ccb2 = self.CconvBlock2(ccb1)        
      ccb3 = self.CconvBlock3(ccb2)        
      ccb4 = self.CconvBlock4(ccb3)

      return [ccb0, ccb1, ccb2, ccb3, ccb4]

In [12]:
class Decoder(nn.Module):
  def __init__(self, n_fft=64, hop_length=16):
    super(Decoder, self).__init__()
    self.n_fft = n_fft
    self.hop_length = hop_length


    self.CConvTransBlock0 = CConvTransBlock(filter_size=(5,3), stride_size=(2,1), in_channels=90, out_channels=90, output_padding=(0,0), padding=(0,0))
    self.CConvTransBlock1 = CConvTransBlock(filter_size=(5,3), stride_size=(2,2), in_channels=180, out_channels=90, output_padding=(0,0), padding=(0,0))
    self.CConvTransBlock2 = CConvTransBlock(filter_size=(5,3), stride_size=(2,2), in_channels=180, out_channels=90, output_padding=(0,0), padding=(0,0))
    self.CConvTransBlock3 = CConvTransBlock(filter_size=(7,5), stride_size=(2,2), in_channels=180, out_channels=45, output_padding=(0,0), padding=(0,0))
    self.CConvTransBlock4 = CConvTransBlock(filter_size=(7,5), stride_size=(2,2), in_channels=90, output_padding=(0,1), padding=(0,0),
                              out_channels=1, last_layer=True)
    

  def forward(self, x0, x, is_istft=True):

    
        cctb0 = self.CConvTransBlock0(x[-1])
        # skip-connection
        c0 = torch.cat((cctb0, x[-2]), dim=1)
        
        cctb1 = self.CConvTransBlock1(c0)
        c1 = torch.cat((cctb1, x[-3]), dim=1)
        
        cctb2 = self.CConvTransBlock2(c1)
        c2 = torch.cat((cctb2, x[-4]), dim=1)
        
        cctb3 = self.CConvTransBlock3(c2)
        c3 = torch.cat((cctb3, x[-5]), dim=1)
        
        cctb4 = self.CConvTransBlock4(c3)


        output = cctb4 * x0
        if is_istft:
            output = torch.squeeze(output, 1)
            output = torch.istft(output, n_fft=self.n_fft, hop_length=self.hop_length, normalized=True)
        
        return output


In [None]:
class Model(nn.Module):
  def __init__(self, encoder_out, n_fft=64, hop_length=16):
    super(Model, self).__init__()
    self.encoder_out = encoder_out
    self.n_fft = n_fft
    self.hop_length = hop_length

    self.Encoder = Encoder()
    self.Decoder = Decoder(self.encoder_out, self.n_fft, self.hop_length)

  def forward(self, x):
      encoded = self.Encoder(x)
      decoded = self.Decoder(x, encoded) 
      return decoded