Start by making sure you have the following packages in your environment:

In [None]:
# ! pip install evaluate datasets transformers[torch]

In [None]:
EVAL_DATASET = "janaab/supreme-court-speech"

In [None]:
BASE_MODEL = "openai/whisper-small"
TUNED_MODEL = "janaab/whisper-small-sc"

## Load data and models

In [None]:
from datasets import load_dataset

sc_speech = load_dataset(
    EVAL_DATASET, split="test"
)

In [None]:
from transformers import pipeline
import torch

if torch.cuda.is_available():
    device = "cuda:0"
    torch_dtype = torch.float16
else:
    device = "cpu"
    torch_dtype = torch.float32

pipe = pipeline(
    "automatic-speech-recognition",
    model=BASE_MODEL,
    torch_dtype=torch_dtype,
    device=device,
)

## Generate predictions

In [None]:
from tqdm import tqdm
from transformers.pipelines.pt_utils import KeyDataset

all_predictions = []

# run streamed inference
for prediction in tqdm(
    pipe(
        KeyDataset(sc_speech, "audio"),
        max_new_tokens=128,
        generate_kwargs={"task": "transcribe"},
        batch_size=32,
    ),
    total=len(sc_speech),
):
    all_predictions.append(prediction["text"])

## Evaluate metrics

In [None]:
from evaluate import load

wer_metric = load("wer")

wer_ortho = 100 * wer_metric.compute(
    references=sc_speech["transcript"], predictions=all_predictions
)
wer_ortho

In [None]:
from transformers.models.whisper.english_normalizer import BasicTextNormalizer

normalizer = BasicTextNormalizer()

# compute normalised WER
all_predictions_norm = [normalizer(pred) for pred in all_predictions]
all_references_norm = [normalizer(label) for label in sc_speech["transcript"]]

# filtering step to only evaluate the samples that correspond to non-zero references
all_predictions_norm = [
    all_predictions_norm[i]
    for i in range(len(all_predictions_norm))
    if len(all_references_norm[i]) > 0
]
all_references_norm = [
    all_references_norm[i]
    for i in range(len(all_references_norm))
    if len(all_references_norm[i]) > 0
]

wer = 100 * wer_metric.compute(
    references=all_references_norm, predictions=all_predictions_norm
)

wer