In [1]:
import test_config as cfg
import numpy as np
import itertools, functools

%matplotlib inline
import matplotlib.pyplot as plt
import sys, os, os.path, time, datetime
import pickle, io, json

import skimage, skimage.io, skimage.transform, skimage.filters
import sklearn, sklearn.metrics

import importlib
sys.path.append('../src/')
import modutils
import word_processing as wp
import tqdm
import tensorflow as tf
import editdistance

ImportError: cannot import name '_validate_lengths'

In [2]:
class BaseTransformer:
    def __init__(self):
        pass
    def transform(self, x):
        return x

    
class SequentialTransformer:
    def __init__(self, *args):
        self.stages_ = args
        
    def transform(self, x):
        res = x
        for s in self.stages_:
            res = s.transform(res)
        return res
    
class LoadImageTransformer(BaseTransformer):
    def __init__(self, path):
        self.path_ = path
        
    def transform(self, x):
        if type(x) != str:
            raise Exception("LoadImageTransformer: expects filename as argument!")
        return skimage.io.imread(os.path.join(self.path_, x), as_grey=True)
    
class ConvertFloatTransformer(BaseTransformer):
    def __init__(self, min_value = 0.0, max_value = 1.0):
        self.min_ = min_value
        self.max_ = max_value
        
    def transform(self, x):
        if x.dtype in (np.float, np.float64, np.float32):
            return x
        if x.dtype == np.uint8:
            return (x / 255.0) * (self.max_ - self.min_) + self.min_
        if x.dtype == np.uint16:
            return (x / 65535.0) * (self.max_ - self.min_) + self.min_
        raise Exception("ConvertFloatTransformer: unexpected argument type!")
    
class RandomStretchTransformer(BaseTransformer):
    def __init__(self, min_scale = 0.66, max_scale = 1.5, fill_value=1.0):
        self.max_ = max_scale
        self.min_ = min_scale
        self.fill_ = fill_value
        
    def transform(self, x):
        f = np.random.uniform(self.min_, self.max_)
        return skimage.transform.rescale(x, (1.0, f), mode='constant', cval=self.fill_)
    
class TransposeTransformer(BaseTransformer):
    def __init__(self):
        pass
    def transform(self, x):
        return np.transpose(x)
    
class FitSizeTransformer(BaseTransformer):
    def __init__(self, width, height, fill_value=1.0):
        self.w_ = width
        self.h_ = height
        self.fill_ = fill_value
        self.template_ = np.ones((self.h_, self.w_)) * self.fill_
        
    def transform(self, x):
        (h, w) = x.shape
        f = max(w / self.w_, h / self.h_)
        res = self.template_.copy()
        rw = max(min(self.w_, int(w / f)), 1)
        rh = max(min(self.h_, int(h / f)), 1)
        res[0:rh, 0:rw] = skimage.transform.resize(x, (rh, rw), mode='constant', cval=self.fill_)
        return res
    
class StandardizeTransformer(BaseTransformer):
    def __init__(self):
        pass
    
    def transform(self, x):
        m = np.mean(x)
        s = np.std(x)
        if s <= 1e-9:
            return x - m
        return (x - m) / s
    
class TruncateLabelTransform(BaseTransformer):
    def __init__(self, max_cost):
        self.max_cost_ = max_cost
        
    def transform(self, x):
        if type(x) != str:
            raise Exception("TruncateLabelTransform: input expected to be of type string!")
        cost = 0
        for i in range(len(x)):
            flg = (i > 0) and (x[i] == x[i-1])
            cost += 1 + int(flg)
            if cost > max_cost:
                return x[:i]
        return x
    
def load_images(path, words, transform):
    res = []
    for (i, (word, fname)) in enumerate(tqdm.tqdm(words)):
        src_image = skimage.io.imread(os.path.join(path, fname), as_grey=True)
        res_image = wp.perform_transform(src_image, transform)
        res.append(res_image)
    return res

In [3]:
%%time
fname = 'D:/Data/bujo_sample_v2/dataset.json'
extraction_path = os.path.join(os.path.dirname(fname),
                               os.path.basename(fname).split('.')[0])
with open(fname, 'r', encoding='utf-8') as f:
    src = json.load(f)
    
words = wp.extract_words_from_dataset(src, (1,))

transform_pipeline = [
    {'type':'cutoff', 'cutoff':0.7},
    {'type':'trimx'}, {'type':'trimy'}, {'type':'resize', 'y':32}, {'type':'invert'}
]

Wall time: 119 ms


In [4]:
raw_imgs = load_images('D:/Data/bujo_sample_v2/dataset/', words, transform_pipeline)

