# **Fine-tuning XLSR-Wav2Vec2 for Multi-Lingual ASR with ü§ó Transformers**

## Pre-configuration

In [52]:
from ipywidgets import widgets

In [53]:
import os

In [101]:
language_code = 'ga-IE'
language_name = 'irish'
base_model = "facebook/wav2vec2-large-xlsr-53"
pretrain_model = f"jimregan/wav2vec2-large-xlsr-{language_name}-extra4"

data_dir = f"/workspace/data/{language_code}"
output_models_dir = f"/workspace/output_models/{language_code}/wav2vec2-large-xlsr-{language_name}-extra4"

In [102]:
from datasets import load_dataset, load_metric

common_voice_train = load_dataset("common_voice", language_code, split="train+validation")
common_voice_test = load_dataset("common_voice", language_code, split="test")

Reusing dataset common_voice (/workspace/.cache/huggingface/datasets/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f)
Reusing dataset common_voice (/workspace/.cache/huggingface/datasets/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f)


In [103]:
common_voice_train = common_voice_train.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "segment", "up_votes"])
common_voice_test = common_voice_test.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "segment", "up_votes"])

OVH crashes when trying to read the files from this dataset; so split, load the audio and save as arrow datasets later

In [6]:
#from datasets import load_dataset
#teanglann = load_dataset('json', data_files='/workspace/data/irish/teanglann.json', split='train')

Using custom data configuration default-92d97cb1c79a07fc
Reusing dataset json (/workspace/.cache/huggingface/datasets/json/default-92d97cb1c79a07fc/0.0.0/83d5b3a2f62630efc6b5315f00f20209b4ad91a00ac586597caee3a4da0bef02)


In [26]:
#teanglann[0]

Dataset({
    features: ['path', 'sentence'],
    num_rows: 5672
})

In [104]:
from datasets import Dataset
livingaudio = Dataset.load_from_disk('/workspace/data/irish/ga.ie.cll')

In [105]:
from datasets import load_dataset
fuaimeanna = load_dataset('csv', data_files='/workspace/data/irish/fuaimeanna-text.csv', split='train')

Using custom data configuration default-f1e3c9ea09de94af
Reusing dataset csv (/workspace/.cache/huggingface/datasets/csv/default-f1e3c9ea09de94af/0.0.0/2dc6629a9ff6b5697d82c25b73731dd440507a69cbce8b425db50b751e8fcfd0)


In [14]:
# merge after processing
#import datasets
#train_dataset = datasets.concatenate_datasets([teanglann, common_voice_train])

In [106]:
def is_upper_vowel(letter):
  if letter in ['A', 'E', 'I', 'O', 'U', '√Å', '√â', '√ç', '√ì', '√ö']:
    return True
  else:
    return False

def irish_lower(word):
  if len(word) > 1 and word[0] in ['n', 't'] and is_upper_vowel(word[1]):
    return word[0] + '-' + word[1:].lower()
  else:
    return word.lower()

def irish_lower_sentence(sentence):
  return " ".join([irish_lower(w) for w in sentence.split(" ")])


In [107]:
import re
chars_to_ignore_regex = '[,\?\.\!\;\:\"\‚Äú\%\‚Äò\‚Äù\(\)\*\‚Äì]'

def remove_special_characters(batch):
    tmp = re.sub('‚Äô ', ' ', batch["sentence"])
    tmp = re.sub("‚Äô$", '', tmp)
    tmp = re.sub('‚Äô', '\'', tmp)
    tmp = re.sub(chars_to_ignore_regex, '', tmp)
    batch["sentence"] = irish_lower_sentence(tmp).strip() + ' '
    return batch

In [108]:
common_voice_train = common_voice_train.map(remove_special_characters)
common_voice_test = common_voice_test.map(remove_special_characters)
#teanglann = teanglann.map(remove_special_characters)

Loading cached processed dataset at /workspace/.cache/huggingface/datasets/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-42ceb4f596ec5537.arrow
Loading cached processed dataset at /workspace/.cache/huggingface/datasets/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-08ecb622b391bcec.arrow


In [109]:
livingaudio = livingaudio.map(remove_special_characters)

Loading cached processed dataset at /workspace/data/irish/ga.ie.cll/cache-e2ba4a62057e8990.arrow


In [29]:
fuaimeanna

Dataset({
    features: ['path', 'sentence'],
    num_rows: 2283
})

