## Fine Tuning and inference

To train an Automatic Speech Recognition (ASR) model, the following inputs are required:

1. **Audio Data with Corresponding Transcriptions**: A dataset of audio recordings paired with their textual transcriptions.  
2. **Dictionary**: Defines the set of tokens over which the acoustic model predicts probabilities for each audio frame.  
3. **Lexicon**: Maps words to sequences of tokens, enabling the conversion of text into token representations.

#### Tools used:

1. **Fairseq**: Used for fine-tuning pre-trained models, enabling efficient training of ASR systems with existing models and datasets.  
2. **KenLM**: Utilized for building and integrating language models, which enhance the recognition accuracy by capturing the probabilities of word sequences.  
3. **Flashlight**: Employed for decoding, providing a fast and flexible beam search decoder to predict text sequences from acoustic model outputs.

Clone the fairseq and the Ai4Bharat repo.

In [None]:
!git clone https://github.com/AI4Bharat/IndicWav2Vec.git
!git clone https://github.com/pytorch/fairseq.git

In [None]:
# try different versions of pip that is compatible with pip
# The latest version is likely to be uncompatible
!pip install pip==23.1.2

In [None]:
%cd /content/IndicWav2Vec
!pip install packaging soundfile swifter -r w2v_inference/requirements.txt
%cd ..

##### Install fairseq

In [None]:
%cd fairseq
!pip install --editable ./
%cd

In [None]:
# Install other dependencies
!pip install torch torchvision torchaudio soundfile torchaudio sentencepiece editdistance scikit-learn

#### Prepare the folder structure

In [None]:
!mkdir datasets
!mkdir datasets/santali

##### Download the dataset

In [None]:
!pip install gdown

In [None]:
# Anonymous linkes for train, valid, test
!gdown 1kJSOB3hpAcYCLwpExR03A3TXQZvFkqfD
!gdown 1QBz-IeMmgi2EPCa9endt0Zvwa3tLiuP2
!gdown 1gfyDg464ExtaS8OdzcVxB5BQW1S3ZqKi

In [None]:
# Extract files
!unzip train.zip -d datasets/santali/train
!unzip test.zip -d datasets/santali/test
!unzip valid.zip -d datasets/santali/valid

##### Download pre-trained model

In [None]:
!wget https://indic-asr-public.objectstore.e2enetworks.net/aaai_ckpts/pretrained_models/indicw2v_base_pretrained.pt

In [None]:
!mkdir checkpoint

##### Start fine-tuning

In [None]:
## Run this cell if some fairseq module not found error shows up
## Just a bypass, reset PYTHONPATH after finetuning

import os
os.environ['PYTHONPATH'] = "/content/fairseq/"
!echo $PYTHONPATH

In [None]:
!fairseq-hydra-train task.data="/teamspace/studios/this_studio/datasets/santali/manifest" \
    dataset.max_tokens=1000000 \
    common.log_interval=50 \
    model.freeze_finetune_updates=1000 \
    model.w2v_path="/content/indicw2v_base_pretrained.pt" \
    checkpoint.save_dir="/content/checkpoint" \
    checkpoint.restore_file="/content/checkpoint/checkpoint_last.pt" \
    distributed_training.distributed_world_size=1 \
    +optimization.update_freq='[1]' \
    +optimization.lr=[0.00005] \
    optimization.max_update=100000 \
    checkpoint.save_interval_updates=10000 \
    --config-dir "IndicWav2Vec/finetune_configs" \
    --config-name ai4b_base

Change --config-name to `ai4b_large` if using large model.

### Evaluate


#### Download and build Flashlight for decoding

In [None]:
!pip install flashlight-text
!pip install git+https://github.com/kpu/kenlm.git

In [None]:
!git clone https://github.com/flashlight/sequence
%cd sequence
!pip install .

In [None]:
!git clone https://github.com/flashlight/sequence
%cd sequence
!cmake -S . -B build
!cmake --build build --parallel
!cd build && ctest
%cd .. # run tests
!cmake --install build # install at the CMAKE_INSTALL_PREFIX

In [None]:
!git clone https://github.com/flashlight/text
%cd text
!cmake -S . -B build
!cmake --build build --parallel
!cd build && ctest && cd .. # run tests
!cmake --install build # install at the CMAKE_INSTALL_PREFIX

Inference on test/ valid set

In [None]:
!python3 /content/IndicWav2Vec/w2v_inference/infer/infer.py "/content/datasets/santali/manifest" --task audio_finetuning \
--nbest 1 --path "/content/gdrive/MyDrive/checkpoint_last.pt" --gen-subset test --results-path "/content/res/" --w2l-decoder viterbi \
--lexicon none --lm-weight 0 --word-score 0 --sil-weight 0 --criterion ctc --labels ltr --max-tokens 1000000 \
--post-process letter

Inference on external dataset

In [None]:
!pip install jiwer
!pip install Levenshtein

Below is an illustration of inference of Mozila Common Voice Dataset.

In [None]:
import pandas as pd
from jiwer import wer
import subprocess
import sys
import re
import Levenshtein as Lev

# Load the .tsv file
input_file = '/content/test/audio/test.tsv'  # Replace with your file path
data = pd.read_csv(input_file, sep='\t')

def wer( s1, s2):
        """
        Computes the Word Error Rate, defined as the edit distance between the
        two provided sentences after tokenizing to words.
        Arguments:
            s1 (string): space-separated sentence
            s2 (string): space-separated sentence
        """

        # build mapping of words to integers
        b = set(s1.split() + s2.split())
        word2char = dict(zip(b, range(len(b))))

        # map the words to a char array (Levenshtein packages only accepts
        # strings)
        w1 = [chr(word2char[w]) for w in s1.split()]
        w2 = [chr(word2char[w]) for w in s2.split()]

        return Lev.distance(''.join(w1), ''.join(w2))

