# load dataset

In [3]:
from datasets import load_dataset
dataset = load_dataset("hsekhalilian/commonvoice", split="dev")
dataset = dataset.select(indices=range(100))
dataset

Dataset({
    features: ['client_id', 'path', 'sentence_id', 'sentence', 'sentence_domain', 'up_votes', 'down_votes', 'age', 'gender', 'accents', 'variant', 'locale', 'segment', 'audio', 'normalized_transcription'],
    num_rows: 100
})

# load model

In [4]:
import torch
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor

In [5]:
model_name_or_path = "m3hrdadfi/wav2vec2-large-xlsr-persian-v3"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [6]:
processor = Wav2Vec2Processor.from_pretrained(model_name_or_path)
model = Wav2Vec2ForCTC.from_pretrained(model_name_or_path).to(device)



# predict

In [11]:
import sys_append
from utils.normalizer import persian_normalizer

In [8]:
def predict(batch):
    features = processor(
        [sample["array"] for sample in batch["audio"]], 
        sampling_rate=processor.feature_extractor.sampling_rate, 
        return_tensors="pt", 
        padding=True
    )

    input_values = features.input_values.to(device)
    attention_mask = features.attention_mask.to(device)

    with torch.no_grad():
        logits = model(input_values, attention_mask=attention_mask).logits 

    pred_ids = torch.argmax(logits, dim=-1)

    batch["predicted"] = [persian_normalizer(item) for item in processor.batch_decode(pred_ids)]
    return batch

In [9]:
result = dataset.map(predict, batched=True, batch_size=4)



Map:   0%|          | 0/100 [00:00<?, ? examples/s]

In [10]:
result

Dataset({
    features: ['client_id', 'path', 'sentence_id', 'sentence', 'sentence_domain', 'up_votes', 'down_votes', 'age', 'gender', 'accents', 'variant', 'locale', 'segment', 'audio', 'normalized_transcription', 'predicted'],
    num_rows: 100
})

# evaluate

In [None]:
from utils.evaluate import evaluate_asr

In [18]:
evaluate_asr(result["normalized_transcription"], result["predicted"])

{'wer': 0.28858024691358025, 'cer': 0.07687231936654569}