In [None]:
import os
os.chdir('../')

In [None]:
from torch.fft import rfft as fft
from scipy.signal import check_COLA, get_window
import torch.nn as nn

support_clp_op = True

In [None]:
import torch
import torch.nn.functional as F
import sentencepiece as spm
import librosa
import numpy as np
from huggingface_hub import hf_hub_download

from models.encoder import AudioEncoder
from models.decoder import Decoder
from models.jointer import Jointer

from constants import SAMPLE_RATE, N_FFT, HOP_LENGTH, N_MELS
from constants import RNNT_BLANK, PAD, VOCAB_SIZE, TOKENIZER_MODEL_PATH, MAX_SYMBOLS
from constants import ATTENTION_CONTEXT_SIZE
from constants import N_STATE, N_LAYER, N_HEAD

In [None]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# Load trained model

In [None]:
# This model is FP32
trained_model_path = hf_hub_download(
    repo_id="hkab/vietnamese-asr-model", 
    filename="rnnt-latest.ckpt",
    subfolder="rnnt-whisper-small/80_3"
)

In [None]:
checkpoint = torch.load(trained_model_path, map_location="cpu", weights_only=True )

encoder_weight = {}
decoder_weight = {}
joint_weight = {}

for k, v in checkpoint['state_dict'].items():
    if 'alibi' in k:
        continue
    if 'encoder' in k:
        encoder_weight[k.replace('encoder.', '')] = v
    elif 'decoder' in k:
        decoder_weight[k.replace('decoder.', '')] = v
    elif 'joint' in k:
        joint_weight[k.replace('joint.', '')] = v

In [None]:
encoder = AudioEncoder(
    n_mels=N_MELS,
    n_state=N_STATE,
    n_head=N_HEAD,
    n_layer=N_LAYER,
    att_context_size=ATTENTION_CONTEXT_SIZE
)

decoder = Decoder(vocab_size=VOCAB_SIZE + 1)

joint = Jointer(vocab_size=VOCAB_SIZE + 1)

encoder.load_state_dict(encoder_weight, strict=False)
decoder.load_state_dict(decoder_weight, strict=False)
joint.load_state_dict(joint_weight, strict=False)

In [None]:
encoder = encoder.to(DEVICE)
decoder = decoder.to(DEVICE)
joint = joint.to(DEVICE)

encoder.eval()
decoder.eval()
joint.eval()

