In [1]:
import brevitas.onnx as bo
import numpy as np
import torch

import finn.core.onnx_exec as oxe
from finn.core.modelwrapper import ModelWrapper
from finn.transformation.fold_constants import FoldConstants
from finn.transformation.general import RemoveStaticGraphInputs
from finn.transformation.infer_shapes import InferShapes
from finn.util.basic import make_build_dir, gen_finn_dt_tensor
from finn.core.datatype import DataType
import brevitas_examples.speech_to_text as stt

export_onnx_path = make_build_dir("test_brevitas_quartznet_")


def test_brevitas_quartznet_onnx_export_and_exec():
    nname = "quartznet-4b"
    finn_onnx = export_onnx_path + "/%s.onnx" % nname
    quartznet_torch = stt.quant_quartznet_perchannelscaling_4b(export_mode=True)
    ishape = (1, 64, 256)
    idt = DataType.FLOAT32
    bo.export_finn_onnx(quartznet_torch, ishape, finn_onnx)
    model = ModelWrapper(finn_onnx)
    model = model.transform(InferShapes())
    model = model.transform(FoldConstants())
    model = model.transform(RemoveStaticGraphInputs())
    assert len(model.graph.input) == 1
    assert len(model.graph.output) == 1
    # generate a random test vector
    iname = model.graph.input[0].name
    oname = model.graph.output[0].name
    np.random.seed(42)
    rand_inp = gen_finn_dt_tensor(idt, ishape)
    # run using FINN-based execution
    input_dict = {iname: rand_inp}
    output_dict = oxe.execute_onnx(model, input_dict)
    produced = output_dict[oname]
    # run using PyTorch/Brevitas
    rand_inp_torch = torch.from_numpy(rand_inp).float()
    # do forward pass in PyTorch/Brevitas
    expected = quartznet_torch.forward(rand_inp_torch).detach().numpy()
    #assert np.isclose(produced, expected, atol=1e-3).all()
    
    return expected, produced

Could not import torchaudio. Some features might not work.


In [None]:
expected_brevitas, produced_finn = test_brevitas_quartznet_onnx_export_and_exec()

=> Loading encoder checkpoint from:'https://github.com/Xilinx/brevitas/releases/download/quant_quartznet_4b-r0/quant_quartznet_encoder_4b-0a46a232.pth'
=> Loading decoder checkpoint from:'https://github.com/Xilinx/brevitas/releases/download/quant_quartznet_4b-r0/quant_quartznet_decoder_4b-bcbf8c7b.pth'
Checkpoint restored


In [None]:
assert np.isclose(expected_brevitas, produced_finn, atol=1e-3).all()