Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support batch inference for crnn and segocr #407

Merged
merged 2 commits into from Aug 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 6 additions & 4 deletions configs/_base_/recog_datasets/seg_toy_dataset.py
Expand Up @@ -41,6 +41,9 @@
meta_keys=['filename', 'ori_shape', 'img_shape'])
]

test_img_norm_cfg = dict(
mean=[x * 255 for x in img_norm_cfg['mean']],
std=[x * 255 for x in img_norm_cfg['std']])
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
Expand All @@ -49,13 +52,12 @@
min_width=64,
max_width=None,
keep_aspect_ratio=True),
dict(type='ToTensorOCR'),
dict(type='NormalizeOCR', **img_norm_cfg),
dict(type='CustomFormatBundle', call_super=False),
dict(type='Normalize', **test_img_norm_cfg),
dict(type='DefaultFormatBundle'),
dict(
type='Collect',
keys=['img'],
meta_keys=['filename', 'ori_shape', 'img_shape'])
meta_keys=['filename', 'ori_shape', 'resize_shape'])
]

prefix = 'tests/data/ocr_char_ann_toy_dataset/'
Expand Down
4 changes: 2 additions & 2 deletions configs/_base_/recog_datasets/toy_dataset.py
Expand Up @@ -13,7 +13,7 @@
type='Collect',
keys=['img'],
meta_keys=[
'filename', 'ori_shape', 'img_shape', 'text', 'valid_ratio'
'filename', 'ori_shape', 'resize_shape', 'text', 'valid_ratio'
]),
]
test_pipeline = [
Expand All @@ -34,7 +34,7 @@
type='Collect',
keys=['img'],
meta_keys=[
'filename', 'ori_shape', 'img_shape', 'valid_ratio'
'filename', 'ori_shape', 'resize_shape', 'valid_ratio'
]),
])
]
Expand Down
18 changes: 8 additions & 10 deletions configs/textrecog/crnn/crnn_academic_dataset.py
Expand Up @@ -39,7 +39,7 @@
total_epochs = 5

# data
img_norm_cfg = dict(mean=[0.5], std=[0.5])
img_norm_cfg = dict(mean=[127], std=[127])

train_pipeline = [
dict(type='LoadImageFromFile', color_type='grayscale'),
Expand All @@ -49,29 +49,27 @@
min_width=100,
max_width=100,
keep_aspect_ratio=False),
dict(type='ToTensorOCR'),
dict(type='NormalizeOCR', **img_norm_cfg),
dict(type='Normalize', **img_norm_cfg),
dict(type='DefaultFormatBundle'),
dict(
type='Collect',
keys=['img'],
meta_keys=[
'filename', 'ori_shape', 'img_shape', 'text', 'valid_ratio'
]),
meta_keys=['filename', 'resize_shape', 'text', 'valid_ratio']),
]
test_pipeline = [
dict(type='LoadImageFromFile', color_type='grayscale'),
dict(
type='ResizeOCR',
height=32,
min_width=4,
min_width=32,
max_width=None,
keep_aspect_ratio=True),
dict(type='ToTensorOCR'),
dict(type='NormalizeOCR', **img_norm_cfg),
dict(type='Normalize', **img_norm_cfg),
dict(type='DefaultFormatBundle'),
dict(
type='Collect',
keys=['img'],
meta_keys=['filename', 'ori_shape', 'img_shape', 'valid_ratio']),
meta_keys=['filename', 'resize_shape', 'valid_ratio']),
]

