# Fine-tune a pre-trained model using Afrikaans/isiXhosa data

### 1. Install dependencies

In [None]:
sudo_password = False
if sudo_password:
    !echo "awe" > password.txt

In [35]:
# Install Python dependencies
!pip3 install -r requirements.txt

# Install GitLFS
if sudo_password:
    !sudo -S apt-get update < password.txt
    !sudo apt-get install git-lfs tree
else:
    !apt-get update
    !apt-get install git-lfs



### 2. Setup experiment

#### NB!!!!!!!!!!!!!!!!!! Change WRITE_ACCESS_TOKEN to your HF token in ``utils.py``

### 2.1 Choose dataset

In [2]:
dataset_name = "asr_af"
# dataset_name = "asr_xh"

### 2.2 Choose pre-trained model

In [3]:
PULL_REPO = "facebook/wav2vec2-xls-r-300m"

### 2.3 Choose repo name for pushing fine-tuned model

In [6]:
run_number = int(input("Provide the run number please: "))
PUSH_REPO = f"wav2vec2-xls-r-300m-{dataset_name}-run{run_number}"

Provide the run number please: 2


### 3. Import libraries

In [37]:
# Standard imports
import os
import json
import torch
import evaluate
import numpy as np
import pandas as pd
from dataclasses import dataclass
from typing import Dict, List, Union

# My own imports
from load_nchlt import load_nchlt
from load_fleurs import load_fleurs
from load_high_quality_tts import load_high_quality_tts
from utils import (
    SR, WRITE_ACCESS_TOKEN,
    clear_cache,
    remove_special_characters_batch
)

# HuggingFace imports
from datasets import Audio
from datasets import load_dataset
from huggingface_hub import login
from transformers import Wav2Vec2ForCTC
from transformers import Wav2Vec2Processor
from transformers import Wav2Vec2CTCTokenizer
from transformers import Wav2Vec2FeatureExtractor
from transformers import Trainer, TrainingArguments, EarlyStoppingCallback

### 4. ASR model class

I made this for convenience in certain cases.

