In [None]:
!pip install "datasets[audio]>=3" jiwer pandas jinja2 tqdm onnx_asr[gpu,hub]

In [None]:
from itertools import islice
from timeit import default_timer as timer

import jiwer
import pandas as pd
from datasets import load_dataset
from tqdm.notebook import tqdm

from onnx_asr import load_model

n = None
ds = list(islice(load_dataset("istupakov/russian_librispeech", split="test", streaming=True), n))

In [None]:
cer_transform = jiwer.Compose(
    [
        jiwer.ToLowerCase(),
        jiwer.SubstituteRegexes({"ё": "е"}),
        jiwer.RemovePunctuation(),
        jiwer.RemoveMultipleSpaces(),
        jiwer.Strip(),
        jiwer.ReduceToListOfListOfChars(),
    ]
)

wer_transform = jiwer.Compose(
    [
        jiwer.ToLowerCase(),
        jiwer.SubstituteRegexes({"ё": "е"}),
        jiwer.RemovePunctuation(),
        jiwer.RemoveMultipleSpaces(),
        jiwer.Strip(),
        jiwer.ReduceToListOfListOfWords(),
    ]
)


def calc_cer(transcript, result):
    metrics_cer = jiwer.process_characters(transcript, result, cer_transform, cer_transform)
    return {
        "char_errors": metrics_cer.substitutions + metrics_cer.deletions + metrics_cer.insertions,
        "char_count": len(metrics_cer.references[0]),
    }


def calc_wer(transcript, result):
    metrics_wer = jiwer.process_words(transcript, result, wer_transform, wer_transform)
    return {
        "word_errors": metrics_wer.substitutions + metrics_wer.deletions + metrics_wer.insertions,
        "word_count": len(metrics_wer.references[0]),
    }


def agg_metrics(group):
    return pd.Series(
        {
            "N": len(group),
            "CER": group.char_errors.sum() / group.char_count.sum(),
            "WER": group.word_errors.sum() / group.word_count.sum(),
            "RTFx": group.duration.sum() / group.asr_time.sum(),
        }
    )

In [None]:
def run_test(model_name, providers=None, quantization=None):
    model = load_model(model_name, quantization=quantization, providers=providers)

    for row in tqdm(ds[:100]):
        model.recognize(row["audio"]["array"], language="ru", sample_rate=row["audio"]["sampling_rate"])

    for row in tqdm(ds):
        start = timer()
        waveform = row["audio"]["array"]
        sample_rate = row["audio"]["sampling_rate"]
        result = model.recognize(waveform, language="ru", sample_rate=sample_rate)
        yield {
            "model": model_name,
            "providers": str(providers),
            "quantization": str(quantization),
            "result": result,
            "text": row["text"],
            "duration": waveform.shape[0] / sample_rate,
            "asr_time": timer() - start,
        }

In [None]:
providers = ["CUDAExecutionProvider"]

df = pd.DataFrame(
    [
        *run_test("gigaam-v2-ctc", providers),
        *run_test("gigaam-v2-rnnt", providers),
        *run_test("nemo-fastconformer-ru-ctc", providers),
        *run_test("nemo-fastconformer-ru-rnnt", providers),
        *run_test("alphacep/vosk-model-ru", providers),
        *run_test("alphacep/vosk-model-small-ru", providers),
        *run_test("whisper-base", providers),
        *run_test("onnx-community/whisper-large-v3-turbo", providers, "fp16"),
    ]
)

df["result"] = df["result"].fillna("")

tqdm.pandas(desc="calc_cer")
df = df.join(df.progress_apply(lambda row: pd.Series(calc_cer(row.text, row.result)), axis=1))

tqdm.pandas(desc="calc_wer")
df = df.join(df.progress_apply(lambda row: pd.Series(calc_wer(row.text, row.result)), axis=1))

with pd.option_context("display.max_rows", None):
    display(
        df.groupby(["model", "providers", "quantization"], sort=False)
        .apply(agg_metrics, include_groups=False)
        .style.format(formatter="{:0.0f}".format, subset=["N"])
        .format(formatter="{:,.2%}".format, subset=["CER", "WER"])
        .format(formatter="{:0.1f}".format, subset=["RTFx"])
    )