## Introduction

This notebook implements the Binarized Neural Networks (BNNs) as described in the paper "Binarized Neural Networks: Training Neural Networks with Weights and Activations Constrained to +1 or −1" by Matthieu Courbariaux, Itay Hubara, Daniel Soudry, Ran El-Yaniv, and Yoshua Bengio.

BNNs are neural networks with binary weights and activations at run-time. They drastically reduce memory consumption and replace most arithmetic operations with bit-wise operations, making them very power efficient.


In [19]:
!pip install torch
!pip install torchvision



In [20]:
## Setup
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

In [24]:
def binarize(x):
    return torch.sign(x)

# Binarized activation function
class BinaryActivation(nn.Module):
    def __init__(self):
        super(BinaryActivation, self).__init__()
    
    def forward(self, x):
        return binarize(x)

# Binarized linear layer
class BinarizedLinear(nn.Linear):
    def __init__(self, *args, **kwargs):
        super(BinarizedLinear, self).__init__(*args, **kwargs)
    
    def forward(self, input):
        binary_weight = binarize(self.weight)
        if self.bias is None:
            return nn.functional.linear(input, binary_weight)
        else:
            return nn.functional.linear(input, binary_weight, self.bias)

# Binarized convolutional layer  
class BinarizedConv2d(nn.Conv2d):
    def __init__(self, *args, **kwargs):
        super(BinarizedConv2d, self).__init__(*args, **kwargs)
    
    def forward(self, input):
        binarized_weight = binarize(self.weight)
        return nn.functional.conv2d(input, binarized_weight, self.bias, self.stride, self.padding, self.dilation, self.groups)


## Define the BNN Model

class BNN(nn.Module):
    def __init__(self):
        super(BNN, self).__init__()
        self.features = nn.Sequential(
            BinarizedConv2d(3, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            BinaryActivation(),
            BinarizedConv2d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.BatchNorm2d(128),
            BinaryActivation(),
            BinarizedConv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            BinaryActivation(),
            BinarizedConv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.BatchNorm2d(256),
            BinaryActivation(),
            BinarizedConv2d(256, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            BinaryActivation(),
            BinarizedConv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.BatchNorm2d(512),
            BinaryActivation()
        )
        self.classifier = nn.Sequential(
            BinarizedLinear(512 * 4 * 4, 1024),
            nn.BatchNorm1d(1024),
            BinaryActivation(),
            nn.Dropout(0.5),
            BinarizedLinear(1024, 1024),
            nn.BatchNorm1d(1024),
            BinaryActivation(),
            nn.Dropout(0.5),
            BinarizedLinear(1024, 10)
        )
    
    def forward(self, x):
        x = self.features(x)
        x = x.view(-1, 512 * 4 * 4)
        x = self.classifier(x)
        return x

In [22]:

import ssl
ssl._create_default_https_context = ssl._create_unverified_context
## Load and Preprocess Data

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128,
                                         shuffle=False, num_workers=2)


Files already downloaded and verified
Files already downloaded and verified


In [25]:
## Train the BNN

net = BNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.01)

for epoch in range(100):  
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        optimizer.zero_grad()

        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        
        for p in net.parameters():
            p.data.clamp_(-1,1) 
        
        optimizer.step()
        
        running_loss += loss.item()
        if i % 100 == 99:    
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 100))
            running_loss = 0.0

print('Finished Training')

[1,   100] loss: 69.217
[1,   200] loss: 69.774
[1,   300] loss: 68.488
[2,   100] loss: 69.653
[2,   200] loss: 68.951
[2,   300] loss: 69.358
[3,   100] loss: 69.630
[3,   200] loss: 69.721
[3,   300] loss: 69.471


KeyboardInterrupt: 

In [None]:
## Evaluate the BNN

correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))

NameError: name 'binarize' is not defined