### Installations

In [1]:
%%capture

!pip install datasets==1.13.3
!pip install transformers==4.11.3
!pip install pyspellchecker
!pip install symspellpy
!pip install jiwer
!pip install s3fs
!pip install boto3
!pip install hazm

print('everything installed')

### Imports

In [5]:
# importing relevant libraries
import re
import sys
import torch
import warnings
import torchaudio
from spellchecker import SpellChecker
from hazm import word_tokenize, Normalizer
from symspellpy import SymSpell, Verbosity
from datasets import load_dataset, load_metric
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor

### Global Variables

In [6]:
# ignore warnings
warnings.simplefilter("ignore")

# conditional inference
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# initializing spell-checker parameters
spell = SpellChecker(distance=1, language='de')
sym_spell = SymSpell(max_dictionary_edit_distance=2, prefix_length=7)
dictionary_path = "wiki_fa_80k.txt"
_normalizer = Normalizer()

### Helper Functions

In [7]:
def get_model_id(lang_id):
    if lang_id == 'de':
        return 'facebook/wav2vec2-large-xlsr-53-german'
    
    elif lang_id == 'fa':
        return 'm3hrdadfi/wav2vec2-large-xlsr-persian'
    
    else:
        return False

def load_model(lang_id):
    model_id = get_model_id(lang_id)
    if model_id:
        model = Wav2Vec2ForCTC.from_pretrained(model_id).to(DEVICE)
        processor = Wav2Vec2Processor.from_pretrained(model_id)
        return model, processor
    else:
        return None, None

### Spell Checker

In [8]:
def spell_check(sentence, lang_id):
    if lang_id == 'de':
        sentence = ' '.join([spell.correction(word) for word in sentence.split()])
        return sentence
    
    if lang_id == 'fa':
        sentence = normalizer.normalize(sentence)
        sentence = ' '.join([sym_spell.lookup(word, Verbosity.ALL, max_edit_distance=1)[0].term 
                             if sym_spell.lookup(word, Verbosity.ALL, max_edit_distance=2) else word
                             for word in word_tokenize(sentence)])
        return sentence
    
    return sentence

### Inference

In [20]:
def predict(audio, lang_id, model, processor):
    speech, orig_freq = torchaudio.load(audio)
    resampler = torchaudio.transforms.Resample(orig_freq=orig_freq, new_freq=16_000)
    speech = resampler.forward(speech.squeeze(0)).numpy()
    sampling_rate = resampler.new_freq
    features = processor(speech, sampling_rate=sampling_rate, return_tensors="pt")
    input_values = features.input_values.to(DEVICE)
    attention_mask = features.attention_mask.to(DEVICE)
    with torch.no_grad():
        logits = model(input_values, attention_mask=attention_mask).logits
    pred_ids = torch.argmax(logits, dim=-1)
    predicted = processor.batch_decode(pred_ids)
    predicted = spell_check(predicted[0], LANG_ID)
    return predicted

### Test

In [24]:
lang_id = 'fa'
SAMPLES = 10
test_dataset = load_dataset("common_voice", LANG_ID, split=f"test[:{SAMPLES}]", data_dir="./cv-corpus-6.1-2020-12-11",
                           keep_in_memory=True)
model, processor = (load_model(lang_id))

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [25]:
audio = test_dataset[3]['path']
prediction = predict(audio, lang_id, model, processor)
print(prediction)

فلیط بید اه نعاف ما لخن اون که بم ریک لنگ دید لیسته


In [26]:
lang_id = 'de'
SAMPLES = 10
test_dataset = load_dataset("common_voice", LANG_ID, split=f"test[:{SAMPLES}]", data_dir="./cv-corpus-6.1-2020-12-11",
                           keep_in_memory=True)
model, processor = (load_model(lang_id))



In [27]:
audio = test_dataset[3]['path']
prediction = predict(audio, lang_id, model, processor)
print(prediction)

fhelippehat eine auch für monarchen ungewöhnlich lange titelliste
