# Evaluation: Groq Whisper Turbo

Evaluate the [Whisper Large v3 Turbo](https://groq.com/blog/whisper-large-v3-turbo-now-available-on-groq-combining-speed-quality-for-speech-recognition/) model served by [Groq](https://groq.com) on the Czech split of the [FLEURS](https://huggingface.co/datasets/google/fleurs) dataset.


## Setup

In [None]:
%pip install -q ipywidgets tqdm humanize pandas pyarrow datasets groq requests jiwer python-dotenv

In [None]:
import os
import sys
import time

from html import escape
from pathlib import Path
from datetime import timedelta

import pandas as pd
from datasets import Audio as AudioFeature, load_dataset

from dotenv import load_dotenv
from tqdm.notebook import tqdm
import humanize
import requests
from groq import Groq, RateLimitError
from IPython.display import Audio, HTML, display

from utils.text import normalize
from utils.dataset import path_to_url

REPO_ID = "karmiq/fleurs-cs"
REVISION = "files"

MODEL_NAME = "whisper-large-v3-turbo"
RESULTS_PATH = Path("results-whisper-groq.parquet")

## Load dataset

In [29]:
ds = load_dataset(REPO_ID, revision=REVISION, split="test")
ds = ds.cast_column("audio", AudioFeature(decode=False))

total_ms = sum(ex["duration_ms"] for ex in ds)
total_audio = humanize.precisedelta(timedelta(milliseconds=total_ms), minimum_unit='hours', suppress=['days'])

print(f"Total: {total_audio}")
display(ds)

Resolving data files:   0%|          | 0/2812 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/306 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/724 [00:00<?, ?it/s]

Total: 39.20 hours


Dataset({
    features: ['id', 'audio', 'raw_transcription', 'normalized_transcription', 'duration_ms', 'gender'],
    num_rows: 723
})

## Resolve audio URLs


In [None]:
session = requests.Session()


def resolve_hf_url(audio_path: str) -> str:
    url = path_to_url(audio_path, REPO_ID, REVISION)
    response = session.head(url, allow_redirects=True, timeout=10)
    response.raise_for_status()
    return response.url


def _map_audio_url(example):
    example = dict(example)
    audio_url = resolve_hf_url(example["audio"]["path"])
    example["audio_url"] = audio_url
    return example


ds = ds.map(_map_audio_url, load_from_cache_file=False)

## Text normalization

In [35]:
reference_df = (
    pd.DataFrame({
        "sentence_id": ds["id"],
        "raw_transcription": pd.Series(ds["raw_transcription"]).fillna(""),
        "reference_normalized_transcription": pd.Series(ds["normalized_transcription"]).fillna(""),
    })
    .assign(
        computed_normalized_transcription=lambda df: df["raw_transcription"].map(normalize)
    )
)

mismatches = reference_df.loc[
    reference_df["computed_normalized_transcription"] != reference_df["reference_normalized_transcription"],
    ["sentence_id", "raw_transcription", "reference_normalized_transcription"],
]

print(
    "Normalization mismatches:",
    f"{len(mismatches)}/{len(reference_df)}",
    f"({(len(mismatches) / len(reference_df)) * 100:.0f}%)")

Normalization mismatches: 169/723 (23%)


## Groq transcription


In [None]:
load_dotenv(".env")

groq_api_key = os.environ.get("GROQ_API_KEY")
if not groq_api_key:
    raise RuntimeError("GROQ_API_KEY environment variable is required")

client = Groq(api_key=groq_api_key)

sample_limit = None

MAX_RATE_LIMIT_RETRIES = 5
BASE_BACKOFF_SECONDS = 5
BACKOFF_CAP_SECONDS = 300

pending = ds.select(range(sample_limit)) if sample_limit else ds


def _rate_limit_wait_seconds(error, attempt):
    retry_after = error.response.headers.get("Retry-After")
    if retry_after:
        return max(float(retry_after), 1.0)

    reset_header = error.response.headers.get("X-RateLimit-Reset")
    if reset_header:
        return max(float(reset_header) - time.time(), 1.0)

    backoff_seconds = BASE_BACKOFF_SECONDS * (2 ** (attempt - 1))
    return min(max(backoff_seconds, 1.0), BACKOFF_CAP_SECONDS)


def _transcribe(example):
    clip_t0 = time.perf_counter()
    attempt = 0

    while True:
        attempt += 1
        try:
            tx = client.audio.transcriptions.create(
                model=MODEL_NAME,
                url=example["audio_url"],
                language="cs",
                temperature=0.0,
            )
            break
        except RateLimitError as err:
            wait_seconds = _rate_limit_wait_seconds(err, attempt)
            if attempt >= MAX_RATE_LIMIT_RETRIES:
                raise
            tqdm.write(
                f"Rate limit hit (attempt {attempt}/{MAX_RATE_LIMIT_RETRIES}); "
                f"sleeping for {wait_seconds:.1f}s"
            )
            time.sleep(wait_seconds)

    text = tx.text.strip()
    return {
        "id": example["id"],
        "text": text,
        "text_normalized": normalize(text),
        "audio_path": example["audio"]["path"],
        "audio_url": example["audio_url"],
        "processing_duration_ms": (time.perf_counter() - clip_t0) * 1_000,
        "model_version": MODEL_NAME,
        "transcription_id": None,
    }


t0 = time.perf_counter()

results = []
for example in tqdm(pending, total=len(pending), desc="Groq", unit="clip"):
    results.append(_transcribe(example))

runtime_s = time.perf_counter() - t0
audio_ms_total = sum(pending["duration_ms"])
audio_human = humanize.precisedelta(
    timedelta(milliseconds=audio_ms_total),
    minimum_unit="minutes",
    suppress=["days"],
)
s_per_hour = (runtime_s / (audio_ms_total / 1000.0)) * 3600 if audio_ms_total else float("nan")

print(
    f"Transcribed {len(results)} clips covering {audio_human} in "
    f"{humanize.precisedelta(timedelta(seconds=runtime_s))} "
    f"({s_per_hour:.0f}s/h)"
)


Groq:   0%|          | 0/723 [00:00<?, ?clip/s]

Transcribed 723 clips covering 39 hours and 11.95 minutes in 41 minutes and 45.39 seconds (64s/h)


## Compute word error rate (WER)

In [37]:
import jiwer

results_df = (
    pd.DataFrame(results)
    .join(
        pd.Series(
            ds["normalized_transcription"],
            index=ds["id"],
            name="text_normalized_reference",
        ),
        on="id",
    )
    .assign(
        wer=lambda d: [jiwer.wer([r], [h]) for r, h in zip(d["text_normalized_reference"], d["text_normalized"])]
    )
)

results_df.to_parquet(RESULTS_PATH, index=False)

In [38]:
overall_wer = jiwer.wer(
    results_df["text_normalized_reference"].tolist(),
    results_df["text_normalized"].tolist(),
)

print(f"Overall WER: {overall_wer:.3f} ({overall_wer*100:.1f}%)")

Overall WER: 0.132 (13.2%)


In [39]:
worst_samples = results_df.sort_values("wer", ascending=False).head(25)
# worst_samples

In [40]:
display(HTML("<h3>Top 25 worst samples by WER</h3>"))
display(HTML(
    """<style>
section { margin: 0 0 0.25rem 0; }
audio { margin: 0.5rem 0 0 0; }
p { margin: 0 0 0.25rem 0; }
</style>"""
))

for row in worst_samples.itertuples(index=False):
    display(
        HTML(
            f"""
<section>
<p><code><small>WER: </small>{row.wer:.3f}</code></p>
<p>※ {escape(row.text_normalized_reference)}</p>
<p>≈ {escape(row.text_normalized)}</p>
{Audio(url=path_to_url(row.audio_path, REPO_ID, REVISION), embed=False)._repr_html_()}
</section>
            """.strip()
        )
    )