#### NOTE: Also need to identify how fairseq evaluates model by wer/cer on librispeech -> enables beam search decode

In [1]:
from load_fsq_model import load_model

In [2]:
# model = load_model('wav2vec_small_960h.pt')
model = load_model('wav2vec2_vox_960h_new.pt')
# model = load_model('wav2vec_big_960h.pt')

In [3]:
model.eval();
model.cuda();

In [4]:
import torchaudio
test_data = torchaudio.datasets.LIBRISPEECH("../", "test-clean", download=True)

In [5]:
import numpy as np
from itertools import groupby

json_dict = {"<s>": 0, "<pad>": 1, "</s>": 2, "<unk>": 3, "|": 4, "E": 5, "T": 6, "A": 7, "O": 8, "N": 9, "I": 10, "H": 11, "S": 12, "R": 13, "D": 14, "L": 15, "U": 16, "M": 17, "W": 18, "C": 19, "F": 20, "G": 21, "Y": 22, "P": 23, "B": 24, "V": 25, "K": 26, "'": 27, "X": 28, "J": 29, "Q": 30, "Z": 31}

class Decoder:
    def __init__(self, json_dict):
        self.dict = json_dict
        self.look_up = np.asarray(list(self.dict.keys()))

    def decode(self, ids):
        converted_tokens = self.look_up[ids]
        fused_tokens = [tok[0] for tok in groupby(converted_tokens)]
        output = ' '.join(''.join(''.join(fused_tokens).split("<s>")).split("|"))
        return output
    
decoder = Decoder(json_dict=json_dict)

In [6]:
from datasets import load_metric

wer_metric = load_metric("wer")
cer_metric = load_metric("cer")

In [7]:
# Evaluation without batch
from tqdm.auto import tqdm
import numpy as np
from jiwer import wer
from datasets import load_metric

wer_metric = load_metric("wer")

wer_ = []

for i, data in enumerate(tqdm(test_data)):
    logits = model(source=data[0].cuda(), padding_mask=None)["encoder_out"].transpose(0,1)
    predicted_ids = np.argmax(logits.cpu().detach().numpy(), axis=-1)
    predictions = [decoder.decode(ids) for ids in predicted_ids]
    labels = [data[2]]
    
    # wer_.append(wer(labels, predictions))
    wer_metric.add_batch(predictions=predictions, references=labels)
    
# print(f"WER: {np.mean(wer_)}")
print(f"WER: {wer_metric.compute()}")

  0%|          | 0/2620 [00:00<?, ?it/s]

WER: 0.03210590383444918


---

## Implement batch evaluation & metric wer, cer

In [8]:
# from torch.nn.utils.rnn import pad_sequence
# import torch

# def data_collator(features):
#     # split inputs and labels since they have to be of different lengths and need
#     # different padding methods
#     input_features = [feature[0].squeeze(0) for feature in features]
#     labels = [feature[2] for feature in features]

#     src = pad_sequence(input_features, batch_first=True, padding_value=0.0)
#     mask = torch.zeros(src.shape).masked_fill_(src==0, 1)

#     return src, mask, labels

In [9]:
from typing import Any, Dict, List
from torch.nn.utils.rnn import pad_sequence
import torch

class DataCollatorWithPadding:

    def __call__(self, features: List[Dict[str, Any]]):
        # split inputs and labels since they have to be of different lengths and need
        # different padding methods
        input_features = [feature[0][0] for feature in features]
        labels = [feature[2] for feature in features]
        
        src = pad_sequence(input_features, batch_first=True, padding_value=0)
        mask = torch.zeros(src.shape).masked_fill_(src==0, 1)
        
        return {'src': src, 'mask': mask, 'labels': labels}
    
data_collator = DataCollatorWithPadding()

In [10]:
from torch.utils.data import DataLoader
test_dataloader = DataLoader(test_data, batch_size=2, collate_fn=data_collator, num_workers=4)

In [12]:
# Full batch evaluation
from tqdm.auto import tqdm
import numpy as np

wer_ = []
cer_ = []

for i, batch in enumerate(tqdm(test_dataloader)):
    logits = model(source=batch['src'].cuda(), padding_mask=batch['mask'])["encoder_out"].transpose(0,1)
    predicted_ids = np.argmax(logits.cpu().detach().numpy(), axis=-1)
    predictions = [decoder.decode(ids) for ids in predicted_ids]
    
    # wer_.append(wer_metric.compute(predictions=predictions, references=labels))
    # cer_.append(cer_metric.compute(predictions=predictions, references=labels))
    wer_metric.add_batch(predictions=predictions, references=batch['labels'])
    cer_metric.add_batch(predictions=predictions, references=batch['labels'])
    
wer = wer_metric.compute()
cer = cer_metric.compute()

# print(f"WER: {np.mean(wer_)}")
# print(f"CER: {np.mean(cer_)}")
print(f"WER: {wer}")
print(f"CER: {cer}")

  0%|          | 0/1310 [00:00<?, ?it/s]

WER: 0.03596137763247928
CER: 0.012510212055553582
