# Custom Quantized Brevitas Model Creation

In [1]:
import numpy as np
import brevitas.nn as qnn
from brevitas.core.quant import QuantType
from torch.nn import Module
from torch import nn
from torch import Tensor
from torch import from_numpy
import torch
import torch.optim as optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from brevitas.quant_tensor import pack_quant_tensor

In [49]:
#fix the random seed
import random
random.seed(2)

In [2]:
class XorDataset(Dataset):
    def __init__(self):
        a = Tensor([0,0])
        b = Tensor([0,1])
        c = Tensor([1,0])
        d = Tensor([1,1])
        aq = pack_quant_tensor(a,torch.tensor(.125, dtype=torch.float32),torch.tensor(4.0, dtype=torch.float32))
        bq = pack_quant_tensor(b,torch.tensor(.125, dtype=torch.float32),torch.tensor(4.0, dtype=torch.float32))
        cq = pack_quant_tensor(c,torch.tensor(.125, dtype=torch.float32),torch.tensor(4.0, dtype=torch.float32))
        dq = pack_quant_tensor(d,torch.tensor(.125, dtype=torch.float32),torch.tensor(4.0, dtype=torch.float32))
        self.data=[aq,bq,cq,dq]
        self.key=[torch.tensor(0.0, dtype=torch.float32),torch.tensor(1.0, dtype=torch.float32),
                  torch.tensor(1.0, dtype=torch.float32),torch.tensor(0.0, dtype=torch.float32)]

    def __getitem__(self, index):
        x = self.data[index]
        y = self.key[index]
        
        return x, y
    
    def __len__(self):
        return len(self.data)

In [33]:
class QuantXORNet(Module):
    def __init__(self):
        super(QuantXORNet, self).__init__()
        #self.batch1 = qnn.QuantBatchNorm2d(num_features=n_inputs, 
        #                                  momentum=0.6,
        #                                  bias_quant_type=QuantType.INT)
        self.relu0 = qnn.QuantReLU(quant_type=QuantType.INT,
                                  bit_width=4.0,
                                  max_val=8,
                                  return_quant_tensor=True)
        self.linear1 = qnn.QuantLinear(in_features = 2, 
                                       out_features=2, 
                                       bias_quant_type=QuantType.INT,
                                       bias=True,
                                       compute_output_scale=True,
                                       compute_output_bit_width=True,
                                       #input_bit_width=32,
                                       weight_quant_type=QuantType.INT,
                                      return_quant_tensor=True)
        self.relu1 = qnn.QuantReLU(quant_type=QuantType.INT,
                                  bit_width=4,
                                  max_val=8,
                                  return_quant_tensor=True)
        #self.batch2 = qnn.QuantBatchNorm2d(num_features=31, 
        #                                   momentum=0.6,
        #                                   bias_quant_type=QuantType.FP)
        self.linear2 = qnn.QuantLinear(in_features = 2, 
                                       out_features=2, 
                                       bias_quant_type=QuantType.INT, 
                                       bias=True,
                                       compute_output_scale=True,
                                       compute_output_bit_width=True,
                                       #bit_width=4,
                                       weight_quant_type=QuantType.INT,
                                      return_quant_tensor=True)
        self.relu2 = qnn.QuantReLU(quant_type=QuantType.INT,
                                  bit_width=4,
                                  max_val=8,
                                  return_quant_tensor=True)
        #self.batch3 = qnn.QuantBatchNorm2d(num_features=11, 
        #                                   momentum=0.6,
        #                                   bias_quant_type=QuantType.FP)
        self.linear3 = qnn.QuantLinear(in_features = 2, 
                                       out_features=1, 
                                       bias_quant_type=QuantType.INT, 
                                       bias=True,
                                       compute_output_scale=True,
                                       compute_output_bit_width=True,
                                       #bit_width=4,
                                       weight_quant_type=QuantType.INT)
    def forward(self, x):
        res=x
        #res = self.batch1(res)
        res = self.relu0(res)
        #res = pack_quant_tensor(res,torch.tensor(1.0, dtype=torch.float32),torch.tensor(4.0, dtype=torch.float32))
        #print(res)
        res = self.linear1(res)
        #res = pack_quant_tensor(res,torch.tensor(1.0, dtype=torch.float32),torch.tensor(4.0, dtype=torch.float32))
        res = self.relu1(res)
        #res = pack_quant_tensor(res,torch.tensor(1.0, dtype=torch.float32),torch.tensor(4.0, dtype=torch.float32))
        #res = self.relu1(self.linear1(res))
        #res = self.batch2(res)
        res = self.linear2(res)
        #res = pack_quant_tensor(res,torch.tensor(1.0, dtype=torch.float32),torch.tensor(4.0, dtype=torch.float32))
        res = self.relu2(res)
        #res = self.batch3(res)
        #res = self.relu3(self.linear3(res))
        #res = self.batch4(res)
        res = self.linear3(res)
        return res

In [59]:
model = QuantXORNet()

