<a href="https://colab.research.google.com/github/k2-fsa/colab/blob/master/kaldi-native-fbank/speech_recognition_with_kaldi_native_fbank_and_whisper_and_onnxruntime.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Introduction

This colab notebooks demonstrates how to combine [kaldi-native-fbank][kaldi-native-fbank], [Whisper][Whisper], and [onnxruntime][onnxruntime] for speech recognition.

[kaldi-native-fbank]: https://github.com/csukuangfj/kaldi-native-fbank
[Whisper]: https://github.com/openai/whisper/
[onnxruntime]: https://github.com/microsoft/onnxruntime

# Install dependencies

In [1]:
%%shell

pip install kaldi-native-fbank onnxruntime

Collecting kaldi-native-fbank
  Downloading kaldi_native_fbank-1.18.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (210 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m210.5/210.5 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting onnxruntime
  Downloading onnxruntime-1.15.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (5.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.9/5.9 MB[0m [31m62.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting coloredlogs (from onnxruntime)
  Downloading coloredlogs-15.0.1-py2.py3-none-any.whl (46 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m46.0/46.0 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
Collecting humanfriendly>=9.1 (from coloredlogs->onnxruntime)
  Downloading humanfriendly-10.0-py2.py3-none-any.whl (86 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.8/86.8 kB[0m [31m10.5 MB/s[0m eta [36m0:00:00[0m
Installing collec



# Downloand Whisper onnxruntime models

Please see
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/index.html
for details.

In [2]:
%%shell

sudo apt-get install git-lfs

GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-whisper-tiny.en
cd sherpa-onnx-whisper-tiny.en
git lfs pull --include "*.onnx"


Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
git-lfs is already the newest version (3.0.2-1ubuntu0.2).
0 upgraded, 0 newly installed, 0 to remove and 16 not upgraded.
Cloning into 'sherpa-onnx-whisper-tiny.en'...
remote: Enumerating objects: 46, done.[K
remote: Counting objects: 100% (46/46), done.[K
remote: Compressing objects: 100% (45/45), done.[K
remote: Total 46 (delta 5), reused 0 (delta 0), pack-reused 0[K
Unpacking objects: 100% (46/46), 1.00 MiB | 7.38 MiB/s, done.




In [3]:
%%shell

GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-whisper-tiny
cd sherpa-onnx-whisper-tiny
git lfs pull --include "*.onnx"

Cloning into 'sherpa-onnx-whisper-tiny'...
remote: Enumerating objects: 31, done.[K
remote: Counting objects:   3% (1/31)[Kremote: Counting objects:   6% (2/31)[Kremote: Counting objects:   9% (3/31)[Kremote: Counting objects:  12% (4/31)[Kremote: Counting objects:  16% (5/31)[Kremote: Counting objects:  19% (6/31)[Kremote: Counting objects:  22% (7/31)[Kremote: Counting objects:  25% (8/31)[Kremote: Counting objects:  29% (9/31)[Kremote: Counting objects:  32% (10/31)[Kremote: Counting objects:  35% (11/31)[Kremote: Counting objects:  38% (12/31)[Kremote: Counting objects:  41% (13/31)[Kremote: Counting objects:  45% (14/31)[Kremote: Counting objects:  48% (15/31)[Kremote: Counting objects:  51% (16/31)[Kremote: Counting objects:  54% (17/31)[Kremote: Counting objects:  58% (18/31)[Kremote: Counting objects:  61% (19/31)[Kremote: Counting objects:  64% (20/31)[Kremote: Counting objects:  67% (21/31)[Kremote: Counting objects:  70% (22/31)[K



# Compute features with kaldi-native-fbank

In [4]:
import torch
import torchaudio
import kaldi_native_fbank as knf

def compute_features(filename: str) -> torch.Tensor:
    """
    Args:
      filename:
        Path to an audio file.
    Returns:
      Return a 1-D float32 tensor of shape (1, 80, 3000) containing the features.
    """
    wave, sample_rate = torchaudio.load(filename)
    audio = wave[0].contiguous()  # only use the first channel
    if sample_rate != 16000:
        audio = torchaudio.functional.resample(
            audio, orig_freq=sample_rate, new_freq=16000
        )

    features = []
    online_whisper_fbank = knf.OnlineWhisperFbank(knf.FrameExtractionOptions())
    online_whisper_fbank.accept_waveform(16000, audio.numpy())
    online_whisper_fbank.input_finished()
    for i in range(online_whisper_fbank.num_frames_ready):
        f = online_whisper_fbank.get_frame(i)
        f = torch.from_numpy(f)
        features.append(f)

    features = torch.stack(features)

    log_spec = torch.clamp(features, min=1e-10).log10()
    log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
    mel = (log_spec + 4.0) / 4.0
    target = 3000
    mel = torch.nn.functional.pad(mel, (0, 0, 0, target - mel.shape[0]), "constant", 0)
    return mel.t().unsqueeze(0)

# Load the Whisper onnxruntime model

In [5]:
import onnxruntime as ort
from typing import Tuple


class OnnxModel:
    def __init__(
        self,
        encoder: str,
        decoder: str,
    ):
        session_opts = ort.SessionOptions()
        session_opts.inter_op_num_threads = 1
        session_opts.intra_op_num_threads = 2

        self.session_opts = session_opts

        self.init_encoder(encoder)
        self.init_decoder(decoder)

    def init_encoder(self, encoder: str):
        self.encoder = ort.InferenceSession(
            encoder,
            sess_options=self.session_opts,
        )

        meta = self.encoder.get_modelmeta().custom_metadata_map
        self.n_text_layer = int(meta["n_text_layer"])
        self.n_text_ctx = int(meta["n_text_ctx"])
        self.n_text_state = int(meta["n_text_state"])
        self.sot = int(meta["sot"])
        self.eot = int(meta["eot"])
        self.translate = int(meta["translate"])
        self.transcribe = int(meta["transcribe"])
        self.no_timestamps = int(meta["no_timestamps"])
        self.no_speech = int(meta["no_speech"])
        self.blank = int(meta["blank_id"])

        self.sot_sequence = list(map(int, meta["sot_sequence"].split(",")))

        self.sot_sequence.append(self.no_timestamps)

        self.all_language_tokens = list(
            map(int, meta["all_language_tokens"].split(","))
        )
        self.all_language_codes = meta["all_language_codes"].split(",")
        self.lang2id = dict(zip(self.all_language_codes, self.all_language_tokens))
        self.id2lang = dict(zip(self.all_language_tokens, self.all_language_codes))

        self.is_multilingual = int(meta["is_multilingual"]) == 1

    def init_decoder(self, decoder: str):
        self.decoder = ort.InferenceSession(
            decoder,
            sess_options=self.session_opts,
        )

    def run_encoder(
        self,
        mel: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        n_layer_cross_k, n_layer_cross_v = self.encoder.run(
            [
                self.encoder.get_outputs()[0].name,
                self.encoder.get_outputs()[1].name,
            ],
            {
                self.encoder.get_inputs()[0].name: mel.numpy(),
            },
        )
        return torch.from_numpy(n_layer_cross_k), torch.from_numpy(n_layer_cross_v)

    def run_decoder(
        self,
        tokens: torch.Tensor,
        n_layer_self_k_cache: torch.Tensor,
        n_layer_self_v_cache: torch.Tensor,
        n_layer_cross_k: torch.Tensor,
        n_layer_cross_v: torch.Tensor,
        offset: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        logits, out_n_layer_self_k_cache, out_n_layer_self_v_cache = self.decoder.run(
            [
                self.decoder.get_outputs()[0].name,
                self.decoder.get_outputs()[1].name,
                self.decoder.get_outputs()[2].name,
            ],
            {
                self.decoder.get_inputs()[0].name: tokens.numpy(),
                self.decoder.get_inputs()[1].name: n_layer_self_k_cache.numpy(),
                self.decoder.get_inputs()[2].name: n_layer_self_v_cache.numpy(),
                self.decoder.get_inputs()[3].name: n_layer_cross_k.numpy(),
                self.decoder.get_inputs()[4].name: n_layer_cross_v.numpy(),
                self.decoder.get_inputs()[5].name: offset.numpy(),
            },
        )
        return (
            torch.from_numpy(logits),
            torch.from_numpy(out_n_layer_self_k_cache),
            torch.from_numpy(out_n_layer_self_v_cache),
        )

    def get_self_cache(self) -> Tuple[torch.Tensor, torch.Tensor]:
        batch_size = 1
        n_layer_self_k_cache = torch.zeros(
            self.n_text_layer,
            batch_size,
            self.n_text_ctx,
            self.n_text_state,
        )
        n_layer_self_v_cache = torch.zeros(
            self.n_text_layer,
            batch_size,
            self.n_text_ctx,
            self.n_text_state,
        )
        return n_layer_self_k_cache, n_layer_self_v_cache

    def suppress_tokens(self, logits, is_initial: bool) -> None:
        # suppress blank
        if is_initial:
            logits[self.eot] = float("-inf")
            logits[self.blank] = float("-inf")

        # suppress <|notimestamps|>
        logits[self.no_timestamps] = float("-inf")

        logits[self.sot] = float("-inf")
        logits[self.no_speech] = float("-inf")

        # logits is changed in-place
        logits[self.translate] = float("-inf")

    def detect_language(
        self, n_layer_cross_k: torch.Tensor, n_layer_cross_v: torch.Tensor
    ) -> int:
        tokens = torch.tensor([[self.sot]], dtype=torch.int64)
        offset = torch.zeros(1, dtype=torch.int64)
        n_layer_self_k_cache, n_layer_self_v_cache = self.get_self_cache()

        logits, n_layer_self_k_cache, n_layer_self_v_cache = self.run_decoder(
            tokens=tokens,
            n_layer_self_k_cache=n_layer_self_k_cache,
            n_layer_self_v_cache=n_layer_self_v_cache,
            n_layer_cross_k=n_layer_cross_k,
            n_layer_cross_v=n_layer_cross_v,
            offset=offset,
        )
        logits = logits.reshape(-1)
        mask = torch.ones(logits.shape[0], dtype=torch.int64)
        mask[self.all_language_tokens] = 0
        logits[mask] = float("-inf")
        lang_id = logits.argmax().item()
        print("detected language: ", self.id2lang[lang_id])
        return lang_id


def load_tokens(filename):
    tokens = dict()
    with open(filename, "r") as f:
        for line in f:
            t, i = line.split()
            tokens[int(i)] = t
    return tokens

# Decoding

In [6]:
import base64
from typing import Optional

def decode(sound_file: str, encoder: str, decoder: str, tokens_txt: str,
           language: Optional[str] = None,
           task: Optional[str] = None):
    '''
    Args:
        - sound_file (str): Path to the sound file for decoding.
        - encoder (str): Path to the encoder model in onnx format
        - decoder (str): Path to the decoder model in onnx format
        - tokens (str): Path to the tokens.txt
        - language (str): If not empty, it specifies the spoken language
            of the sound_file. Example values: en, zh, jp, fr, de.
        - task (str): It can be empty or take the following two values: transcribe, translate
    '''
    mel = compute_features(sound_file)
    model = OnnxModel(encoder, decoder)

    n_layer_cross_k, n_layer_cross_v = model.run_encoder(mel)

    if language is not None:
        if model.is_multilingual is False and language != "en":
            print(f"This model supports only English. Given: {language}")
            return

        if args.language not in model.lang2id:
            print(f"Invalid language: {language}")
            print(f"Valid values are: {list(model.lang2id.keys())}")
            return

        # [sot, lang, task, notimestamps]
        model.sot_sequence[1] = model.lang2id[language]
    elif model.is_multilingual is True:
        print("detecting language")
        lang = model.detect_language(n_layer_cross_k, n_layer_cross_v)
        model.sot_sequence[1] = lang

    if task is not None:
        if model.is_multilingual is False and task != "transcribe":
            print("This model supports only English. Please use --task=transcribe")
            return
        assert task in ["transcribe", "translate"], task

        if task == "translate":
            model.sot_sequence[2] = model.translate

    n_layer_self_k_cache, n_layer_self_v_cache = model.get_self_cache()

    tokens = torch.tensor([model.sot_sequence], dtype=torch.int64)
    offset = torch.zeros(1, dtype=torch.int64)
    logits, n_layer_self_k_cache, n_layer_self_v_cache = model.run_decoder(
        tokens=tokens,
        n_layer_self_k_cache=n_layer_self_k_cache,
        n_layer_self_v_cache=n_layer_self_v_cache,
        n_layer_cross_k=n_layer_cross_k,
        n_layer_cross_v=n_layer_cross_v,
        offset=offset,
    )
    offset += len(model.sot_sequence)
    # logits.shape (batch_size, tokens.shape[1], vocab_size)
    logits = logits[0, -1]
    model.suppress_tokens(logits, is_initial=True)
    #  logits = logits.softmax(dim=-1)
    # for greedy search, we don't need to compute softmax or log_softmax
    max_token_id = logits.argmax(dim=-1)
    results = []
    for i in range(model.n_text_ctx):
        if max_token_id == model.eot:
            break
        results.append(max_token_id.item())
        tokens = torch.tensor([[results[-1]]])

        logits, n_layer_self_k_cache, n_layer_self_v_cache = model.run_decoder(
            tokens=tokens,
            n_layer_self_k_cache=n_layer_self_k_cache,
            n_layer_self_v_cache=n_layer_self_v_cache,
            n_layer_cross_k=n_layer_cross_k,
            n_layer_cross_v=n_layer_cross_v,
            offset=offset,
        )
        offset += 1
        logits = logits[0, -1]
        model.suppress_tokens(logits, is_initial=False)
        max_token_id = logits.argmax(dim=-1)
    token_table = load_tokens(tokens_txt)
    s = b""
    for i in results:
        if i in token_table:
            s += base64.b64decode(token_table[i])

    print(s.decode().strip())

## Test tiny.en (English only model)

In [7]:
args = {
    "sound_file": "./sherpa-onnx-whisper-tiny.en/test_wavs/0.wav",
    "encoder": "./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.onnx",
    "decoder": "./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.onnx",
    "tokens_txt": "./sherpa-onnx-whisper-tiny.en/tiny.en-tokens.txt",
}
decode(**args)

After early nightfall, the yellow lamps would light up here and there, the squalid quarter of the brothels.


## Test tiny (Multilingual model)

### Transcribe

In [8]:
args = {
    "sound_file": "./sherpa-onnx-whisper-tiny/test_wavs/chinese-i-love-you.wav",
    "encoder": "./sherpa-onnx-whisper-tiny/tiny-encoder.onnx",
    "decoder": "./sherpa-onnx-whisper-tiny/tiny-decoder.onnx",
    "tokens_txt": "./sherpa-onnx-whisper-tiny/tiny-tokens.txt",
}
decode(**args)

detecting language
detected language:  zh
我愛你


### Translate to English

In [9]:
args = {
    "sound_file": "./sherpa-onnx-whisper-tiny/test_wavs/chinese-i-love-you.wav",
    "encoder": "./sherpa-onnx-whisper-tiny/tiny-encoder.onnx",
    "decoder": "./sherpa-onnx-whisper-tiny/tiny-decoder.onnx",
    "tokens_txt": "./sherpa-onnx-whisper-tiny/tiny-tokens.txt",
    "task": "translate",
}
decode(**args)

detecting language
detected language:  zh
I love you