In [None]:
# Exportable STFT from https://github.com/echocatzh/conv-stft/blob/master/conv_stft/conv_stft.py
class STFT(torch.nn.Module):
    def __init__(self, win_len=1024, win_hop=512, fft_len=1024,
                 enframe_mode='continue', win_type='hann',
                 win_sqrt=False, pad_center=True):
        """
        Implement of STFT using 1D convolution and 1D transpose convolutions.
        Implement of framing the signal in 2 ways, `break` and `continue`.
        `break` method is a kaldi-like framing.
        `continue` method is a librosa-like framing.

        More information about `perfect reconstruction`:
        1. https://ww2.mathworks.cn/help/signal/ref/stft.html
        2. https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.get_window.html

        Args:
            win_len (int): Number of points in one frame.  Defaults to 1024.
            win_hop (int): Number of framing stride. Defaults to 512.
            fft_len (int): Number of DFT points. Defaults to 1024.
            enframe_mode (str, optional): `break` and `continue`. Defaults to 'continue'.
            win_type (str, optional): The type of window to create. Defaults to 'hann'.
            win_sqrt (bool, optional): using square root window. Defaults to True.
            pad_center (bool, optional): `perfect reconstruction` opts. Defaults to True.
        """
        super(STFT, self).__init__()
        assert enframe_mode in ['break', 'continue']
        assert fft_len >= win_len
        self.win_len = win_len
        self.win_hop = win_hop
        self.fft_len = fft_len
        self.mode = enframe_mode
        self.win_type = win_type
        self.win_sqrt = win_sqrt
        self.pad_center = pad_center
        self.pad_amount = self.fft_len // 2

        en_k, fft_k, ifft_k, ola_k = self.__init_kernel__()
        self.register_buffer('en_k', en_k)
        self.register_buffer('fft_k', fft_k)
        self.register_buffer('ifft_k', ifft_k)
        self.register_buffer('ola_k', ola_k)

    def __init_kernel__(self):
        """
        Generate enframe_kernel, fft_kernel, ifft_kernel and overlap-add kernel.
        ** enframe_kernel: Using conv1d layer and identity matrix.
        ** fft_kernel: Using linear layer for matrix multiplication. In fact,
        enframe_kernel and fft_kernel can be combined, But for the sake of 
        readability, I took the two apart.
        ** ifft_kernel, pinv of fft_kernel.
        ** overlap-add kernel, just like enframe_kernel, but transposed.
        
        Returns:
            tuple: four kernels.
        """
        enframed_kernel = torch.eye(self.fft_len)[:, None, :]
        if support_clp_op:
            tmp = fft(torch.eye(self.fft_len))
            fft_kernel = torch.stack([tmp.real, tmp.imag], dim=2)
        else:
            fft_kernel = fft(torch.eye(self.fft_len), 1)
        if self.mode == 'break':
            enframed_kernel = torch.eye(self.win_len)[:, None, :]
            fft_kernel = fft_kernel[:self.win_len]
        fft_kernel = torch.cat(
            (fft_kernel[:, :, 0], fft_kernel[:, :, 1]), dim=1)
        ifft_kernel = torch.pinverse(fft_kernel)[:, None, :]
        window = get_window(self.win_type, self.win_len, fftbins=False)

        self.perfect_reconstruct = check_COLA(
            window,
            self.win_len,
            self.win_len-self.win_hop)
        window = torch.FloatTensor(window)
        if self.mode == 'continue':
            left_pad = (self.fft_len - self.win_len)//2
            right_pad = left_pad + (self.fft_len - self.win_len) % 2
            window = F.pad(window, (left_pad, right_pad))
        if self.win_sqrt:
            self.padded_window = window
            window = torch.sqrt(window)
        else:
            self.padded_window = window**2

        fft_kernel = fft_kernel.T * window
        ifft_kernel = ifft_kernel * window
        ola_kernel = torch.eye(self.fft_len)[:self.win_len, None, :]
        if self.mode == 'continue':
            ola_kernel = torch.eye(self.fft_len)[:, None, :self.fft_len]
        return enframed_kernel, fft_kernel, ifft_kernel, ola_kernel

    def is_perfect(self):
        """
        Whether the parameters win_len, win_hop and win_sqrt
        obey constants overlap-add(COLA)

        Returns:
            bool: Return true if parameters obey COLA.
        """
        return self.perfect_reconstruct and self.pad_center

    def transform(self, inputs, return_type='complex'):
        """Take input data (audio) to STFT domain.

        Args:
            inputs (tensor): Tensor of floats, with shape (num_batch, num_samples)
            return_type (str, optional): return (mag, phase) when `magphase`,
            return (real, imag) when `realimag` and complex(real, imag) when `complex`.
            Defaults to 'complex'.

        Returns:
            tuple: (mag, phase) when `magphase`, return (real, imag) when
            `realimag`. Defaults to 'complex', each elements with shape 
            [num_batch, num_frequencies, num_frames]
        """
        assert return_type in ['magphase', 'realimag', 'complex']
        if inputs.dim() == 2:
            inputs = torch.unsqueeze(inputs, 1)
        self.num_samples = inputs.size(-1)
        if self.pad_center:
            inputs = F.pad(
                inputs, (self.pad_amount, self.pad_amount), mode='reflect')
        enframe_inputs = F.conv1d(inputs, self.en_k, stride=self.win_hop)
        outputs = torch.transpose(enframe_inputs, 1, 2)
        outputs = F.linear(outputs, self.fft_k)
        outputs = torch.transpose(outputs, 1, 2)
        dim = self.fft_len//2+1
        real = outputs[:, :dim, :]
        imag = outputs[:, dim:, :]
        if return_type == 'realimag':
            return real, imag
        elif return_type == 'complex':
            assert support_clp_op
            return torch.complex(real, imag)
        else:
            mags = torch.sqrt(real**2+imag**2)
            phase = torch.atan2(imag, real)
            return mags, phase

    def inverse(self, input1, input2=None, input_type='magphase'):
        """Call the inverse STFT (iSTFT), given tensors produced 
        by the `transform` function.

        Args:
            input1 (tensors): Magnitude/Real-part of STFT with shape 
            [num_batch, num_frequencies, num_frames]
            input2 (tensors): Phase/Imag-part of STFT with shape
            [num_batch, num_frequencies, num_frames]
            input_type (str, optional): Mathematical meaning of input tensor's.
            Defaults to 'magphase'.

        Returns:
            tensors: Reconstructed audio given magnitude and phase. Of
                shape [num_batch, num_samples]
        """
        assert input_type in ['magphase', 'realimag']
        if input_type == 'realimag':
            real, imag = None, None
            if support_clp_op and torch.is_complex(input1):
                real, imag = input1.real, input1.imag
            else:
                real, imag = input1, input2
        else:
            real = input1*torch.cos(input2)
            imag = input1*torch.sin(input2)
        inputs = torch.cat([real, imag], dim=1)
        outputs = F.conv_transpose1d(inputs, self.ifft_k, stride=self.win_hop)
        t = (self.padded_window[None, :, None]).repeat(1, 1, inputs.size(-1))
        t = t.to(inputs.device)
        coff = F.conv_transpose1d(t, self.ola_k, stride=self.win_hop)
        rm_start, rm_end = self.pad_amount, self.pad_amount+self.num_samples
        outputs = outputs[..., rm_start:rm_end]
        coff = coff[..., rm_start:rm_end]
        coffidx = torch.where(coff > 1e-8)
        outputs[coffidx] = outputs[coffidx]/(coff[coffidx])
        return outputs.squeeze(dim=1)

    def forward(self, inputs):
        """Take input data (audio) to STFT domain and then back to audio.

        Args:
            inputs (tensor): Tensor of floats, with shape [num_batch, num_samples]

        Returns:
            tensor: Reconstructed audio given magnitude and phase.
            Of shape [num_batch, num_samples]
        """
        mag, phase = self.transform(inputs)
        rec_wav = self.inverse(mag, phase)
        return rec_wav

