In [1]:
import test_config as cfg
import numpy as np
import itertools, functools

%matplotlib inline
import matplotlib.pyplot as plt
import sys, os, os.path, time, datetime
import pickle, io, json

import skimage, skimage.io, skimage.transform, skimage.filters
import sklearn, sklearn.metrics

import importlib
sys.path.append('../src/')
import modutils
import word_processing as wp
import htr_model as hm
import tqdm
import tensorflow as tf
import editdistance

In [2]:
class BaseTransformer:
    def __init__(self):
        pass
    def transform(self, x):
        return x

    
class SequentialTransformer:
    def __init__(self, *args):
        self.stages_ = args
        
    def transform(self, x):
        res = x
        for s in self.stages_:
            res = s.transform(res)
        return res
    
class LoadImageTransformer(BaseTransformer):
    def __init__(self, path):
        self.path_ = path
        
    def transform(self, x):
        if type(x) != str:
            raise Exception("LoadImageTransformer: expects filename as argument!")
        return skimage.io.imread(os.path.join(self.path_, x), as_grey=True)
    
class ConvertFloatTransformer(BaseTransformer):
    def __init__(self, min_value = 0.0, max_value = 1.0):
        self.min_ = min_value
        self.max_ = max_value
        
    def transform(self, x):
        if x.dtype in (np.float, np.float64, np.float32):
            return x
        if x.dtype == np.uint8:
            return (x / 255.0) * (self.max_ - self.min_) + self.min_
        if x.dtype == np.uint16:
            return (x / 65535.0) * (self.max_ - self.min_) + self.min_
        raise Exception("ConvertFloatTransformer: unexpected argument type!")
    
class RandomStretchTransformer(BaseTransformer):
    def __init__(self, min_scale = 0.66, max_scale = 1.5, fill_value=1.0):
        self.max_ = max_scale
        self.min_ = min_scale
        self.fill_ = fill_value
        
    def transform(self, x):
        f = np.random.uniform(self.min_, self.max_)
        return skimage.transform.rescale(x, (1.0, f), mode='constant', cval=self.fill_)
    
class TransposeTransformer(BaseTransformer):
    def __init__(self):
        pass
    def transform(self, x):
        return np.transpose(x)
    
class FitSizeTransformer(BaseTransformer):
    def __init__(self, width, height, fill_value=1.0):
        self.w_ = width
        self.h_ = height
        self.fill_ = fill_value
        self.template_ = np.ones((self.h_, self.w_)) * self.fill_
        
    def transform(self, x):
        (h, w) = x.shape
        f = max(w / self.w_, h / self.h_)
        res = self.template_.copy()
        rw = max(min(self.w_, int(w / f)), 1)
        rh = max(min(self.h_, int(h / f)), 1)
        res[0:rh, 0:rw] = skimage.transform.resize(x, (rh, rw), mode='constant', cval=self.fill_)
        return res
    
class StandardizeTransformer(BaseTransformer):
    def __init__(self):
        pass
    
    def transform(self, x):
        m = np.mean(x)
        s = np.std(x)
        if s <= 1e-9:
            return x - m
        return (x - m) / s
    
class TruncateLabelTransform(BaseTransformer):
    def __init__(self, max_cost):
        self.max_cost_ = max_cost
        
    def transform(self, x):
        if type(x) != str:
            raise Exception("TruncateLabelTransform: input expected to be of type string!")
        cost = 0
        for i in range(len(x)):
            flg = (i > 0) and (x[i] == x[i-1])
            cost += 1 + int(flg)
            if cost > max_cost:
                return x[:i]
        return x