In [60]:
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.3)
criterion = nn.MSELoss()
loader = DataLoader(XorDataset())

In [61]:
epoch = 0
running_loss = 999
while running_loss / 4 > .1:
#for epoch in range(2000):  # loop over the dataset multiple times
    model.train()
    criterion.train()
    running_loss = 0.0
    for i, data in enumerate(loader):
        # zero the parameter gradients
        optimizer.zero_grad()
        inputs,target = data

        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, target)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 4 == 3:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 4))
    epoch += 1

print('Finished Training')

[1,     4] loss: 1.365
[2,     4] loss: 1.139
[3,     4] loss: 0.954
[4,     4] loss: 0.808
[5,     4] loss: 0.692
[6,     4] loss: 0.601
[7,     4] loss: 0.529
[8,     4] loss: 0.472
[9,     4] loss: 0.426
[10,     4] loss: 0.391
[11,     4] loss: 0.363
[12,     4] loss: 0.340
[13,     4] loss: 0.323
[14,     4] loss: 0.309
[15,     4] loss: 0.298
[16,     4] loss: 0.289
[17,     4] loss: 0.282
[18,     4] loss: 0.277
[19,     4] loss: 0.272
[20,     4] loss: 0.269
[21,     4] loss: 0.266
[22,     4] loss: 0.264
[23,     4] loss: 0.262
[24,     4] loss: 0.261
[25,     4] loss: 0.260
[26,     4] loss: 0.259
[27,     4] loss: 0.259
[28,     4] loss: 0.258
[29,     4] loss: 0.258
[30,     4] loss: 0.257
[31,     4] loss: 0.257
[32,     4] loss: 0.257
[33,     4] loss: 0.257
[34,     4] loss: 0.257
[35,     4] loss: 0.256
[36,     4] loss: 0.256
[37,     4] loss: 0.256
[38,     4] loss: 0.256
[39,     4] loss: 0.256
[40,     4] loss: 0.256
[41,     4] loss: 0.256
[42,     4] loss: 0.256
[

[342,     4] loss: 0.174
[343,     4] loss: 0.173
[344,     4] loss: 0.173
[345,     4] loss: 0.173
[346,     4] loss: 0.173
[347,     4] loss: 0.173
[348,     4] loss: 0.173
[349,     4] loss: 0.173
[350,     4] loss: 0.173
[351,     4] loss: 0.173
[352,     4] loss: 0.173
[353,     4] loss: 0.173
[354,     4] loss: 0.173
[355,     4] loss: 0.173
[356,     4] loss: 0.173
[357,     4] loss: 0.173
[358,     4] loss: 0.173
[359,     4] loss: 0.173
[360,     4] loss: 0.173
[361,     4] loss: 0.173
[362,     4] loss: 0.173
[363,     4] loss: 0.173
[364,     4] loss: 0.173
[365,     4] loss: 0.173
[366,     4] loss: 0.173
[367,     4] loss: 0.173
[368,     4] loss: 0.173
[369,     4] loss: 0.173
[370,     4] loss: 0.173
[371,     4] loss: 0.173
[372,     4] loss: 0.173
[373,     4] loss: 0.173
[374,     4] loss: 0.172
[375,     4] loss: 0.172
[376,     4] loss: 0.172
[377,     4] loss: 0.172
[378,     4] loss: 0.172
[379,     4] loss: 0.172
[380,     4] loss: 0.172
[381,     4] loss: 0.172


[672,     4] loss: 0.171
[673,     4] loss: 0.171
[674,     4] loss: 0.171
[675,     4] loss: 0.171
[676,     4] loss: 0.171
[677,     4] loss: 0.171
[678,     4] loss: 0.171
[679,     4] loss: 0.171
[680,     4] loss: 0.171
[681,     4] loss: 0.171
[682,     4] loss: 0.171
[683,     4] loss: 0.171
[684,     4] loss: 0.171
[685,     4] loss: 0.171
[686,     4] loss: 0.171
[687,     4] loss: 0.171
[688,     4] loss: 0.171
[689,     4] loss: 0.171
[690,     4] loss: 0.171
[691,     4] loss: 0.171
[692,     4] loss: 0.171
[693,     4] loss: 0.171
[694,     4] loss: 0.171
[695,     4] loss: 0.171
[696,     4] loss: 0.171
[697,     4] loss: 0.171
[698,     4] loss: 0.171
[699,     4] loss: 0.171
[700,     4] loss: 0.171
[701,     4] loss: 0.171
[702,     4] loss: 0.171
[703,     4] loss: 0.171
[704,     4] loss: 0.171
[705,     4] loss: 0.171
[706,     4] loss: 0.171
[707,     4] loss: 0.171
[708,     4] loss: 0.171
[709,     4] loss: 0.171
[710,     4] loss: 0.171
[711,     4] loss: 0.171


[1002,     4] loss: 0.173
[1003,     4] loss: 0.173
[1004,     4] loss: 0.225
[1005,     4] loss: 0.174
[1006,     4] loss: 0.174
[1007,     4] loss: 0.173
[1008,     4] loss: 0.222
[1009,     4] loss: 0.175
[1010,     4] loss: 0.174
[1011,     4] loss: 0.174
[1012,     4] loss: 0.219
[1013,     4] loss: 0.175
[1014,     4] loss: 0.174
[1015,     4] loss: 0.174
[1016,     4] loss: 0.217
[1017,     4] loss: 0.175
[1018,     4] loss: 0.175
[1019,     4] loss: 0.214
[1020,     4] loss: 0.176
[1021,     4] loss: 0.175
[1022,     4] loss: 0.175
[1023,     4] loss: 0.213
[1024,     4] loss: 0.176
[1025,     4] loss: 0.175
[1026,     4] loss: 0.210
[1027,     4] loss: 0.176
[1028,     4] loss: 0.176
[1029,     4] loss: 0.208
[1030,     4] loss: 0.177
[1031,     4] loss: 0.176
[1032,     4] loss: 0.176
[1033,     4] loss: 0.208
[1034,     4] loss: 0.177
[1035,     4] loss: 0.176
[1036,     4] loss: 0.206
[1037,     4] loss: 0.177
[1038,     4] loss: 0.176
[1039,     4] loss: 0.204
[1040,     4

[1322,     4] loss: 0.171
[1323,     4] loss: 0.171
[1324,     4] loss: 0.170
[1325,     4] loss: 0.170
[1326,     4] loss: 0.170
[1327,     4] loss: 0.170
[1328,     4] loss: 0.170
[1329,     4] loss: 0.170
[1330,     4] loss: 0.170
[1331,     4] loss: 0.170
[1332,     4] loss: 0.170
[1333,     4] loss: 0.170
[1334,     4] loss: 0.170
[1335,     4] loss: 0.170
[1336,     4] loss: 0.170
[1337,     4] loss: 0.170
[1338,     4] loss: 0.170
[1339,     4] loss: 0.170
[1340,     4] loss: 0.170
[1341,     4] loss: 0.170
[1342,     4] loss: 0.170
[1343,     4] loss: 0.170
[1344,     4] loss: 0.170
[1345,     4] loss: 0.170
[1346,     4] loss: 0.170
[1347,     4] loss: 0.170
[1348,     4] loss: 0.170
[1349,     4] loss: 0.170
[1350,     4] loss: 0.170
[1351,     4] loss: 0.170
[1352,     4] loss: 0.170
[1353,     4] loss: 0.170
[1354,     4] loss: 0.170
[1355,     4] loss: 0.170
[1356,     4] loss: 0.170
[1357,     4] loss: 0.170
[1358,     4] loss: 0.198
[1359,     4] loss: 0.171
[1360,     4

[1642,     4] loss: 0.170
[1643,     4] loss: 0.170
[1644,     4] loss: 0.170
[1645,     4] loss: 0.170
[1646,     4] loss: 0.170
[1647,     4] loss: 0.170
[1648,     4] loss: 0.170
[1649,     4] loss: 0.170
[1650,     4] loss: 0.170
[1651,     4] loss: 0.170
[1652,     4] loss: 0.170
[1653,     4] loss: 0.170
[1654,     4] loss: 0.170
[1655,     4] loss: 0.170
[1656,     4] loss: 0.170
[1657,     4] loss: 0.170
[1658,     4] loss: 0.170
[1659,     4] loss: 0.170
[1660,     4] loss: 0.170
[1661,     4] loss: 0.170
[1662,     4] loss: 0.170
[1663,     4] loss: 0.170
[1664,     4] loss: 0.170
[1665,     4] loss: 0.170
[1666,     4] loss: 0.170
[1667,     4] loss: 0.170
[1668,     4] loss: 0.170
[1669,     4] loss: 0.170
[1670,     4] loss: 0.170
[1671,     4] loss: 0.170
[1672,     4] loss: 0.170
[1673,     4] loss: 0.170
[1674,     4] loss: 0.170
[1675,     4] loss: 0.170
[1676,     4] loss: 0.170
[1677,     4] loss: 0.170
[1678,     4] loss: 0.170
[1679,     4] loss: 0.170
[1680,     4

KeyboardInterrupt: 

In [62]:
print(model(Tensor([0,0])))
print(model(Tensor([0,1])))
print(model(Tensor([1,0])))
print(model(Tensor([1,1])))

tensor([0.3356], grad_fn=<AddBackward0>)
tensor([1.0076], grad_fn=<AddBackward0>)
tensor([0.3356], grad_fn=<AddBackward0>)
tensor([0.3356], grad_fn=<AddBackward0>)


# Export Model from Brevitas to FINN ONNX

In [63]:
from brevitas.onnx import export_finn_onnx

In [64]:
input_shape = [2]
export_finn_onnx(model, input_shape, "xor.onnx")

Exception: Unsupported config combination for export