In [None]:
class WrapperPreprocessor(nn.Module):
	def __init__(self):
		super().__init__()
		self.stft = STFT(
            win_len=400, win_hop=160, fft_len=400,
			pad_center=False # For streaming
        )
		# self.filters = torch.from_numpy(librosa.filters.mel(sr=16000, n_fft=400, n_mels=80))

		self.register_buffer('filters', torch.from_numpy(librosa.filters.mel(sr=SAMPLE_RATE, n_fft=N_FFT, n_mels=N_MELS)))
	
	def forward(self, audio_signal):
		mags, _ = self.stft.transform(audio_signal, return_type='magphase')
		mags = mags**2

		audio_signal = self.filters @ mags
		audio_signal = torch.clamp(audio_signal, min=1e-10).log10()
		audio_signal = (audio_signal + 4.0) / 4.0

		return audio_signal

In [None]:
class WrapperEncoderALiBi(nn.Module):
    def __init__(self, encoder):
        super().__init__()

        self.encoder = encoder.to(DEVICE)
        self.preprocessor = WrapperPreprocessor().to(DEVICE)

    def forward(self, 
                audio_chunk, audio_cache, 
                conv1_cache, conv2_cache, conv3_cache,
                k_cache, v_cache, cache_len):
        audio_chunk = torch.cat([audio_cache, audio_chunk], dim=1)
        audio_cache = audio_chunk[:, -(N_FFT - HOP_LENGTH):]

        x_chunk = self.preprocessor(audio_chunk)
        x_chunk = torch.cat([conv1_cache, x_chunk], dim=2)

        conv1_cache = x_chunk[:, :, -1].unsqueeze(2)
        x_chunk = F.gelu(self.encoder.conv1(x_chunk))

        x_chunk = torch.cat([conv2_cache, x_chunk], dim=2)
        conv2_cache = x_chunk[:, :, -1].unsqueeze(2)
        x_chunk = F.gelu(self.encoder.conv2(x_chunk))

        x_chunk = torch.cat([conv3_cache, x_chunk], dim=2)
        conv3_cache = x_chunk[:, :, -1].unsqueeze(2)
        x_chunk = F.gelu(self.encoder.conv3(x_chunk))

        x_chunk = x_chunk.permute(0, 2, 1)

        x_len = torch.tensor([ATTENTION_CONTEXT_SIZE[0] + ATTENTION_CONTEXT_SIZE[1] + 1]).to(DEVICE)
        offset = torch.neg(cache_len) + ATTENTION_CONTEXT_SIZE[0]

        attn_mask = self.encoder.form_attention_mask_for_streaming(ATTENTION_CONTEXT_SIZE, x_len, offset.to(DEVICE), DEVICE)
        attn_mask = attn_mask[:, :, ATTENTION_CONTEXT_SIZE[0]:, :]

        new_k_cache = []
        new_v_cache = []
        for i, block in enumerate(self.encoder.blocks):
            x_chunk, layer_k_cache, layer_v_cache = block(x_chunk, mask=attn_mask, k_cache=k_cache[i], v_cache=v_cache[i])
            new_k_cache.append(layer_k_cache)
            new_v_cache.append(layer_v_cache)

        enc_out = self.encoder.ln_post(x_chunk)

        k_cache = torch.stack(new_k_cache, dim=0)
        v_cache = torch.stack(new_v_cache, dim=0)
        cache_len = torch.clamp(cache_len + ATTENTION_CONTEXT_SIZE[1] + 1, max=ATTENTION_CONTEXT_SIZE[0])

        return enc_out, audio_cache, conv1_cache, conv2_cache, conv3_cache, k_cache, v_cache, cache_len

