In [1]:
# Import necessary libraries
import torch
print(torch.__version__)

import torch.nn as nn
import torch.optim as optim

import torchvision
print(torchvision.__version__)
import torchvision.transforms as transforms

from torchsummary import summary

import brevitas
from brevitas.nn import QuantLinear
from brevitas.core.quant.binary import ClampedBinaryQuant
from brevitas.core.scaling import ConstScaling
from brevitas.core.quant import QuantType
print(brevitas.__version__)

2.1.0
0.16.0
0.9.1


In [2]:
# Load CIFAR-10
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100.0%


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


# Network Architectures

## Simple BNN vs. Real-Valued DNN

In [3]:
# Real-valued DNN with a single layer
class DNN(nn.Module):
    def __init__(self):
        super(DNN, self).__init__()
        self.fc1 = nn.Linear(32 * 32 * 3, 10)  # For CIFAR-10

    def forward(self, x):
        x = x.view(-1, 32 * 32 * 3)
        x = self.fc1(x)
        return x

class BNN(nn.Module):
    def __init__(self):
        super(BNN, self).__init__()
        
        # Use predefined BINARY weight quant type
        self.fc1 = QuantLinear(
            32 * 32 * 3, 
            10, 
            bias=True, 
            weight_quant_type=QuantType.BINARY)

    def forward(self, x):
        x = x.view(-1, 32 * 32 * 3)
        x = self.fc1(x)
        return x


In [4]:
dnn_model = DNN()
bnn_model = BNN()

In [5]:
print(summary(dnn_model).trainable_params)
print(summary(bnn_model).trainable_params)

Layer (type:depth-idx)                   Param #
├─Linear: 1-1                            30,730
Total params: 30,730
Trainable params: 30,730
Non-trainable params: 0
30730
Layer (type:depth-idx)                                  Param #
├─QuantLinear: 1-1                                      --
|    └─ActQuantProxyFromInjector: 2-1                   --
|    |    └─StatelessBuffer: 3-1                        --
|    └─ActQuantProxyFromInjector: 2-2                   --
|    |    └─StatelessBuffer: 3-2                        --
|    └─WeightQuantProxyFromInjector: 2-3                --
|    |    └─StatelessBuffer: 3-3                        --
|    |    └─BinaryQuant: 3-4                            30,720
|    └─BiasQuantProxyFromInjector: 2-4                  --
|    |    └─StatelessBuffer: 3-5                        --
Total params: 30,720
Trainable params: 30,720
Non-trainable params: 0
30720


## XNOR NIN vs. Real-Valued NIN

In [7]:
import torch.nn as nn
import torch
import torch.nn.functional as F

class BinActive(torch.autograd.Function):
    '''
    Binarize the input activations and calculate the mean across channel dimension.
    '''
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        size = input.size()
        mean = torch.mean(input.abs(), 1, keepdim=True)
        input = input.sign()
        return input, mean

    @staticmethod
    def backward(ctx, grad_output, grad_output_mean):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input.ge(1)] = 0
        grad_input[input.le(-1)] = 0
        return grad_input


class BinConv2d(nn.Module):
    def __init__(self, input_channels, output_channels,
            kernel_size=-1, stride=-1, padding=-1, dropout=0):
        super(BinConv2d, self).__init__()
        self.layer_type = 'BinConv2d'
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dropout_ratio = dropout

        self.bn = nn.BatchNorm2d(input_channels, eps=1e-4, momentum=0.1, affine=True)
        self.bn.weight.data = self.bn.weight.data.zero_().add(1.0)
        if dropout!=0:
            self.dropout = nn.Dropout(dropout)
        self.conv = nn.Conv2d(input_channels, output_channels,
                kernel_size=kernel_size, stride=stride, padding=padding)
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x):
        x = self.bn(x)
        x, mean = BinActive.apply(x)
        if self.dropout_ratio!=0:
            x = self.dropout(x)
        x = self.conv(x)
        x = self.relu(x)
        return x

    