In [38]:
class ASR_MODEL:
    def __init__(self, repo_name, dataset_name, load_from_hf, push_dataset, push_repo, write_audio):
        self.repo_name = repo_name
        self.dataset_name = dataset_name
        self.load_from_hf = load_from_hf
        self.push_dataset = push_dataset
        self.push_repo = push_repo
        self.write_audio = write_audio
        os.makedirs(repo_name, exist_ok=True)

        self.train_set = None
        self.val_set = None
        self.test_set = None
        self.tokenizer = None
        self.feature_extractor = None
        self.processor = None
        self.data_collator = None
        self.xlsr_model = None
        self.trainer = None

    def load_datasets(self):
        self.dataset_dir = os.path.join("data", "speech_data", self.dataset_name)
        if not self.load_from_hf:
            if not os.path.exists(self.dataset_dir):
                os.makedirs(self.dataset_dir, exist_ok=True)
                # Create dataset by combining 3 datasets into an audiofolder
                csv_entries = []
                if (self.dataset_name == "asr_af"):
                    csv_entries += load_fleurs(language="af", write_audio=self.write_audio)
                    csv_entries += load_high_quality_tts(language="af", write_audio=self.write_audio)
                    # csv_entries += load_nchlt(language="af", write_audio=self.write_audio)
                elif (self.dataset_name == "asr_xh"):
                    csv_entries += load_fleurs(language="xh", write_audio=self.write_audio)
                    csv_entries += load_high_quality_tts(language="xh", write_audio=self.write_audio)
                    csv_entries += load_nchlt(language="xh", write_audio=self.write_audio)
                elif (self.dataset_name == "asr_af_xh"):
                    csv_entries += load_fleurs(language="both", write_audio=self.write_audio)
                    csv_entries += load_high_quality_tts(language="both", write_audio=self.write_audio)
                    csv_entries += load_nchlt(language="both", write_audio=self.write_audio)
                metadata = pd.DataFrame(csv_entries, columns=['file_name', 'transcription'])
                metadata.to_csv(path_or_buf=os.path.join(self.dataset_dir, "metadata.csv"), sep=",", index=False)

                # Load dataset from audiofolder that you created
                dataset = load_dataset("audiofolder", data_dir=self.dataset_dir)

                # Push dataset to huggingface hub
                if self.push_dataset:
                    dataset.push_to_hub(f"lucas-meyer/{self.dataset_name}")
                    # done = False
                    # num_restarts = 0
                    # while not done:
                    #     try:
                    #         dataset.push_to_hub(f"lucas-meyer/{self.dataset_name}")
                    #         done = True
                    #     except Exception as e:
                    #         num_restarts += 1
                    #         print(f"{str(e)}")
                    #         print(f"Restarting (num restarts: {num_restarts})")
            else:
                # Load dataset from audiofolder that you created
                dataset = load_dataset("audiofolder", data_dir=self.dataset_dir)
        else:
            # Load dataset from huggingface hub
            dataset = load_dataset(f"lucas-meyer/{self.dataset_name}")

        # Downsample audio to SR = 16000 and init train/val/test sets
        self.train_set = dataset["train"].cast_column("audio", Audio(sampling_rate=SR)).rename_column("transcription", "sentence")
        self.val_set = dataset["validation"].cast_column("audio", Audio(sampling_rate=SR)).rename_column("transcription", "sentence")
        self.test_set = dataset["test"].cast_column("audio", Audio(sampling_rate=SR)).rename_column("transcription", "sentence")
        self.train_set = self.train_set.map(remove_special_characters_batch)
        self.val_set = self.val_set.map(remove_special_characters_batch)
        self.test_set = self.test_set.map(remove_special_characters_batch)
        torch.cuda.empty_cache()

    def extract_all_chars(self, batch):
        all_text = " ".join(batch["sentence"])
        vocab = list(set(all_text))
        return {"vocab": [vocab], "all_text": [all_text]}

    def create_tokenizer(self):
        vocab_train = self.train_set.map(self.extract_all_chars,
                                         batched=True, batch_size=-1,
                                         keep_in_memory=True,
                                         remove_columns=self.train_set.column_names)

        vocab_val = self.val_set.map(self.extract_all_chars,
                                     batched=True, batch_size=-1,
                                     keep_in_memory=True,
                                     remove_columns=self.val_set.column_names)

        vocab_test = self.test_set.map(self.extract_all_chars,
                                       batched=True, batch_size=-1,
                                       keep_in_memory=True,
                                       remove_columns=self.test_set.column_names)

        # Get list for vocab of train/val/test
        vocab_list = list(set(vocab_train["vocab"][0]) |
                        set(vocab_test["vocab"][0]) |
                        set(vocab_val["vocab"][0]))

        # Get dict for vocab of train/val/test
        vocab_dict = {v: k for k, v in enumerate(sorted(vocab_list))}
        vocab_dict["|"] = vocab_dict[" "]
        del vocab_dict[" "]
        vocab_dict["[UNK]"] = len(vocab_dict)
        vocab_dict["[PAD]"] = len(vocab_dict)

        # Save vocabulary file
        with open(os.path.join(self.repo_name, 'vocab.json'), 'w') as vocab_file:
            json.dump(vocab_dict, vocab_file)

        self.tokenizer = Wav2Vec2CTCTokenizer(os.path.join(self.repo_name, 'vocab.json'),
                                              unk_token="[UNK]",
                                              pad_token="[PAD]",
                                              bos_token=None,
                                              eos_token=None,
                                              word_delimiter_token="|")
        if self.push_repo:
            self.tokenizer.push_to_hub(f"lucas-meyer/{self.repo_name}")

    def prepare_dataset(self, batch):
        audio = batch["audio"]
        batch["input_values"] = self.processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]
        batch["input_length"] = len(batch["input_values"])
        batch["labels"] = self.processor(text=batch["sentence"]).input_ids
        return batch

    def extract_features(self):
        self.feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1,
                                                     sampling_rate=16000,
                                                     padding_value=0.0,
                                                     do_normalize=True,
                                                     return_attention_mask=True)

        self.processor = Wav2Vec2Processor(feature_extractor=self.feature_extractor,
                                      tokenizer=self.tokenizer)

        self.train_set = self.train_set.map(self.prepare_dataset, remove_columns=self.train_set.column_names)
        self.val_set = self.val_set.map(self.prepare_dataset, remove_columns=self.val_set.column_names)
        self.test_set = self.test_set.map(self.prepare_dataset, remove_columns=self.test_set.column_names)

    def create_data_collator(self):
        self.data_collator = DataCollatorCTCWithPadding(processor=self.processor, padding=True)

    def download_xlsr(self):
        self.xlsr_model = Wav2Vec2ForCTC.from_pretrained(
            PULL_REPO,
            attention_dropout=0.0,
            hidden_dropout=0.0,
            feat_proj_dropout=0.0,
            mask_time_prob=0.05,
            layerdrop=0.0,
            ctc_loss_reduction="mean",
            pad_token_id=self.processor.tokenizer.pad_token_id,
            vocab_size=len(self.processor.tokenizer),
        )
        self.xlsr_model.freeze_feature_encoder()          # Freeze feature exctraction weights
        self.xlsr_model.gradient_checkpointing_enable()   # Enable gradient checkpointing

    def compute_metrics(self, pred):
        pred_logits = pred.predictions
        pred_ids = np.argmax(pred_logits, axis=-1)
        pred.label_ids[pred.label_ids == -100] = self.processor.tokenizer.pad_token_id
        pred_str = self.processor.batch_decode(pred_ids)
        # we do not want to group tokens when computing the metrics
        label_str = self.processor.batch_decode(pred.label_ids, group_tokens=False)

        wer_metric = evaluate.load("wer")
        wer = wer_metric.compute(predictions=pred_str, references=label_str)
        return {"wer": wer}

