# Importing Brevitas networks into FINN with the QONNX interchange format

**Note: Previously it was possible to directly export the FINN-ONNX interchange format from Brevitas to pass to the FINN compiler. This support is deprecated and FINN uses the export to the QONNX format as a front end, internally FINN uses still the FINN-ONNX format.**

In this notebook we'll go through an example of how to import a Brevitas-trained QNN into FINN. The steps will be as follows:

1. Load up the trained PyTorch model
2. Call Brevitas QONNX export and visualize with Netron
3. Import into FINN and converting QONNX to FINN-ONNX

We'll use the following utility functions to print the source code for function calls (`showSrc()`) and to visualize a network using netron (`showInNetron()`) in the Jupyter notebook:

In [1]:
import onnx
from finn.util.visualization import showSrc, showInNetron

## 1. Load up the Data in tensor format

The FINN Docker image comes with several [example Brevitas networks](https://github.com/Xilinx/brevitas/tree/master/src/brevitas_examples/bnn_pynq), and we'll use the LFC-w1a1 model as the example network here. This is a binarized fully connected network trained on the MNIST dataset. Let's start by looking at what the PyTorch network definition looks like:

We can see that the network topology is constructed using a few helper functions that generate the quantized linear layers and quantized activations. The bitwidth of the layers is actually parametrized in the constructor, so let's instantiate a 1-bit weights and activations version of this network. We also have pretrained weights for this network, which we will load into the model.

We have now instantiated our trained PyTorch network. Let's try to run an example MNIST image through the network using PyTorch.

## 2. Visualize with Netron

Let's examine what the exported ONNX model looks like. For this, we will use the Netron visualizer:

In [14]:
model_path = "/home/omar/finn/notebooks/yolov1_quant.onnx"

In [15]:
showInNetron(model_path)

Stopping http://0.0.0.0:8083
Serving '/home/omar/finn/notebooks/yolov1_quant.onnx' at http://0.0.0.0:8083


When running this notebook in the FINN Docker container, you should be able to see an interactive visualization of the imported network above, and click on individual nodes to inspect their parameters. If you look at any of the MatMul nodes, you should be able to see that the weights are all {-1, +1} values.

# 3. Import into FINN and call cleanup transformations

We will now import this ONNX model into FINN using the ModelWrapper, and examine some of the graph attributes from Python.

In [2]:
from qonnx.util.cleanup import cleanup

model_path_clean = "/home/omaribrahim/Omar/thesis/finn/notebooks/yolov1_quant_clean.onnx"
cleanup(model_path, out_file=model_path_clean)

NameError: name 'model_path' is not defined

In [5]:
showInNetron(model_path_clean)

Stopping http://0.0.0.0:8082
Serving '/home/omaribrahim/Omar/thesis/finn/notebooks/yolov5s_quant_clean.onnx' at http://0.0.0.0:8082


We will now import this QONNX model into FINN using the ModelWrapper. Here we can immediatley execute the model to verify correctness.

Using the `QONNXtoFINN` transformation we can convert the model to the FINN internal FINN-ONNX representation. Notably all Quant and BipolarQuant nodes will have disappeared and are converted into MultiThreshold nodes.

And once again we can execute the model with the FINN/QONNX execution engine.

We have succesfully verified that the transformed and cleaned-up FINN graph still produces the same output, and can now use this model for further processing in FINN.

# 4. Further cleanup and Running the Model

In [3]:

# Print information about input and output tensors
for n in model.graph.node:
    for i in n.input:
        i_shape = model.get_tensor_shape(i)
            print("input: ",i,i_shape)
    for o in n.output:
        o_shape = model.get_tensor_shape(o)
            print("output: ",o,o_shape)

IndentationError: unexpected indent (355756226.py, line 5)

In [9]:
len(model.graph.node[0].output)

1

In [1]:
import json
import numpy as np
import os
import shutil
import warnings
from copy import deepcopy
from distutils.dir_util import copy_tree
from qonnx.core.modelwrapper import ModelWrapper
from qonnx.custom_op.registry import getCustomOp
from qonnx.transformation.bipolar_to_xnor import ConvertBipolarMatMulToXnorPopcount
from qonnx.transformation.fold_constants import FoldConstants
from qonnx.transformation.general import (
    ApplyConfig,
    GiveReadableTensorNames,
    GiveUniqueNodeNames,
    GiveUniqueParameterTensors,
    RemoveStaticGraphInputs,
    RemoveUnusedTensors,
    ConvertSubToAdd,
    ConvertDivToMul,
    SortGraph
)
from qonnx.transformation.infer_data_layouts import InferDataLayouts
from qonnx.transformation.infer_datatypes import InferDataTypes
from qonnx.transformation.infer_shapes import InferShapes
from qonnx.transformation.insert_topk import InsertTopK
from qonnx.transformation.lower_convs_to_matmul import LowerConvsToMatMul
from qonnx.util.cleanup import cleanup_model
from qonnx.util.config import extract_model_config_to_json
from shutil import copy

from qonnx.transformation.double_to_single_float import DoubleToSingleFloat
from qonnx.transformation.remove import RemoveIdentityOps
from finn.transformation.streamline.sign_to_thres import ConvertSignToThres
from qonnx.transformation.batchnorm_to_affine import BatchNormToAffine
from qonnx.core.datatype import DataType

# just for not linear
from finn.transformation.streamline.reorder import (
    MoveLinearPastEltwiseAdd,
    MoveLinearPastFork,
)

import finn.transformation.fpgadataflow.convert_to_hls_layers as to_hls
import finn.transformation.streamline.absorb as absorb
from finn.transformation.streamline.round_thresholds import RoundAndClipThresholds
from finn.transformation.streamline.collapse_repeated import (
    CollapseRepeatedAdd,
    CollapseRepeatedMul,
)

from finn.transformation.streamline.reorder import (
    MoveAddPastMul,
    MoveScalarMulPastMatMul,
    MoveScalarAddPastMatMul,
    MoveAddPastConv,
    MoveScalarMulPastConv,
    MoveScalarLinearPastInvariants,
    MoveMaxPoolPastMultiThreshold,
    MoveMulPastMaxPool,
    MoveTransposePastFork,
    MoveTransposePastScalarMul
)

from finn.analysis.fpgadataflow.dataflow_performance import dataflow_performance
from finn.analysis.fpgadataflow.exp_cycles_per_layer import exp_cycles_per_layer
from finn.analysis.fpgadataflow.hls_synth_res_estimation import hls_synth_res_estimation
from finn.analysis.fpgadataflow.op_and_param_counts import (
    aggregate_dict_keys,
    op_and_param_counts,
)
from finn.analysis.fpgadataflow.res_estimation import (
    res_estimation,
    res_estimation_complete,
)
from finn.builder.build_dataflow_config import (
    DataflowBuildConfig,
    DataflowOutputType,
    ShellFlowType,
    VerificationStepType,
)
from finn.core.onnx_exec import execute_onnx
from finn.core.rtlsim_exec import rtlsim_exec
from finn.core.throughput_test import throughput_test_rtlsim
from finn.transformation.fpgadataflow.annotate_cycles import AnnotateCycles
from finn.transformation.fpgadataflow.compile_cppsim import CompileCppSim
from finn.transformation.fpgadataflow.create_dataflow_partition import (
    CreateDataflowPartition,
)
from finn.transformation.fpgadataflow.create_stitched_ip import CreateStitchedIP
from finn.transformation.fpgadataflow.derive_characteristic import (
    DeriveCharacteristic,
    DeriveFIFOSizes,
)
from finn.transformation.fpgadataflow.hlssynth_ip import HLSSynthIP
from finn.transformation.fpgadataflow.insert_dwc import InsertDWC
from finn.transformation.fpgadataflow.insert_fifo import InsertFIFO
from finn.transformation.fpgadataflow.make_pynq_driver import MakePYNQDriver
from finn.transformation.fpgadataflow.make_zynq_proj import ZynqBuild
from finn.transformation.fpgadataflow.minimize_accumulator_width import (
    MinimizeAccumulatorWidth,
)
from finn.transformation.fpgadataflow.minimize_weight_bit_width import (
    MinimizeWeightBitWidth,
)
from finn.transformation.fpgadataflow.prepare_cppsim import PrepareCppSim
from finn.transformation.fpgadataflow.prepare_ip import PrepareIP
from finn.transformation.fpgadataflow.prepare_rtlsim import PrepareRTLSim
from finn.transformation.fpgadataflow.replace_verilog_relpaths import (
    ReplaceVerilogRelPaths,
)
from finn.transformation.fpgadataflow.set_exec_mode import SetExecMode
from finn.transformation.fpgadataflow.set_fifo_depths import (
    InsertAndSetFIFODepths,
    RemoveShallowFIFOs,
    SplitLargeFIFOs,
)
from finn.transformation.fpgadataflow.set_folding import SetFolding
from finn.transformation.fpgadataflow.synth_ooc import SynthOutOfContext
from finn.transformation.fpgadataflow.vitis_build import VitisBuild
from finn.transformation.move_reshape import RemoveCNVtoFCFlatten
from finn.transformation.qonnx.convert_qonnx_to_finn import ConvertQONNXtoFINN
from finn.transformation.qonnx.quant_act_to_multithreshold import (
    default_filter_function_generator,
)
from finn.transformation.streamline import Streamline
from finn.transformation.streamline.reorder import MakeMaxPoolNHWC
from finn.util.basic import (
    get_rtlsim_trace_depth,
    pyverilate_get_liveness_threshold_cycles,
)
from finn.util.pyverilator import verilator_fifosim
from finn.util.test import execute_parent
from finn.util.visualization import showSrc, showInNetron
import onnx
from qonnx.util.cleanup import cleanup_model

In [2]:
def step_yolo_tidy(model: ModelWrapper, cfg: DataflowBuildConfig):
    model = cleanup_model(model)
    model = model.transform(ConvertQONNXtoFINN())
    model = model.transform(GiveUniqueParameterTensors())
    model = model.transform(InferShapes())
    model = model.transform(FoldConstants())
    model = model.transform(RemoveStaticGraphInputs())
    model = model.transform(GiveUniqueNodeNames())
    model = model.transform(GiveReadableTensorNames())
    model = model.transform(InferDataTypes())
    model = model.transform(InsertTopK())
    model = model.transform(InferShapes())
    model = model.transform(GiveUniqueNodeNames())
    model = model.transform(GiveReadableTensorNames())
    model = model.transform(InferDataTypes())
    return model

def step_yolo_streamline_linear(model: ModelWrapper, cfg: DataflowBuildConfig):
    streamline_transformations = [
        absorb.AbsorbScalarMulAddIntoTopK(),  # before MoveAddPastMul to avoid int->float
        ConvertSubToAdd(),
        ConvertDivToMul(),
        RemoveIdentityOps(),
        CollapseRepeatedMul(),
        BatchNormToAffine(),
        ConvertSignToThres(),
        MoveAddPastMul(),
        MoveScalarAddPastMatMul(),
        MoveAddPastConv(),
        MoveScalarMulPastMatMul(),
        MoveScalarMulPastConv(),
        MoveScalarLinearPastInvariants(),
        MoveAddPastMul(),
        CollapseRepeatedAdd(),
        CollapseRepeatedMul(),
        absorb.AbsorbAddIntoMultiThreshold(),
        absorb.FactorOutMulSignMagnitude(),
        MoveMulPastMaxPool(),
        MoveMaxPoolPastMultiThreshold(),
        absorb.AbsorbMulIntoMultiThreshold(),
        absorb.Absorb1BitMulIntoMatMul(),
        absorb.Absorb1BitMulIntoConv(),
        MakeMaxPoolNHWC(),
        absorb.AbsorbConsecutiveTransposes(),
        RoundAndClipThresholds(),
    ]
    for trn in streamline_transformations:
        model = model.transform(trn)
        model = model.transform(GiveUniqueNodeNames())
    return model

def step_yolo_streamline_nonlinear(model: ModelWrapper, cfg: DataflowBuildConfig):
    streamline_transformations = [
        MoveLinearPastEltwiseAdd(),
        MoveLinearPastFork(),
    ]
    for trn in streamline_transformations:
        model = model.transform(trn)
        model = model.transform(GiveUniqueNodeNames())
    return model

def step_yolo_streamline(model: ModelWrapper, cfg: DataflowBuildConfig):

    for iter_id in range(8):
        model = step_yolo_streamline_linear(model, cfg)
        model = step_yolo_streamline_nonlinear(model, cfg)

        # big loop tidy up
        model = model.transform(RemoveUnusedTensors())
        model = model.transform(GiveReadableTensorNames())
        model = model.transform(InferDataTypes())
        model = model.transform(SortGraph())

    model = model.transform(DoubleToSingleFloat())

    return model

def step_yolo_convert_to_hls(model: ModelWrapper, cfg: DataflowBuildConfig):
    model.set_tensor_datatype(model.graph.input[0].name, DataType["UINT8"])
    model = model.transform(InferDataLayouts())
    model = model.transform(DoubleToSingleFloat())
    model = model.transform(InferDataTypes())
    model = model.transform(SortGraph())

    to_hls_transformations = [
        to_hls.InferAddStreamsLayer,
        LowerConvsToMatMul,
        to_hls.InferChannelwiseLinearLayer,
        absorb.AbsorbTransposeIntoMultiThreshold,
        RoundAndClipThresholds,
        to_hls.InferQuantizedMatrixVectorActivation,
        #MoveTransposePastFork,
        MakeMaxPoolNHWC,
        absorb.AbsorbConsecutiveTransposes,
        to_hls.InferStreamingMaxPool,
        to_hls.InferConvInpGen,
        to_hls.InferDuplicateStreamsLayer,
        to_hls.InferLabelSelectLayer,
    ]
    for iter_id in range(4):
        for trn in to_hls_transformations:
            if trn.__name__=="InferConvInpGen":
                model = model.transform(trn(cfg.force_rtl_conv_inp_gen))
            else:
                model = model.transform(trn())
    
            model = model.transform(InferDataLayouts())
            model = model.transform(GiveUniqueNodeNames())
            model = model.transform(InferDataTypes())
        
    model = model.transform(RemoveCNVtoFCFlatten())
    model = model.transform(GiveReadableTensorNames())
    model = model.transform(RemoveUnusedTensors())
    model = model.transform(SortGraph())

    return model

def step_create_dataflow_partition(model: ModelWrapper, cfg: DataflowBuildConfig):
    """Separate consecutive groups of HLSCustomOp nodes into StreamingDataflowPartition
    nodes, which point to a separate ONNX file. Dataflow accelerator synthesis
    can only be performed on those HLSCustomOp sub-graphs."""

    parent_model = model.transform(
        CreateDataflowPartition(
            partition_model_dir=cfg.output_dir + "/intermediate_models/supported_op_partitions"
        )
    )
    sdp_nodes = parent_model.get_nodes_by_op_type("StreamingDataflowPartition")
    assert len(sdp_nodes) == 1, "Only a single StreamingDataflowPartition supported."
    sdp_node = sdp_nodes[0]
    sdp_node = getCustomOp(sdp_node)
    dataflow_model_filename = sdp_node.get_nodeattr("model")
    if cfg.save_intermediate_models:
        parent_model.save(cfg.output_dir + "/intermediate_models/dataflow_parent.onnx")
    model = ModelWrapper(dataflow_model_filename)
    return model

def step_hls_codegen(model: ModelWrapper, cfg: DataflowBuildConfig):
    "Generate Vivado HLS code to prepare HLSCustomOp nodes for IP generation."

    model = model.transform(PrepareIP(cfg._resolve_fpga_part(), cfg._resolve_hls_clk_period()))
    return model



In [4]:
import finn.builder.build_dataflow as build
import finn.builder.build_dataflow_config as build_cfg
from qonnx.util.cleanup import cleanup
import os
import shutil

model_dir = os.environ['FINN_ROOT'] + "/notebooks"

estimates_output_dir = os.environ['FINN_ROOT'] + "/notebooks/lpyoloW4A4"

#Delete previous run results if exist
if os.path.exists(estimates_output_dir):
    shutil.rmtree(estimates_output_dir)
    print("Previous run results deleted!")


cfg_estimates = build.DataflowBuildConfig(
    output_dir          = estimates_output_dir,
    mvau_wwidth_max     = 80,
    target_fps          = 100,
    synth_clk_period_ns = 10.0,
    fpga_part           = "xcu280-fsvh2892-2L-e",
    verbose             =  True,
    steps               = build_cfg.yolo_build_steps,
    save_intermediate_models = True,
    force_rtl_conv_inp_gen = True,
    generate_outputs=[
        build_cfg.DataflowOutputType.ESTIMATE_REPORTS,
    ]
)

model_name = "lpyoloW4A4"

model_path = "/home/omar/finn/notebooks/{}_quant.onnx".format(model_name)

model_path_clean = "/home/omar/finn/notebooks/{}_quant_clean.onnx".format(model_name)
cleanup(model_path, out_file=model_path_clean)

model_path_transformed = "/home/omar/finn/notebooks/{}_quant_transformed.onnx".format(model_name)

# model = ModelWrapper(model_path_add_trans)
# model = step_yolo_convert_to_hls(model,cfg_estimates)
# model.save(model_path_transformed.split(".onnx")[0]+"hls.onnx")
# model = step_create_dataflow_partition(model,cfg_estimates)
# model.save(model_path_transformed.split(".onnx")[0]+"partition.onnx")

model = ModelWrapper(model_path_clean)
model.save(model_path_transformed.split(".onnx")[0]+"before.onnx")
model = step_yolo_tidy(model,cfg_estimates)
model.save(model_path_transformed.split(".onnx")[0]+"tidy.onnx")
model = step_yolo_streamline(model,cfg_estimates)
model.save(model_path_transformed.split(".onnx")[0]+"streamline.onnx")
model = step_yolo_convert_to_hls(model,cfg_estimates)
model.save(model_path_transformed.split(".onnx")[0]+"hls.onnx")
model = step_create_dataflow_partition(model,cfg_estimates)
model.save(model_path_transformed.split(".onnx")[0]+"partition.onnx")
model = step_hls_codegen(model,cfg_estimates)
model.save(model_path_transformed.split(".onnx")[0]+"hls_codegen.onnx")
model.save(model_path_transformed)

Previous run results deleted!


AssertionError: Thresholds in MatrixVectorActivation_7 can't be expressed with type INT32

In [20]:
showInNetron(model_path_transformed.split(".onnx")[0]+"tidy.onnx")

Stopping http://0.0.0.0:8078
Serving '/home/omar/finn/notebooks/lpyoloW8A8_quant_transformedtidy.onnx' at http://0.0.0.0:8078


In [5]:
showInNetron(model_path_transformed.split(".onnx")[0]+"streamline.onnx")

Serving '/home/omar/finn/notebooks/lpyoloW4A4_quant_transformedstreamline.onnx' at http://0.0.0.0:8078


In [11]:
showInNetron(model_path_transformed.split(".onnx")[0]+"hls.onnx")

Stopping http://0.0.0.0:8078
Serving '/home/omar/finn/notebooks/lpyoloW4A4_quant_transformedhls.onnx' at http://0.0.0.0:8078


In [14]:
showInNetron(model_path_transformed.split(".onnx")[0]+"partition.onnx")

Stopping http://0.0.0.0:8078
Serving '/home/omar/finn/notebooks/lpyoloW4A4_quant_transformedpartition.onnx' at http://0.0.0.0:8078


In [11]:
showInNetron(model_path_transformed.split(".onnx")[0]+"hls_codegen.onnx")

Stopping http://0.0.0.0:8082
Serving '/home/omar/finn/notebooks/yolov1_quant_transformedhls_codegen.onnx' at http://0.0.0.0:8082


In [21]:
showInNetron(estimates_output_dir + "/intermediate_models/supported_op_partitions/partition_0.onnx")

Stopping http://0.0.0.0:8083
Serving '/home/omaribrahim/Omar/thesis/finn/notebooks/build/intermediate_models/supported_op_partitions/partition_0.onnx' at http://0.0.0.0:8083


In [22]:
showInNetron(estimates_output_dir + "/intermediate_models/dataflow_parent.onnx")

Stopping http://0.0.0.0:8083
Serving '/home/omaribrahim/Omar/thesis/finn/notebooks/build/intermediate_models/dataflow_parent.onnx' at http://0.0.0.0:8083


In [145]:
model.graph.node[3]

input: "/model.0/conv/export_handler/Add_output_0"
input: "/model.0/act/export_handler/Constant_output_0"
output: "/model.0/act/export_handler/MultiThreshold_output_0"
name: "/model.0/act/export_handler/MultiThreshold"
op_type: "MultiThreshold"
attribute {
  name: "out_dtype"
  s: "UINT8"
  type: STRING
}
domain: "qonnx.custom_op.general"

In [189]:
model.graph.node[0].input[0]

'inp.1'

In [163]:
print(model.get_tensor_shape(model.graph.output[0].name))

[0, 0, 0]


In [None]:
for idx in range(len(model.graph.node)):
    print(model.get_tensor_datatype(model.graph.node[idx].input[0]))

In [None]:
for n in model.graph.node:
    for i in n.input:
        i_shape = model.get_tensor_shape(i)
        print("input: ",i,i_shape)
    for o in n.output:
        o_shape = model.get_tensor_shape(o)
        print("output: ",o,o_shape)


In [16]:
import cv2
import numpy as np
import onnx.numpy_helper as nph

img = cv2.imread(os.environ['FINN_ROOT'] + "/notebooks/zidane.jpg")
img = cv2.resize(img, (384,640))
img = np.float32(img)
img = np.reshape(img, (-1,3,384,640))
img_tensor = nph.from_array(img)

In [17]:
import finn.core.onnx_exec as oxe
model = ModelWrapper(model_path_clean)
input_dict = {"global_in": nph.to_array(img_tensor)}
output_dict = oxe.execute_onnx(model, input_dict)
produced_finn = output_dict[list(output_dict.keys())[0]]

produced_finn[0]

array([[-5.4415359e+00,  2.2633389e+01,  5.2430069e+01, ...,
         3.8146973e-04,  3.5927892e-03,  8.9058280e-03],
       [ 3.1717010e+01,  3.3369095e+01,  1.2519133e+02, ...,
         6.1967969e-04,  4.4372380e-03,  6.3653886e-03],
       [ 5.6003559e+01,  3.4059624e+01,  1.6523679e+02, ...,
         4.9712956e-03,  3.5220087e-03,  1.3423741e-02],
       ...,
       [ 5.5408630e+02,  3.5350375e+02,  2.0497575e+02, ...,
         3.4766406e-02,  7.7051461e-02,  1.9621313e-02],
       [ 5.7463354e+02,  3.6684723e+02,  1.7511247e+02, ...,
         4.7931284e-02,  5.0269932e-02,  8.8611543e-03],
       [ 6.0777545e+02,  3.5122327e+02,  2.3339964e+02, ...,
         4.3502390e-02,  1.3307890e-01,  1.1202604e-02]], dtype=float32)

In [18]:
import time
import torch
import torchvision

def xywh2xyxy(x):
    # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
    y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
    y[..., 0] = x[..., 0] - x[..., 2] / 2  # top left x
    y[..., 1] = x[..., 1] - x[..., 3] / 2  # top left y
    y[..., 2] = x[..., 0] + x[..., 2] / 2  # bottom right x
    y[..., 3] = x[..., 1] + x[..., 3] / 2  # bottom right y
    return y

def non_max_suppression(
        prediction,
        conf_thres=0.25,
        iou_thres=0.45,
        classes=None,
        agnostic=False,
        multi_label=False,
        labels=(),
        max_det=300,
        nm=0,  # number of masks
):
    """Non-Maximum Suppression (NMS) on inference results to reject overlapping detections

    Returns:
         list of detections, on (n,6) tensor per image [xyxy, conf, cls]
    """

    # Checks
    assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
    assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
    if isinstance(prediction, (list, tuple)):  # YOLOv5 model in validation model, output = (inference_out, loss_out)
        prediction = prediction[0]  # select only inference output

    bs = prediction.shape[0]  # batch size
    nc = prediction.shape[2] - nm - 5  # number of classes
    xc = prediction[..., 4] > conf_thres  # candidates

    # Settings
    # min_wh = 2  # (pixels) minimum box width and height
    max_wh = 7680  # (pixels) maximum box width and height
    max_nms = 30000  # maximum number of boxes into torchvision.ops.nms()
    time_limit = 0.5 + 0.05 * bs  # seconds to quit after
    redundant = True  # require redundant detections
    multi_label &= nc > 1  # multiple labels per box (adds 0.5ms/img)
    merge = False  # use merge-NMS

    t = time.time()
    mi = 5 + nc  # mask start index
    output = [torch.zeros((0, 6 + nm))] * bs
    for xi, x in enumerate(prediction):  # image index, image inference
        # Apply constraints
        # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0  # width-height
        x = x[xc[xi]]  # confidence

        # Cat apriori labels if autolabelling
        if labels and len(labels[xi]):
            lb = labels[xi]
            v = torch.zeros((len(lb), nc + nm + 5))
            v[:, :4] = lb[:, 1:5]  # box
            v[:, 4] = 1.0  # conf
            v[range(len(lb)), lb[:, 0].long() + 5] = 1.0  # cls
            x = torch.cat((x, v), 0)

        # If none remain process next image
        if not x.shape[0]:
            continue

        # Compute conf
        x[:, 5:] *= x[:, 4:5]  # conf = obj_conf * cls_conf

        # Box/Mask
        box = xywh2xyxy(x[:, :4])  # center_x, center_y, width, height) to (x1, y1, x2, y2)
        mask = x[:, mi:]  # zero columns if no masks

        # Detections matrix nx6 (xyxy, conf, cls)
        if multi_label:
            i, j = (x[:, 5:mi] > conf_thres).nonzero(as_tuple=False).T
            x = torch.cat((box[i], x[i, 5 + j, None], j[:, None].float(), mask[i]), 1)
        else:  # best class only
            conf, j = torch.max(x[:, 5:mi],1, keepdim=True)
            x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]

        # Filter by class
        if classes is not None:
            x = x[(x[:, 5:6] == torch.tensor(classes)).any(1)]

        # Apply finite constraint
        # if not torch.isfinite(x).all():
        #     x = x[torch.isfinite(x).all(1)]

        # Check shape
        n = x.shape[0]  # number of boxes
        if not n:  # no boxes
            continue
        x = x[x[:, 4].argsort(descending=True)[:max_nms]]  # sort by confidence and remove excess boxes

        # Batched NMS
        c = x[:, 5:6] * (0 if agnostic else max_wh)  # classes
        boxes, scores = x[:, :4] + c, x[:, 4]  # boxes (offset by class), scores
        i = torchvision.ops.nms(boxes, scores, iou_thres)  # NMS
        i = i[:max_det]  # limit detections
        if merge and (1 < n < 3E3):  # Merge NMS (boxes merged using weighted mean)
            # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
            iou = box_iou(boxes[i], boxes) > iou_thres  # iou matrix
            weights = iou * scores[None]  # box weights
            x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True)  # merged boxes
            if redundant:
                i = i[iou.sum(1) > 1]  # require redundancy

        output[xi] = x[i]
        if (time.time() - t) > time_limit:
            LOGGER.warning(f'WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded')
            break  # time limit exceeded

    return output

In [19]:
non_max_suppression(torch.from_numpy(produced_finn))

[tensor([], size=(0, 6))]