In [52]:
import torch
import torch.nn as nn
import numpy as np
from pytorch_complex_tensor import ComplexTensor
from torch.nn import Conv2d

#### The goal is to build a wrapper around torch.nn layers such that 

**MetaComplexMM('conv2d',kwargs)**

creates an subclass of nn.Module that has 2 conv2d functions

    f_re = conv2d(**kwargs)

    f_im = conv2d(**kwargs)
    
The initialization and the forward operation should be as below. 

In [53]:
class ComplexConv2d(nn.Module):
    '''
    This class is meant to be a generic wrapper around torch.nn
    layers such as Dense, Conv, ConvTran,...etc. 

    If you keep the bias in f_re and f_im, then the 
    bias becomes (b_r - b_i) + i(b_r + b_i) wheraas the bias 
    should really be (b_r + ib_i). That's why we need to do 
    "bias" trick below. 
    '''
    def __init__(self, *args, **kwargs):
        super(ComplexConv2d,self).__init__()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.bias = kwargs.get('bias',False)
        # turn the bias off so as to only do matrix multiplication 
        kwargs['bias'] = False
        self.f_re = f(*args,**kwargs)
        self.f_im = f(*args,**kwargs)
        self.b = None
        if self.bias:
            n = args[1] # output dim
            b_r = np.random.randn(n,1).astype('float32')
            b_i = np.random.randn(n,1).astype('float32')
            z = b_r + 1j*b_i
            self.b = ComplexTensor(z)

    def forward(self, x): 
        real = self.f_re(x.real) - self.f_im(x.imag)
        imaginary = self.f_re(x.imag) + self.f_im(x.real)
        result = torch.cat([real, imaginary], dim=-2)
        result.__class__ = ComplexTensor
        if self.b:
            result = result + self.b
        return result

I believe that this should work for the other layers too, e.g. Dense, Conv, ConvTran. But not maxpool, batchNorm, etc. I would like to build a meta class that builds complex layers as follows:

ComplexConv2d = MetaComplexMM('conv2d',kwargs)

ComplexDense = MetaComplexMM('Dense',kwargs)

.
.
.

but I don't know how. 

I'm trying some approaches from here

https://realpython.com/python-metaclasses/

https://stackoverflow.com/questions/681953/how-to-decorate-a-class

In [54]:
class MetaComplexMM(type):
    
    def __new__(meta, name, bases, dct):
        x = super(MetaComplexMM, meta).__new__(meta, name, bases, dct)
        return x
    
    def __init__(cls, name, bases, dct):
        super(MetaComplexMM, cls).__init__(name, bases, dct)
        cls.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        cls.bias = dct.get('bias',False)
        # turn the bias off so as to only do matrix multiplication 
        dct['bias'] = False
        cls.f_re = bases[0](**dct)
        cls.f_im = bases[0](**dct)
        cls.b = None
        if cls.bias:
            n = args[1] # output dim
            b_r = np.random.randn(n,1).astype('float32')
            b_i = np.random.randn(n,1).astype('float32')
            z = b_r + 1j*b_i
            cls.b = ComplexTensor(z)    

    def forward(cls, x): 
        real = cls.f_re(x.real) - cls.f_im(x.imag)
        imaginary = cls.f_re(x.imag) + cls.f_im(x.real)
        result = torch.cat([real, imaginary], dim=-2)
        result.__class__ = ComplexTensor
        if cls.b:
            result = result + cls.b
        return result

In [55]:
dct = {'in_channels':3, 'out_channels':10, 'kernel_size':5}
complexConv = MetaComplexMM('ComplexConv2d',(Conv2d,), dct)

In [56]:
complexConv.__name__

'ComplexConv2d'

In [57]:
complexConv.in_channels

3

In [60]:
bz = 16
x = torch.randn((bz,2,3,100,100))
x_np = x.detach().numpy()
real = np.squeeze(x_np[:,0,:,:])
imag = np.squeeze(x_np[:,1,:,:])
z = real + 1j*imag
z = ComplexTensor(z)

In [61]:
out = complexConv(z)

TypeError: __init__() missing 2 required positional arguments: 'out_channels' and 'kernel_size'

In [62]:
out = complexConv.forward(z)

TypeError: forward() missing 1 required positional argument: 'input'

In [63]:
complexConv

__main__.ComplexConv2d

In [68]:
complexConv.f_re

Conv2d(3, 10, kernel_size=(5, 5), stride=(1, 1), bias=False)

In [71]:
out = complexConv.f_re(z.real)

In [72]:
out.shape

torch.Size([16, 10, 96, 96])