100%|█████████████████████████████████████████████████████████████████████████████| 1579/1579 [00:13<00:00, 117.19it/s]


In [5]:
raw_imgs[0].shape

(32, 82)

In [6]:
model_fname = 'D:/htr-model/frozen_model.pb'

In [7]:
print('Loading model...')
tf.reset_default_graph()
tf_graph = tf.Graph()
tf_session = tf.InteractiveSession(graph = tf_graph)

with tf.gfile.GFile(model_fname, 'rb') as f:
    tf_graph_def = tf.GraphDef()
    tf_graph_def.ParseFromString(f.read())

print('Check out the input placeholders:')
nodes = [n.name + ' => ' +  n.op for n in tf_graph_def.node if n.op in ('Placeholder')]
for node in nodes:
    print(node)
    
with tf.Graph().as_default() as tf_graph:
    tf.import_graph_def(tf_graph_def, name="")

#self.input = tf.placeholder(np.float32, shape = [None, 32, 32, 3], name='input')
#self.dropout_rate = tf.placeholder(tf.float32, shape = [], name = 'dropout_rate')

#tf.import_graph_def(graph_def, {'input': self.input, 'dropout_rate': self.dropout_rate})
print('Model loading complete!')

Loading model...
Check out the input placeholders:
is_train => Placeholder
Placeholder => Placeholder
Placeholder_4 => Placeholder
Model loading complete!


In [11]:
tf.import_graph_def(tf_graph_def, name="")

In [8]:
with open('D:/htr-model/frozen_model.info', 'r') as fp:
    info = json.load(fp)

In [9]:
info

{'chars': '!"(),-.0123456789:?FLOPTabcdefhiklmnorstuwyАБВГДЕЖЗИКЛМНОПРСТУФЦЧЭЮЯабвгдежзийклмнопрстуфхцчшщъыьэюяё',
 'input': {'images': 'Placeholder',
  'is_train': 'is_train',
  'textlen': 'Placeholder_4'},
 'output': 'CTCGreedyDecoder',
 'textlen': 32}

In [12]:
tf_is_train = tf.get_default_graph().get_tensor_by_name(info['input']['is_train']+':0')
tf_in_images = tf.get_default_graph().get_tensor_by_name(info['input']['images']+':0')
tf_in_seqlens = tf.get_default_graph().get_tensor_by_name(info['input']['textlen']+':0')
tf_out_ctc0 = tf.get_default_graph().get_tensor_by_name(info['output']+':0')
tf_out_ctc1 = tf.get_default_graph().get_tensor_by_name(info['output']+':1')

In [17]:
def runBatch(session, imgs, chars):
    def decodeOutput(indices, values, batch_size):
        encoded_labels = [[] for i in range(batch_size)]

        for (k, (i, j)) in enumerate(indices):
            encoded_labels[i].append(values[k])

        return [''.join([chars[c] for c in x]) for x in encoded_labels]
    
    batch_size = len(imgs)
    evalList = [tf_out_ctc0, tf_out_ctc1]
    feedDict = {tf_in_images : imgs,
                tf_in_seqlens : [32] * batch_size,
                tf_is_train: False}
    indices, values = session.run(evalList, feedDict)
    return decodeOutput(indices, values, batch_size)

In [14]:
transform_pipeline_img = SequentialTransformer(
    FitSizeTransformer(128, 32),
    TransposeTransformer(),
    StandardizeTransformer())

In [15]:
tmp_imgs = [transform_pipeline_img.transform(x) for x in raw_imgs[:10]]

In [18]:
tmp = []
with tf.Session() as tfs:
    for i in range(len(tmp_imgs)//256+1):
        s0 = i * 256
        s1 = (i+1)*256
        print(i)
        tmp += runBatch(tfs, tmp_imgs[s0:s1], info['chars'])

0


In [19]:
pred = tmp

In [20]:
act = [x[0] for x in words]

In [21]:
import editdistance

In [22]:
dists = np.array([editdistance.distance(x, y) for (x, y) in zip(pred, act)])

In [23]:
len(dists), sum(dists==0), sum(dists==1), sum(dists<=2)

(10, 4, 4, 9)

In [24]:
with tf.Session() as tfs:
    tflite_model = tf.contrib.lite.toco_convert(tfs.graph_def,
                    [tf_is_train, tf_in_images, tf_in_seqlens], [tf_out_ctc0, tf_out_ctc1])
    open('D:/htr-model/frozen_model.tflite', 'wb').write(tflite_model)

AttributeError: module 'tensorflow.contrib' has no attribute 'lite'

In [26]:
import sys