In [1]:
from jiwer import wer

import torch
import torchaudio

import fairseq_mod
from fairseq_mod.models.wav2vec.wav2vec2_asr import Wav2VecCtc

from utils import Wav2VecCtc, W2lViterbiDecoder, postprocess_features, post_process_sentence

### Step 1: Specify paths to wav2vec 2.0 model and dataset. Create the letter dictionary.

In [2]:
model_path = "/home/models/wav2vec_big_960h.pt"
data_path = "/home/datasets"
target_dict = fairseq_mod.data.Dictionary.load('ltr_dict.txt')

### Step 2: Initialize wav2vec 2.0 model

In [3]:
w2v = torch.load(model_path)
model = Wav2VecCtc.build_model(w2v["args"], target_dict)
model.load_state_dict(w2v["model"], strict=True)
model = model.cuda()
model.eval()

Wav2VecCtc(
  (w2v_encoder): Wav2VecEncoder(
    (w2v_model): Wav2Vec2Model(
      (feature_extractor): ConvFeatureExtractionModel(
        (conv_layers): ModuleList(
          (0): Sequential(
            (0): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
            (1): Dropout(p=0.0, inplace=False)
            (2): Fp32GroupNorm(512, 512, eps=1e-05, affine=True)
            (3): GELU()
          )
          (1): Sequential(
            (0): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
            (1): Dropout(p=0.0, inplace=False)
            (2): GELU()
          )
          (2): Sequential(
            (0): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
            (1): Dropout(p=0.0, inplace=False)
            (2): GELU()
          )
          (3): Sequential(
            (0): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
            (1): Dropout(p=0.0, inplace=False)
            (2): GELU()
          )
          (4): Sequen

### Step 3: Create decoder

In [4]:
decoder = W2lViterbiDecoder(target_dict)

### Step 4: Create data loader

In [5]:
dev_clean_librispeech_data = torchaudio.datasets.LIBRISPEECH(data_path, url='dev-clean', download=False)
data_loader = torch.utils.data.DataLoader(dev_clean_librispeech_data, batch_size=1, shuffle=False)

### Step 5: Define a helper method which converts one audio sample into text

In [6]:
def process_data_sample(data_sample, model, decoder, target_dict):
    encoder_input = dict()
    feature = postprocess_features(data_sample[0][0][0], data_sample[1]).unsqueeze(0)
    padding_mask = torch.BoolTensor(feature.size(1)).fill_(False).unsqueeze(0)
    
    encoder_input["source"] = feature.cuda()
    encoder_input["padding_mask"] = padding_mask.cuda()
    encoder_input["features_only"] = True
    encoder_input["mask"] = False
    
    encoder_out = model(**encoder_input)
    emissions = model.get_normalized_probs(encoder_out, log_probs=True)
    emissions = emissions.transpose(0, 1).float().cpu().contiguous()
    
    decoder_out = decoder.decode(emissions)
    hyp_pieces = target_dict.string(decoder_out[0][0]["tokens"].int().cpu())
    prediction = post_process_sentence(hyp_pieces, 'letter')
    
    return prediction

### Step 6: Calculate the WER on the entire dataset

In [7]:
predictions = []
ground_truths = []
for i, data_sample in enumerate(data_loader):
    prediction = process_data_sample(data_sample, model, decoder, target_dict)
    predictions.append(prediction)
    ground_truths.append(data_sample[2][0])
wer_score = wer(ground_truths, predictions)
print("WER is {:.2f}".format(wer_score*100))

WER is 2.63
