<span style="font-size:36px"><b>CTC Decoding</b></span>

Copyright &copy; 2020 Gunawan Lumban Gaol

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language overning permissions and limitations under the License.

# Import Packages

In [1]:
import os
import re
import pickle
from multiprocessing import Pool
# from multiprocess import Pool  # uses dill
from collections import defaultdict, Counter
from string import ascii_lowercase

import numpy as np
import tensorflow as tf

from gurih.utils import batch

In [2]:
resource_dir = '4.0-glg-ctc-decoding-resources/'

# Prefix Beam Search

This borrows example from https://github.com/corticph/prefix-beam-search.

In [3]:
def greedy_decoder(ctc):
    """
    Performs greedy decoding (max decoding) on the output of a CTC network.

    Args:
    ctc (np.ndarray): The CTC output. Should be a 2D array (timesteps x alphabet_size)

    Returns:
    string: The decoded CTC output.
    """

    alphabet = list(ascii_lowercase) + [' ', '>']
    alphabet_size = len(alphabet)

    #  collapse repeating characters
    arg_max = np.argmax(ctc, axis=1)
    repeat_filter = arg_max[1:] != arg_max[:-1]
    repeat_filter = np.concatenate([[True], repeat_filter])
    collapsed = arg_max[repeat_filter]

    # discard blank tokens (the blank is always last in the alphabet)
    blank_filter = np.where(collapsed < (alphabet_size - 1))[0]
    final_sequence = collapsed[blank_filter]
    full_decode = ''.join([alphabet[letter_idx] for letter_idx in final_sequence])

    return full_decode[:full_decode.find('>')]

In [4]:
def prefix_beam_search(ctc, lm=None, k=25, alpha=0.30, beta=5, prune=0.001):
    """
    Performs prefix beam search on the output of a CTC network.

    Parameters
    ----------
    ctc : np.ndarray 
        The CTC output. Should be a 2D array (timesteps x alphabet_size)
    lm : function, [default=None]
        Should take as input a string and output a probability
    k : int, [default=25]
        The beam width. Will keep the 'k' most likely candidates at each timestep
    alpha : float, [default=0.30]
        The language model weight. Should usually be between 0 and 1.
    beta : float, [default=0.5]
        The language model compensation term. The higher the 'alpha', the higher the 'beta'.
    prune : float, [default=0.001]
        Only extend prefixes with chars with an emission probability higher than 'prune'.

    Returns
    -------
    string: The decoded CTC output.
    """

    lm = (lambda l: 1) if lm is None else lm # if no LM is provided, just set to function returning 1
    W = lambda l: re.findall(r'\w+[\s|>]', l)
    alphabet = list(ascii_lowercase) + [' ', '>', '%']
    F = ctc.shape[1]
    ctc = np.vstack((np.zeros(F), ctc)) # just add an imaginative zero'th step (will make indexing more intuitive)
    T = ctc.shape[0]

    # STEP 1: Initiliazation
    O = ''
    Pb, Pnb = defaultdict(Counter), defaultdict(Counter)
    Pb[0][O] = 1
    Pnb[0][O] = 0
    A_prev = [O]
    # END: STEP 1

    # STEP 2: Iterations and pruning
    for t in range(1, T):
        pruned_alphabet = [alphabet[i] for i in np.where(ctc[t] > prune)[0]]
        for l in A_prev:

            if len(l) > 0 and l[-1] == '>':
                Pb[t][l] = Pb[t - 1][l]
                Pnb[t][l] = Pnb[t - 1][l]
                continue  

            for c in pruned_alphabet:
                c_ix = alphabet.index(c)
                # END: STEP 2

                # STEP 3: “Extending” with a blank
                if c == '%':
                    Pb[t][l] += ctc[t][-1] * (Pb[t - 1][l] + Pnb[t - 1][l])
                # END: STEP 3

                # STEP 4: Extending with the end character
                else:
                    l_plus = l + c
                    if len(l) > 0 and c == l[-1]:
                        Pnb[t][l_plus] += ctc[t][c_ix] * Pb[t - 1][l]
                        Pnb[t][l] += ctc[t][c_ix] * Pnb[t - 1][l]
                # END: STEP 4

                    # STEP 5: Extending with any other non-blank character and LM constraints
                    elif len(l.replace(' ', '')) > 0 and c in (' ', '>'):
                        lm_prob = lm(l_plus.strip(' >')) ** alpha
                        Pnb[t][l_plus] += lm_prob * ctc[t][c_ix] * (Pb[t - 1][l] + Pnb[t - 1][l])
                    else:
                        Pnb[t][l_plus] += ctc[t][c_ix] * (Pb[t - 1][l] + Pnb[t - 1][l])
                    # END: STEP 5

                    # STEP 6: Make use of discarded prefixes
                    if l_plus not in A_prev:
                        Pb[t][l_plus] += ctc[t][-1] * (Pb[t - 1][l_plus] + Pnb[t - 1][l_plus])
                        Pnb[t][l_plus] += ctc[t][c_ix] * Pnb[t - 1][l_plus]
                    # END: STEP 6

        # STEP 7: Select most probable prefixes
        A_next = Pb[t] + Pnb[t]
        sorter = lambda l: A_next[l] * (len(W(l)) + 1) ** beta
        A_prev = sorted(A_next, key=sorter, reverse=True)[:k]
        # END: STEP 7

    return A_prev[0].strip('>')

