In [1]:
!pip install setuptools==69.5.0
import datetime
import torch
import torchvision
import brevitas.nn as qnn
from brevitas.quant.scaled_int import Int8ActPerTensorFloat, Int32Bias


Defaulting to user installation because normal site-packages is not writeable


In [2]:
WEIGHT_BIT_WIDTH = 8
ACT_BIT_WIDTH = 3

class QuantModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_0 = qnn.QuantConv2d(
            1,
            6,
            kernel_size=3,
            bias=False,
            weight_bit_width=WEIGHT_BIT_WIDTH)
        self.relu_0 = qnn.QuantReLU(bit_width=ACT_BIT_WIDTH)
        self.conv_1 = qnn.QuantConv2d(6, 16, 6,
                                      weight_bit_width=WEIGHT_BIT_WIDTH,
                                      bias=False)
        self.relu_1 = qnn.QuantReLU(bit_width=ACT_BIT_WIDTH)
        self.conv_2 = qnn.QuantConv2d(16, 128, 4,
                                      weight_bit_width=WEIGHT_BIT_WIDTH,
                                      bias=False)
        self.fc1 = qnn.QuantLinear(128, 84,
                                   weight_bit_width=WEIGHT_BIT_WIDTH,
                                   bias=True)
        self.relu_2 = qnn.QuantReLU(bit_width=ACT_BIT_WIDTH)
        self.fc2 = qnn.QuantLinear(84, 10,
                                   weight_bit_width=WEIGHT_BIT_WIDTH,
                                   bias=True)

    def forward(self, x):
        x = self.conv_0(x)
        x = self.relu_0(x)
        x = torch.nn.functional.max_pool2d(x, 2, 2)
        x = self.conv_1(x)
        x = self.relu_1(x)
        x = torch.nn.functional.max_pool2d(x, 2, 2)
        x = self.conv_2(x)
        x = x.view(-1, 128)
        x = self.fc1(x)
        x = self.relu_2(x)
        x = self.fc2(x)
        return x
    

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device: ", device)

Using device:  cuda


In [4]:
input_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
])

In [5]:
train_dataset = torchvision.datasets.FashionMNIST('./data', train=True, download=True, transform=input_transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

val_dataset = torchvision.datasets.FashionMNIST('./data', train=False, download=True, transform=input_transform)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=64, shuffle=False)

classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
           'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')

print("Train dataset size: ", len(train_dataset))
print("Val dataset size: ", len(val_dataset))


Train dataset size:  60000
Val dataset size:  10000


In [6]:
model = QuantModel().to(device)
loss = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

In [26]:
for epoch in range(20):
    print(f"Epoch {epoch}")
    train_loss = 0
    val_loss = 0
    last_loss = 0

    model.train()
    for i, data in enumerate(train_loader):
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)

        loss_value = loss(outputs, labels)
        loss_value.backward()
        optimizer.step()
        train_loss += loss_value.item()
        if i > 0 and i % 100 == 0:
            last_loss = train_loss / 100
            print(f"Batch {i}: Train loss: {last_loss}")
            train_loss = 0

    model.eval()
    with torch.no_grad():
        for i, data in enumerate(val_loader):
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            loss_value = loss(outputs, labels)
            val_loss += loss_value.item()

    val_loss = val_loss / len(val_loader)
    print(f"Val loss: {val_loss}")

torch.save(model.state_dict(), "fashion_mnist_quant.pt")

Epoch 0
Batch 100: Train loss: 2.326232163906097
Batch 200: Train loss: 2.2999991583824158
Batch 300: Train loss: 2.296544842720032
Batch 400: Train loss: 2.2940742707252504
Batch 500: Train loss: 2.2860090065002443
Batch 600: Train loss: 2.2772892189025877
Batch 700: Train loss: 2.260222737789154
Batch 800: Train loss: 2.232508268356323
Batch 900: Train loss: 2.1825063848495483
Val loss: 2.115875739200859
Epoch 1
Batch 100: Train loss: 2.080527250766754
Batch 200: Train loss: 1.905735763311386
Batch 300: Train loss: 1.732987620830536
Batch 400: Train loss: 1.5567605876922608
Batch 500: Train loss: 1.3984897255897522
Batch 600: Train loss: 1.268468418121338
Batch 700: Train loss: 1.1667830550670624
Batch 800: Train loss: 1.1032745975255966
Batch 900: Train loss: 1.0356007397174836
Val loss: 1.0123573389782268
Epoch 2
Batch 100: Train loss: 0.9851358270645142
Batch 200: Train loss: 0.9386532229185104
Batch 300: Train loss: 0.9216842442750931
Batch 400: Train loss: 0.8831329649686813
Bat

