In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
from __future__ import division
from __future__ import print_function

import cv2
import matplotlib.pyplot as plt
import numpy as np
import json
import random
import os
import codecs
import sys

from skimage.filters import threshold_local, threshold_yen
import argparse
import editdistance

%tensorflow_version 1.x
import tensorflow as tf
print(tf.__version__)
device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0':
  raise SystemError('GPU device not found')
print('Found GPU at: {}'.format(device_name))

TensorFlow 1.x selected.
1.15.2
Found GPU at: /device:GPU:0


## helpers

In [3]:
SMALL_HEIGHT = 800

In [4]:
def preprocessor(imgPath, imgSize, binary=True):
    """ Pre-processing image for predicting """
    img = cv2.imread(imgPath)
    # Binary
    if binary:
        brightness = 0
        contrast = 50
        img = np.int16(img)
        img = img * (contrast/127+1) - contrast + brightness
        img = np.clip(img, 0, 255)
        img = np.uint8(img)

        img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        T = threshold_local(img, 11, offset=10, method="gaussian")
        img = (img > T).astype("uint8") * 255

        # Increase line width
        kernel = np.ones((3, 3), np.uint8)
        img = cv2.erode(img, kernel, iterations=1)
    else:
        img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

    # Create target image and copy sample image into it
    (wt, ht) = imgSize
    (h, w) = img.shape
    fx = w / wt
    fy = h / ht
    f = max(fx, fy)

    # Scale according to f (result at least 1 and at most wt or ht)
    newSize = (max(min(wt, int(w / f)), 1), max(min(ht, int(h / f)), 1))
    img = cv2.resize(img, newSize)
    target = np.ones([ht, wt]) * 255
    target[0:newSize[1], 0:newSize[0]] = img

    # Transpose for TF
    img = cv2.transpose(target)

    # Normalize
    (m, s) = cv2.meanStdDev(img)
    m = m[0][0]
    s = s[0][0]
    img = img - m
    img = img / s if s > 0 else img

    return img

In [5]:
def wer(r, h):
    """
    Calculation of WER with Levenshtein distance.

    Works only for iterables up to 254 elements (uint8).
    O(nm) time ans space complexity.

    Parameters
    ----------
    r : list
    h : list

    Returns
    -------
    int

    Examples
    --------
    >>> wer("who is there".split(), "is there".split())
    1
    >>> wer("who is there".split(), "".split())
    3
    >>> wer("".split(), "who is there".split())
    3
    """
    # initialisation
    d = np.zeros((len(r)+1)*(len(h)+1), dtype=np.uint8)
    d = d.reshape((len(r)+1, len(h)+1))
    for i in range(len(r)+1):
        for j in range(len(h)+1):
            if i == 0:
                d[0][j] = j
            elif j == 0:
                d[i][0] = i

    # computation
    for i in range(1, len(r)+1):
        for j in range(1, len(h)+1):
            if r[i-1] == h[j-1]:
                d[i][j] = d[i-1][j-1]
            else:
                substitution = d[i-1][j-1] + 1
                insertion = d[i][j-1] + 1
                deletion = d[i-1][j] + 1
                d[i][j] = min(substitution, insertion, deletion)
    return d[len(r)][len(h)]


## DataLoader

In [6]:
class FilePaths:
    """ Filenames and paths to data """
    fnCharList = '/content/drive/My Drive/LineHTR/model/charList.txt'
    fnWordCharList = '/content/drive/My Drive/LineHTR/model/wordCharList.txt'
    fnCorpus = '/content/drive/My Drive/LineHTR/data/corpus.txt'
    fnAccuracy = '/content/drive/My Drive/LineHTR/model/accuracy.txt'
    fnTrain = '/content/drive/My Drive/LineHTR/data/'
    fnInfer = '/content/drive/My Drive/LineHTR/data/a01-000u-00.png'  # 测试图片

In [7]:
class Sample:
    """ Sample from the dataset """

    def __init__(self, gtText, filePath):
        self.gtText = gtText
        self.filePath = filePath

