In [9]:
# Import necessary libraries
import torch
print(torch.__version__)

import torch.nn as nn
import torch.optim as optim

import torchvision
print(torchvision.__version__)
import torchvision.transforms as transforms

from torchsummary import summary

import brevitas
from brevitas.nn import QuantLinear
from brevitas.core.quant.binary import ClampedBinaryQuant
from brevitas.core.scaling import ConstScaling
from brevitas.core.quant import QuantType
print(brevitas.__version__)

import time

2.1.0
0.16.0
0.9.1


In [10]:
# Load CIFAR-10
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


# Network Architectures

## Simple BNN vs. Real-Valued DNN

In [11]:
# Real-valued DNN with a single layer
class DNN(nn.Module):
    def __init__(self):
        super(DNN, self).__init__()
        self.fc1 = nn.Linear(32 * 32 * 3, 10)  # For CIFAR-10

    def forward(self, x):
        x = x.view(-1, 32 * 32 * 3)
        x = self.fc1(x)
        return x

class BNN(nn.Module):
    def __init__(self):
        super(BNN, self).__init__()
        
        # Use predefined BINARY weight quant type
        self.fc1 = QuantLinear(
            32 * 32 * 3, 
            10, 
            bias=True, 
            weight_quant_type=QuantType.BINARY)

    def forward(self, x):
        x = x.view(-1, 32 * 32 * 3)
        x = self.fc1(x)
        return x


In [12]:
dnn_model = DNN()
bnn_model = BNN()

In [13]:
print(summary(dnn_model).trainable_params)
print(summary(bnn_model).trainable_params)

Layer (type:depth-idx)                   Param #
├─Linear: 1-1                            30,730
Total params: 30,730
Trainable params: 30,730
Non-trainable params: 0
30730
Layer (type:depth-idx)                                  Param #
├─QuantLinear: 1-1                                      --
|    └─ActQuantProxyFromInjector: 2-1                   --
|    |    └─StatelessBuffer: 3-1                        --
|    └─ActQuantProxyFromInjector: 2-2                   --
|    |    └─StatelessBuffer: 3-2                        --
|    └─WeightQuantProxyFromInjector: 2-3                --
|    |    └─StatelessBuffer: 3-3                        --
|    |    └─BinaryQuant: 3-4                            30,720
|    └─BiasQuantProxyFromInjector: 2-4                  --
|    |    └─StatelessBuffer: 3-5                        --
Total params: 30,720
Trainable params: 30,720
Non-trainable params: 0
30720


## XNOR NIN vs. Real-Valued NIN

In [21]:
import torch.nn as nn
import torch
import torch.nn.functional as F

class BinActive(torch.autograd.Function):
    '''
    Binarize the input activations and calculate the mean across channel dimension.
    '''
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        size = input.size()
        mean = torch.mean(input.abs(), 1, keepdim=True)
        input = input.sign()
        return input, mean

    @staticmethod
    def backward(ctx, grad_output, grad_output_mean):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input.ge(1)] = 0
        grad_input[input.le(-1)] = 0
        return grad_input


class BinConv2d(nn.Module):
    def __init__(self, input_channels, output_channels,
            kernel_size=-1, stride=-1, padding=-1, dropout=0):
        super(BinConv2d, self).__init__()
        self.layer_type = 'BinConv2d'
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dropout_ratio = dropout

        self.bn = nn.BatchNorm2d(input_channels, eps=1e-4, momentum=0.1, affine=True)
        self.bn.weight.data = self.bn.weight.data.zero_().add(1.0)
        if dropout!=0:
            self.dropout = nn.Dropout(dropout)
        self.conv = nn.Conv2d(input_channels, output_channels,
                kernel_size=kernel_size, stride=stride, padding=padding)
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x):
        x = self.bn(x)
        x, mean = BinActive.apply(x)
        if self.dropout_ratio!=0:
            x = self.dropout(x)
        x = self.conv(x)
        x = self.relu(x)
        return x

    
# XNOR NIN
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.xnor = nn.Sequential(
                nn.Conv2d(3, 192, kernel_size=5, stride=1, padding=2),
                nn.BatchNorm2d(192, eps=1e-4, momentum=0.1, affine=False),
                nn.ReLU(inplace=True),
                BinConv2d(192, 160, kernel_size=1, stride=1, padding=0),
                BinConv2d(160,  96, kernel_size=1, stride=1, padding=0),
                nn.MaxPool2d(kernel_size=3, stride=2, padding=1),

                BinConv2d( 96, 192, kernel_size=5, stride=1, padding=2, dropout=0.5),
                BinConv2d(192, 192, kernel_size=1, stride=1, padding=0),
                BinConv2d(192, 192, kernel_size=1, stride=1, padding=0),
                nn.AvgPool2d(kernel_size=3, stride=2, padding=1),

                BinConv2d(192, 192, kernel_size=3, stride=1, padding=1, dropout=0.5),
                BinConv2d(192, 192, kernel_size=1, stride=1, padding=0),
                nn.BatchNorm2d(192, eps=1e-4, momentum=0.1, affine=False),
                nn.Conv2d(192,  10, kernel_size=1, stride=1, padding=0),
                nn.ReLU(inplace=True),
                nn.AvgPool2d(kernel_size=8, stride=1, padding=0),
                )

    def forward(self, x):
        for m in self.modules():
            if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
                if hasattr(m.weight, 'data'):
                    m.weight.data.clamp_(min=0.01)
        x = self.xnor(x)
        x = x.view(x.size(0), 10)
        return x
    
