In [None]:
# %load_ext autoreload
%reload_ext autoreload
%autoreload 2
import os
import sys
sys.path.insert(0,'../')
import time
import datetime
import cv2
import numpy as np
import uuid
import json
import matplotlib.pyplot as plt
import functools
import logging
import collections
import tensorflow as tf
import model
from icdar import restore_rectangle
import lanms
from eval import resize_image, sort_poly, detect
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


In [None]:
class Config:
    SAVE_DIR = 'static/results'
class EAST_Pridictor(object):
    def __init__(self,checkpoint_path):
        self.sess,self.input_images,self.f_score, self.f_geometry = self.load_model(checkpoint_path)        
        # input_tensor
        
    def load_model(self, checkpoint_path):
        tf.reset_default_graph()
        sess = tf.Session()
        input_images = tf.placeholder(tf.float32, shape=[None, None, None, 3], name='input_images')
        global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False)

        f_score, f_geometry = model.model(input_images, is_training=False)

        variable_averages = tf.train.ExponentialMovingAverage(0.997, global_step)
        saver = tf.train.Saver(variable_averages.variables_to_restore())

        sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))

        ckpt_state = tf.train.get_checkpoint_state(checkpoint_path)
        model_path = os.path.join(checkpoint_path, os.path.basename(ckpt_state.model_checkpoint_path))
        logger.info('Restore from {}'.format(model_path))
        saver.restore(sess, model_path)
        return sess,input_images,f_score, f_geometry
    def predict(self,img):
        """
        :return: {
            'text_lines': [
                {
                    'score': ,
                    'x0': ,
                    'y0': ,
                    'x1': ,
                    ...
                    'y3': ,
                }
            ],
            'rtparams': {  # runtime parameters
                'image_size': ,
                'working_size': ,
            },
            'timing': {
                'net': ,
                'restore': ,
                'nms': ,
                'cpuinfo': ,
                'meminfo': ,
                'uptime': ,
            }
        }
        """
        start_time = time.time()
        rtparams = collections.OrderedDict()
        rtparams['start_time'] = datetime.datetime.now().isoformat()
        rtparams['image_size'] = '{}x{}'.format(img.shape[1], img.shape[0])
        timer = collections.OrderedDict([
            ('net', 0),
            ('restore', 0),
            ('nms', 0)
        ])

        im_resized, (ratio_h, ratio_w) = resize_image(img)
        rtparams['working_size'] = '{}x{}'.format(
            im_resized.shape[1], im_resized.shape[0])
        start = time.time()
        score, geometry = self.sess.run(
            [self.f_score, self.f_geometry],
            feed_dict={self.input_images: [im_resized[:,:,::-1]]})
        timer['net'] = time.time() - start

        boxes, timer = detect(score_map=score, geo_map=geometry, timer=timer)
        logger.info('net {:.0f}ms, restore {:.0f}ms, nms {:.0f}ms'.format(
            timer['net']*1000, timer['restore']*1000, timer['nms']*1000))

        if boxes is not None:
            scores = boxes[:,8].reshape(-1)
            boxes = boxes[:, :8].reshape((-1, 4, 2))
            boxes[:, :, 0] /= ratio_w
            boxes[:, :, 1] /= ratio_h

        duration = time.time() - start_time
        timer['overall'] = duration
        logger.info('[timing] {}'.format(duration))

        text_lines = []
        if boxes is not None:
            text_lines = []
            for box, score in zip(boxes, scores):
                box = sort_poly(box.astype(np.int32))
                if np.linalg.norm(box[0] - box[1]) < 5 or np.linalg.norm(box[3]-box[0]) < 5:
                    continue
                tl = collections.OrderedDict(zip(
                    ['x0', 'y0', 'x1', 'y1', 'x2', 'y2', 'x3', 'y3'],
                    map(float, box.flatten())))
                tl['score'] = float(score)
                text_lines.append(tl)
        ret = {
            'text_lines': text_lines,
            'rtparams': rtparams,
            'timing': timer,
        }
        return ret

def draw_illu(illu, rst):
    for t in rst['text_lines']:
        d = np.array([t['x0'], t['y0'], t['x1'], t['y1'], t['x2'],
                      t['y2'], t['x3'], t['y3']], dtype='int32')
        d = d.reshape(-1, 2)
        cv2.polylines(illu, [d], isClosed=True, color=(255, 255, 0),thickness=2)
    return illu

def save_result(img, rst):
    session_id = str(uuid.uuid1())
    dirpath = os.path.join(config.SAVE_DIR, session_id)
    os.makedirs(dirpath)

    # save input image
    output_path = os.path.join(dirpath, 'input.png')
    cv2.imwrite(output_path, img)

    # save illustration
    output_path = os.path.join(dirpath, 'output.png')
    output_img = draw_illu(img.copy(), rst)
    cv2.imwrite(output_path, output_img)

    # save json data
    output_path = os.path.join(dirpath, 'result.json')
    with open(output_path, 'w') as f:
        json.dump(rst, f)

    rst['session_id'] = session_id
    return rst,output_img

In [None]:
checkpoint_path = '../../EAST/model/east_icdar2015_resnet_v1_50_rbox/'
predictor = EAST_Pridictor(checkpoint_path)

In [None]:
import glob
import random
img_dir = '/clever/cabernet/wuwenhui/data/invoice/small_test/*jpg'
# img_dir = '../data/icdar2015/test_images/*jpg'
img_path = random.choice(glob.glob(img_dir))
img = cv2.imread(img_path)
# ret = predictor(img)
config = Config()

%time rst = predictor.predict(img)
ret,ouput_img = save_result(img, rst)
plt.figure(figsize=(20,20))
plt.imshow(ouput_img[:,:,::-1])

