In [1]:
import torchvision
from torchvision import datasets
from torchvision import transforms
from torch.autograd import Variable

In [2]:
data_train = torchvision.datasets.MNIST('./data/', 
                                        train=True, download=True,
                                        transform=torchvision.transforms.Compose
                                        ([
                                            torchvision.transforms.ToTensor(),
                                            torchvision.transforms.Normalize((0.5,), (0.5,))
                                        ]))
data_test = torchvision.datasets.MNIST('./data/', 
                                       train=False, download=True,
                                       transform=torchvision.transforms.Compose
                                       ([
                                            torchvision.transforms.ToTensor(),
                                            torchvision.transforms.Normalize((0.5,), (0.5,))
                                       ]))

In [3]:
import torch

batch_size_train = 64
batch_size_test = 1000

data_loader_train = torch.utils.data.DataLoader(dataset=data_train,
                                                batch_size=batch_size_train, 
                                                shuffle=True)

data_loader_test = torch.utils.data.DataLoader(dataset=data_test,
                                               batch_size=batch_size_test, 
                                               shuffle=True)

In [4]:
import brevitas.nn as qnn
from brevitas.nn import QuantLinear, QuantReLU, QuantConv2d
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Module
from brevitas.quant.binary import SignedBinaryActPerTensorConst
from brevitas.quant.binary import SignedBinaryWeightPerTensorConst
from brevitas.inject.enum import QuantType

# Setting seeds for reproducibility
torch.manual_seed(0)

ModuleNotFoundError: No module named 'brevitas'

In [None]:
class TCV_W1A1(Module):
    def __init__(self):
        super(TCV_W1A1, self).__init__()
        
        self.input = qnn.QuantIdentity(
                         quant_type='binary',
                         scaling_impl_type='const',
                         bit_width=act_bit_width,
                         min_val=-1.0,
                         max_val=1.0, 
                         return_quant_tensor=True
                     )
        
        self.conv1 = qnn.QuantConv2d( 
                         in_channels=in_channels1,
                         out_channels=out_channels1,
                         kernel_size=kernel_size, 
                         stride=1, 
                         padding=1,
                         weight_bit_width=weight_bit_width,
                         weight_quant_type=QuantType.BINARY,
                         bias=False
                     )
        
        self.bn1   = nn.BatchNorm2d(out_channels1)
        self.relu1 = qnn.QuantReLU(
                         bit_width=act_bit_width, 
                         return_quant_tensor=True
                     )
        
        self.pool1 = qnn.QuantMaxPool2d(2, return_quant_tensor=True)
        
        self.conv2 = qnn.QuantConv2d( 
                         in_channels=in_channels2,
                         out_channels=out_channels2,
                         kernel_size=kernel_size, 
                         stride=1, 
                         padding=1,
                         weight_bit_width=weight_bit_width,
                         weight_quant_type=QuantType.BINARY,
                         bias=False
                     )
        
        self.bn2   = nn.BatchNorm2d(out_channels2)
        self.relu2 = qnn.QuantReLU(
                         bit_width=act_bit_width, 
                         return_quant_tensor=True
                     )
        
        self.pool2 = qnn.QuantMaxPool2d(2, return_quant_tensor=True)
        
        self.fc1   = qnn.QuantLinear(
                         input_size, 
                         hidden1, 
                         weight_bit_width=weight_bit_width,
                         weight_quant_type=QuantType.BINARY,
                         bias=False
                     )
        
        self.bn3   = nn.BatchNorm1d(hidden1)
        self.relu3 = qnn.QuantReLU(
                         bit_width=act_bit_width, 
                         return_quant_tensor=True
                     )
        
        self.out   = qnn.QuantLinear(
                         hidden1, 
                         num_classes, 
                         weight_bit_width=weight_bit_width,
                         weight_quant_type=QuantType.BINARY,
                         bias=False
                     )

    def forward(self, x):
        out = self.input(x)
        out = self.pool1(self.relu1(self.bn1(self.conv1(out))))
        out = self.pool2(self.relu2(self.bn2(self.conv2(out))))
        out = out.reshape(out.shape[0], -1)
        out = self.relu3(self.bn3(self.fc1(out)))
        out = self.out(out)
        return out
   
model = TCV_W1A1()