In [None]:
import os
import time
import torch
import torchaudio
import numpy as np
import soundfile as sf
import onnxruntime as ort
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor

In [None]:
def convert_to_onnx(model_id_or_path, onnx_model_name):
    print(f"Converting {model_id_or_path} to onnx")
    model = Wav2Vec2ForCTC.from_pretrained(model_id_or_path)
    audio_len = 16000

    dummy_input = torch.randn(1, audio_len, requires_grad=True)

    torch.onnx.export(model,                        # model being run
                    dummy_input,                    # model input (or a tuple for multiple inputs)
                    onnx_model_name,                # where to save the model (can be a file or file-like object)
                    export_params=True,             # store the trained parameter weights inside the model file
                    opset_version=20,               # the ONNX version to export the model to
                    do_constant_folding=True,       # whether to execute constant folding for optimization
                    input_names = ['input'],        # the model's input names
                    output_names = ['output'],      # the model's output names
                    dynamic_axes={'input' : {1 : 'audio_len'},    # variable length axes
                                'output' : {1 : 'audio_len'}})


In [None]:
def predict(file):
  speech_array, sr = sf.read(file)
  features = processor(speech_array, sampling_rate=16000, return_tensors="pt")
  input_values = features.input_values
  onnx_outputs = ort_session.run(None, {ort_session.get_inputs()[0].name: input_values.numpy()})[0]
  prediction = np.argmax(onnx_outputs, axis=-1)
  return processor.decode(prediction.squeeze().tolist())

In [None]:
def inference(file):
    num_trials = 100  # Number of trials
    latencies = []
    speech_array, sr = sf.read(file)
    features = processor(speech_array, sampling_rate=16000, return_tensors="pt")
    input_values = features.input_values
    for _ in range(num_trials):
        start_time = time.time()
        ort_session.run(None, {ort_session.get_inputs()[0].name: input_values.numpy()})[0]
        latencies.append(time.time() - start_time)

    print(f"Inference Latency (single sample, median): {np.percentile(latencies, 50) * 1000:.2f} ms")
    print(f"Inference Latency (single sample, 95th percentile): {np.percentile(latencies, 95) * 1000:.2f} ms")
    print(f"Inference Latency (single sample, 99th percentile): {np.percentile(latencies, 99) * 1000:.2f} ms")
    print(f"Inference Throughput (single sample): {num_trials/np.sum(latencies):.2f} FPS")


In [None]:
SPEECHCOMMANDS_DATA_DIR = os.getenv("SPEECHCOMMANDS_DATA_DIR")
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")
ort_session = ort.InferenceSession("test.onnx")

predict(SPEECHCOMMANDS_DATA_DIR +"/speech_commands_test_set_v0.02_processed/down/"+"022cd682_nohash_0.wav")

In [None]:
inference(SPEECHCOMMANDS_DATA_DIR +"/speech_commands_test_set_v0.02_processed/up/"+"03401e93_nohash_0.wav")