dataset_type = 'OCRDataset'
Expand Down
4 changes: 2 additions & 2 deletions configs/textrecog/nrtr/nrtr_r31_1by16_1by8_academic.py
Expand Up @@ -42,7 +42,7 @@
type='Collect',
keys=['img'],
meta_keys=[
'filename', 'ori_shape', 'img_shape', 'text', 'valid_ratio'
'filename', 'ori_shape', 'resize_shape', 'text', 'valid_ratio'
]),
]
test_pipeline = [
Expand All @@ -64,7 +64,7 @@
type='Collect',
keys=['img'],
meta_keys=[
'filename', 'ori_shape', 'img_shape', 'valid_ratio'
'filename', 'ori_shape', 'resize_shape', 'valid_ratio'
]),
])
]
Expand Down
4 changes: 2 additions & 2 deletions configs/textrecog/nrtr/nrtr_r31_1by8_1by4_academic.py
Expand Up @@ -42,7 +42,7 @@
type='Collect',
keys=['img'],
meta_keys=[
'filename', 'ori_shape', 'img_shape', 'text', 'valid_ratio'
'filename', 'ori_shape', 'resize_shape', 'text', 'valid_ratio'
]),
]
test_pipeline = [
Expand All @@ -64,7 +64,7 @@
type='Collect',
keys=['img'],
meta_keys=[
'filename', 'ori_shape', 'img_shape', 'valid_ratio'
'filename', 'ori_shape', 'resize_shape', 'valid_ratio'
]),
])
]
Expand Down
Expand Up @@ -26,7 +26,7 @@
type='Collect',
keys=['img'],
meta_keys=[
'filename', 'ori_shape', 'img_shape', 'text', 'valid_ratio'
'filename', 'ori_shape', 'resize_shape', 'text', 'valid_ratio'
]),
]
test_pipeline = [
Expand All @@ -48,7 +48,7 @@
type='Collect',
keys=['img'],
meta_keys=[
'filename', 'ori_shape', 'img_shape', 'valid_ratio'
'filename', 'ori_shape', 'resize_shape', 'valid_ratio'
]),
])
]
Expand Down
4 changes: 2 additions & 2 deletions configs/textrecog/sar/sar_r31_parallel_decoder_academic.py
Expand Up @@ -48,7 +48,7 @@
type='Collect',
keys=['img'],
meta_keys=[
'filename', 'ori_shape', 'img_shape', 'text', 'valid_ratio'
'filename', 'ori_shape', 'resize_shape', 'text', 'valid_ratio'
]),
]
test_pipeline = [
Expand All @@ -70,7 +70,7 @@
type='Collect',
keys=['img'],
meta_keys=[
'filename', 'ori_shape', 'img_shape', 'valid_ratio'
'filename', 'ori_shape', 'resize_shape', 'valid_ratio'
]),
])
]
Expand Down
4 changes: 2 additions & 2 deletions configs/textrecog/sar/sar_r31_parallel_decoder_chinese.py
Expand Up @@ -49,7 +49,7 @@
type='Collect',
keys=['img'],
meta_keys=[
'filename', 'ori_shape', 'img_shape', 'text', 'valid_ratio'
'filename', 'ori_shape', 'resize_shape', 'text', 'valid_ratio'
]),
]
test_pipeline = [
Expand All @@ -71,7 +71,7 @@
type='Collect',
keys=['img'],
meta_keys=[
'filename', 'ori_shape', 'img_shape', 'valid_ratio'
'filename', 'ori_shape', 'resize_shape', 'valid_ratio'
]),
])
]
Expand Down
4 changes: 2 additions & 2 deletions configs/textrecog/sar/sar_r31_parallel_decoder_toy_dataset.py
Expand Up @@ -24,7 +24,7 @@
type='Collect',
keys=['img'],
meta_keys=[
'filename', 'ori_shape', 'img_shape', 'text', 'valid_ratio'
'filename', 'ori_shape', 'resize_shape', 'text', 'valid_ratio'
]),
]
test_pipeline = [
Expand All @@ -41,7 +41,7 @@
type='Collect',
keys=['img'],
meta_keys=[
'filename', 'ori_shape', 'img_shape', 'valid_ratio',
'filename', 'ori_shape', 'resize_shape', 'valid_ratio',
'img_norm_cfg', 'ori_filename'
])
]
Expand Down
4 changes: 2 additions & 2 deletions configs/textrecog/sar/sar_r31_sequential_decoder_academic.py
Expand Up @@ -48,7 +48,7 @@
type='Collect',
keys=['img'],
meta_keys=[
'filename', 'ori_shape', 'img_shape', 'text', 'valid_ratio'
'filename', 'ori_shape', 'resize_shape', 'text', 'valid_ratio'
]),
]
test_pipeline = [
Expand All @@ -70,7 +70,7 @@
type='Collect',
keys=['img'],
meta_keys=[
'filename', 'ori_shape', 'img_shape', 'valid_ratio'
'filename', 'ori_shape', 'resize_shape', 'valid_ratio'
]),
])
]
Expand Down
12 changes: 7 additions & 5 deletions configs/textrecog/seg/seg_r31_1by16_fpnocr_academic.py
Expand Up @@ -72,9 +72,12 @@
dict(
type='Collect',
keys=['img', 'gt_kernels'],
meta_keys=['filename', 'ori_shape', 'img_shape'])
meta_keys=['filename', 'ori_shape', 'resize_shape'])
]

