Skip to content

Commit

Permalink
support batch inference for crnn and segocr (open-mmlab#407)
Browse files Browse the repository at this point in the history
* support batch inference for crnn and segocr
  • Loading branch information
cuhk-hbsun committed Aug 3, 2021
1 parent f96187d commit 2d51a7f
Show file tree
Hide file tree
Showing 20 changed files with 68 additions and 79 deletions.
10 changes: 6 additions & 4 deletions configs/_base_/recog_datasets/seg_toy_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@
meta_keys=['filename', 'ori_shape', 'img_shape'])
]

test_img_norm_cfg = dict(
mean=[x * 255 for x in img_norm_cfg['mean']],
std=[x * 255 for x in img_norm_cfg['std']])
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
Expand All @@ -49,13 +52,12 @@
min_width=64,
max_width=None,
keep_aspect_ratio=True),
dict(type='ToTensorOCR'),
dict(type='NormalizeOCR', **img_norm_cfg),
dict(type='CustomFormatBundle', call_super=False),
dict(type='Normalize', **test_img_norm_cfg),
dict(type='DefaultFormatBundle'),
dict(
type='Collect',
keys=['img'],
meta_keys=['filename', 'ori_shape', 'img_shape'])
meta_keys=['filename', 'ori_shape', 'resize_shape'])
]

prefix = 'tests/data/ocr_char_ann_toy_dataset/'
Expand Down
4 changes: 2 additions & 2 deletions configs/_base_/recog_datasets/toy_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
type='Collect',
keys=['img'],
meta_keys=[
'filename', 'ori_shape', 'img_shape', 'text', 'valid_ratio'
'filename', 'ori_shape', 'resize_shape', 'text', 'valid_ratio'
]),
]
test_pipeline = [
Expand All @@ -34,7 +34,7 @@
type='Collect',
keys=['img'],
meta_keys=[
'filename', 'ori_shape', 'img_shape', 'valid_ratio'
'filename', 'ori_shape', 'resize_shape', 'valid_ratio'
]),
])
]
Expand Down
18 changes: 8 additions & 10 deletions configs/textrecog/crnn/crnn_academic_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
total_epochs = 5

# data
img_norm_cfg = dict(mean=[0.5], std=[0.5])
img_norm_cfg = dict(mean=[127], std=[127])

train_pipeline = [
dict(type='LoadImageFromFile', color_type='grayscale'),
Expand All @@ -49,29 +49,27 @@
min_width=100,
max_width=100,
keep_aspect_ratio=False),
dict(type='ToTensorOCR'),
dict(type='NormalizeOCR', **img_norm_cfg),
dict(type='Normalize', **img_norm_cfg),
dict(type='DefaultFormatBundle'),
dict(
type='Collect',
keys=['img'],
meta_keys=[
'filename', 'ori_shape', 'img_shape', 'text', 'valid_ratio'
]),
meta_keys=['filename', 'resize_shape', 'text', 'valid_ratio']),
]
test_pipeline = [
dict(type='LoadImageFromFile', color_type='grayscale'),
dict(
type='ResizeOCR',
height=32,
min_width=4,
min_width=32,
max_width=None,
keep_aspect_ratio=True),
dict(type='ToTensorOCR'),
dict(type='NormalizeOCR', **img_norm_cfg),
dict(type='Normalize', **img_norm_cfg),
dict(type='DefaultFormatBundle'),
dict(
type='Collect',
keys=['img'],
meta_keys=['filename', 'ori_shape', 'img_shape', 'valid_ratio']),
meta_keys=['filename', 'resize_shape', 'valid_ratio']),
]

