Prepare colab environment:

In [None]:
!pip install datasets>=1.18.3
!pip install transformers==4.11.3
!pip install librosa
!pip install jiwer
!pip install evaluate
!pip install rouge_score

In [1]:
import evaluate

In [6]:
from datasets import load_dataset
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import soundfile as sf
import torch
from jiwer import wer


common_voice = load_dataset('DTU54DL/dmeo', split="test")

model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h").to("cuda")
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-960h")



Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-large-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [14]:
def map_to_pred(batch):
    input_values = processor(batch["audio"]["array"], sampling_rate=batch["audio"]["sampling_rate"], return_tensors="pt", padding="longest").input_values
    input_values = input_values.to(torch.device('cuda'))
    with torch.no_grad():
        logits = model(input_values).logits

    predicted_ids = torch.argmax(logits, dim=-1)
    transcription = processor.batch_decode(predicted_ids)
    batch["transcription"] = transcription[0].lower()
    batch["sentence"] = batch["sentence"].lower()
    return batch

result = common_voice.map(map_to_pred)




  0%|          | 0/1000 [00:00<?, ?ex/s]

In [15]:
print("WER:", wer(result["sentence"], result["transcription"]))

WER: 0.3634438955539873


In [16]:
bleu = evaluate.load('bleu')
rouge = evaluate.load('rouge')

bleu_res = bleu.compute(predictions=result["sentence"], references=result["transcription"])
rouge_res = rouge.compute(predictions=result["sentence"], references=result["transcription"])

print(f"BLEU: {bleu_res}\nROUGE: {rouge_res}")

INFO:absl:Using default tokenizer.


BLEU: {'bleu': 0.49055386368568643, 'precisions': [0.6657503563429037, 0.5391067785082748, 0.4443876246484275, 0.3630769230769231], 'brevity_penalty': 1.0, 'length_ratio': 1.1533583842179427, 'translation_length': 9822, 'reference_length': 8516}
ROUGE: {'rouge1': 0.7614702528311696, 'rouge2': 0.6446272859730633, 'rougeL': 0.7603843611078379, 'rougeLsum': 0.7599907297070143}
