In [1]:
from finn.util.basic import make_build_dir
from finn.util.visualization import showSrc, showInNetron
    
build_dir = "/workspace/finn/lenet"

MODEL_PREFIX = "test_model"

In [2]:
import torch
from torch.nn import Module, ModuleList, BatchNorm2d, MaxPool2d, BatchNorm1d

from brevitas.nn import QuantConv2d, QuantIdentity, QuantLinear
from brevitas.core.restrict_val import RestrictValueType
from brevitas_examples.bnn_pynq.models.common import CommonWeightQuant, CommonActQuant
from brevitas.core.restrict_val import RestrictValueType
from brevitas_examples.bnn_pynq.models.tensor_norm import TensorNorm

In [3]:
from brevitas.core.scaling import ScalingImplType
from brevitas.core.stats import StatsOp
from brevitas.nn import QuantReLU
from brevitas.core.quant import QuantType


In [4]:
# CNV_OUT_CH_POOL = [(64, False), (64, True), (128, False), (128, True), (256, False), (256, False)]
CNV_OUT_CH_POOL = [(6, True), (16, True)]

INTERMEDIATE_FC_FEATURES = [(16*4*4, 120), (120, 84)]
LAST_FC_IN_FEATURES = 84

POOL_SIZE = 2
KERNEL_SIZE = 5

MAX_RELU_VAL = 6.0

class CNV(Module):

    def __init__(self, num_classes, weight_bit_width, act_bit_width, in_bit_width, in_ch=1):
        super(CNV, self).__init__()

        self.conv_features = ModuleList()
        self.linear_features = ModuleList()

        self.conv_features.append(QuantIdentity( # for Q1.7 input format
            act_quant=CommonActQuant,
            bit_width=in_bit_width,
            min_val=- 1.0,
            max_val=1.0 - 2.0 ** (-7),
            narrow_range=False,
            restrict_scaling_type=RestrictValueType.POWER_OF_TWO))

        for out_ch, is_pool_enabled in CNV_OUT_CH_POOL:
            self.conv_features.append(QuantConv2d(
                kernel_size=KERNEL_SIZE,
                in_channels=in_ch,
                out_channels=out_ch,
                bias=False,
                weight_quant=CommonWeightQuant,
                weight_bit_width=weight_bit_width))
            
#             self.conv_features.append(QuantReLU(
#                 bit_width=act_bit_width,
#                 quant_type=QuantType.INT,
#                 weight_quant=CommonWeightQuant
#             ))
            
            in_ch = out_ch
            self.conv_features.append(BatchNorm2d(in_ch, eps=1e-4))
            self.conv_features.append(QuantIdentity(
                act_quant=CommonActQuant,
                bit_width=act_bit_width))
            if is_pool_enabled:
                self.conv_features.append(MaxPool2d(kernel_size=2))

        for in_features, out_features in INTERMEDIATE_FC_FEATURES:
            self.linear_features.append(QuantLinear(
                in_features=in_features,
                out_features=out_features,
                bias=False,
                weight_quant=CommonWeightQuant,
                weight_bit_width=weight_bit_width))
            
#             self.linear_features.append(QuantReLU(
#                 bit_width=act_bit_width,
#                 quant_type=QuantType.INT,
#                 weight_quant=CommonWeightQuant,
#             ))
            
            self.linear_features.append(BatchNorm1d(out_features, eps=1e-4))
            self.linear_features.append(QuantIdentity(
                act_quant=CommonActQuant,
                bit_width=act_bit_width))

        self.linear_features.append(QuantLinear(
            in_features=LAST_FC_IN_FEATURES,
            out_features=num_classes,
            bias=False,
            weight_quant=CommonWeightQuant,
            weight_bit_width=weight_bit_width))
        
#         self.linear_features.append(QuantReLU(
#             bit_width=act_bit_width,
#             quant_type=QuantType.INT,
#             weight_quant=CommonWeightQuant,
#             max_val=MAX_RELU_VAL,
# #             scaling_impl_type=ScalingImplType.CONST,
# #             scaling_stats_permute_dims=None,
# #             scaling_stats_op=StatsOp.MAX
#         ))
        
        self.linear_features.append(TensorNorm())
        
        for m in self.modules():
            if isinstance(m, QuantConv2d) or isinstance(m, QuantLinear):
                torch.nn.init.uniform_(m.weight.data, -1, 1)


    def clip_weights(self, min_val, max_val):
        for mod in self.conv_features:
            if isinstance(mod, QuantConv2d):
                mod.weight.data.clamp_(min_val, max_val)
        for mod in self.linear_features:
            if isinstance(mod, QuantLinear):
                mod.weight.data.clamp_(min_val, max_val)

    def forward(self, x):
        x = 2.0 * x - torch.tensor([1.0], device=x.device)
        for mod in self.conv_features:
            x = mod(x)
        x = x.view(x.shape[0], -1)
        for mod in self.linear_features:
            x = mod(x)
        return x