# XNOR NIN
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.xnor = nn.Sequential(
                nn.Conv2d(3, 192, kernel_size=5, stride=1, padding=2),
                nn.BatchNorm2d(192, eps=1e-4, momentum=0.1, affine=False),
                nn.ReLU(inplace=True),
                BinConv2d(192, 160, kernel_size=1, stride=1, padding=0),
                BinConv2d(160,  96, kernel_size=1, stride=1, padding=0),
                nn.MaxPool2d(kernel_size=3, stride=2, padding=1),

                BinConv2d( 96, 192, kernel_size=5, stride=1, padding=2, dropout=0.5),
                BinConv2d(192, 192, kernel_size=1, stride=1, padding=0),
                BinConv2d(192, 192, kernel_size=1, stride=1, padding=0),
                nn.AvgPool2d(kernel_size=3, stride=2, padding=1),

                BinConv2d(192, 192, kernel_size=3, stride=1, padding=1, dropout=0.5),
                BinConv2d(192, 192, kernel_size=1, stride=1, padding=0),
                nn.BatchNorm2d(192, eps=1e-4, momentum=0.1, affine=False),
                nn.Conv2d(192,  10, kernel_size=1, stride=1, padding=0),
                nn.ReLU(inplace=True),
                nn.AvgPool2d(kernel_size=8, stride=1, padding=0),
                )

    def forward(self, x):
        for m in self.modules():
            if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
                if hasattr(m.weight, 'data'):
                    m.weight.data.clamp_(min=0.01)
        x = self.xnor(x)
        x = x.view(x.size(0), 10)
        return x
    
# Real NIN
class RealNIN(nn.Module):
    def __init__(self):
        super(RealNIN, self).__init__()
        self.net = nn.Sequential(
                nn.Conv2d(3, 192, kernel_size=5, stride=1, padding=2),
                nn.BatchNorm2d(192, eps=1e-4, momentum=0.1),
                nn.ReLU(inplace=True),
                nn.Conv2d(192, 160, kernel_size=1, stride=1, padding=0),
                nn.ReLU(inplace=True),
                nn.Conv2d(160,  96, kernel_size=1, stride=1, padding=0),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(kernel_size=3, stride=2, padding=1),

                nn.Conv2d(96, 192, kernel_size=5, stride=1, padding=2),
                nn.BatchNorm2d(192, eps=1e-4, momentum=0.1),
                nn.ReLU(inplace=True),
                nn.Conv2d(192, 192, kernel_size=1, stride=1, padding=0),
                nn.ReLU(inplace=True),
                nn.Conv2d(192, 192, kernel_size=1, stride=1, padding=0),
                nn.ReLU(inplace=True),
                nn.AvgPool2d(kernel_size=3, stride=2, padding=1),

                nn.Conv2d(192, 192, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(192, eps=1e-4, momentum=0.1),
                nn.ReLU(inplace=True),
                nn.Conv2d(192, 192, kernel_size=1, stride=1, padding=0),
                nn.ReLU(inplace=True),
                nn.Conv2d(192,  10, kernel_size=1, stride=1, padding=0),
                nn.ReLU(inplace=True),
                nn.AdaptiveAvgPool2d(1)
                )

    def forward(self, x):
        x = self.net(x)
        x = x.view(x.size(0), 10)
        return x


In [8]:
model_nin = Net()
model_real_nin = RealNIN()

In [9]:
print(summary(model_nin).trainable_params)
print(summary(model_real_nin).trainable_params)

Layer (type:depth-idx)                   Param #
├─Sequential: 1-1                        --
|    └─Conv2d: 2-1                       14,592
|    └─BatchNorm2d: 2-2                  --
|    └─ReLU: 2-3                         --
|    └─BinConv2d: 2-4                    --
|    |    └─BatchNorm2d: 3-1             384
|    |    └─Conv2d: 3-2                  30,880
|    |    └─ReLU: 3-3                    --
|    └─BinConv2d: 2-5                    --
|    |    └─BatchNorm2d: 3-4             320
|    |    └─Conv2d: 3-5                  15,456
|    |    └─ReLU: 3-6                    --
|    └─MaxPool2d: 2-6                    --
|    └─BinConv2d: 2-7                    --
|    |    └─BatchNorm2d: 3-7             192
|    |    └─Dropout: 3-8                 --
|    |    └─Conv2d: 3-9                  460,992
|    |    └─ReLU: 3-10                   --
|    └─BinConv2d: 2-8                    --
|    |    └─BatchNorm2d: 3-11            384
|    |    └─Conv2d: 3-12                 37,056
| 

# Model Training

## BNN vs. RDNN

In [10]:
def train(model, dataloader, criterion, optimizer, epochs=10):
    for epoch in range(epochs):
        running_loss = 0.0
        for i, data in enumerate(dataloader, 0):
            inputs, labels = data
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"Epoch {epoch+1}, Loss: {running_loss / len(dataloader)}")
    print('Finished Training')

