In [1]:
import math
import torch
import torch.nn as nn
from torch.nn import Module
import torch.nn.functional as F

In [2]:
class CustomSignFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        # Save the input for backward computation
        ctx.save_for_backward(input)
        # Output +1 for input > 0, else -1 (including for input == 0)
        return torch.where(input > 0, torch.tensor(1.0, device=input.device), torch.tensor(-1.0, device=input.device))

    @staticmethod
    def backward(ctx, grad_output):
        # Retrieve the input saved in the forward pass
        input, = ctx.saved_tensors
        # Gradient of the input is the same as the gradient output (STE)
        grad_input = grad_output.clone()
        # Pass the gradient only where input was non-zero, otherwise set it to 0
        grad_input[input.abs() > 0] = grad_output[input.abs() > 0]
        return grad_input

# Wrapper class for convenience
class CustomSignActivation(torch.nn.Module):
    def __init__(self):
        super(CustomSignActivation, self).__init__()

    def forward(self, input):
        return CustomSignFunction.apply(input)

# Example usage:
sign_activation = CustomSignActivation()

# Test the forward pass
x = torch.tensor([2.0, -3.0, 0.0, 1.5], requires_grad=True)
output = sign_activation(x)
print("Output during inference:", output)

# Test the backward pass (gradient computation during training)
loss = output.sum()  # Just an example loss
loss.backward()
print("Gradient during training:", x.grad)

Output during inference: tensor([ 1., -1., -1.,  1.], grad_fn=<CustomSignFunctionBackward>)
Gradient during training: tensor([1., 1., 1., 1.])


In [12]:
class BiKAConv2D(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1):
        super(BiKAConv2D, self).__init__()
        # Define weights for convolution
        self.weight = nn.Parameter(
            torch.randn(out_channels, in_channels, kernel_size, kernel_size)
        )
        # Define an individual bias for each weight in the kernel
        self.bias = nn.Parameter(
            torch.randn(out_channels, in_channels, kernel_size, kernel_size)
        )
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation

    def forward(self, x):
        print(x)
        print(x.shape)
        print(self.bias)
        print(self.bias.shape)
        # Add the bias to each activation before multiplying by the weight
        # Equivalent to computing w * (a + b) for each kernel position
        modified_input = F.unfold(x, kernel_size=self.weight.shape[2:], stride=self.stride, padding=self.padding)
        print(modified_input)
        print(modified_input.shape)
        modified_input = modified_input.view(
            x.shape[0], x.shape[1], self.weight.shape[2], self.weight.shape[3], -1
        )
        print(modified_input)
        print(modified_input.shape)
        modified_input = modified_input + self.bias.unsqueeze(-1)
        print(modified_input)
        print(modified_input.shape)
        modified_input = modified_input.view(x.shape[0], -1, modified_input.shape[-1])
        print(modified_input)
        print(modified_input.shape)
        
        print(self.weight)
        print(self.weight.shape)
        # Perform the convolution with the modified input
        weight_reshaped = self.weight.view(self.weight.shape[0], -1)
        print(weight_reshaped)
        print(weight_reshaped.shape)
        print(weight_reshaped.shape[0])
        print(weight_reshaped.shape[1])
        print(weight_reshaped.shape[1])
        print(weight_reshaped.reshape(1, weight_reshaped.shape[0]*weight_reshaped.shape[1]).unsqueeze(-1).repeat(1, 1, 2))
        print(weight_reshaped.reshape(1, weight_reshaped.shape[0]*weight_reshaped.shape[1]).unsqueeze(-1).repeat(1, 1, 2).shape)
        print(weight_reshaped.unsqueeze(0).expand(x.shape[0], -1, -1))
        print(weight_reshaped.unsqueeze(0).expand(x.shape[0], -1, -1).shape)
        output = torch.bmm(weight_reshaped.unsqueeze(0).expand(x.shape[0], -1, -1), modified_input)
        output = output.view(x.shape[0], self.weight.shape[0], int(x.shape[2]/self.stride), int(x.shape[3]/self.stride))

        return output

# Example usage
bika_conv = BiKAConv2D(in_channels=1, out_channels=3, kernel_size=3, stride=1, padding=1)
input_tensor  = torch.arange(0, 16, dtype=torch.float32).reshape(1,1,4,4)
output_tensor = bika_conv(input_tensor)

tensor([[[[ 0.,  1.,  2.,  3.],
          [ 4.,  5.,  6.,  7.],
          [ 8.,  9., 10., 11.],
          [12., 13., 14., 15.]]]])
torch.Size([1, 1, 4, 4])
Parameter containing:
tensor([[[[-1.5539, -0.3945, -1.5523],
          [ 0.7011,  1.9568, -0.7478],
          [-1.7810, -1.9629,  1.2301]]],


        [[[-1.3081,  0.2781, -0.0061],
          [-1.0853,  0.4958, -1.7125],
          [ 1.5011, -0.1729, -0.1830]]],


        [[[-1.0890, -0.7653,  0.0072],
          [ 0.5572, -1.3871, -0.9199],
          [ 0.0359, -1.1856, -2.6031]]]], requires_grad=True)
torch.Size([3, 1, 3, 3])
tensor([[[ 0.,  0.,  0.,  0.,  0.,  0.,  1.,  2.,  0.,  4.,  5.,  6.,  0.,  8.,
           9., 10.],
         [ 0.,  0.,  0.,  0.,  0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9.,
          10., 11.],
         [ 0.,  0.,  0.,  0.,  1.,  2.,  3.,  0.,  5.,  6.,  7.,  0.,  9., 10.,
          11.,  0.],
         [ 0.,  0.,  1.,  2.,  0.,  4.,  5.,  6.,  0.,  8.,  9., 10.,  0., 12.,
          13., 14.],
         [ 

RuntimeError: Expected size for first two dimensions of batch2 tensor to be: [1, 9] but got: [1, 27].

In [7]:
print(input_tensor)
print(input_tensor.shape)

tensor([[[[ 0.,  1.,  2.,  3.],
          [ 4.,  5.,  6.,  7.],
          [ 8.,  9., 10., 11.],
          [12., 13., 14., 15.]]]])
torch.Size([1, 1, 4, 4])


In [None]:
print(output_tensor)
print(output_tensor.shape)