<a href="https://colab.research.google.com/github/mehdihosseinimoghadam/Complex-Neural-Networks/blob/main/Deep_Complex_U_net.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from torch import nn
import torch

In [42]:
class CLinear(nn.Module):
  def __init__(self, in_channels, out_channels, **kwargs):
    super(CLinear, self).__init__()
    self.in_channels = in_channels
    self.out_channels = out_channels

    self.re_linear = nn.Linear(self.in_channels, self.out_channels, **kwargs)
    self.im_linear = nn.Linear(self.in_channels, self.out_channels, **kwargs)

    nn.init.xavier_uniform_(self.re_linear.weight)
    nn.init.xavier_uniform_(self.im_linear.weight)



  def forward(self, x):  
    x_re = x[..., 0]
    x_im = x[..., 1]

    out_re = self.re_linear(x_re) - self.im_linear(x_im)
    out_im = self.re_linear(x_im) + self.im_linear(x_re)

    out = torch.stack([out_re, out_im], -1) 

    return out

In [23]:
class CConv2d(nn.Module):
  def __init__(self, in_channels, out_channels, **kwargs):
    super(CConv2d, self).__init__()
    self.in_channels = in_channels
    self.out_channels = out_channels


    self.re_conv = nn.Conv2d(self.in_channels, self.out_channels, **kwargs)
    self.im_conv = nn.Conv2d(self.in_channels, self.out_channels, **kwargs)

    nn.init.xavier_uniform_(self.re_conv.weight)
    nn.init.xavier_uniform_(self.im_conv.weight)

  def forward(self, x):  
    x_re = x[..., 0]
    x_im = x[..., 1]

    out_re = self.re_conv(x_re) - self.im_conv(x_im)
    out_im = self.re_conv(x_im) + self.im_conv(x_re)

    out = torch.stack([out_re, out_im], -1) 

    return out


In [24]:
class CConvTrans2d(nn.Module):
  def __init__(self, in_channels, out_channels, **kwargs):
    super(CConvTrans2d, self).__init__()
    self.in_channels = in_channels
    self.out_channels = out_channels


    self.re_Tconv = nn.ConvTranspose2d(self.in_channels, self.out_channels, **kwargs)
    self.im_Tconv = nn.ConvTranspose2d(self.in_channels, self.out_channels, **kwargs)

    nn.init.xavier_uniform_(self.re_Tconv.weight)
    nn.init.xavier_uniform_(self.im_Tconv.weight)


  def forward(self, x):  
    x_re = x[..., 0]
    x_im = x[..., 1]

    out_re = self.re_Tconv(x_re) - self.im_Tconv(x_im)
    out_im = self.re_Tconv(x_im) + self.im_Tconv(x_re)

    out = torch.stack([out_re, out_im], -1) 

    return out

In [25]:
class CBatchnorm(nn.Module):
    def __init__(self, in_channels):
        super(CBatchnorm, self).__init__()
        self.in_channels = in_channels

        self.re_batch = nn.BatchNorm2d(in_channels)
        self.im_batch = nn.BatchNorm2d(in_channels)


    def forward(self, x):
        x_re = x[..., 0]
        x_im = x[..., 1]

        out_re =  self.re_batch(x_re)
        out_im =  self.re_batch(x_im)


        out = torch.stack([out_re, out_im], -1) 

        return out



In [26]:
class CconvBlock(nn.Module):
  def __init__(self, in_channels, out_channels, **kwargs):
    super(CconvBlock, self).__init__()
    self.in_channels = in_channels
    self.out_channels = out_channels

    self.CConv2d = CConv2d(self.in_channels, self.out_channels, **kwargs)
    self.CBatchnorm = CBatchnorm(self.out_channels)
    self.leaky_relu = nn.LeakyReLU()


  def forward(self, x):
    print(1)
    conved = self.CConv2d(x)
    print(2)
    normed = self.CBatchnorm(conved)
    print(3)
    activated =  self.leaky_relu(normed)

    return activated



In [27]:
class CConvTransBlock(nn.Module):
  def __init__(self, in_channels, out_channels, last_layer=False, **kwargs):
    super(CConvTransBlock, self).__init__()
    self.in_channels = in_channels
    self.out_channels = out_channels
    self.last_layer = last_layer

    self.CConvTrans2d = CConvTrans2d(self.in_channels, self.out_channels, **kwargs)
    self.CBatchnorm = CBatchnorm(self.out_channels)
    self.leaky_relu = nn.LeakyReLU()


  def forward(self, x):
    conved =  self.CConvTrans2d(x)

    if not self.last_layer:
        normed = self.CBatchnorm(conved)
        activated =  self.leaky_relu(normed)
        return activated
    else:
        m_phase = conved/(torch.abs(conved)+1e-8)  
        m_mag = torch.tanh(torch.abs(conved))
        out = m_phase * m_mag
        return out  


