In [1]:
## custom design Network Binary 1

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision import datasets
from torch.utils.data import DataLoader
from torch.autograd import Function

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Binarization function
class BinaryActivation(Function):
    @staticmethod
    def forward(ctx, input):
        return input.sign()

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.clamp_(-1, 1)

def binary_activation(input):
    return BinaryActivation.apply(input)

# Define your binary neural network
class BinaryNet(nn.Module):
    def __init__(self):
        super(BinaryNet, self).__init__()
        self.fc1 = nn.Linear(28*28, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = binary_activation(self.fc1(x))
        x = binary_activation(self.fc2(x))
        logits = self.fc3(x)
        return logits



In [2]:
## Train block

# Initialize network
model = BinaryNet().to(device)

# Load MNIST
mnist_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.MNIST(root='./data', train=True, transform=mnist_transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)


criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())

# Training loop
for epoch in range(1, 11):
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

    print(f'Epoch {epoch}, Loss: {loss.item()}')

print("Training Complete")


torch.save(model.state_dict(), 'binary_mnist_model.pth')

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 82662391.78it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 83772955.62it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 30976433.16it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 16609005.03it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw






Epoch 1, Loss: 0.2997497618198395
Epoch 2, Loss: 0.29428812861442566
Epoch 3, Loss: 0.3297218978404999
Epoch 4, Loss: 0.19530165195465088
Epoch 5, Loss: 0.056826308369636536
Epoch 6, Loss: 0.43031495809555054
Epoch 7, Loss: 0.36292341351509094
Epoch 8, Loss: 0.16632799804210663
Epoch 9, Loss: 0.5014564990997314
Epoch 10, Loss: 0.391250342130661
Training Complete


In [3]:
## Test model before pruning

# Load the saved model
model = BinaryNet().to(device)
model.load_state_dict(torch.load('binary_mnist_model.pth'))
model.eval()

# Load the test dataset
mnist_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

test_dataset = datasets.MNIST(root='./data', train=False, transform=mnist_transform, download=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)

# Test the model
correct = 0
total = 0

with torch.no_grad():
    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        outputs = model(data)
        _, predicted = torch.max(outputs.data, 1)
        total += target.size(0)
        correct += (predicted == target).sum().item()

accuracy = correct / total
print(f'Test Accuracy (Model Before pruning): {accuracy * 100:.2f}%')

Test Accuracy (Model Before pruning): 92.69%


In [4]:
import torch
import numpy as np
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import copy


class PruningGeneticAlgorithm:
    def __init__(self, model, dataset, population_size=50, generations=10, mutation_rate=0.01, crossover_rate=0.5):
        self.model = model
        self.dataset = dataset
        self.population_size = population_size
        self.generations = generations
        self.mutation_rate = mutation_rate
        self.crossover_rate = crossover_rate
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model = self.model.to(self.device)

    def initial_population(self):
        population = []
        for _ in range(self.population_size):
            individual = {name: torch.randint(0, 2, weight.shape).bool().to(self.device)
                          for name, weight in self.model.named_parameters()}
            population.append(individual)
        return population

    def fitness(self, individual):
        # Apply the mask to the model
        pruned_model = copy.deepcopy(self.model)
        with torch.no_grad():
            for name, param in pruned_model.named_parameters():
                mask = individual.get(name, torch.ones_like(param.data).bool().to(self.device))
                param.data.mul_(mask)
        # Evaluate fitness with evaluation function -> accuracy
        accuracy = self.evaluate(pruned_model)
        return accuracy

    def evaluate(self, model):
        correct, total = 0, 0
        dataloader = DataLoader(self.dataset, batch_size=64, shuffle=False)
        model.eval()
        with torch.no_grad():
            for inputs, labels in dataloader:
                inputs, labels = inputs.to(self.device), labels.to(self.device)
                outputs = model(inputs)
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        return correct / total

    def selection(self, population, fitnesses):
        sorted_population = [x for _, x in sorted(zip(fitnesses, population), key=lambda pair: pair[0], reverse=True)]
        return sorted_population[:int(0.2 * len(sorted_population))]

    def crossover(self, parent1, parent2):
        child = {}
        for key in parent1.keys():
            mask = torch.rand(parent1[key].size()).to(self.device) < self.crossover_rate
            child[key] = mask.bool() * parent1[key] + (~mask).bool() * parent2[key]
        return child

    def mutation(self, individual):
        for param in individual.values():
            mutation_mask = torch.rand(param.size()).to(self.device) < self.mutation_rate
            param.logical_xor_(mutation_mask)
        return individual

    def run(self):
        population = self.initial_population()
        for generation in range(self.generations):
            fitnesses = [self.fitness(individual) for individual in population]
            selected = self.selection(population, fitnesses)
            children = []
            while len(children) < self.population_size:
                parent1, parent2 = np.random.choice(selected, size=2, replace=False)
                child = self.crossover(parent1, parent2)
                child = self.mutation(child)
                children.append(child)
            population = children

            best_fitness = max(fitnesses)
            print(f"Generation {generation} -- Best Fitness: {best_fitness}")

        best_individual = population[np.argmax(fitnesses)]
        return best_individual



