In [1]:
# !wget https://huggingface.co/datasets/malaysia-ai/fleurs-my-ms/resolve/main/test-fleurs.json
# !wget https://huggingface.co/datasets/malaysia-ai/fleurs-my-ms/resolve/main/fleurs-test.zip
# !unzip fleurs-test.zip

In [2]:
from glob import glob
import json
import torch
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq, pipeline
from datasets import Audio
from tqdm import tqdm
import jiwer

sr = 16000
audio = Audio(sampling_rate=sr)
PUNCTUATION = '!"#$%&\'()*+,./:;<=>?@[\\]^_`{|}~'

In [3]:
processor = AutoProcessor.from_pretrained('openai/whisper-small')
model = AutoModelForSpeechSeq2Seq.from_pretrained(
    'openai/whisper-small',
    use_flash_attention_2 = True,
    torch_dtype = torch.bfloat16
)

The model was loaded with use_flash_attention_2=True, which is deprecated and may be removed in a future release. Please use `attn_implementation="flash_attention_2"` instead.
You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.


In [4]:
_ = model.cuda()

In [5]:
with open('test-fleurs.json') as fopen:
    data = json.load(fopen)

In [6]:
len(data)

749

In [7]:
wer, cer = [], []

for i in tqdm(range(len(data))):
    y = audio.decode_example(audio.encode_example(data[i]['audio_filename']))['array']
    inputs = processor([y], return_tensors = 'pt', sampling_rate = 16000)
    inputs['input_features'] = inputs['input_features'].type(torch.bfloat16).cuda()
    r = model.generate(inputs['input_features'], language='ms', return_timestamps=True)
    
    out = processor.tokenizer.decode(r[0], skip_special_tokens = True).strip()
    actual = processor.tokenizer.decode(processor.tokenizer.encode(data[i]['new_text']), skip_special_tokens = True).strip()
    
    for p in PUNCTUATION:
        out = out.replace(p, '')
        actual = actual.replace(p, '')
        
    actual = actual.lower()
    out = out.lower()
    
    error = jiwer.wer(actual, out)
    wer.append(error)
    error = jiwer.cer(actual, out)
    cer.append(error)

100%|██████████| 749/749 [08:05<00:00,  1.54it/s]


In [9]:
import numpy as np

np.mean(wer), np.mean(cer)

(0.2327510905228186, 0.07028889922090295)