In [7]:
model.load_state_dict(torch.load("fashion_mnist_quant.pt"))

<All keys matched successfully>

In [8]:
from brevitas.export import export_qonnx
from qonnx.util.cleanup import cleanup as qonnx_cleanup

input_shape = (1, 1, 28, 28)
inp = torch.rand(input_shape)
print(next(val_loader.__iter__())[0].shape)

model.cpu()

qonnx_path = "fashion_mnist_quant.onnx"
export_qonnx(model, inp, export_path=qonnx_path)
qonnx_cleanup(qonnx_path, out_file=qonnx_path)

from finn.util.visualization import showInNetron

showInNetron(qonnx_path)


torch.Size([64, 1, 28, 28])


OSError: [Errno 98] Address already in use

In [9]:
from qonnx.core.modelwrapper import ModelWrapper
from qonnx.core.datatype import DataType
from finn.transformation.qonnx.convert_qonnx_to_finn import ConvertQONNXtoFINN

finn_path = "fashion_mnist_finn.onnx"

model_for_finn = ModelWrapper(qonnx_path)
model_for_finn = model_for_finn.transform(ConvertQONNXtoFINN())
model_for_finn.save(finn_path)



In [11]:
showInNetron(finn_path)

Serving 'fashion_mnist_finn.onnx' at http://0.0.0.0:8081


In [12]:
import finn.core.onnx_exec as oxe
from qonnx.core.modelwrapper import ModelWrapper
import numpy as np

finn_path = "fashion_mnist_finn.onnx"

model_for_finn = ModelWrapper(finn_path)
input_name = model_for_finn.graph.input[0].name
input_shape = model_for_finn.get_tensor_shape(input_name)
output_name = model_for_finn.graph.output[0].name
inp = next(val_loader.__iter__())[0][0:1]
inp_dict = {input_name: inp.detach().numpy()}
out_dict = oxe.execute_onnx(model_for_finn, inp_dict)

print(f"FINN output: {out_dict[output_name]}")
print(f"Brevitas output: {model(inp)}")

FINN output: [[-2.390639   -4.6819415  -3.6313727  -1.5422814  -3.9852936   3.3341618
  -4.3270936   6.643266    0.62800026  9.485837  ]]
Brevitas output: tensor([[-2.3906, -4.6819, -3.6314, -1.5423, -3.9853,  3.3342, -4.3271,  6.6433,
          0.6280,  9.4858]], grad_fn=<AddmmBackward0>)


In [13]:
from finn.util.pytorch import ToTensor
from qonnx.transformation.merge_onnx_models import MergeONNXModels
from qonnx.core.datatype import DataType

finn_path = "fashion_mnist_finn.onnx"

model = ModelWrapper(finn_path)
input_name = model.graph.input[0].name
scale_inp = ToTensor()
pre_ckpt = "scale_input.onnx"

input_shape = (1, 1, 28, 28)
inp = torch.rand(input_shape)
export_qonnx(scale_inp, inp, export_path=pre_ckpt)
qonnx_cleanup(pre_ckpt, out_file=pre_ckpt)

pre_model = ModelWrapper(pre_ckpt)
pre_model = pre_model.transform(ConvertQONNXtoFINN())

model = model.transform(MergeONNXModels(pre_model))
global_inp_name = model.graph.input[0].name
model.set_tensor_datatype(global_inp_name, DataType["UINT8"])
model.save(finn_path)



In [14]:
showInNetron(finn_path)

Stopping http://0.0.0.0:8081
Serving 'fashion_mnist_finn.onnx' at http://0.0.0.0:8081


In [17]:
model = ModelWrapper(finn_path)
inp = next(val_loader.__iter__())[0][0:1] * 255
np.save("test_inp.npy", inp.swapaxes(1, -1))

input_name = model_for_finn.graph.input[0].name
inp_dict = {input_name: inp.detach().numpy()}

out_dict = oxe.execute_onnx(model, inp_dict)
output_name = model.graph.output[0].name
print(f"FINN output: {out_dict[output_name]}")

FINN output: [[-2.390639   -4.6819415  -3.6313727  -1.5422814  -3.9852936   3.3341618
  -4.3270936   6.643266    0.62800026  9.485837  ]]