In [5]:

# def test_pruned_model(model, validation_dataset, best_pruned_mask):
#     # Apply the best mask to the model to get the pruned model
#     pruned_model = copy.deepcopy(model)
#     for name, param in pruned_model.named_parameters():
#         mask = best_pruned_mask.get(name, torch.ones_like(param.data).bool())
#         param.data.mul_(mask)

#     # Evaluate the pruned model on the validation dataset
#     accuracy = evaluate(pruned_model, validation_dataset)
#     print(f"Accuracy of the pruned model: {accuracy}")

def test_pruned_model(model, validation_dataset, best_pruned_mask):
    # Apply the best mask to the model to get the pruned model
    pruned_model = copy.deepcopy(model)
    total_params = 0
    pruned_params = 0

    for name, param in pruned_model.named_parameters():
        mask = best_pruned_mask.get(name, torch.ones_like(param.data).bool())
        pruned_params += torch.sum(~mask).item()
        total_params += torch.numel(param.data)

        param.data.mul_(mask)

    # Evaluate the pruned model on the validation dataset
    accuracy = evaluate(pruned_model, validation_dataset)
    print(f"Accuracy of the pruned model: {accuracy}")

    # Calculate and print the percentage of weights pruned
    percentage_pruned = (pruned_params / total_params) * 100
    print(f"Percentage of weights pruned: {percentage_pruned:.2f}%")


def evaluate(model, dataset):
    correct, total = 0, 0
    dataloader = DataLoader(dataset, batch_size=64, shuffle=False)
    model.eval()
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

if __name__ == "__main__":

    model = BinaryNet()
    model.load_state_dict(torch.load('binary_mnist_model.pth'))


    validation_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor())


    ga = PruningGeneticAlgorithm(model=model, dataset=validation_dataset)
    best_pruned_mask = ga.run()

    # Test the pruned model
    test_pruned_model(model, validation_dataset, best_pruned_mask)

Generation 0 -- Best Fitness: 0.8497
Generation 1 -- Best Fitness: 0.8732
Generation 2 -- Best Fitness: 0.8811
Generation 3 -- Best Fitness: 0.8836
Generation 4 -- Best Fitness: 0.8926
Generation 5 -- Best Fitness: 0.8924
Generation 6 -- Best Fitness: 0.8957
Generation 7 -- Best Fitness: 0.8988
Generation 8 -- Best Fitness: 0.9014
Generation 9 -- Best Fitness: 0.9054
Accuracy of the pruned model: 0.8946
Percentage of weights pruned: 49.81%


In [None]:
#  approaches to achieve train BNNs
#1- BinaryConnect or BinaryWeight Networks:
# Use training techniques like BinaryConnect or BinaryWeight networks, where during the forward pass, weights are binarized to +1 or -1, and during the backward pass, real-valued weights are used for gradient descent.

#2-Use Binary Neural Network Libraries:
# Using specialized binary neural network libraries that provide tools and techniques for training binary networks. Examples include Brevitas, BNN-PYNQ

#3-Quantization Techniques:
# Apply quantization techniques to train a low-bit model. This is different from a strict binary model but allows weights to take on a limited set of values, such as -1, 0, and 1.

#4-Post-Training Binarization:
# Train a standard neural network first, and then apply post-training binarization to convert the weights to +1 or -1.

In [11]:
#Post-Training Binarization Method
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import copy