In [110]:
fuaimeanna = fuaimeanna.map(remove_special_characters)

Loading cached processed dataset at /workspace/.cache/huggingface/datasets/csv/default-f1e3c9ea09de94af/0.0.0/2dc6629a9ff6b5697d82c25b73731dd440507a69cbce8b425db50b751e8fcfd0/cache-ce2816e0fccf49d5.arrow


In [111]:
vocab_list = [char for char in "a√°bcde√©fghi√≠jklmno√≥pqrstu√∫vwxyz'- "]
vocab_dict = {v: k for k, v in enumerate(vocab_list)}
vocab_dict

{'a': 0,
 '√°': 1,
 'b': 2,
 'c': 3,
 'd': 4,
 'e': 5,
 '√©': 6,
 'f': 7,
 'g': 8,
 'h': 9,
 'i': 10,
 '√≠': 11,
 'j': 12,
 'k': 13,
 'l': 14,
 'm': 15,
 'n': 16,
 'o': 17,
 '√≥': 18,
 'p': 19,
 'q': 20,
 'r': 21,
 's': 22,
 't': 23,
 'u': 24,
 '√∫': 25,
 'v': 26,
 'w': 27,
 'x': 28,
 'y': 29,
 'z': 30,
 "'": 31,
 '-': 32,
 ' ': 33}

In [112]:
vocab_dict["|"] = vocab_dict[" "]
del vocab_dict[" "]

In [113]:
vocab_dict["[UNK]"] = len(vocab_dict)
vocab_dict["[PAD]"] = len(vocab_dict)
len(vocab_dict)

36

In [114]:
import json
with open('vocab.json', 'w') as vocab_file:
    json.dump(vocab_dict, vocab_file)

In [115]:
from transformers import Wav2Vec2CTCTokenizer

tokenizer = Wav2Vec2CTCTokenizer("./vocab.json", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")

In [116]:
from transformers import Wav2Vec2FeatureExtractor

feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=True)

In [117]:
from transformers import Wav2Vec2Processor

processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

In [118]:
processor.save_pretrained(output_models_dir)

In [119]:
import torchaudio

def speech_file_to_array_fn(batch):
    speech_array, sampling_rate = torchaudio.load(batch["path"])
    batch["speech"] = speech_array[0].numpy()
    batch["sampling_rate"] = sampling_rate
    batch["target_text"] = batch["sentence"]
    return batch

In [120]:
common_voice_train = common_voice_train.map(speech_file_to_array_fn, remove_columns=common_voice_train.column_names)
common_voice_test = common_voice_test.map(speech_file_to_array_fn, remove_columns=common_voice_test.column_names)
#teanglann = teanglann.map(speech_file_to_array_fn, remove_columns=teanglann.column_names)

Loading cached processed dataset at /workspace/.cache/huggingface/datasets/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-233e362c82c57cbd.arrow
Loading cached processed dataset at /workspace/.cache/huggingface/datasets/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-e107863191228708.arrow


In [21]:
#teanglann.save_to_disk('/workspace/data/irish/teanglann')

In [121]:
from datasets import Dataset
teanglann = Dataset.load_from_disk('/workspace/data/irish/teanglann')

In [122]:
livingaudio = livingaudio.map(speech_file_to_array_fn, remove_columns=livingaudio.column_names)

Loading cached processed dataset at /workspace/data/irish/ga.ie.cll/cache-d35be5c1ea300185.arrow


In [123]:
fuaimeanna = fuaimeanna.map(speech_file_to_array_fn, remove_columns=fuaimeanna.column_names)

Loading cached processed dataset at /workspace/.cache/huggingface/datasets/csv/default-f1e3c9ea09de94af/0.0.0/2dc6629a9ff6b5697d82c25b73731dd440507a69cbce8b425db50b751e8fcfd0/cache-1f6acb1efa2b91b1.arrow


In [124]:
import librosa
import numpy as np

def resample(batch):
    batch["speech"] = librosa.resample(np.asarray(batch["speech"]), batch["sampling_rate"], 16_000)
    batch["sampling_rate"] = 16_000
    return batch

In [125]:
common_voice_train = common_voice_train.map(resample, num_proc=12)
common_voice_test = common_voice_test.map(resample, num_proc=12)
teanglann = teanglann.map(resample, num_proc=12)
livingaudio = livingaudio.map(resample, num_proc=12)
fuaimeanna = fuaimeanna.map(resample, num_proc=12)

  

