#### This notebook shows how to build a complex layer for pytorch's Linear, Conv and ConvTranspose layers

In [163]:
import torch
import torch.nn as nn
import numpy as np
from pytorch_complex_tensor import ComplexTensor
from torch.nn import Conv2d, Linear, ConvTranspose2d

#### Generic complex wrapper

In [340]:
class complexLayer(nn.Module):
    '''
    This class wraps a pytorch layer and turns it into the equivalent 
    complex layer. So far it works for Linear, Conv and ConvTranspose
    
    TODO:   1. Code it for RNN layers.
    
    '''
    def __init__(self, Layer,kwargs):
        super().__init__()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.bias = kwargs.get('bias',False)
        # turn the bias off so as to only do matrix multiplication 
        # if you leave the bias on, then the complex arithmetic does not 
        # work out correctly
        kwargs['bias'] = False
        self.f_re = Layer(**kwargs)
        self.f_im = Layer(**kwargs)
        self.b = None
        out_dim_keyNames = set(['out_channels', 'out_features'])
        self.outType = list(out_dim_keyNames.intersection(kwargs.keys()))[0]
        self.out_dim = kwargs[self.outType]
        if self.bias:
            b_r = np.random.randn(self.out_dim,1).astype('float32')
            b_i = np.random.randn(self.out_dim,1).astype('float32')
            z = b_r + 1j*b_i
            self.b = ComplexTensor(z)    

    def forward(self, x): 
        real = self.f_re(x.real) - self.f_im(x.imag)
        imaginary = self.f_re(x.imag) + self.f_im(x.real)
        if self.bias:
            if self.outType == 'out_channels':
                # expand the dims
                b_r = self.b.real.reshape(1,len(self.b),1,1)
                b_i = self.b.imag.reshape(1,len(self.b),1,1)
            else:
                b_r = self.b.real.reshape(len(self.b),)
                b_i = self.b.imag.reshape(len(self.b),)
            real = real + b_r
            imaginary = imaginary + b_i
        result = torch.cat([real, imaginary], dim=-2)
        result.__class__ = ComplexTensor
        return result
    
    def __call__(self,x):
        result = self.forward(x)
        return result

#### Testing on Conv2d

In [350]:
bz = 16
bias = True # vary this for testing purposes
x = torch.randn((bz,2,3,100,100))
x_np = x.detach().numpy()
real = np.squeeze(x_np[:,0,:,:])
imag = np.squeeze(x_np[:,1,:,:])
z = real + 1j*imag
z = ComplexTensor(z)

In [351]:
dct = {'in_channels':3, 'out_channels':10, 'kernel_size':5, 'bias':bias}
compConv2D = complexLayer(Conv2d,dct)

In [352]:
out = compConv2D(z)
out.shape

torch.Size([16, 10, 96, 96])

#### Testing on ConvTran2D

In [353]:
dct = {'in_channels':3, 'out_channels':10, 'kernel_size':5,'padding':(2,2),'stride':2, 'bias':bias}
compConvTran2D = complexLayer(ConvTranspose2d,dct)
out = compConvTran2D(z)
out.shape

torch.Size([16, 10, 199, 199])

#### Testing on Linear 

In [355]:
dct = {'in_features':3, 'out_features':10,'bias':bias}
denseLayer = complexLayer(Linear,dct)

In [356]:
bz = 16
x = torch.randn((bz,2,1,3))
x_np = x.detach().numpy()
real = np.squeeze(x_np[:,0,:,:])
imag = np.squeeze(x_np[:,1,:,:])
z = real + 1j*imag
z = ComplexTensor(z)

In [357]:
out = denseLayer(z)
out.shape

torch.Size([16, 10])