In [8]:
class Batch:
    """ Batch containing images and ground truth texts """

    def __init__(self, gtTexts, imgs):
        self.imgs = np.stack(imgs, axis=0)
        self.gtTexts = gtTexts

In [9]:
class DataLoader:
    """ Loads data from data folder """

    def __init__(self, filePath, batchSize, imgSize, maxTextLen):
        """ Loader for dataset at given location, preprocess images and text according to parameters """

        assert filePath[-1] == '/'

        self.currIdx = 0
        self.batchSize = batchSize
        self.imgSize = imgSize
        self.samples = []

        chars = set()
        bad_samples = []
        # Read json lables file
        # Dataset folder should contain a labels.json file inside, with key is the file name of images and value is the label
        with open(filePath + 'test.txt', 'r') as f:
            lines = f.readlines()
            for line in lines:
                if not line or line[0] == '#':
                    continue

                lineSplit = line.strip().split(' ')
                assert len(lineSplit) >= 9

                # filename: part1-part2-part3 --> part1/part1-part2/part1-part2-part3.png
                fileNameSplit = lineSplit[0].split('-')
                fileName = filePath + 'lines/' + fileNameSplit[0] + '/' + fileNameSplit[0] + '-' + fileNameSplit[1] + '/' + lineSplit[0] + '.png'

                # GT text are columns starting at 9
                gtText = self.truncateLabel(' '.join(lineSplit[8:]), maxTextLen)
                chars = chars.union(set(list(gtText)))

                # check if image is not empty
                if not os.path.getsize(fileName):
                    bad_samples.append(lineSplit[0] + '.png')
                    continue

                # put sample into list
                self.samples.append(Sample(gtText, fileName))

        self.charList = list(open(FilePaths.fnCharList).read())

        # Split into training and validation set: 90% - 10%
        splitIdx = int(0.95 * len(self.samples))
        self.trainSamples = self.samples[:splitIdx]
        self.validationSamples = self.samples[splitIdx:]

        print("Train on", len(self.trainSamples), "images. Validate on",
              len(self.validationSamples), "images.")

        # Number of randomly chosen samples per epoch for training
        self.numTrainSamplesPerEpoch = 5500

        # Start with train set
        self.trainSet()

        # List of all chars in dataset
        #self.charList = sorted(list(chars))

    def truncateLabel(self, text, maxTextLen):
        # ctc_loss can't compute loss if it cannot find a mapping between text label and input
        # labels. Repeat letters cost double because of the blank symbol needing to be inserted.
        # If a too-long label is provided, ctc_loss returns an infinite gradient
        cost = 0
        for i in range(len(text)):
            if i != 0 and text[i] == text[i - 1]:
                cost += 2
            else:
                cost += 1
            if cost > maxTextLen:
                return text[:i]
        return text

    def trainSet(self):
        """ Switch to randomly chosen subset of training set """
        self.currIdx = 0
        random.shuffle(self.trainSamples)
        self.samples = self.trainSamples[:self.numTrainSamplesPerEpoch]

    def validationSet(self):
        """ Switch to validation set """
        self.currIdx = 0
        self.samples = self.validationSamples

    def getIteratorInfo(self):
        """ Current batch index and overall number of batches """
        return (self.currIdx // self.batchSize + 1, len(self.samples) // self.batchSize)

    def hasNext(self):
        """ Iterator """
        return self.currIdx + self.batchSize <= len(self.samples)

    def getNext(self):
        """ Iterator """
        batchRange = range(self.currIdx, self.currIdx + self.batchSize)
        gtTexts = [self.samples[i].gtText for i in batchRange]
        imgs = [preprocessor(self.samples[i].filePath,
                             self.imgSize, binary=True) for i in batchRange]
        self.currIdx += self.batchSize
        return Batch(gtTexts, imgs)


## Model

In [10]:
class DecoderType:
    BestPath = 0
    WordBeamSearch = 1

In [11]:
class Model:
    # Model Constants
    batchSize = 50
    imgSize = (800, 64)
    maxTextLen = 100


    def __init__(self, charList, decoderType=DecoderType.BestPath, mustRestore=False):
        self.charList = charList
        self.decoderType = decoderType
        self.mustRestore = mustRestore
        self.snapID = 0

        # CNN
        with tf.name_scope('CNN'):
            with tf.name_scope('Input'):
                self.inputImgs = tf.placeholder(tf.float32, shape=(Model.batchSize, Model.imgSize[0], Model.imgSize[1]))
            cnnOut4d = self.setupCNN(self.inputImgs)

        # RNN
        with tf.name_scope('RNN'):
            rnnOut3d = self.setupRNN(cnnOut4d)

        # # Debuging CTC
        # self.rnnOutput = tf.transpose(rnnOut3d, [1, 0, 2])

        # CTC
        with tf.name_scope('CTC'):
            (self.loss, self.decoder) = self.setupCTC(rnnOut3d)
            self.training_loss_summary = tf.summary.scalar(
                'loss', self.loss)  # Tensorboard: Track loss

        # Optimize NN parameters
        with tf.name_scope('Optimizer'):
            self.batchesTrained = 0
            self.learningRate = tf.placeholder(tf.float32, shape=[])
            self.optimizer = tf.train.RMSPropOptimizer(
                self.learningRate).minimize(self.loss)

        # Initialize TensorFlow
        (self.sess, self.saver) = self.setupTF()

        self.writer = tf.summary.FileWriter(
            '/content/drive/My Drive/LineHTR/src/logs', self.sess.graph)  # Tensorboard: Create writer
        self.merge = tf.summary.merge(
            [self.training_loss_summary])  # Tensorboard: Merge


    def setupCNN(self, cnnIn3d):
        """ Create CNN layers and return output of these layers """

        cnnIn4d = tf.expand_dims(input=cnnIn3d, axis=3)

        # First Layer: Conv (5x5) + Pool (2x2) - Output size: 400 x 32 x 64
        with tf.name_scope('Conv_Pool_1'):
            kernel = tf.Variable(
                tf.truncated_normal([5, 5, 1, 64], stddev=0.1))
            conv = tf.nn.conv2d(
                cnnIn4d, kernel, padding='SAME', strides=(1, 1, 1, 1))
            relu = tf.nn.relu(conv)
            pool = tf.nn.max_pool(relu, (1, 2, 2, 1), (1, 2, 2, 1), 'VALID')

        # Second Layer: Conv (5x5) - Output size: 400 x 32 x 128
        with tf.name_scope('Conv_2'):
            kernel = tf.Variable(tf.truncated_normal(
                [5, 5, 64, 128], stddev=0.1))
            conv = tf.nn.conv2d(
                pool, kernel, padding='SAME', strides=(1, 1, 1, 1))
            relu = tf.nn.relu(conv)

        # Third Layer: Conv (3x3) + Pool (2x2) + Simple Batch Norm - Output size: 200 x 16 x 128
        with tf.name_scope('Conv_Pool_BN_3'):
            kernel = tf.Variable(tf.truncated_normal(
                [3, 3, 128, 128], stddev=0.1))
            conv = tf.nn.conv2d(
                relu, kernel, padding='SAME', strides=(1, 1, 1, 1))
            mean, variance = tf.nn.moments(conv, axes=[0])
            batch_norm = tf.nn.batch_normalization(
                conv, mean, variance, offset=None, scale=None, variance_epsilon=0.001)
            relu = tf.nn.relu(batch_norm)
            pool = tf.nn.max_pool(relu, (1, 2, 2, 1), (1, 2, 2, 1), 'VALID')

        # Fourth Layer: Conv (3x3) - Output size: 200 x 16 x 256
        with tf.name_scope('Conv_4'):
            kernel = tf.Variable(tf.truncated_normal(
                [3, 3, 128, 256], stddev=0.1))
            conv = tf.nn.conv2d(
                pool, kernel, padding='SAME', strides=(1, 1, 1, 1))
            relu = tf.nn.relu(conv)

        # Fifth Layer: Conv (3x3) - Output size: 200 x 16 x 256
        with tf.name_scope('Conv_5'):
            kernel = tf.Variable(tf.truncated_normal(
                [3, 3, 256, 256], stddev=0.1))
            conv = tf.nn.conv2d(
                relu, kernel, padding='SAME', strides=(1, 1, 1, 1))
            relu = tf.nn.relu(conv)

        # Sixth Layer: Conv (3x3) + Simple Batch Norm - Output size: 200 x 16 x 512
        with tf.name_scope('Conv_BN_6'):
            kernel = tf.Variable(tf.truncated_normal(
                [3, 3, 256, 512], stddev=0.1))
            conv = tf.nn.conv2d(
                relu, kernel, padding='SAME', strides=(1, 1, 1, 1))
            mean, variance = tf.nn.moments(conv, axes=[0])
            batch_norm = tf.nn.batch_normalization(
                conv, mean, variance, offset=None, scale=None, variance_epsilon=0.001)
            relu = tf.nn.relu(batch_norm)

        # Seventh Layer: Conv (3x3) + Pool (2x2) - Output size: 100 x 8 x 512
        with tf.name_scope('Conv_Pool_7'):
            kernel = tf.Variable(tf.truncated_normal(
                [3, 3, 512, 512], stddev=0.1))
            conv = tf.nn.conv2d(
                relu, kernel, padding='SAME', strides=(1, 1, 1, 1))
            relu = tf.nn.relu(conv)
            pool = tf.nn.max_pool(relu, (1, 2, 2, 1), (1, 2, 2, 1), 'VALID')

        return pool


    def setupRNN(self, rnnIn4d):
        """ Create RNN layers and return output of these layers """
        rnnIn4d = tf.slice(rnnIn4d, [0, 0, 0, 0], [
                           self.batchSize, 100, 1, 512])
        rnnIn3d = tf.squeeze(rnnIn4d)

        # 2 layers of LSTM cell used to build RNN
        numHidden = 512
        cells = [tf.nn.rnn_cell.LSTMCell(
            numHidden, name='basic_lstm_cell') for _ in range(2)]
        stacked = tf.nn.rnn_cell.MultiRNNCell(cells, state_is_tuple=True)

        # Bi-directional RNN
        # BxTxF -> BxTx2H
        ((forward, backward), _) = tf.nn.bidirectional_dynamic_rnn(
            cell_fw=stacked, cell_bw=stacked, inputs=rnnIn3d, dtype=rnnIn3d.dtype)

        # BxTxH + BxTxH -> BxTx2H -> BxTx1X2H
        concat = tf.expand_dims(tf.concat([forward, backward], 2), 2)

        # Project output to chars (including blank): BxTx1x2H -> BxTx1xC -> BxTxC
        kernel = tf.Variable(tf.truncated_normal(
            [1, 1, numHidden*2, len(self.charList)+1], stddev=0.1))
        return tf.squeeze(tf.nn.atrous_conv2d(value=concat, filters=kernel, rate=1, padding='SAME'), axis=[2])


    def setupCTC(self, ctcIn3d):
        """ Create CTC loss and decoder and return them """
        # BxTxC -> TxBxC
        ctcIn3dTBC = tf.transpose(ctcIn3d, [1, 0, 2])

        # Ground truth text as sparse tensor
        with tf.name_scope('CTC_Loss'):
            self.gtTexts = tf.SparseTensor(tf.placeholder(tf.int64, shape=[
                                           None, 2]), tf.placeholder(tf.int32, [None]), tf.placeholder(tf.int64, [2]))
            # Calculate loss for batch
            self.seqLen = tf.placeholder(tf.int32, [None])
            loss = tf.nn.ctc_loss(labels=self.gtTexts, inputs=ctcIn3dTBC, sequence_length=self.seqLen,
                                  ctc_merge_repeated=True, ignore_longer_outputs_than_inputs=True)
        with tf.name_scope('CTC_Decoder'):
            # Decoder: Best path decoding or Word beam search decoding
            if self.decoderType == DecoderType.BestPath:
                decoder = tf.nn.ctc_greedy_decoder(
                    inputs=ctcIn3dTBC, sequence_length=self.seqLen)
            elif self.decoderType == DecoderType.WordBeamSearch:
                # Import compiled word beam search operation (see https://github.com/githubharald/CTCWordBeamSearch)
                word_beam_search_module = tf.load_op_library(
                    '/content/drive/My Drive/LineHTR/src/TFWordBeamSearch.so')
                
                # Prepare: dictionary, characters in dataset, characters forming words
                chars = codecs.open(FilePaths.fnCharList, 'r', 'utf8').read()
                wordChars = codecs.open(
                    FilePaths.fnWordCharList, 'r', 'utf8').read()
                corpus = codecs.open(FilePaths.fnCorpus, 'r', 'utf8').read()

                # # Decoder using the "NGramsForecastAndSample": restrict number of (possible) next words to at most 20 words: O(W) mode of word beam search
                # decoder = word_beam_search_module.word_beam_search(tf.nn.softmax(ctcIn3dTBC, dim=2), 25, 'NGramsForecastAndSample', 0.0, corpus.encode('utf8'), chars.encode('utf8'), wordChars.encode('utf8'))

                # Decoder using the "Words": only use dictionary, no scoring: O(1) mode of word beam search
                decoder = word_beam_search_module.word_beam_search(tf.nn.softmax(
                    ctcIn3dTBC, dim=2), 25, 'Words', 0.0, corpus.encode('utf8'), chars.encode('utf8'), wordChars.encode('utf8'))

        # Return a CTC operation to compute the loss and CTC operation to decode the RNN output
        return (tf.reduce_mean(loss), decoder)


    def setupTF(self):
        """ Initialize TensorFlow """
        print('Python: ' + sys.version)
        print('Tensorflow: ' + tf.__version__)
        sess = tf.Session()  # Tensorflow session
        saver = tf.train.Saver(max_to_keep=5)  # Saver saves model to file
        modelDir = '/content/drive/My Drive/LineHTR/model/'
        latestSnapshot = tf.train.latest_checkpoint(
            modelDir)  # Is there a saved model?
        # If model must be restored (for inference), there must be a snapshot
        if self.mustRestore and not latestSnapshot:
            raise Exception('No saved model found in: ' + modelDir)
        # Load saved model if available
        if latestSnapshot:
            print('Init with stored values from ' + latestSnapshot)
            saver.restore(sess, latestSnapshot)
        else:
            print('Init with new values')
            sess.run(tf.global_variables_initializer())

        return (sess, saver)

    def toSpare(self, texts):
        """ Convert ground truth texts into sparse tensor for ctc_loss """
        indices = []
        values = []
        shape = [len(texts), 0]  # Last entry must be max(labelList[i])
        # Go over all texts
        for (batchElement, texts) in enumerate(texts):
            # Convert to string of label (i.e. class-ids)
            # print(texts)
            labelStr = []
            for c in texts:
                # print(c, '|', end='')
                labelStr.append(self.charList.index(c))
            # print(' ')
            # labelStr = [self.charList.index(c) for c in texts]
            # Sparse tensor must have size of max. label-string
            if len(labelStr) > shape[1]:
                shape[1] = len(labelStr)
            # Put each label into sparse tensor
            for (i, label) in enumerate(labelStr):
                indices.append([batchElement, i])
                values.append(label)

        return (indices, values, shape)

    def decoderOutputToText(self, ctcOutput):
        """ Extract texts from output of CTC decoder """
        # Contains string of labels for each batch element
        encodedLabelStrs = [[] for i in range(Model.batchSize)]
        # Word beam search: label strings terminated by blank
        if self.decoderType == DecoderType.WordBeamSearch:
            blank = len(self.charList)
            for b in range(Model.batchSize):
                for label in ctcOutput[b]:
                    if label == blank:
                        break
                    encodedLabelStrs[b].append(label)
        # TF decoders: label strings are contained in sparse tensor
        else:
            # Ctc returns tuple, first element is SparseTensor
            decoded = ctcOutput[0][0]
            # Go over all indices and save mapping: batch -> values
            #idxDict = {b : [] for b in range(Model.batchSize)}
            for (idx, idx2d) in enumerate(decoded.indices):
                label = decoded.values[idx]
                batchElement = idx2d[0]  # index according to [b,t]
                encodedLabelStrs[batchElement].append(label)
        # Map labels to chars for all batch elements
        return [str().join([self.charList[c] for c in labelStr]) for labelStr in encodedLabelStrs]

    def trainBatch(self, batch, batchNum):
        """ Feed a batch into the NN to train it """
        spare = self.toSpare(batch.gtTexts)
        rate = 0.01 if self.batchesTrained < 50 else (
            0.001 if self.batchesTrained < 2750 else 0.0001)
        (loss_summary, _, lossVal) = self.sess.run([self.merge, self.optimizer, self.loss], {
            self.inputImgs: batch.imgs, self.gtTexts: spare, self.seqLen: [Model.maxTextLen] * Model.batchSize, self.learningRate: rate})
        # Tensorboard: Add loss_summary to writer
        self.writer.add_summary(loss_summary, batchNum)
        self.batchesTrained += 1
        return lossVal

    def inferBatch(self, batch):
        """ Feed a batch into the NN to recognize texts """
        decoded = self.sess.run(self.decoder, {self.inputImgs: batch.imgs, self.seqLen: [
                                Model.maxTextLen] * Model.batchSize})

        # # Dump RNN output to .csv file
        # decoded, rnnOutput = self.sess.run([self.decoder, self.rnnOutput], {
        #                                    self.inputImgs: batch.imgs, self.seqLen: [Model.maxTextLen] * Model.batchSize})
        # s = rnnOutput.shape
        # b = 0
        # csv = ''
        # for t in range(s[0]):
        #     for c in range(s[2]):
        #         csv += str(rnnOutput[t, b, c]) + ';'
        #     csv += '\n'
        # open('mat_0.csv', 'w').write(csv)

        return self.decoderOutputToText(decoded)

    def save(self):
        """ Save model to file """
        self.snapID += 1
        self.saver.save(self.sess, '/content/drive/My Drive/LineHTR/model/snapshot',
                        global_step=self.snapID)


## main

In [12]:
# # Disable GPU
# os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

In [13]:
def train(model, loader):
    """ Train the neural network """
    epoch = 0  # Number of training epochs since start
    bestCharErrorRate = float('inf')  # Best valdiation character error rate
    noImprovementSince = 0  # Number of epochs no improvement of character error rate occured
    earlyStopping = 8  # Stop training after this number of epochs without improvement
    batchNum = 0

    totalEpoch = len(loader.trainSamples)//loader.numTrainSamplesPerEpoch

    while True:
        epoch += 1
        # print('Epoch:', epoch, '/', totalEpoch)
        print('Epoch:', epoch)

        # Train
        print('Train neural network')
        loader.trainSet()
        while loader.hasNext():
            batchNum += 1
            iterInfo = loader.getIteratorInfo()
            batch = loader.getNext()
            loss = model.trainBatch(batch, batchNum)
            print('Batch:', iterInfo[0], '/', iterInfo[1], 'Loss:', loss)

        # Validate
        charErrorRate, textLineAccuracy, wordErrorRate = validate(model, loader)
        cer_summary = tf.Summary(value=[tf.Summary.Value(
            tag='charErrorRate', simple_value=charErrorRate)])  # Tensorboard: Track charErrorRate
        # Tensorboard: Add cer_summary to writer
        model.writer.add_summary(cer_summary, epoch)
        text_line_summary = tf.Summary(value=[tf.Summary.Value(
            tag='textLineAccuracy', simple_value=textLineAccuracy)])  # Tensorboard: Track textLineAccuracy
        # Tensorboard: Add text_line_summary to writer
        model.writer.add_summary(text_line_summary, epoch)
        wer_summary = tf.Summary(value=[tf.Summary.Value(
            tag='wordErrorRate', simple_value=wordErrorRate)])  # Tensorboard: Track wordErrorRate
        # Tensorboard: Add wer_summary to writer
        model.writer.add_summary(wer_summary, epoch)

        # If best validation accuracy so far, save model parameters
        if charErrorRate < bestCharErrorRate:
            print('Character error rate improved, save model')
            bestCharErrorRate = charErrorRate
            noImprovementSince = 0
            model.save()
            open(FilePaths.fnAccuracy, 'w').write(
                'Validation character error rate of saved model: %f%%' % (charErrorRate*100.0))
        else:
            print('Character error rate not improved')
            noImprovementSince += 1

        # Stop training if no more improvement in the last x epochs
        if noImprovementSince >= earlyStopping:
            print('No more improvement since %d epochs. Training stopped.' %
                  earlyStopping)
            break

In [14]:
def validate(model, loader):
    """ Validate neural network """
    print('Validate neural network')
    loader.validationSet()
    numCharErr = 0
    numCharTotal = 0
    numWordOK = 0
    numWordTotal = 0

    totalCER = []
    totalWER = []
    while loader.hasNext():
        iterInfo = loader.getIteratorInfo()
        print('Batch:', iterInfo[0], '/', iterInfo[1])
        batch = loader.getNext()
        recognized = model.inferBatch(batch)

        print('Ground truth -> Recognized')
        for i in range(len(recognized)):
            numWordOK += 1 if batch.gtTexts[i] == recognized[i] else 0
            numWordTotal += 1
            dist = editdistance.eval(recognized[i], batch.gtTexts[i])

            currCER = dist/max(len(recognized[i]), len(batch.gtTexts[i]))
            totalCER.append(currCER)

            currWER = wer(recognized[i].split(), batch.gtTexts[i].split())
            totalWER.append(currWER)

            numCharErr += dist
            numCharTotal += len(batch.gtTexts[i])
            print('[OK]' if dist == 0 else '[ERR:%d]' % dist, '"' +
                  batch.gtTexts[i] + '"', '->', '"' + recognized[i] + '"')

    # Print validation result
    try:
        charErrorRate = sum(totalCER)/len(totalCER)
        wordErrorRate = sum(totalWER)/len(totalWER)
        textLineAccuracy = numWordOK / numWordTotal
    except ZeroDivisionError:
        charErrorRate = 0
        wordErrorRate = 0
        textLineAccuracy = 0
    print('Character error rate: %f%%. Text line accuracy: %f%%. Word error rate: %f%%' %
          (charErrorRate*100.0, textLineAccuracy*100.0, wordErrorRate*100.0))
    return charErrorRate, textLineAccuracy, wordErrorRate


In [15]:
def infer(model, fnImg):
    """ Recognize text in image provided by file path """
    img = preprocessor(fnImg, model.imgSize, binary=True)
    # Fill all batch elements with same input image
    batch = Batch(None, [img] * Model.batchSize)
    recognized = model.inferBatch(batch)  # recognize text
    # All batch elements hold same result
    print('Recognized:', '"' + recognized[0] + '"')

In [16]:
def main(operation):
    """ Main function """
    decoderType = DecoderType.BestPath
    # Load training data, create TF model
    loader = DataLoader(FilePaths.fnTrain, Model.batchSize,Model.imgSize, Model.maxTextLen)

    if operation == 'train':
        # train
        model = Model(loader.charList, decoderType)
        train(model, loader)
    elif operation == 'test':
        # test
        print(open(FilePaths.fnAccuracy).read())
        model = Model(open(FilePaths.fnCharList).read(), decoderType, mustRestore=False)
        infer(model, FilePaths.fnInfer)

In [None]:
main('test')