In [1]:
import torch
from torch import nn
from torch.nn import Linear
from torch.optim import SGD

In [9]:
class ComplexLinearLayer(nn.Module):
    def __init__(self, in_features, out_features):
        super(ComplexLinearLayer, self).__init__()
        self.W_r = Linear(in_features, out_features,bias=True)
        self.W_i = Linear(in_features, out_features,bias=True)
        
    def forward(self,in_r, in_i):
        Real_OUT = self.W_r(in_r) - self.W_i(in_i)
        Imag_OUT = self.W_r(in_i) + self.W_i(in_r)
        return Real_OUT, Imag_OUT

In [10]:
def complex_sign_activation(in_r, in_i):
    return torch.sign(in_r), torch.sign(in_i)

In [12]:
class ComplexNetworkPaper(nn.Module):
    def __init__(self,in_features,hidden_features):
        super(ComplexNetworkPaper,self).__init__()
        self.input_linear_Layer = ComplexLinearLayer(in_features,hidden_features)
        
    def forward(self,in_r,in_i):
        hidden_r,hidden_i = self.input_linear_Layer(in_r,in_i)
        output_r,output_i = complex_sign_activation(hidden_r,hidden_i)
        
        return output_r,output_i

In [13]:
def custom_loss_with_tanh(input_r, input_i, target_r, target_i, H_r,H_i, bias_r,bias_i, lam=0.01):
    '''
    bias_r: z for the real part
    bias_i: z for the imaginary part
    H_r: weight matrix( the channel matrix) (real)
    H_i: weight matrix( the channel matrix) (imaginary)
    input_r: real part of the input
    input_i: imaginary part of the input
    target_r: real part of the target
    target_i: imaginary part of the target
    '''
    # Compute H * x_n + z for r and i parts
    temp_r = torch.matmul(input_r, H_r.t()) + bias_r
    temp_i = torch.matmul(input_i, H_i.t()) + bias_i
    
    # (tanh(Hx_n + z))
    tanh_r = torch.tanh(temp_r)
    tanh_i = torch.tanh(temp_i)
    
    # first_term = 1/N sum [tanh(H * x_n + z)(REAL) + tanh(H * x_n + z)(IMAG)]
    first_term = torch.mean(torch.norm(target_r - tanh_r, dim=1)**2 + torch.norm(target_i - tanh_i, dim=1)**2)
    

    regul_term = lam * (torch.norm(H_r, p=2)+torch.norm(H_i, p=2))
    loss = first_term + regul_term
    return loss


In [19]:
in_r = torch.randn(10, 5)  # 10 is the batch size, so the forwad function will handle these 10 samples at once, and so the same weights are being used for all of these 10 samples 
in_i = torch.randn(10, 5)
target_r = torch.randn(10, 5)  
target_i = torch.randn(10, 5)
model = ComplexNetworkPaper(in_features=in_r.shape[1], hidden_features=target_r.shape[1])
optimizer = SGD(model.parameters(), lr=0.001)

In [None]:
epochs = 100
for epoch in range(epochs):
    model.train()
    
    # Forward pass
    output_r, output_i = model(in_r, in_i)
    
    # LOss
    loss = custom_loss_with_tanh(input_r=in_r, input_i=in_i, target_r=target_r, target_i=target_i, H_r=model.input_linear_Layer.W_r.weight,H_i = model.input_linear_Layer.W_i.weight, bias_r=model.input_linear_Layer.W_r.bias,bias_i=model.input_linear_Layer.W_i.bias)
    
    # Backpropagation and optimization
    optimizer.zero_grad()  # reset the gradients
    loss.backward()  # Backpropgation to calculation the new gradients
    optimizer.step()  # Update the weights in the direction of these new gradients

    print(f'Epoch {epoch}, Loss: {loss.item()}')