#Integrate SimpleHTR with word beam search
Using CTCWordBeamSearch


In [None]:
!rm -rf CTCWordBeamSearch

In [None]:
!git clone https://github.com/githubharald/CTCWordBeamSearch
%cd ./CTCWordBeamSearch
!pip install .
%cd /content

Cloning into 'CTCWordBeamSearch'...
remote: Enumerating objects: 84, done.[K
remote: Counting objects: 100% (84/84), done.[K
remote: Compressing objects: 100% (63/63), done.[K
remote: Total 340 (delta 19), reused 72 (delta 12), pack-reused 256[K
Receiving objects: 100% (340/340), 1.63 MiB | 25.62 MiB/s, done.
Resolving deltas: 100% (116/116), done.
/content/CTCWordBeamSearch
Processing /content/CTCWordBeamSearch
Building wheels for collected packages: word-beam-search
  Building wheel for word-beam-search (setup.py) ... [?25l[?25hdone
  Created wheel for word-beam-search: filename=word_beam_search-1.0.0-cp37-cp37m-linux_x86_64.whl size=1178307 sha256=78bc47f343ceb1e2923be32c52a71f07cdb2f2084838b73ee8a80c3fa9ff4359
  Stored in directory: /tmp/pip-ephem-wheel-cache-3czr_c8u/wheels/a9/69/4c/9d6acbecc7bf4b47c5072b213d9b08e4b9c43864bbed5206cc
Successfully built word-beam-search
Installing collected packages: word-beam-search
Successfully installed word-beam-search-1.0.0
/content


#Train the model with the dataset

In [None]:
!cp '/content/drive/MyDrive/TFG/TrainSimpleHTR/dataset.zip' .
!unzip './dataset.zip'
!mv '/content/dataset/gt/wordsNew.txt' '/content/dataset/gt/words.txt'

In [None]:
!cp '/content/drive/MyDrive/TFG/TrainSimpleHTR/requirements.txt' .
!pip install -r requirements.txt
!pip install path

In [None]:
!cp -r '/content/drive/MyDrive/TFG/TrainSimpleHTR/model' './model'
!cp -r '/content/drive/MyDrive/TFG/TrainSimpleHTR/data' './data'
!cp -r '/content/drive/MyDrive/TFG/TrainSimpleHTR/src' './src'

In [None]:
%cd ./src

/content/src


In [None]:
from pathlib import Path

import argparse
import json

import cv2
import editdistance
from path import Path

from DataLoaderIAM import DataLoaderIAM, Batch
from Model import Model, DecoderType
from SamplePreprocessor import preprocess


class FilePaths:
    "filenames and paths to data"
    fnCharList = '../model/wordCharList.txt'
    fnSummary = '../model/summary.json'
    fnInfer = '../data/test.png'
    fnCorpus = '../data/corpus.txt'


def write_summary(charErrorRates, wordAccuracies):
    with open(FilePaths.fnSummary, 'w') as f:
        json.dump({'charErrorRates': charErrorRates, 'wordAccuracies': wordAccuracies}, f)


def train(model, loader):
    "train NN"
    epoch = 0  # number of training epochs since start
    summaryCharErrorRates = []
    summaryWordAccuracies = []
    bestCharErrorRate = float('inf')  # best valdiation character error rate
    noImprovementSince = 0  # number of epochs no improvement of character error rate occured
    earlyStopping = 25  # stop training after this number of epochs without improvement
    while True:
        epoch += 1
        print('Epoch:', epoch)

        # train
        print('Train NN')
        loader.trainSet()
        while loader.hasNext():
            iterInfo = loader.getIteratorInfo()
            batch = loader.getNext()
            loss = model.trainBatch(batch)
            print(f'Epoch: {epoch} Batch: {iterInfo[0]}/{iterInfo[1]} Loss: {loss}')

        # validate
        charErrorRate, wordAccuracy = validate(model, loader)

        # write summary
        summaryCharErrorRates.append(charErrorRate)
        summaryWordAccuracies.append(wordAccuracy)
        write_summary(summaryCharErrorRates, summaryWordAccuracies)

        # if best validation accuracy so far, save model parameters
        if charErrorRate < bestCharErrorRate:
            print('Character error rate improved, save model')
            bestCharErrorRate = charErrorRate
            noImprovementSince = 0
            model.save()
        else:
            print(f'Character error rate not improved, best so far: {charErrorRate * 100.0}%')
            noImprovementSince += 1
            print('No improvement since: ' + str(noImprovementSince))

        # stop training if no more improvement in the last x epochs
        if noImprovementSince >= earlyStopping:
            print(f'No more improvement since {earlyStopping} epochs. Training stopped.')
            break


def validate(model, loader):
    "validate NN"
    print('Validate NN')
    loader.validationSet()
    numCharErr = 0
    numCharTotal = 0
    numWordOK = 0
    numWordTotal = 0
    while loader.hasNext():
        iterInfo = loader.getIteratorInfo()
        print(f'Batch: {iterInfo[0]} / {iterInfo[1]}')
        batch = loader.getNext()
        (recognized, _) = model.inferBatch(batch)

        print('Ground truth -> Recognized')
        for i in range(len(recognized)):
            numWordOK += 1 if batch.gtTexts[i] == recognized[i] else 0
            numWordTotal += 1
            dist = editdistance.eval(recognized[i], batch.gtTexts[i])
            numCharErr += dist
            numCharTotal += len(batch.gtTexts[i])
            print('[OK]' if dist == 0 else '[ERR:%d]' % dist, '"' + batch.gtTexts[i] + '"', '->',
                  '"' + recognized[i] + '"')

    # print validation result
    charErrorRate = numCharErr / numCharTotal
    wordAccuracy = numWordOK / numWordTotal
    print(f'Character error rate: {charErrorRate * 100.0}%. Word accuracy: {wordAccuracy * 100.0}%.')
    return charErrorRate, wordAccuracy


def infer(model, fnImg):
    "recognize text in image provided by file path"
    img = preprocess(cv2.imread(fnImg, cv2.IMREAD_GRAYSCALE), Model.imgSize)
    batch = Batch(None, [img])
    (recognized, probability) = model.inferBatch(batch, True)
    print(f'Recognized: "{recognized[0]}"')
    print(f'Probability: {probability[0]}')


def main():
    # "main function"
    # parser = argparse.ArgumentParser()
    # parser.add_argument('--train', help='train the NN', action='store_true')
    args_train = False
    # parser.add_argument('--validate', help='validate the NN', action='store_true')
    args_validate = True
    # parser.add_argument('--decoder', choices=['bestpath', 'beamsearch', 'wordbeamsearch'], default='bestpath', help='CTC decoder')
    args_decoder = 'wordbeamsearch'
    # parser.add_argument('--batch_size', help='batch size', type=int, default=100)
    args_batch_size = 100
    # parser.add_argument('--data_dir', help='directory containing IAM dataset', type=Path, required=False)
    args_data_dir = Path('/content/dataset')
    # parser.add_argument('--fast', help='use lmdb to load images', action='store_true')
    args_fast = False
    # parser.add_argument('--dump', help='dump output of NN to CSV file(s)', action='store_true')
    args_dump = False
    # args = parser.parse_args()

    if args_train and args_validate:
      print("Both Train and Validate are enabled.")
      raise

    # set chosen CTC decoder
    if args_decoder == 'bestpath':
        decoderType = DecoderType.BestPath
    elif args_decoder == 'beamsearch':
        decoderType = DecoderType.BeamSearch
    elif args_decoder == 'wordbeamsearch':
        decoderType = DecoderType.WordBeamSearch

    # train or validate on IAM dataset
    if args_train or args_validate:
        # load training data, create TF model
        loader = DataLoaderIAM(args_data_dir, args_batch_size, Model.imgSize, Model.maxTextLen, args_fast)

        # save characters of model for inference mode
        open(FilePaths.fnCharList, 'w').write(str().join(loader.charList))

        # save words contained in dataset into file
        open(FilePaths.fnCorpus, 'w').write(str(' ').join(loader.trainWords + loader.validationWords))

        # execute training or validation
        if args_train:
            model = Model(loader.charList, decoderType)
            train(model, loader)
        elif args_validate:
            model = Model(loader.charList, decoderType, mustRestore=True)
            validate(model, loader)

    # infer text on test image
    else:
        model = Model(open(FilePaths.fnCharList).read(), decoderType, mustRestore=True, dump=args_dump)
        infer(model, FilePaths.fnInfer)


if __name__ == '__main__':
    main()


In [None]:
!zip -r model_trained.zip ../data ../model

!cp 'model_trained.zip' '/content/drive/MyDrive/TFG/TrainSimpleHTR/model_trained.zip'