In [351]:
import test_config as cfg
import numpy as np
import itertools, functools

%matplotlib inline
import matplotlib.pyplot as plt
import sys, os, os.path, time, datetime
import pickle, io, json

import skimage, skimage.io, skimage.transform, skimage.filters
import sklearn, sklearn.metrics

import importlib
sys.path.append('../src/')
import modutils
import word_processing as wp
import tqdm
import tensorflow as tf
import editdistance

In [215]:
class BaseTransformer:
    def __init__(self):
        pass
    def transform(self, x):
        return x

    
class SequentialTransformer:
    def __init__(self, *args):
        self.stages_ = args
        
    def transform(self, x):
        res = x
        for s in self.stages_:
            res = s.transform(res)
        return res
    
class LoadImageTransformer(BaseTransformer):
    def __init__(self, path):
        self.path_ = path
        
    def transform(self, x):
        if type(x) != str:
            raise Exception("LoadImageTransformer: expects filename as argument!")
        return skimage.io.imread(os.path.join(self.path_, x), as_grey=True)
    
class ConvertFloatTransformer(BaseTransformer):
    def __init__(self, min_value = 0.0, max_value = 1.0):
        self.min_ = min_value
        self.max_ = max_value
        
    def transform(self, x):
        if x.dtype in (np.float, np.float64, np.float32):
            return x
        if x.dtype == np.uint8:
            return (x / 255.0) * (self.max_ - self.min_) + self.min_
        if x.dtype == np.uint16:
            return (x / 65535.0) * (self.max_ - self.min_) + self.min_
        raise Exception("ConvertFloatTransformer: unexpected argument type!")
    
class RandomStretchTransformer(BaseTransformer):
    def __init__(self, min_scale = 0.66, max_scale = 1.5, fill_value=1.0):
        self.max_ = max_scale
        self.min_ = min_scale
        self.fill_ = fill_value
        
    def transform(self, x):
        f = np.random.uniform(self.min_, self.max_)
        return skimage.transform.rescale(x, (1.0, f), mode='constant', cval=self.fill_)
    
class TransposeTransformer(BaseTransformer):
    def __init__(self):
        pass
    def transform(self, x):
        return np.transpose(x)
    
class FitSizeTransformer(BaseTransformer):
    def __init__(self, width, height, fill_value=1.0):
        self.w_ = width
        self.h_ = height
        self.fill_ = fill_value
        self.template_ = np.ones((self.h_, self.w_)) * self.fill_
        
    def transform(self, x):
        (h, w) = x.shape
        f = max(w / self.w_, h / self.h_)
        res = self.template_.copy()
        rw = max(min(self.w_, int(w / f)), 1)
        rh = max(min(self.h_, int(h / f)), 1)
        res[0:rh, 0:rw] = skimage.transform.resize(x, (rh, rw), mode='constant', cval=self.fill_)
        return res
    
class StandardizeTransformer(BaseTransformer):
    def __init__(self):
        pass
    
    def transform(self, x):
        m = np.mean(x)
        s = np.std(x)
        if s <= 1e-9:
            return x - m
        return (x - m) / s
    
class TruncateLabelTransform(BaseTransformer):
    def __init__(self, max_cost):
        self.max_cost_ = max_cost
        
    def transform(self, x):
        if type(x) != str:
            raise Exception("TruncateLabelTransform: input expected to be of type string!")
        cost = 0
        for i in range(len(x)):
            flg = (i > 0) and (x[i] == x[i-1])
            cost += 1 + int(flg)
            if cost > max_cost:
                return x[:i]
        return x

