## PT complex tensor examples

In [1]:
import os, sys
import numpy as np
import torch
import torch.nn as nn
sys.path.append('../pytorch-complex-tensor')
from pytorch_complex_tensor import ComplexTensor

---   
Creation

In [2]:
# numpy complex tensor
np_c = np.asarray([[1+3j, 1+3j, 1+3j], [2+4j, 2+4j, 2+4j]]).astype(np.complex64)
np_c

array([[1.+3.j, 1.+3.j, 1.+3.j],
       [2.+4.j, 2.+4.j, 2.+4.j]], dtype=complex64)

In [3]:
# torch equivalent
pt_c = ComplexTensor([[1, 1, 1], [2,2,2], [3,3,3], [4,4,4]])
print(pt_c)

tensor([1.+3.j, 1.+3.j, 1.+3.j, 2.+4.j, 2.+4.j, 2.+4.j], dtype=complex64)


In [4]:
# verify reals match
print(np_c.real)
print(pt_c.real)

[[1. 1. 1.]
 [2. 2. 2.]]
tensor([[1., 1., 1.],
        [2., 2., 2.]])


In [5]:
# verify imag match
print(np_c.imag)
print(pt_c.imag)

[[3. 3. 3.]
 [4. 4. 4.]]
tensor([[3., 3., 3.],
        [4., 4., 4.]])


### Below shows that unsupported functions give erroneous results, i.e. be careful

In [6]:
a = pt_c.log2()
print(a)
print(a.shape)

tensor([[0.0000, 0.0000, 0.0000],
        [1.0000, 1.0000, 1.0000],
        [1.5850, 1.5850, 1.5850],
        [2.0000, 2.0000, 2.0000]])
torch.Size([4, 3])


---   
Verify complex addition

In [7]:
np_c + (3+2j)

array([[4.+5.j, 4.+5.j, 4.+5.j],
       [5.+6.j, 5.+6.j, 5.+6.j]], dtype=complex64)

In [8]:
pt_c + (3 + 2j)

tensor([4.+5.j, 4.+5.j, 4.+5.j, 5.+6.j, 5.+6.j, 5.+6.j], dtype=complex64)

--- 
verify abs

In [9]:
np.abs(np_c)

array([[3.1622777, 3.1622777, 3.1622777],
       [4.472136 , 4.472136 , 4.472136 ]], dtype=float32)

In [10]:
pt_c.abs()

tensor([[3.1623, 3.1623, 3.1623],
        [4.4721, 4.4721, 4.4721]])

--- 
verify complex vs real matrix multiply

In [11]:
np_x = np.asarray([[3, 3], [4, 4], [2, 2]])
pt_x = torch.Tensor([[3, 3], [4, 4], [2, 2]])

print(np_x)
print(pt_x)

[[3 3]
 [4 4]
 [2 2]]
tensor([[3., 3.],
        [4., 4.],
        [2., 2.]])


In [12]:
np_mm_out = np.matmul(np_c, np_x)
np_mm_out

array([[ 9.+27.j,  9.+27.j],
       [18.+36.j, 18.+36.j]])

In [13]:
pt_mm_out = pt_c.mm(pt_x)
pt_mm_out

tensor([ 9.+27.j,  9.+27.j, 18.+36.j, 18.+36.j], dtype=complex64)

In [14]:
# verify reals
print(np_mm_out.real)
print(pt_mm_out.real)

[[ 9.  9.]
 [18. 18.]]
tensor([[ 9.,  9.],
        [18., 18.]])


In [15]:
# verify imags
print(np_mm_out.imag)
print(pt_mm_out.imag)

[[27. 27.]
 [36. 36.]]
tensor([[27., 27.],
        [36., 36.]])


--- 
verify transpose

In [16]:
np_c.T

array([[1.+3.j, 2.+4.j],
       [1.+3.j, 2.+4.j],
       [1.+3.j, 2.+4.j]], dtype=complex64)

In [17]:
pt_c.t()

tensor([1.+3.j, 2.+4.j, 1.+3.j, 2.+4.j, 1.+3.j, 2.+4.j], dtype=complex64)

--- 
## wfalcon/pytorch-complex-tensor


In [18]:
pt_c2 = ComplexTensor([[1, 3, 5], [7,9,11], [2,4,6], [8,10,12]])
print(pt_c2)
pt_c2.requires_grad = True

