Based on "Multi-Dataset Evaluation with ðŸ¤— Transformers and Datasets" by [Sanchit Gandhi](https://huggingface.co/sanchit-gandhi)

In [None]:
!pip install datasets transformers evaluate huggingface_hub jiwer soundfile librosa

# If true, will use dynamic audio_ctx based on input length
# If false, will run the model normally (padding to 30sec)
USE_AUDIO_CTX = True

# Type of model, must match LOAD_FROM if set
MODEL_TYPE = "openai/whisper-small.en"

# Load from path (for finetuned model), or set to None if you'd like to use the normal one
LOAD_FROM = "/workspace/acft-small.en"
#LOAD_FROM = None

In [None]:
from datasets import load_dataset

librispeech_clean = load_dataset("librispeech_asr", "all", split="test.clean", streaming=True)
librispeech_other = load_dataset("librispeech_asr", "all", split="test.other", streaming=True)

voxpopuli = load_dataset("facebook/voxpopuli", "en", split="test", streaming=True)

In [None]:
esb_datasets = {
    "LibriSpeech Clean": librispeech_clean,
    "LibriSpeech Other": librispeech_other,
    "VoxPopuli": voxpopuli
}

In [None]:
def get_text(sample):
    if "text" in sample:
        return sample["text"]
    elif "sentence" in sample:
        return sample["sentence"]
    elif "normalized_text" in sample:
        return sample["normalized_text"]
    elif "transcript" in sample:
        return sample["transcript"]
    else:
        raise ValueError(f"Sample: {sample.keys()} has no transcript.")

In [None]:
from transformers import pipeline, WhisperForConditionalGeneration

whisper_asr = pipeline("automatic-speech-recognition", model=MODEL_TYPE, device=0)

device = whisper_asr.device

if LOAD_FROM:
    whisper_asr.model = WhisperForConditionalGeneration.from_pretrained(LOAD_FROM).to(device)

## Load the Word Error Rate metric

We'll assess our system using the [Word Error Rate (WER)](https://huggingface.co/spaces/evaluate-metric/wer) metric, the 'de-facto' metric for assessing ASR systems. We'll load the WER metric from the ðŸ¤— Evaluate library:

In [None]:
import evaluate

wer_metric = evaluate.load("wer")

Bonus: You can also try other evaluation methods like the [Character Error Rate (CER)](https://huggingface.co/spaces/evaluate-metric/cer). For the CER, update the above statement to `evaluate.load("cer")`

## Normalisation

The [Whisper paper](https://cdn.openai.com/papers/whisper.pdf) demonstrates the drastic effect that normalising the text outputs have on WER. The normalisation step is important as it removes errors unrelated to the speech recognition task, such as casing and punctuation. It also makes the formatting consistent between references and predictions by converting spelled out numbers to symbollic form (e.g. "two" -> "2") and British spellings to American (e.g. "grey" -> "gray").

We first write a function to normalise the reference of a single sample according to the Whisper English text normaliser:

In [None]:
whisper_norm = whisper_asr.tokenizer._normalize

def normalise(batch):
    batch["norm_text"] = whisper_norm(get_text(batch))
    return batch

We'll apply this function to our data using ðŸ¤— Datasets' [`.map`](https://huggingface.co/docs/datasets/process#map) method in our evaluation pipeline.

We also need to remove any empty reference transcriptions from our dataset, as these will give a divide by 0 error in the WER calculation.

We write a function that indicates which samples to keep, and which to discard. This function, `is_target_text_in_range`, returns a boolean: reference transcriptions that are not empty return True, and those are empty return False:

In [None]:
filter_sequences = ["ignore time segment in scoring", ""]

def is_target_text_in_range(ref):
    ref = ref.strip()
    return ref not in filter_sequences

Again, we'll apply this function to our data using ðŸ¤— Datasets' [`.map`](https://huggingface.co/docs/datasets/process#map) method in our evaluation pipeline.

## Multi-Dataset Evaluation

In this final section, we combine everything together to form the multi-dataset evaluation loop for the Whisper model.

First, we define a generator that iterates over the dataset and yields the audio samples and reference text ready for our model:

In [None]:
def data(dataset):
    for i, item in enumerate(dataset):
        yield {**item["audio"], "reference": item["norm_text"]}

In [None]:
from datasets import Audio
from torch import nn
import torch
from transformers.modeling_outputs import BaseModelOutput

def compute_partially_encoder(model, data, n_audio_ctx):
  diffy = 2*n_audio_ctx - data.shape[2]

  if diffy > 0:
    data = nn.functional.pad(data, [0, diffy, 0, 0, 0, 0], "constant", 0.0)
  elif diffy < 0:
    data = data[:,:,:diffy]

  if n_audio_ctx == 1500:
    return model.encoder(data)

  input_embeds = nn.functional.gelu(model.encoder.conv1(data))
  input_embeds = nn.functional.gelu(model.encoder.conv2(input_embeds))
  input_embeds = input_embeds.permute(0, 2, 1)

  embed_pos = model.encoder.embed_positions.weight[:n_audio_ctx]

  hidden_states = input_embeds + embed_pos
  hidden_states = nn.functional.dropout(hidden_states, p=model.encoder.dropout, training=model.encoder.training)

  for idx, encoder_layer in enumerate(model.encoder.layers):
    to_drop = False
    if model.encoder.training:
      dropout_probability = torch.rand([])
      if dropout_probability < model.encoder.layerdrop:
        to_drop = True

    if to_drop:
        layer_outputs = (None, None)
    else:
        if model.encoder.gradient_checkpointing and model.encoder.training:
            layer_outputs = model.encoder._gradient_checkpointing_func(
                encoder_layer.__call__,
                hidden_states,
                None,
                None,
                False,
            )
        else:
            layer_outputs = encoder_layer(
                hidden_states,
                None,
                layer_head_mask=None,
                output_attentions=False,
            )

        hidden_states = layer_outputs[0]

  hidden_states = model.encoder.layer_norm(hidden_states)
  return BaseModelOutput(last_hidden_state=hidden_states)


def whisper_asr_partial(whisper_asr, data, batch_size=0):
    for data in data:
        waveform = data["array"]
        sampling_rate = data["sampling_rate"]
        
        # Use the model and processor to transcribe the audio:
        input_features = whisper_asr.feature_extractor(
            waveform, sampling_rate=sampling_rate, return_tensors="pt"
        ).input_features.to(whisper_asr.model.model.device)

        length = len(waveform) / sampling_rate

        n_ctx = int(round((1500.0 / 30.0) * length )) + 8
        encoder_hidden_states_partial = compute_partially_encoder(whisper_asr.model.model, input_features, n_ctx)
        tokens = whisper_asr.model.generate(encoder_outputs=encoder_hidden_states_partial)
        yield {"text": whisper_asr.tokenizer.batch_decode(tokens, skip_special_tokens=True)[0], "reference": [data["reference"]]}

In [None]:
import time

# set the batch size in accordance to your device
BATCH_SIZE = 1
wer_results = []

t0 = time.time()
# loop over all the datasets in the ESB benchmark
for dataset_name, dataset in esb_datasets.items():
    # only for debugging, restricts the number of rows to numeric value in brackets
    dataset = dataset.take(128)

    # resample to 16kHz
    dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))

    # normalise references
    dataset = dataset.map(normalise)

    # remove any empty references
    dataset = dataset.filter(is_target_text_in_range, input_columns=["norm_text"])

    # placeholders for predictions and references
    predictions = []
    references = []

    # run streamed inference
    if USE_AUDIO_CTX:
        for out in whisper_asr_partial(whisper_asr, data(dataset), batch_size=BATCH_SIZE):
            predictions.append(whisper_norm(out["text"]))
            references.append(out["reference"][0])
    else:
        for out in whisper_asr(data(dataset), batch_size=BATCH_SIZE):
            predictions.append(whisper_norm(out["text"]))
            references.append(out["reference"][0])

    # compute the WER
    wer = wer_metric.compute(references=references, predictions=predictions)
    wer = round(100 * wer, 2)

    wer_results.append(wer)
t1 = time.time()

In [None]:
import pandas as pd

print("Time:",t1-t0)
df = pd.DataFrame({"Dataset": esb_datasets.keys(), "WER": wer_results})
df