<a href="https://colab.research.google.com/github/finardi/tutos/blob/master/convert_Wav2vec_to_ONNX.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install magic_timer
!pip install -q onnx
!pip install -q onnxruntime
!pip install -q transformers



In [2]:
import torch
import torchaudio
import numpy as np

import onnxruntime as rt
from onnxruntime.quantization.quantize import quantize

from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor

from magic_timer import MagicTimer

In [11]:
CONVERT = False

model_name = "lgris/bp400-xlsr"
quantized_model_name = "/content/drive/MyDrive/Wav2Vec/model_checkpoint/ONNX_weights/"+ model_name.split("/")[-1] + ".quant.onnx"

In [6]:
# script from: https://github.com/ccoreilly/wav2vec2-service/blob/master/convert_torch_to_onnx.py

if CONVERT:

    def convert_to_onnx(model_id_or_path, onnx_model_name):
        print(f"Converting {model_id_or_path} to onnx")
        model = Wav2Vec2ForCTC.from_pretrained(model_id_or_path)
        audio_len = 250000

        x = torch.randn(1, audio_len, requires_grad=True)

        torch.onnx.export(
            model,                          # model being run
            x,                              # model input (or a tuple for multiple inputs)
            onnx_model_name,                # where to save the model (can be a file or file-like object)
            export_params=True,             # store the trained parameter weights inside the model file
            opset_version=11,               # the ONNX version to export the model to
            do_constant_folding=True,       # whether to execute constant folding for optimization
            input_names = ['input'],        # the model's input names
            output_names = ['output'],      # the model's output names
            dynamic_axes={'input' : {1 : 'audio_len'},    # variable length axes
                        'output' : {1 : 'audio_len'}},
                        )

    def quantize_onnx_model(onnx_model_path, quantized_model_path):
        print("Starting quantization...")
        from onnxruntime.quantization import quantize_dynamic, QuantType
        quantize_dynamic(onnx_model_path,
                        quantized_model_path,
                        weight_type=QuantType.QUInt8)

        print(f"Quantized model saved to: {quantized_model_path}")

    # - - - - 
    onnx_model_name = "/content/drive/MyDrive/Wav2Vec/model_checkpoint/ONNX_weights/"+ model_name.split("/")[-1] + ".onnx"
    convert_to_onnx(model_id_or_path, onnx_model_name)
    print('model ONNX exported !!!')
    
    quantize_onnx_model(onnx_model_name, quantized_model_name)
    print('model quantized !!!')

In [8]:
class Wave2Vec2Inference():
    def __init__(self,model_name,onnx_path, model_sample_rate=16_000):
        options = rt.SessionOptions()
        options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_ALL
        self.model_quant = rt.InferenceSession(onnx_path, options)
        
        self.processor = Wav2Vec2Processor.from_pretrained(model_name) 
        self.model_sample_rate = model_sample_rate
        
        self.model = Wav2Vec2ForCTC.from_pretrained(model_name)

    def buffer_to_text(self, audio_tensor, onnx):
        if audio_tensor.shape[0] == 0:
            return ""

        inputs = self.processor(audio_tensor, sampling_rate=self.model_sample_rate, return_tensors="pt")

        input_values = inputs.input_values
        
        if onnx:
            outputs = self.model_quant.run(None, {self.model_quant.get_inputs()[0].name: input_values.numpy()})[0]
            prediction = np.argmax(outputs, axis=-1).squeeze()
            transcription = self.processor.decode(prediction.tolist())
        else:
            outputs = self.model(input_values)[0]
            prediction = torch.argmax(outputs, axis=-1).squeeze()
            transcription = self.processor.decode(prediction.tolist())
        
        return transcription

    def file_to_text(self,filename, onnx=True):
        speech, sample_rate = torchaudio.load(filename)
        resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=self.model_sample_rate)
        speech = resampler.forward(speech.squeeze(0))
        
        return self.buffer_to_text(speech, onnx)

In [39]:
asr = Wave2Vec2Inference(model_name, quantized_model_name)

# Quantized ONNX
timer = MagicTimer()  
predicted = asr.file_to_text("/content/drive/MyDrive/Wav2Vec/data/audio_andrea_3.wav")
print(f'predicted: {predicted}\n\tTempo: {timer}')

# Regular predict
timer = MagicTimer()
predicted = asr.file_to_text("/content/drive/MyDrive/Wav2Vec/data/audio_andrea_3.wav", onnx=False)
print(f'predicted: {predicted}\n\tTempo: {timer}')

predicted: por favor como eu faço para solicitar talão de chaque meu talão de cheque acabou e eu queria solicitar novos talões
	Tempo: 5.9 seconds
predicted: por favor como eu faço pra solicitar talão de cheque meu talão de cheque acabou e eu queria solicitar novos talões
	Tempo: 7.0 seconds


In [40]:
%%time
predicted = asr.file_to_text("/content/drive/MyDrive/Wav2Vec/data/audio_andrea_3.wav")

CPU times: user 11.5 s, sys: 80.7 ms, total: 11.6 s
Wall time: 6.76 s


In [41]:
%%time
predicted = asr.file_to_text("/content/drive/MyDrive/Wav2Vec/data/audio_andrea_3.wav", onnx=False)

CPU times: user 5.86 s, sys: 1.37 s, total: 7.23 s
Wall time: 3.73 s