In [None]:
export_encoder = WrapperEncoderALiBi(encoder)
export_encoder.eval()

In [None]:
audio_chunk = torch.zeros(1, HOP_LENGTH * 31 + N_FFT - (N_FFT - HOP_LENGTH), device=DEVICE)
audio_cache = torch.zeros(1, N_FFT - HOP_LENGTH, device=DEVICE)
conv1_cache = torch.zeros(1, 80, 1, device=DEVICE)
conv2_cache = torch.zeros(1, 768, 1, device=DEVICE)
conv3_cache = torch.zeros(1, 768, 1, device=DEVICE)

k_cache = torch.zeros(12, 1, ATTENTION_CONTEXT_SIZE[0], 768, device=DEVICE)
v_cache = torch.zeros(12, 1, ATTENTION_CONTEXT_SIZE[0], 768, device=DEVICE)
cache_len = torch.zeros(1, dtype=torch.int, device=DEVICE)

r = export_encoder(
    audio_chunk, audio_cache, 
    conv1_cache, conv2_cache, conv3_cache, 
    k_cache, v_cache, cache_len
)

In [None]:
with torch.inference_mode():
    torch.onnx.export(
        export_encoder,
        (audio_chunk, audio_cache, conv1_cache, conv2_cache, conv3_cache, k_cache, v_cache, cache_len),
        "./onnx/encoder.onnx",
        input_names=["audio_chunk", "audio_cache", "conv1_cache", "conv2_cache", "conv3_cache", "k_cache", "v_cache", "cache_len"],
        output_names=["enc_out", "audio_cache", "conv1_cache", "conv2_cache", "conv3_cache", "k_cache", "v_cache", "cache_len"],
        export_params=True,
		opset_version=17,
		do_constant_folding=False, # Must be false for alibi
    )

In [None]:
class WrapperDecoder(nn.Module):
    def __init__(self, decoder):
        super().__init__()
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.decoder = decoder.to(self.device)

    def forward(self, token, h_n):
        dec, h_n = self.decoder(token, h_n)
        return dec, h_n

In [None]:
decoder = WrapperDecoder(decoder)
decoder.eval()

In [None]:
token = torch.tensor([[RNNT_BLANK]], dtype=torch.long, device=DEVICE)
h_n = torch.zeros(1, 1, 768, device=DEVICE)

r = decoder(token, h_n)

In [None]:
with torch.inference_mode():
    torch.onnx.export(
        decoder,
        (token, h_n),
        "./onnx/decoder.onnx",
        input_names=["token", "h_n"],
        output_names=["dec", "h_n"],
        export_params=True,
        opset_version=17,
    )

In [None]:
class WrapperJoint(nn.Module):
    def __init__(self, joint):
        super().__init__()
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.joint = joint.to(self.device)

    def forward(self, enc, dec):
        return self.joint(enc, dec)[0, 0, 0, :]

In [None]:
jointer = WrapperJoint(joint)
jointer.eval()

In [None]:
enc = torch.zeros(1, 1, 768, device=DEVICE)
dec = torch.zeros(1, 1, 768, device=DEVICE)

r = jointer(enc, dec)

In [None]:
with torch.inference_mode():
    torch.onnx.export(
        jointer,
        (enc, dec),
        "./onnx/jointer.onnx",
        input_names=["enc", "dec"],
        output_names=["output"],
        export_params=True,
        opset_version=17,
    )

## Quantization

In [None]:
import onnxruntime as ort
import librosa
from tqdm import tqdm
import time
import onnx
from onnxruntime.quantization import quantize_dynamic, QuantType

In [None]:
# RUN: python -m onnxruntime.quantization.preprocess --input jointer.onnx --output jointer-infer.onnx before quantization
import subprocess

subprocess.run(["python", "-m", "onnxruntime.quantization.preprocess", "--input", "./onnx/encoder.onnx", "--output", "./onnx/encoder-infer.onnx"])
subprocess.run(["python", "-m", "onnxruntime.quantization.preprocess", "--input", "./onnx/decoder.onnx", "--output", "./onnx/decoder-infer.onnx"])
subprocess.run(["python", "-m", "onnxruntime.quantization.preprocess", "--input", "./onnx/jointer.onnx", "--output", "./onnx/jointer-infer.onnx"])