# Real NIN
class RealNIN(nn.Module):
    def __init__(self):
        super(RealNIN, self).__init__()
        self.net = nn.Sequential(
                nn.Conv2d(3, 192, kernel_size=5, stride=1, padding=2),
                nn.BatchNorm2d(192, eps=1e-4, momentum=0.1),
                nn.ReLU(inplace=True),
                nn.Conv2d(192, 160, kernel_size=1, stride=1, padding=0),
                nn.ReLU(inplace=True),
                nn.Conv2d(160,  96, kernel_size=1, stride=1, padding=0),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(kernel_size=3, stride=2, padding=1),

                nn.Conv2d(96, 192, kernel_size=5, stride=1, padding=2),
                nn.BatchNorm2d(192, eps=1e-4, momentum=0.1),
                nn.ReLU(inplace=True),
                nn.Conv2d(192, 192, kernel_size=1, stride=1, padding=0),
                nn.ReLU(inplace=True),
                nn.Conv2d(192, 192, kernel_size=1, stride=1, padding=0),
                nn.ReLU(inplace=True),
                nn.AvgPool2d(kernel_size=3, stride=2, padding=1),

                nn.Conv2d(192, 192, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(192, eps=1e-4, momentum=0.1),
                nn.ReLU(inplace=True),
                nn.Conv2d(192, 192, kernel_size=1, stride=1, padding=0),
                nn.ReLU(inplace=True),
                nn.Conv2d(192,  10, kernel_size=1, stride=1, padding=0),
                nn.ReLU(inplace=True),
                nn.AdaptiveAvgPool2d(1)
                )

    def forward(self, x):
        x = self.net(x)
        x = x.view(x.size(0), 10)
        return x


In [24]:
class TimerHook:
    def __init__(self):
        self.start_time = 0
        self.total_time = 0

    def __call__(self, module, input, output):
        if self.start_time == 0:
            self.start_time = time.time()
        else:
            end_time = time.time()
            self.total_time += end_time - self.start_time
            self.start_time = 0

    def reset(self):
        self.start_time = 0
        self.total_time = 0

    def total(self):
        return self.total_time

In [25]:
model_nin = Net()
# model_real_nin = RealNIN()

hooks = {}

In [17]:
print(summary(model_nin).trainable_params)
# print(summary(model_real_nin).trainable_params)

Layer (type:depth-idx)                   Param #
├─Sequential: 1-1                        --
|    └─Conv2d: 2-1                       14,592
|    └─BatchNorm2d: 2-2                  --
|    └─ReLU: 2-3                         --
|    └─BinConv2d: 2-4                    --
|    |    └─BatchNorm2d: 3-1             384
|    |    └─Conv2d: 3-2                  30,880
|    |    └─ReLU: 3-3                    --
|    └─BinConv2d: 2-5                    --
|    |    └─BatchNorm2d: 3-4             320
|    |    └─Conv2d: 3-5                  15,456
|    |    └─ReLU: 3-6                    --
|    └─MaxPool2d: 2-6                    --
|    └─BinConv2d: 2-7                    --
|    |    └─BatchNorm2d: 3-7             192
|    |    └─Dropout: 3-8                 --
|    |    └─Conv2d: 3-9                  460,992
|    |    └─ReLU: 3-10                   --
|    └─BinConv2d: 2-8                    --
|    |    └─BatchNorm2d: 3-11            384
|    |    └─Conv2d: 3-12                 37,056
| 

In [26]:
# Use named_children to get immediate child modules. If you need all nested modules, use named_modules instead.
for name, layer in model_nin.named_children():
    if isinstance(layer, nn.Sequential):  # since 'net' is Sequential
        for sub_name, sub_layer in layer.named_children():
            hook = TimerHook()
            sub_layer.register_forward_hook(hook)
            hooks[sub_name] = hook
    else:
        hook = TimerHook()
        layer.register_forward_hook(hook)
        hooks[name] = hook

# After an epoch, print the aggregated times
def print_times_and_reset():
    for name, hook in hooks.items():
        print(f"Total time for {name}: {hook.total()} seconds")
        hook.reset()


# Model Training

## BNN vs. RDNN

In [19]:
def train(model, dataloader, criterion, optimizer, epochs=10):
    for epoch in range(epochs):
        running_loss = 0.0
        for i, data in enumerate(dataloader, 0):
            inputs, labels = data
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"Epoch {epoch+1}, Loss: {running_loss / len(dataloader)}")
    print('Finished Training')