# Train the real-valued DNN
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(dnn_model.parameters(), lr=0.001, momentum=0.9)
train(dnn_model, trainloader, criterion, optimizer)

# Train the BNN
optimizer_bnn = optim.SGD(bnn_model.parameters(), lr=0.001, momentum=0.9)
train(bnn_model, trainloader, criterion, optimizer_bnn)


Epoch 1, Loss: 2.154309051605463
Epoch 2, Loss: 2.1049341724646093
Epoch 3, Loss: 2.0855155343353746
Epoch 4, Loss: 2.075864036922455
Epoch 5, Loss: 2.066716543700695
Epoch 6, Loss: 2.061993913832903
Epoch 7, Loss: 2.0530951206195356
Epoch 8, Loss: 2.046263520656824
Epoch 9, Loss: 2.0302124366790055
Epoch 10, Loss: 2.0334975581109522
Finished Training
Epoch 1, Loss: 2.1836362445765736
Epoch 2, Loss: 2.165489381093979
Epoch 3, Loss: 2.2206116836598517
Epoch 4, Loss: 2.193145127355009
Epoch 5, Loss: 2.267161365763396
Epoch 6, Loss: 2.2126669481050967
Epoch 7, Loss: 2.2651475597362967
Epoch 8, Loss: 2.228112700407803
Epoch 9, Loss: 2.2825963844022157
Epoch 10, Loss: 2.2447486337159575
Finished Training


## NIN vs. Real NIN

In [12]:
def train(model, dataloader, criterion, optimizer, epochs=10):
    for epoch in range(epochs):
        running_loss = 0.0
        for i, data in enumerate(dataloader, 0):
            inputs, labels = data
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"Epoch {epoch+1}, Loss: {running_loss / len(dataloader)}")
    print('Finished Training')

# Train the real-valued NIN
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model_real_nin.parameters(), lr=0.001, momentum=0.9)
train(model_real_nin, trainloader, criterion, optimizer)

# Train the B-NIN
optimizer_bnn = optim.SGD(model_nin.parameters(), lr=0.001, momentum=0.9)
train(model_nin, trainloader, criterion, optimizer_bnn)

Epoch 1, Loss: 2.3026707262802124
Epoch 2, Loss: 2.3025851249694824
Epoch 3, Loss: 2.3025851249694824
Epoch 4, Loss: 2.3025851249694824
Epoch 5, Loss: 2.3025851249694824
Epoch 6, Loss: 2.3025851249694824
Epoch 7, Loss: 2.3025851249694824
Epoch 8, Loss: 2.3025851249694824
Epoch 9, Loss: 2.3025851249694824
Epoch 10, Loss: 2.3025851249694824
Finished Training
Epoch 1, Loss: 2.3027420357322694
Epoch 2, Loss: 2.302663890571594
Epoch 3, Loss: 2.3026093805122376
Epoch 4, Loss: 2.3025438442611694
Epoch 5, Loss: 2.302365941066742
Epoch 6, Loss: 2.2497773061943054
Epoch 7, Loss: 2.0179050044202804
Epoch 8, Loss: 1.9382286017131805
Epoch 9, Loss: 1.8031506101131438
Epoch 10, Loss: 1.681749998474121
Finished Training


# Model Evaluation

In [13]:
def evaluate(model, dataloader):
    correct = 0
    total = 0
    with torch.no_grad():
        for data in dataloader:
            images, labels = data
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return 100 * correct / total

dnn_accuracy = evaluate(dnn_model, testloader)
bnn_accuracy = evaluate(bnn_model, testloader)

print(f'DNN Accuracy on test set: {dnn_accuracy}%')
print(f'BNN Accuracy on test set: {bnn_accuracy}%')


DNN Accuracy on test set: 33.82%
BNN Accuracy on test set: 33.08%


In [14]:
def evaluate(model, dataloader):
    correct = 0
    total = 0
    with torch.no_grad():
        for data in dataloader:
            images, labels = data
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return 100 * correct / total

r_nin_accuracy = evaluate(model_real_nin, testloader)
b_nin_accuracy = evaluate(model_nin, testloader)

print(f'R-NIN Accuracy on test set: {r_nin_accuracy}%')
print(f'B-NIN Accuracy on test set: {b_nin_accuracy}%')

R-NIN Accuracy on test set: 10.0%
B-NIN Accuracy on test set: 41.03%