## With Subword Regularization

As implemented by Jennifer Drexler and James Glass in [Subword Regularization and Beam Searh Decoding for End-to-End ASR](http://groups.csail.mit.edu/sls/publications/2019/JenniferDrexler_ICASSP-2019.pdf).

In [5]:
import sentencepiece as spm

# train sentencepiece model from `botchan.txt` and makes `m.model` and `m.vocab`
# `m.vocab` is just a reference. not used in the segmentation.
spm.SentencePieceTrainer.train(f'--input={resource_dir}botchan.txt --model_prefix={resource_dir}m_botchan --vocab_size=2000')

# makes segmenter instance and loads the model file (m.model)
sp = spm.SentencePieceProcessor()
sp.load(f'{resource_dir}m_botchan.model')

# encode: text => id
print(sp.encode_as_pieces('This is a test'))
print(sp.encode_as_ids('This is a test'))

# decode: id => text
print(sp.decode_pieces(['▁This', '▁is', '▁a', '▁t', 'est']))
print(sp.decode_ids([209, 31, 9, 375, 586]))

['▁This', '▁is', '▁a', '▁t', 'est']
[209, 31, 9, 375, 586]
This is a test
This is a test


In [6]:
class SPLM:
    def __init__(self, spp, log=True, regularize=True):
        self.spp = spp
        self.log = log
        self.regularize = regularize
    
    def __call__(self, sentence, **kwargs):
        return self.sp_score(sentence, **kwargs)
        
    def sp_score(self, sentence, l=-1, alpha=0.2):
        """Score sentence using unigram model of sentencepiece"""
        if self.regularize:
            encoded = self.spp.sample_encode_as_ids(sentence, l, alpha)
        else:
            encoded = self.spp.encode_as_ids(sentence)

        score = 0
        for idx in encoded:
            # return emission log probabilities, so just add them by chain-rule
            score += self.spp.GetScore(idx)  

        if not self.log:
            score = 10 ** score

        return score

In [7]:
splm = SPLM(sp, log=False, regularize=True)
splm('this is a test')

4.237340023052399e-43

## With Parallel Processing on CPU

Beam search decoding on ctc matrix requires an awful lot of computation. On `IPython`, we need to store the worker function on a different module in order to see the results.

In [8]:
from worker import ctc_beam_search_sp_mp

This is the barebone code in worker.py 

```python
def worker(b, lm):
    res = prefix_beam_search(b,
                             lm=lm,
                             k=100,
                             alpha=0.30,
                             beta=5,
                             prune=0.001)
    return res

# create the threadpool
with Pool(os.cpu_count() - 1) as p:
    # schedule one map/worker for each row in the original data
    q = p.starmap(worker, ([b for b in examples], splm))
```    

# Benchmark

Perform benchmark on various algorithm on our model and dataset.

## Create Dummy Data

In [9]:
def load_example(filename):
    with open(filename, 'rb') as f:
        example = pickle.load(f)
    return example

In [10]:
example_1 = load_example(resource_dir+"example_99.p")
example_2 = load_example(resource_dir+"example_1518.p")
example_3 = load_example(resource_dir+"example_2002.p")

print(example_1.shape, example_2.shape, example_3.shape)

(860, 29) (860, 29) (860, 29)


In [11]:
example_1 = np.expand_dims(example_1, axis=0)
example_2 = np.expand_dims(example_2, axis=0)
example_3 = np.expand_dims(example_3, axis=0)

Benchmark by creating 100 examples.

In [12]:
example_1 = np.vstack([example_1]*33)
example_2 = np.vstack([example_2]*33)
example_3 = np.vstack([example_3]*34)

In [13]:
examples = np.vstack([example_1, example_2, example_3])
examples.shape

(100, 860, 29)

## Quality

In [14]:
example_1 = load_example(resource_dir+"example_99.p")
example_2 = load_example(resource_dir+"example_1518.p")
example_3 = load_example(resource_dir+"example_2002.p")

print(example_1.shape, example_2.shape, example_3.shape)

(860, 29) (860, 29) (860, 29)


In following respective order:
1. Python numpy greedy
2. Python numpy prefix beam search
3. Python numpy multiprocessing prefix beam search

In [15]:
for example in [example_1, example_2, example_3]:
    res = greedy_decoder(example)
    print(res)

but no ghoes tor anything else appeared upon the angient wall
mister qualter as the apostle of the middle classes and we re glad twelcomed his gospe
alloud laugh followed at chunkeys expencs


In [16]:
for example in [example_1, example_2, example_3]:
    res = prefix_beam_search(example,
                             lm=None,
                             k=100,
                             alpha=0.30,
                             beta=5,
                             prune=0.001)
    print(res)

but no ghoest tor anything else appeared upon the angient walls
mister qualter as the apostle of the middle classes and we are glad t welcomed his gospel
alloud laugh followed at chunkeys expense


In [17]:
for example in [example_1, example_2, example_3]:
    res = prefix_beam_search(example,
                             lm=splm,
                             k=100,
                             alpha=0.30,
                             beta=5,
                             prune=0.001)
    print(res)

but noghoestoranything elseappeared upon theagenwalls
mister quiteras theirpostle of the middleclasses andweregladwelcomehis gospll
loud laugh  followedatchunkeysexpens


## Time

In following respective order:
1. Python numpy greedy
2. Python numpy prefix beam search
3. C++ tensorflow greedy
4. C++ tensorflow prefix beam search
5. Python numpy multiprocessing prefix beam search

In [18]:
%%time
for example in examples:
    greedy_decoder(example)

Wall time: 22 ms


In [19]:
%%time
for example in examples:
    prefix_beam_search(example,
                       lm=None,
                       k=100,
                       alpha=0.30,
                       beta=5,
                       prune=0.001)

Wall time: 2min 13s


In [20]:
%%time
_ = tf.nn.ctc_greedy_decoder(np.transpose(examples, [1, 0, 2]),
                             [examples.shape[1]]*examples.shape[0],
                             merge_repeated=True)

Wall time: 217 ms


In [21]:
%%time
_ = tf.nn.ctc_beam_search_decoder(np.transpose(examples, [1, 0, 2]),
                                  [examples.shape[1]]*examples.shape[0],
                                  beam_width=100,
                                  top_paths=1)

Wall time: 30.1 s


In [22]:
%%time
_ = ctc_beam_search_sp_mp(examples)

Wall time: 1min 22s
