This notebook is inspired from a very good [HuggingFace Tutorial](https://colab.research.google.com/github/patrickvonplaten/notebooks/blob/master/Fine_Tune_XLSR_Wav2Vec2_on_Turkish_ASR_with_%F0%9F%A4%97_Transformers.ipynb#scrollTo=bTjNp2KUYAl8)

# pip install

In [1]:
!pip install phonemizer
!apt-get install espeak
!pip install git+https://github.com/huggingface/datasets.git
!pip install git+https://github.com/huggingface/transformers.git
!pip install torchaudio
!pip install librosa
!pip install jiwer
!pip install soundfile

Collecting phonemizer
[?25l  Downloading https://files.pythonhosted.org/packages/86/6d/fb1757f006b584469bc0b9d56209b2ac873420033133b7da58e49033862e/phonemizer-2.2.2-py3-none-any.whl (49kB)
[K     |██████▋                         | 10kB 24.8MB/s eta 0:00:01[K     |█████████████▎                  | 20kB 20.6MB/s eta 0:00:01[K     |███████████████████▉            | 30kB 11.2MB/s eta 0:00:01[K     |██████████████████████████▌     | 40kB 9.1MB/s eta 0:00:01[K     |████████████████████████████████| 51kB 2.9MB/s 
Collecting segments
  Downloading https://files.pythonhosted.org/packages/1e/ae/02d31d73cfc3fa1dc74b7b7f14820fadc287e74406583d7af7b80fcaac41/segments-2.2.0-py2.py3-none-any.whl
Collecting clldutils>=1.7.3
[?25l  Downloading https://files.pythonhosted.org/packages/78/b9/456cc0fa90dd0f0b5710a0e8e728925cbd532ef4ade0f49a8c01d5ebc69c/clldutils-3.7.0-py2.py3-none-any.whl (190kB)
[K     |████████████████████████████████| 194kB 7.7MB/s 
[?25hCollecting csvw>=1.5.6
  Downloading

# Import

In [None]:
!nvidia-smi -L

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [29]:
# Import libraries
from datasets import load_dataset, load_metric, ClassLabel, load_from_disk
import datasets
datasets.set_caching_enabled(False)

import torch

from dataclasses import dataclass, field

from typing import Any, Dict, List, Optional, Union

from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, Wav2Vec2Processor, Wav2Vec2ForCTC
from transformers import AdamW, get_linear_schedule_with_warmup, get_polynomial_decay_schedule_with_warmup
from transformers import trainer_pt_utils

from torch.utils.data.dataloader import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau

import random
import math
import pandas as pd
import numpy as np
import sys

from IPython.display import display, HTML

import re
import json
import os
from tqdm.notebook import tqdm

sys.path.append('/content/drive/MyDrive/speech_w2v/')
from utils import *
from trainer import Trainer

# Visualisation Function

In [2]:
# Visualisation
def show_random_elements(dataset, num_examples=10):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    df = pd.DataFrame(dataset[picks])
    display(HTML(df.to_html()))

# Load & Preprocess Dataset
IF YOU DON'T HAVE ALREADY THE DATASET PREPROCESSED CONTINUE, OTHERWISE SKIP THIS SECTION

## Download/Load

First we are going to choose one language (you can look on https://huggingface.co/datasets/common_voice for other code's languages)

In [3]:
code_lang = "cs" # You can change if you want another language from the common voice dataset

For this experience, we chose Czech. \
Let's download the dataset.

In [4]:
common_voice = load_dataset("common_voice", "cs", data_dir="./cv-corpus-6.1-2020-12-11", split="train+validation")
common_voice_test = load_dataset("common_voice", "cs", data_dir="./cv-corpus-6.1-2020-12-11", split="test")

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=4323.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=9954.0, style=ProgressStyle(description…




Downloading and preparing dataset common_voice/cs (download: 1.18 GiB, generated: 9.06 MiB, post-processed: Unknown size, total: 1.19 GiB) to /root/.cache/huggingface/datasets/common_voice/cs-ad9f7b76efa9f3a0/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f...


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1271909933.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Dataset common_voice downloaded and prepared to /root/.cache/huggingface/datasets/common_voice/cs-ad9f7b76efa9f3a0/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f. Subsequent calls will reuse this data.




If you are going to use only audio & transcription, you can remove the other columns.

In [14]:
common_voice = common_voice.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "segment", "up_votes"])
common_voice_test = common_voice_test.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "segment", "up_votes"])

## Preprocess
We are going to preprocess the dataset

In [7]:
show_random_elements(common_voice.remove_columns(['path']), num_examples=10)

Unnamed: 0,sentence
0,Velká Sestra tě sleduje!
1,Další následovala v následujících letech.
2,Tu o ztraceném kůzlátku.
3,Domén s diakritikou se v Česku jen tak nedočkáme.
4,Navštěvoval střední školu v Manchesteru a později studoval na univerzitě v Manchesteru a Cambridge.
5,Není žádná čistá filosofie existence.
6,Potřebujeme skutečně společnou směrnici o pracovní době?
7,Kromě obličejů jsme zatím dělali testy jen na již segmentovaných číslech.
8,Svatý Stolec je reprezentován apoštolským nunciem.
9,Pro poskytnutí lze využít různé pomůcky a metody.


Now, it depends on what we want to evaluate, if we want to evaluate the phonem transcription or the word transcription.

### Word Transcription

We are going to preprocess the text and remove some special symbol `,.?!;` as we don't have any language model at the output

In [11]:
chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"\“\%\‘\”\�\—\…\–\«\»]' # You can modify or add things here
preprocess = Preprocess(chars_to_ignore_regex)
common_voice = common_voice.map(preprocess.remove_special_characters, remove_columns=["sentence"])
common_voice_test = common_voice_test.map(preprocess.remove_special_characters, remove_columns=["sentence"])

HBox(children=(FloatProgress(value=0.0, max=9773.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=4144.0), HTML(value='')))




In [12]:
show_random_elements(common_voice.remove_columns(['path']), num_examples=10)

Unnamed: 0,text
0,je to film stejného režiséra mladého tomáše pavlíčka
1,a máte pocit že tam či onam někdo nepatří
2,barvy byly neuvěřitelné
3,kůzlátko je samo doma
4,je aktivní převážně v noci a živí se především rybami
5,tříaktová hra o dobytí cizího sklepa začíná
6,po odsunu německých obyvatel kaple postupně chátrala
7,i srdce je větší
8,z tohoto hlediska nemám žádné pochybnosti
9,jsem velmi potěšena že evropský parlament činí taková konstruktivní rozhodnutí


### Phonem Transcription

For phonem transcription we need first to convert the text to phonemes.

In [15]:
txt2phon = Text2Phoneme(language='cs')
common_voice = common_voice.map(txt2phon.text2phoneme, num_proc=4)
common_voice_test = common_voice_test.map(txt2phon.text2phoneme, num_proc=4)
common_voice = common_voice.rename_column("sentence", "text")
common_voice_test = common_voice_test.rename_column("sentence", "text")

    

HBox(children=(FloatProgress(value=0.0, description='#2', max=2443.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='#1', max=2443.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='#3', max=2443.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='#0', max=2444.0, style=ProgressStyle(description_width='i…





    

HBox(children=(FloatProgress(value=0.0, description='#0', max=1036.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='#3', max=1036.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='#2', max=1036.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='#1', max=1036.0, style=ProgressStyle(description_width='i…







In [16]:
show_random_elements(common_voice.remove_columns(['path']), num_examples=10)

Unnamed: 0,text
0,jednaː se opostavu ze steɪnojmeneːho komiksu
1,botanitski neɲiː oblast pr̝̊iːliʒ zajiːmavaː
2,pr̝̊ipoj to na steɪnosmɲerniː adapteːr
3,podle statistik se zdravotɲiː situatse v evropje shorʃuje
4,tr̝̊inaːtst ameritʃanuː jeden kanaɟan a jeden aʊstralan
5,fstal tedi a poloʒil ruku nasr̩ttse
6,ɲikdo s kompetentɲiːx liɟiː naːm tuto teoriji neportvr̩ɟil
7,pak se uːzemiː tʃeskoslovenska rozɟelilo nakraje a tʃexi jako spraːvɲiː jednotka zaɲikli
8,klasifikaːtor se zvjetʃuje
9,eksistujiː ruːzneː naːmɲeti kontse teːto hri a to i v kɲiʒɲiːm vidaːɲiː


## Building Vocabulary

As we are going to use a CTC (as top layer), we are going to classify speech chunks into letters, so now we will extract all distinct letters and build our vocabulary from that.

In [17]:
vocab_train = common_voice.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice.column_names)
vocab_test = common_voice_test.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_test.column_names)

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




Now we will create the union of all distinct letters from both dataset. We will do the same thing as when we are dealing with translation / generation task.

In [18]:
vocab_list = list(set(vocab_train["vocab"][0]) | set(vocab_test["vocab"][0]))
vocab_dict = {v: k for k, v in enumerate(vocab_list)}
vocab_dict

{' ': 34,
 'a': 5,
 'b': 18,
 'c': 14,
 'd': 20,
 'e': 16,
 'f': 30,
 'h': 27,
 'i': 37,
 'j': 17,
 'k': 11,
 'l': 1,
 'm': 8,
 'n': 13,
 'o': 26,
 'p': 6,
 'r': 12,
 's': 4,
 't': 15,
 'u': 35,
 'v': 21,
 'x': 0,
 'z': 19,
 'ŋ': 32,
 'ɟ': 29,
 'ɡ': 22,
 'ɣ': 7,
 'ɪ': 2,
 'ɲ': 33,
 'ɹ': 36,
 'ʃ': 10,
 'ʊ': 3,
 'ʑ': 28,
 'ʒ': 9,
 'ː': 23,
 '̊': 31,
 '̝': 24,
 '̩': 25}

In [19]:
# Adding the blank token, the unknown token and the padding token
vocab_dict["|"] = vocab_dict[" "]
del vocab_dict[" "]
vocab_dict["[UNK]"] = len(vocab_dict)
vocab_dict["[PAD]"] = len(vocab_dict)
print(f"Our final layer will have as output dimension {len(vocab_dict)}")

Our final layer will have as output dimension 40


In [20]:
# Now let's save our dictionary
parent_dir = "/content/drive/MyDrive/speech_w2v/" # Here you have to put where you want to save the vocabulary 
with open(os.path.join(parent_dir, 'czeck_phonem_vocab.json'), 'w') as vocab_file:
    json.dump(vocab_dict, vocab_file)

## Audio Preprocessing

Now we are going to open and store the audio file (represented as a numpy array)

In [21]:
audio_preprocess = Preprocess(is_torch=True, source_sampling=48_000, target_sampling=16_000)
common_voice = common_voice.map(audio_preprocess.speech_file_to_array_fn, remove_columns=common_voice.column_names)
common_voice_test = common_voice_test.map(audio_preprocess.speech_file_to_array_fn, remove_columns=common_voice_test.column_names)

HBox(children=(FloatProgress(value=0.0, max=9773.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=4144.0), HTML(value='')))




### Resample
**If you dataset is already sampled to 16kHZ, skip this step** \

Wav2Vec2 (XLSR or English Only) was pretrained on the audio data of Babel, Multilingual LibriSpeech (MLS), and Common Voice. Most of those datasets were sampled at 16kHz, so that Common Voice, sampled at 48kHz, has to be downsampled to 16kHz for training. Therefore, we will have to downsample our fine-tuning data to 16kHz in the following.

In [22]:
# First we have to downsampled the original sample from 48 kHZ to 16kHZ
common_voice = common_voice.map(audio_preprocess.resample, num_proc=4)
common_voice_test = common_voice_test.map(audio_preprocess.resample, num_proc=4)

    

HBox(children=(FloatProgress(value=0.0, description='#2', max=2443.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='#0', max=2444.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='#3', max=2443.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='#1', max=2443.0, style=ProgressStyle(description_width='i…





    

HBox(children=(FloatProgress(value=0.0, description='#2', max=1036.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='#0', max=1036.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='#3', max=1036.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='#1', max=1036.0, style=ProgressStyle(description_width='i…







# Train Dev Test
Now we are going to split our data intro three subsets. Fortunately, the common voice already provide us with these subset. \
Nevertheless if you want to do your own split, you can follow these steps (note that I will not modify the test set as it is already given by CV, so it's better to keep the same testing set in order to have a fair and good comparison)

In [23]:
# Split into train/dev
np.random.seed(42)
data = common_voice.train_test_split(test_size=0.2, seed=42)
common_voice_train, common_voice_validation = data['train'], data['test']

Now if you want to make experimentation (as I did) and see how these pretrained models performs with few labeled data, you can split the train into different subsets (10mn, 1h, 10h for instance)

In [None]:
# Now let's shuffle data
common_voice_train = common_voice_train.shuffle(seed=42)

In [None]:
total_len_seconds = 0
indices_10mn = []
indices_1h = []
indices_10h = []
for i in tqdm(range(len(common_voice_train))):
  speech_array, sampling_rate = common_voice_train[i]["speech"], common_voice_train[i]["sampling_rate"]
  duration_audio = len(speech_array) * (1/sampling_rate)
  if total_len_seconds <= 600: # 600 => 10 minutes
    indices_10mn.append(i)
  if total_len_seconds <= 3600: # 3600 => 60 minutes => 1 heure
    indices_1h.append(i)
  if total_len_seconds <= 36000: # 36000 => 600 minutes => 10 heures
    indices_10h.append(i)
  if total_len_seconds > 36000:
    break
  total_len_seconds += duration_audio  

In [None]:
common_voice_train_10mn = common_voice_train.select(indices_10mn)
common_voice_train_1h = common_voice_train.select(indices_1h)
common_voice_train_10h = common_voice_train.select(indices_10h)

# Final Preparation of the Data

In [24]:
# Loading tokenizer
tokenizer = Wav2Vec2CTCTokenizer(os.path.join(parent_dir, 'czeck_phonem_vocab.json'), unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")
# Load Feature Extractor
feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=True)
# Wrap the feature_extractor and the tokenizer into one class (thanks so much HuggingFace)
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
# Initialize our data_collator
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)

In [25]:
PrepData = PrepareDataset(processor)
common_voice_train_10mn = common_voice_train_10mn.map(PrepData.prepare_dataset, remove_columns=common_voice_train_10mn.column_names, batch_size=8, num_proc=4, batched=True)
common_voice_train_1h = common_voice_train_1h.map(PrepData.prepare_dataset, remove_columns=common_voice_train_1h.column_names, batch_size=8, num_proc=4, batched=True)
common_voice_train_10h = common_voice_train_10h.map(PrepData.prepare_dataset, remove_columns=common_voice_train_10h.column_names, batch_size=8, num_proc=4, batched=True)
common_voice_train = common_voice_train.map(PrepData.prepare_dataset, remove_columns=common_voice_train.column_names, batch_size=8, num_proc=4, batched=True)
common_voice_validation = common_voice_validation.map(PrepData.prepare_dataset, remove_columns=common_voice_validation.column_names, batch_size=8, num_proc=4, batched=True)
common_voice_test = common_voice_test.map(PrepData.prepare_dataset, remove_columns=common_voice_test.column_names, batch_size=8, num_proc=4, batched=True)

    

HBox(children=(FloatProgress(value=0.0, description='#0', max=245.0, style=ProgressStyle(description_width='in…

HBox(children=(FloatProgress(value=0.0, description='#1', max=245.0, style=ProgressStyle(description_width='in…

HBox(children=(FloatProgress(value=0.0, description='#3', max=245.0, style=ProgressStyle(description_width='in…

HBox(children=(FloatProgress(value=0.0, description='#2', max=245.0, style=ProgressStyle(description_width='in…

  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)






    

HBox(children=(FloatProgress(value=0.0, description='#1', max=62.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, description='#2', max=62.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, description='#0', max=62.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, description='#3', max=61.0, style=ProgressStyle(description_width='ini…

  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)






    

HBox(children=(FloatProgress(value=0.0, description='#3', max=130.0, style=ProgressStyle(description_width='in…

HBox(children=(FloatProgress(value=0.0, description='#2', max=130.0, style=ProgressStyle(description_width='in…

HBox(children=(FloatProgress(value=0.0, description='#0', max=130.0, style=ProgressStyle(description_width='in…

HBox(children=(FloatProgress(value=0.0, description='#1', max=130.0, style=ProgressStyle(description_width='in…

  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)








We can to disk the data ... but **you will need extra memory as the files are huge**

In [26]:
common_voice_train_10mn.save_to_disk(os.path.join(parent_dir, 'train_czeck_phonem_10mn.files')
common_voice_train_1h.save_to_disk(os.path.join(parent_dir, 'train_czeck_phonem_1h.files'))
common_voice_train_10h.save_to_disk(os.path.join(parent_dir, 'train_czeck_phonem_10h.files'))
common_voice_train.save_to_disk(os.path.join(parent_dir, 'train_czeck_phonem.files'))
common_voice_validation.save_to_disk(os.path.join(parent_dir, 'validation_czeck_phonem.files'))
common_voice_test.save_to_disk(os.path.join(parent_dir, 'test_czeck_phonem.files'))

# Training

## Loading
Skip these cells if you don't need to load the dataset, tokenizer and processor and data_collator.

In [None]:
parent_dir = "/content/drive/MyDrive/speech_w2v/" # Your path

In [None]:
common_voice_train_10mn = load_from_disk(os.path.join(parent_dir, 'train_czeck_phonem_10mn.files')
common_voice_train_1h = load_from_disk(os.path.join(parent_dir, 'train_czeck_phonem_1h.files'))
common_voice_train_10h = load_from_disk(os.path.join(parent_dir, 'train_czeck_phonem_10h.files'))
common_voice_train = load_from_disk(os.path.join(parent_dir, 'train_czeck_phonem.files'))
common_voice_validation = load_from_disk(os.path.join(parent_dir, 'validation_czeck_phonem.files'))
common_voice_test = load_from_disk(os.path.join(parent_dir, 'test_czeck_phonem.files'))

In [None]:
# Loading tokenizer
tokenizer = Wav2Vec2CTCTokenizer(os.path.join(parent_dir, 'czeck_phonem_vocab.json'), unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")
# Load Feature Extractor
feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=True)
# Wrap the feature_extractor and the tokenizer into one class (thanks so much HuggingFace)
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
# Prepare our data collator
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)

## Run Training

The first component of XLSR-Wav2Vec2 consists of a stack of CNN layers that are used to extract acoustically meaningful - but contextually independent - features from the raw speech signal. This part of the model has already been sufficiently trained during pretraining and as stated in the [paper](https://arxiv.org/pdf/2006.13979.pdf) does not need to be fine-tuned anymore. 
Thus, we can set the `requires_grad` to `False` for all parameters of the *feature extraction* part.

I'm just going to take the first 100 samples from the common_voice_train as this notebook is just a template/example. You should uncomment the next cell if you want to run your own experiments

In [27]:
common_voice_train_example = common_voice_train.select(range(100))

In [33]:
# Load model
model = Wav2Vec2ForCTC.from_pretrained(
    "facebook/wav2vec2-large-xlsr-53", 
    # "facebook/wav2vec2-base-960h",
    # "facebook/wav2vec2-large-lv60",
    attention_dropout=0.1,
    hidden_dropout=0.1,
    feat_proj_dropout=0.0,
    mask_time_prob=0.05,
    layerdrop=0.1,
    gradient_checkpointing=True, 
    ctc_loss_reduction="mean", 
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=len(processor.tokenizer)
)

# Freeze the feature extractor
model.freeze_feature_extractor()

# Set to GPU
model.cuda()

# Get sampler
model_input_name = processor.feature_extractor.model_input_names[0]
sampler_train = trainer_pt_utils.LengthGroupedSampler(common_voice_train_example, batch_size=12, model_input_name=model_input_name)
sampler_val = trainer_pt_utils.LengthGroupedSampler(common_voice_validation, batch_size=12, model_input_name=model_input_name)

# Get Loader
train_loader = DataLoader(common_voice_train_example, batch_size=12, sampler=sampler_train, collate_fn=data_collator, num_workers=4)
valid_loader = DataLoader(common_voice_validation, batch_size=12, sampler=sampler_val, collate_fn=data_collator, num_workers=4)

#
learning_rate = 4e-4
n_epochs = 2 

num_update_steps_per_epoch = len(train_loader)
max_steps = math.ceil(n_epochs * num_update_steps_per_epoch)
validation_freq = int(1*num_update_steps_per_epoch)
print_freq = int(1*num_update_steps_per_epoch)
scheduler_on_plateau_freq = int(num_update_steps_per_epoch)

# Optimizer
decay_parameters = trainer_pt_utils.get_parameter_names(model, [torch.nn.LayerNorm])
decay_parameters = [name for name in decay_parameters if "bias" not in name]
optimizer_grouped_parameters = [
    {
        "params": [p for n, p in model.named_parameters() if n in decay_parameters],
        "weight_decay": 0.0,
    },
    {
        "params": [p for n, p in model.named_parameters() if n not in decay_parameters],
        "weight_decay": 0.0,
    },
]
optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate)

# Scheduler
num_warmup_steps = int(0.5 * num_update_steps_per_epoch) # Neccessary Number of steps to go from 0.0 to lr 
#warmup_scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps, max_steps)
warmup_scheduler = get_polynomial_decay_schedule_with_warmup(optimizer, num_warmup_steps, max_steps, lr_end=1e-7)
reduce_lr_plateau = None
## reduce_lr_plateau = ReduceLROnPlateau(optimizer, factor=0.6, patience=7) ## To define when warmup scheduler is finished


trainer = Trainer(model=model, processor=processor, 
                  optimizer=optimizer, warmup_scheduler=warmup_scheduler,
                  validation_freq=validation_freq, log_freq=print_freq,
                  num_warmup_steps=num_warmup_steps, is_plateau_scheduler=False, 
                  type_score='PER')

trainer.train(train_loader, valid_loader, n_epochs, path=os.path.join(parent_dir, "model.pt"))

 Epoch  |  Batch  |  Train Loss  | Val Metric |  Elapsed 
----------------------------------------------------------------------
   1    |    9    |   4.373020   |     -     |   15.14  
   1    |    9    |   4.373020   |   0.99    |   85.59  
Hooray! New Best Validation Score, Saving model.
   2    |    9    |   3.502168   |     -     |   15.61  
   2    |    9    |   3.502168   |   0.99    |   86.56  


# Result on Test Set

In [36]:
sampler_test = trainer_pt_utils.LengthGroupedSampler(common_voice_test, batch_size=32, model_input_name=model_input_name)
test_loader = DataLoader(common_voice_test, batch_size=32, sampler=sampler_test, collate_fn=data_collator, num_workers=4)
print(f"The final PER score on the test set is {trainer.compute_score(test_loader, os.path.join(parent_dir, 'model.pt'))}")

The final PER score on the test set is 0.9946907071456387
