# Use pretrained model for test predictions

In [1]:
!pip3 install kenlm
!pip3 install -r requirements.txt



In [2]:
import torch

from utils import SR
from datasets import Audio, load_dataset
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2ProcessorWithLM

# DL data
dataset_name = "lucas-meyer/asr_af"
af_test_set = load_dataset(dataset_name, split="test")
af_test_set = af_test_set.cast_column("audio", Audio(sampling_rate=SR))

# DL model
repo_name = "lucas-meyer/wav2vec2-xls-r-300m-asr_af"
model_basic = Wav2Vec2ForCTC.from_pretrained(repo_name).to("cuda")
processor_basic = Wav2Vec2Processor.from_pretrained(repo_name)

# DL model with LM
repo_name = "lucas-meyer/wav2vec2-xls-r-300m-with-LM-asr_af"
model_with_LM = Wav2Vec2ForCTC.from_pretrained(repo_name).to("cuda")
processor_with_LM = Wav2Vec2ProcessorWithLM.from_pretrained(repo_name)

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

In [3]:
def predict_transcription(audio_sample, model, processor):
    # Get model inputs
    inputs = processor(
        audio_sample["audio"]["array"],
        sampling_rate=audio_sample["audio"]["sampling_rate"],
        return_tensors="pt",
        padding=True,
    ).to("cuda")

    # Pass inputs into model to get logits
    with torch.no_grad():
        logits = model(**inputs).logits

    # Decode logits to get predicted transcription
    if isinstance(processor, Wav2Vec2ProcessorWithLM):
        pred = processor.batch_decode(logits.cpu().numpy()).text
        pred = pred[0].lower()
    else:
        predicted_ids = torch.argmax(logits, dim=-1)
        pred = processor.batch_decode(predicted_ids)
        pred = pred[0].lower()

    return pred

In [4]:
for i in range(50):
    true_transcription = af_test_set[i]["transcription"].lower()
    pred_basic = predict_transcription(af_test_set[i], model_basic, processor_basic)
    pred_with_LM = predict_transcription(af_test_set[i], model_with_LM, processor_with_LM)
    print(f"Test {i}:")
    print(f"  - true (.....): {true_transcription}")
    print(f"  - pred (basic): {pred_basic}")
    print(f"  - pred (w/ LM): {pred_with_LM}\n")

Test 0:
  - true (.....): waarom is dit positief negatief of neutraal
  - pred (basic): waarom is dit positief negatief of neeutraaal
  - pred (w/ LM): waarom is dit positief negatief of neutraal

Test 1:
  - true (.....): dus is die woord measure baie naby aan leisure
  - pred (basic): dus is die woord measurer baie naby aan leisure
  - pred (w/ LM): dus is die woord measurer baie naby aan leisure

Test 2:
  - true (.....): die muskiekgroep zef maar zen het meer as 25 miljoen albums wêreldwyd verkoop
  - pred (basic): die musiekgroep zefmaar zen het meer as 25 miljoen albums wêreldwyd verkoop
  - pred (w/ LM): die musiekgroep zefmaar zen het meer as 34 miljoen albums wêreldwyd verkoop

Test 3:
  - true (.....): thus had the raw wilderness prepared him for this day
  - pred (basic): thus hat te row worldenes proepiedhim for this dayy
  - pred (w/ LM): thus ha te rog worldenes proepiedhim for this day

Test 4:
  - true (.....): uiteindelik was daar geen meer sprake van dominees en emira

In [34]:
true_transcriptions = []
model_predictions = []
model_with_LM_predictions = []

for i in range(len(af_test_set)):
    pred_basic = predict_transcription(af_test_set[i], model_basic, processor_basic)
    pred_with_LM = predict_transcription(af_test_set[i], model_with_LM, processor_with_LM)

    model_predictions.append(pred_basic)
    model_with_LM_predictions.append(pred_with_LM)
    true_transcriptions.append(af_test_set[i]["transcription"].lower())

    # Print progress
    print(f"\r{i+1}/{len(af_test_set)}\t\t", end="")
print("")

585/585		


In [35]:
from evaluate import load

wer = load("wer")
wer_score_model = wer.compute(predictions=model_predictions, references=true_transcriptions)
wer_score_model_with_LM = wer.compute(predictions=model_with_LM_predictions, references=true_transcriptions)

print(wer_score_model)
print(wer_score_model_with_LM)

Downloading builder script:   0%|          | 0.00/4.49k [00:00<?, ?B/s]

0.393003646308113
0.32440747493163175