In [3]:
def format_timedelta(seconds):
    if seconds < 1e-10:
        return '0s'
    sf = seconds - np.floor(seconds)
    si = int(np.floor(seconds))
    d, s_h = divmod(si, 3600*24)
    h, s_m = divmod(s_h, 3600)
    m, s = divmod(s_m, 60)
    if d > 9:
        return '{}d'.format(d)
    elif d > 0:
        return '{}d {}h'.format(d, h)
    elif h > 9:
        return '{}h'.format(h)
    elif h > 0:
        return '{}h {}m'.format(h, m)
    elif m > 9:
        return '{}m'.format(m)
    elif m > 0:
        return '{}m {}s'.format(m, s)
    elif s > 9:
        return '{}s'.format(s)
    elif s > 0:
        return '{:.1f}s'.format(s + sf)
    else:
        return '{}ms'.format(int(sf*1000))
    
def apply_esmooth(array, factor):
    tmp = np.exp(np.cumsum([factor]*len(array)))
    tmp = tmp / np.sum(tmp)
    return np.sum(tmp * array)
    
def train(model, imgs, labels, batch_size, transform_pipeline=BaseTransformer()):
    num = len(imgs)
    num_batches = num // batch_size
    if len(imgs) > num_batches * batch_size:
        num_batches += 1
    ids = np.arange(num)
    np.random.shuffle(ids)
    text_template = 'Train batch {}/{}. Loss: {:.2f}. Time: {}. ETA: {}.'
    hist_times = []
    t_start = time.perf_counter()
    sum_loss = 0
    for i in range(num_batches):
        t0 = time.perf_counter()
        batch_ids = ids[(i*batch_size):((i+1)*batch_size)]
        batch_imgs = np.array([transform_pipeline.transform(imgs[j]) for j in batch_ids])
        batch_lbls = np.array([labels[j] for j in batch_ids])
        loss = model.trainBatch(batch_imgs, batch_lbls)
        sum_loss += loss * len(batch_lbls)
        t1 = time.perf_counter()
        hist_times.append(t1-t0)
        t_delta = apply_esmooth(np.array(hist_times)[::-1], -0.5)
        t_eta = t_delta * (num_batches - i - 1)
        print(text_template.format(i+1, num_batches, loss, format_timedelta(t1-t0), format_timedelta(t_eta)))
    return sum_loss / len(labels)
        
def validate(model, imgs, labels, batch_size,
             transform_pipeline=BaseTransformer()):
    n_char_err = 0
    n_char = 0
    n_word_ok = 0
    n_word = 0
    num_batches = len(imgs) // batch_size
    if len(imgs) > num_batches * batch_size:
        num_batches += 1
    text_template = 'Validation batch {}/{}. Time: {}. ETA: {}.'
    hist_times = []
    t_start = time.perf_counter()
    sum_loss = 0
    for i in range(num_batches):
        t0 = time.perf_counter()
        batch_imgs = np.array([transform_pipeline.transform(x)
                               for x in imgs[(i*batch_size):((i+1)*batch_size)]])
        batch_lbls = np.array(labels[(i*batch_size):((i+1)*batch_size)])
        recognized, loss = model.validBatch(batch_imgs, batch_lbls)
        sum_loss += loss * len(batch_lbls)
        for j in range(len(recognized)):
            n_word_ok += int(batch_lbls[j] == recognized[j])
            n_word += 1
            dist = editdistance.eval(recognized[j], batch_lbls[j])
            n_char_err += dist
            n_char += len(batch_lbls[j])
        t1 = time.perf_counter()
        hist_times.append(t1-t0)
        t_delta = apply_esmooth(np.array(hist_times)[::-1], -0.5)
        t_eta = t_delta * (num_batches - i - 1)
        t_eta = (t1 - t_start) / (i + 1) * (num_batches - i - 1)
        print(text_template.format(i+1, num_batches, format_timedelta(t1-t0), format_timedelta(t_eta)))

    cer = n_char_err / n_char
    wa = n_word_ok / n_word
    print('Validation results: CER: {:.3f}, WA: {:.3f}.'.format(cer, wa))
    return sum_loss/len(labels), cer, wa


