#### NOTE: Also need to identify how fairseq evaluates model by wer/cer on librispeech -> enables beam search decode

In [1]:
from load_fsq_model import load_model

In [2]:
# Path to wav2vec parameters
model = load_model('../parameters/w2v2/wav2vec_small_960h.pt')
# model = load_model('./parameters/wav2vec2_vox_960h_new.pt')
# model = load_model('wav2vec_big_960h.pt')
model.eval();
model.cuda();

In [3]:
# Path to Librispeech forlder
import torchaudio
test_data = torchaudio.datasets.LIBRISPEECH("../data/", "test-clean", download=True)

In [4]:
import numpy as np
from itertools import groupby

# <s>: Beginning of sentence or CTC-blank token
# <pad>: Pad token
# </s>: End of sentence
# <unk>: unkown token
# |: delimieter between words

json_dict = {"<s>": 0, "<pad>": 1, "</s>": 2, "<unk>": 3, "|": 4, 
             "E": 5, "T": 6, "A": 7, "O": 8, "N": 9, "I": 10, "H": 11,
             "S": 12, "R": 13, "D": 14, "L": 15, "U": 16, "M": 17, "W": 18, 
             "C": 19, "F": 20, "G": 21, "Y": 22, "P": 23, "B": 24, "V": 25, 
             "K": 26, "'": 27, "X": 28, "J": 29, "Q": 30, "Z": 31}

class Decoder:
    def __init__(self, json_dict):
        self.dict = json_dict
        self.look_up = np.asarray(list(self.dict.keys()))

    def decode(self, ids):
        converted_tokens = self.look_up[ids]
        fused_tokens = [tok[0] for tok in groupby(converted_tokens)]
        output = ' '.join(''.join(''.join(fused_tokens).split("<s>")).split("|"))
        return output
    
decoder = Decoder(json_dict = json_dict)

In [5]:
from datasets import load_metric

wer_metric = load_metric("wer")
cer_metric = load_metric("cer")

In [6]:
# Evaluation without batch
from tqdm.auto import tqdm
import numpy as np
from jiwer import wer
from datasets import load_metric

In [14]:
_wer = []
_pred = []
_ref = []

for i, data in enumerate(tqdm(test_data)):
    # Batch = 1
    logits = model(source=data[0].cuda(), padding_mask=None)["encoder_out"].transpose(0,1)
    predicted_ids = np.argmax(logits.cpu().detach().numpy(), axis=-1)
    prediction = [decoder.decode(ids) for ids in predicted_ids]
    ground_truth = [data[2]]        
    
    # Stack predictions and ground truth to the cache memory
    wer_metric.add_batch(predictions=prediction, references=ground_truth)

# Calculate at once (same as WER between concatenated predictions & references)
print(f"WER: {wer_metric.compute()}")

  0%|          | 0/2620 [00:00<?, ?it/s]

WER: 0.03395085209981741


---

## Implement batch evaluation & metric wer, cer

In [8]:
# from torch.nn.utils.rnn import pad_sequence
# import torch

# def data_collator(features):
#     # split inputs and labels since they have to be of different lengths and need
#     # different padding methods
#     input_features = [feature[0].squeeze(0) for feature in features]
#     labels = [feature[2] for feature in features]

#     src = pad_sequence(input_features, batch_first=True, padding_value=0.0)
#     mask = torch.zeros(src.shape).masked_fill_(src==0, 1)

#     return src, mask, labels

In [22]:
from typing import Any, Dict, List
from torch.nn.utils.rnn import pad_sequence
import torch

class DataCollatorWithPadding:

    def __call__(self, features: List[Dict[str, Any]]):
        # split inputs and labels since they have to be of different lengths
        # and need different padding methods
        input_features = [feature[0][0] for feature in features]
        labels = [feature[2] for feature in features]
        # Padding value should be not zero, unless predictions will slightly changes
        src = pad_sequence(input_features, batch_first=True, padding_value=float("-inf"))
        mask = torch.zeros(src.shape).masked_fill_(src==float("-inf"), 1)
        
        return {'src': src, 'mask': mask, 'labels': labels}
    
data_collator = DataCollatorWithPadding()

In [23]:
from torch.utils.data import DataLoader
test_dataloader = DataLoader(test_data, batch_size=1, collate_fn=data_collator, num_workers=4)

wer_metric = load_metric("wer")
cer_metric = load_metric("cer")

In [24]:
wer_ = []
cer_ = []

for i, batch in enumerate(tqdm(test_dataloader)):
    
    logits = model(source=batch['src'].cuda(), padding_mask=batch['mask'])["encoder_out"].transpose(0,1)
    predicted_ids = np.argmax(logits.cpu().detach().numpy(), axis=-1)
    predictions = [decoder.decode(ids) for ids in predicted_ids]
    
    wer_metric.add_batch(predictions=predictions, references=batch['labels'])
    cer_metric.add_batch(predictions=predictions, references=batch['labels'])

print(f"WER: {wer_metric.compute()}")
print(f"CER: {cer_metric.compute()}")

  0%|          | 0/2620 [00:00<?, ?it/s]

WER: 0.03395085209981741
<jiwer.transforms.Compose object at 0x7f53649715d0>
CER: 0.009526515824246084