In [411]:
class HTRModel:
    def __init__(self, charlist, img_size=(128, 32), text_len=32,
                 cnn_kernels = [5, 5, 3, 3, 3],
                 cnn_features = [1, 32, 64, 128, 128, 256],
                 cnn_pools = [(2,2), (2,2), (1,2), (1,2), (1,2)],
                 cnn_strides = None,
                 rnn_cells = [256, 256],
                 decoder = 'best-path',
                 model_path = '/htr-model/',
                 restore=False):
        self.chars_ = charlist
        self.restore_ = restore
        self.epochID_ = 0
        self.img_size_ = img_size
        self.text_len_ = text_len
        
        self.cnn_kernels_ = cnn_kernels
        self.cnn_features_ = cnn_features
        self.cnn_pools_ = cnn_pools
        self.cnn_strides_ = cnn_strides
        if cnn_pools is None and cnn_strides is None:
            raise Exception("Must specify at least one of `pools` and `strides`!")
        if cnn_pools is None:
            self.cnn_pools_ = cnn_strides
        if cnn_strides is None:
            self.cnn_strides_ = cnn_pools
        self.rnn_cells_ = rnn_cells
        self.model_path_ = model_path
        
        if decoder not in ('best-path', 'beam-search'):
            raise Exception("HTRModel: unknown decoder name `{}`. Expected `best-path` or `beam-search`".format(decoder))
        self.decoder_ = decoder
        
        tf.reset_default_graph()
        self.tf_is_train_ = tf.placeholder(tf.bool, name='is_train')
        self.tf_in_images_ = tf.placeholder(tf.float32, shape=(None, self.img_size_[0], self.img_size_[1]))

        self.tf_cnn_out_ = HTRModel.setupCNN_(self.tf_in_images_, self.tf_is_train_,
                    self.cnn_kernels_, self.cnn_features_, self.cnn_pools_, self.cnn_strides_)
        
        self.tf_rnn_out_ = HTRModel.setupRNN_(self.tf_cnn_out_, len(self.chars_), self.rnn_cells_)
        
        self.setupCTC_()

        self.snap_id_ = 0
        self.trained_samples_ = 0
        self.tf_learning_rate_ = tf.placeholder(tf.float32, shape=[])
        self.tf_update_ops_ = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 
        with tf.control_dependencies(self.tf_update_ops_):
            self.tf_optimizer_ = tf.train.RMSPropOptimizer(self.tf_learning_rate_).minimize(self.tf_loss_)

        (self.tf_session_, self.tf_saver_) = HTRModel.setupTF_(model_path)
        
    def setupCNN_(tf_input, tf_is_train, kernels, features, pools, strides):
        chk1 = len(kernels)+1 != len(features)
        chk2 = len(kernels) != len(pools)
        chk3 = len(pools) != len(strides)
        if chk1 or chk2 or chk3:
            print(len(kernels), len(pools), len(strides), len(features))
            raise Exception("HTRModel.setupCNN: lengths of arguments mismatch!")
            
        tf_cnn_input = tf.expand_dims(input=tf_input, axis=3)

        pool = tf_cnn_input
        for i in range(len(kernels)):
            kernel = tf.Variable(tf.truncated_normal([kernels[i], kernels[i], features[i], features[i + 1]], stddev=0.1))
            conv = tf.nn.conv2d(pool, kernel, padding='SAME',  strides=(1,1,1,1))
            conv_norm = tf.layers.batch_normalization(conv, training=tf_is_train)
            relu = tf.nn.relu(conv_norm)
            pool = tf.nn.max_pool(relu, (1, pools[i][0], pools[i][1], 1), (1, strides[i][0], strides[i][1], 1), 'VALID')

        return pool


    def setupRNN_(tf_input, charnum, cell_sizes):
        rnn_input = tf.squeeze(tf_input, axis=[2])
        #default [256, 256]

        cells = [tf.contrib.rnn.LSTMCell(num_units=x, state_is_tuple=True) for x in cell_sizes]
        stacked = tf.contrib.rnn.MultiRNNCell(cells, state_is_tuple=True)

        # bidirectional RNN, BxTxF -> BxTx2H
        ((fw, bw), _) = tf.nn.bidirectional_dynamic_rnn(cell_fw=stacked, cell_bw=stacked, inputs=rnn_input, dtype=rnn_input.dtype)
        # BxTxH + BxTxH -> BxTx2H -> BxTx1X2H
        concat = tf.expand_dims(tf.concat([fw, bw], 2), 2)
        
        # project output to chars (including blank): BxTx1x2H -> BxTx1xC -> BxTxC
        kernel = tf.Variable(tf.truncated_normal([1, 1, sum(cell_sizes), charnum + 1], stddev=0.1))
        return tf.squeeze(tf.nn.atrous_conv2d(value=concat, filters=kernel, rate=1, padding='SAME'), axis=[2])
    
    def setupCTC_(self):
        self.tf_ctc_in_ = tf.transpose(self.tf_rnn_out_, [1, 0, 2]) # BxTxC -> TxBxC
        # ground truth text as sparse tensor
        self.tf_ctc_gt_ = tf.SparseTensor(tf.placeholder(tf.int64, shape=[None, 2]) , tf.placeholder(tf.int32, [None]), tf.placeholder(tf.int64, [2]))

        # calc loss for batch
        self.tf_seq_len_ = tf.placeholder(tf.int32, [None])
        self.tf_loss_ = tf.reduce_mean(tf.nn.ctc_loss(labels=self.tf_ctc_gt_, inputs=self.tf_ctc_in_, sequence_length=self.tf_seq_len_, ctc_merge_repeated=True))

        # calc loss for each element to compute label probability
        self.tf_ctc_in_saved_ = tf.placeholder(tf.float32, shape=[self.text_len_, None, len(self.chars_) + 1])
        self.tf_loss_per_elem_ = tf.nn.ctc_loss(labels=self.tf_ctc_gt_, inputs=self.tf_ctc_in_saved_, sequence_length=self.tf_seq_len_, ctc_merge_repeated=True)

        if self.decoder_ == 'best-path':
            self.tf_decoder_ = tf.nn.ctc_greedy_decoder(inputs=self.tf_ctc_in_, sequence_length=self.tf_seq_len_)
        elif self.decoder_ == 'beam-search':
            self.tf_decoder_ = tf.nn.ctc_beam_search_decoder(inputs=self.tf_ctc_in_, sequence_length=self.tf_seq_len_, beam_width=50, merge_repeated=False)
    
    def setupTF_(model_path, max_to_keep=1):
        print('Python: {}; TF: {}'.format(sys.version, tf.__version__))
        sess=tf.Session()
        saver = tf.train.Saver(max_to_keep=max_to_keep)
        latest_snapshot = tf.train.latest_checkpoint(model_path)
        if latest_snapshot:
            print('Starting hot: {}'.format(latest_snapshot))
            saver.restore(sess, latest_snapshot)
        else:
            print('Starting cold')
            sess.run(tf.global_variables_initializer())

        return (sess,saver)
    
    
    def encodeLabels(self, texts):
        indices = []
        values = []
        shape = [len(texts), max(len(x) for x in texts)] # last entry must be max(labelList[i])

        for (i, text) in enumerate(texts):
            encoded_text = [self.chars_.index(c) for c in text]
            for (j, label) in enumerate(encoded_text):
                indices.append([i, j])
                values.append(label)
        return (indices, values, shape)


    def decodeOutput(self, ctc_output, batch_size):
        encoded_labels = [[] for i in range(batch_size)]

        decoded=ctc_output[0][0] 
        # go over all indices and save mapping: batch -> values
        for (k, (i, j)) in enumerate(decoded.indices):
            label = decoded.values[k]
            encoded_labels[i].append(label)

        return [''.join([self.chars_[c] for c in x]) for x in encoded_labels]
    
    def getLearningRate(self):
        return 0.01 if self.trained_samples_ < 1e4 else (0.001 if self.trained_samples_ < 1e5 else 0.0001)


    def trainBatch(self, imgs, texts):
        batch_size = len(imgs)
        gt_sparse = self.encodeLabels(texts)
        rate =  self.getLearningRate()
        evalList = [self.tf_optimizer_, self.tf_loss_]
        feedDict = {self.tf_in_images_ : imgs,
                    self.tf_ctc_gt_ : gt_sparse,
                    self.tf_seq_len_ : [self.text_len_] * batch_size,
                    self.tf_learning_rate_ : rate,
                    self.tf_is_train_: True}
        (_, lossVal) = self.tf_session_.run(evalList, feedDict)
        self.trained_samples_ += batch_size
        return lossVal
    
    def validBatch(self, imgs, texts):
        batch_size = len(imgs)
        gt_sparse = self.encodeLabels(texts)
        evalList = [self.tf_decoder_, self.tf_loss_]
        feedDict = {self.tf_in_images_ : imgs,
                    self.tf_ctc_gt_ : gt_sparse,
                    self.tf_seq_len_ : [self.text_len_] * batch_size,
                    self.tf_is_train_: False}
        (evalRes, lossVal) = self.tf_session_.run(evalList, feedDict)
        return self.decodeOutput(evalRes, batch_size), lossVal
    
    def inferBatch(self, imgs):
        batch_size = len(imgs)
        evalList = [self.tf_decoder_]
        feedDict = {self.tf_in_images_ : imgs,
                    self.tf_seq_len_ : [self.text_len_] * batch_size,
                    self.tf_is_train_: False}
        evalRes = self.tf_session_.run(evalList, feedDict)
        return self.decodeOutput(evalRes[0], batch_size)

    def save(self):
        self.snap_id_ += 1
        self.tf_saver_.save(self.tf_session_, self.model_path_ + 'snapshot', global_step=self.snap_id_)

