# Define BiKA Linear and Conv2d Layer

In [1]:
import math
import torch
import torch.nn as nn
from torch.nn import Module
import torch.nn.functional as F

In [2]:
class CustomSignFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        # Save the input for backward computation
        ctx.save_for_backward(input)
        # Output +1 for input > 0, else -1 (including for input == 0)
        return torch.where(input > 0, torch.tensor(1.0, device=input.device), torch.tensor(-1.0, device=input.device))

    @staticmethod
    def backward(ctx, grad_output):
        # Retrieve the input saved in the forward pass
        input, = ctx.saved_tensors
        # Gradient of the input is the same as the gradient output (STE)
        grad_input = grad_output.clone()
        # Pass the gradient only where input was non-zero, otherwise set it to 0
        grad_input[input.abs() > 0] = grad_output[input.abs() > 0]
        return grad_input

# Wrapper class for convenience
class CustomSignActivation(torch.nn.Module):
    def __init__(self):
        super(CustomSignActivation, self).__init__()

    def forward(self, input):
        return CustomSignFunction.apply(input)

In [3]:
class BiKALinear(nn.Module):
    def __init__(self, in_features, out_features):
        super(BiKALinear, self).__init__()
        
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
        self.bias = nn.Parameter(torch.Tensor(out_features, in_features))
        self.sign = CustomSignActivation()
            
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
        bound = 1 / math.sqrt(fan_in)
        nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, x):
        # Expand the input to match the bias shape for broadcasting
        # x is of shape (batch_size, in_features)
        # Expand bias matrix to (batch_size, out_features, in_features)
        x = x.unsqueeze(1) + self.bias.unsqueeze(0)
        
        # Perform element-wise multiplication with weights
        x = x * self.weight.unsqueeze(0)
        
        # Apply sign function: -1 for negative and 0, 1 for positive
        x = self.sign(x)
        
        # Sum the thresholded products along the input features dimension
        x = torch.sum(x, dim=-1) 

        return x

In [4]:
class BiKAConv2D(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1):
        super(BiKAConv2D, self).__init__()
        # Define weights for convolution
        self.weight = nn.Parameter(
            torch.randn(out_channels, in_channels, kernel_size, kernel_size)
        )
        # Define an individual bias for each weight in the kernel
        self.bias = nn.Parameter(
            torch.randn(out_channels, in_channels, kernel_size, kernel_size)
        )
        self.sign = CustomSignActivation()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation

    def forward(self, x):   
        batch_size, in_channels, height, width = x.shape
        out_height = int((height + 2*self.padding-self.dilation*(self.kernel_size-1)-1)/self.stride+1)
        out_width = int((width + 2*self.padding-self.dilation*(self.kernel_size-1)-1)/self.stride+1)
        unfold_length = out_height*out_width
        
        # Add the bias to each activation before multiplying by the weight
        # Equivalent to computing w * (a + b) for each kernel position
        modified_input = F.unfold(x, kernel_size=self.weight.shape[2:], stride=self.stride, padding=self.padding)
        modified_input = modified_input.view(x.shape[0], x.shape[1], self.weight.shape[2], self.weight.shape[3], -1).unsqueeze(1)
        modified_bias = self.bias.unsqueeze(-1).unsqueeze(0)
        modified_input = modified_input + modified_bias
        modified_input = modified_input.view(x.shape[0], -1, modified_input.shape[-1])
        
        # Perform the convolution with the modified input
        modified_weight = self.weight.view(-1).unsqueeze(0).unsqueeze(2)
        output=modified_input * modified_weight
        
        # Apply sign function: -1 for negative and 0, 1 for positive
        output = self.sign(output)
        
        # Sum the thresholded products along the input features dimension
        output = output.view(batch_size, self.out_channels, in_channels*self.kernel_size*self.kernel_size,unfold_length)
        output = output.sum(dim=2) 
        output = output.view(batch_size, self.out_channels, out_height, out_width)

        return output