# Binarization function
class BinaryActivation(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        return input.sign()

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output

def binary_activation(input):
    return BinaryActivation.apply(input)

#Post-Training Binarization ()
class BinaryNet_(nn.Module):
    def __init__(self):
        super(BinaryNet, self).__init__()
        self.fc1 = nn.Linear(28*28, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = binary_activation(self.fc1(x))
        x = binary_activation(self.fc2(x))
        logits = self.fc3(x)
        return logits

# Load MNIST
mnist_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.MNIST(root='./data', train=True, transform=mnist_transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=mnist_transform)

train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)

# Initialize and train the model
model = BinaryNet()

# Training
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

for epoch in range(5):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

# Function to perform post-training binarization
def binarize_weights(model):
    binarized_model = copy.deepcopy(model)
    for name, param in binarized_model.named_parameters():
        param.data = param.data.sign()
    return binarized_model

# Apply post-training binarization
binarized_model = binarize_weights(model)

# Print all weights
for name, param in binarized_model.named_parameters():
    print(f'{name}: {param.data}')


# Test the binarized model
binarized_model.eval()
correct = 0
total = 0

with torch.no_grad():
    for data, target in test_loader:
        output = binarized_model(data)
        _, predicted = torch.max(output, 1)
        total += target.size(0)
        correct += (predicted == target).sum().item()

accuracy = correct / total
print(f'Test Accuracy of the Post-Training Binarization Model: {accuracy * 100:.2f}%')
torch.save(model.state_dict(), 'Post_Training_Binarization.pth')


fc1.weight: tensor([[ 1., -1.,  1.,  ..., -1., -1., -1.],
        [ 1.,  1., -1.,  ..., -1.,  1., -1.],
        [ 1., -1.,  1.,  ...,  1.,  1.,  1.],
        ...,
        [ 1.,  1., -1.,  ...,  1.,  1.,  1.],
        [-1.,  1., -1.,  ..., -1., -1.,  1.],
        [ 1., -1.,  1.,  ..., -1., -1., -1.]])
fc1.bias: tensor([-1.,  1.,  1.,  1.,  1.,  1.,  1.,  1., -1., -1., -1.,  1.,  1.,  1.,
        -1.,  1.,  1., -1., -1., -1., -1., -1., -1.,  1.,  1., -1., -1.,  1.,
         1., -1., -1., -1.,  1., -1.,  1.,  1., -1., -1.,  1., -1., -1., -1.,
         1., -1., -1.,  1.,  1., -1., -1.,  1.,  1., -1., -1.,  1., -1.,  1.,
         1., -1., -1.,  1.,  1.,  1.,  1.,  1., -1., -1.,  1.,  1., -1., -1.,
        -1., -1.,  1., -1., -1.,  1.,  1., -1.,  1., -1., -1., -1.,  1., -1.,
         1., -1., -1.,  1.,  1.,  1.,  1., -1.,  1., -1., -1., -1.,  1.,  1.,
         1., -1., -1., -1., -1.,  1.,  1.,  1.,  1.,  1.,  1.,  1., -1.,  1.,
        -1., -1.,  1.,  1., -1., -1., -1.,  1.,  1.,  1., -1., -

In [12]:
## prune Post-Training Binarization model
def test_pruned_model(model, validation_dataset, best_pruned_mask):
    # Apply the best mask to the model to get the pruned model
    pruned_model = copy.deepcopy(model)
    total_params = 0
    pruned_params = 0

    for name, param in pruned_model.named_parameters():
        mask = best_pruned_mask.get(name, torch.ones_like(param.data).bool())
        pruned_params += torch.sum(~mask).item()
        total_params += torch.numel(param.data)

        param.data.mul_(mask)

    # Evaluate the pruned model on the validation dataset
    accuracy = evaluate(pruned_model, validation_dataset)
    print(f"Accuracy of the pruned model: {accuracy}")

    # Calculate and print the percentage of weights pruned
    percentage_pruned = (pruned_params / total_params) * 100
    print(f"Percentage of weights pruned: {percentage_pruned:.2f}%")


def evaluate(model, dataset):
    correct, total = 0, 0
    dataloader = DataLoader(dataset, batch_size=64, shuffle=False)
    model.eval()
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

if __name__ == "__main__":

    model = BinaryNet()
    model.load_state_dict(torch.load('Post_Training_Binarization.pth'))


    validation_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor())


    ga = PruningGeneticAlgorithm(model=model, dataset=validation_dataset)
    best_pruned_mask = ga.run()

    # Test the pruned model
    test_pruned_model(model, validation_dataset, best_pruned_mask)


Generation 0 -- Best Fitness: 0.8625
Generation 1 -- Best Fitness: 0.8792
Generation 2 -- Best Fitness: 0.8902
Generation 3 -- Best Fitness: 0.8932
Generation 4 -- Best Fitness: 0.8979
Generation 5 -- Best Fitness: 0.9022
Generation 6 -- Best Fitness: 0.9035
Generation 7 -- Best Fitness: 0.9054
Generation 8 -- Best Fitness: 0.9081
Generation 9 -- Best Fitness: 0.9078
Accuracy of the pruned model: 0.9003
Percentage of weights pruned: 50.01%