In [430]:
def format_timedelta(seconds):
    if seconds < 1e-10:
        return '0s'
    sf = seconds - np.floor(seconds)
    si = int(np.floor(seconds))
    d, s_h = divmod(si, 3600*24)
    h, s_m = divmod(s_h, 3600)
    m, s = divmod(s_m, 60)
    if d > 9:
        return '{}d'.format(d)
    elif d > 0:
        return '{}d {}h'.format(d, h)
    elif h > 9:
        return '{}h'.format(h)
    elif h > 0:
        return '{}h {}m'.format(h, m)
    elif m > 9:
        return '{}m'.format(m)
    elif m > 0:
        return '{}m {}s'.format(m, s)
    elif s > 9:
        return '{}s'.format(s)
    elif s > 0:
        return '{:.1f}s'.format(s + sf)
    else:
        return '{}ms'.format(int(sf*1000))
    
def apply_esmooth(array, factor):
    tmp = np.exp(np.cumsum([factor]*len(array)))
    tmp = tmp / np.sum(tmp)
    return np.sum(tmp * array)
    
def train(model, imgs, labels, batch_size, transform_pipeline=BaseTransformer()):
    num = len(imgs)
    num_batches = num // batch_size
    ids = np.arange(num)
    np.random.shuffle(ids)
    text_template = 'Train batch {}/{}. Loss: {:.2f}. Time: {}. ETA: {}.'
    hist_times = []
    t_start = time.perf_counter()
    sum_loss = 0
    for i in range(num_batches):
        t0 = time.perf_counter()
        batch_ids = ids[(i*batch_size):((i+1)*batch_size)]
        batch_imgs = np.array([transform_pipeline.transform(imgs[j]) for j in batch_ids])
        batch_lbls = np.array([labels[j] for j in batch_ids])
        loss = model.trainBatch(batch_imgs, batch_lbls)
        sum_loss += loss * len(batch_lbls)
        t1 = time.perf_counter()
        hist_times.append(t1-t0)
        t_delta = apply_esmooth(np.array(hist_times)[::-1], -0.5)
        t_eta = t_delta * (num_batches - i - 1)
        print(text_template.format(i+1, num_batches, loss, format_timedelta(t1-t0), format_timedelta(t_eta)))
    return sum_loss / len(labels)
        