Loading cached processed dataset at /workspace/.cache/huggingface/datasets/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-8fb992b263a78fa5.arrow


 

Loading cached processed dataset at /workspace/.cache/huggingface/datasets/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-be0843154634faa7.arrow


 

Loading cached processed dataset at /workspace/.cache/huggingface/datasets/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-a8bead7dd6b7b182.arrow
Loading cached processed dataset at /workspace/.cache/huggingface/datasets/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-ffa9de8c33e3d0fa.arrow


  

Loading cached processed dataset at /workspace/.cache/huggingface/datasets/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-e9a90ec295c31bd4.arrow
Loading cached processed dataset at /workspace/.cache/huggingface/datasets/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-c056f10aedcb156d.arrow


   

Loading cached processed dataset at /workspace/.cache/huggingface/datasets/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-6694955fc1f2d7fa.arrow
Loading cached processed dataset at /workspace/.cache/huggingface/datasets/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-a887264ae2ac7c20.arrow
Loading cached processed dataset at /workspace/.cache/huggingface/datasets/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-836d38b0fdd62a4e.arrow


 

Loading cached processed dataset at /workspace/.cache/huggingface/datasets/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-99cf6b168faf0405.arrow


  

Loading cached processed dataset at /workspace/.cache/huggingface/datasets/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-cc65aa36eaff56ef.arrow
Loading cached processed dataset at /workspace/.cache/huggingface/datasets/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-fa3f5ca243b75383.arrow


  

Loading cached processed dataset at /workspace/.cache/huggingface/datasets/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-bd46dbbe690f0be6.arrow
Loading cached processed dataset at /workspace/.cache/huggingface/datasets/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-5158189223ce0790.arrow


 

Loading cached processed dataset at /workspace/.cache/huggingface/datasets/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-1763d2540b66435e.arrow


  

Loading cached processed dataset at /workspace/.cache/huggingface/datasets/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-b4dd4bf42b649c4b.arrow
Loading cached processed dataset at /workspace/.cache/huggingface/datasets/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-767b276ea93afe4c.arrow


   

Loading cached processed dataset at /workspace/.cache/huggingface/datasets/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-cb5bd0475623eddd.arrow
Loading cached processed dataset at /workspace/.cache/huggingface/datasets/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-f22c010ac7bae2d5.arrow
Loading cached processed dataset at /workspace/.cache/huggingface/datasets/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-c68a55d216df753a.arrow


 

Loading cached processed dataset at /workspace/.cache/huggingface/datasets/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-2baf6d114205a7e2.arrow


  

Loading cached processed dataset at /workspace/.cache/huggingface/datasets/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-db985ac6fd3b8534.arrow


 

Loading cached processed dataset at /workspace/.cache/huggingface/datasets/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-4903f0bd77b0f7b0.arrow
Loading cached processed dataset at /workspace/.cache/huggingface/datasets/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-282ed708e2df192a.arrow


  

Loading cached processed dataset at /workspace/data/irish/teanglann/cache-54a26c07d22ce0b7.arrow
Loading cached processed dataset at /workspace/data/irish/teanglann/cache-5c7e1d3f5c60d720.arrow


  

Loading cached processed dataset at /workspace/data/irish/teanglann/cache-49122c9f7bfbb104.arrow
Loading cached processed dataset at /workspace/data/irish/teanglann/cache-8f044570ccf30f38.arrow


 

Loading cached processed dataset at /workspace/data/irish/teanglann/cache-f7b5c96e7318b1e9.arrow


 

Loading cached processed dataset at /workspace/data/irish/teanglann/cache-809d12b20bdc5d29.arrow


  

Loading cached processed dataset at /workspace/data/irish/teanglann/cache-fd4a799c9df5d22b.arrow
Loading cached processed dataset at /workspace/data/irish/teanglann/cache-77705fb40d5ea463.arrow


  

Loading cached processed dataset at /workspace/data/irish/teanglann/cache-8ad1ba31570fa11a.arrow
Loading cached processed dataset at /workspace/data/irish/teanglann/cache-1b22f6ffb0025fdc.arrow


  

Loading cached processed dataset at /workspace/data/irish/teanglann/cache-f9664a0be0bba037.arrow
Loading cached processed dataset at /workspace/data/irish/teanglann/cache-cbfd834446e0dfed.arrow


 