dataset_type = 'OCRDataset'
Expand Down
4 changes: 2 additions & 2 deletions configs/textrecog/nrtr/nrtr_r31_1by16_1by8_academic.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
type='Collect',
keys=['img'],
meta_keys=[
'filename', 'ori_shape', 'img_shape', 'text', 'valid_ratio'
'filename', 'ori_shape', 'resize_shape', 'text', 'valid_ratio'
]),
]
test_pipeline = [
Expand All @@ -64,7 +64,7 @@
type='Collect',
keys=['img'],
meta_keys=[
'filename', 'ori_shape', 'img_shape', 'valid_ratio'
'filename', 'ori_shape', 'resize_shape', 'valid_ratio'
]),
])
]
Expand Down
4 changes: 2 additions & 2 deletions configs/textrecog/nrtr/nrtr_r31_1by8_1by4_academic.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
type='Collect',
keys=['img'],
meta_keys=[
'filename', 'ori_shape', 'img_shape', 'text', 'valid_ratio'
'filename', 'ori_shape', 'resize_shape', 'text', 'valid_ratio'
]),
]
test_pipeline = [
Expand All @@ -64,7 +64,7 @@
type='Collect',
keys=['img'],
meta_keys=[
'filename', 'ori_shape', 'img_shape', 'valid_ratio'
'filename', 'ori_shape', 'resize_shape', 'valid_ratio'
]),
])
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
type='Collect',
keys=['img'],
meta_keys=[
'filename', 'ori_shape', 'img_shape', 'text', 'valid_ratio'
'filename', 'ori_shape', 'resize_shape', 'text', 'valid_ratio'
]),
]
test_pipeline = [
Expand All @@ -48,7 +48,7 @@
type='Collect',
keys=['img'],
meta_keys=[
'filename', 'ori_shape', 'img_shape', 'valid_ratio'
'filename', 'ori_shape', 'resize_shape', 'valid_ratio'
]),
])
]
Expand Down
4 changes: 2 additions & 2 deletions configs/textrecog/sar/sar_r31_parallel_decoder_academic.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
type='Collect',
keys=['img'],
meta_keys=[
'filename', 'ori_shape', 'img_shape', 'text', 'valid_ratio'
'filename', 'ori_shape', 'resize_shape', 'text', 'valid_ratio'
]),
]
test_pipeline = [
Expand All @@ -70,7 +70,7 @@
type='Collect',
keys=['img'],
meta_keys=[
'filename', 'ori_shape', 'img_shape', 'valid_ratio'
'filename', 'ori_shape', 'resize_shape', 'valid_ratio'
]),
])
]
Expand Down
4 changes: 2 additions & 2 deletions configs/textrecog/sar/sar_r31_parallel_decoder_chinese.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
type='Collect',
keys=['img'],
meta_keys=[
'filename', 'ori_shape', 'img_shape', 'text', 'valid_ratio'
'filename', 'ori_shape', 'resize_shape', 'text', 'valid_ratio'
]),
]
test_pipeline = [
Expand All @@ -71,7 +71,7 @@
type='Collect',
keys=['img'],
meta_keys=[
'filename', 'ori_shape', 'img_shape', 'valid_ratio'
'filename', 'ori_shape', 'resize_shape', 'valid_ratio'
]),
])
]
Expand Down
4 changes: 2 additions & 2 deletions configs/textrecog/sar/sar_r31_parallel_decoder_toy_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
type='Collect',
keys=['img'],
meta_keys=[
'filename', 'ori_shape', 'img_shape', 'text', 'valid_ratio'
'filename', 'ori_shape', 'resize_shape', 'text', 'valid_ratio'
]),
]
test_pipeline = [
Expand All @@ -41,7 +41,7 @@
type='Collect',
keys=['img'],
meta_keys=[
'filename', 'ori_shape', 'img_shape', 'valid_ratio',
'filename', 'ori_shape', 'resize_shape', 'valid_ratio',
'img_norm_cfg', 'ori_filename'
])
]
Expand Down
4 changes: 2 additions & 2 deletions configs/textrecog/sar/sar_r31_sequential_decoder_academic.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
type='Collect',
keys=['img'],
meta_keys=[
'filename', 'ori_shape', 'img_shape', 'text', 'valid_ratio'
'filename', 'ori_shape', 'resize_shape', 'text', 'valid_ratio'
]),
]
test_pipeline = [
Expand All @@ -70,7 +70,7 @@
type='Collect',
keys=['img'],
meta_keys=[
'filename', 'ori_shape', 'img_shape', 'valid_ratio'
'filename', 'ori_shape', 'resize_shape', 'valid_ratio'
]),
])
]
Expand Down
12 changes: 7 additions & 5 deletions configs/textrecog/seg/seg_r31_1by16_fpnocr_academic.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,12 @@
dict(
type='Collect',
keys=['img', 'gt_kernels'],
meta_keys=['filename', 'ori_shape', 'img_shape'])
meta_keys=['filename', 'ori_shape', 'resize_shape'])
]

test_img_norm_cfg = dict(
mean=[x * 255 for x in img_norm_cfg['mean']],
std=[x * 255 for x in img_norm_cfg['std']])
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
Expand All @@ -83,13 +86,12 @@
min_width=64,
max_width=None,
keep_aspect_ratio=True),
dict(type='ToTensorOCR'),
dict(type='NormalizeOCR', **img_norm_cfg),
dict(type='CustomFormatBundle', call_super=False),
dict(type='Normalize', **test_img_norm_cfg),
dict(type='DefaultFormatBundle'),
dict(
type='Collect',
keys=['img'],
meta_keys=['filename', 'ori_shape', 'img_shape'])
meta_keys=['filename', 'ori_shape', 'resize_shape'])
]

train_img_root = 'data/mixture/'
Expand Down
4 changes: 2 additions & 2 deletions configs/textrecog/tps/crnn_tps_academic_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
type='Collect',
keys=['img'],
meta_keys=[
'filename', 'ori_shape', 'img_shape', 'text', 'valid_ratio'
'filename', 'ori_shape', 'resize_shape', 'text', 'valid_ratio'
]),
]
test_pipeline = [
Expand All @@ -76,7 +76,7 @@
dict(
type='Collect',
keys=['img'],
meta_keys=['filename', 'ori_shape', 'img_shape', 'valid_ratio']),
meta_keys=['filename', 'ori_shape', 'resize_shape', 'valid_ratio']),
]

