Skip to content

Commit

Permalink
[Feature] Add User friendly API (#366)
Browse files Browse the repository at this point in the history
  • Loading branch information
samayala22 committed Jul 15, 2021
1 parent 782e966 commit 87bd295
Showing 1 changed file with 360 additions and 0 deletions.
360 changes: 360 additions & 0 deletions mmocr/utils/ocr.py
@@ -0,0 +1,360 @@
import os
from argparse import ArgumentParser, Namespace

import mmcv
from mmdet.apis import init_detector

from mmocr.apis.inference import model_inference
from mmocr.core.visualize import det_recog_show_result
from mmocr.datasets.pipelines.crop import crop_img
from mmocr.utils.box_util import stitch_boxes_into_lines

textdet_models = {
'DB_r18': {
'config': 'dbnet/dbnet_r18_fpnc_1200e_icdar2015.py',
'ckpt':
'dbnet/dbnet_r18_fpnc_sbn_1200e_icdar2015_20210329-ba3ab597.pth'
},
'DB_r50': {
'config':
'dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py',
'ckpt':
'dbnet/dbnet_r50dcnv2_fpnc_sbn_1200e_icdar2015_20210325-91cef9af.pth'
},
'DRRG': {
'config': 'drrg/drrg_r50_fpn_unet_1200e_ctw1500.py',
'ckpt': 'drrg/drrg_r50_fpn_unet_1200e_ctw1500-1abf4f67.pth'
},
'FCE_ICDAR15': {
'config': 'fcenet/fcenet_r50_fpn_1500e_icdar2015.py',
'ckpt': 'fcenet/fcenet_r50_fpn_1500e_icdar2015-d435c061.pth'
},
'FCE_CTW_DCNv2': {
'config': 'fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500.py',
'ckpt': 'fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500-05d740bb.pth'
},
'MaskRCNN_CTW': {
'config': 'maskrcnn/mask_rcnn_r50_fpn_160e_ctw1500.py',
'ckpt': 'maskrcnn/mask_rcnn_r50_fpn_160e_ctw1500_20210219-96497a76.pth'
},
'MaskRCNN_ICDAR15': {
'config': 'maskrcnn/mask_rcnn_r50_fpn_160e_icdar2015.py',
'ckpt':
'maskrcnn/mask_rcnn_r50_fpn_160e_icdar2015_20210219-8eb340a3.pth'
},
'MaskRCNN_ICDAR17': {
'config': 'maskrcnn/mask_rcnn_r50_fpn_160e_icdar2017.py',
'ckpt':
'maskrcnn/mask_rcnn_r50_fpn_160e_icdar2017_20210218-c6ec3ebb.pth'
},
'PANet_CTW': {
'config': 'panet/panet_r18_fpem_ffm_600e_ctw1500.py',
'ckpt':
'panet/panet_r18_fpem_ffm_sbn_600e_ctw1500_20210219-3b3a9aa3.pth'
},
'PANet_ICDAR15': {
'config': 'panet/panet_r18_fpem_ffm_600e_icdar2015.py',
'ckpt':
'panet/panet_r18_fpem_ffm_sbn_600e_icdar2015_20210219-42dbe46a.pth'
},
'PS_CTW': {
'config': 'psenet/psenet_r50_fpnf_600e_ctw1500.py',
'ckpt': 'psenet/psenet_r50_fpnf_600e_ctw1500_20210401-216fed50.pth'
},
'PS_ICDAR15': {
'config': 'psenet/psenet_r50_fpnf_600e_icdar2015.py',
'ckpt': 'psenet/psenet_r50_fpnf_600e_icdar2015_pretrain-eefd8fe6.pth'
},
'TextSnake': {
'config': 'textsnake/textsnake_r50_fpn_unet_1200e_ctw1500.py',
'ckpt': 'textsnake/textsnake_r50_fpn_unet_1200e_ctw1500-27f65b64.pth'
}
}

textrecog_models = {
'CRNN': {
'config': 'crnn/crnn_academic_dataset.py',
'ckpt': 'crnn/crnn_academic-a723a1c5.pth'
},
'SAR': {
'config': 'sar/sar_r31_parallel_decoder_academic.py',
'ckpt': 'sar/sar_r31_parallel_decoder_academic-dba3a4a3.pth'
},
'NRTR_1/16-1/8': {
'config': 'nrtr/nrtr_r31_1by16_1by8_academic.py',
'ckpt': 'nrtr/nrtr_r31_academic_20210406-954db95e.pth'
},
'NRTR_1/8-1/4': {
'config': 'nrtr/nrtr_r31_1by8_1by4_academic.py',
'ckpt': 'nrtr/nrtr_r31_1by8_1by4_academic_20210406-ce16e7cc.pth'
},
'RobustScanner': {
'config': 'robust_scanner/robustscanner_r31_academic.py',
'ckpt': 'robust_scanner/robustscanner_r31_academic-5f05874f.pth'
},
'SEG': {
'config': 'seg/seg_r31_1by16_fpnocr_academic.py',
'ckpt': 'seg/seg_r31_1by16_fpnocr_academic-72235b11.pth'
},
'CRNN_TPS': {
'config': 'tps/crnn_tps_academic_dataset.py',
'ckpt': 'tps/crnn_tps_academic_dataset_20210510-d221a905.pth'
}
}


def det_recog_pp(args, det_recog_result):
if args.export_json:
mmcv.dump(
det_recog_result,
args.out_img + '.json',
ensure_ascii=False,
indent=4)
if args.ocr_in_lines:
res = det_recog_result['result']
res = stitch_boxes_into_lines(res, 10, 0.5)
det_recog_result['result'] = res
mmcv.dump(
det_recog_result,
args.out_img + '.line.json',
ensure_ascii=False,
indent=4)
if args.out_img or args.imshow:
res_img = det_recog_show_result(args.img, det_recog_result)
if args.out_img:
mmcv.imwrite(res_img, args.out_img)
if args.imshow:
mmcv.imshow(res_img, 'predicted results')
if not args.details:
det_recog_result = [x['text'] for x in det_recog_result['result']]
if args.print_result:
print(det_recog_result)
return det_recog_result


def single_pp(args, result, model):
if args.export_json:
mmcv.dump(result, args.out_img + '.json', ensure_ascii=False, indent=4)
if args.out_img or args.imshow:
model.show_result(
args.img, result, out_file=args.out_img, show=args.imshow)
if args.print_result:
print(result)
return result


def det_and_recog_inference(args, det_model, recog_model):
image = args.img
if isinstance(image, str):
end2end_res = {'filename': image}
else:
end2end_res = {}
end2end_res['result'] = []
image = mmcv.imread(image)
det_result = model_inference(det_model, image)
bboxes = det_result['boundary_result']

box_imgs = []
for bbox in bboxes:
box_res = {}
box_res['box'] = [round(x) for x in bbox[:-1]]
box_res['box_score'] = float(bbox[-1])
box = bbox[:8]
if len(bbox) > 9:
min_x = min(bbox[0:-1:2])
min_y = min(bbox[1:-1:2])
max_x = max(bbox[0:-1:2])
max_y = max(bbox[1:-1:2])
box = [min_x, min_y, max_x, min_y, max_x, max_y, min_x, max_y]
box_img = crop_img(image, box)
if args.batch_mode:
box_imgs.append(box_img)
else:
recog_result = model_inference(recog_model, box_img)
text = recog_result['text']
text_score = recog_result['score']
if isinstance(text_score, list):
text_score = sum(text_score) / max(1, len(text))
box_res['text'] = text
box_res['text_score'] = text_score

end2end_res['result'].append(box_res)

if args.batch_mode:
batch_size = args.batch_size
for chunk_idx in range(len(box_imgs) // batch_size + 1):
start_idx = chunk_idx * batch_size
end_idx = (chunk_idx + 1) * batch_size
chunk_box_imgs = box_imgs[start_idx:end_idx]
if len(chunk_box_imgs) == 0:
continue
recog_results = model_inference(
recog_model, chunk_box_imgs, batch_mode=True)
for i, recog_result in enumerate(recog_results):
text = recog_result['text']
text_score = recog_result['score']
if isinstance(text_score, list):
text_score = sum(text_score) / max(1, len(text))
end2end_res['result'][start_idx + i]['text'] = text
end2end_res['result'][start_idx + i]['text_score'] = text_score

return end2end_res


def main():
args = parse_args()
ocr = MMOCR(**vars(args))
ocr.readtext(**vars(args))


def parse_args():
parser = ArgumentParser()
parser.add_argument('img', type=str, help='Input Image file.')
parser.add_argument(
'--out_img',
type=str,
default='',
help='Output file name of the visualized image.')
parser.add_argument(
'--det',
type=str,
default='PANet_ICDAR15',
help='Text detection algorithm')
parser.add_argument(
'--det-config',
type=str,
default='',
help='Path to the custom config of the selected textdet model')
parser.add_argument(
'--recog', type=str, default='SEG', help='Text recognition algorithm')
parser.add_argument(
'--recog-config',
type=str,
default='',
help='Path to the custom config of the selected textrecog model')
parser.add_argument(
'--batch-mode',
action='store_true',
help='Whether use batch mode for text recognition.')
parser.add_argument(
'--batch-size',
type=int,
default=4,
help='Batch size for text recognition inference')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference.')
parser.add_argument(
'--export-json',
action='store_true',
help='Whether export the ocr results in a json file.')
parser.add_argument(
'--details',
action='store_true',
help='Whether include the text boxes coordinates and confidence values'
)
parser.add_argument(
'--imshow',
action='store_true',
help='Whether show image with OpenCV.')
parser.add_argument(
'--ocr-in-lines',
action='store_true',
help='Whether group ocr results in lines.')
parser.add_argument(
'--print-result',
action='store_true',
help='Prints the recognised text')
args = parser.parse_args()
return args


class MMOCR:

def __init__(self,
det='PANet_ICDAR15',
det_config='',
recog='SEG',
recog_config='',
device='cuda:0',
**kwargs):
print(det, recog)
self.td = det
self.tr = recog
if device == 'cpu':
self.device = 0
else:
self.device = device

if self.td and self.td not in textdet_models:
raise ValueError(self.td,
'is not a supported text detection algorthm')
elif self.tr and self.tr not in textrecog_models:
raise ValueError(self.tr,
'is not a supported text recognition algorithm')

dir_path = os.getcwd()

if self.td:
# build detection model
if not det_config:
det_config = dir_path + '/configs/textdet/' + textdet_models[
self.td]['config']
det_ckpt = 'https://download.openmmlab.com/mmocr/textdet/' + \
textdet_models[self.td]['ckpt']

self.detect_model = init_detector(
det_config, det_ckpt, device=self.device)
else:
self.detect_model = None

if self.tr:
# build recognition model
if not recog_config:
recog_config = dir_path + '/configs/textrecog/' + \
textrecog_models[self.tr]['config']
recog_ckpt = 'https://download.openmmlab.com/mmocr/textrecog/' + \
textrecog_models[self.tr]['ckpt']

self.recog_model = init_detector(
recog_config, recog_ckpt, device=self.device)
else:
self.recog_model = None

# Attribute check
for model in list(filter(None, [self.recog_model, self.detect_model])):
if hasattr(model, 'module'):
model = model.module
if model.cfg.data.test['type'] == 'ConcatDataset':
model.cfg.data.test.pipeline = \
model.cfg.data.test['datasets'][0].pipeline

def readtext(self,
img,
out_img=None,
details=False,
export_json=False,
batch_mode=False,
batch_size=4,
imshow=False,
ocr_in_lines=False,
print_result=False,
**kwargs):
args = locals()
[args.pop(x, None) for x in ['kwargs', 'self']]
args = Namespace(**args)
if self.detect_model and self.recog_model:
det_recog_result = det_and_recog_inference(args, self.detect_model,
self.recog_model)
pp_result = det_recog_pp(args, det_recog_result)
elif self.detect_model:
result = model_inference(self.detect_model, args.img)
pp_result = single_pp(args, result, self.detect_model)
elif self.recog_model:
result = model_inference(self.recog_model, args.img)
pp_result = single_pp(args, result, self.recog_model)

return pp_result


if __name__ == '__main__':
main()

0 comments on commit 87bd295

Please sign in to comment.