This is code originally from [this gist](https://gist.github.com/awni/56369a90d03953e370f3964c826ed4b0).

We start out by doing the imports:

In [42]:
import numpy as np
import math
import sys
import collections

Then we define a class for the CTC decoder:

In [93]:
class CTCDecoder:
    
    def __init__(self, alphabet):
        self.alphabet = alphabet
        self.NEG_INF = -float("inf")
        self.trace = False
    
    def make_new_beam(self):
        fn = lambda : (self.NEG_INF, self.NEG_INF)
        return collections.defaultdict(fn)
    
    def logsumexp(self, *args):
        """
        Stable log sum exp.
        """
        if all(a == self.NEG_INF for a in args):
                return self.NEG_INF
        a_max = max(args)
        lsp = math.log(sum(math.exp(a - a_max) for a in args))
        return a_max + lsp
    
    def decode(self, probs, beam_size=100, blank=0):
        """
        Performs inference for the given output probabilities.
    
        Arguments:
                probs: The output probabilities (e.g. post-softmax) for each
                    time step. Should be an array of shape (time x output dim).
                beam_size (int): Size of the beam to use during inference.
                blank (int): Index of the CTC blank label.
    
        Returns the output label sequence and the corresponding negative
        log-likelihood estimated by the decoder.
        """
        T, S = probs.shape
        probs = np.log(probs)
    
        # Elements in the beam are (prefix, (p_blank, p_no_blank))
        # Initialize the beam with the empty sequence, a probability of
        # 1 for ending in blank and zero for ending in non-blank
        # (in log space).
        beam = [(tuple(), (0.0, self.NEG_INF))]
    
        for t in range(T): # Loop over time
            if self.trace:
                print('t:', t, file=sys.stderr)
                
            if self.trace:
                print('BEAM:', beam, file=sys.stderr)
            # A default dictionary to store the next step candidates.
            next_beam = self.make_new_beam()
    
            for s in range(S): # Loop over vocab
                p = probs[t, s]
    
                # The variables p_b and p_nb are respectively the
                # probabilities for the prefix given that it ends in a
                # blank and does not end in a blank at this time step.
                for prefix, (p_b, p_nb) in beam: # Loop over beam
    
                    # If we propose a blank the prefix doesn't change.
                    # Only the probability of ending in blank gets updated.
                    if s == blank:
                        n_p_b, n_p_nb = next_beam[prefix]
                        n_p_b = self.logsumexp(n_p_b, p_b + p, p_nb + p)
                        next_beam[prefix] = (n_p_b, n_p_nb)
                        continue
    
                    # Extend the prefix by the new character s and add it to
                    # the beam. Only the probability of not ending in blank
                    # gets updated.
                    end_t = prefix[-1] if prefix else None
                    n_prefix = prefix + (s,)
                    n_p_b, n_p_nb = next_beam[n_prefix]
                    if s != end_t:
                        n_p_nb = self.logsumexp(n_p_nb, p_b + p, p_nb + p)
                    else:
                        # We don't include the previous probability of not ending
                        # in blank (p_nb) if s is repeated at the end. The CTC
                        # algorithm merges characters not separated by a blank.
                        n_p_nb = self.logsumexp(n_p_nb, p_b + p)
                        
                    # *NB* this would be a good place to include an LM score.
                    next_beam[n_prefix] = (n_p_b, n_p_nb)
    
                    # If s is repeated at the end we also update the unchanged
                    # prefix. This is the merging case.
                    if s == end_t:
                        if self.trace:
                            print('MERGE')
                        n_p_b, n_p_nb = next_beam[prefix]
                        n_p_nb = self.logsumexp(n_p_nb, p_nb + p)
                        next_beam[prefix] = (n_p_b, n_p_nb)
    
            # Sort and trim the beam before moving on to the
            # next time-step.
            beam = sorted(next_beam.items(),
                            key=lambda x : self.logsumexp(*x[1]),
                            reverse=True)
            beam = beam[:beam_size]
            
            if self.trace:
                print('NEW BEAM:', beam, file=sys.stderr)
    
        best = beam[0]
        return best[0], -self.logsumexp(*best[1])
    
    def test(self):
        np.random.seed(3)
    
        time = 6
        output_dim = len(self.alphabet)
    
        probs = np.random.rand(time, output_dim)
        probs = probs / np.sum(probs, axis=1, keepdims=True)
    
        labels, score = self.decode(probs)
        print(labels)
        print(''.join([self.alphabet[i] for i in labels]))
        print("Score {:.3f}".format(score))
        
    def run(self, probs):
        labels, score = self.decode(probs)
        print(labels)
        print(''.join([self.alphabet[i] for i in labels]))
        print("Score {:.3f}".format(score))



Run a quick test:

In [94]:
V = [c for c in ' abcdefghijklmnopqrstuvwxyz']
dec = CTCDecoder(V)
dec.test()


(5, 14, 7)
eng
Score 13.048


Here we make an input matrix, emulating the output of the acoustic model.

In [95]:
import random, sys

C = ['c', 'c', 'a', 'a', 't', 't'] # sequence we want to output
M = [] # matrix for output 

for c in C: 
        row = []
        for v in V:
                if v == c:
                        row.append(10) # this is the best 
                else:
                        row.append(random.randint(1,5)) # a random other value
        nrow = [i/sum(row) for i in row] # normalise 
        M.append(nrow)

M = np.array(M) # numpy-ise it.

This is our TxV matrix (timesteps by vocabulary/alphabet)

In [96]:
print(M)

[[0.03614458 0.04819277 0.02409639 0.12048193 0.02409639 0.01204819
  0.04819277 0.06024096 0.01204819 0.03614458 0.03614458 0.03614458
  0.01204819 0.01204819 0.01204819 0.06024096 0.01204819 0.03614458
  0.01204819 0.06024096 0.06024096 0.01204819 0.06024096 0.02409639
  0.06024096 0.01204819 0.06024096]
 [0.05128205 0.01282051 0.03846154 0.12820513 0.06410256 0.06410256
  0.02564103 0.05128205 0.03846154 0.02564103 0.02564103 0.02564103
  0.01282051 0.02564103 0.01282051 0.01282051 0.01282051 0.01282051
  0.06410256 0.05128205 0.01282051 0.05128205 0.03846154 0.02564103
  0.02564103 0.06410256 0.02564103]
 [0.04597701 0.11494253 0.03448276 0.04597701 0.04597701 0.04597701
  0.04597701 0.05747126 0.02298851 0.02298851 0.04597701 0.04597701
  0.04597701 0.02298851 0.02298851 0.01149425 0.03448276 0.04597701
  0.01149425 0.05747126 0.05747126 0.01149425 0.02298851 0.02298851
  0.01149425 0.03448276 0.01149425]
 [0.02020202 0.1010101  0.02020202 0.05050505 0.01010101 0.04040404
  0.0303

In [97]:
dec.run(M)

(3, 1, 20)
cat
Score 11.288


In [98]:
dec.trace = True
dec.run(M)

(3, 1, 20)
cat
Score 11.288


t: 0
BEAM: [((), (0.0, -inf))]
NEW BEAM: [((3,), (-inf, -2.1162555148025524)), ((7,), (-inf, -2.8094026953624978)), ((15,), (-inf, -2.8094026953624978)), ((19,), (-inf, -2.8094026953624978)), ((20,), (-inf, -2.8094026953624978)), ((22,), (-inf, -2.8094026953624978)), ((24,), (-inf, -2.8094026953624978)), ((26,), (-inf, -2.8094026953624978)), ((1,), (-inf, -3.0325462466767075)), ((6,), (-inf, -3.0325462466767075)), ((), (-3.3202283191284883, -inf)), ((9,), (-inf, -3.3202283191284883)), ((10,), (-inf, -3.3202283191284883)), ((11,), (-inf, -3.3202283191284883)), ((17,), (-inf, -3.3202283191284883)), ((2,), (-inf, -3.7256934272366524)), ((4,), (-inf, -3.7256934272366524)), ((23,), (-inf, -3.7256934272366524)), ((5,), (-inf, -4.418840607796598)), ((8,), (-inf, -4.418840607796598)), ((12,), (-inf, -4.418840607796598)), ((13,), (-inf, -4.418840607796598)), ((14,), (-inf, -4.418840607796598)), ((16,), (-inf, -4.418840607796598)), ((18,), (-inf, -4.418840607796598)), ((21,), (-inf, -4.418840607