In [1]:
import math
import torch
import torch.nn as nn
from torch.nn import Module
import torch.nn.functional as F

In [2]:
class CustomSignFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        # Save the input for backward computation
        ctx.save_for_backward(input)
        # Output +1 for input > 0, else -1 (including for input == 0)
        return torch.where(input > 0, torch.tensor(1.0, device=input.device), torch.tensor(-1.0, device=input.device))

    @staticmethod
    def backward(ctx, grad_output):
        # Retrieve the input saved in the forward pass
        input, = ctx.saved_tensors
        # Gradient of the input is the same as the gradient output (STE)
        grad_input = grad_output.clone()
        # Pass the gradient only where input was non-zero, otherwise set it to 0
        grad_input[input.abs() > 0] = grad_output[input.abs() > 0]
        return grad_input

# Wrapper class for convenience
class CustomSignActivation(torch.nn.Module):
    def __init__(self):
        super(CustomSignActivation, self).__init__()

    def forward(self, input):
        return CustomSignFunction.apply(input)

# Example usage:
sign_activation = CustomSignActivation()

# Test the forward pass
x = torch.tensor([2.0, -3.0, 0.0, 1.5], requires_grad=True)
output = sign_activation(x)
print("Output during inference:", output)

# Test the backward pass (gradient computation during training)
loss = output.sum()  # Just an example loss
loss.backward()
print("Gradient during training:", x.grad)

Output during inference: tensor([ 1., -1., -1.,  1.], grad_fn=<CustomSignFunctionBackward>)
Gradient during training: tensor([1., 1., 1., 1.])


In [3]:
class BiKALinear(nn.Module):
    def __init__(self, in_features, out_features):
        super(BiKALinear, self).__init__()
        
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
        self.bias = nn.Parameter(torch.Tensor(out_features, in_features))
        self.sign = CustomSignActivation()
            
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
        bound = 1 / math.sqrt(fan_in)
        nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, x):
        # Expand the input to match the bias shape for broadcasting
        # x is of shape (batch_size, in_features)
        # Expand bias matrix to (batch_size, out_features, in_features)
        x = x.unsqueeze(1) + self.bias.unsqueeze(0)
        
        # Perform element-wise multiplication with weights
        x = x * self.weight.unsqueeze(0)
        
        # Apply sign function: -1 for negative and 0, 1 for positive
        x = self.sign(x)
        
        # Sum the thresholded products along the input features dimension
        x = torch.sum(x, dim=-1) 

        return x

# Example usage
bika_linear = BiKALinear(in_features=2, out_features=3)
input_tensor  = torch.randn(3, 2)  # Batch of 3, 10 input features each
output_tensor = bika_linear(input_tensor)

