In [124]:
import kenlm
import pandas as pd
import numpy as np
from pyctcdecode import build_ctcdecoder
import soundfile as sf
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
import torch
import os
import pickle
from datasets import Dataset, DatasetDict, load_metric
from tqdm import tqdm

device = 'cuda' if torch.cuda.is_available else 'cpu'
device

'cuda'

In [33]:
# greedy decode algorithm
def greedy_decode(logits, labels):
    """Decode argmax of logits and squash in CTC fashion."""
    label_dict = {n: c for n, c in enumerate(labels)}
    prev_c = None
    out = []
    for n in logits.argmax(axis=1):
        c = label_dict.get(n, "")  # if not in labels, then assume it's ctc blank char
        if c != prev_c:
            out.append(c)
        prev_c = c
    return "".join(out)

In [34]:
KENLM_MODEL_LOC = 'lm/4gram_big.arpa.gz'

In [35]:
# load the finetuned model and the processor
asr_model = Wav2Vec2ForCTC.from_pretrained('./saved_model/')
asr_processor = Wav2Vec2Processor.from_pretrained('./processor/')

In [36]:
print("Vocab: ", asr_processor.tokenizer.get_vocab())

Vocab:  {'B': 0, 'D': 1, 'S': 2, 'C': 3, 'K': 4, 'N': 5, 'Y': 6, 'W': 7, 'U': 8, 'A': 9, 'P': 10, 'Z': 11, 'X': 12, 'J': 13, 'G': 14, "'": 15, 'I': 16, 'L': 17, 'V': 18, 'O': 19, 'F': 20, 'T': 21, 'Q': 22, 'R': 23, 'E': 24, 'H': 25, 'M': 26, '|': 27, '[UNK]': 28, '[PAD]': 29, '<s>': 30, '</s>': 31}


In [37]:
vocab = list(asr_processor.tokenizer.get_vocab().keys())
print(vocab)

['B', 'D', 'S', 'C', 'K', 'N', 'Y', 'W', 'U', 'A', 'P', 'Z', 'X', 'J', 'G', "'", 'I', 'L', 'V', 'O', 'F', 'T', 'Q', 'R', 'E', 'H', 'M', '|', '[UNK]', '[PAD]', '<s>', '</s>']


In [38]:
# convert some vocabs
vocab[vocab.index('[PAD]')] = '_'
vocab[vocab.index('|')] = ' '

In [40]:
print(vocab)

['B', 'D', 'S', 'C', 'K', 'N', 'Y', 'W', 'U', 'A', 'P', 'Z', 'X', 'J', 'G', "'", 'I', 'L', 'V', 'O', 'F', 'T', 'Q', 'R', 'E', 'H', 'M', ' ', '[UNK]', '_', '<s>', '</s>']


In [41]:
# build the decoder
decoder = build_ctcdecoder(
    labels = vocab,
    kenlm_model_path = KENLM_MODEL_LOC,
    alpha=0.6,  # tuned on a val set
    beta=2.0,  # tuned on a val set
)

Loading the LM will be faster if you build a binary file.
Reading /w2v2_kenlm_pipeline/lm/4gram_big.arpa.gz
----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100
Unigrams not provided and cannot be automatically determined from LM file (only arpa format). Decoding accuracy might be reduced.
****************************************************************************************************
Found entries of length > 1 in alphabet. This is unusual unless style is BPE, but the alphabet was not recognized as BPE type. Is this correct?
No known unigrams provided, decoding results might be a lot worse.


In [99]:
audio_path = './datasets/magister_data_flac_16000/test/11039/2614000/11039-2614000-0053.flac'

In [100]:
# get logits
arr, _ = sf.read(audio_path)
input_values = asr_processor(arr, return_tensors="pt", sampling_rate=16000).input_values  # Batch size 1
logits = asr_model(input_values).logits.cpu().detach().numpy()[0]

In [106]:
arr[:50]