Loading cached processed dataset at /workspace/data/irish/ga.ie.cll/cache-51324bc82e65ecd0.arrow


 

Loading cached processed dataset at /workspace/data/irish/ga.ie.cll/cache-328f90385995870e.arrow
Loading cached processed dataset at /workspace/data/irish/ga.ie.cll/cache-75f2aea06c5320aa.arrow


  

Loading cached processed dataset at /workspace/data/irish/ga.ie.cll/cache-5353cac089b69d45.arrow
Loading cached processed dataset at /workspace/data/irish/ga.ie.cll/cache-db0464c2fa470b40.arrow


  

Loading cached processed dataset at /workspace/data/irish/ga.ie.cll/cache-ebca8676e6251b3c.arrow
Loading cached processed dataset at /workspace/data/irish/ga.ie.cll/cache-98e32fd80d22ae85.arrow


  

Loading cached processed dataset at /workspace/data/irish/ga.ie.cll/cache-fc09d610a24d8305.arrow


 

Loading cached processed dataset at /workspace/data/irish/ga.ie.cll/cache-01837b1a16efd098.arrow


 

Loading cached processed dataset at /workspace/data/irish/ga.ie.cll/cache-70b38f50ea616be2.arrow


  

Loading cached processed dataset at /workspace/data/irish/ga.ie.cll/cache-c1457f2f34b2d9ea.arrow
Loading cached processed dataset at /workspace/data/irish/ga.ie.cll/cache-1961b2be63d8742a.arrow


  

Loading cached processed dataset at /workspace/.cache/huggingface/datasets/csv/default-f1e3c9ea09de94af/0.0.0/2dc6629a9ff6b5697d82c25b73731dd440507a69cbce8b425db50b751e8fcfd0/cache-d835da1417c6542e.arrow
Loading cached processed dataset at /workspace/.cache/huggingface/datasets/csv/default-f1e3c9ea09de94af/0.0.0/2dc6629a9ff6b5697d82c25b73731dd440507a69cbce8b425db50b751e8fcfd0/cache-99537fcbc6a88648.arrow


 

Loading cached processed dataset at /workspace/.cache/huggingface/datasets/csv/default-f1e3c9ea09de94af/0.0.0/2dc6629a9ff6b5697d82c25b73731dd440507a69cbce8b425db50b751e8fcfd0/cache-8f0a769286fb07eb.arrow


 

Loading cached processed dataset at /workspace/.cache/huggingface/datasets/csv/default-f1e3c9ea09de94af/0.0.0/2dc6629a9ff6b5697d82c25b73731dd440507a69cbce8b425db50b751e8fcfd0/cache-531776cb6155ea4f.arrow


 

Loading cached processed dataset at /workspace/.cache/huggingface/datasets/csv/default-f1e3c9ea09de94af/0.0.0/2dc6629a9ff6b5697d82c25b73731dd440507a69cbce8b425db50b751e8fcfd0/cache-58d937aa068921de.arrow


 

Loading cached processed dataset at /workspace/.cache/huggingface/datasets/csv/default-f1e3c9ea09de94af/0.0.0/2dc6629a9ff6b5697d82c25b73731dd440507a69cbce8b425db50b751e8fcfd0/cache-cac96a911e2804ab.arrow


 

Loading cached processed dataset at /workspace/.cache/huggingface/datasets/csv/default-f1e3c9ea09de94af/0.0.0/2dc6629a9ff6b5697d82c25b73731dd440507a69cbce8b425db50b751e8fcfd0/cache-7169e5cbfcf192f4.arrow


 

Loading cached processed dataset at /workspace/.cache/huggingface/datasets/csv/default-f1e3c9ea09de94af/0.0.0/2dc6629a9ff6b5697d82c25b73731dd440507a69cbce8b425db50b751e8fcfd0/cache-36a4da5a56cd6d8d.arrow


 

Loading cached processed dataset at /workspace/.cache/huggingface/datasets/csv/default-f1e3c9ea09de94af/0.0.0/2dc6629a9ff6b5697d82c25b73731dd440507a69cbce8b425db50b751e8fcfd0/cache-375acf58b1e4a398.arrow


 

Loading cached processed dataset at /workspace/.cache/huggingface/datasets/csv/default-f1e3c9ea09de94af/0.0.0/2dc6629a9ff6b5697d82c25b73731dd440507a69cbce8b425db50b751e8fcfd0/cache-c628886a6210f5a6.arrow


  

