In [3]:
!pip install install transformers==4.28
!pip install evaluate>=0.3.0
!pip install datasets==2.6.1
!pip install huggingface_hub
!pip install jiwer
!pip install librosa
!pip install tqdm
!pip install peft==0.2.0
!pip install accelerate
!pip install bitsandbytes

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [4]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [8]:
#import argparse

from transformers import pipeline
from transformers.models.whisper.english_normalizer import BasicTextNormalizer
from datasets import load_dataset, Audio
import evaluate
from tqdm import tqdm

wer_metric = evaluate.load("wer")


def is_target_text_in_range(ref):
    if ref.strip() == "ignore time segment in scoring":
        return False
    else:
        return ref.strip() != ""


def get_text(sample):
    if "text" in sample:
        return sample["text"]
    elif "sentence" in sample:
        return sample["sentence"]
    elif "normalized_text" in sample:
        return sample["normalized_text"]
    elif "transcript" in sample:
        return sample["transcript"]
    elif "transcription" in sample:
        return sample["transcription"]
    else:
        raise ValueError(
            f"Expected transcript column of either 'text', 'sentence', 'normalized_text' or 'transcript'. Got sample of "
            ".join{sample.keys()}. Ensure a text column name is present in the dataset."
        )


whisper_norm = BasicTextNormalizer()


def normalise(batch):
    batch["norm_text"] = whisper_norm(get_text(batch))
    return batch


def data(dataset):
    for i, item in enumerate(dataset):
        yield {**item["audio"], "reference": item["norm_text"]}


def main(model_id,dataset,config,split,device,batch_size,max_eval_samples,streaming,language):
    batch_size = batch_size
    whisper_asr = pipeline(
        "automatic-speech-recognition", model=model_id, device=device
    )

    whisper_asr.model.config.forced_decoder_ids = (
        whisper_asr.tokenizer.get_decoder_prompt_ids(
            language=language, task="transcribe"
        )
    )

    dataset = load_dataset(
        dataset,
        config,
        split=split,
        streaming=streaming,
        use_auth_token=True,
    )

    # Only uncomment for debugging
    # dataset = dataset.take(args.max_eval_samples)

    dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
    dataset = dataset.map(normalise)
    dataset = dataset.filter(is_target_text_in_range, input_columns=["norm_text"])

    predictions = []
    references = []

    # run streamed inference
    for out in tqdm(whisper_asr(data(dataset), batch_size=batch_size)):
        predictions.append(whisper_norm(out["text"]))
        references.append(out["reference"][0])

    wer = wer_metric.compute(references=references, predictions=predictions)
    wer = round(100 * wer, 2)

    print("WER:", wer)

In [None]:
import torch
from transformers import pipeline
from transformers.models.whisper.english_normalizer import BasicTextNormalizer
from datasets import load_dataset, Audio
import evaluate
from tqdm import tqdm

from transformers import (
    AutomaticSpeechRecognitionPipeline,
    WhisperForConditionalGeneration,
    WhisperTokenizer,
    WhisperProcessor,
)
from peft import PeftModel, PeftConfig

wer_metric = evaluate.load("wer")


def is_target_text_in_range(ref):
    if ref.strip() == "ignore time segment in scoring":
        return False
    else:
        return ref.strip() != ""


def get_text(sample):
    if "text" in sample:
        return sample["text"]
    elif "sentence" in sample:
        return sample["sentence"]
    elif "normalized_text" in sample:
        return sample["normalized_text"]
    elif "transcript" in sample:
        return sample["transcript"]
    elif "transcription" in sample:
        return sample["transcription"]
    else:
        raise ValueError(
            f"Expected transcript column of either 'text', 'sentence', 'normalized_text' or 'transcript'. Got sample of "
            ".join{sample.keys()}. Ensure a text column name is present in the dataset."
        )


whisper_norm = BasicTextNormalizer()


def normalise(batch):
    batch["norm_text"] = whisper_norm(get_text(batch))
    return batch


def data(dataset):
    for i, item in enumerate(dataset):
        yield {**item["audio"], "reference": item["norm_text"]}

def main_peft(model_id,dataset,config,split,device,batch_size,max_eval_samples,streaming,language):
    peft_model_id = model_id
    language = language
    task = "transcribe"
    peft_config = PeftConfig.from_pretrained(peft_model_id)
    model = WhisperForConditionalGeneration.from_pretrained(
        peft_config.base_model_name_or_path, load_in_8bit=True, device_map="auto"
    )

    model = PeftModel.from_pretrained(model, peft_model_id)
    tokenizer = WhisperTokenizer.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)
    processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)
    feature_extractor = processor.feature_extractor
    forced_decoder_ids = processor.get_decoder_prompt_ids(language=language, task=task)
    pipe = AutomaticSpeechRecognitionPipeline(model=model, tokenizer=tokenizer, feature_extractor=feature_extractor)

    dataset = load_dataset(
        dataset,
        config,
        split=split,
        streaming=streaming,
        use_auth_token=True,
    )

    # Only uncomment for debugging
    # dataset = dataset.take(args.max_eval_samples)

    dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
    dataset = dataset.map(normalise)
    dataset = dataset.filter(is_target_text_in_range, input_columns=["norm_text"])

    predictions = []
    references = []

    # run streamed inference
    with torch.cuda.amp.autocast():
        for out in tqdm(pipe(data(dataset), generate_kwargs={"forced_decoder_ids": forced_decoder_ids}, max_new_tokens=255)["text"]):
            predictions.append(whisper_norm(out["text"]))
            references.append(out["reference"][0])

    wer = wer_metric.compute(references=references, predictions=predictions)
    wer = round(100 * wer, 2)

    print("WER:", wer)

In [4]:
model_id = "openai/whisper-medium"
dataset = "common_voice"
config = "id"
split = "test"
device = 0
batch_size = 8
max_eval_samples = None
streaming = False
language = "id"


In [5]:
main(model_id,dataset,config,split,device,batch_size,max_eval_samples,streaming,language)

            This version of the Common Voice dataset is deprecated.
            You can download the latest one with
            >>> load_dataset("mozilla-foundation/common_voice_11_0", "en")
            
1844it [19:16,  1.59it/s]


WER: 11.96


In [9]:
model_id = "cahya/whisper-medium-id"
dataset = "common_voice"
config = "id"
split = "test"
device = 0
batch_size = 8
max_eval_samples = None
streaming = False
language = "id"

In [10]:
main(model_id,dataset,config,split,device,batch_size,max_eval_samples,streaming,language)

Downloading (…)lve/main/config.json: 0.00B [00:00, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/3.06G [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/832 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json: 0.00B [00:00, ?B/s]

Downloading (…)olve/main/merges.txt: 0.00B [00:00, ?B/s]

Downloading (…)main/normalizer.json: 0.00B [00:00, ?B/s]

Downloading (…)in/added_tokens.json: 0.00B [00:00, ?B/s]

Downloading (…)cial_tokens_map.json: 0.00B [00:00, ?B/s]

Downloading (…)rocessor_config.json: 0.00B [00:00, ?B/s]

1844it [1:00:48,  1.98s/it]


WER: 2.88