tensor([ 1. +2.j,  3. +4.j,  5. +6.j,  7. +8.j,  9.+10.j, 11.+12.j],
      dtype=complex64)


In [19]:
out = pt_c2 + 4
out = out.mm(pt_c2.t())
print(out)

tensor([15.+136.j, 69.+334.j, -3.+262.j, 51.+676.j], dtype=complex64)


In [20]:
real_sum = out.real.sum()
print(real_sum)
out_imag = out.imag.sum()
print(out_imag)

tensor(132., grad_fn=<SumBackward0>)
tensor(1408., grad_fn=<SumBackward0>)


In [21]:
real_sum.backward()

In [22]:
pt_c2.grad

tensor([['(24.0-20.0j)' '(32.0-28.0j)' '(40.0-36.0j)'],
        ['(24.0-20.0j)' '(32.0-28.0j)' '(40.0-36.0j)']])

### This is the complex convolution code from here:
https://github.com/litcoderr/ComplexCNN

In [23]:
class ComplexConv(nn.Module):
    def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
        super(ComplexConv,self).__init__()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.padding = padding

        ## Model components
        self.conv_re = nn.Conv2d(in_channel, out_channel, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
        self.conv_im = nn.Conv2d(in_channel, out_channel, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
        
    def forward(self, x): # shpae of x : [batch,2,channel,axis1,axis2]
        real = self.conv_re(x[:,0]) - self.conv_im(x[:,1])
        imaginary = self.conv_re(x[:,1]) + self.conv_im(x[:,0])
        output = torch.stack((real,imaginary),dim=1)
        return output

In [79]:
## Random Tensor for Input
## shape : [batchsize,2,channel,axis1_size,axis2_size]
## Below dimensions are totally random
bz = 16
x = torch.randn((bz,2,3,100,100))

# 1. Make ComplexConv Object
## (in_channel, out_channel, kernel_size) parameter is required
complexConv = ComplexConv(3,10,(5,5))

# 2. compute
y = complexConv(x)
print(y.shape)

torch.Size([16, 2, 10, 96, 96])


In [80]:
x_np = x.detach().numpy()
real = np.squeeze(x_np[:,0,:,:])
imag = np.squeeze(x_np[:,1,:,:])
z = real + 1j*imag
print(real.shape)
print(z.shape)

(16, 3, 100, 100)
(16, 3, 100, 100)


In [81]:
C = ComplexTensor(z)
C.shape

torch.Size([16, 3, 100, 100])

### Example initialization from pytorch tensors

In [114]:
D = torch.cat([x[:,0,:,:].squeeze(),x[:,1,:,:].squeeze()],dim=-2)
print(D.shape)
D.__class__ = ComplexTensor
assert torch.all(D.real == C.real)
assert torch.all(D.imag == C.imag)

torch.Size([16, 3, 200, 100])


In [107]:
class ComplexConv_w_complexTensor(nn.Module):
    def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
        super(ComplexConv_w_complexTensor,self).__init__()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.padding = padding

        ## Model components
        self.conv_re = nn.Conv2d(in_channel, out_channel, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
        self.conv_im = nn.Conv2d(in_channel, out_channel, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
        
    def forward(self, x): # shpae of x : [batch,2,channel,axis1,axis2]
        real = self.conv_re(x.real) - self.conv_im(x.imag)
        imaginary = self.conv_re(x.imag) + self.conv_im(x.real)
        result = torch.cat([real, imaginary], dim=-2)
        result.__class__ = ComplexTensor
        return result
    

In [108]:
complexConv_w_CT = ComplexConv_w_complexTensor(3,10,(5,5))
complexConv_w_CT.conv_re = complexConv.conv_re
complexConv_w_CT.conv_im = complexConv.conv_im

In [109]:
y2 = complexConv_w_CT(C)
print(y2.shape)
real_err = y2.real - y[:,0,:,:].squeeze()
imag_err = y2.imag - y[:,1,:,:].squeeze()
print(torch.max(real_err.abs()))
print(torch.max(imag_err.abs()))

torch.Size([16, 10, 96, 96])
tensor(0., grad_fn=<MaxBackward1>)
tensor(0., grad_fn=<MaxBackward1>)