def run_training(model, train_imgs, train_labels, valid_imgs, valid_labels,
                 batch_size=128, train_pipeline=BaseTransformer(), valid_pipeline=None):
    epoch = 0
    if valid_pipeline is None:
        valid_pipeline = train_pipeline
    text_template = 'Epoch {} complete in {}. T-loss is {:.2f}, V-loss is {:.2f}'
    while True:
        t0 = time.perf_counter()
        epoch += 1
        print('Epoch: {}'.format(epoch))

        tloss = train(model, train_imgs, train_labels, batch_size, transform_pipeline=train_pipeline)
        vloss, cer, wa = validate(model, valid_imgs, valid_labels, batch_size, transform_pipeline=valid_pipeline)
        model.save()
        t1 = time.perf_counter()
        print(text_template.format(epoch, format_timedelta(t1-t0), tloss, vloss))
        
def load_sample(fname):
    path = '.'.join(fname.split('.')[:-1])
    sample = [tuple(y.strip() for y in x.split(' ')) for x in open(fname, 'r').readlines()]
    load_pipeline = SequentialTransformer(LoadImageTransformer(path), ConvertFloatTransformer())
    return [x for x in sample if len(x[1])>0], load_pipeline

        
def prepare_sample(sample, pipeline):
    imgs = [pipeline.transform('{}.png'.format(x)) for (x,_) in sample]
    lbls = [x for (_,x) in sample]
    return imgs, lbls

In [4]:
%%time
train_sample, train_load_pipeline = load_sample('D:/Data/HTR/train.txt')
train_imgs, train_lbls = prepare_sample(train_sample, train_load_pipeline)

  warn('`as_grey` has been deprecated in favor of `as_gray`')


Wall time: 22.7 s


In [5]:
%%time
valid_sample, valid_load_pipeline = load_sample('D:/Data/HTR/valid.txt')
valid_imgs, valid_lbls = prepare_sample(valid_sample, valid_load_pipeline)

Wall time: 3.43 s


In [6]:
train_pipeline = SequentialTransformer(
    RandomStretchTransformer(),
    FitSizeTransformer(128, 32),
    TransposeTransformer(),
    StandardizeTransformer())

valid_pipeline = SequentialTransformer(
    FitSizeTransformer(128, 32),
    TransposeTransformer(),
    StandardizeTransformer())

charlist = sorted(list(functools.reduce(set.union, [set(x) for x in train_lbls])))

In [7]:
importlib.reload(hm)

<module 'htr_model' from '../src\\htr_model.py'>

In [8]:
%%time
model = hm.HTRModel(charlist, img_size=(128, 32),
                cnn_kernels = [5, 5, 3, 3], #[5, 5, 3, 3, 3],#default
                cnn_features = [1, 32, 64, 64, 128], #[1, 32, 64, 128, 128, 256],#default
                cnn_pools = [(2,2), (2,2), (1,2), (1,4)], #[(2,2), (2,2), (1,2), (1,2), (1,2)] #default
                rnn_cells = [128, 128], #default 
                model_path='D:/models/htr-static-128/', decoder='best-path'
                )

W0722 21:22:06.779354 13824 deprecation_wrapper.py:119] From ../src\htr_model.py:76: The name tf.truncated_normal is deprecated. Please use tf.random.truncated_normal instead.

W0722 21:22:06.799336 13824 deprecation_wrapper.py:119] From ../src\htr_model.py:82: The name tf.nn.max_pool is deprecated. Please use tf.nn.max_pool2d instead.

W0722 21:22:07.770168 13824 lazy_loader.py:50] 
The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.

W0722 21:22:07.770168 13824 deprecation.py:323] From ../src\htr_model.py:107: LSTMCell.__init__ (from tensorflow.python.ops.rnn_cell_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This class is equivalent as 

Python: 3.7.3 (default, Mar 27 2019, 17:13:21) [MSC v.1915 64 bit (AMD64)]; TF: 1.14.0
Starting cold


W0722 21:22:13.514518 13824 deprecation_wrapper.py:119] From ../src\htr_model.py:145: The name tf.global_variables_initializer is deprecated. Please use tf.compat.v1.global_variables_initializer instead.



