In [2]:
import torch
import torchaudio

torch.random.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(torch.__version__)
print(torchaudio.__version__)
print(device)

cv_dataset = torchaudio.datasets.COMMONVOICE(
    root="_cv_corpus/en",
    tsv="test.tsv",
)

cv_dataset[0]

1.13.0.dev20220725
0.13.0.dev20220725
cpu


(tensor([[ 0.0000e+00, -3.8697e-12, -9.5148e-12,  ...,  1.1894e-06,
           1.4817e-06,  1.8046e-06]]),
 32000,
 {'client_id': '000abb3006b78ea4c1144e55d9d158f05a9db0110160510fef2b006f2c2c8e35f7bb538b04542511834b61503cdda5b0331566a5cf59dc0d375a44afc4d10777',
  'path': 'common_voice_en_27710027.mp3',
  'sentence': 'Joe Keaton disapproved of films, and Buster also had reservations about the medium.',
  'up_votes': '3',
  'down_votes': '1',
  'age': '',
  'gender': '',
  'accents': '',
  'locale': 'en',
  'segment': ''})

In [3]:
from pprint import pprint
sample_rates = {}

for data_item in cv_dataset:
    sample_rates[data_item[1]] = 0

pprint(sample_rates)

KeyboardInterrupt: 

In [4]:
bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
w2v_model = bundle.get_model().to(device)

print(w2v_model.__class__)
print(f"sample_rate: {bundle.sample_rate}")

<class 'torchaudio.models.wav2vec2.model.Wav2Vec2Model'>
sample_rate: 16000


In [5]:
from torchaudio.models.decoder import ctc_decoder
from torchaudio.models.decoder import download_pretrained_files

files = download_pretrained_files("librispeech-4-gram")

LM_WEIGHT1 = 3.23
LM_WEIGHT2 = 1.0
WORD_SCORE = -0.26

beam_search_decoder1 = ctc_decoder(
    lexicon=files.lexicon,
    tokens=files.tokens,
    lm=files.lm,
    nbest=3,
    beam_size=1500,
    lm_weight=LM_WEIGHT1,
    word_score=WORD_SCORE,
)
beam_search_decoder2 = ctc_decoder(
    lexicon=files.lexicon,
    tokens=files.tokens,
    lm=files.lm,
    nbest=3,
    beam_size=1500,
    lm_weight=LM_WEIGHT2,
    word_score=WORD_SCORE,
)

In [16]:
from numpy import average
import jiwer
from tqdm.notebook import tqdm

transformation = jiwer.Compose([
    jiwer.ToLowerCase(),
    jiwer.ExpandCommonEnglishContractions(),
    jiwer.RemovePunctuation(),
    jiwer.RemoveWhiteSpace(replace_by_space=True),
    jiwer.RemoveMultipleSpaces(),
    jiwer.ReduceToListOfListOfWords(word_delimiter=" ")
]) 

accum_wer1 = 0
accum_wer2 = 0
hypothesies1 = []
hypothesies2 = []
truths = []

with torch.inference_mode():
    for wav, sr, metadata in tqdm(cv_dataset):
        ground_truth = metadata["sentence"]
        truths.append(ground_truth)

        waveform = torchaudio.functional.resample(wav, sr, bundle.sample_rate)
        emission, _ = w2v_model(waveform)
        hypothesis1 = ' '.join(beam_search_decoder1(emission[0][None, :])[0][0].words)
        hypothesis2 = ' '.join(beam_search_decoder2(emission[0][None, :])[0][0].words)
        hypothesies1.append(hypothesis1)
        hypothesies2.append(hypothesis2)
        # print(f"ground_truth: {ground_truth}")
        # print(f"hypothesis1: {hypothesis1}")
        # print(f"hypothesis2: {hypothesis2}")
        # print("WER1:", jiwer.wer(ground_truth, hypothesis1, truth_transform=transformation, hypothesis_transform=transformation))
        # print("WER2:", jiwer.wer(ground_truth, hypothesis2, truth_transform=transformation, hypothesis_transform=transformation))
        accum_wer1 += jiwer.wer(ground_truth, hypothesis1, truth_transform=transformation, hypothesis_transform=transformation)
        accum_wer2 += jiwer.wer(ground_truth, hypothesis2, truth_transform=transformation, hypothesis_transform=transformation)

average_wer1 = accum_wer1 / len(cv_dataset)
average_wer2 = accum_wer2 / len(cv_dataset)
print(f"average_wer1: {average_wer1}")
print(f"average_wer2: {average_wer2}")

total_wer1 = jiwer.wer(truths, hypothesies1, truth_transform=transformation, hypothesis_transform=transformation)
total_wer2 = jiwer.wer(truths, hypothesies2, truth_transform=transformation, hypothesis_transform=transformation)
print(f"total_wer1: {total_wer1}")
print(f"total_wer2: {total_wer2}")


  0%|          | 0/16345 [00:00<?, ?it/s]

average_wer1: 0.31724393969763076
average_wer2: 0.33207283387666553
total_wer1: 0.31346207127623243
total_wer2: 0.3258977184131341


In [10]:
import jiwer

jiwer.wer(["Hello,,,   this is a test", 'past'], ["hello this is a test", 'past'], truth_transform=transformation, hypothesis_transform=transformation)

0.0