### Test ASR model in batch mode

* consider all the wav files contained in a directory
* two modes: 1. requires a csv file with path and sentence (can be empty), 2. process all wav in a dir
* check that all the files have the required characteristics (sample_rate=16000)
* apply the model and makes a prediction for each file
* eventually compare prediction with expected and compute WER
* apply SpellChecker
* compute WER after Spell Checking

In [1]:
from datasets import load_dataset, Audio
from datasets import Dataset

# progress bar
from tqdm import tqdm

import random

import pandas as pd
import glob
import os
import soundfile as sf
import re

# for model inference
import json
from transformers import Wav2Vec2CTCTokenizer
from transformers import Wav2Vec2Processor
from transformers import Wav2Vec2ForCTC
from transformers import Wav2Vec2FeatureExtractor
from transformers import Wav2Vec2Processor

import torch

# used to compute WER
from jiwer import wer

from utils import check_sample_rate, init_empty_list, check_files_exists, check_gpu
from utils import check_mono

from spellchecker import SpellChecker

In [2]:
# inputs

# a DIR containing a csv file with path,sentence
# a list of WAV files one for each rows in csv
# wav files are MONO, sample_rate = 16 Khz

# if it is set to True it expects a csv file with the list and paths of the wav. Otherwise
# it process all the wav files in the DIR_4_TEST
CSV_MODE = True
# save all predictions in csv file
CREATE_OUT_CSV = True

DIR_4_TEST = "/home/datascience/asr-atc/data2_4_test/"
CSV_FILE_NAME = "test.csv"
# DIR_4_TEST = "/home/datascience/asr-atc/data_4_test_train/"
# CSV_FILE_NAME = "atco2.csv"

# the output file with all predictions
OUT_CSV = "predictions.csv"

# the directory containing the files of the trained model
REPO_NAME = "wav2vec2-large-xls-r-300m-tr-ls"
VOCAB_DIR = "./vocab_atco2"

# globals
# 16 Khz
SAMPLE_RATE = 16000

In [3]:
# It is expected that this notebook runs on GPU (otherwise, remove tocuda)
# check it
check_gpu()

GPU is available, OK


#### Load the trained model

In [4]:
# load everything

tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(VOCAB_DIR, unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")

feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=SAMPLE_RATE, padding_value=0.0, do_normalize=True, return_attention_mask=True)

processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

# it is expected to run on GPU (to(cuda))
model = Wav2Vec2ForCTC.from_pretrained(REPO_NAME).to("cuda")

print()
print("Model loaded !!!")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.



Model loaded !!!


In [5]:
#
# Functions
#
def prepare_dataset(batch):
    audio = batch["audio"]

    # batched output is "un-batched"
    batch["input_values"] = processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]
    batch["input_length"] = len(batch["input_values"])
    
    with processor.as_target_processor():
        batch["labels"] = processor(batch["sentence"]).input_ids
    return batch

def do_test(index):
    # from input to prediction
    input_dict = processor(ds_hf_test[index]["input_values"], return_tensors="pt", padding=True, sampling_rate=SAMPLE_RATE)
    # it is expected to run on GPU (to(cuda))
    logits = model(input_dict.input_values.to("cuda")).logits
    pred_ids = torch.argmax(logits, dim=-1)[0]
    pred_text = processor.decode(pred_ids)
    
    print()
    print(f"Prediction on: {list_wav[index]}")
    print(pred_text)
    
    if CSV_MODE == True:
        print()
        print("Expected:")
        print(list_txts[index].lower())
    print()
    
    return pred_text

# do test in batch mode da problemi con OOM

In [6]:
# read files list from csv or all wav from directory

if CSV_MODE == True:
    # there is a csv file that will guide
    CSV_FULL_NAME = DIR_4_TEST + CSV_FILE_NAME

    df_test = pd.read_csv(CSV_FULL_NAME)

    df_test.head()
    
    # create the list of wav from DataFrame
    list_wav = list(df_test['path'].values)
else:
    # build the list of wav directly from the contents of directory
    print(f"Not using CSV file ...reading list of wav from directory {DIR_4_TEST}")
    
    list_wav = sorted(glob.glob(DIR_4_TEST + "*.wav"))

In [7]:
# check that all wav files are available
check_files_exists(list_wav)


All required wav files are available!


In [8]:
# check that all files have SAMPLE_RATE
check_sample_rate(list_wav, ref_sample_rate=SAMPLE_RATE)

All wav files have sample rate = 16000.


In [9]:
# check that all files are MONO
check_mono(list_wav)

All wav files are MONO.


#### Create the HF dataset

In [10]:
#
# load all data in HF dataset
#
list_path_names = list_wav

if CSV_MODE == True:
    list_txts = list(df_test['sentence'].values)
else:
    # no expected values available
    list_txts = init_empty_list(len(list_wav))

# create a dictionary
dict_res = {"path": list_path_names, "audio" : list_path_names, "sentence": list_txts}

# create the dataset in HF format
ds_hf_test = Dataset.from_dict(dict_res).cast_column("audio", Audio())

print("HF dataset created !")

HF dataset created !