# Try Tiny CNN-like BiKA with MNIST

## 1. Dataset Loading

In [5]:
import torchvision
from torchvision import datasets
from torchvision import transforms
from torch.autograd import Variable
from torch.utils.data import DataLoader, random_split

In [6]:
full_data_train = torchvision.datasets.MNIST('./data/', 
                                        train=True, download=True,
                                        transform=torchvision.transforms.Compose
                                        ([
                                            torchvision.transforms.ToTensor(),
                                            torchvision.transforms.Normalize((0.5,), (0.5,))
                                        ]))

# Split the dataset into training and validation subsets
train_size = int(0.8 * len(full_data_train))
val_size = len(full_data_train) - train_size
data_train, data_valid = random_split(full_data_train, [train_size, val_size])

data_test = torchvision.datasets.MNIST('./data/', 
                                       train=False, download=True,
                                       transform=torchvision.transforms.Compose
                                       ([
                                            torchvision.transforms.ToTensor(),
                                            torchvision.transforms.Normalize((0.5,), (0.5,))
                                       ]))

## 2. Define CNN-like BiKA structure

In [7]:
import math
import torch
import torch.nn as nn
from torch.nn import Module
import torch.nn.functional as F
import torch.optim as optim
from sklearn.metrics import accuracy_score
from itertools import product
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm, trange

In [8]:
import brevitas.nn as qnn
from brevitas.nn import QuantLinear, QuantReLU, QuantConv2d
from brevitas.quant.binary import SignedBinaryActPerTensorConst
from brevitas.quant.binary import SignedBinaryWeightPerTensorConst
from brevitas.inject.enum import QuantType

In [9]:
kernel_size=3    

in_channels0=1
out_channels0=64 

in_channels1=out_channels0
out_channels1=64

input_size = 7*7*out_channels1
hidden0 = 64   
num_classes = 10  

In [10]:
class BiKA_MNIST(Module):
    def __init__(self):
        super(BiKA_MNIST, self).__init__()
        
        self.conv0 = BiKAConv2D(in_channels=in_channels0, out_channels=out_channels0, kernel_size=kernel_size, stride=1, padding=1)
        self.pool0 = nn.MaxPool2d(2)
        
        self.conv1 = BiKAConv2D(in_channels=in_channels1, out_channels=out_channels1, kernel_size=kernel_size, stride=1, padding=1)
        self.pool1 = nn.MaxPool2d(2)
        
        self.fc0   = BiKALinear(in_features=input_size, out_features=hidden0)
        
        self.out   = BiKALinear(in_features=hidden0, out_features=num_classes)

    def forward(self, x):
        
        out = self.pool0((self.conv0(x)))
        out = self.pool1((self.conv1(out)))
        out = out.reshape(out.shape[0], -1)
        out = self.fc0(out)
        out = self.out(out)
        
        return out

## 3. Define Training Function

In [11]:
num_of_gpus = torch.cuda.device_count()
print(num_of_gpus)

# Check for GPU
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Setting seeds for reproducibility
torch.manual_seed(0)

2
Using device: cuda:1


<torch._C.Generator at 0x7f42350c5970>

In [12]:
def display_loss_plot(losses, title="Training loss", xlabel="Iterations", ylabel="Loss"):
    x_axis = [i for i in range(len(losses))]
    plt.plot(x_axis,losses)
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.show()

In [13]:
def train_and_validate(model, train_loader, val_loader, criterion, learning_rate):
    model.train()
    
    optimizer = optim.Adam(model.parameters(), lr=learning_rate[0])
    
    for epoch in range(100):
        running_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            
        # Adjust learning rate at epoch 100
        if epoch+1 == 50:
            for param_group in optimizer.param_groups:
                param_group['lr'] = learning_rate[1]
                print(f"Learning rate changed to {param_group['lr']} at epoch {epoch+1}")
        
        # Adjust learning rate at epoch 150
        if epoch+1 == 75:
            for param_group in optimizer.param_groups:
                param_group['lr'] = learning_rate[2]
                print(f"Learning rate changed to {param_group['lr']} at epoch {epoch+1}")

        # Validation phase
        model.eval()
        all_preds = []
        all_labels = []
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                _, preds = torch.max(outputs, 1)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

        val_acc = accuracy_score(all_labels, all_preds)
        print(f"Epoch [{epoch+1}/{100}], "
              f"Train Loss: {running_loss/len(train_loader):.4f}, "
              f"Val Accuracy: {val_acc*100:.2f}%")
        
    return val_acc