In [28]:
class Encoder(nn.Module):
  def __init__(self):
      super(Encoder, self).__init__()

      self.CconvBlock0 = CconvBlock(filter_size=(7,5), stride_size=(2,2), in_channels=1, out_channels=45, padding=(0,0))
      self.CconvBlock1 = CconvBlock(filter_size=(7,5), stride_size=(2,2), in_channels=45, out_channels=90, padding=(0,0))
      self.CconvBlock2 = CconvBlock(filter_size=(5,3), stride_size=(2,2), in_channels=90, out_channels=90, padding=(0,0))
      self.CconvBlock3 = CconvBlock(filter_size=(5,3), stride_size=(2,2), in_channels=90, out_channels=90, padding=(0,0))
      self.CconvBlock4 = CconvBlock(filter_size=(5,3), stride_size=(2,1), in_channels=90, out_channels=90, padding=(0,0))


  def forward(self, x):
      ccb0 = self.CconvBlock0(x)
      ccb1 = self.CconvBlock1(ccb0) 
      ccb2 = self.CconvBlock2(ccb1)        
      ccb3 = self.CconvBlock3(ccb2)        
      ccb4 = self.CconvBlock4(ccb3)

      return [ccb0, ccb1, ccb2, ccb3, ccb4]

In [29]:
class Decoder(nn.Module):
  def __init__(self, n_fft=64, hop_length=16):
    super(Decoder, self).__init__()
    self.n_fft = n_fft
    self.hop_length = hop_length


    self.CConvTransBlock0 = CConvTransBlock(filter_size=(5,3), stride_size=(2,1), in_channels=90, out_channels=90, output_padding=(0,0), padding=(0,0))
    self.CConvTransBlock1 = CConvTransBlock(filter_size=(5,3), stride_size=(2,2), in_channels=180, out_channels=90, output_padding=(0,0), padding=(0,0))
    self.CConvTransBlock2 = CConvTransBlock(filter_size=(5,3), stride_size=(2,2), in_channels=180, out_channels=90, output_padding=(0,0), padding=(0,0))
    self.CConvTransBlock3 = CConvTransBlock(filter_size=(7,5), stride_size=(2,2), in_channels=180, out_channels=45, output_padding=(0,0), padding=(0,0))
    self.CConvTransBlock4 = CConvTransBlock(filter_size=(7,5), stride_size=(2,2), in_channels=90, output_padding=(0,1), padding=(0,0),
                              out_channels=1, last_layer=True)
    

  def forward(self, x0, x, is_istft=True):

    
        cctb0 = self.CConvTransBlock0(x[-1])
        # skip-connection
        c0 = torch.cat((cctb0, x[-2]), dim=1)
        
        cctb1 = self.CConvTransBlock1(c0)
        c1 = torch.cat((cctb1, x[-3]), dim=1)
        
        cctb2 = self.CConvTransBlock2(c1)
        c2 = torch.cat((cctb2, x[-4]), dim=1)
        
        cctb3 = self.CConvTransBlock3(c2)
        c3 = torch.cat((cctb3, x[-5]), dim=1)
        
        cctb4 = self.CConvTransBlock4(c3)


        output = cctb4 * x0
        if is_istft:
            output = torch.squeeze(output, 1)
            output = torch.istft(output, n_fft=self.n_fft, hop_length=self.hop_length, normalized=True)
        
        return output


In [30]:
class Model(nn.Module):
  def __init__(self, encoder_out, n_fft=64, hop_length=16):
    super(Model, self).__init__()
    self.encoder_out = encoder_out
    self.n_fft = n_fft
    self.hop_length = hop_length

    self.Encoder = Encoder()
    self.Decoder = Decoder(self.encoder_out, self.n_fft, self.hop_length)

  def forward(self, x):
      encoded = self.Encoder(x)
      decoded = self.Decoder(x, encoded) 
      return decoded

In [31]:
CConv2d1 = CConv2d(in_channels = 1, out_channels = 2, kernel_size = (2,2), stride = (1,1), padding = (0,0))

##Complex Linear Layer Test

In [44]:
x0 = torch.randn(5,)
x1 = torch.randn(5,)
x = torch.stack([x0,x1],-1)
x = x.unsqueeze(0)
x.shape

torch.Size([1, 5, 2])

