# Evaluation: Whisper (Local)

Evaluate the local Whisper transcription pipeline on the Czech split of the [FLEURS](https://huggingface.co/datasets/google/fleurs) dataset.

## Setup

In [None]:
%pip install -q ipywidgets tqdm humanize pandas pyarrow datasets mlx-whisper jiwer

In [20]:
import os
import sys
import time

from html import escape
from pathlib import Path
from datetime import timedelta

import pandas as pd
from datasets import Audio as AudioFeature, load_dataset

from tqdm.notebook import tqdm
import humanize
from IPython.display import Audio, HTML, display

from utils.text import normalize
from utils.dataset import path_to_url

os.environ.setdefault("HF_HUB_DISABLE_HF_TRANSFER", "1")
os.environ.setdefault("HF_HUB_DISABLE_XET", "1")

REPO_ID = "karmiq/fleurs-cs"
REVISION = "files"
RESULTS_PATH = Path("results-whisper-local.parquet")

## Load dataset

In [3]:
ds = load_dataset(REPO_ID, revision=REVISION, split="test")
ds = ds.cast_column("audio", AudioFeature(decode=False))

total_ms = sum(ex["duration_ms"] for ex in ds)
total_audio = humanize.precisedelta(timedelta(milliseconds=total_ms), minimum_unit='hours', suppress=['days'])

print(f"Total: {total_audio}")
display(ds)

Resolving data files:   0%|          | 0/2812 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/306 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/724 [00:00<?, ?it/s]

Total: 39.20 hours


Dataset({
    features: ['id', 'audio', 'raw_transcription', 'normalized_transcription', 'duration_ms', 'gender'],
    num_rows: 723
})

## Text normalization

In [4]:
reference_df = (
    pd.DataFrame({
        "sentence_id": ds["id"],
        "raw_transcription": pd.Series(ds["raw_transcription"]).fillna(""),
        "reference_normalized_transcription": pd.Series(ds["normalized_transcription"]).fillna(""),
    })
    .assign(
        computed_normalized_transcription=lambda df: df["raw_transcription"].map(normalize)
    )
)

mismatches = reference_df.loc[
    reference_df["computed_normalized_transcription"] != reference_df["reference_normalized_transcription"],
    ["sentence_id", "raw_transcription", "reference_normalized_transcription"],
]

print(
    "Normalization mismatches:",
    f"{len(mismatches)}/{len(reference_df)}",
    f"({(len(mismatches) / len(reference_df)) * 100:.0f}%)")

Normalization mismatches: 169/723 (23%)


## Whisper transcription

In [5]:
import mlx_whisper

model_id = "mlx-community/whisper-large-v3-mlx"
language = "cs"
sample_limit = None

pending = ds.select(range(sample_limit)) if sample_limit else ds

results = []
t0 = time.perf_counter()

for ex in tqdm(pending, total=len(pending), desc="Whisper", unit="clip"):
    clip_t0 = time.perf_counter()
    out = mlx_whisper.transcribe(
        str(ex["audio"]["path"]),
        path_or_hf_repo=model_id,
        language=language,
        verbose=None,
        word_timestamps=False,
    )
    text = (out.get("text") or "").strip()
    results.append({
        "id": ex["id"],
        "text": text,
        "text_normalized": normalize(text),
        "audio_path": ex["audio"]["path"],
        "processing_duration_ms": (time.perf_counter() - clip_t0) * 1_000,
        "model_version": model_id,
    })

runtime_s = time.perf_counter() - t0
audio_ms_total = sum(pending["duration_ms"])
audio_human = humanize.precisedelta(
    timedelta(milliseconds=audio_ms_total),
    minimum_unit="minutes",
    suppress=["days"],
)
s_per_hour = (runtime_s / (audio_ms_total / 1000.0)) * 3600 if audio_ms_total else float("nan")

print(
    f"Transcribed {len(results)} clips covering {audio_human} in "
    f"{humanize.precisedelta(timedelta(seconds=runtime_s))} "
    f"({s_per_hour:.0f}s/h)"
)

Whisper:   0%|          | 0/723 [00:00<?, ?clip/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

Transcribed 723 clips covering 39 hours and 11.95 minutes in 13 minutes and 36.68 seconds (21s/h)


## Compute word error rate (WER)

In [None]:
import jiwer

results_df = (
    pd.DataFrame(results)
    .join(
        pd.Series(
            ds["normalized_transcription"],
            index=ds["id"],
            name="text_normalized_reference",
        ),
        on="id",
    )
    .assign(
        wer=lambda d: [jiwer.wer([r], [h]) for r, h in zip(d["text_normalized_reference"], d["text_normalized"])]
    )
)

results_df.to_parquet(RESULTS_PATH, index=False)

overall_wer = jiwer.wer(
    results_df["text_normalized_reference"].tolist(),
    results_df["text_normalized"].tolist(),
)

print(f"Overall WER: {overall_wer:.3f} ({overall_wer*100:.1f}%)")

Overall WER: 0.136 (13.6%)


In [7]:
worst_samples = results_df.sort_values("wer", ascending=False).head(25)
# worst_samples

In [25]:
display(HTML("<h3>Top 25 worst samples by WER</h3>"))
display(HTML(
    """<style>
section { margin: 0 0 0.25rem 0; }
audio { margin: 0.5rem 0 0 0; }
p { margin: 0 0 0.25rem 0; }
</style>"""
))

for row in worst_samples.itertuples(index=False):
    display(
        HTML(
            f"""
<section>
<p><code><small>WER: </small>{row.wer:.3f}</code></p>
<p>※ {escape(row.text_normalized_reference)}</p>
<p>≈ {escape(row.text_normalized)}</p>
{Audio(url=path_to_url(row.audio_path, REPO_ID, REVISION), embed=False)._repr_html_()}
</section>
            """.strip()
        )
    )