In [1]:
# %load text_extractor.py
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from scipy import misc
import sys
import os
import math
import argparse
import numpy as np
import mxnet as mx
import random
import cv2
import sklearn
from skimage import transform as trans
from easydict import EasyDict as edict
from text_detector import TextDetector
from crnn.symbols.crnn_test import crnn_lstm_test, BucketingSymTest
from rcnn.config import default, generate_config


class TextRecognizer:
  def __init__(self, prefix, epoch, ctx, network='simplenet', use_lstm = True, char_dict=None):
    sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
    all_layers = sym.get_internals()
    sym = all_layers['fc7']
    default_data_width = 280
    #sym = crnn_lstm_test('simplenet', len(char_dict), data_width//8)
    #model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], image_size[1]))], label_shapes=[('softmax_label', (args.batch_size,))])
    self.batch_size = 1
    num_hidden = 100
    num_lstm_layer = 2
    default_data_shape = (1, 32, default_data_width)
    data_names = ['data']
    data_shapes=[('data', (self.batch_size,)+default_data_shape)]
    self.use_lstm = use_lstm
    if use_lstm:
      init_c = [('l%d_init_c' % l, (self.batch_size, num_hidden)) for l in range(num_lstm_layer * 2)]
      init_h = [('l%d_init_h' % l, (self.batch_size, num_hidden)) for l in range(num_lstm_layer * 2)]
      self.init_states = init_c + init_h
      data_names += [x[0] for x in self.init_states]
      data_shapes += self.init_states
      self.init_state_arrays = [mx.nd.zeros(x[1]) for x in self.init_states]

    #model = mx.mod.Module(symbol=sym, context=ctx, data_names=data_names, label_names = None)
    G = BucketingSymTest(network, use_lstm, data_names = data_names, num_lstm_layer=num_lstm_layer, num_hidden=num_hidden, num_classes=len(char_dict))
    sym = G.get
    model = mx.mod.BucketingModule(
          sym_gen             = sym,
          default_bucket_key  = default_data_width//8,
          context             = ctx)
    model.bind(data_shapes=data_shapes)
    model.set_params(arg_params, aux_params)
    self.model = model
    self.char_dict = char_dict
    self.default_data_shape = default_data_shape
    self.bucket_group_size = 80

  def get_text(self, img):
    #TODO img padding if width is smaller than data_shape?
    #if img.shape[0]!=data_shape[1] or img.shape[1]!=data_shape[2]:
    #  img = cv2.resize(img, (data_shape[2], data_shape[1]))
    default_height = self.default_data_shape[1]
    scale = float(default_height)/img.shape[0]
    img = cv2.resize(img, None, None, fx=scale, fy=scale)
    assert img.shape[0]==default_height
    img_width = img.shape[1]
    if img_width%self.bucket_group_size!=0:
      expect_width = (img_width//self.bucket_group_size+1)*self.bucket_group_size
      img = cv2.copyMakeBorder(img, 0, 0, 0, expect_width-img_width, cv2.BORDER_CONSTANT, value=(255,255,255))
      img_width = img.shape[1]
    #if img.shape[1]<self.data_shape[2]:
    print(img.shape)
    if self.default_data_shape[0]==3:
      nimg = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
      nimg = np.transpose(nimg, (2,0,1))
      nimg = np.expand_dims(nimg, axis=0)
    else:
      nimg = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
      nimg = nimg.reshape( (1,1)+nimg.shape )
    data = [mx.nd.array(nimg)]
    seq_length = img.shape[1]//8
    provide_data=[('data', (self.batch_size,)+self.default_data_shape[0:2]+(img.shape[1],))]
    if self.use_lstm:
      data+=self.init_state_arrays
      provide_data += self.init_states
    db = mx.io.DataBatch(data=data, bucket_key = seq_length, provide_data=provide_data)
    self.model.forward(db, is_train=False)
    acts = self.model.get_outputs()[0].asnumpy()
    sel = np.argmax(acts, axis=1)
    ret = []
    for charid in sel:
      if charid==0:
        continue
      char = self.char_dict[charid]
      print('char', char)
      ret.append(char)
    return ret


dict_file = 'charset.txt'
idx = 0
char_dict = {}
for line in open(dict_file, 'r'):
  line = line.strip()
  char_dict[idx] = line
  idx+=1
assert len(char_dict)==5990

ctx_id = 7
ctx = mx.gpu(ctx_id)

#recognizer = TextRecognizer('./model/fp1', 0, ctx, network='resnet', use_lstm=False, char_dict=char_dict)

In [2]:
char_dict = {}
dict_file = 'charset.txt'
idx = 0
for line in open(dict_file, 'r'):
  line = line.strip()
  char_dict[idx] = line
  idx+=1

In [3]:
from text_recognizer import TextRecognizer

In [4]:
from matplotlib import pyplot as plt

In [5]:
import os

In [6]:
detector = TextDetector(default.network, './model/final', 0, ctx_id, mask_nms=True)

In [7]:
print(char_dict)

{0: 'blank', 1: '\xef\xbc\x8c', 2: '\xe7\x9a\x84', 3: '\xe3\x80\x82', 4: '\xe4\xb8\x80', 5: '\xe6\x98\xaf', 6: '0', 7: '\xe4\xb8\x8d', 8: '\xe5\x9c\xa8', 9: '\xe6\x9c\x89', 10: '\xe3\x80\x81', 11: '\xe4\xba\xba', 12: '\xe2\x80\x9c', 13: '\xe2\x80\x9d', 14: '\xe4\xba\x86', 15: '\xe4\xb8\xad', 16: '\xe5\x9b\xbd', 17: '\xe5\xa4\xa7', 18: '\xe4\xb8\xba', 19: '1', 20: ':', 21: '\xe4\xb8\x8a', 22: '2', 23: '\xe8\xbf\x99', 24: '\xe4\xb8\xaa', 25: '\xe4\xbb\xa5', 26: '\xe5\xb9\xb4', 27: '\xe7\x94\x9f', 28: '\xe5\x92\x8c', 29: '\xe6\x88\x91', 30: '\xe6\x97\xb6', 31: '\xe4\xb9\x8b', 32: '\xe4\xb9\x9f', 33: '\xe6\x9d\xa5', 34: '\xe5\x88\xb0', 35: '\xe8\xa6\x81', 36: '\xe4\xbc\x9a', 37: '\xe5\xad\xa6', 38: '\xe5\xaf\xb9', 39: '\xe4\xb8\x9a', 40: '\xe5\x87\xba', 41: '\xe8\xa1\x8c', 42: '\xe5\x85\xac', 43: '\xe8\x83\xbd', 44: '\xe4\xbb\x96', 45: '\xe4\xba\x8e', 46: '5', 47: 'e', 48: '3', 49: '\xe8\x80\x8c', 50: '\xe5\x8f\x91', 51: '\xe5\x9c\xb0', 52: '\xe5\x8f\xaf', 53: '\xe4\xbd\x9c', 54: '\xe5\xb0

In [8]:
from text_recognizer import TextRecognizer
#from crnn.symbols.crnn_test import crnn_lstm_test, BucketingSymTest
import sys
import mxnet as mx
stdi, stdo, stde = sys.stdin, sys.stdout, sys.stderr
reload(sys)
import cv2
sys.setdefaultencoding('utf-8')
sys.stdin, sys.stdout, sys.stderr = stdi, stdo, stde
dict_file = 'charset.txt'
idx = 0
char_dict = {}
for line in open(dict_file, 'r'):
  line = line.strip()
  char_dict[idx] = line
  idx+=1
assert len(char_dict)==5990




In [9]:
recognizer = TextRecognizer('./model/fp6', 0, ctx = mx.gpu(ctx_id), network='resnet', char_dict=char_dict)

In [23]:
img = cv2.imread("9743FFF007BA48F70CDDC5EDE3D7332C.jpg")
boxes = detector.detect(img,scales=[0.8,1.3],thresh =0.85)
# f = file(os.path.join("aligned",path.split(".")[0]+".txt"),"w")
for box in boxes:
    x1,y1,w1,h1 = cv2.boundingRect(box)
    x1,y1,x2,y2 = x1,y1,(x1+w1),(y1+h1)
    sub_image = img[y1:y2,x1:x2]
    txt = ""
    sub_image = cv2.cvtColor(sub_image,cv2.COLOR_BGR2GRAY)

    flag,sub_image = cv2.threshold(sub_image, 0, 255, cv2.THRESH_BINARY|cv2.THRESH_OTSU)
    sub_image = cv2.cvtColor(sub_image,cv2.COLOR_GRAY2BGR)
    for one in recognizer.get_text(sub_image):
        txt+=one
    print(txt)
    
cv2.imwrite("./test.png",img)



232101198704080831
黑龙江省双城市双城镇车
证
站街三委九组
公民身份号码~~〇
姓.名____
郝洪刚
生1987、年、4月_8.日
住址.___
民、族-汉


True

In [12]:
for path in os.listdir("aligned"):
    if path.endswith(".tif"):
        p = os.path.join("aligned",path)
        img = cv2.imread(p)
        print(img.shape)
        boxes = detector.detect(img,scales=[0.8,1.3],thresh =0.85)
        f = file(os.path.join("aligned",path.split(".")[0]+".txt"),"w")
        for box in boxes:
            x1,y1,w1,h1 = cv2.boundingRect(box)
            cv2.polylines(img,[box],True,(255,0,0))
            f.write("[{0},{1},{2},{3}] \n".format(x1,y1,w1,h1))
        f.close()
        cv2.imwrite(os.path.join("temps",path.split(".")[0]+".jpg"),img)
        print(path)

(1084, 1723, 3)
30.tif
(1084, 1723, 3)
38.tif
(1084, 1723, 3)
40.tif
(1084, 1723, 3)
14.tif
(1084, 1723, 3)
27.tif
(1084, 1723, 3)
29.tif
(1084, 1723, 3)
25.tif
(1084, 1723, 3)
34.tif
(1084, 1723, 3)
35.tif
(1084, 1723, 3)
33.tif
(1084, 1723, 3)
24.tif
(1084, 1723, 3)
39.tif
(1084, 1723, 3)
26.tif
