In [1]:
import time
from jiwer import wer
from tqdm import tqdm

import torch
import torchaudio

import fairseq_mod
from fairseq_mod.models.wav2vec.wav2vec2_asr import Wav2VecCtc

from utils import Wav2VecCtc, W2lViterbiDecoder, postprocess_features, post_process_sentence

### Step 1: import and start Ray
Note: it's better to set `num_cpus` to be much smaller than the actual number of CPUs available on the machine. To check the number of CPUs on a machine, run:
```
import psutil
psutil.cpu_count(logical=True)
```

I have 48 logical CPUs on my machine and I set `num_cpus` to 20 in order to run Ray successfully

In [2]:
import ray
ray.shutdown()
ray.init(num_cpus=20)

2021-03-02 18:25:01,265	INFO services.py:1174 -- View the Ray dashboard at [1m[32mhttp://127.0.0.1:8265[39m[22m


{'node_ip_address': '172.17.0.2',
 'raylet_ip_address': '172.17.0.2',
 'redis_address': '172.17.0.2:6379',
 'object_store_address': '/tmp/ray/session_2021-03-02_18-25-00_785668_42284/sockets/plasma_store',
 'raylet_socket_name': '/tmp/ray/session_2021-03-02_18-25-00_785668_42284/sockets/raylet',
 'webui_url': '127.0.0.1:8265',
 'session_dir': '/tmp/ray/session_2021-03-02_18-25-00_785668_42284',
 'metrics_export_port': 60873,
 'node_id': 'e97a56849240c143a6f3d67926f2301601d6001706b8cb7d7e9ffdca'}

### Step 2: Create the wav2vec 2.0 model, decoder and data loader

In [3]:
model_path = "/home/models/wav2vec2/wav2vec_small_960h.pt"
data_path = "/home/datasets"
target_dict = fairseq_mod.data.Dictionary.load('ltr_dict.txt')

In [4]:
w2v = torch.load(model_path)
model = Wav2VecCtc.build_model(w2v["args"], target_dict)
model.load_state_dict(w2v["model"], strict=True)
model.eval()
decoder = W2lViterbiDecoder(target_dict)

In [5]:
dev_clean_librispeech_data = torchaudio.datasets.LIBRISPEECH(data_path, url='dev-clean', download=False)
data_loader = torch.utils.data.DataLoader(dev_clean_librispeech_data, batch_size=1, shuffle=False)

### Step 3: Define a helper method which converts one audio sample into text

In [6]:
def process_data_sample(data_sample, model, decoder, target_dict):
    encoder_input = dict()
    feature = postprocess_features(data_sample[0][0][0], data_sample[1]).unsqueeze(0)
    padding_mask = torch.BoolTensor(feature.size(1)).fill_(False).unsqueeze(0)
    
    encoder_input["source"] = feature
    encoder_input["padding_mask"] = padding_mask
    encoder_input["features_only"] = True
    encoder_input["mask"] = False
    
    encoder_out = model(**encoder_input)
    emissions = model.get_normalized_probs(encoder_out, log_probs=True)
    emissions = emissions.transpose(0, 1).float().cpu().contiguous()
    
    decoder_out = decoder.decode(emissions)
    hyp_pieces = target_dict.string(decoder_out[0][0]["tokens"].int().cpu())
    prediction = post_process_sentence(hyp_pieces, 'letter')
    
    return prediction

### Step 4: Define the remote function for Ray

In [7]:
@ray.remote
def remote_process_data_sample(batch, model, generator, target_dict):
    result = process_data_sample(batch, model, generator, target_dict)
    return result

### Step 5: Put wav2vec 2.0 model and decoder to a shared memory space

In [8]:
model_id = ray.put(model)
decoder_id = ray.put(decoder)

### Step 6: Start the inference task for each data sample and collect predictions
Note that Ray will distribute inference tasks in the background

In [9]:
prediction_futures, ground_truths, start_time = [], [], time.time()
for i, batch in enumerate(tqdm(data_loader)):
    prediction_future = remote_process_data_sample.remote(batch, model_id, decoder_id, target_dict)
    prediction_futures.append(prediction_future)
    ground_truths.append(batch[2][0])
predictions = ray.get(prediction_futures)
inference_time = time.time() - start_time

100%|██████████| 2703/2703 [00:15<00:00, 173.30it/s]


### Step 7: Calculate the WER score

In [10]:
wer_score = wer(ground_truths, predictions)

In [11]:
print("WER is {:.2f}%. Inference took {} seconds".format(wer_score*100, int(inference_time)))

WER is 3.17%. Inference took 259 seconds