array([-2.44140625e-04, -1.86157227e-03, -3.47900391e-03, -3.32641602e-03,
       -3.14331055e-03, -2.22778320e-03, -1.28173828e-03, -9.15527344e-05,
        1.09863281e-03,  1.64794922e-03,  2.22778320e-03,  3.66210938e-04,
       -1.49536133e-03, -4.66918945e-03, -7.84301758e-03, -9.76562500e-03,
       -1.16882324e-02, -1.10168457e-02, -1.03454590e-02, -7.26318359e-03,
       -4.15039062e-03, -1.19018555e-03,  1.80053711e-03,  3.20434570e-03,
        4.63867188e-03,  6.01196289e-03,  7.41577148e-03,  6.95800781e-03,
        6.50024414e-03,  6.43920898e-03,  6.37817383e-03,  8.33129883e-03,
        1.02844238e-02,  1.14135742e-02,  1.25732422e-02,  1.25732422e-02,
        1.25732422e-02,  9.33837891e-03,  6.10351562e-03,  1.67846680e-03,
       -2.71606445e-03, -5.34057617e-03, -7.96508789e-03, -9.27734375e-03,
       -1.05895996e-02, -1.06201172e-02, -1.06201172e-02, -9.94873047e-03,
       -9.24682617e-03, -9.61303711e-03])

In [84]:
with open('./pkl/magister_data_flac_16000_test.pkl', 'rb') as f:
    df_test = pickle.load(f)

In [88]:
df_test_data = Dataset.from_pandas(df_test)
df_test_data

Dataset({
    features: ['file', 'audio', 'text'],
    num_rows: 334
})

In [97]:
print(df_test_data[0]['audio']['path'])

./datasets/magister_data_flac_16000/test/11039/2614000/11039-2614000-0053.flac


In [98]:
# audio array
print(df_test_data[0]['audio']['array'][:40])

[-0.000244140625, -0.001861572265625, -0.00347900390625, -0.003326416015625, -0.003143310546875, -0.002227783203125, -0.00128173828125, -9.1552734375e-05, 0.0010986328125, 0.00164794921875, 0.002227783203125, 0.0003662109375, -0.001495361328125, -0.004669189453125, -0.007843017578125, -0.009765625, -0.011688232421875, -0.011016845703125, -0.010345458984375, -0.00726318359375, -0.004150390625, -0.001190185546875, 0.001800537109375, 0.003204345703125, 0.004638671875, 0.006011962890625, 0.007415771484375, 0.0069580078125, 0.006500244140625, 0.006439208984375, 0.006378173828125, 0.008331298828125, 0.010284423828125, 0.01141357421875, 0.0125732421875, 0.0125732421875, 0.0125732421875, 0.00933837890625, 0.006103515625, 0.001678466796875]


In [112]:
# audio array needs to be converted to numpy array again
np.array(df_test_data[0]['audio']['array'])

array([-2.44140625e-04, -1.86157227e-03, -3.47900391e-03, ...,
       -2.13623047e-04, -6.10351562e-05,  1.22070312e-04])

In [102]:
# ground truth text
print(df_test_data[0]['text'])

OOW INCREASE ONE THREE


In [65]:
logits

array([[ 1.135748  ,  1.7107844 ,  1.3733399 , ..., 14.112412  ,
        -1.7596617 , -1.1428883 ],
       [ 1.1059318 ,  1.5879831 ,  1.2237644 , ..., 14.059038  ,
        -1.8348356 , -1.3138465 ],
       [ 0.21142852,  1.9927497 ,  1.7712679 , ..., 13.7063265 ,
        -1.67544   , -1.1022942 ],
       ...,
       [-0.4605291 ,  2.3351345 ,  3.1894143 , ..., 12.909706  ,
        -1.1828046 , -1.0336481 ],
       [-0.30894375,  2.5699425 ,  2.9037273 , ..., 14.281372  ,
        -1.0672334 , -0.8433927 ],
       [ 0.21268058,  1.9065666 ,  1.7108347 , ..., 13.669061  ,
        -1.7176361 , -1.1836007 ]], dtype=float32)

In [66]:
text = decoder.decode(logits)

In [67]:
# get greedy decoding
greedy_text = greedy_decode(logits, vocab)
greedy_text = ("".join(c for c in greedy_text if c not in ["_"]))

In [68]:
print("Greedy Decoding: \n" + greedy_text)
print("\n")
print("Language Model Decoding: \n" + text)
print("\n")
# print("Ground truth \n" + true_text)
# print("\n")

