In [627]:
import sys
sys.executable

import numpy as np

### CTC Beam Search Algorithm with external Language Model(LM)
beam search as described by the paper of Hwang et al. and the paper of Graves et al.

In [766]:
class Beam():
    def __init__(self, labeling=(), pnB=0, pB=0, pT=0):
        self.labeling = labeling
        self.pNonBlank = pnB
        self.pBlank = pB
        self.pTotal = pT

        # Language Model
        self.pLM = 0
        self.LM_applied = False
    
    def applyLM(self, alphabet, LM=None):
        if LM and not self.LM_applied:
            c1 = alphabet[self.labeling[-2] if len(self.labeling)>=2 else alphabet.index(' ')]
            c2 = alphabet[self.labeling[-1]]
            lmFactor = 0.01 # influence of language model
            self.pLM *= LM.getCharBigram(c1, c2) ** lmFactor #add probability of seeing first and second char next to each other to the char seq probability
            self.LM_applied = True

    def __str__(self):
        return str(self.labeling)

    def __repr__(self):
        return "Beam(labeling={}, pnB={}, pB={}, pT={})".format(self.labeling, self.pNonBlank, self.pBlank, self.pTotal)



class BeamDict():
    def __init__(self, beamList=None):
        if beamList:
            for beam in beamList:
                self.entries = {beam.labeling:beam for beam in beamList}
        else:
            self.entries = {}

    def getBest(self, BW=4):
        beamList = [a for _,a in self.entries.items()]
        beamList = sorted(beamList, key = lambda x: x.pTotal, reverse = True)
        return BeamDict(beamList[:BW])

    def addBeam(self, labeling):
        if labeling not in self.entries:
            self.entries[labeling] = Beam(labeling)

    def norm(self):
        "length-normalise LM score"
        for k in self.entries:
            labelingLen = len(self.entries[k].labeling)
            self.entries[k].pLM = self.entries[k].pLM ** (1.0 / (labelingLen if labelingLen else 1.0))

    def __str__(self):
        return self.entries.__str__()
    
    def __repr__(self):
        return "BeamDict(" +self.entries.__repr__() +")"

def ctcBeamSearch(alphabet, mat, LM=None, width=25):
    blankIndex = len(alphabet)
    _, maxT = mat.shape

    beams = BeamDict()
    beams.addBeam(labeling=())
    beams.entries[()].pBlank = 1
    beams.entries[()].pTotal = 1

    for t in range(maxT):
        bestBeams = beams.getBest(width)
        beams = BeamDict()
        for _,beam in bestBeams.entries.items():
            labeling = beam.labeling
            
            # 1)Copying
            beams.addBeam(labeling)
            if(labeling):
                # 1.a)repeating character
                beams.entries[labeling].pNonBlank += beam.pNonBlank * mat[labeling[-1], t]
            # 1.b)labeling ending with blank
            beams.entries[labeling].pBlank += beam.pTotal * mat[blankIndex, t]
            beams.entries[labeling].pTotal = beams.entries[labeling].pNonBlank + beams.entries[labeling].pBlank
            beams.entries[labeling].pLM = beam.pLM
            beams.entries[labeling].LM_applied = True

            # 2)Extending
            for c in range(len(alphabet)):
                # 2.a) Adding doubling last character, only if there is a pB >0 otherwise
                if (not labeling or labeling[-1] == c):
                    pNonBlank = beam.pBlank * mat[c, t]
                else:
                    # 2.b) Extending with different char than the last one
                    pNonBlank = beam.pTotal * mat[c, t]
                
                labeling_ = labeling + (c,)
                # Fill info for the extended labeling
                beams.addBeam(labeling_)
                beams.entries[labeling_].pNonBlank += pNonBlank
                beams.entries[labeling_].pTotal += pNonBlank
                beams.entries[labeling_].applyLM(mat)
                beams.entries[labeling_].LM_applied = True

    # normalise LM scores according to beam-labeling-length
    beams.norm()

    bestBeam = [beams.entries[key] for key in beams.getBest(1).entries][0]
    bestlabeling = ''
    for c_ix in bestBeam.labeling:
        bestlabeling += alphabet[c_ix]
    return bestBeam, bestlabeling

### Examples

In [768]:
alphabet = 'ab'
mat1 = np.array([[0.4, 0, 0.6], [0.4, 0.4, 0.2], [0.4, 0.1, 0.5]]).T

mat2 = np.array([[0.4, 0, 0.6], [0.4, 0.2, 0.4], [0.4, 0.2, 0.4]]).T
print(ctcBeamSearch(alphabet, mat1, None))
print(ctcBeamSearch(alphabet, mat2, None))

(Beam(labeling=(0,), pnB=0.20800000000000002, pB=0.24000000000000002, pT=0.448), 'a')
(Beam(labeling=(0,), pnB=0.256, pB=0.22400000000000003, pT=0.4800000000000001), 'a')
