In [1]:
from jiwer import wer
from datasets import load_dataset, DatasetDict, Audio
from whisper_normalizer.basic import BasicTextNormalizer
from transformers import pipeline

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import gc

import torch
# from transformers.models.whisper.english_normalizer import BasicTextNormalizer
from numerize import numerize


In [3]:

# %% ../nbs/00_utils.ipynb 4
whisper_norm = BasicTextNormalizer()

# %% ../nbs/00_utils.ipynb 5
def is_target_text_in_range(ref):
    if ref.strip() == "ignore time segment in scoring":
        return False
    else:
        return ref.strip() != ""


def get_text(sample):
    if "text" in sample:
        return sample["text"]
    elif "sentence" in sample:
        return sample["sentence"]
    elif "normalized_text" in sample:
        return sample["normalized_text"]
    elif "transcript" in sample:
        return sample["transcript"]
    elif "transcription" in sample:
        return sample["transcription"]
    else:
        raise ValueError(
            f"Expected transcript column of either 'text', 'sentence', 'normalized_text' or 'transcript'. Got sample of "
            ".join{sample.keys()}. Ensure a text column name is present in the dataset."
        )


def normalise(batch):
    batch["norm_text"] = whisper_norm(get_text(batch))
    return batch


def data(dataset):
    for i, item in enumerate(dataset):
        yield {**item["audio"], "reference": item["norm_text"]}

# %% ../nbs/00_utils.ipynb 6
def get_model_size(model):
    total_params = sum(param.numel() for param in model.parameters())
    return numerize.numerize(total_params)

# %% ../nbs/00_utils.ipynb 7
def clear_gpu_memory():
    torch.cuda.empty_cache()
    gc.collect()

In [4]:
poly = DatasetDict()

poly["train"] = load_dataset(
    "PolyAI/minds14", "en-US", split="train[0%:80%]"
)
poly["test"] = load_dataset(
    "PolyAI/minds14", "en-US", split="train[80%:100%]"
)

print(poly)

Downloading builder script: 100%|██████████| 5.95k/5.95k [00:00<00:00, 14.8MB/s]
Downloading readme: 100%|██████████████████| 5.29k/5.29k [00:00<00:00, 19.5MB/s]
Downloading data: 100%|██████████████████████| 471M/471M [01:18<00:00, 5.97MB/s]
Generating train split: 563 examples [00:00, 23518.45 examples/s]


DatasetDict({
    train: Dataset({
        features: ['path', 'audio', 'transcription', 'english_transcription', 'intent_class', 'lang_id'],
        num_rows: 450
    })
    test: Dataset({
        features: ['path', 'audio', 'transcription', 'english_transcription', 'intent_class', 'lang_id'],
        num_rows: 113
    })
})


In [5]:
whisper_asr = pipeline("automatic-speech-recognition", model="kurianbenoy/hfa-poly_english_small")

Downloading (…)lve/main/config.json: 100%|█| 2.23k/2.23k [00:00<00:00, 4.82MB/s]
Downloading pytorch_model.bin: 100%|█████████| 967M/967M [02:37<00:00, 6.12MB/s]
Downloading (…)okenizer_config.json: 100%|█████| 805/805 [00:00<00:00, 3.40MB/s]
Downloading (…)olve/main/vocab.json: 100%|█| 1.04M/1.04M [00:00<00:00, 1.14MB/s]
Downloading (…)olve/main/merges.txt: 100%|████| 494k/494k [00:00<00:00, 561kB/s]
Downloading (…)main/normalizer.json: 100%|██| 52.7k/52.7k [00:00<00:00, 173kB/s]
Downloading (…)in/added_tokens.json: 100%|█| 2.08k/2.08k [00:00<00:00, 2.17MB/s]
Downloading (…)cial_tokens_map.json: 100%|█| 2.08k/2.08k [00:00<00:00, 7.45MB/s]
Downloading (…)rocessor_config.json: 100%|█████| 339/339 [00:00<00:00, 1.02MB/s]


In [6]:
def ld():
    dataset = poly["test"]
    dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
    dataset = dataset.map(normalise)
    dataset = dataset.filter(is_target_text_in_range, input_columns=["norm_text"])
    return dataset

In [7]:
%%time
ds = ld()
predictions = []
predictions_raw = []
references = []
references_raw = []
normalizer = BasicTextNormalizer()
for out in whisper_asr(data(ds), batch_size=4):
    predictions_raw.append(out["text"])
    references_raw.append(out["reference"][0])
    predictions.append(normalizer(out["text"]))
    references.append(normalizer(out["reference"][0]))

Map: 100%|██████████████████████████| 113/113 [00:00<00:00, 11071.68 examples/s]
Filter: 100%|███████████████████████| 113/113 [00:00<00:00, 55459.44 examples/s]


CPU times: user 3h 22min 46s, sys: 4min 56s, total: 3h 27min 42s
Wall time: 21min 9s


In [8]:
wer(predictions_raw, references_raw)

0.3715670436187399

In [9]:
wer(predictions, references)

0.24654023577652487