In [11]:
import torch
import re
import jiwer
from transformers.models.whisper.english_normalizer import BasicTextNormalizer
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from datasets import load_from_disk

In [3]:
device = "cuda:0"

model_id = "../bin/whisper-cslu-kids"
tokenizer_id = "openai/whisper-large-v3"

model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id,
    low_cpu_mem_usage=True,
)
model.to(device)

processor = AutoProcessor.from_pretrained(tokenizer_id)

pipe = pipeline(
    "automatic-speech-recognition",
    model=model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    chunk_length_s=30,
    batch_size=1,  # batch size for inference - set based on your device
    device=device,
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Device set to use cuda:0


In [4]:
dss = {
    "cslu": load_from_disk("../data/cslu_kids.ds"),
}

In [5]:
# Test
sample = dss["cslu"][0]["audio"]

result = pipe(sample)
print(dss["cslu"][0]["sentence"])
print(result)

The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


<bn> a b c d e f g <br> h i j k<ln> l m n o p<ln> <br> q r s t u v w x y and z <bn> <pau> my<bn> family<bn> <bn> she<bn> went<bn> to<bn> go<bn> pick<bn> up<bn> my<bn> little<bn> sister<bn> and<bn> she's gonna<bn> <br> come<bn> tomorrow she's gonna come at eleven <pau> yeah <pau> <bn> okay <bn> clean my room <bn> and<bn> then<bn> when<bn> i'm<bn> done<bn> i<bn> get<bn> to<bn> play<bn> with<bn> my<bn> friend<bn> <pau> brittney we go over to her house and we play barbies <pau> and <br> we uhm <pau> we ride our bikes after we're done and then we eat some ice cream <pau> i have four sisters <pau> <bs> one's fifteen <bn> th* four* thirteen <br> and ten and one's five <pau> yeah <pau> <bn> they're nice and they let me <br> uhm watch tv<ln> in their room <bs> and uhm <br> <pau> and<bn> <br> she<bn> when sometimes<ln> when i <br> do a little bit of chores <br> she gives me a dollar
{'text': ' a b c d e f g h i j k l m n o p q r s t u v w x y and z my family she went to go pick up my little sist

# Run on original data

In [None]:
results = pipe(dss["cslu"]["audio"])

In [9]:
normalizer = BasicTextNormalizer()


def normalize_transcript(text):
    # The original transcript has annotations, for example a pause is <pau>
    # Remove tags in angle brackets
    text = re.sub(r"<[^>]*>", "", text)

    # These are "false starts" in the original transcript, for example th*
    # These are ignored by ASR
    # Remove words that end with asterisks (e.g., th*)
    text = re.sub(r"\S*\*", "", text)

    # Apply Whisper's English normalizer
    normalized_text = normalizer(text)

    return normalized_text

In [13]:
def weighted_wer(ref: list[str], pred: list[str]):
    # Normalize both predictions and references
    pred_normalized = [normalize_transcript(text) for text in pred]
    label_normalized = [normalize_transcript(text) for text in ref]

    total_errors = 0
    total_words = 0

    for pred_text, ref_text in zip(pred_normalized, label_normalized):
        ref_words = ref_text.split()

        # Compute WER for this sample
        sample_wer = jiwer.wer(ref_text, pred_text)

        # Accumulate weighted errors
        sample_errors = sample_wer * len(ref_words)
        total_errors += sample_errors
        total_words += len(ref_words)

    weighted_wer = total_errors / total_words if total_words > 0 else 0.0

    return {"wer": weighted_wer}


predictions = [d["text"] for d in results]
references = dss["cslu"]["sentence"]
wer_score = weighted_wer(references, predictions)
print(wer_score)  # {'wer': 0.15899313905005624} on 07-08-25

{'wer': 0.15899313905005624}