# Train the real-valued DNN
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(dnn_model.parameters(), lr=0.001, momentum=0.9)
train(dnn_model, trainloader, criterion, optimizer)

# Train the BNN
optimizer_bnn = optim.SGD(bnn_model.parameters(), lr=0.001, momentum=0.9)
train(bnn_model, trainloader, criterion, optimizer_bnn)


Epoch 1, Loss: 2.155623477058411
Epoch 2, Loss: 2.1059559034729003
Epoch 3, Loss: 2.0895450543379783
Epoch 4, Loss: 2.0689636600238086
Epoch 5, Loss: 2.0588250145411493
Epoch 6, Loss: 2.0522228903388977
Epoch 7, Loss: 2.0453778921216728
Epoch 8, Loss: 2.0422889260053636
Epoch 9, Loss: 2.036768742130399
Epoch 10, Loss: 2.0323982599473
Finished Training
Epoch 1, Loss: 2.1664529304897786
Epoch 2, Loss: 2.1771481917607782
Epoch 3, Loss: 2.191689472646117
Epoch 4, Loss: 2.2208557810384035
Epoch 5, Loss: 2.2242107924509047
Epoch 6, Loss: 2.202486847819686
Epoch 7, Loss: 2.240039607576132
Epoch 8, Loss: 2.252634796155095
Epoch 9, Loss: 2.2478501925316454
Epoch 10, Loss: 2.249462173079848
Finished Training


## NIN vs. Real NIN

In [27]:
def train(model, dataloader, criterion, optimizer, epochs=10):
    for epoch in range(epochs):
        running_loss = 0.0
        for i, data in enumerate(dataloader, 0):
            inputs, labels = data
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"Epoch {epoch+1}, Loss: {running_loss / len(dataloader)}")
        print_times_and_reset()
    print('Finished Training')

# Train the real-valued NIN
criterion = nn.CrossEntropyLoss()
# optimizer = optim.SGD(model_real_nin.parameters(), lr=0.001, momentum=0.9)
# train(model_real_nin, trainloader, criterion, optimizer)

# Train the B-NIN
optimizer_bnn = optim.SGD(model_nin.parameters(), lr=0.001, momentum=0.9)
train(model_nin, trainloader, criterion, optimizer_bnn)

Epoch 1, Loss: 1.7454025261569024
Total time for 0: 165.91272497177124 seconds
Total time for 1: 165.91626811027527 seconds
Total time for 2: 165.9168107509613 seconds
Total time for 3: 165.90263390541077 seconds
Total time for 4: 165.9103901386261 seconds
Total time for 5: 165.90477681159973 seconds
Total time for 6: 165.91030263900757 seconds
Total time for 7: 165.93570709228516 seconds
Total time for 8: 165.9660141468048 seconds
Total time for 9: 165.97250533103943 seconds
Total time for 10: 165.97922015190125 seconds
Total time for 11: 165.98755264282227 seconds
Total time for 12: 165.99263215065002 seconds
Total time for 13: 165.99957704544067 seconds
Total time for 14: 166.00323724746704 seconds
Total time for 15: 166.00459933280945 seconds
Epoch 2, Loss: 1.3910341137969493
Total time for 0: 164.85368847846985 seconds
Total time for 1: 164.86217665672302 seconds
Total time for 2: 164.865398645401 seconds
Total time for 3: 164.9595353603363 seconds
Total time for 4: 164.9498412609

KeyboardInterrupt: 

# Model Evaluation

In [None]:
def evaluate(model, dataloader):
    correct = 0
    total = 0
    with torch.no_grad():
        for data in dataloader:
            images, labels = data
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return 100 * correct / total

dnn_accuracy = evaluate(dnn_model, testloader)
bnn_accuracy = evaluate(bnn_model, testloader)

print(f'DNN Accuracy on test set: {dnn_accuracy}%')
print(f'BNN Accuracy on test set: {bnn_accuracy}%')


DNN Accuracy on test set: 33.82%
BNN Accuracy on test set: 33.08%


In [None]:
def evaluate(model, dataloader):
    correct = 0
    total = 0
    with torch.no_grad():
        for data in dataloader:
            images, labels = data
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return 100 * correct / total

r_nin_accuracy = evaluate(model_real_nin, testloader)
b_nin_accuracy = evaluate(model_nin, testloader)

print(f'R-NIN Accuracy on test set: {r_nin_accuracy}%')
print(f'B-NIN Accuracy on test set: {b_nin_accuracy}%')

R-NIN Accuracy on test set: 10.0%
B-NIN Accuracy on test set: 41.03%
