# Pretraining for ASR

In [None]:
# installing libs
# !pip3 install torch torchvision torchaudio datasets transformers soundfile jiwer --index-url https://download.pytorch.org/whl/cu118
# !pip3 install librosa --index-url https://pypi.org/simple

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import re
import torch
import torch.nn as nn
import numpy as np

from datasets import load_dataset, disable_caching, load_metric
from transformers import Wav2Vec2ForPreTraining, Wav2Vec2FeatureExtractor, Wav2Vec2ForCTC, Wav2Vec2Processor
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2Encoder


## Finetuning Wav2Vec2 model on CTC loss (5 points)


In this task you have to create pipeline for finetuning pretrained multilingual Wav2Vec2 model on belarusian audio from [Fleurs](https://huggingface.co/datasets/google/fleurs) dataset.

#### Prepare data

In [None]:
fleurs = load_dataset("google/fleurs", "be_by", split=["train", "validation", "test"])

In [None]:
fleurs[0]["transcription"][9]

In this task, you should:

* filter all samples, where `transcription` includes digits. Hint: take care of specific belarussian symbols "і", "ў";
* remove punctuation from `transcription`.

In [None]:
preprocessed_train = # YOUR CODE HERE
preprocessed_val = # YOUR CODE HERE

#### Train tokenizer

There you should train your own BPE tokenizer based on texts from Fleurs dataset using [HuggingFace tokenizer](https://huggingface.co/docs/tokenizers/en/training_from_memory).

In [None]:
from tokenizers import models, trainers, tokenizers, normalizers, pre_tokenizers, decoders

PAD_TOKEN = "[PAD]"
BOS_TOKEN = "[BOS]"
EOS_TOKEN = "[EOS]"
UNK_TOKEN = "[UNK]"
VOCAB_SIZE = 1000

tokenizer = # YOUR CODE HERE


#### Loading model and preprocessor

In [None]:
from transformers import Wav2Vec2FeatureExtractor
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
   "facebook/wav2vec2-xls-r-300m"
)
model = Wav2Vec2ForCTC.from_pretrained(
    "facebook/wav2vec2-xls-r-300m", 
    ctc_loss_reduction="mean", 
    pad_token_id=tokenizer.token_to_id(PAD_TOKEN),
    vocab_size=tokenizer.get_vocab_size(),
)


#### Data processor and data collator 

In [None]:
class CtcDataProcessor:
    def __init__(self, tokenizer, feature_extractor):
        self.tokenizer = tokenizer
        self.feature_extractor = feature_extractor

    def __call__(self, row):
        """
            Function applies tokenizer on row['transcription'] and applies feature extractor on audio column in row.
            Input: dict with transcription and audio fields
            Output: original dict includes `labels` column with tokenized sequence and `input_values` column with computed spectrogram.
        """
        # YOUR CODE HERE
        pass

In [None]:
data_processor = CtcDataProcessor(tokenizer, feature_extractor)
train = preprocessed_train.map(data_processor, keep_in_memory=True, remove_columns=preprocessed_train.column_names)
val = preprocessed_val.map(data_processor, keep_in_memory=True, remove_columns=preprocessed_val.column_names)

In [None]:
class CTCDataCollator:
    # HuggingFace requires pad transcript tokens with this value
    LABELS_PAD_IDX = -100

    @staticmethod
    def collate_tokens(tokens_batch, type, pad_value=0.0):
        """
            Function collates list of tokens
        """
        # YOUR CODE HERE
        pass
        
    def __call__(self, batch):
        """
            Function collates `input_values` and `labels` into one tensor respectively
            Input: list with dicts, output of CTCDataProcessor
            Output row includes `labels` column with tokenized sequence, `input_values` column with computed spectrogram and 
            `attention_mask` (0 for not-attending position, 1 for attending)
        """
        # YOUR CODE HERE
        pass

#### Inference and metrics computing

There you should use simple greedy straregy for CTC output decoding. 

Hint: Don't forget about padding value -100 in reference.

Hint: Don't forget about CTC output format.

In [None]:
from itertools import groupby
wer_metric = load_metric("wer")

class MetricsComputer:
    def __call__(self, pred):
        """
            Input: object with fields `predictions` for CTC model output and `label_ids` for tokenized reference;
            Output: dict with key `wer` and computed wer
        """
        # model prediction tensor, tensor batch_size x max_seq_len x vocab_size
        preds_logits = pred.predictions
        # reference, tensor batch_size x max_seq_len
        label_ids = pred.label_ids
        
        # YOUR CODE HERE
        
        pred_str = # YOUR CODE HERE
        label_str = # YOUR CODE HERE
    
        print(f"Prediction: {pred_str[0]}")
        print(f"Reference: {label_str[0]}")
        
        wer = wer_metric.compute(predictions=pred_str, references=label_str)
        return {"wer": wer}

#### Overfitting on train batch

In this task you should check pipeline correctness by overfitting on you need to finetune Wav2Vec2 model and achieve 50 WER or lower accuracy on val set.

In [None]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="test",
    per_device_train_batch_size=2, # you could increase batch size
    gradient_accumulation_steps=8, 
    evaluation_strategy="steps",
    max_steps=3000,
    fp16=True,
    save_steps=50,
    eval_steps=10,
    logging_steps=10,
    learning_rate=# YOUR CODE HERE, 
    weight_decay=# YOUR CODE HERE,
    warmup_steps=# YOUR CODE HERE,
    gradient_checkpointing=True,
)

In [None]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    data_collator=CTCDataCollator(),
    args=training_args,
    compute_metrics=MetricsComputer(),
    train_dataset=train,
    eval_dataset=val,
)

In [None]:
trainer.train()