In [46]:
CLinear1 = CLinear(5, 20)
CLinear1(x)
print(x.shape)
print(CLinear1(x))
print(CLinear1(x).shape)

torch.Size([1, 5, 2])
tensor([[[ 2.3109, -1.0268],
         [ 0.1623, -0.9172],
         [ 0.3190, -1.2287],
         [-0.2733,  1.0967],
         [ 1.5082,  0.7720],
         [-0.0378, -0.7488],
         [-0.2157, -0.3021],
         [ 0.5175,  0.4255],
         [ 1.1525,  0.2468],
         [-0.6438, -0.1749],
         [ 0.3438,  0.8782],
         [ 0.8914,  0.0172],
         [ 0.7914,  0.7718],
         [ 1.2767,  1.2431],
         [ 0.4427,  0.1786],
         [ 1.4706, -0.0756],
         [-0.5609,  1.6882],
         [ 1.0263,  1.0759],
         [ 0.3197, -1.4567],
         [ 1.4837, -0.2569]]], grad_fn=<StackBackward0>)
torch.Size([1, 20, 2])


##Complex Convolution Tests

In [32]:
x0 = torch.randn(5,5)
x1 = torch.randn(5,5)
x = torch.stack([x0,x1],-1)
x = x.unsqueeze(0)
x = x.unsqueeze(0)

In [33]:
print(x.shape)
print(CConv2d1(x))
print(CConv2d1(x).shape)