Loading cached processed dataset at /workspace/.cache/huggingface/datasets/csv/default-f1e3c9ea09de94af/0.0.0/2dc6629a9ff6b5697d82c25b73731dd440507a69cbce8b425db50b751e8fcfd0/cache-1bc41d8fdb0bc8a5.arrow
Loading cached processed dataset at /workspace/.cache/huggingface/datasets/csv/default-f1e3c9ea09de94af/0.0.0/2dc6629a9ff6b5697d82c25b73731dd440507a69cbce8b425db50b751e8fcfd0/cache-d92a561a9873ed00.arrow


In [80]:
import datasets
train_dataset = datasets.concatenate_datasets([teanglann, common_voice_train, livingaudio, fuaimeanna])

In [42]:
#train_dataset.save_to_disk('/workspace/data/irish/previous_training')

In [126]:
def prepare_dataset(batch):
    # check that all files have the correct sampling rate
    assert (
        len(set(batch["sampling_rate"])) == 1
    ), f"Make sure all inputs have the same sampling rate of {processor.feature_extractor.sampling_rate}."

    batch["input_values"] = processor(batch["speech"], sampling_rate=batch["sampling_rate"][0]).input_values
    
    with processor.as_target_processor():
        batch["labels"] = processor(batch["target_text"]).input_ids
    return batch

In [127]:
common_voice_train = train_dataset.map(prepare_dataset, remove_columns=common_voice_train.column_names, batch_size=8, num_proc=12, batched=True)
common_voice_test = common_voice_test.map(prepare_dataset, remove_columns=common_voice_test.column_names, batch_size=8, num_proc=12, batched=True)

 

Loading cached processed dataset at /workspace/data/irish/teanglann/cache-97481a7fbe508a9c.arrow


    

Loading cached processed dataset at /workspace/data/irish/teanglann/cache-47e8ea28e4808b52.arrow
Loading cached processed dataset at /workspace/data/irish/teanglann/cache-060d0136b86e227b.arrow
Loading cached processed dataset at /workspace/data/irish/teanglann/cache-8af145ee847d623a.arrow
Loading cached processed dataset at /workspace/data/irish/teanglann/cache-3ae9ef823cbf103f.arrow


 

Loading cached processed dataset at /workspace/data/irish/teanglann/cache-f7bd2b90587993b3.arrow


  

Loading cached processed dataset at /workspace/data/irish/teanglann/cache-ca5c7a2972fd2e6b.arrow
Loading cached processed dataset at /workspace/data/irish/teanglann/cache-b36e0914c52c4b6b.arrow


   

Loading cached processed dataset at /workspace/data/irish/teanglann/cache-5e3974d18d8de48f.arrow
Loading cached processed dataset at /workspace/data/irish/teanglann/cache-437f1e64ce16425f.arrow


 

Loading cached processed dataset at /workspace/data/irish/teanglann/cache-6a1d9f42e92c093e.arrow
Loading cached processed dataset at /workspace/data/irish/teanglann/cache-3698e8c1be56ca4a.arrow


         

Loading cached processed dataset at /workspace/.cache/huggingface/datasets/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-4083777e5b21d2e6.arrow


 

Loading cached processed dataset at /workspace/.cache/huggingface/datasets/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-a2f6480f0c81b845.arrow
Loading cached processed dataset at /workspace/.cache/huggingface/datasets/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-fed2d46be89a6826.arrow
Loading cached processed dataset at /workspace/.cache/huggingface/datasets/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-fc74f342fa152265.arrow
Loading cached processed dataset at /workspace/.cache/huggingface/datasets/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-7b0b39f453a8802f.arrow
Loading cached processed dataset at /workspace/.cache/huggingface/datasets/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-fd4a640c58d2df21.arrow


  

Loading cached processed dataset at /workspace/.cache/huggingface/datasets/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-99f67547988544d3.arrow
Loading cached processed dataset at /workspace/.cache/huggingface/datasets/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-1e8b3f9f7582a68f.arrow
Loading cached processed dataset at /workspace/.cache/huggingface/datasets/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-017d4117dcf54a1a.arrow
Loading cached processed dataset at /workspace/.cache/huggingface/datasets/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-0cf31671a4ee4fb9.arrow
Loading cached processed dataset at /workspace/.cache/huggingface/datasets/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-95f6501bc73bb1a5.arrow
Loading cached processed datas

In [128]:
import torch

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union