## 4. Define Evaluation Function

In [14]:
def evaluate_model(model, test_loader):
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    test_acc = accuracy_score(all_labels, all_preds)
    print(f"Test Accuracy: {test_acc * 100:.2f}%")
    return test_acc

## 5. Train CNN-like BiKA for MNIST

In [15]:
batch_sizes = [128]
learning_rates = [
                  [0.0010, 0.0010, 0.0001]
                 ]

In [16]:
best_acc = 0.0
best_params = None

for batch_size, learning_rate in product(batch_sizes, learning_rates):
    print(f"Training with batch_size={batch_size}, learning_rate={learning_rate}")

    # Data loaders
    train_loader = DataLoader(data_train, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(data_valid, batch_size=batch_size, shuffle=False)

    # Initialize the model, loss, and optimizer
    model = BiKA_MNIST().to(device)
    criterion = nn.CrossEntropyLoss()

    # Train and validate
    val_acc = train_and_validate(model, train_loader, val_loader, criterion, learning_rate)

    # Update best parameters
    if val_acc > best_acc:
        best_acc = val_acc
        best_params = (batch_size, learning_rate)

print(f"Best Accuracy: {best_acc*100:.2f}%")
print(f"Best Parameters: Batch Size={best_params[0]}, Learning Rate={best_params[1]}")

Training with batch_size=128, learning_rate=[0.001, 0.001, 0.0001]


Epoch [1/100], Train Loss: 2.0419, Val Accuracy: 86.78%


Epoch [2/100], Train Loss: 1.1253, Val Accuracy: 88.52%


Epoch [3/100], Train Loss: 0.9457, Val Accuracy: 89.14%


Epoch [4/100], Train Loss: 0.8015, Val Accuracy: 91.12%


Epoch [5/100], Train Loss: 0.7112, Val Accuracy: 89.30%


Epoch [6/100], Train Loss: 0.8149, Val Accuracy: 90.57%


Epoch [7/100], Train Loss: 0.8715, Val Accuracy: 90.34%


Epoch [8/100], Train Loss: 0.8274, Val Accuracy: 86.78%


Epoch [9/100], Train Loss: 0.8286, Val Accuracy: 89.07%


Epoch [10/100], Train Loss: 0.7794, Val Accuracy: 88.70%


Epoch [11/100], Train Loss: 0.7475, Val Accuracy: 90.09%


Epoch [12/100], Train Loss: 0.7216, Val Accuracy: 88.97%


Epoch [13/100], Train Loss: 0.7687, Val Accuracy: 89.22%


Epoch [14/100], Train Loss: 0.7556, Val Accuracy: 89.49%


Epoch [15/100], Train Loss: 0.7476, Val Accuracy: 89.80%


Epoch [16/100], Train Loss: 0.7161, Val Accuracy: 90.89%


Epoch [17/100], Train Loss: 0.7418, Val Accuracy: 90.72%


Epoch [18/100], Train Loss: 0.7154, Val Accuracy: 90.25%


Epoch [19/100], Train Loss: 0.7189, Val Accuracy: 87.51%


Epoch [20/100], Train Loss: 0.7027, Val Accuracy: 88.68%


Epoch [21/100], Train Loss: 0.7414, Val Accuracy: 89.88%


Epoch [22/100], Train Loss: 0.7263, Val Accuracy: 89.81%


Epoch [23/100], Train Loss: 0.7597, Val Accuracy: 90.30%


Epoch [24/100], Train Loss: 0.7605, Val Accuracy: 90.59%


Epoch [25/100], Train Loss: 0.7377, Val Accuracy: 90.14%


Epoch [26/100], Train Loss: 0.7052, Val Accuracy: 90.54%


Epoch [27/100], Train Loss: 0.7326, Val Accuracy: 89.43%


Epoch [28/100], Train Loss: 0.7543, Val Accuracy: 88.98%


Epoch [29/100], Train Loss: 0.7214, Val Accuracy: 88.01%


Epoch [30/100], Train Loss: 0.7756, Val Accuracy: 89.28%


Epoch [31/100], Train Loss: 0.7295, Val Accuracy: 88.66%


Epoch [32/100], Train Loss: 0.7381, Val Accuracy: 87.62%


Epoch [33/100], Train Loss: 0.7623, Val Accuracy: 88.69%


Epoch [34/100], Train Loss: 0.7395, Val Accuracy: 86.24%


Epoch [35/100], Train Loss: 0.7603, Val Accuracy: 88.26%


Epoch [36/100], Train Loss: 0.7342, Val Accuracy: 89.51%


Epoch [37/100], Train Loss: 0.7516, Val Accuracy: 87.68%


Epoch [38/100], Train Loss: 0.7670, Val Accuracy: 88.37%


Epoch [39/100], Train Loss: 0.7628, Val Accuracy: 89.26%


Epoch [40/100], Train Loss: 0.7931, Val Accuracy: 88.22%


Epoch [41/100], Train Loss: 0.7619, Val Accuracy: 87.99%


Epoch [42/100], Train Loss: 0.7570, Val Accuracy: 89.75%


Epoch [43/100], Train Loss: 0.7393, Val Accuracy: 88.74%


Epoch [44/100], Train Loss: 0.7469, Val Accuracy: 88.84%


Epoch [45/100], Train Loss: 0.7432, Val Accuracy: 89.62%


Epoch [46/100], Train Loss: 0.7538, Val Accuracy: 88.52%


Epoch [47/100], Train Loss: 0.7621, Val Accuracy: 89.16%


Epoch [48/100], Train Loss: 0.7557, Val Accuracy: 89.38%


Epoch [49/100], Train Loss: 0.7740, Val Accuracy: 89.22%


Learning rate changed to 0.001 at epoch 50


Epoch [50/100], Train Loss: 0.7530, Val Accuracy: 89.16%


Epoch [51/100], Train Loss: 0.7698, Val Accuracy: 88.82%


Epoch [52/100], Train Loss: 0.7799, Val Accuracy: 88.74%


Epoch [53/100], Train Loss: 0.7697, Val Accuracy: 88.73%


Epoch [54/100], Train Loss: 0.7608, Val Accuracy: 88.00%


Epoch [55/100], Train Loss: 0.7600, Val Accuracy: 88.98%


Epoch [56/100], Train Loss: 0.7569, Val Accuracy: 88.82%


Epoch [57/100], Train Loss: 0.7598, Val Accuracy: 89.12%


Epoch [58/100], Train Loss: 0.7427, Val Accuracy: 89.22%


Epoch [59/100], Train Loss: 0.7294, Val Accuracy: 87.83%


Epoch [60/100], Train Loss: 0.7220, Val Accuracy: 89.71%


Epoch [61/100], Train Loss: 0.7557, Val Accuracy: 89.10%


Epoch [62/100], Train Loss: 0.7486, Val Accuracy: 88.00%


Epoch [63/100], Train Loss: 0.7356, Val Accuracy: 89.56%


Epoch [64/100], Train Loss: 0.7129, Val Accuracy: 90.17%


Epoch [65/100], Train Loss: 0.7253, Val Accuracy: 89.28%


Epoch [66/100], Train Loss: 0.7533, Val Accuracy: 89.83%


Epoch [67/100], Train Loss: 0.7256, Val Accuracy: 88.02%


Epoch [68/100], Train Loss: 0.7433, Val Accuracy: 87.72%


Epoch [69/100], Train Loss: 0.7350, Val Accuracy: 88.81%


Epoch [70/100], Train Loss: 0.7408, Val Accuracy: 89.49%


Epoch [71/100], Train Loss: 0.7369, Val Accuracy: 89.10%


Epoch [72/100], Train Loss: 0.7379, Val Accuracy: 89.74%


Epoch [73/100], Train Loss: 0.7265, Val Accuracy: 89.70%


Epoch [74/100], Train Loss: 0.7229, Val Accuracy: 87.52%


Learning rate changed to 0.0001 at epoch 75


Epoch [75/100], Train Loss: 0.7407, Val Accuracy: 88.58%


Epoch [76/100], Train Loss: 0.6797, Val Accuracy: 89.55%


Epoch [77/100], Train Loss: 0.6677, Val Accuracy: 89.21%


Epoch [78/100], Train Loss: 0.6703, Val Accuracy: 89.40%


Epoch [79/100], Train Loss: 0.6900, Val Accuracy: 89.38%


Epoch [80/100], Train Loss: 0.6929, Val Accuracy: 89.84%


Epoch [81/100], Train Loss: 0.6998, Val Accuracy: 89.55%


Epoch [82/100], Train Loss: 0.7020, Val Accuracy: 89.38%


Epoch [83/100], Train Loss: 0.6782, Val Accuracy: 89.29%


Epoch [84/100], Train Loss: 0.6831, Val Accuracy: 89.58%


Epoch [85/100], Train Loss: 0.6873, Val Accuracy: 89.05%


Epoch [86/100], Train Loss: 0.6719, Val Accuracy: 89.47%


Epoch [87/100], Train Loss: 0.6900, Val Accuracy: 89.66%


Epoch [88/100], Train Loss: 0.6893, Val Accuracy: 89.13%


Epoch [89/100], Train Loss: 0.6869, Val Accuracy: 89.54%


Epoch [90/100], Train Loss: 0.6844, Val Accuracy: 89.99%


Epoch [91/100], Train Loss: 0.6944, Val Accuracy: 89.61%


Epoch [92/100], Train Loss: 0.6820, Val Accuracy: 89.12%


Epoch [93/100], Train Loss: 0.6912, Val Accuracy: 89.80%


Epoch [94/100], Train Loss: 0.6938, Val Accuracy: 88.66%


Epoch [95/100], Train Loss: 0.6830, Val Accuracy: 89.34%


Epoch [96/100], Train Loss: 0.6908, Val Accuracy: 89.02%


Epoch [97/100], Train Loss: 0.6880, Val Accuracy: 89.38%


Epoch [98/100], Train Loss: 0.6874, Val Accuracy: 89.74%


Epoch [99/100], Train Loss: 0.6754, Val Accuracy: 89.64%


Epoch [100/100], Train Loss: 0.6901, Val Accuracy: 89.87%
Best Accuracy: 89.87%
Best Parameters: Batch Size=128, Learning Rate=[0.001, 0.001, 0.0001]


## 6. Evaluate CNN-like BiKA for MNIST

In [17]:
#train_loader = DataLoader(data_train, batch_size=best_params[0], shuffle=True)
#val_loader = DataLoader(data_valid, batch_size=best_params[0], shuffle=False)
test_loader = DataLoader(data_test, batch_size=best_params[0], shuffle=False)

#model = BiKA_MNIST().to(device)
#criterion = nn.CrossEntropyLoss()

#train_and_validate(model, train_loader, val_loader, criterion, best_params[1])

In [18]:
print(f"Best Validation Accuracy: {best_acc*100:.2f}%")
print(f"Best Parameters: Batch Size={best_params[0]}, Learning Rate={best_params[1]}")

evaluate_model(model, test_loader)

Best Validation Accuracy: 89.87%
Best Parameters: Batch Size=128, Learning Rate=[0.001, 0.001, 0.0001]


Test Accuracy: 90.42%


0.9042