Greedy Decoding: 
ALL STATION THIS IS PWO STAND BY FOR SITREP ONE EXTERNAL FOLLOW BYTE GUNERY BROADCAST


Language Model Decoding: 
ALL STATION THIS IS STANDBY FOR SITREPONEEXTERNAL FOLLOBYGUNERY BROADCAST




## The code to decode all the test audio and obtain the WER

In [48]:
# greedy decode algorithm
def greedy_decode(logits, labels):
    """Decode argmax of logits and squash in CTC fashion."""
    label_dict = {n: c for n, c in enumerate(labels)}
    prev_c = None
    out = []
    for n in logits.argmax(axis=1):
        c = label_dict.get(n, "")  # if not in labels, then assume it's ctc blank char
        if c != prev_c:
            out.append(c)
        prev_c = c
    return "".join(out)

In [49]:
KENLM_MODEL_LOC = 'lm/4gram_big.arpa.gz'

In [50]:
# load the finetuned model and the processor
asr_model = Wav2Vec2ForCTC.from_pretrained('./saved_model/')
asr_processor = Wav2Vec2Processor.from_pretrained('./processor/')

In [73]:
print(f'Vocab: {asr_processor.tokenizer.get_vocab()}')
print()
vocab = list(asr_processor.tokenizer.get_vocab().keys())
print(f'Vocab List: {vocab}')

Vocab: {'B': 0, 'D': 1, 'S': 2, 'C': 3, 'K': 4, 'N': 5, 'Y': 6, 'W': 7, 'U': 8, 'A': 9, 'P': 10, 'Z': 11, 'X': 12, 'J': 13, 'G': 14, "'": 15, 'I': 16, 'L': 17, 'V': 18, 'O': 19, 'F': 20, 'T': 21, 'Q': 22, 'R': 23, 'E': 24, 'H': 25, 'M': 26, '|': 27, '[UNK]': 28, '[PAD]': 29, '<s>': 30, '</s>': 31}

Vocab List: ['B', 'D', 'S', 'C', 'K', 'N', 'Y', 'W', 'U', 'A', 'P', 'Z', 'X', 'J', 'G', "'", 'I', 'L', 'V', 'O', 'F', 'T', 'Q', 'R', 'E', 'H', 'M', '|', '[UNK]', '[PAD]', '<s>', '</s>']


In [74]:
# convert some vocabs
vocab[vocab.index('[PAD]')] = '_'
vocab[vocab.index('|')] = ' '
print(vocab)

['B', 'D', 'S', 'C', 'K', 'N', 'Y', 'W', 'U', 'A', 'P', 'Z', 'X', 'J', 'G', "'", 'I', 'L', 'V', 'O', 'F', 'T', 'Q', 'R', 'E', 'H', 'M', ' ', '[UNK]', '_', '<s>', '</s>']


In [76]:
# build the decoder
decoder = build_ctcdecoder(
    labels = vocab,
    kenlm_model_path = KENLM_MODEL_LOC,
    alpha=0.6,  # tuned on a val set
    beta=2.0,  # tuned on a val set
)

Loading the LM will be faster if you build a binary file.
Reading /w2v2_kenlm_pipeline/lm/4gram_big.arpa.gz
----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100
****************************************************************************************************
Unigrams not provided and cannot be automatically determined from LM file (only arpa format). Decoding accuracy might be reduced.
Found entries of length > 1 in alphabet. This is unusual unless style is BPE, but the alphabet was not recognized as BPE type. Is this correct?
No known unigrams provided, decoding results might be a lot worse.


In [113]:
# load the test set data for evaluation
with open('./pkl/magister_data_flac_16000_test.pkl', 'rb') as f:
    df_test = pickle.load(f)

# convert the data into a huggingface Dataset object
df_test_data = Dataset.from_pandas(df_test)
df_test_data

Dataset({
    features: ['file', 'audio', 'text'],
    num_rows: 334
})

In [114]:
print(df_test_data[0]['audio']['path'])

./datasets/magister_data_flac_16000/test/11039/2614000/11039-2614000-0053.flac


In [115]:
# audio array needs to be converted to numpy array again
np.array(df_test_data[0]['audio']['array'])