@dataclass
class DataCollatorCTCWithPadding:
    """
    Data collator that will dynamically pad the inputs received.
    Args:
        processor (:class:`~transformers.Wav2Vec2Processor`)
            The processor used for proccessing the data.
        padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
            Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
            among:
            * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
              sequence if provided).
            * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
              maximum acceptable input length for the model if that argument is not provided.
            * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
              different lengths).
        max_length (:obj:`int`, `optional`):
            Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
        max_length_labels (:obj:`int`, `optional`):
            Maximum length of the ``labels`` returned list and optionally padding length (see above).
        pad_to_multiple_of (:obj:`int`, `optional`):
            If set will pad the sequence to a multiple of the provided value.
            This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
            7.5 (Volta).
    """

    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True
    max_length: Optional[int] = None
    max_length_labels: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None
    pad_to_multiple_of_labels: Optional[int] = None

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lenghts and need
        # different padding methods
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )
        with self.processor.as_target_processor():
            labels_batch = self.processor.pad(
                label_features,
                padding=self.padding,
                max_length=self.max_length_labels,
                pad_to_multiple_of=self.pad_to_multiple_of_labels,
                return_tensors="pt",
            )

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        batch["labels"] = labels

        return batch

In [129]:
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)

In [130]:
wer_metric = load_metric("wer")

In [131]:
def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(pred_ids)
    # we do not want to group tokens when computing the metrics
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    wer = wer_metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

In [132]:
from transformers import Wav2Vec2ForCTC

