## How to use pyctcdecode when working with a NeMo model

In [None]:
# install NeMo
!pip install "nemo-toolkit[asr]==1.3.0"

In [2]:
# get a single audio file
!wget https://dldata-public.s3.us-east-2.amazonaws.com/1919-142785-0028.wav

--2021-10-01 10:20:31--  https://dldata-public.s3.us-east-2.amazonaws.com/1919-142785-0028.wav
Resolving dldata-public.s3.us-east-2.amazonaws.com (dldata-public.s3.us-east-2.amazonaws.com)... 52.219.88.16
Connecting to dldata-public.s3.us-east-2.amazonaws.com (dldata-public.s3.us-east-2.amazonaws.com)|52.219.88.16|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 165164 (161K) [audio/wav]
Saving to: ‘1919-142785-0028.wav’


2021-10-01 10:20:31 (1.94 MB/s) - ‘1919-142785-0028.wav’ saved [165164/165164]



In [None]:
# load pretrained NeMo model
import nemo.collections.asr as nemo_asr
# we could choose for example a BPE encoded conformer-ctc model
# asr_model = nemo_asr.models.EncDecCTCModelBPE.from_pretrained(model_name='stt_en_conformer_ctc_small')
# let's take a standard quartznet model though to start
asr_model = nemo_asr.models.EncDecCTCModel.from_pretrained(model_name='QuartzNet15x5Base-En')

In [5]:
# transcribe audio to logits
logits = asr_model.transcribe(["1919-142785-0028.wav"], logprobs=True)[0]

Transcribing:   0%|          | 0/1 [00:00<?, ?it/s]

In [6]:
# look at the alphabet of our model defining the labels for the logit matrix we just calculated
asr_model.decoder.vocabulary

[' ', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', "'"]

In [7]:
from pyctcdecode import build_ctcdecoder

# build the decoder and decode the logits
decoder = build_ctcdecoder(asr_model.decoder.vocabulary)
decoder.decode(logits)

'boil them before they are put into the soup or other dish they may be intended for'

## Librispeech experiments

The real use of a decoder however comes from the ability to 

In [None]:
# NOTE: some of this code is borrowed from the official NeMo tutorial 
#     https://github.com/NVIDIA/NeMo/blob/main/tutorials/asr/Offline_ASR.ipynb

In [None]:
# download pretrained knelm model for librispeech
# NOTE: since out nemo vocabulary is all lowercased, we need to convert all librispeech data as well
import gzip
import os, shutil, wget

lm_gzip_path = '3-gram.pruned.1e-7.arpa.gz'
if not os.path.exists(lm_gzip_path):
    print('Downloading pruned 3-gram model.')
    lm_url = 'http://www.openslr.org/resources/11/3-gram.pruned.1e-7.arpa.gz'
    lm_gzip_path = wget.download(lm_url)
    print('Downloaded the 3-gram language model.')
else:
    print('Pruned .arpa.gz already exists.')

uppercase_lm_path = '3-gram.pruned.1e-7.arpa'
if not os.path.exists(uppercase_lm_path):
    with gzip.open(lm_gzip_path, 'rb') as f_zipped:
        with open(uppercase_lm_path, 'wb') as f_unzipped:
            shutil.copyfileobj(f_zipped, f_unzipped)
    print('Unzipped the 3-gram language model.')
else:
    print('Unzipped .arpa already exists.')

lm_path = 'lowercase_3-gram.pruned.1e-7.arpa'
if not os.path.exists(lm_path):
    with open(uppercase_lm_path, 'r') as f_upper:
        with open(lm_path, 'w') as f_lower:
            for line in f_upper:
                f_lower.write(line.lower())
print('Converted language model file to lowercase.')

In [12]:
# download unigram vocab
!wget http://www.openslr.org/resources/11/librispeech-vocab.txt

--2021-04-27 12:26:03--  http://www.openslr.org/resources/11/librispeech-vocab.txt
Resolving www.openslr.org (www.openslr.org)... 46.101.158.64
Connecting to www.openslr.org (www.openslr.org)|46.101.158.64|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1737588 (1.7M) [text/plain]
Saving to: ‘librispeech-vocab.txt’


2021-04-27 12:26:05 (1.07 MB/s) - ‘librispeech-vocab.txt’ saved [1737588/1737588]



In [6]:
# load unigram list
with open("librispeech-vocab.txt") as f:
    unigram_list = [t.lower() for t in f.read().strip().split("\n")]
    
# load kenlm Model
import kenlm
kenlm_model = kenlm.Model('lowercase_3-gram.pruned.1e-7.arpa')

In [7]:
decoder = build_ctcdecoder(
    asr_model.decoder.vocabulary,
    kenlm_model,
    unigram_list,
)
decoder.decode(logits)

'boil them before they are put into the soup or other dish they may be intended for'

## Experiments on librispeech dev-other

In [None]:
# download librispeech dev-other corpus, using one of the great existing scripts, for example:
#     https://github.com/NVIDIA/NeMo/blob/main/scripts/dataset_processing/get_librispeech_data.py

In [8]:
# load manifest that holds meta information on all files in dev_other
import pandas as pd
dev_other_df = pd.read_json("/my_dir/dev_other.json", lines=True)

In [10]:
# decode all logits (this may take a while)
logits_list = [
    a.cpu().detach().numpy() 
    for a in asr_model.transcribe(dev_other_df["audio_filepath"].tolist(), logprobs=True)
]

In [11]:
decoder = build_ctcdecoder(
    asr_model.decoder.vocabulary,
    kenlm_model,
    unigram_list,
)

In [12]:
import multiprocessing
with multiprocessing.get_context("fork").Pool() as pool:
    pred_list = decoder.decode_batch(pool, logits_list)

In [13]:
from nemo.collections.asr.metrics.wer import word_error_rate
word_error_rate(dev_other_df["text"].tolist(), pred_list)

0.07708590580349269

In [None]:
# let's compare this to greedy decoding
def _greedy_decode(logits):
    """Decode argmax of logits and squash in CTC fashion."""
    label_dict = {n: c for n, c in enumerate(labels)}
    prev_c = None
    out = []
    for n in logits.argmax(axis=1):
        c = label_dict.get(n, "")  # if not in labels, then assume it's ctc blank char
        if c != prev_c:
            out.append(c)
        prev_c = c
    return "".join(out)

In [15]:
word_error_rate(dev_other_df["text"].tolist(), [_greedy_decode(l) for l in logits_list])

0.10084215497225611

we did better by using a language model, but we can improve this further by tuning the decoder parameters

## gridsearch optimal parameters

In [16]:
data_grid = []
for a in [0.6, 0.7, 0.8]:
    for b in [2.0, 3.0, 4.0]:
        decoder.reset_params(alpha=a, beta=b)
        with multiprocessing.get_context("fork").Pool(15) as pool:
            # use lower beam-with here for fast testing
            lm_preds =  decoder.decode_batch(pool, logits_list, beam_width=50)
        wer_val = word_error_rate(dev_other_df["text"].tolist(), lm_preds)
        data_grid.append((a, b, wer_val))
pd.DataFrame(data_grid, columns=["alpha", "beta", "wer"]).sort_values(by="wer").head()

Unnamed: 0,alpha,beta,wer
4,0.7,3.0,0.076822
0,0.6,2.0,0.076937
1,0.6,3.0,0.077032
3,0.7,2.0,0.077083
7,0.8,3.0,0.077113


In [None]:
# advanced parameters to tune:
#     beam_width: how many beams to keep after each step
#     beam_prune_logp: beams that are much worse than best beam will be pruned
#     token_min_logp: tokens below this logp are skipped unless they are argmax of frame
#     prune_logp: beams that are much worse than best beam will be pruned
#     unk_score_offset: score decrease for token if oov
#     lm_score_boundary: whether to have kenlm respect boundaries when scoring