In [1]:
import torch
from torch import nn
from torch.nn import Linear
from torch.optim import SGD

In [2]:
class ComplexLinearLayer(nn.Module):
    def __init__(self, in_features, out_features):
        super(ComplexLinearLayer, self).__init__()
        self.W_r = Linear(in_features, out_features,bias=True) 
        self.W_i = Linear(in_features, out_features,bias=True)
        
    def forward(self,in_r, in_i): # in = x 
        Real_OUT = self.W_r(in_r) - self.W_i(in_i)
        Imag_OUT = self.W_r(in_i) + self.W_i(in_r)
        return Real_OUT, Imag_OUT

In [3]:
def complex_sign_activation(in_r, in_i): 
    return torch.sign(in_r), torch.sign(in_i)

In [4]:
class ComplexNetworkPaper(nn.Module):
    def __init__(self,in_features,hidden_features):
        super(ComplexNetworkPaper,self).__init__()
        self.input_linear_Layer = ComplexLinearLayer(in_features,hidden_features)
        
    def forward(self,in_r,in_i):
        hidden_r,hidden_i = self.input_linear_Layer(in_r,in_i)
        output_r,output_i = complex_sign_activation(hidden_r,hidden_i)
        return output_r,output_i

In [5]:
def custom_loss_with_tanh(input_r, input_i, target_r, target_i, H_r,H_i, bias_r,bias_i, lam=0.01):
    '''
    bias_r: z for the real part
    bias_i: z for the imaginary part
    H_r: weight matrix( the channel matrix) (real)
    H_i: weight matrix( the channel matrix) (imaginary)
    input_r: real part of the input
    input_i: imaginary part of the input
    target_r: real part of the target
    target_i: imaginary part of the target
    '''
    # Compute H * x_n + z for r and i parts
    temp_r = torch.matmul(input_r, H_r.t()) + bias_r
    temp_i = torch.matmul(input_i, H_i.t()) + bias_i
    
    # (tanh(Hx_n + z))
    tanh_r = torch.tanh(temp_r)
    tanh_i = torch.tanh(temp_i)
    
    # first_term = 1/N sum [tanh(H * x_n + z)(REAL) + tanh(H * x_n + z)(IMAG)]
    first_term = torch.mean(torch.norm(target_r - tanh_r, dim=1)**2 + torch.norm(target_i - tanh_i, dim=1)**2)
    

    regul_term = lam * (torch.norm(H_r, p=2)+torch.norm(H_i, p=2))
    loss = first_term + regul_term
    return loss


In [6]:
in_r = torch.randn(10, 5)  # 10 is the batch size, so the forwad function will handle these 10 samples at once, and so the same weights are being used for all of these 10 samples 
in_i = torch.randn(10, 5)
target_r = torch.randn(10, 5) 
target_i = torch.randn(10, 5)
model = ComplexNetworkPaper(in_features=in_r.shape[1], hidden_features=target_r.shape[1])
optimizer = SGD(model.parameters(), lr=0.001)

In [7]:
epochs = 100
for epoch in range(epochs):
    model.train() # eval vs train 
    
    # Forward pass
    output_r, output_i = model(in_r, in_i)
    
    # LOss
    loss = custom_loss_with_tanh(input_r=in_r, input_i=in_i, target_r=target_r, target_i=target_i, H_r=model.input_linear_Layer.W_r.weight,H_i = model.input_linear_Layer.W_i.weight, bias_r=model.input_linear_Layer.W_r.bias,bias_i=model.input_linear_Layer.W_i.bias)
    
    # Backpropagation and optimization
    optimizer.zero_grad()  # reset the gradients
    loss.backward()  # Backpropgation to calculation the new gradients
    optimizer.step()  # Update the weights in the direction of these new gradients

    print(f'Epoch {epoch}, Loss: {loss.item()}')


Epoch 0, Loss: 10.04552936553955
Epoch 1, Loss: 10.027202606201172
Epoch 2, Loss: 10.008940696716309
Epoch 3, Loss: 9.990747451782227
Epoch 4, Loss: 9.97261905670166
Epoch 5, Loss: 9.954554557800293
Epoch 6, Loss: 9.936559677124023
Epoch 7, Loss: 9.918630599975586
Epoch 8, Loss: 9.900768280029297
Epoch 9, Loss: 9.88297176361084
Epoch 10, Loss: 9.865241050720215
Epoch 11, Loss: 9.847578048706055
Epoch 12, Loss: 9.829981803894043
Epoch 13, Loss: 9.812451362609863
Epoch 14, Loss: 9.79498291015625
Epoch 15, Loss: 9.777586936950684
Epoch 16, Loss: 9.760254859924316
Epoch 17, Loss: 9.742988586425781
Epoch 18, Loss: 9.725787162780762
Epoch 19, Loss: 9.708650588989258
Epoch 20, Loss: 9.691583633422852
Epoch 21, Loss: 9.674577713012695
Epoch 22, Loss: 9.657639503479004
Epoch 23, Loss: 9.640765190124512
Epoch 24, Loss: 9.623956680297852
Epoch 25, Loss: 9.607213973999023
Epoch 26, Loss: 9.590533256530762
Epoch 27, Loss: 9.573920249938965
Epoch 28, Loss: 9.557369232177734
Epoch 29, Loss: 9.5408849

In [17]:
H = {name: param for name, param in model.named_parameters()}
for name, param in model.named_parameters():
    H[name] = param.data
for h in H:
    print(H[h])

tensor([[-0.3109,  0.3562, -0.1505,  0.1037, -0.1914],
        [ 0.2683, -0.3811,  0.1777,  0.1989, -0.3673],
        [-0.4057,  0.1451,  0.1896, -0.3522,  0.2909],
        [-0.3734, -0.4211, -0.0590,  0.0278, -0.4328],
        [ 0.2413,  0.0966,  0.2457,  0.0462, -0.0449]])
tensor([ 0.2005, -0.1995, -0.4651, -0.1121, -0.3915])
tensor([[ 0.2729,  0.0279, -0.0104,  0.2520,  0.2719],
        [ 0.4346, -0.4275,  0.3036,  0.3758, -0.3768],
        [ 0.3558,  0.3291,  0.0900, -0.1054, -0.0410],
        [ 0.1467, -0.3146, -0.0909, -0.2807, -0.4620],
        [-0.2189, -0.0509,  0.3681,  0.4110,  0.3457]])
tensor([-0.3976, -0.0818, -0.2187, -0.1184, -0.0827])