model = Wav2Vec2ForCTC.from_pretrained(
    "facebook/wav2vec2-large-xlsr-53", 
    attention_dropout=0.055,
    hidden_dropout=0.047,
    feat_proj_dropout=0.04,
    mask_time_prob=0.082,
    layerdrop=0.041,
    gradient_checkpointing=True, 
    ctc_loss_reduction="mean", 
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=len(processor.tokenizer)
)

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-large-xlsr-53 and are newly initialized: ['lm_head.bias', 'lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [133]:
model.freeze_feature_extractor()

In [134]:
from transformers import TrainingArguments

training_args = TrainingArguments(
  output_dir=output_models_dir,
  group_by_length=True,
  per_device_train_batch_size=16,
  gradient_accumulation_steps=2,
  evaluation_strategy="steps",
  num_train_epochs=13,
  fp16=True,
  save_steps=400,
  eval_steps=400,
  logging_steps=400,
  learning_rate=2.34e-4,
  warmup_steps=500,
  save_total_limit=20,
)

Using the `WAND_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


In [135]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=common_voice_train,
    eval_dataset=common_voice_test,
    tokenizer=processor.feature_extractor,
)

In [None]:
trainer.train()



Step,Training Loss,Validation Loss,Wer,Runtime,Samples Per Second
400,6.3177,3.035551,0.949956,48.9714,10.333
800,1.2729,1.585447,0.914545,48.433,10.447
1200,inf,1.399172,1.048581,48.5188,10.429
1600,,1.242888,0.853088,47.8496,10.575
2000,,1.191748,0.829968,47.9145,10.56
2400,,1.135158,0.813872,47.6923,10.61
2800,,1.134999,0.819432,48.9846,10.33
3200,,1.044153,0.821481,48.5205,10.429
3600,,1.077358,0.807141,48.962,10.335
4000,,0.994849,0.77729,48.8829,10.351


In [41]:
trainer.save_model(output_models_dir)
tokenizer.save_pretrained(output_models_dir)

('/workspace/output_models/ga-IE/wav2vec2-large-xlsr-irish-base/tokenizer_config.json',
 '/workspace/output_models/ga-IE/wav2vec2-large-xlsr-irish-base/special_tokens_map.json',
 '/workspace/output_models/ga-IE/wav2vec2-large-xlsr-irish-base/vocab.json',
 '/workspace/output_models/ga-IE/wav2vec2-large-xlsr-irish-base/added_tokens.json')

In [None]:
trainer.save_model('/workspace/output_models/newest-run')
tokenizer.save_pretrained('/workspace/output_models/newest-run')

In [42]:
model2 = Wav2Vec2ForCTC.from_pretrained(output_models_dir).to("cuda")
processor2 = Wav2Vec2Processor.from_pretrained(output_models_dir)

Special tokens have been added in the vocabulary, make sure the associated word embedding are fine-tuned or trained.


Now, we will just take the first example of the test set, run it through the model and take the `argmax(...)` of the logits to retrieve the predicted token ids.

In [43]:
input_dict = processor2(common_voice_test["input_values"][0], return_tensors="pt", padding=True)

logits = model2(input_dict.input_values.to("cuda")).logits

pred_ids = torch.argmax(logits, dim=-1)[0]

It is strongly recommended to pass the ``sampling_rate`` argument to this function.Failing to do so can result in silent errors that might be hard to debug.


We adapted `common_voice_test` quite a bit so that the dataset instance does not contain the original sentence label anymore. Thus, we re-use the original dataset to get the label of the first example.

In [44]:
common_voice_test_transcription = load_dataset("common_voice", language_code, data_dir=data_dir, split="test")

Using custom data configuration ga-IE-d1da170b20bac7b9
Reusing dataset common_voice (/workspace/.cache/huggingface/datasets/common_voice/ga-IE-d1da170b20bac7b9/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f)


Finally, we can decode the example.

In [45]:
print("Prediction:")
print(processor2.decode(pred_ids))

print("\nReference:")
print(common_voice_test_transcription["sentence"][0].lower())


Prediction:
bh√≠ tara viscair√≠ agus meadh bh o'rourc seaint i n√°rd√≠n na mbl√°

Reference:
"bh√≠ tara viscardi agus meadhbh o'rourke ag seinnt i ngaird√≠n na mbl√°th"


Alright! The transcription can definitely be recognized from our prediction, but it is far from being perfect. Training the model a bit longer, spending more time on the data preprocessing, and especially using a language model for decoding would certainly improve the model's overall performance. 

For a demonstration model on a low-resource language, the results are acceptable, however ü§ó.

In [46]:
import torch
import torchaudio
from datasets import load_dataset, load_metric
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import re
test_dataset = load_dataset("common_voice", "ga-IE", split="test")
wer = load_metric("wer")
def is_upper_vowel(letter):
    if letter in ['A', 'E', 'I', 'O', 'U', '√Å', '√â', '√ç', '√ì', '√ö']:
        return True
    else:
        return False
def irish_lower(word):
    if len(word) > 1 and word[0] in ['n', 't'] and is_upper_vowel(word[1]):
        return word[0] + '-' + word[1:].lower()
    else:
        return word.lower()
def irish_lower_sentence(sentence):
    return " ".join([irish_lower(w) for w in sentence.split(" ")])
chars_to_ignore_regex = '[,\?\.\!\;\:\"\‚Äú\%\‚Äò\‚Äù\(\)\*]'
def remove_special_characters(sentence):
  tmp = re.sub('‚Äô ', ' ', sentence)
  tmp = re.sub("‚Äô$", '', tmp)
  tmp = re.sub('‚Äô', '\'', tmp)
  tmp = re.sub(chars_to_ignore_regex, '', tmp)
  sentence = irish_lower_sentence(tmp) + ' '
  return sentence
resampler = torchaudio.transforms.Resample(48_000, 16_000)
# Preprocessing the datasets.
# We need to read the audio files as arrays
def speech_file_to_array_fn(batch):
    batch["sentence"] = remove_special_characters(batch["sentence"])
    speech_array, sampling_rate = torchaudio.load(batch["path"])
    batch["speech"] = resampler(speech_array).squeeze().numpy()
    return batch
test_dataset = test_dataset.map(speech_file_to_array_fn)

# Preprocessing the datasets.
# We need to read the aduio files as arrays
def evaluate(batch):
    inputs = processor2(batch["speech"], sampling_rate=16_000, return_tensors="pt", padding=True)

    with torch.no_grad():
        logits = model2(inputs.input_values.to("cuda"), attention_mask=inputs.attention_mask.to("cuda")).logits

    pred_ids = torch.argmax(logits, dim=-1)
    batch["pred_strings"] = processor.batch_decode(pred_ids)
    return batch

result = test_dataset.map(evaluate, batched=True, batch_size=8)

print("WER: {:2f}".format(100 * wer.compute(predictions=result["pred_strings"], references=result["sentence"])))

Reusing dataset common_voice (/workspace/.cache/huggingface/datasets/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f)
Loading cached processed dataset at /workspace/.cache/huggingface/datasets/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-97fff0ec51dc91d3.arrow


HBox(children=(FloatProgress(value=0.0, max=64.0), HTML(value='')))


WER: 43.680515
