Skip to content

Commit

Permalink
fix #279: save detect results (#281)
Browse files Browse the repository at this point in the history
* fix #279: save detect results

* rename

* set device as arg

* rm bash file
  • Loading branch information
cuhk-hbsun committed Jun 15, 2021
1 parent 3bfbb2b commit 87a7dce
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 222 deletions.
4 changes: 2 additions & 2 deletions docs/getting_started.md
Expand Up @@ -36,10 +36,10 @@ The predicted result will be saved as `demo/output.jpg`.

```shell
# for text detection
sh tools/test_imgs.sh ${CONFIG_FILE} ${CHECKPOINT_FILE} ${IMG_ROOT_PATH} ${IMG_LIST} ${RESULTS_DIR}
./tools/det_test_imgs.py ${IMG_ROOT_PATH} ${IMG_LIST} ${CONFIG_FILE} ${CHECKPOINT_FILE} --out-dir ${RESULTS_DIR}

# for text recognition
sh tools/ocr_test_imgs.sh ${CONFIG_FILE} ${CHECKPOINT_FILE} ${IMG_ROOT_PATH} ${IMG_LIST} ${RESULTS_DIR}
./tools/recog_test_imgs.py ${IMG_ROOT_PATH} ${IMG_LIST} ${CONFIG_FILE} ${CHECKPOINT_FILE} --out-dir ${RESULTS_DIR}
```
It will save both the prediction results and visualized images to `${RESULTS_DIR}`

Expand Down
113 changes: 113 additions & 0 deletions tools/det_test_imgs.py
@@ -0,0 +1,113 @@
#!/usr/bin/env python
import os.path as osp
from argparse import ArgumentParser

import mmcv
from mmcv.utils import ProgressBar

from mmdet.apis import inference_detector, init_detector
from mmocr.models import build_detector # noqa: F401
from mmocr.utils import list_from_file, list_to_file


def gen_target_path(target_root_path, src_name, suffix):
"""Gen target file path.
Args:
target_root_path (str): The target root path.
src_name (str): The source file name.
suffix (str): The suffix of target file.
"""
assert isinstance(target_root_path, str)
assert isinstance(src_name, str)
assert isinstance(suffix, str)

file_name = osp.split(src_name)[-1]
name = osp.splitext(file_name)[0]
return osp.join(target_root_path, name + suffix)


def save_results(result, out_dir, img_name, score_thr=0.3):
"""Save result of detected bounding boxes (quadrangle or polygon) to txt
file.
Args:
result (dict): Text Detection result for one image.
img_name (str): Image file name.
out_dir (str): Dir of txt files to save detected results.
score_thr (float, optional): Score threshold to filter bboxes.
"""
assert 'boundary_result' in result
assert score_thr > 0 and score_thr < 1

txt_file = gen_target_path(out_dir, img_name, '.txt')
valid_boundary_res = [
res for res in result['boundary_result'] if res[-1] > score_thr
]
lines = [
','.join([str(round(x)) for x in row]) for row in valid_boundary_res
]
list_to_file(txt_file, lines)


def main():
parser = ArgumentParser()
parser.add_argument('img_root', type=str, help='Image root path')
parser.add_argument('img_list', type=str, help='Image path list file')
parser.add_argument('config', type=str, help='Config file')
parser.add_argument('checkpoint', type=str, help='Checkpoint file')
parser.add_argument(
'--score-thr', type=float, default=0.5, help='Bbox score threshold')
parser.add_argument(
'--out-dir',
type=str,
default='./results',
help='Dir to save '
'visualize images '
'and bbox')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference.')
args = parser.parse_args()

assert 0 < args.score_thr < 1

# build the model from a config file and a checkpoint file
model = init_detector(args.config, args.checkpoint, device=args.device)
if hasattr(model, 'module'):
model = model.module
if model.cfg.data.test['type'] == 'ConcatDataset':
model.cfg.data.test.pipeline = model.cfg.data.test['datasets'][
0].pipeline

# Start Inference
out_vis_dir = osp.join(args.out_dir, 'out_vis_dir')
mmcv.mkdir_or_exist(out_vis_dir)
out_txt_dir = osp.join(args.out_dir, 'out_txt_dir')
mmcv.mkdir_or_exist(out_txt_dir)

total_img_num = sum([1 for _ in open(args.img_list)])
progressbar = ProgressBar(task_num=total_img_num)
for line in list_from_file(args.img_list):
progressbar.update()
img_path = osp.join(args.img_root, line.strip())
if not osp.exists(img_path):
raise FileNotFoundError(img_path)
# Test a single image
result = inference_detector(model, img_path)
img_name = osp.basename(img_path)
# save result
save_results(result, out_txt_dir, img_name, score_thr=args.score_thr)
# show result
out_file = osp.join(out_vis_dir, img_name)
kwargs_dict = {
'score_thr': args.score_thr,
'show': False,
'out_file': out_file
}
model.show_result(img_path, result, **kwargs_dict)

print(f'\nInference done, and results saved in {args.out_dir}\n')


if __name__ == '__main__':
main()
25 changes: 0 additions & 25 deletions tools/ocr_test_imgs.sh

This file was deleted.

14 changes: 7 additions & 7 deletions tools/ocr_test_imgs.py → tools/recog_test_imgs.py
Expand Up @@ -6,7 +6,6 @@
from itertools import compress

import mmcv
import torch
from mmcv.utils import ProgressBar

from mmdet.apis import init_detector
Expand Down Expand Up @@ -40,14 +39,16 @@ def save_results(img_paths, pred_labels, gt_labels, res_dir):

def main():
parser = ArgumentParser()
parser.add_argument('--img_root_path', type=str, help='Image root path')
parser.add_argument('--img_list', type=str, help='Image path list file')
parser.add_argument('--config', type=str, help='Config file')
parser.add_argument('--checkpoint', type=str, help='Checkpoint file')
parser.add_argument('img_root_path', type=str, help='Image root path')
parser.add_argument('img_list', type=str, help='Image path list file')
parser.add_argument('config', type=str, help='Config file')
parser.add_argument('checkpoint', type=str, help='Checkpoint file')
parser.add_argument(
'--out_dir', type=str, default='./results', help='Dir to save results')
parser.add_argument(
'--show', action='store_true', help='show image or save')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference.')
args = parser.parse_args()

# init the logger before other steps
Expand All @@ -56,8 +57,7 @@ def main():
logger = get_root_logger(log_file=log_file, log_level='INFO')

# build the model from a config file and a checkpoint file
device = 'cuda:' + str(torch.cuda.current_device())
model = init_detector(args.config, args.checkpoint, device=device)
model = init_detector(args.config, args.checkpoint, device=args.device)
if hasattr(model, 'module'):
model = model.module
if model.cfg.data.test['type'] == 'ConcatDataset':
Expand Down
165 changes: 0 additions & 165 deletions tools/test_imgs.py

This file was deleted.

23 changes: 0 additions & 23 deletions tools/test_imgs.sh

This file was deleted.

0 comments on commit 87a7dce

Please sign in to comment.