test_img_norm_cfg = dict(
mean=[x * 255 for x in img_norm_cfg['mean']],
std=[x * 255 for x in img_norm_cfg['std']])
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
Expand All @@ -83,13 +86,12 @@
min_width=64,
max_width=None,
keep_aspect_ratio=True),
dict(type='ToTensorOCR'),
dict(type='NormalizeOCR', **img_norm_cfg),
dict(type='CustomFormatBundle', call_super=False),
dict(type='Normalize', **test_img_norm_cfg),
dict(type='DefaultFormatBundle'),
dict(
type='Collect',
keys=['img'],
meta_keys=['filename', 'ori_shape', 'img_shape'])
meta_keys=['filename', 'ori_shape', 'resize_shape'])
]

train_img_root = 'data/mixture/'
Expand Down
4 changes: 2 additions & 2 deletions configs/textrecog/tps/crnn_tps_academic_dataset.py
Expand Up @@ -60,7 +60,7 @@
type='Collect',
keys=['img'],
meta_keys=[
'filename', 'ori_shape', 'img_shape', 'text', 'valid_ratio'
'filename', 'ori_shape', 'resize_shape', 'text', 'valid_ratio'
]),
]
test_pipeline = [
Expand All @@ -76,7 +76,7 @@
dict(
type='Collect',
keys=['img'],
meta_keys=['filename', 'ori_shape', 'img_shape', 'valid_ratio']),
meta_keys=['filename', 'ori_shape', 'resize_shape', 'valid_ratio']),
]

dataset_type = 'OCRDataset'
Expand Down
10 changes: 0 additions & 10 deletions mmocr/apis/inference.py
Expand Up @@ -26,20 +26,10 @@ def disable_text_recog_aug_test(cfg, set_types=None):
cfg.data[set_type].pipeline[0],
*cfg.data[set_type].pipeline[1].transforms
]
assert_if_not_support_batch_mode(cfg, set_type)

return cfg


def assert_if_not_support_batch_mode(cfg, set_type='test'):
if cfg.data[set_type].pipeline[1].type == 'ResizeOCR':
if cfg.data[set_type].pipeline[1].max_width is None:
raise Exception('Batch mode is not supported '
'since the image width is not fixed, '
'in the case that keeping aspect ratio but '
'max_width is none when do resize.')


