In [1]:
from transformers import pipeline
from transformers.models.whisper.english_normalizer import BasicTextNormalizer
from datasets import load_dataset, Audio

In [2]:
def is_target_text_in_range(ref):
    if ref.strip() == "ignore time segment in scoring":
        return False
    else:
        return ref.strip() != ""


def get_text(sample):
    if "text" in sample:
        return sample["text"]
    elif "sentence" in sample:
        return sample["sentence"]
    elif "normalized_text" in sample:
        return sample["normalized_text"]
    elif "transcript" in sample:
        return sample["transcript"]
    elif "transcription" in sample:
        return sample["transcription"]
    else:
        raise ValueError(
            f"Expected transcript column of either 'text', 'sentence', 'normalized_text' or 'transcript'. Got sample of "
            ".join{sample.keys()}. Ensure a text column name is present in the dataset."
        )


In [3]:
whisper_norm = BasicTextNormalizer()

In [4]:
def normalise(batch):
    batch["norm_text"] = whisper_norm(get_text(batch))
    return batch


def data(dataset):
    for i, item in enumerate(dataset):
        yield {**item["audio"], "reference": item["norm_text"]}

## Evaluate Param Bharats Model

In [5]:
! nvidia-smi

Sat Mar  4 14:32:52 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 520.61.05    Driver Version: 520.61.05    CUDA Version: 11.8     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Quadro RTX 5000     Off  | 00000000:1E:00.0 Off |                  Off |
| 34%   33C    P8    13W / 230W |      0MiB / 16384MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [6]:
model_id = "parambharat/whisper-small-ml"

In [7]:
whisper_asr = pipeline(
        "automatic-speech-recognition", model=model_id, device=0
    )

dataset = load_dataset(
        "mozilla-foundation/common_voice_11_0",
        "ml",
        split="test"
)

# Only uncomment for debugging
#dataset = dataset.take(args.max_eval_samples)

dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
dataset = dataset.map(normalise)
dataset = dataset.filter(is_target_text_in_range, input_columns=["norm_text"])

dataset.shape

In [11]:
%%time
predictions = []
references = []

# run streamed inference
for out in whisper_asr(data(dataset), batch_size=16):
    predictions.append(whisper_norm(out["text"]))
    references.append(out["reference"][0])



CPU times: user 6min 33s, sys: 53 s, total: 7min 26s
Wall time: 4min 38s


In [18]:
from jiwer import wer, cer

In [19]:
rwer = wer(references, predictions)
rwer = round(100 * rwer, 2)
print(f"The WER of model: {rwer}")

The WER of model: 21.65


In [20]:
from jiwer import cer

In [21]:
rcer = cer(references, predictions)
rcer = round(100 * rcer, 2)
print(f"The CER of model: {rcer}")

The CER of model: 11.78


## Common function to evaluate models

In [4]:
from datasets import load_dataset, Audio
from jiwer import wer, cer
from transformers import pipeline
from transformers.models.whisper.english_normalizer import BasicTextNormalizer

In [5]:
whisper_norm = BasicTextNormalizer()

In [15]:
def is_target_text_in_range(ref):
    if ref.strip() == "ignore time segment in scoring":
        return False
    else:
        return ref.strip() != ""


def get_text(sample):
    if "text" in sample:
        return sample["text"]
    elif "sentence" in sample:
        return sample["sentence"]
    elif "normalized_text" in sample:
        return sample["normalized_text"]
    elif "transcript" in sample:
        return sample["transcript"]
    elif "transcription" in sample:
        return sample["transcription"]
    else:
        raise ValueError(
            f"Expected transcript column of either 'text', 'sentence', 'normalized_text' or 'transcript'. Got sample of "
            ".join{sample.keys()}. Ensure a text column name is present in the dataset."
        )

def normalise(batch):
    batch["norm_text"] = whisper_norm(get_text(batch))
    return batch

def data(dataset):
    for i, item in enumerate(dataset):
        yield {**item["audio"], "reference": item["norm_text"]}

In [18]:
def evaluate_whisper_model_common_voice(model_name: "str")->None:
    whisper_asr = pipeline(
            "automatic-speech-recognition", model=model_name, device=0
        )

    dataset = load_dataset(
            "mozilla-foundation/common_voice_11_0",
            "ml",
            split="test"
    )
    dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
    dataset = dataset.map(normalise)
    dataset = dataset.filter(is_target_text_in_range, input_columns=["norm_text"])
    predictions = []
    references = []

    for out in whisper_asr(data(dataset), batch_size=32):
        predictions.append(whisper_norm(out["text"]))
        references.append(out["reference"][0])
        
    rwer = wer(references, predictions)
    rwer = round(100 * rwer, 2)
    print(f"The WER of model: {rwer}")

    rcer = cer(references, predictions)
    rcer = round(100 * rcer, 2)
    print(f"The CER of model: {rcer}")

In [19]:
evaluate_whisper_model_common_voice("parambharat/whisper-small-ml")

Found cached dataset common_voice_11_0 (/home/.cache/huggingface/datasets/mozilla-foundation___common_voice_11_0/ml/11.0.0/2c65b95d99ca879b1b1074ea197b65e0497848fd697fdb0582e0f6b75b6f4da0)
Loading cached processed dataset at /home/.cache/huggingface/datasets/mozilla-foundation___common_voice_11_0/ml/11.0.0/2c65b95d99ca879b1b1074ea197b65e0497848fd697fdb0582e0f6b75b6f4da0/cache-b5fde927f6328b58.arrow
Loading cached processed dataset at /home/.cache/huggingface/datasets/mozilla-foundation___common_voice_11_0/ml/11.0.0/2c65b95d99ca879b1b1074ea197b65e0497848fd697fdb0582e0f6b75b6f4da0/cache-34d1ec8a736a6ac3.arrow


The WER of model: 21.65
The CER of model: 11.78
