# Use pretrained model for test predictions

In [2]:
!pip3 install -r requirements.txt

Collecting datasets==2.14.4 (from -r requirements.txt (line 4))
  Downloading datasets-2.14.4-py3-none-any.whl (519 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m519.3/519.3 kB[0m [31m7.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting evaluate==0.4.0 (from -r requirements.txt (line 5))
  Downloading evaluate-0.4.0-py3-none-any.whl (81 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m81.4/81.4 kB[0m [31m8.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting huggingface_hub==0.16.4 (from -r requirements.txt (line 6))
  Downloading huggingface_hub-0.16.4-py3-none-any.whl (268 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m268.8/268.8 kB[0m [31m8.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting librosa==0.10.0.post2 (from -r requirements.txt (line 7))
  Downloading librosa-0.10.0.post2-py3-none-any.whl (253 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m253.0/253.0 kB[0m [31m11.1 MB/s[0m eta [36m0:00:0

In [3]:
from huggingface_hub import login
from utils import WRITE_ACCESS_TOKEN

login(WRITE_ACCESS_TOKEN)

Token will not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /root/.cache/huggingface/token
Login successful


In [4]:
import os
import torch
import pandas as pd
from evaluate import load

from load_fleurs_nl import load_fleurs_nl
from load_fleurs_zu import load_fleurs_zu

from datasets import Audio, load_dataset
from utils import SR, remove_special_characters_batch
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2ProcessorWithLM

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def predict_transcription(audio_sample, model, processor):
    # Get model inputs
    inputs = processor(
        audio_sample["audio"]["array"],
        sampling_rate=audio_sample["audio"]["sampling_rate"],
        return_tensors="pt",
        padding=True,
    ).to(device)

    # Pass inputs into model to get logits
    with torch.no_grad():
        logits = model(**inputs).logits

    # Decode logits to get predicted transcription
    if isinstance(processor, Wav2Vec2ProcessorWithLM):
        pred = processor.batch_decode(logits.cpu().numpy()).text
        pred = pred[0].lower()
    else:
        predicted_ids = torch.argmax(logits, dim=-1)
        pred = processor.batch_decode(predicted_ids)
        pred = pred[0].lower()

    return pred

### Basic fine-tuning models

In [None]:
for dataset_name, repo_name in [
    ("asr_af", "lucas-meyer/xls-r-asr_af-run1"),
    ("asr_af", "lucas-meyer/xls-r-asr_af-run2"),
    ("asr_af", "lucas-meyer/xls-r-asr_af-run3"),
    ("asr_af", "lucas-meyer/xls-r-asr_af-run4"),
    ("asr_af", "lucas-meyer/xls-r-asr_af-run5"),
    ("asr_af", "lucas-meyer/xls-r-asr_af-run6"),
    ("asr_af", "lucas-meyer/xls-r-asr_af-run7"),
    ("asr_af", "lucas-meyer/xls-r-asr_af-run8"),

    ("asr_xh", "lucas-meyer/xls-r-asr_xh-run1"),
    ("asr_xh", "lucas-meyer/xls-r-asr_xh-run2"),
    ("asr_xh", "lucas-meyer/xls-r-asr_xh-run3"),
    ("asr_xh", "lucas-meyer/xls-r-asr_xh-run4"),
    ("asr_xh", "lucas-meyer/xls-r-asr_xh-run5"),
    ("asr_xh", "lucas-meyer/xls-r-asr_xh-run6"),
    ("asr_xh", "lucas-meyer/xls-r-asr_xh-run7"),
    ("asr_xh", "lucas-meyer/xls-r-asr_xh-run8"),

#     ("asr_xh", "lucas-meyer/seq-xls-r-fleurs_nl-run2-asr_af-run8"), # ignore this lol
#     ("asr_af", "lucas-meyer/seq-xls-r-fleurs_zu-run3-asr_xh-run8"), # ignore this lol
]:
    # DL data
    if dataset_name == "asr_af" or dataset_name == "asr_xh":
        val_set = load_dataset("lucas-meyer/" + dataset_name, split="validation")
        test_set = load_dataset("lucas-meyer/" + dataset_name, split="test")
        val_set = val_set.cast_column("audio", Audio(sampling_rate=SR))
        test_set = test_set.cast_column("audio", Audio(sampling_rate=SR))
    else:
        dataset_dir = os.path.join("data", "speech_data", dataset_name)
        if not os.path.exists(dataset_dir):
            os.makedirs(dataset_dir, exist_ok=True)
            # Create dataset by combining 3 datasets into an audiofolder
            csv_entries = []
            if (dataset_name == "fleurs_nl"):
                csv_entries += load_fleurs_nl(write_audio=True)
            elif (dataset_name == "fleurs_zu"):
                csv_entries += load_fleurs_zu(write_audio=True)
            metadata = pd.DataFrame(csv_entries, columns=['file_name', 'transcription'])
            metadata.to_csv(path_or_buf=os.path.join(dataset_dir, "metadata.csv"), sep=",", index=False)

            # Load dataset from audiofolder that you created
            dataset = load_dataset("audiofolder", data_dir=dataset_dir)
        else:
            # Load dataset from audiofolder that you created
            dataset = load_dataset("audiofolder", data_dir=dataset_dir)

        # Downsample audio to SR = 16000 and init train/val/test sets
        val_set = dataset["validation"].cast_column("audio", Audio(sampling_rate=SR)).rename_column("transcription", "sentence")
        test_set = dataset["test"].cast_column("audio", Audio(sampling_rate=SR)).rename_column("transcription", "sentence")
        val_set = val_set.map(remove_special_characters_batch)
        test_set = test_set.map(remove_special_characters_batch)
        # torch.cuda.empty_cache()

    # DL model
    model_basic = Wav2Vec2ForCTC.from_pretrained(repo_name).to(device)
    processor_basic = Wav2Vec2Processor.from_pretrained(repo_name)

    print(f"Results: {repo_name}", end="\n\n")

    # --------------------------------------------------------------
    # VALIDATION SET SCORE
    # --------------------------------------------------------------
    true_transcriptions = []
    model_predictions = []
    model_with_LM_predictions = []

    for i in range(len(val_set)):
        pred_basic = predict_transcription(val_set[i], model_basic, processor_basic)
        model_predictions.append(pred_basic)
        if "fleurs" in dataset_name:
            true_transcriptions.append(val_set[i]["sentence"].lower())
        else:
            true_transcriptions.append(val_set[i]["transcription"].lower())

        # Print progress
        print(f"\r{i+1}/{len(val_set)}\t\t", end="")
    print("")

    wer = load("wer")
    wer_score_model = wer.compute(predictions=model_predictions, references=true_transcriptions)
    print(f"Validation score: {wer_score_model}", end="\n\n")

    # --------------------------------------------------------------
    # TEST SET SCORE
    # --------------------------------------------------------------
    true_transcriptions = []
    model_predictions = []
    model_with_LM_predictions = []

    for i in range(len(test_set)):
        pred_basic = predict_transcription(test_set[i], model_basic, processor_basic)
        model_predictions.append(pred_basic)
        if "fleurs" in dataset_name:
            true_transcriptions.append(test_set[i]["sentence"].lower())
        else:
            true_transcriptions.append(test_set[i]["transcription"].lower())

        # Print progress
        print(f"\r{i+1}/{len(test_set)}\t\t", end="")
    print("")

    wer = load("wer")
    wer_score_model = wer.compute(predictions=model_predictions, references=true_transcriptions)
    print(f"Test score: {wer_score_model}", end="\n\n")

Results: lucas-meyer/xls-r-asr_af-run1

447/447		


Downloading builder script:   0%|          | 0.00/4.49k [00:00<?, ?B/s]

Validation score: 0.4245732738735968

476/476		
Test score: 0.4331452565280147



Downloading (…)lve/main/config.json:   0%|          | 0.00/2.06k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.26G [00:00<?, ?B/s]

Downloading (…)rocessor_config.json:   0%|          | 0.00/214 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/328 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/607 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/51.0 [00:00<?, ?B/s]

Results: lucas-meyer/xls-r-asr_af-run2

447/447		
Validation score: 0.42472704905428266

476/476		
Test score: 0.4341949875344443



Downloading (…)lve/main/config.json:   0%|          | 0.00/2.06k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.26G [00:00<?, ?B/s]

Downloading (…)rocessor_config.json:   0%|          | 0.00/214 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/328 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/607 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/51.0 [00:00<?, ?B/s]

Results: lucas-meyer/xls-r-asr_af-run3

447/447		
Validation score: 0.43057050592034446

476/476		
Test score: 0.43655688229891093



Downloading (…)lve/main/config.json:   0%|          | 0.00/2.06k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.26G [00:00<?, ?B/s]

Downloading (…)rocessor_config.json:   0%|          | 0.00/214 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/328 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/607 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/51.0 [00:00<?, ?B/s]

Results: lucas-meyer/xls-r-asr_af-run4

447/447		
Validation score: 0.3789020452099031

476/476		
Test score: 0.3840703319774308



Downloading (…)lve/main/config.json:   0%|          | 0.00/2.06k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.26G [00:00<?, ?B/s]

Downloading (…)rocessor_config.json:   0%|          | 0.00/214 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/328 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/607 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/51.0 [00:00<?, ?B/s]

Results: lucas-meyer/xls-r-asr_af-run5

447/447		
Validation score: 0.37982469629401816

476/476		
Test score: 0.3801338407033198



Downloading (…)lve/main/config.json:   0%|          | 0.00/2.06k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.26G [00:00<?, ?B/s]

Downloading (…)rocessor_config.json:   0%|          | 0.00/214 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/328 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/607 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/51.0 [00:00<?, ?B/s]

Results: lucas-meyer/xls-r-asr_af-run6

447/447		
Validation score: 0.4133476856835307

476/476		
Test score: 0.4248786248523816



Downloading (…)lve/main/config.json:   0%|          | 0.00/2.06k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.26G [00:00<?, ?B/s]

Downloading (…)rocessor_config.json:   0%|          | 0.00/214 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/328 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/607 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/51.0 [00:00<?, ?B/s]

Results: lucas-meyer/xls-r-asr_af-run7

447/447		
Validation score: 0.41826849146547745

476/476		
Test score: 0.42120456632987796



Downloading (…)lve/main/config.json:   0%|          | 0.00/2.06k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.26G [00:00<?, ?B/s]

Downloading (…)rocessor_config.json:   0%|          | 0.00/214 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/328 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/607 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/51.0 [00:00<?, ?B/s]

Results: lucas-meyer/xls-r-asr_af-run8

447/447		
Validation score: 0.36998308473012453

476/476		
Test score: 0.38774439049993437



Downloading readme:   0%|          | 0.00/715 [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/238M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/147M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/306M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/303M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/273M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/37.6M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/108M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/144M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/148M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/142M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/194M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/3 [00:00<?, ?it/s]

Generating test split:   0%|          | 0/627 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/2506 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/338 [00:00<?, ? examples/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/2.06k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.26G [00:00<?, ?B/s]

Downloading (…)rocessor_config.json:   0%|          | 0.00/214 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/328 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/608 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/51.0 [00:00<?, ?B/s]

Results: lucas-meyer/xls-r-asr_xh-run1

338/338		
Validation score: 0.6132167152575316

627/627		
Test score: 0.6238851095993954



Downloading (…)lve/main/config.json:   0%|          | 0.00/2.06k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.26G [00:00<?, ?B/s]

Downloading (…)rocessor_config.json:   0%|          | 0.00/214 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/328 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/608 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/51.0 [00:00<?, ?B/s]

Results: lucas-meyer/xls-r-asr_xh-run2

338/338		
Validation score: 0.6047942986718496

627/627		
Test score: 0.6173847316704459



Downloading (…)lve/main/config.json:   0%|          | 0.00/2.06k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.26G [00:00<?, ?B/s]

Downloading (…)rocessor_config.json:   0%|          | 0.00/214 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/328 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/608 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/51.0 [00:00<?, ?B/s]

Results: lucas-meyer/xls-r-asr_xh-run3

338/338		
Validation score: 0.5008098477486232

627/627		
Test score: 0.5052154195011338



Downloading (…)lve/main/config.json:   0%|          | 0.00/2.06k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.26G [00:00<?, ?B/s]

Downloading (…)rocessor_config.json:   0%|          | 0.00/214 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/328 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/608 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/51.0 [00:00<?, ?B/s]

Results: lucas-meyer/xls-r-asr_xh-run4

338/338		
Validation score: 0.626498218334953

627/627		
Test score: 0.6476190476190476



Downloading (…)lve/main/config.json:   0%|          | 0.00/2.06k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.26G [00:00<?, ?B/s]

Downloading (…)rocessor_config.json:   0%|          | 0.00/214 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/328 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/608 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/51.0 [00:00<?, ?B/s]

Results: lucas-meyer/xls-r-asr_xh-run5

338/338		
Validation score: 0.5435698088759313

627/627		
Test score: 0.5493575207860922



Downloading (…)lve/main/config.json:   0%|          | 0.00/2.06k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.26G [00:00<?, ?B/s]

Downloading (…)rocessor_config.json:   0%|          | 0.00/214 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/328 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/608 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/51.0 [00:00<?, ?B/s]

Results: lucas-meyer/xls-r-asr_xh-run6

338/338		
Validation score: 0.5306122448979592

627/627		
Test score: 0.5387755102040817



Downloading (…)lve/main/config.json:   0%|          | 0.00/2.06k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.26G [00:00<?, ?B/s]

Downloading (…)rocessor_config.json:   0%|          | 0.00/214 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/328 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/608 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/51.0 [00:00<?, ?B/s]

Results: lucas-meyer/xls-r-asr_xh-run7

338/338		
Validation score: 0.5513443472627146

627/627		
Test score: 0.5210884353741496



Downloading (…)lve/main/config.json:   0%|          | 0.00/2.06k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.26G [00:00<?, ?B/s]

Downloading (…)rocessor_config.json:   0%|          | 0.00/214 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/328 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/608 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/51.0 [00:00<?, ?B/s]

Results: lucas-meyer/xls-r-asr_xh-run8

338/338		
Validation score: 0.5387107223841918

627/627		
Test score: 0.5371126228269085



Downloading (…)lve/main/config.json:   0%|          | 0.00/2.07k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.26G [00:00<?, ?B/s]

Downloading (…)rocessor_config.json:   0%|          | 0.00/214 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/328 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/607 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/51.0 [00:00<?, ?B/s]

Results: lucas-meyer/seq-xls-r-fleurs_nl-run2-asr_af-run8

338/338		
Validation score: 1.1956592160673793

627/627		
Test score: 1.127891156462585



Downloading (…)lve/main/config.json:   0%|          | 0.00/2.07k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.26G [00:00<?, ?B/s]

Downloading (…)rocessor_config.json:   0%|          | 0.00/214 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/328 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/608 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/51.0 [00:00<?, ?B/s]

Results: lucas-meyer/seq-xls-r-fleurs_zu-run3-asr_xh-run8

447/447		
Validation score: 0.9675534368752883

476/476		
Test score: 0.9813672746358746



### Sequential fine-tuning models

In [3]:
for dataset_name, repo_name in [
    ("asr_af", "lucas-meyer/seq-xls-r-fleurs_nl-run2-asr_af-run1"),
    ("asr_af", "lucas-meyer/seq-xls-r-fleurs_nl-run2-asr_af-run2"),
    ("asr_af", "lucas-meyer/seq-xls-r-fleurs_nl-run2-asr_af-run3"),
    ("asr_af", "lucas-meyer/seq-xls-r-fleurs_nl-run2-asr_af-run4"),
    ("asr_af", "lucas-meyer/seq-xls-r-fleurs_nl-run2-asr_af-run5"),
    ("asr_af", "lucas-meyer/seq-xls-r-fleurs_nl-run2-asr_af-run6"),
    ("asr_af", "lucas-meyer/seq-xls-r-fleurs_nl-run2-asr_af-run7"),
    ("asr_af", "lucas-meyer/seq-xls-r-fleurs_nl-run2-asr_af-run8"),

    ("asr_xh", "lucas-meyer/seq-xls-r-fleurs_zu-run3-asr_xh-run1"),
    ("asr_xh", "lucas-meyer/seq-xls-r-fleurs_zu-run3-asr_xh-run2"),
    ("asr_xh", "lucas-meyer/seq-xls-r-fleurs_zu-run3-asr_xh-run3"),
    ("asr_xh", "lucas-meyer/seq-xls-r-fleurs_zu-run3-asr_xh-run4"),
    ("asr_xh", "lucas-meyer/seq-xls-r-fleurs_zu-run3-asr_xh-run5"),
    ("asr_xh", "lucas-meyer/seq-xls-r-fleurs_zu-run3-asr_xh-run6"),
    ("asr_xh", "lucas-meyer/seq-xls-r-fleurs_zu-run3-asr_xh-run7"),
    ("asr_xh", "lucas-meyer/seq-xls-r-fleurs_zu-run3-asr_xh-run8"),
]:
    # DL data
    if dataset_name == "asr_af" or dataset_name == "asr_xh":
        val_set = load_dataset("lucas-meyer/" + dataset_name, split="validation")
        test_set = load_dataset("lucas-meyer/" + dataset_name, split="test")
        val_set = val_set.cast_column("audio", Audio(sampling_rate=SR))
        test_set = test_set.cast_column("audio", Audio(sampling_rate=SR))
    else:
        dataset_dir = os.path.join("data", "speech_data", dataset_name)
        if not os.path.exists(dataset_dir):
            os.makedirs(dataset_dir, exist_ok=True)
            # Create dataset by combining 3 datasets into an audiofolder
            csv_entries = []
            if (dataset_name == "fleurs_nl"):
                csv_entries += load_fleurs_nl(write_audio=True)
            elif (dataset_name == "fleurs_zu"):
                csv_entries += load_fleurs_zu(write_audio=True)
            metadata = pd.DataFrame(csv_entries, columns=['file_name', 'transcription'])
            metadata.to_csv(path_or_buf=os.path.join(dataset_dir, "metadata.csv"), sep=",", index=False)

            # Load dataset from audiofolder that you created
            dataset = load_dataset("audiofolder", data_dir=dataset_dir)
        else:
            # Load dataset from audiofolder that you created
            dataset = load_dataset("audiofolder", data_dir=dataset_dir)

        # Downsample audio to SR = 16000 and init train/val/test sets
        val_set = dataset["validation"].cast_column("audio", Audio(sampling_rate=SR)).rename_column("transcription", "sentence")
        test_set = dataset["test"].cast_column("audio", Audio(sampling_rate=SR)).rename_column("transcription", "sentence")
        val_set = val_set.map(remove_special_characters_batch)
        test_set = test_set.map(remove_special_characters_batch)
        # torch.cuda.empty_cache()

    # DL model
    model_basic = Wav2Vec2ForCTC.from_pretrained(repo_name).to(device)
    processor_basic = Wav2Vec2Processor.from_pretrained(repo_name)

    print(f"Results: {repo_name}", end="\n\n")

    # --------------------------------------------------------------
    # VALIDATION SET SCORE
    # --------------------------------------------------------------
    true_transcriptions = []
    model_predictions = []
    model_with_LM_predictions = []

    for i in range(len(val_set)):
        pred_basic = predict_transcription(val_set[i], model_basic, processor_basic)
        model_predictions.append(pred_basic)
        if "fleurs" in dataset_name:
            true_transcriptions.append(val_set[i]["sentence"].lower())
        else:
            true_transcriptions.append(val_set[i]["transcription"].lower())

        # Print progress
        print(f"\r{i+1}/{len(val_set)}\t\t", end="")
    print("")

    wer = load("wer")
    wer_score_model = wer.compute(predictions=model_predictions, references=true_transcriptions)
    print(f"Validation score: {wer_score_model}", end="\n\n")

    # --------------------------------------------------------------
    # TEST SET SCORE
    # --------------------------------------------------------------
    true_transcriptions = []
    model_predictions = []
    model_with_LM_predictions = []

    for i in range(len(test_set)):
        pred_basic = predict_transcription(test_set[i], model_basic, processor_basic)
        model_predictions.append(pred_basic)
        if "fleurs" in dataset_name:
            true_transcriptions.append(test_set[i]["sentence"].lower())
        else:
            true_transcriptions.append(test_set[i]["transcription"].lower())

        # Print progress
        print(f"\r{i+1}/{len(test_set)}\t\t", end="")
    print("")

    wer = load("wer")
    wer_score_model = wer.compute(predictions=model_predictions, references=true_transcriptions)
    print(f"Test score: {wer_score_model}", end="\n\n")

Downloading readme:   0%|          | 0.00/715 [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/315M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/364M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/409M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/858M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/286M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/3 [00:00<?, ?it/s]

Generating test split:   0%|          | 0/476 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/2723 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/447 [00:00<?, ? examples/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/2.07k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.26G [00:00<?, ?B/s]

Downloading (…)rocessor_config.json:   0%|          | 0.00/214 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/328 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/607 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/51.0 [00:00<?, ?B/s]

Results: lucas-meyer/seq-xls-r-fleurs_nl-run2-asr_af-run1

447/447		


Downloading builder script:   0%|          | 0.00/4.49k [00:00<?, ?B/s]

Validation score: 0.41396278640627404

476/476		
Test score: 0.42514105760398896



Downloading (…)lve/main/config.json:   0%|          | 0.00/2.07k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.26G [00:00<?, ?B/s]

Downloading (…)rocessor_config.json:   0%|          | 0.00/214 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/328 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/607 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/51.0 [00:00<?, ?B/s]

Results: lucas-meyer/seq-xls-r-fleurs_nl-run2-asr_af-run2

447/447		
Validation score: 0.38751345532831

476/476		
Test score: 0.3865634431177011



Downloading (…)lve/main/config.json:   0%|          | 0.00/2.07k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.26G [00:00<?, ?B/s]

Downloading (…)rocessor_config.json:   0%|          | 0.00/214 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/328 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/607 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/51.0 [00:00<?, ?B/s]

Results: lucas-meyer/seq-xls-r-fleurs_nl-run2-asr_af-run3

447/447		
Validation score: 0.36706135629709363

476/476		
Test score: 0.37160477627607924



Downloading (…)lve/main/config.json:   0%|          | 0.00/2.07k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.26G [00:00<?, ?B/s]

Downloading (…)rocessor_config.json:   0%|          | 0.00/214 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/328 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/607 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/51.0 [00:00<?, ?B/s]

Results: lucas-meyer/seq-xls-r-fleurs_nl-run2-asr_af-run4

447/447		
Validation score: 0.4041211748423804

476/476		
Test score: 0.42251673008791496



Downloading (…)lve/main/config.json:   0%|          | 0.00/2.07k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.26G [00:00<?, ?B/s]

Downloading (…)rocessor_config.json:   0%|          | 0.00/214 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/328 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/607 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/51.0 [00:00<?, ?B/s]

Results: lucas-meyer/seq-xls-r-fleurs_nl-run2-asr_af-run5

447/447		
Validation score: 0.3722897124404121

476/476		
Test score: 0.37777194593885316



Downloading (…)lve/main/config.json:   0%|          | 0.00/2.07k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.26G [00:00<?, ?B/s]

Downloading (…)rocessor_config.json:   0%|          | 0.00/214 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/328 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/607 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/51.0 [00:00<?, ?B/s]

Results: lucas-meyer/seq-xls-r-fleurs_nl-run2-asr_af-run6

447/447		
Validation score: 0.3516838382285099

476/476		
Test score: 0.37488518567117174



Downloading (…)lve/main/config.json:   0%|          | 0.00/2.07k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.26G [00:00<?, ?B/s]

Downloading (…)rocessor_config.json:   0%|          | 0.00/214 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/328 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/607 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/51.0 [00:00<?, ?B/s]

Results: lucas-meyer/seq-xls-r-fleurs_nl-run2-asr_af-run7

447/447		
Validation score: 0.3672151314777795

476/476		
Test score: 0.37160477627607924



Downloading (…)lve/main/config.json:   0%|          | 0.00/2.07k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.26G [00:00<?, ?B/s]

Downloading (…)rocessor_config.json:   0%|          | 0.00/214 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/328 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/607 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/51.0 [00:00<?, ?B/s]

Results: lucas-meyer/seq-xls-r-fleurs_nl-run2-asr_af-run8

447/447		
Validation score: 0.3776718437644164

476/476		
Test score: 0.3827581682193938



Downloading readme:   0%|          | 0.00/715 [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/238M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/147M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/306M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/303M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/273M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/37.6M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/108M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/144M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/148M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/142M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/194M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/3 [00:00<?, ?it/s]

Generating test split:   0%|          | 0/627 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/2506 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/338 [00:00<?, ? examples/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/2.07k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.26G [00:00<?, ?B/s]

Downloading (…)rocessor_config.json:   0%|          | 0.00/214 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/328 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/608 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/51.0 [00:00<?, ?B/s]

Results: lucas-meyer/seq-xls-r-fleurs_zu-run3-asr_xh-run1

338/338		
Validation score: 0.640427599611273

627/627		
Test score: 0.6439909297052154



Downloading (…)lve/main/config.json:   0%|          | 0.00/2.07k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.26G [00:00<?, ?B/s]

Downloading (…)rocessor_config.json:   0%|          | 0.00/214 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/328 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/608 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/51.0 [00:00<?, ?B/s]

Results: lucas-meyer/seq-xls-r-fleurs_zu-run3-asr_xh-run2

338/338		
Validation score: 0.5114998380304503

627/627		
Test score: 0.5132275132275133



Downloading (…)lve/main/config.json:   0%|          | 0.00/2.07k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.26G [00:00<?, ?B/s]

Downloading (…)rocessor_config.json:   0%|          | 0.00/214 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/328 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/608 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/51.0 [00:00<?, ?B/s]

Results: lucas-meyer/seq-xls-r-fleurs_zu-run3-asr_xh-run3

338/338		
Validation score: 0.5205701328150307

627/627		
Test score: 0.5318216175359033



Downloading (…)lve/main/config.json:   0%|          | 0.00/2.07k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.26G [00:00<?, ?B/s]

Downloading (…)rocessor_config.json:   0%|          | 0.00/214 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/328 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/608 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/51.0 [00:00<?, ?B/s]

Results: lucas-meyer/seq-xls-r-fleurs_zu-run3-asr_xh-run4

338/338		
Validation score: 0.540006478781989

627/627		
Test score: 0.5517762660619804



Downloading (…)lve/main/config.json:   0%|          | 0.00/2.07k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.26G [00:00<?, ?B/s]

Downloading (…)rocessor_config.json:   0%|          | 0.00/214 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/328 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/608 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/51.0 [00:00<?, ?B/s]

Results: lucas-meyer/seq-xls-r-fleurs_zu-run3-asr_xh-run5

338/338		
Validation score: 0.5578231292517006

627/627		
Test score: 0.5765684051398338



Downloading (…)lve/main/config.json:   0%|          | 0.00/2.07k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.26G [00:00<?, ?B/s]

Downloading (…)rocessor_config.json:   0%|          | 0.00/214 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/328 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/608 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/51.0 [00:00<?, ?B/s]

Results: lucas-meyer/seq-xls-r-fleurs_zu-run3-asr_xh-run6

338/338		
Validation score: 0.5221898283122773

627/627		
Test score: 0.5529856386999245



Downloading (…)lve/main/config.json:   0%|          | 0.00/2.07k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.26G [00:00<?, ?B/s]

Downloading (…)rocessor_config.json:   0%|          | 0.00/214 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/328 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/608 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/51.0 [00:00<?, ?B/s]

Results: lucas-meyer/seq-xls-r-fleurs_zu-run3-asr_xh-run7

338/338		
Validation score: 0.5011337868480725

627/627		
Test score: 0.4988662131519274



Downloading (…)lve/main/config.json:   0%|          | 0.00/2.07k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.26G [00:00<?, ?B/s]

Downloading (…)rocessor_config.json:   0%|          | 0.00/214 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/328 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/608 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/51.0 [00:00<?, ?B/s]

Results: lucas-meyer/seq-xls-r-fleurs_zu-run3-asr_xh-run8

338/338		
Validation score: 0.5364431486880467

627/627		
Test score: 0.5685563114134543



### Best models (with LMs)

In [None]:
for dataset_name, repo_name in [
    ("asr_af", "lucas-meyer/xls-r-asr_af-run8-with-LM"),
    ("asr_af", "lucas-meyer/seq-xls-r-fleurs_nl-run2-asr_af-run6-with-LM"),
    ("asr_xh", "lucas-meyer/xls-r-asr_xh-run3-with-LM"),
    ("asr_xh", "lucas-meyer/seq-xls-r-fleurs_zu-run3-asr_xh-run7-with-LM"),
]:
    # DL data
    if dataset_name == "asr_af" or dataset_name == "asr_xh":
        val_set = load_dataset("lucas-meyer/" + dataset_name, split="validation")
        test_set = load_dataset("lucas-meyer/" + dataset_name, split="test")
        val_set = val_set.cast_column("audio", Audio(sampling_rate=SR))
        test_set = test_set.cast_column("audio", Audio(sampling_rate=SR))
    else:
        dataset_dir = os.path.join("data", "speech_data", dataset_name)
        if not os.path.exists(dataset_dir):
            os.makedirs(dataset_dir, exist_ok=True)
            # Create dataset by combining 3 datasets into an audiofolder
            csv_entries = []
            if (dataset_name == "fleurs_nl"):
                csv_entries += load_fleurs_nl(write_audio=True)
            elif (dataset_name == "fleurs_zu"):
                csv_entries += load_fleurs_zu(write_audio=True)
            metadata = pd.DataFrame(csv_entries, columns=['file_name', 'transcription'])
            metadata.to_csv(path_or_buf=os.path.join(dataset_dir, "metadata.csv"), sep=",", index=False)

            # Load dataset from audiofolder that you created
            dataset = load_dataset("audiofolder", data_dir=dataset_dir)
        else:
            # Load dataset from audiofolder that you created
            dataset = load_dataset("audiofolder", data_dir=dataset_dir)

        # Downsample audio to SR = 16000 and init train/val/test sets
        val_set = dataset["validation"].cast_column("audio", Audio(sampling_rate=SR)).rename_column("transcription", "sentence")
        test_set = dataset["test"].cast_column("audio", Audio(sampling_rate=SR)).rename_column("transcription", "sentence")
        val_set = val_set.map(remove_special_characters_batch)
        test_set = test_set.map(remove_special_characters_batch)
        # torch.cuda.empty_cache()

    # DL model
    model_with_LM = Wav2Vec2ForCTC.from_pretrained(repo_name).to(device)
    processor_with_LM = Wav2Vec2ProcessorWithLM.from_pretrained(repo_name)

    print(f"Results: {repo_name}", end="\n\n")

    # --------------------------------------------------------------
    # VALIDATION SET SCORE
    # --------------------------------------------------------------
    true_transcriptions = []
    model_with_LM_predictions = []

    for i in range(len(val_set)):
        pred_with_LM = predict_transcription(val_set[i], model_with_LM, processor_with_LM)
        model_with_LM_predictions.append(pred_with_LM)
        if "fleurs" in dataset_name:
            true_transcriptions.append(val_set[i]["sentence"].lower())
        else:
            true_transcriptions.append(val_set[i]["transcription"].lower())

        # Print progress
        print(f"\r{i+1}/{len(val_set)}\t\t", end="")
    print("")

    wer = load("wer")
    wer_score_model_with_LM = wer.compute(predictions=model_with_LM_predictions, references=true_transcriptions)
    print(f"Validation score: {wer_score_model_with_LM}", end="\n\n")

    # --------------------------------------------------------------
    # TEST SET SCORE
    # --------------------------------------------------------------
    true_transcriptions = []
    model_with_LM_predictions = []

    for i in range(len(test_set)):
        pred_with_LM = predict_transcription(test_set[i], model_with_LM, processor_with_LM)
        model_with_LM_predictions.append(pred_with_LM)
        if "fleurs" in dataset_name:
            true_transcriptions.append(test_set[i]["sentence"].lower())
        else:
            true_transcriptions.append(test_set[i]["transcription"].lower())

        # Print progress
        print(f"\r{i+1}/{len(test_set)}\t\t", end="")
    print("")

    wer = load("wer")
    wer_score_model_with_LM = wer.compute(predictions=model_with_LM_predictions, references=true_transcriptions)
    print(f"Test score: {wer_score_model_with_LM}", end="\n\n")