def get_trans(filepath):
    try:
        # Format the command
        command = f"""
        python /content/IndicWav2Vec/w2v_inference/scripts/sfi.py \
        --audio-file {filepath} \
        --ft-model /content/gdrive/MyDrive/checkpoint_last.pt \
        --w2l-decoder viterbi \
        --lexicon none
        """

        # Run the command and capture output
        result = subprocess.run(command, shell=True, text=True, capture_output=True)

        if result.returncode == 0:
            # Successful execution, process the output
            output = result.stdout.strip()  # Get the command's stdout and strip whitespace
            lines = output.split('\n')     # Split output into lines
            prediction = lines[-1].strip()  # Get the last non-empty line
            return prediction
        else:
            # Print the error message and exit
            print(f"Error running command: {result.stderr.strip()}")
            sys.exit(1)
    except Exception as e:
        print(f"An error occurred: {str(e)}")
        sys.exit(1)


wer_list = []

# Iterate through each file in the 'path' column
for index, row in data.iterrows():
    filepath = row['path'].replace('.mp3', '.wav')  # Change .mp3 to .wav
    actual_sentence = row['sentence']

    # Get the transcript using the CLI command
    predicted_transcript = get_trans(filepath)

    # Calculate the WER
    file_wer = wer(actual_sentence, predicted_transcript)
    wer_list.append(file_wer)

    # Debug log for current file's WER
    print(f"File: {filepath}, WER: {file_wer:.2%}")
    print(f"Lev_WER ", wer(actual_sentence, predicted_transcript))
    print(f"Actual: {actual_sentence}")
    print(f"Predicted: {predicted_transcript}")

# Calculate the average WER
average_wer = sum(wer_list) / len(wer_list) if wer_list else 0

# Print the average WER
print(f"\nAverage WER: {average_wer:.2%}")

#### LM integration

Download and build Kenlm

In [None]:
!wget -O - https://kheafield.com/code/kenlm.tar.gz |tar xz
!mkdir kenlm/build
%cd kenlm/build
!cmake ..
!make -j2

Train Kenlm - Model

In [None]:
from os.path import abspath
import os
import argparse
from tqdm import tqdm

kenlm_path = "/content/kenlm/"
transcript_file = "/content/gdrive/MyDrive/transcription.txt"
additional_file = "/content/gdrive/MyDrive/corpus.txt.txt"
ngram = 3
output_path = "/content/output"


if not os.path.exists(output_path):
    os.makedirs(output_path)

with open(transcript_file, encoding="utf-8") as f:
    train = f.read().upper().splitlines()
    train = [' '.join(d.split()[1:]) for d in train]


chars = [list(d.replace(' ','')) for d in train]
chars = [j for i in chars for j in i]
chars = set(chars)

if additional_file != None:
    with open(additional_file, encoding="utf-8") as f:
        train += f.read().upper().splitlines()

vocabs = set([])
for line in tqdm(train):
    for word in line.split():
        vocabs.add(word)
vocabs = list(vocabs)
print(len(vocabs))
vocabs = [v for v in vocabs if not any(c for c in list(v) if c not in chars)]
print(len(vocabs))

vocab_path = os.path.join(output_path,'vocabs.txt')
lexicon_path = os.path.join(output_path,'lexicon.txt')
train_text_path = os.path.join(output_path,'world_lm_data.train')
train_text_path_train = train_text_path.replace('world_lm_data.train','kenlm.train')
model_arpa = train_text_path.replace('world_lm_data.train','kenlm.arpa')
model_bin  = train_text_path.replace('world_lm_data.train','lm.bin')
kenlm_path_train = os.path.join(abspath(kenlm_path) , 'build/bin/lmplz')
kenlm_path_convert = os.path.join(abspath(kenlm_path) , 'build/bin/build_binary')
kenlm_path_query = os.path.join(abspath(kenlm_path) , 'build/bin/query')

with open(train_text_path,'w') as f:
    f.write('\n'.join(train))

with open(vocab_path,'w') as f:
    f.write(' '.join(vocabs))

for i in range(0,len(vocabs)):
    vocabs[i] = vocabs[i] + '\t' + ' '.join(list(vocabs[i])) + ' |'

with open(lexicon_path,'w') as f:
    f.write('\n'.join(vocabs))

cmd = kenlm_path_train + " -T /tmp -S 4G --discount_fallback -o " + str(ngram) +" --limit_vocab_file " + vocab_path + " trie < " + train_text_path +  ' > ' + model_arpa
os.system(cmd)
cmd = kenlm_path_convert +' trie ' + model_arpa + ' ' + model_bin
os.system(cmd)
cmd = kenlm_path_query + ' ' + model_bin + " < " + train_text_path + ' > ' + train_text_path_train
os.system(cmd)
os.remove(train_text_path)
os.remove(train_text_path_train)
os.remove(model_arpa)
os.remove(vocab_path)

In [None]:
!python3 /content/IndicWav2Vec/w2v_inference/infer/infer.py "/content/datasets/santali/manifest" --task audio_finetuning \
--nbest 1 --path "/content/gdrive/MyDrive/checkpoint_last.pt" --gen-subset test --results-path "/content/res/" --w2l-decoder kenlm \
--lexicon "/content/output/lexicon.txt" --kenlm-model "/content/output/lm.bin" --sil-weight 0 --criterion ctc --labels ltr --max-tokens 1000000 \
--post-process letter