def cnv(weight_bit_width, act_bit_width, in_bit_width, num_classes, in_channels):
#     weight_bit_width = cfg.getint('QUANT', 'WEIGHT_BIT_WIDTH')
#     act_bit_width = cfg.getint('QUANT', 'ACT_BIT_WIDTH')
#     in_bit_width = cfg.getint('QUANT', 'IN_BIT_WIDTH')
#     num_classes = cfg.getint('MODEL', 'NUM_CLASSES')
#     in_channels = cfg.getint('MODEL', 'IN_CHANNELS')
    net = CNV(weight_bit_width=weight_bit_width,
              act_bit_width=act_bit_width,
              in_bit_width=in_bit_width,
              num_classes=num_classes,
              in_ch=in_channels)
    return net

In [5]:
import onnx
from finn.util.test import get_test_model_trained
import brevitas.onnx as bo
from finn.core.modelwrapper import ModelWrapper
from finn.transformation.infer_shapes import InferShapes
from finn.transformation.fold_constants import FoldConstants
from finn.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames, RemoveStaticGraphInputs


In [6]:
model = cnv(4, 8, 8, 10, 1)

In [7]:
[x for x in dir(model.conv_features[0]) if 'ame' in x]

['_export_debug_name',
 '_get_name',
 '_named_members',
 '_parameters',
 '_tracing_name',
 'export_debug_name',
 'named_buffers',
 'named_children',
 'named_modules',
 'named_parameters',
 'parameters',
 'register_parameter']

In [8]:
[x._get_name() for x in model.conv_features]

['QuantIdentity',
 'QuantConv2d',
 'BatchNorm2d',
 'QuantIdentity',
 'MaxPool2d',
 'QuantConv2d',
 'BatchNorm2d',
 'QuantIdentity',
 'MaxPool2d']

In [9]:
[x._get_name() for x in model.linear_features]

['QuantLinear',
 'BatchNorm1d',
 'QuantIdentity',
 'QuantLinear',
 'BatchNorm1d',
 'QuantIdentity',
 'QuantLinear',
 'TensorNorm']

In [10]:
bo.export_finn_onnx(model, (1, 1, 28, 28), build_dir + f"/{MODEL_PREFIX}_export.onnx")
model = ModelWrapper(build_dir + f"/{MODEL_PREFIX}_export.onnx")



In [None]:
model = model.transform(InferShapes())

In [None]:
model = model.transform(FoldConstants())


In [None]:
model = model.transform(GiveUniqueNodeNames())


In [None]:
model = model.transform(GiveReadableTensorNames())


In [None]:
model = model.transform(RemoveStaticGraphInputs())


In [None]:
model.save(build_dir + f"/{MODEL_PREFIX}_tidy.onnx")
showInNetron(build_dir+f"/{MODEL_PREFIX}_tidy.onnx")

In [None]:
from finn.util.pytorch import ToTensor
from finn.transformation.merge_onnx_models import MergeONNXModels
from finn.core.datatype import DataType

model = ModelWrapper(build_dir+f"/{MODEL_PREFIX}_tidy.onnx")
global_inp_name = model.graph.input[0].name
ishape = model.get_tensor_shape(global_inp_name)

# preprocessing: torchvision's ToTensor divides uint8 inputs by 255
totensor_pyt = ToTensor()
chkpt_preproc_name = build_dir+f"/{MODEL_PREFIX}_preproc.onnx"
bo.export_finn_onnx(totensor_pyt, ishape, chkpt_preproc_name)

# join preprocessing and core model
pre_model = ModelWrapper(chkpt_preproc_name)
model = model.transform(MergeONNXModels(pre_model))

# add input quantization annotation: UINT8 for all BNN-PYNQ models
global_inp_name = model.graph.input[0].name
model.set_tensor_datatype(global_inp_name, DataType.UINT8)