array([-2.44140625e-04, -1.86157227e-03, -3.47900391e-03, ...,
       -2.13623047e-04, -6.10351562e-05,  1.22070312e-04])

In [116]:
# ground truth text
print(df_test_data[0]['text'])

OOW INCREASE ONE THREE


In [117]:
# get logits
audio_array = np.array(df_test_data[0]['audio']['array'])
input_values = asr_processor(arr, return_tensors="pt", sampling_rate=16000).input_values  # Batch size 1
logits = asr_model(input_values).logits.cpu().detach().numpy()[0]

# beam search decoding 
text = decoder.decode(logits)

# greedy search decoding
greedy_text = greedy_decode(logits, vocab)
greedy_text = ("".join(c for c in greedy_text if c not in ["_"]))

# ground truth
ground_truth_text = df_test_data[0]['text']

In [118]:
print(f'Beam Search: {text}\n')
print(f'Greedy Search: {greedy_text}\n')
print(f'Ground Truth: {ground_truth_text}')

Beam Search: OW INCREASE ONE THREE

Greedy Search: OW INCREASE ONE THREE

Ground Truth: OOW INCREASE ONE THREE


## Actual code

In [None]:
# greedy decode algorithm
def greedy_decode(logits, labels):
    """Decode argmax of logits and squash in CTC fashion."""
    label_dict = {n: c for n, c in enumerate(labels)}
    prev_c = None
    out = []
    for n in logits.argmax(axis=1):
        c = label_dict.get(n, "")  # if not in labels, then assume it's ctc blank char
        if c != prev_c:
            out.append(c)
        prev_c = c
    return "".join(out)

In [None]:
# load the finetuned model and the processor
asr_model = Wav2Vec2ForCTC.from_pretrained('./saved_model/')
asr_processor = Wav2Vec2Processor.from_pretrained('./processor/')

# load the kenlm language model
KENLM_MODEL_LOC = 'lm/4gram_big.arpa.gz'

# get the vocab list from the dictionary
vocab = list(asr_processor.tokenizer.get_vocab().keys())

# convert some vocabs
vocab[vocab.index('[PAD]')] = '_'
vocab[vocab.index('|')] = ' '

# build the decoder
decoder = build_ctcdecoder(
    labels = vocab,
    kenlm_model_path = KENLM_MODEL_LOC,
    alpha=0.6,  # tuned on a val set
    beta=2.0,  # tuned on a val set
)

In [121]:
# load the test set data for evaluation
with open('./pkl/magister_data_flac_16000_test.pkl', 'rb') as f:
    df_test = pickle.load(f)

# convert the data into a huggingface Dataset object
data_test = Dataset.from_pandas(df_test)
data_test

Dataset({
    features: ['file', 'audio', 'text'],
    num_rows: 334
})

In [123]:
ground_truth_list = []
pred_beam_search_list = []
pred_greedy_search_list = []
# append the text and predictions into lists
for idx, entry in tqdm(enumerate(data_test)):
    # get logits
    audio_array = np.array(df_test_data[idx]['audio']['array'])
    input_values = asr_processor(audio_array, return_tensors="pt", sampling_rate=16000).input_values  # Batch size 1
    logits = asr_model(input_values).logits.cpu().detach().numpy()[0]

    # beam search decoding 
    beam_text = decoder.decode(logits)

    # greedy search decoding
    greedy_text = greedy_decode(logits, vocab)
    greedy_text = ("".join(c for c in greedy_text if c not in ["_"]))

    # ground truth
    ground_truth_text = df_test_data[idx]['text']
    
    # appending the data to the individual lists
    ground_truth_list.append(ground_truth_text)
    pred_beam_search_list.append(beam_text)
    pred_greedy_search_list.append(greedy_text)

334it [03:31,  1.58it/s]


In [125]:
# define evaluation metric
wer_metric = load_metric("wer")

Downloading builder script:   0%|          | 0.00/1.90k [00:00<?, ?B/s]

In [129]:
print("WER (greedy search): {:.5f}".format(wer_metric.compute(predictions=pred_greedy_search_list, references=ground_truth_list)))
print()
print("WER (beam search): {:.5f}".format(wer_metric.compute(predictions=pred_beam_search_list, references=ground_truth_list)))

WER (greedy search): 0.26360

WER (beam search): 0.58328
