In [1]:
import test_config as cfg
import numpy as np
import itertools, functools

%matplotlib inline
import matplotlib.pyplot as plt
import sys, os, os.path, time, datetime
import pickle, io, json

import skimage, skimage.io, skimage.transform, skimage.filters
import sklearn, sklearn.metrics

import importlib
sys.path.append('../src/')
import modutils
import word_processing as wp
import htr_model as hm
import tqdm
import tensorflow as tf
import editdistance

In [2]:
class BaseTransformer:
    def __init__(self):
        pass
    def transform(self, x):
        return x

    
class SequentialTransformer:
    def __init__(self, *args):
        self.stages_ = args
        
    def transform(self, x):
        res = x
        for s in self.stages_:
            res = s.transform(res)
        return res
    
class LoadImageTransformer(BaseTransformer):
    def __init__(self, path):
        self.path_ = path
        
    def transform(self, x):
        if type(x) != str:
            raise Exception("LoadImageTransformer: expects filename as argument!")
        return skimage.io.imread(os.path.join(self.path_, x), as_grey=True)
    
class ConvertFloatTransformer(BaseTransformer):
    def __init__(self, min_value = 0.0, max_value = 1.0):
        self.min_ = min_value
        self.max_ = max_value
        
    def transform(self, x):
        if x.dtype in (np.float, np.float64, np.float32):
            return x
        if x.dtype == np.uint8:
            return (x / 255.0) * (self.max_ - self.min_) + self.min_
        if x.dtype == np.uint16:
            return (x / 65535.0) * (self.max_ - self.min_) + self.min_
        raise Exception("ConvertFloatTransformer: unexpected argument type!")
    
class RandomStretchTransformer(BaseTransformer):
    def __init__(self, min_scale = 0.66, max_scale = 1.5, fill_value=1.0):
        self.max_ = max_scale
        self.min_ = min_scale
        self.fill_ = fill_value
        
    def transform(self, x):
        f = np.random.uniform(self.min_, self.max_)
        return skimage.transform.rescale(x, (1.0, f), mode='constant', cval=self.fill_)
    
class TransposeTransformer(BaseTransformer):
    def __init__(self):
        pass
    def transform(self, x):
        return np.transpose(x)
    
class FitSizeTransformer(BaseTransformer):
    def __init__(self, width, height, fill_value=1.0):
        self.w_ = width
        self.h_ = height
        self.fill_ = fill_value
        self.template_ = np.ones((self.h_, self.w_)) * self.fill_
        
    def transform(self, x):
        (h, w) = x.shape
        f = max(w / self.w_, h / self.h_)
        res = self.template_.copy()
        rw = max(min(self.w_, int(w / f)), 1)
        rh = max(min(self.h_, int(h / f)), 1)
        res[0:rh, 0:rw] = skimage.transform.resize(x, (rh, rw), mode='constant', cval=self.fill_)
        return res
    
class StandardizeTransformer(BaseTransformer):
    def __init__(self):
        pass
    
    def transform(self, x):
        m = np.mean(x)
        s = np.std(x)
        if s <= 1e-9:
            return x - m
        return (x - m) / s
    
class TruncateLabelTransform(BaseTransformer):
    def __init__(self, max_cost):
        self.max_cost_ = max_cost
        
    def transform(self, x):
        if type(x) != str:
            raise Exception("TruncateLabelTransform: input expected to be of type string!")
        cost = 0
        for i in range(len(x)):
            flg = (i > 0) and (x[i] == x[i-1])
            cost += 1 + int(flg)
            if cost > max_cost:
                return x[:i]
        return x

In [3]:
def format_timedelta(seconds):
    if seconds < 1e-10:
        return '0s'
    sf = seconds - np.floor(seconds)
    si = int(np.floor(seconds))
    d, s_h = divmod(si, 3600*24)
    h, s_m = divmod(s_h, 3600)
    m, s = divmod(s_m, 60)
    if d > 9:
        return '{}d'.format(d)
    elif d > 0:
        return '{}d {}h'.format(d, h)
    elif h > 9:
        return '{}h'.format(h)
    elif h > 0:
        return '{}h {}m'.format(h, m)
    elif m > 9:
        return '{}m'.format(m)
    elif m > 0:
        return '{}m {}s'.format(m, s)
    elif s > 9:
        return '{}s'.format(s)
    elif s > 0:
        return '{:.1f}s'.format(s + sf)
    else:
        return '{}ms'.format(int(sf*1000))
    
def apply_esmooth(array, factor):
    tmp = np.exp(np.cumsum([factor]*len(array)))
    tmp = tmp / np.sum(tmp)
    return np.sum(tmp * array)
    