dataset_type = 'OCRDataset'
Expand Down
10 changes: 0 additions & 10 deletions mmocr/apis/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,10 @@ def disable_text_recog_aug_test(cfg, set_types=None):
cfg.data[set_type].pipeline[0],
*cfg.data[set_type].pipeline[1].transforms
]
assert_if_not_support_batch_mode(cfg, set_type)

return cfg


def assert_if_not_support_batch_mode(cfg, set_type='test'):
if cfg.data[set_type].pipeline[1].type == 'ResizeOCR':
if cfg.data[set_type].pipeline[1].max_width is None:
raise Exception('Batch mode is not supported '
'since the image width is not fixed, '
'in the case that keeping aspect ratio but '
'max_width is none when do resize.')


def model_inference(model, imgs, batch_mode=False):
"""Inference image(s) with the detector.
Expand Down
2 changes: 2 additions & 0 deletions mmocr/core/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,8 @@ def draw_texts_by_pil(img, texts, boxes=None):
out_img = Image.new('RGB', (w, h), color=(255, 255, 255))
out_draw = ImageDraw.Draw(out_img)
for idx, (box, text) in enumerate(zip(boxes, texts)):
if len(text) == 0:
continue
min_x, max_x = min(box[0::2]), max(box[0::2])
min_y, max_y = min(box[1::2]), max(box[1::2])
color = tuple(list(color_list[idx % len(color_list)])[::-1])
Expand Down
5 changes: 4 additions & 1 deletion mmocr/models/textrecog/convertors/seg.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,11 @@ def tensor2str(self, output, img_metas=None):
texts, scores = [], []
for b in range(output.size(0)):
seg_pred = output[b].detach()
valid_width = int(
output.size(-1) * img_metas[b]['valid_ratio'] + 1)
seg_res = torch.argmax(
seg_pred, dim=0).cpu().numpy().astype(np.int32)
seg_pred[:, :, :valid_width],
dim=0).cpu().numpy().astype(np.int32)

seg_thr = np.where(seg_res == 0, 0, 255).astype(np.uint8)
_, labels, stats, centroids = cv2.connectedComponentsWithStats(
Expand Down
8 changes: 8 additions & 0 deletions mmocr/models/textrecog/recognizer/encode_decode_recognizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ def forward_train(self, img, img_metas):
Returns:
dict[str, tensor]: A dictionary of loss components.
"""
for img_meta in img_metas:
valid_ratio = 1.0 * img_meta['resize_shape'][1] / img.size(-1)
img_meta['valid_ratio'] = valid_ratio

feat = self.extract_feat(img)

gt_labels = [img_meta['text'] for img_meta in img_metas]
Expand Down Expand Up @@ -123,6 +127,10 @@ def simple_test(self, img, img_metas, **kwargs):
Returns:
list[str]: Text label result of each image.
"""
for img_meta in img_metas:
valid_ratio = 1.0 * img_meta['resize_shape'][1] / img.size(-1)
img_meta['valid_ratio'] = valid_ratio

feat = self.extract_feat(img)

out_enc = None
Expand Down
4 changes: 4 additions & 0 deletions mmocr/models/textrecog/recognizer/seg_recognizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@ def simple_test(self, img, img_metas, **kwargs):

out_head = self.head(out_neck)

for img_meta in img_metas:
valid_ratio = 1.0 * img_meta['resize_shape'][1] / img.size(-1)
img_meta['valid_ratio'] = valid_ratio

texts, scores = self.label_convertor.tensor2str(out_head, img_metas)

# flatten batch results
Expand Down
28 changes: 0 additions & 28 deletions tests/test_apis/test_model_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,31 +102,3 @@ def test_model_batch_inference_recog(cfg_file):
results = model_inference(model, [img, img], batch_mode=True)

assert len(results) == 2


@pytest.mark.parametrize(
'cfg_file', ['../configs/textrecog/crnn/crnn_academic_dataset.py'])
def test_model_batch_inference_raises_exception_error_free_resize_recog(
cfg_file):
tmp_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
config_file = os.path.join(tmp_dir, cfg_file)
model = build_model(config_file)

with pytest.raises(
Exception,
match='Batch mode is not supported '
'since the image width is not fixed, '
'in the case that keeping aspect ratio but '
'max_width is none when do resize.'):
sample_img_path = os.path.join(tmp_dir, '../demo/demo_text_recog.jpg')
model_inference(
model, [sample_img_path, sample_img_path], batch_mode=True)

with pytest.raises(
Exception,
match='Batch mode is not supported '
'since the image width is not fixed, '
'in the case that keeping aspect ratio but '
'max_width is none when do resize.'):
img = imread(sample_img_path)
model_inference(model, [img, img], batch_mode=True)
2 changes: 1 addition & 1 deletion tests/test_models/test_recog_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def _demo_mm_inputs(num_kernels=0, input_shape=(1, 3, 300, 300),
img_metas = [{
'img_shape': (H, W, C),
'ori_shape': (H, W, C),
'pad_shape': (H, W, C),
'resize_shape': (H, W, C),
'filename': '<demo>.png',
'text': 'hello',
'valid_ratio': 1.0,
Expand Down

0 comments on commit 2d51a7f

Please sign in to comment.