def validate(model, imgs, labels, batch_size,
             transform_pipeline=BaseTransformer()):
    n_char_err = 0
    n_char = 0
    n_word_ok = 0
    n_word = 0
    num_batches = len(imgs) // batch_size
    text_template = 'Validation batch {}/{}. Time: {}. ETA: {}.'
    hist_times = []
    t_start = time.perf_counter()
    sum_loss = 0
    for i in range(num_batches):
        t0 = time.perf_counter()
        batch_imgs = np.array([transform_pipeline.transform(x)
                               for x in imgs[(i*batch_size):((i+1)*batch_size)]])
        batch_lbls = np.array(labels[(i*batch_size):((i+1)*batch_size)])
        recognized, loss = model.validBatch(batch_imgs, batch_lbls)
        sum_loss += loss * len(batch_lbls)
        for j in range(len(recognized)):
            n_word_ok += int(batch_lbls[j] == recognized[j])
            n_word += 1
            dist = editdistance.eval(recognized[j], batch_lbls[j])
            n_char_err += dist
            n_char += len(batch_lbls[j])
        t1 = time.perf_counter()
        hist_times.append(t1-t0)
        t_delta = apply_esmooth(np.array(hist_times)[::-1], -0.5)
        t_eta = t_delta * (num_batches - i - 1)
        t_eta = (t1 - t_start) / (i + 1) * (num_batches - i - 1)
        print(text_template.format(i+1, num_batches, format_timedelta(t1-t0), format_timedelta(t_eta)))

    cer = n_char_err / n_char
    wa = n_word_ok / n_word
    print('Validation results: CER: {:.3f}, WA: {:.3f}.'.format(cer, wa))
    return sum_loss/len(labels), cer, wa