def train(model, imgs, labels, batch_size, transform_pipeline=BaseTransformer()):
    num = len(imgs)
    num_batches = num // batch_size
    ids = np.arange(num)
    np.random.shuffle(ids)
    text_template = 'Train batch {}/{}. Loss: {:.2f}. Time: {}. ETA: {}.'
    hist_times = []
    t_start = time.perf_counter()
    sum_loss = 0
    for i in range(num_batches):
        t0 = time.perf_counter()
        batch_ids = ids[(i*batch_size):((i+1)*batch_size)]
        batch_imgs = np.array([transform_pipeline.transform(imgs[j]) for j in batch_ids])
        batch_lbls = np.array([labels[j] for j in batch_ids])
        loss = model.trainBatch(batch_imgs, batch_lbls)
        sum_loss += loss * len(batch_lbls)
        t1 = time.perf_counter()
        hist_times.append(t1-t0)
        t_delta = apply_esmooth(np.array(hist_times)[::-1], -0.5)
        t_eta = t_delta * (num_batches - i - 1)
        print(text_template.format(i+1, num_batches, loss, format_timedelta(t1-t0), format_timedelta(t_eta)))
    return sum_loss / len(labels)
        
def validate(model, imgs, labels, batch_size,
             transform_pipeline=BaseTransformer()):
    n_char_err = 0
    n_char = 0
    n_word_ok = 0
    n_word = 0
    num_batches = len(imgs) // batch_size
    text_template = 'Validation batch {}/{}. Time: {}. ETA: {}.'
    hist_times = []
    t_start = time.perf_counter()
    sum_loss = 0
    for i in range(num_batches):
        t0 = time.perf_counter()
        batch_imgs = np.array([transform_pipeline.transform(x)
                               for x in imgs[(i*batch_size):((i+1)*batch_size)]])
        batch_lbls = np.array(labels[(i*batch_size):((i+1)*batch_size)])
        recognized, loss = model.validBatch(batch_imgs, batch_lbls)
        sum_loss += loss * len(batch_lbls)
        for j in range(len(recognized)):
            n_word_ok += int(batch_lbls[j] == recognized[j])
            n_word += 1
            dist = editdistance.eval(recognized[j], batch_lbls[j])
            n_char_err += dist
            n_char += len(batch_lbls[j])
        t1 = time.perf_counter()
        hist_times.append(t1-t0)
        t_delta = apply_esmooth(np.array(hist_times)[::-1], -0.5)
        t_eta = t_delta * (num_batches - i - 1)
        t_eta = (t1 - t_start) / (i + 1) * (num_batches - i - 1)
        print(text_template.format(i+1, num_batches, format_timedelta(t1-t0), format_timedelta(t_eta)))

    cer = n_char_err / n_char
    wa = n_word_ok / n_word
    print('Validation results: CER: {:.3f}, WA: {:.3f}.'.format(cer, wa))
    return sum_loss/len(labels), cer, wa


def run_training(model, train_imgs, train_labels, valid_imgs, valid_labels,
                 batch_size=128, train_pipeline=BaseTransformer(), valid_pipeline=None):
    epoch = 0
    if valid_pipeline is None:
        valid_pipeline = train_pipeline
    text_template = 'Epoch {} complete in {}. T-loss is {:.2f}, V-loss is {:.2f}'
    while True:
        t0 = time.perf_counter()
        epoch += 1
        print('Epoch: {}'.format(epoch))

        tloss = train(model, train_imgs, train_labels, batch_size, transform_pipeline=train_pipeline)
        vloss, cer, wa = validate(model, valid_imgs, valid_labels, batch_size, transform_pipeline=valid_pipeline)
        model.save()
        t1 = time.perf_counter()
        print(text_template.format(epoch, format_timedelta(t1-t0), tloss, vloss))
        
def load_sample(fname):
    path = '.'.join(fname.split('.')[:-1])
    sample = [tuple(y.strip() for y in x.split(' ')) for x in open(fname, 'r').readlines()]
    load_pipeline = SequentialTransformer(LoadImageTransformer(path), ConvertFloatTransformer())
    return [x for x in sample if len(x[1])>0], load_pipeline

        
def prepare_sample(sample, pipeline):
    imgs = [pipeline.transform('{}.png'.format(x)) for (x,_) in sample]
    lbls = [x for (_,x) in sample]
    return imgs, lbls

In [4]:
%%time
train_sample, train_load_pipeline = load_sample('D:/Data/HTR/train.txt')
train_imgs, train_lbls = prepare_sample(train_sample, train_load_pipeline)

  warn('`as_grey` has been deprecated in favor of `as_gray`')


Wall time: 35.8 s


In [5]:
%%time
valid_sample, valid_load_pipeline = load_sample('D:/Data/HTR/valid.txt')
valid_imgs, valid_lbls = prepare_sample(valid_sample, valid_load_pipeline)

Wall time: 25 s


In [6]:
train_pipeline = SequentialTransformer(
    RandomStretchTransformer(),
    FitSizeTransformer(128, 32),
    TransposeTransformer(),
    StandardizeTransformer())

valid_pipeline = SequentialTransformer(
    FitSizeTransformer(128, 32),
    TransposeTransformer(),
    StandardizeTransformer())

charlist = sorted(list(functools.reduce(set.union, [set(x) for x in train_lbls])))