# Evaluation: AssemblyAI

Evaluate the [AssemblyAI](https://www.assemblyai.com/docs) service on the Czech split of the [FLEURS](https://huggingface.co/datasets/google/fleurs) dataset.

## Setup

In [None]:
%pip install -q ipywidgets tqdm humanize pandas pyarrow datasets assemblyai jiwer

In [None]:
import os
import sys
import time

from html import escape
from pathlib import Path
from datetime import timedelta

import pandas as pd
from datasets import Audio as AudioFeature, load_dataset

from dotenv import load_dotenv
from tqdm.notebook import tqdm
import humanize
from IPython.display import Audio, HTML, display

from utils.text import normalize
from utils.dataset import path_to_url

REPO_ID = "karmiq/fleurs-cs"
REVISION = "files"
RESULTS_PATH = Path("results-assemblyai.parquet")

## Load dataset

In [2]:
ds = load_dataset(REPO_ID, revision=REVISION, split="test")
ds = ds.cast_column("audio", AudioFeature(decode=False))

total_ms = sum(ex["duration_ms"] for ex in ds)
total_audio = humanize.precisedelta(timedelta(milliseconds=total_ms), minimum_unit='hours', suppress=['days'])

print(f"Total: {total_audio}")
display(ds)

Resolving data files:   0%|          | 0/2812 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/306 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/724 [00:00<?, ?it/s]

Total: 39.20 hours


Dataset({
    features: ['id', 'audio', 'raw_transcription', 'normalized_transcription', 'duration_ms', 'gender'],
    num_rows: 723
})

## Text normalization

In [3]:
reference_df = (
    pd.DataFrame({
        "sentence_id": ds["id"],
        "raw_transcription": pd.Series(ds["raw_transcription"]).fillna(""),
        "reference_normalized_transcription": pd.Series(ds["normalized_transcription"]).fillna(""),
    })
    .assign(
        computed_normalized_transcription=lambda df: df["raw_transcription"].map(normalize)
    )
)

mismatches = reference_df.loc[
    reference_df["computed_normalized_transcription"] != reference_df["reference_normalized_transcription"],
    ["sentence_id", "raw_transcription", "reference_normalized_transcription"],
]

print(
    "Normalization mismatches:",
    f"{len(mismatches)}/{len(reference_df)}",
    f"({(len(mismatches) / len(reference_df)) * 100:.0f}%)")

Normalization mismatches: 169/723 (23%)


## AssemblyAI transcription

In [17]:
from concurrent.futures import ThreadPoolExecutor

import assemblyai as aai

load_dotenv(".env")

assemblyai_api_key = os.environ.get("ASSEMBLYAI_API_KEY")
if not assemblyai_api_key:
    raise RuntimeError("ASSEMBLYAI_API_KEY environment variable is required")

aai.settings.api_key = assemblyai_api_key

config_aai = aai.TranscriptionConfig(
    format_text=True,
    punctuate=True,
    speech_model=aai.SpeechModel.best,
    language_detection=False,
    language_code="cs",
)
transcriber_aai = aai.Transcriber(config=config_aai)
model_version = "assemblyai-default"

sample_limit = None
workers = 5

pending = ds.select(range(sample_limit)) if sample_limit else ds

def _transcribe(ex):
    clip_t0 = time.perf_counter()
    url = path_to_url(ex["audio"]["path"], REPO_ID, REVISION)
    tx = transcriber_aai.transcribe(url)
    return {
        "id": ex["id"],
        "text": tx.text.strip(),
        "text_normalized": normalize(tx.text.strip()),
        "audio_path": ex["audio"]["path"],
        "processing_duration_ms": (time.perf_counter() - clip_t0) * 1_000,
        "model_version": model_version,
        "transcription_id": tx.id,
    }

t0 = time.perf_counter()

with ThreadPoolExecutor(max_workers=workers) as pool:
    results = list(tqdm(pool.map(_transcribe, pending), total=len(pending), desc="AssemblyAI", unit="clip"))

runtime_s = time.perf_counter() - t0
audio_ms_total = sum(pending["duration_ms"])
audio_human = humanize.precisedelta(
    timedelta(milliseconds=audio_ms_total),
    minimum_unit="minutes",
    suppress=["days"],
)
s_per_hour = (runtime_s / (audio_ms_total / 1000.0)) * 3600 if audio_ms_total else float("nan")

print(
    f"Transcribed {len(results)} clips covering {audio_human} in "
    f"{humanize.precisedelta(timedelta(seconds=runtime_s))} "
    f"({s_per_hour:.0f}s/h)")

AssemblyAI:   0%|          | 0/723 [00:00<?, ?clip/s]

Transcribed 723 clips covering 39 hours and 11.95 minutes in 9 minutes and 22.35 seconds (14s/h)


## Compute word error rate (WER)

In [22]:
import jiwer

results_df = (
    pd.DataFrame(results)
    .join(
        pd.Series(
            ds["normalized_transcription"],
            index=ds["id"],
            name="text_normalized_reference",
        ),
        on="id",
    )
    .assign(
        wer=lambda d: [jiwer.wer([r], [h]) for r, h in zip(d["text_normalized_reference"], d["text_normalized"])]
    )
)

results_df.to_parquet(RESULTS_PATH, index=False)

In [23]:
overall_wer = jiwer.wer(
    results_df["text_normalized_reference"].tolist(),
    results_df["text_normalized"].tolist(),
)

print(f"Overall WER: {overall_wer:.3f} ({overall_wer*100:.1f}%)")

Overall WER: 0.154 (15.4%)


In [20]:
worst_samples = results_df.sort_values("wer", ascending=False).head(25)
# worst_samples

In [21]:
display(HTML("<h3>Top 25 worst samples by WER</h3>"))
display(HTML(
    """<style>
section { margin: 0 0 0.25rem 0; }
audio { margin: 0.5rem 0 0 0; }
p { margin: 0 0 0.25rem 0; }
</style>"""
))

for row in worst_samples.itertuples(index=False):
    display(
        HTML(
            f"""
<section>
<p><code><small>WER: </small>{row.wer:.3f}</code></p>
<p>※ {escape(row.text_normalized_reference)}</p>
<p>≈ {escape(row.text_normalized)}</p>
{Audio(url=path_to_url(row.audio_path, REPO_ID, REVISION), embed=False)._repr_html_()}
</section>
            """.strip()
        )
    )