Wall time: 7.01 s


In [9]:
run_training(model, train_imgs, train_lbls, valid_imgs, valid_lbls,
            batch_size=256, train_pipeline=train_pipeline, valid_pipeline=valid_pipeline)

Epoch: 1


  warn('The default multichannel argument (None) is deprecated.  Please '
  warn("Anti-aliasing will be enabled by default in skimage 0.15 to "


Train batch 1/54. Loss: 126.99. Time: 5.0s. ETA: 4m 24s.
Train batch 2/54. Loss: 69.25. Time: 2.5s. ETA: 2m 58s.
Train batch 3/54. Loss: 25.63. Time: 2.4s. ETA: 2m 29s.
Train batch 4/54. Loss: 23.14. Time: 2.4s. ETA: 2m 14s.
Train batch 5/54. Loss: 23.98. Time: 2.4s. ETA: 2m 5s.
Train batch 6/54. Loss: 23.53. Time: 2.6s. ETA: 2m 3s.
Train batch 7/54. Loss: 22.58. Time: 2.6s. ETA: 2m 0s.
Train batch 8/54. Loss: 22.66. Time: 3.4s. ETA: 2m 12s.
Train batch 9/54. Loss: 21.84. Time: 4.7s. ETA: 2m 42s.
Train batch 10/54. Loss: 22.44. Time: 2.8s. ETA: 2m 23s.
Train batch 11/54. Loss: 21.57. Time: 2.4s. ETA: 2m 6s.
Train batch 12/54. Loss: 21.28. Time: 2.5s. ETA: 1m 57s.
Train batch 13/54. Loss: 19.55. Time: 2.5s. ETA: 1m 49s.
Train batch 14/54. Loss: 20.29. Time: 2.4s. ETA: 1m 42s.
Train batch 15/54. Loss: 20.93. Time: 2.3s. ETA: 1m 35s.
Train batch 16/54. Loss: 20.19. Time: 2.7s. ETA: 1m 37s.
Train batch 17/54. Loss: 19.11. Time: 3.3s. ETA: 1m 45s.
Train batch 18/54. Loss: 20.11. Time: 2.4s.

W0722 21:28:01.506611 13824 deprecation.py:323] From C:\Anaconda3\lib\site-packages\tensorflow\python\training\saver.py:960: remove_checkpoint (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file APIs to delete files with this prefix.


Epoch 2 complete in 2m 42s. T-loss is 17.24, V-loss is 17.44
Epoch: 3
Train batch 1/54. Loss: 17.65. Time: 2.4s. ETA: 2m 4s.
Train batch 2/54. Loss: 18.20. Time: 2.4s. ETA: 2m 3s.
Train batch 3/54. Loss: 15.80. Time: 3.0s. ETA: 2m 17s.
Train batch 4/54. Loss: 17.75. Time: 2.5s. ETA: 2m 10s.
Train batch 5/54. Loss: 17.71. Time: 2.7s. ETA: 2m 9s.
Train batch 6/54. Loss: 18.59. Time: 2.4s. ETA: 2m 1s.
Train batch 7/54. Loss: 18.26. Time: 2.4s. ETA: 1m 56s.
Train batch 8/54. Loss: 16.50. Time: 2.4s. ETA: 1m 51s.
Train batch 9/54. Loss: 16.65. Time: 2.4s. ETA: 1m 47s.
Train batch 10/54. Loss: 16.28. Time: 2.3s. ETA: 1m 44s.
Train batch 11/54. Loss: 18.16. Time: 2.4s. ETA: 1m 41s.
Train batch 12/54. Loss: 16.99. Time: 2.3s. ETA: 1m 38s.
Train batch 13/54. Loss: 17.32. Time: 2.5s. ETA: 1m 38s.
Train batch 14/54. Loss: 18.59. Time: 2.4s. ETA: 1m 35s.
Train batch 15/54. Loss: 16.23. Time: 2.3s. ETA: 1m 32s.
Train batch 16/54. Loss: 16.66. Time: 2.3s. ETA: 1m 29s.
Train batch 17/54. Loss: 16.82.

KeyboardInterrupt: 

In [10]:
tloss, tcer, twa = validate(model, train_imgs, train_lbls, 256, valid_pipeline)

Validation batch 1/54. Time: 1.3s. ETA: 1m 6s.
Validation batch 2/54. Time: 1.2s. ETA: 1m 4s.
Validation batch 3/54. Time: 1.3s. ETA: 1m 4s.
Validation batch 4/54. Time: 1.3s. ETA: 1m 3s.
Validation batch 5/54. Time: 1.2s. ETA: 1m 1s.
Validation batch 6/54. Time: 1.3s. ETA: 1m 0s.
Validation batch 7/54. Time: 1.3s. ETA: 59s.
Validation batch 8/54. Time: 1.3s. ETA: 58s.
Validation batch 9/54. Time: 1.3s. ETA: 57s.
Validation batch 10/54. Time: 1.2s. ETA: 55s.
Validation batch 11/54. Time: 1.2s. ETA: 54s.
Validation batch 12/54. Time: 1.5s. ETA: 53s.
Validation batch 13/54. Time: 1.3s. ETA: 52s.
Validation batch 14/54. Time: 1.2s. ETA: 51s.
Validation batch 15/54. Time: 1.3s. ETA: 49s.
Validation batch 16/54. Time: 1.3s. ETA: 48s.
Validation batch 17/54. Time: 1.3s. ETA: 47s.
Validation batch 18/54. Time: 1.2s. ETA: 45s.
Validation batch 19/54. Time: 1.3s. ETA: 44s.
Validation batch 20/54. Time: 1.3s. ETA: 43s.
Validation batch 21/54. Time: 1.3s. ETA: 42s.
Validation batch 22/54. Time: 1

In [11]:
vloss, vcer, vwa = validate(model, valid_imgs, valid_lbls, 256, valid_pipeline)

Validation batch 1/36. Time: 1.3s. ETA: 45s.
Validation batch 2/36. Time: 1.2s. ETA: 43s.
Validation batch 3/36. Time: 1.3s. ETA: 42s.
Validation batch 4/36. Time: 1.3s. ETA: 41s.
Validation batch 5/36. Time: 1.2s. ETA: 39s.
Validation batch 6/36. Time: 1.3s. ETA: 38s.
Validation batch 7/36. Time: 1.3s. ETA: 37s.
Validation batch 8/36. Time: 1.2s. ETA: 35s.
Validation batch 9/36. Time: 1.3s. ETA: 34s.
Validation batch 10/36. Time: 1.4s. ETA: 33s.
Validation batch 11/36. Time: 1.4s. ETA: 32s.
Validation batch 12/36. Time: 1.4s. ETA: 31s.
Validation batch 13/36. Time: 1.6s. ETA: 30s.
Validation batch 14/36. Time: 1.4s. ETA: 29s.
Validation batch 15/36. Time: 1.3s. ETA: 27s.
Validation batch 16/36. Time: 1.4s. ETA: 26s.
Validation batch 17/36. Time: 1.3s. ETA: 25s.
Validation batch 18/36. Time: 1.2s. ETA: 23s.
Validation batch 19/36. Time: 1.3s. ETA: 22s.
Validation batch 20/36. Time: 1.4s. ETA: 21s.
Validation batch 21/36. Time: 1.2s. ETA: 19s.
Validation batch 22/36. Time: 1.2s. ETA: 18

In [13]:
(tcer, vcer), (twa, vwa)

((0.10549381491036582, 0.43740463561723464),
 (0.6814236111111112, 0.23274739583333334))