Permalink
Browse files

Add predict sample code for lstm+ctc ocr. Also update it's README.md (#…

  • Loading branch information...
1 parent 0e9f56a commit d6328a5087e4fb8bb11d39aaadf8f01745afe11c @BobLiu20 BobLiu20 committed with piiswrong Jan 11, 2017
@@ -34,9 +34,11 @@ I implement two examples, one is just a toy example which can be used to prove c
cd examples/warpctc
python lstm_ocr.py
```
-Note:
-* Please modify ```contexts = [mx.context.gpu(1)]``` in this file according to your hardware. If you only have one GPU pelase change 1 to 0(which GPU is selected.)
-* Please copy your font file to current folder. And instend of './data/Xerox.ttf' by your font file name. Maybe you can get a font from /usr/share/fonts/truetype/ in ubuntu.
+
+Notes:
+* Please modify ```contexts = [mx.context.gpu(0)]``` in this file according to your hardware.
+* Please review the code ```'./font/Ubuntu-M.ttf'```. Copy your font to here font/yourfont.ttf. To get a free font from [here](http://font.ubuntu.com/).
+* The checkpoint will be auto saved in each epoch. And then you can use this checkpoint to do a predict.
The OCR example is constructed as follows:
@@ -92,3 +94,15 @@ Following code show detail construction of the net:
If you label length is smaller than or equal to b. You should provide labels with length b, and for those samples which label length is smaller than b, you should append 0 to label data to make it have length b.
Here, 0 is reserved for blank label.
+
+## Do a predict
+
+Pelase run:
+
+```
+python ocr_predict.py
+```
+
+Notes:
+* Change the code following the name of your params and json file.
+* You have to do a ```make``` in amalgamation folder.(a libmxnet_predict.so will be created in lib folder.)
@@ -48,7 +48,8 @@ def get_label(buf):
class OCRIter(mx.io.DataIter):
def __init__(self, count, batch_size, num_label, init_states):
super(OCRIter, self).__init__()
- self.captcha = ImageCaptcha(fonts=['./data/Xerox.ttf'])
+ # you can get this font from http://font.ubuntu.com/
+ self.captcha = ImageCaptcha(fonts=['./font/Ubuntu-M.ttf'])
self.batch_size = batch_size
self.count = count
self.num_label = num_label
@@ -140,7 +141,7 @@ def Accuracy(label, pred):
momentum = 0.9
num_label = 4
- contexts = [mx.context.gpu(1)]
+ contexts = [mx.context.gpu(0)]
def sym_gen(seq_len):
return lstm_unroll(num_lstm_layer, seq_len,
@@ -172,6 +173,7 @@ def sym_gen(seq_len):
model.fit(X=data_train, eval_data=data_val,
eval_metric = mx.metric.np(Accuracy),
- batch_end_callback=mx.callback.Speedometer(BATCH_SIZE, 50),)
+ batch_end_callback=mx.callback.Speedometer(BATCH_SIZE, 50),
+ epoch_end_callback = mx.callback.do_checkpoint(prefix, 1))
@handong1587
handong1587 Jan 12, 2017

Hi, I might have missed something, I am wondering where is variable "prefix" defined? I cannot find any clue about the definition :-(
epoch_end_callback = mx.callback.do_checkpoint(prefix, 1))

model.save("ocr")
@@ -0,0 +1,83 @@
+#!/usr/bin/env python2.7
+# coding=utf-8
+from __future__ import print_function
+import sys, os
+curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
+sys.path.append("../../amalgamation/python/")
+sys.path.append("../../python/")
+
+from mxnet_predict import Predictor
+import mxnet as mx
+
+import numpy as np
+import cv2
+import os
+
+class lstm_ocr_model(object):
+ # Keep Zero index for blank. (CTC request it)
+ CONST_CHAR='0123456789'
+ def __init__(self, path_of_json, path_of_params):
+ super(lstm_ocr_model, self).__init__()
+ self.path_of_json = path_of_json
+ self.path_of_params = path_of_params
+ self.predictor = None
+ self.__init_ocr()
+
+ def __init_ocr(self):
+ num_label = 4 # Set your max length of label, add one more for blank
+ batch_size = 1
+
+ num_hidden = 100
+ num_lstm_layer = 2
+ init_c = [('l%d_init_c'%l, (batch_size, num_hidden)) for l in range(num_lstm_layer)]
+ init_h = [('l%d_init_h'%l, (batch_size, num_hidden)) for l in range(num_lstm_layer)]
+ init_states = init_c + init_h
+
+ all_shapes = [('data', (batch_size, 80 * 30))] + init_states + [('label', (batch_size, num_label))]
+ all_shapes_dict = {}
+ for _shape in all_shapes:
+ all_shapes_dict[_shape[0]] = _shape[1]
+ self.predictor = Predictor(open(self.path_of_json).read(),
+ open(self.path_of_params).read(),
+ all_shapes_dict)
+
+ def forward_ocr(self, img):
+ img = cv2.resize(img, (80, 30))
+ img = img.transpose(1, 0)
+ img = img.reshape((80 * 30))
+ img = np.multiply(img, 1/255.0)
+ self.predictor.forward(data=img)
+ prob = self.predictor.get_output(0)
+ label_list = []
+ for p in prob:
+ max_index = np.argsort(p)[::-1][0]
+ label_list.append(max_index)
+ return self.__get_string(label_list)
+
+ def __get_string(self, label_list):
+ # Do CTC label rule
+ # CTC cannot emit a repeated symbol on consecutive timesteps
+ ret = []
+ label_list2 = [0] + list(label_list)
+ for i in range(len(label_list)):
+ c1 = label_list2[i]
+ c2 = label_list2[i+1]
+ if c2 == 0 or c2 == c1:
+ continue
+ ret.append(c2)
+ # change to ascii
+ s = ''
+ for l in ret:
+ if l > 0 and l < (len(lstm_ocr_model.CONST_CHAR)+1):
+ c = lstm_ocr_model.CONST_CHAR[l-1]
+ else:
+ c = ''
+ s += c
+ return s
+
+if __name__ == '__main__':
+ _lstm_ocr_model = lstm_ocr_model('ocr-symbol.json', 'ocr-0010.params')
+ img = cv2.imread('sample.jpg', 0)
+ _str = _lstm_ocr_model.forward_ocr(img)
+ print('Result: ', _str)
+
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit d6328a5

Please sign in to comment.