In [1]:
!pip install setuptools==69.5.0
import datetime
import torch
import torchvision
import brevitas.nn as qnn
from brevitas.quant.scaled_int import Int8ActPerTensorFloat, Int32Bias


Defaulting to user installation because normal site-packages is not writeable


In [2]:
WEIGHT_BIT_WIDTH = 8
ACT_BIT_WIDTH = 3

class QuantModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_0 = qnn.QuantConv2d(
            1,
            6,
            kernel_size=3,
            bias=False,
            weight_bit_width=WEIGHT_BIT_WIDTH)
        self.relu_0 = qnn.QuantReLU(bit_width=ACT_BIT_WIDTH)
        self.conv_1 = qnn.QuantConv2d(6, 16, 6,
                                      weight_bit_width=WEIGHT_BIT_WIDTH,
                                      bias=False)
        self.relu_1 = qnn.QuantReLU(bit_width=ACT_BIT_WIDTH)
        self.conv_2 = qnn.QuantConv2d(16, 128, 4,
                                      weight_bit_width=WEIGHT_BIT_WIDTH,
                                      bias=False)
        self.fc1 = qnn.QuantLinear(128, 84,
                                   weight_bit_width=WEIGHT_BIT_WIDTH,
                                   bias=True)
        self.relu_2 = qnn.QuantReLU(bit_width=ACT_BIT_WIDTH)
        self.fc2 = qnn.QuantLinear(84, 10,
                                   weight_bit_width=WEIGHT_BIT_WIDTH,
                                   bias=True)

    def forward(self, x):
        x = self.conv_0(x)
        x = self.relu_0(x)
        x = torch.nn.functional.max_pool2d(x, 2, 2)
        x = self.conv_1(x)
        x = self.relu_1(x)
        x = torch.nn.functional.max_pool2d(x, 2, 2)
        x = self.conv_2(x)
        x = x.view(-1, 128)
        x = self.fc1(x)
        x = self.relu_2(x)
        x = self.fc2(x)
        return x
    

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device: ", device)

Using device:  cuda


In [4]:
input_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.5,), (0.5,))
])

In [10]:
train_dataset = torchvision.datasets.FashionMNIST('./data', train=True, download=True, transform=input_transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

val_dataset = torchvision.datasets.FashionMNIST('./data', train=False, download=True, transform=input_transform)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=64, shuffle=False)

classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
           'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')

print("Train dataset size: ", len(train_dataset))
print("Val dataset size: ", len(val_dataset))


Train dataset size:  60000
Val dataset size:  10000


In [6]:
model = QuantModel().to(device)
loss = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

In [8]:
for epoch in range(20):
    print(f"Epoch {epoch}")
    train_loss = 0
    val_loss = 0
    last_loss = 0

    model.train()
    for i, data in enumerate(train_loader):
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)

        loss_value = loss(outputs, labels)
        loss_value.backward()
        optimizer.step()
        train_loss += loss_value.item()
        if i > 0 and i % 100 == 0:
            last_loss = train_loss / 100
            print(f"Batch {i}: Train loss: {last_loss}")
            train_loss = 0

    model.eval()
    with torch.no_grad():
        for i, data in enumerate(val_loader):
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            loss_value = loss(outputs, labels)
            val_loss += loss_value.item()

    val_loss = val_loss / len(val_loader)
    print(f"Val loss: {val_loss}")

torch.save(model.state_dict(), "fashion_mnist_quant.pt")

Epoch 0


  return super(Tensor, self).rename(names)


Batch 100: Train loss: 2.309877800941467
Batch 200: Train loss: 2.236604051589966
Batch 300: Train loss: 2.057586097717285
Batch 400: Train loss: 1.617937502861023
Batch 500: Train loss: 1.30959286570549
Batch 600: Train loss: 1.113057897090912
Batch 700: Train loss: 1.0147151958942413
Batch 800: Train loss: 0.9384061247110367
Batch 900: Train loss: 0.8605263912677765
Val loss: 0.827880182463652
Epoch 1
Batch 100: Train loss: 0.801836873292923
Batch 200: Train loss: 0.7603584504127503
Batch 300: Train loss: 0.7437887847423553
Batch 400: Train loss: 0.7246499216556549
Batch 500: Train loss: 0.6847815686464309
Batch 600: Train loss: 0.6992537569999695
Batch 700: Train loss: 0.659537806212902
Batch 800: Train loss: 0.650908077955246
Batch 900: Train loss: 0.6508463153243065
Val loss: 0.6473856923306823
Epoch 2
Batch 100: Train loss: 0.625213440656662
Batch 200: Train loss: 0.6157177484035492
Batch 300: Train loss: 0.6062486585974693
Batch 400: Train loss: 0.5901076719164848
Batch 500: Tra

In [8]:
model.load_state_dict(torch.load("fashion_mnist_quant.pt"))

<All keys matched successfully>

In [9]:
from brevitas.export import export_qonnx
from qonnx.util.cleanup import cleanup as qonnx_cleanup

input_shape = (1, 1, 28, 28)
inp = torch.rand(input_shape)
print(next(val_loader.__iter__())[0].shape)

model.cpu()

qonnx_path = "fashion_mnist_quant.onnx"
export_qonnx(model, inp, export_path=qonnx_path)
qonnx_cleanup(qonnx_path, out_file=qonnx_path)

from finn.util.visualization import showInNetron

showInNetron(qonnx_path)


torch.Size([64, 1, 28, 28])
Serving 'fashion_mnist_quant.onnx' at http://0.0.0.0:8081


In [10]:
from qonnx.core.modelwrapper import ModelWrapper
from qonnx.core.datatype import DataType
from finn.transformation.qonnx.convert_qonnx_to_finn import ConvertQONNXtoFINN

finn_path = "fashion_mnist_finn.onnx"

model_for_finn = ModelWrapper(qonnx_path)
model_for_finn = model_for_finn.transform(ConvertQONNXtoFINN())
model_for_finn.save(finn_path)



In [11]:
showInNetron(finn_path)

Stopping http://0.0.0.0:8081
Serving 'fashion_mnist_finn.onnx' at http://0.0.0.0:8081


In [12]:
import finn.core.onnx_exec as oxe

model_for_finn = ModelWrapper(finn_path)
input_name = model_for_finn.graph.input[0].name
input_shape = model_for_finn.get_tensor_shape(input_name)
output_name = model_for_finn.graph.output[0].name
inp = next(val_loader.__iter__())[0][0:1]
inp_dict = {input_name: inp.detach().numpy()}
out_dict = oxe.execute_onnx(model_for_finn, inp_dict)

print(f"FINN output: {out_dict[output_name]}")
print(f"Brevitas output: {model(inp)}")

[1, 1, 28, 28]
FINN output: [[-1.3989073  -9.333837   -3.0418384  -3.5792613  -3.0741603   4.254069
  -2.909774    5.8974247   0.96250105  8.780479  ]]
Brevitas output: tensor([[-1.3989, -9.3338, -3.0418, -3.5793, -3.0742,  4.2541, -2.9098,  5.8974,
          0.9625,  8.7805]], grad_fn=<AddmmBackward0>)
