In [None]:
import jiwer
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
from utils import load_custom_dataset

In [None]:
base_directory = Path.cwd().parent

dataset_name = "yale_econ251"
data_dir = base_directory / 'data'

dataset_size = "normal" # or 'tiny'

if dataset_size == "tiny":
    audio_dir = data_dir / 'inputs' / dataset_name / 'lectures-tiny'
    transcripts_dir = data_dir / 'inputs' / dataset_name / 'transcripts-tiny'
else:
    audio_dir = data_dir / 'inputs' / dataset_name / 'lectures'
    transcripts_dir = data_dir / 'inputs' / dataset_name / 'transcripts'


In [None]:

model_dir = "wav2vec2-base-100h" # baseline
# model_name = "facebook/wav2vec2-base"

In [None]:
predictions_dir = data_dir / 'predictions' / dataset_name / model_dir

In [None]:
import os
os.listdir(predictions_dir)

In [None]:
transformation = jiwer.Compose([
    jiwer.ToUpperCase(),
    jiwer.RemoveWhiteSpace(replace_by_space=True),
    jiwer.RemoveMultipleSpaces(),
    jiwer.RemovePunctuation(),
    jiwer.ReduceToListOfListOfWords(word_delimiter=" ")
])

In [None]:
errors = []

selected_files = ['06']#['01', '02', '03', '04', '05' ]

for transcript in transcripts_dir.glob('*.txt'):
    print(transcript)
    file_no = transcript.stem
    if file_no not in selected_files:
        continue

    # load the ground truth text
    with open(transcript, 'r') as f:
        ground_truth = f.read()

    # load the predicted text
    pred_file_name = 'pred_' + file_no + '_' + dataset_size + '_with_lm'
    pred_file_path = (predictions_dir / pred_file_name).with_suffix('.txt')

    with open(pred_file_path, 'r') as f:
        hypothesis = f.read()

    measures = jiwer.compute_measures(ground_truth,
                                      hypothesis,
                                      truth_transform=transformation,
                                      hypothesis_transform=transformation)
    wer = measures['wer']
    mer = measures['mer']
    wil = measures['wil']
    #
    errors.append([wer, mer, wil])


In [None]:
errors

In [None]:
xrange = [i for i in range(len(errors))]

plt.plot(xrange, np.array(errors)[:,0], label="WER")
plt.plot(xrange, np.array(errors)[:,1], label="MER")
plt.plot(xrange, np.array(errors)[:,2], label="WIL")
plt.title(f"{dataset_name}")
plt.xlabel("lecture no.")
plt.ylabel("Score")
plt.legend()
plt.grid()
plt.show()

print(f"Average WER: {np.average(np.array(errors)[:,0])}")
print(f"Average MER: {np.average(np.array(errors)[:,1])}")
print(f"Average WIL: {np.average(np.array(errors)[:,2])}")