def run_training(model, train_imgs, train_labels, valid_imgs, valid_labels,
                 batch_size=128, transform_pipeline=BaseTransformer()):
    epoch = 0
    text_template = 'Epoch {} complete in {}. T-loss is {:.2f}, V-loss is {:.2f}'
    while True:
        t0 = time.perf_counter()
        epoch += 1
        print('Epoch: {}'.format(epoch))

        tloss = train(model, train_imgs, train_labels, batch_size, transform_pipeline=transform_pipeline)
        vloss, cer, wa = validate(model, valid_imgs, valid_labels, batch_size, transform_pipeline=transform_pipeline)
        model.save()
        t1 = time.perf_counter()
        print(text_template.format(epoch, format_timedelta(t1-t0), tloss, vloss))
        
def load_sample(fname):
    path = '.'.join(fname.split('.')[:-1])
    sample = [tuple(y.strip() for y in x.split(' ')) for x in open(fname, 'r').readlines()]
    load_pipeline = SequentialTransformer(LoadImageTransformer(path), ConvertFloatTransformer())
    return [x for x in sample if len(x[1])>0], load_pipeline

        
def prepare_sample(sample, pipeline):
    imgs = [pipeline.transform('{}.png'.format(x)) for (x,_) in sample]
    lbls = [x for (_,x) in sample]
    return imgs, lbls

In [431]:
%%time
train_sample, train_load_pipeline = load_sample('D:/Data/HTR/train.txt')
train_imgs, train_lbls = prepare_sample(train_sample, train_load_pipeline)

Wall time: 31.1 s


In [432]:
%%time
valid_sample, valid_load_pipeline = load_sample('D:/Data/HTR/valid.txt')
valid_imgs, valid_lbls = prepare_sample(valid_sample, valid_load_pipeline)

Wall time: 19.3 s


In [433]:
transform_pipeline = SequentialTransformer(
    RandomStretchTransformer(),
    FitSizeTransformer(128, 32),
    TransposeTransformer(),
    StandardizeTransformer())

In [434]:
charlist = sorted(list(functools.reduce(set.union, [set(x) for x in train_lbls])))

In [435]:
%%time
model = HTRModel(charlist, img_size=(128, 32),
                cnn_kernels = [5, 5, 3, 3], #[5, 5, 3, 3, 3],#default
                cnn_features = [1, 32, 64, 64, 128], #[1, 32, 64, 128, 128, 256],#default
                cnn_pools = [(2,2), (2,2), (1,2), (1,4)], #[(2,2), (2,2), (1,2), (1,2), (1,2)] #default
                rnn_cells = [128, 128], #default 
                )

Python: 3.6.1 |Anaconda 4.4.0 (64-bit)| (default, May 11 2017, 13:25:24) [MSC v.1900 64 bit (AMD64)]; TF: 1.2.1
Starting cold
Wall time: 4.56 s


In [None]:
run_training(model, train_imgs, train_lbls, valid_imgs, valid_lbls,
            batch_size=256, transform_pipeline=transform_pipeline)

Epoch: 1
Train batch 1/54. Loss: 130.38. Time: 10s. ETA: 9m 14s.
Train batch 2/54. Loss: 73.45. Time: 10s. ETA: 9m 20s.
Train batch 3/54. Loss: 21.90. Time: 13s. ETA: 10m.
Train batch 4/54. Loss: 22.68. Time: 16s. ETA: 11m.
Train batch 5/54. Loss: 21.78. Time: 17s. ETA: 12m.
Train batch 6/54. Loss: 21.40. Time: 13s. ETA: 11m.
Train batch 7/54. Loss: 22.34. Time: 12s. ETA: 10m.
Train batch 8/54. Loss: 21.49. Time: 14s. ETA: 10m.
Train batch 9/54. Loss: 20.81. Time: 11s. ETA: 9m 52s.
Train batch 10/54. Loss: 21.57. Time: 11s. ETA: 9m 10s.
Train batch 11/54. Loss: 21.32. Time: 12s. ETA: 8m 54s.
Train batch 12/54. Loss: 19.23. Time: 11s. ETA: 8m 21s.
Train batch 13/54. Loss: 20.66. Time: 11s. ETA: 8m 9s.
Train batch 14/54. Loss: 20.31. Time: 11s. ETA: 7m 52s.
Train batch 15/54. Loss: 19.33. Time: 11s. ETA: 7m 42s.
Train batch 16/54. Loss: 18.83. Time: 11s. ETA: 7m 20s.
Train batch 17/54. Loss: 19.49. Time: 11s. ETA: 7m 6s.
Train batch 18/54. Loss: 18.70. Time: 10s. ETA: 6m 42s.
Train batch