In [11]:
print(f"We have {len(ds_hf_test)} records in the dataset to be used for test.")

We have 9 records in the dataset to be used for test.


In [12]:
# make a final check for compatibility with HF example

# get a random index to select a randmo item from the dataset
# rand_int = random.randint(0, len(ds_hf_test)-1)

# print()
# print(f"Checking record n. {rand_int}")
# print("Input audio array shape:", ds_hf_test[rand_int]["audio"]["array"].shape)
# print("Sampling rate:", ds_hf_test[rand_int]["audio"]["sampling_rate"])

# if CSV_MODE == True:
#    print("Expected text:", ds_hf_test[rand_int]["sentence"])

In [13]:
# prepare the dataset for inference

print("Preparing dataset for inference....")

ds_hf_test = ds_hf_test.map(prepare_dataset, remove_columns=ds_hf_test.column_names)

Preparing dataset for inference....


  0%|          | 0/9 [00:00<?, ?ex/s]

#### Ready: now we have all files packed in a HF dataset ready to be used for test
#### Do Test

In [14]:
%%time

if CREATE_OUT_CSV == True:
    list_predictions = []
    
for INDEX in range(len(ds_hf_test)):
    str_pred = do_test(INDEX)
    
    if CREATE_OUT_CSV == True:
        list_predictions.append(str_pred)

if CREATE_OUT_CSV == True:
    # create the output csv file
    out_dict = {"file":list_wav, "preds":list_predictions}
    
    out_df = pd.DataFrame.from_dict(out_dict)
    
    out_df.to_csv(DIR_4_TEST + OUT_CSV, index=None)


Prediction on: /home/datascience/asr-atc/data2_4_test/luigi1.wav
alpha bata gma delta

Expected:
alfa beta gamma delta


Prediction on: /home/datascience/asr-atc/data2_4_test/luigi2.wav
euro wind seven alpha bravo turnd right heading two one zero cleared ils approach runway two four leport established

Expected:
eurowings seven alfa bravo turn right heading two one zero cleared ils approach runway two four report established


Prediction on: /home/datascience/asr-atc/data2_4_test/luigi3.wav
rya aunawr seven three halpho tol turn left heading three six zero

Expected:
ryanair seven three alpha hotel turn left heading three six zero


Prediction on: /home/datascience/asr-atc/data2_4_test/luigi4.wav
rya yai noawr seven three allpha hotoel

Expected:
ryanair seven three alpha hotel


Prediction on: /home/datascience/asr-atc/data2_4_test/luigi5.wav
oscar kilo kilo uniform november proceed direct lybiltu

Expected:
oscar kilo kilo uniform november proceed direct baltu


Prediction on: /home

#### Compute WER on all dataset

In [15]:
# compute WER

# all to lower case to compute wer
list_txts = [txt.lower() for txt in list_txts]

if (CSV_MODE == True) and (CREATE_OUT_CSV == True):
    v_wer = wer(list_txts, list_predictions)
    
    print()
    print(f"Computed WER is: {round(v_wer, 3)}")


Computed WER is: 0.362


### Adding Spell Checker

In [98]:
from tqdm import tqdm

In [131]:
%load_ext autoreload
%autoreload 2

#
# Incapsulo correct_txt in ATCSpellChecker
#
class ATCSpellChecker(SpellChecker):
    def __init__(self):
        SpellChecker.__init__(self)
        # carico il testo atco2 per integrare le parole "speciali" nel dizionario
        self.word_frequency.load_text_file('./atco2.txt')
    
    def correct_text(self, text):
        # To correct an entire sentence
        SEP = " "
        l_text = self.split_words(text)
    
        l_text_corrected = [self.correction(w) if self.correction(w) is not None else w for w in l_text]
        
        # rebuild the sentence and return
        return SEP.join(l_text_corrected)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [132]:
spell = ATCSpellChecker()

In [128]:
list_corr_predictions = [spell.correct_text(text) for text in tqdm(list_predictions)]

for corr_txt in list_corr_predictions:
    print(corr_txt)

100%|██████████| 9/9 [00:02<00:00,  4.01it/s]

alpha data ma delta
euro wind seven alpha bravo turn right heading two one zero cleared ils approach runway two four report established
ya lunar seven three alpha to turn left heading three six zero
ya yai now seven three alpha hotel
oscar kilo kilo uniform november proceed direct lybiltu
prosedend e react malt oscar kilo kilo uniform november
hotel delta lima runway vacate on delta
bluparking hotel delta lima
hotel bravo charlie lima hotel i go line up and wait runway one zero





In [129]:
v_wer = wer(list_txts, list_corr_predictions)
    
print()
print(f"Computed WER after spell corrections is: {round(v_wer, 3)}")


Computed WER after spell corrections is: 0.275


In [120]:
# let's see if we identify unknown words
for corr_txt in list_corr_predictions:
    l_text = spell.split_words(corr_txt)
    
    unk = spell.unknown(l_text)
    if len(unk) > 0:
        print(unk)

{'lybiltu'}
{'prosedend'}
{'bluparking'}


In [133]:
spell.known(["baltu"])

{'baltu'}