# Inference

In [1]:
%env LC_ALL=C.UTF-8
%env LANG=C.UTF-8
%env TRANSFORMERS_CACHE=.cache
%env HF_DATASETS_CACHE=.cache
%env CUDA_LAUNCH_BLOCKING=1
%env CUDA_VISIBLE_DEVICES=2

env: LC_ALL=C.UTF-8
env: LANG=C.UTF-8
env: TRANSFORMERS_CACHE=.cache
env: HF_DATASETS_CACHE=.cache
env: CUDA_LAUNCH_BLOCKING=1
env: CUDA_VISIBLE_DEVICES=2


In [2]:
!pip3 install https://github.com/kpu/kenlm/archive/master.zip
!pip3 install git+https://github.com/hbasafa/py-ctc-decode.git

Collecting https://github.com/kpu/kenlm/archive/master.zip
  Using cached https://github.com/kpu/kenlm/archive/master.zip (541 kB)
You should consider upgrading via the '/home/gpu/services/wav2vec2/wav2vec2-finetune/env/bin/python -m pip install --upgrade pip' command.[0m
Collecting git+https://github.com/hbasafa/py-ctc-decode.git
  Cloning https://github.com/hbasafa/py-ctc-decode.git to /tmp/pip-req-build-hgj90jas
  Running command git clone -q https://github.com/hbasafa/py-ctc-decode.git /tmp/pip-req-build-hgj90jas
  Resolved https://github.com/hbasafa/py-ctc-decode.git to commit 4af8aa29487c658a6746c2a01d71da74e93a3aa3
You should consider upgrading via the '/home/gpu/services/wav2vec2/wav2vec2-finetune/env/bin/python -m pip install --upgrade pip' command.[0m


In [2]:
from transformers import Wav2Vec2CTCTokenizer
from transformers import Wav2Vec2FeatureExtractor
from transformers import Wav2Vec2Processor
from transformers import Wav2Vec2ForCTC
from datasets import load_dataset, load_metric

import ctcdecode
import torchaudio
import librosa
import numpy as np
import pandas as pd
import torch
# torch.multiprocessing.set_start_method('spawn')

In [3]:

final_path = "/path/to/model_or_checkpoint"

lm_path = '/path/to/lm.gz'

results_path = "test_results.csv"

test_path = 'path/to/dataset/test.tsv'

device = "cuda"
target_sampling_rate = 16_000

In [None]:
!gunzip /path/to/lm.gz

In [4]:
model = Wav2Vec2ForCTC.from_pretrained(final_path).to(device)
processor = Wav2Vec2Processor.from_pretrained(final_path)
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(final_path)

In [5]:
test_data = load_dataset("csv", data_files={"test": test_path}, delimiter="\t")["test"]
print(test_data)

Using custom data configuration default-7600145e3f20b17b
Reusing dataset csv (.cache/csv/default-7600145e3f20b17b/0.0.0/bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a)


  0%|          | 0/1 [00:00<?, ?it/s]

Dataset({
    features: ['path', 'sentence'],
    num_rows: 994
})


In [6]:
def speech_file_to_array_fn(batch):
    speech_array, sampling_rate = torchaudio.load(batch["path"])
    speech_array = speech_array.squeeze().numpy()
    speech_array = librosa.resample(np.asarray(speech_array), sampling_rate, target_sampling_rate)
    
    
    batch["speech"] = speech_array
    batch["sampling_rate"] = target_sampling_rate
    batch["duration_in_seconds"] = len(batch["speech"]) / target_sampling_rate
    batch["target_text"] = batch["sentence"]
    return batch

def prepare_dataset(batch):
    # check that all files have the correct sampling rate
    assert (
        len(set(batch["sampling_rate"])) == 1
    ), f"Make sure all inputs have the same sampling rate of {processor.feature_extractor.sampling_rate}."

    batch["input_values"] = processor(batch["speech"], sampling_rate=batch["sampling_rate"][0]).input_values
    
    with processor.as_target_processor():
        batch["labels"] = processor(batch["target_text"]).input_ids
    return batch

In [None]:
_test_data = test_data.map(speech_file_to_array_fn, remove_columns=test_data.column_names, num_proc=64)
_test_data = _test_data.map(prepare_dataset, remove_columns=_test_data.column_names, batch_size=16, num_proc=64, batched=True)

In [8]:
sample_id = 0

input_dict = processor(_test_data["input_values"][sample_id], return_tensors="pt", padding=True, sampling_rate=target_sampling_rate)

with torch.no_grad():
    logits = model(input_dict.input_values.to(device)).logits

pred_ids = torch.argmax(logits, dim=-1)[sample_id]

In [9]:
print("Prediction:")
print(processor.decode(pred_ids))

print("\nReference:")
print(test_data["sentence"][sample_id].lower())


Prediction:
نه گفتم من رفتم گفتن که این هم دیگه کاری احتیاج نیست انجام بدید چازدی گه شماتو

Reference:
نه گفتم من رفتم گفتم که این هم دیگه کاری احتیاج نیست انجام بدید چون دیگه شما دو 


In [10]:
vocab_dict = tokenizer.get_vocab()
sort_vocab = sorted((value, key) for (key,value) in vocab_dict.items())
vocab = [x[1].replace("|", " ") if x[1] not in tokenizer.all_special_tokens else "_" for x in sort_vocab]
print(vocab)

['_', '_', '_', '_', ' ', 'آ', 'ئ', 'ا', 'ب', 'ت', 'ث', 'ج', 'ح', 'خ', 'د', 'ذ', 'ر', 'ز', 'س', 'ش', 'ص', 'ض', 'ط', 'ظ', 'ع', 'غ', 'ف', 'ق', 'ل', 'م', 'ن', 'ه', 'و', 'پ', 'چ', 'ژ', 'ک', 'گ', 'ی']


In [12]:
vocabulary = vocab
alpha = 0.5 # LM Weight
beta = 1.0 # LM Usage Reward
word_lm_scorer = ctcdecode.WordKenLMScorer(lm_path, alpha, beta) # use your own kenlm model
decoder = ctcdecode.BeamSearchDecoder(
    vocabulary,
    num_workers=64,
    beam_width=128,
    scorers=[word_lm_scorer],
    cutoff_prob=np.log(1e-7),
    cutoff_top_n=100
)
text = decoder.decode_batch(logits.cpu().numpy())

found 1gram
found 2gram


Loading the LM will be faster if you build a binary file.
Reading /home/gpu/services/wav2vec2/wav2vec2-finetune/lms/cst.v1
----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100
****************************************************************************************************
100%|█████████████████████████████████████████████████████████████████| 1/1 [00:04<00:00,  4.38s/it]


In [13]:
print(text)

['نه گفتم من رفتم گفتن که این هم دیگه کاری احتیاج نیست انجام بدید چز دیگه شما تو']


# Evaluation

In [14]:


def prepare_results(batch):
    
    input_dict = processor(batch["input_values"], return_tensors="pt", padding=True, sampling_rate=target_sampling_rate)
    
    with torch.no_grad():
        pred = model(input_dict.input_values.to(device))
        logits = pred.logits
    
    pred_str = decoder.decode_batch(logits.cpu().numpy())
    
    batch['asr_text'] = pred_str
    
    return batch


In [15]:
results = _test_data.map(prepare_results, remove_columns=_test_data.column_names, batch_size=64, num_proc=1, batched=True)


  0%|          | 0/16 [00:00<?, ?ba/s]


  0%|                                                                        | 0/64 [00:00<?, ?it/s][A
  2%|█                                                               | 1/64 [00:06<07:17,  6.94s/it][A
  3%|██                                                              | 2/64 [00:07<03:01,  2.93s/it][A
  6%|████                                                            | 4/64 [00:07<01:13,  1.23s/it][A
  9%|██████                                                          | 6/64 [00:07<00:41,  1.40it/s][A
 11%|███████                                                         | 7/64 [00:07<00:32,  1.77it/s][A
 16%|█████████▊                                                     | 10/64 [00:08<00:15,  3.40it/s][A
 19%|███████████▊                                                   | 12/64 [00:08<00:12,  4.12it/s][A
 22%|█████████████▊                                                 | 14/64 [00:08<00:09,  5.52it/s][A
 27%|████████████████▋                                         

In [None]:
test_results = pd.DataFrame({"path": test_data["path"],"text": test_data["sentence"], "asr_text": results["asr_text"]})
print(test_results)

In [17]:
wer_metric = load_metric("wer")

wer = wer_metric.compute(predictions=test_results['asr_text'], references=test_results['text'])
print("WER: ", wer)

WER:  0.25727616085314375


In [18]:
test_results.to_csv(results_path, index=False)

# References

1. https://github.com/Wikidepia/wav2vec2-indonesian/blob/master/notebooks/kenlm-wav2vec2.ipynb
2. https://github.com/hbasafa/py-ctc-decode
3. https://github.com/huggingface/transformers/pull/11606
4. https://discuss.huggingface.co/t/language-model-for-wav2vec2-0-decoding/4434/6
5. https://github.com/OthmaneJ/distil-wav2vec2/blob/main/distil-wav2vec2-evaluation.ipynb