torch.Size([1, 1, 5, 5, 2])
tensor([[[[[-0.2341,  0.0716],
           [ 0.3865, -0.8192],
           [-1.3823, -0.6631],
           [-1.8096, -0.9856]],

          [[-0.8186, -3.8495],
           [-2.0101, -0.4146],
           [-0.1466, -0.1537],
           [ 1.0804, -0.3110]],

          [[ 0.0705, -0.0787],
           [-2.3859, -2.2655],
           [-1.0929,  1.6646],
           [ 1.0593, -1.4782]],

          [[ 0.1098, -0.9718],
           [ 0.3863, -0.8386],
           [-0.7495,  0.3499],
           [-0.5392, -2.6627]]],


         [[[ 0.7597,  1.4941],
           [ 0.1508,  0.6988],
           [-1.4712, -0.9276],
           [ 0.5946,  0.7313]],

          [[ 0.8548,  0.8320],
           [ 1.4564, -0.3194],
           [-0.7587,  0.4565],
           [ 1.1803,  1.2750]],

          [[ 0.2698,  1.3152],
           [-0.1428,  1.3002],
           [ 0.6749,  2.2751],
           [ 0.3063,  1.2067]],

          [[-1.8984,  1.0157],
           [ 0.3635,  1.3501],
           [-1.0483, -0.42

In [34]:
CConvTrans2d1 = CConvTrans2d(in_channels = 1, out_channels = 2, kernel_size = (2,2), stride = (1,1), padding = (0,0))

In [35]:
print(x.shape)
print(CConvTrans2d1(x))
print(CConvTrans2d1(x).shape)

torch.Size([1, 1, 5, 5, 2])
tensor([[[[[-0.5543,  0.4390],
           [-0.2612,  0.8152],
           [-0.0586,  0.7387],
           [ 0.0968, -0.7288],
           [ 0.5347,  0.5047],
           [-0.4359,  0.4578]],

          [[ 0.7379,  1.4173],
           [ 0.0898,  0.0964],
           [-0.5463,  0.0163],
           [-2.6312,  0.7264],
           [-0.3534, -0.4425],
           [-0.3940,  0.4020]],

          [[-1.2603, -0.8353],
           [ 0.1504, -0.2374],
           [ 0.1259, -0.6706],
           [ 1.4613,  1.6072],
           [ 0.4606,  1.5097],
           [-0.7338,  0.9787]],

          [[ 0.2797,  2.1281],
           [-2.3418,  1.0650],
           [-1.1543,  2.3451],
           [-0.3229,  0.2800],
           [-0.5932, -1.0071],
           [-0.5940,  0.3839]],

          [[ 0.0170, -0.6591],
           [-0.5761,  0.3649],
           [ 0.4268, -0.0229],
           [-1.4475, -0.4828],
           [ 0.4071, -0.4063],
           [-0.2501,  0.7546]],

          [[-0.7438,  0.0076],
 

In [36]:
CBatchnorm1 = CBatchnorm(in_channels=1) 

In [37]:
print(x.shape)
print(CBatchnorm1(x))
print(CBatchnorm1(x).shape)

torch.Size([1, 1, 5, 5, 2])
tensor([[[[[ 0.4625, -0.4345],
           [-0.2492, -0.6675],
           [-0.3995, -0.0650],
           [ 0.6636,  1.9140],
           [-0.1620, -0.1056]],

          [[-1.6141, -0.9015],
           [-0.4035,  0.9692],
           [ 0.5049,  0.1958],
           [ 2.5199, -0.9564],
           [ 0.0864,  0.3517]],

          [[ 0.8564, -0.1743],
           [-1.1273,  1.5183],
           [ 1.2328,  2.4845],
           [-0.2767, -1.3543],
           [-1.7000, -0.6963]],

          [[-1.3393, -1.9326],
           [ 0.5609, -0.7481],
           [-0.2811, -0.8967],
           [ 0.2470, -0.1657],
           [-0.1225,  0.6096]],

          [[-0.4179, -0.0791],
           [ 0.3312, -0.5425],
           [-0.8905,  0.2389],
           [ 2.1253,  0.4539],
           [-0.6071,  0.9843]]]]], grad_fn=<StackBackward0>)
torch.Size([1, 1, 5, 5, 2])


In [38]:
CconvBlock1 = CconvBlock(in_channels = 1, out_channels = 2, kernel_size = (2,2), stride = (1,1), padding = (0,0))

In [39]:
print(x.shape)
print(CconvBlock1(x))
print(CconvBlock1(x).shape)

torch.Size([1, 1, 5, 5, 2])
1
2
3
tensor([[[[[-3.6366e-03,  1.0668e-01],
           [-2.5801e-03, -7.2432e-04],
           [ 6.9928e-01,  3.2533e-01],
           [ 7.5711e-01, -2.8911e-03]],

          [[-1.4797e-02,  1.4519e+00],
           [ 1.5331e+00,  1.5626e+00],
           [ 1.8055e+00, -1.5193e-02],
           [-2.8347e-03, -1.4290e-02]],

          [[-7.2677e-03,  4.3105e-01],
           [ 4.1220e-01,  2.2348e+00],
           [ 1.2165e+00, -3.1249e-03],
           [-1.6484e-02, -6.0245e-03]],

          [[-1.1996e-02, -5.0755e-03],
           [-6.7233e-03, -9.4710e-03],
           [ 2.6352e-01, -2.6660e-03],
           [-5.5278e-04, -1.6631e-03]]],


         [[[ 1.6595e+00,  3.7760e-02],
           [-9.9751e-05, -1.1516e-02],
           [-9.8102e-03, -6.9797e-03],
           [-8.7162e-03,  1.2551e+00]],

          [[-1.1766e-02,  6.6014e-01],
           [ 1.7939e+00, -4.7465e-03],
           [-5.6363e-04, -3.5260e-03],
           [ 7.3796e-01,  1.1304e+00]],

          [[-6.4

In [40]:
CConvTransBlock1 = CConvTransBlock(in_channels = 1, out_channels = 2, kernel_size = (2,2), stride = (1,1), padding = (0,0))

In [41]:
print(x.shape)
print(CConvTransBlock1(x))
print(CConvTransBlock1(x).shape)

torch.Size([1, 1, 5, 5, 2])
tensor([[[[[ 4.8799e-01, -2.4996e-03],
           [ 8.6699e-02, -8.1122e-03],
           [-8.9990e-03, -5.3460e-03],
           [-9.6118e-03,  1.1314e+00],
           [ 8.0377e-01,  8.0137e-01],
           [-1.3261e-03,  3.8154e-02]],

          [[-6.0247e-03, -1.1592e-02],
           [-2.4552e-02,  1.8082e-01],
           [-1.1305e-03,  1.0884e+00],
           [ 1.9745e+00,  4.3998e-01],
           [ 2.3054e+00, -1.2850e-02],
           [ 2.3022e-01,  4.3867e-01]],

          [[ 1.0779e+00, -3.6623e-04],
           [-1.1136e-02,  4.0976e-01],
           [-5.4729e-03,  2.9596e+00],
           [ 1.6568e+00,  6.1198e-01],
           [-2.0432e-02, -2.5005e-02],
           [-1.3835e-02,  2.2321e-01]],

          [[-1.1945e-03, -1.5280e-02],
           [-1.1892e-02, -1.4337e-02],
           [ 4.4088e-01, -9.1974e-03],
           [ 6.6649e-01, -6.9065e-03],
           [-4.4885e-03,  2.0695e-01],
           [-5.4372e-04,  1.2002e+00]],

          [[ 2.8109e-01, -2.