In [27]:
class BiKAConv2D(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1):
        super(BiKAConv2D, self).__init__()
        # Define weights for convolution
        self.weight = nn.Parameter(
            torch.randn(out_channels, in_channels, kernel_size, kernel_size)
        )
        # Define an individual bias for each weight in the kernel
        self.bias = nn.Parameter(
            torch.randn(out_channels, in_channels, kernel_size, kernel_size)
        )
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation

    def forward(self, x):        
        print("\nInput: ")
        print(x)
        print(x.shape)
        print("\nBias: ")
        print(self.bias)
        print(self.bias.shape)
        
        # Add the bias to each activation before multiplying by the weight
        # Equivalent to computing w * (a + b) for each kernel position
        modified_input = F.unfold(x, kernel_size=self.weight.shape[2:], stride=self.stride, padding=self.padding)
        # batch_size, in_channels, height, width = x.shape
        # kernel_height, kernel_width = self.weight.shape[2:]
        # modified_input = F.unfold(x, kernel_size=(kernel_height, kernel_width), stride=self.stride, padding=self.padding)
        # unfolded_input = F.unfold(x, kernel_size=(kernel_height, kernel_width), stride=self.stride, padding=self.padding)
        print("\nUnfold Input: ")
        print(modified_input)
        print(modified_input.shape)
        
        modified_input = modified_input.view(
             x.shape[0], x.shape[1], self.weight.shape[2], self.weight.shape[3], -1
        )
        # unfolded_input = unfolded_input.view(batch_size, in_channels, kernel_height, kernel_width, -1)
        # modified_input = modified_input.view(batch_size, in_channels, kernel_height, kernel_width, -1)
        print("\nReshaped Unfold Input: ")
        print(modified_input)
        print(modified_input.shape)
        
        print("\nReshaped Bias: ")
        print(self.bias.unsqueeze(-1))
        print(self.bias.unsqueeze(-1).shape)
        modified_input = modified_input + self.bias.unsqueeze(-1)
        # modified_input = unfolded_input + self.bias.unsqueeze(0).unsqueeze(-1)
        # modified_input = modified_input + self.bias.unsqueeze(0).unsqueeze(-1)
        print("\nBiased Input: ")
        print(modified_input)
        print(modified_input.shape)
        
        modified_input = modified_input.view(x.shape[0], -1, modified_input.shape[-1])
        print(modified_input)
        print(modified_input.shape)
        
        print(self.weight)
        print(self.weight.shape)
        # Perform the convolution with the modified input
        weight_reshaped = self.weight.view(self.weight.shape[0], -1)
        print(weight_reshaped)
        print(weight_reshaped.shape)
        print(weight_reshaped.shape[0])
        print(weight_reshaped.shape[1])
        print(weight_reshaped.shape[1])
        print(weight_reshaped.reshape(1, weight_reshaped.shape[0]*weight_reshaped.shape[1]).unsqueeze(-1).repeat(1, 1, 2))
        print(weight_reshaped.reshape(1, weight_reshaped.shape[0]*weight_reshaped.shape[1]).unsqueeze(-1).repeat(1, 1, 2).shape)
        print(weight_reshaped.unsqueeze(0).expand(x.shape[0], -1, -1))
        print(weight_reshaped.unsqueeze(0).expand(x.shape[0], -1, -1).shape)
        output = torch.bmm(weight_reshaped.unsqueeze(0).expand(x.shape[0], -1, -1), modified_input)
        output = output.view(x.shape[0], self.weight.shape[0], int(x.shape[2]/self.stride), int(x.shape[3]/self.stride))

        return output

# Example usage
bika_conv = BiKAConv2D(in_channels=2, out_channels=3, kernel_size=3, stride=1, padding=1)
input_tensor  = torch.arange(0, 36, dtype=torch.float32).reshape(2,2,3,3)
output_tensor = bika_conv(input_tensor)


Input: 
tensor([[[[ 0.,  1.,  2.],
          [ 3.,  4.,  5.],
          [ 6.,  7.,  8.]],

         [[ 9., 10., 11.],
          [12., 13., 14.],
          [15., 16., 17.]]],


        [[[18., 19., 20.],
          [21., 22., 23.],
          [24., 25., 26.]],

         [[27., 28., 29.],
          [30., 31., 32.],
          [33., 34., 35.]]]])
torch.Size([2, 2, 3, 3])

Bias: 
Parameter containing:
tensor([[[[-0.6859, -0.2026, -0.5440],
          [-0.5022,  0.3358, -0.0329],
          [-0.0755,  0.4285, -1.0047]],

         [[-0.9361, -0.1238, -0.3391],
          [ 0.0918,  0.8681,  0.7448],
          [-0.7153, -0.4400,  0.2580]]],


        [[[ 1.4395,  0.7533, -0.6200],
          [-0.4979,  0.4181, -0.1048],
          [-1.3353, -1.1510, -0.8696]],

         [[-0.7482, -1.6774,  0.0405],
          [-1.1515, -1.1363,  0.2383],
          [-1.1684,  0.4206, -1.0500]]],


        [[[-0.0218,  1.5287,  0.4934],
          [ 0.2579,  0.4149,  0.4943],
          [-0.5038,  0.5697,  1.8474]],

  

RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 0

In [7]:
print(input_tensor)
print(input_tensor.shape)

tensor([[[[ 0.,  1.,  2.,  3.],
          [ 4.,  5.,  6.,  7.],
          [ 8.,  9., 10., 11.],
          [12., 13., 14., 15.]]]])
torch.Size([1, 1, 4, 4])


In [None]:
print(output_tensor)
print(output_tensor.shape)