In [None]:
quantize_dynamic(
    './onnx/encoder-infer.onnx', 
    './onnx/encoder-infer.quant.onnx',
    weight_type=QuantType.QInt8,
    op_types_to_quantize=['MatMul'])

quantize_dynamic(
    './onnx/decoder-infer.onnx', 
    './onnx/decoder-infer.quant.onnx',
    weight_type=QuantType.QInt8,
    op_types_to_quantize=['GRU'])

quantize_dynamic(
    './onnx/jointer-infer.onnx', 
    './onnx/jointer-infer.quant.onnx',
    weight_type=QuantType.QInt8,
    op_types_to_quantize=['MatMul'])

In [None]:
ort_encoder_session = ort.InferenceSession("./onnx/encoder-infer.onnx")
ort_decoder_session = ort.InferenceSession("./onnx/decoder-infer.onnx")
ort_jointer_session = ort.InferenceSession("./onnx/jointer-infer.onnx")

In [None]:
def onnx_online_inference(audio, ort_encoder_session, ort_decoder_session, ort_jointer_session, tokenizer):
    if type(audio) == torch.Tensor:
        audio = audio.cpu().numpy()

    if audio.ndim == 1:
        audio = np.expand_dims(audio, 0)

    audio_cache = np.zeros((1, N_FFT - HOP_LENGTH), dtype=np.float32)
    conv1_cache = np.zeros((1, 80, 1), dtype=np.float32)
    conv2_cache = np.zeros((1, 768, 1), dtype=np.float32)
    conv3_cache = np.zeros((1, 768, 1), dtype=np.float32)

    k_cache = np.zeros((12, 1, ATTENTION_CONTEXT_SIZE[0], 768), dtype=np.float32)
    v_cache = np.zeros((12, 1, ATTENTION_CONTEXT_SIZE[0], 768), dtype=np.float32)
    cache_len = np.zeros((1,), dtype=np.int32)

    h_n = np.zeros((1, 1, 768), dtype=np.float32)
    token = np.array([[RNNT_BLANK]], dtype=np.int64)

    RTF = audio.shape[1] / SAMPLE_RATE
    seq_ids = []

    start = 0
    for i in range(0, audio.shape[1], HOP_LENGTH * 31 + N_FFT - (N_FFT - HOP_LENGTH)):
        audio_chunk = audio[:, i:i+HOP_LENGTH * 31 + N_FFT - (N_FFT - HOP_LENGTH)]
        if audio_chunk.shape[1] < HOP_LENGTH * 31 + N_FFT - (N_FFT - HOP_LENGTH):
            audio_chunk = np.pad(audio_chunk, ((0, 0), (0, HOP_LENGTH * 31 + N_FFT - (N_FFT - HOP_LENGTH) - audio_chunk.shape[1])))

        r = ort_encoder_session.run(
            None,
            {
                "audio_chunk": audio_chunk,
                "audio_cache.1": audio_cache,
                "conv1_cache.1": conv1_cache,
                "conv2_cache.1": conv2_cache,
                "conv3_cache.1": conv3_cache,
                "k_cache.1": k_cache,
                "v_cache.1": v_cache,
                "cache_len.1": cache_len
            }
        )

        enc_out, audio_cache, conv1_cache, conv2_cache, conv3_cache, k_cache, v_cache, cache_len = r

        for time_idx in range(enc_out.shape[1]):
            curent_seq_enc_out = enc_out[:, time_idx, :].reshape(1, 1, N_STATE)

            not_blank = True
            symbols_added = 0

            while not_blank and symbols_added < 3:
                dec, new_h_n = ort_decoder_session.run(
                    None,
                    {
                        "token": token,
                        "h_n.1": h_n
                    }
                )

                logits = ort_jointer_session.run(
                    None,
                    {
                        "enc": curent_seq_enc_out,
                        "dec": dec
                    }
                )[0]

                new_token = int(logits.argmax())

                if new_token == RNNT_BLANK:
                    not_blank = False
                else:
                    symbols_added += 1
                    token = np.array([[new_token]], dtype=np.int64)
                    h_n = new_h_n
                    seq_ids.append(new_token)
    end = time.time()

    return tokenizer.decode(seq_ids), RTF / (end - start)

In [None]:
audio = librosa.load("/path/to/vnbacnam.m4a", sr=SAMPLE_RATE)[0]
audio = np.pad(audio, (16000, 0)) # add some zeros to the start of the audio for warmup
audio = np.expand_dims(audio, 0).astype(np.float32)

tokenizer = spm.SentencePieceProcessor(model_file=TOKENIZER_MODEL_PATH)

In [None]:
onnx_online_inference(audio, ort_encoder_session, ort_decoder_session, ort_jointer_session, tokenizer)