@dataclass
class DataCollatorCTCWithPadding:
    """
    Data collator that will dynamically pad the inputs received.
    Args:
        processor (:class:`~transformers.Wav2Vec2Processor`)
            The processor used for proccessing the data.
        padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
            Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
            among:
            * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
            sequence if provided).
            * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
            maximum acceptable input length for the model if that argument is not provided.
            * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
            different lengths).
    """

    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            return_tensors="pt",
        )

        with self.processor.as_target_processor():
            labels_batch = self.processor.pad(
                label_features,
                padding=self.padding,
                return_tensors="pt",
            )

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
        batch["labels"] = labels
        return batch

### 5. Log in to Hugging Face

In [None]:
login(token=WRITE_ACCESS_TOKEN) 

### 6. Instantiate ASR model object

In [39]:
model = ASR_MODEL(repo_name=PUSH_REPO,
                  dataset_name=dataset_name,
                  load_from_hf=True,
                  push_dataset=False,
                  push_repo=True,
                  write_audio=False)

Token will not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /root/.cache/huggingface/token
Login successful


### 7. Do all of the below jazz

In [40]:
model.load_datasets()
model.create_tokenizer()
model.extract_features()
model.create_data_collator()
model.download_xlsr()

Resolving data files:   0%|          | 0/2723 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/447 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/476 [00:00<?, ?it/s]

Map:   0%|          | 0/2723 [00:00<?, ? examples/s]

Map:   0%|          | 0/447 [00:00<?, ? examples/s]

Map:   0%|          | 0/476 [00:00<?, ? examples/s]

Map:   0%|          | 0/2723 [00:00<?, ? examples/s]

Map:   0%|          | 0/447 [00:00<?, ? examples/s]

Map:   0%|          | 0/476 [00:00<?, ? examples/s]

Some weights of the model checkpoint at facebook/wav2vec2-xls-r-300m were not used when initializing Wav2Vec2ForCTC: ['project_q.bias', 'quantizer.weight_proj.weight', 'project_hid.weight', 'quantizer.codevectors', 'project_hid.bias', 'project_q.weight', 'quantizer.weight_proj.bias']
- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-xls-r-300m and are newly initialized: ['lm_head.weight', 'lm_head.bias']
You should probably TRAIN this model on a down-stream task to be able to use it 

### 8. TRAIN

In [None]:
clear_cache()

batch_size = 8
grad_accum_steps = 4

training_args = TrainingArguments(
    # Output directory (repo name)
    output_dir=model.repo_name,
    overwrite_output_dir=True,

    # Evaluate after every eval_steps
    save_strategy="steps",
    evaluation_strategy="steps",

    # Batch sizes and num accumulation steps
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size//4,
    gradient_accumulation_steps=grad_accum_steps,
    eval_accumulation_steps=grad_accum_steps,

    # Train for 30 epochs max
    num_train_epochs=30,

    # Enable gradient checkpointing
    gradient_checkpointing=True,

    # Half-point precision
    fp16=True,

    # Save/evaluate/log/warmup steps
    save_steps=500,
    eval_steps=50,
    logging_steps=50,
    warmup_steps=500,
    save_total_limit=2,

    # Learning rate
    learning_rate=3e-4,

    # Load best model at end
    load_best_model_at_end = True,

    # Push to hub
    push_to_hub = True,
)

model.trainer = Trainer(
    model=model.xlsr_model,
    data_collator=model.data_collator,
    args=training_args,
    compute_metrics=model.compute_metrics,
    train_dataset=model.train_set,
    eval_dataset=model.val_set,
    tokenizer=model.processor.feature_extractor,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
)

model.trainer.train()
model.trainer.push_to_hub()

Cloning https://huggingface.co/lucas-meyer/wav2vec2-xls-r-300m-asr_af-run2 into local empty directory.


Step,Training Loss,Validation Loss,Wer
20,13.4255,14.059743,1.013686
40,11.1376,9.16298,1.0
60,6.3128,5.592697,1.0
80,4.7219,4.517402,1.0
100,4.1284,3.975215,1.0
120,3.6961,3.557995,1.0
140,3.3833,3.257957,1.0
160,3.1534,3.084857,1.0
180,3.0469,3.008968,1.0
200,2.9851,2.979771,1.0