def model_inference(model, imgs, batch_mode=False):
"""Inference image(s) with the detector.

Expand Down
2 changes: 2 additions & 0 deletions mmocr/core/visualize.py
Expand Up @@ -533,6 +533,8 @@ def draw_texts_by_pil(img, texts, boxes=None):
out_img = Image.new('RGB', (w, h), color=(255, 255, 255))
out_draw = ImageDraw.Draw(out_img)
for idx, (box, text) in enumerate(zip(boxes, texts)):
if len(text) == 0:
continue
min_x, max_x = min(box[0::2]), max(box[0::2])
min_y, max_y = min(box[1::2]), max(box[1::2])
color = tuple(list(color_list[idx % len(color_list)])[::-1])
Expand Down
5 changes: 4 additions & 1 deletion mmocr/models/textrecog/convertors/seg.py
Expand Up @@ -66,8 +66,11 @@ def tensor2str(self, output, img_metas=None):
texts, scores = [], []
for b in range(output.size(0)):
seg_pred = output[b].detach()
valid_width = int(
output.size(-1) * img_metas[b]['valid_ratio'] + 1)
seg_res = torch.argmax(
seg_pred, dim=0).cpu().numpy().astype(np.int32)
seg_pred[:, :, :valid_width],
dim=0).cpu().numpy().astype(np.int32)

seg_thr = np.where(seg_res == 0, 0, 255).astype(np.uint8)
_, labels, stats, centroids = cv2.connectedComponentsWithStats(
Expand Down
8 changes: 8 additions & 0 deletions mmocr/models/textrecog/recognizer/encode_decode_recognizer.py
Expand Up @@ -91,6 +91,10 @@ def forward_train(self, img, img_metas):
Returns:
dict[str, tensor]: A dictionary of loss components.
"""
for img_meta in img_metas:
valid_ratio = 1.0 * img_meta['resize_shape'][1] / img.size(-1)
img_meta['valid_ratio'] = valid_ratio

feat = self.extract_feat(img)

gt_labels = [img_meta['text'] for img_meta in img_metas]
Expand Down Expand Up @@ -123,6 +127,10 @@ def simple_test(self, img, img_metas, **kwargs):
Returns:
list[str]: Text label result of each image.
"""
for img_meta in img_metas:
valid_ratio = 1.0 * img_meta['resize_shape'][1] / img.size(-1)
img_meta['valid_ratio'] = valid_ratio

feat = self.extract_feat(img)

out_enc = None
Expand Down
4 changes: 4 additions & 0 deletions mmocr/models/textrecog/recognizer/seg_recognizer.py
Expand Up @@ -110,6 +110,10 @@ def simple_test(self, img, img_metas, **kwargs):

out_head = self.head(out_neck)

for img_meta in img_metas:
valid_ratio = 1.0 * img_meta['resize_shape'][1] / img.size(-1)
img_meta['valid_ratio'] = valid_ratio

texts, scores = self.label_convertor.tensor2str(out_head, img_metas)

# flatten batch results
Expand Down
28 changes: 0 additions & 28 deletions tests/test_apis/test_model_inference.py
Expand Up @@ -102,31 +102,3 @@ def test_model_batch_inference_recog(cfg_file):
results = model_inference(model, [img, img], batch_mode=True)

assert len(results) == 2


@pytest.mark.parametrize(
'cfg_file', ['../configs/textrecog/crnn/crnn_academic_dataset.py'])
def test_model_batch_inference_raises_exception_error_free_resize_recog(
cfg_file):
tmp_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
config_file = os.path.join(tmp_dir, cfg_file)
model = build_model(config_file)

with pytest.raises(
Exception,
match='Batch mode is not supported '
'since the image width is not fixed, '
'in the case that keeping aspect ratio but '
'max_width is none when do resize.'):
sample_img_path = os.path.join(tmp_dir, '../demo/demo_text_recog.jpg')
model_inference(
model, [sample_img_path, sample_img_path], batch_mode=True)

with pytest.raises(
Exception,
match='Batch mode is not supported '
'since the image width is not fixed, '
'in the case that keeping aspect ratio but '
'max_width is none when do resize.'):
img = imread(sample_img_path)
model_inference(model, [img, img], batch_mode=True)
2 changes: 1 addition & 1 deletion tests/test_models/test_recog_config.py
Expand Up @@ -26,7 +26,7 @@ def _demo_mm_inputs(num_kernels=0, input_shape=(1, 3, 300, 300),
img_metas = [{
'img_shape': (H, W, C),
'ori_shape': (H, W, C),
'pad_shape': (H, W, C),
'resize_shape': (H, W, C),
'filename': '<demo>.png',
'text': 'hello',
'valid_ratio': 1.0,
Expand Down