From 0b360a5afe611638900d9a0591fefa5872b41f60 Mon Sep 17 00:00:00 2001 From: lizz Date: Mon, 24 May 2021 14:01:42 +0800 Subject: [PATCH] Add list_from_file and list_to_file (#226) * Add list_from_file and list_to_file Signed-off-by: lizz * Add test list_to_file and list_from_file * more * Fix tests --- mmocr/datasets/kie_dataset.py | 7 +- mmocr/datasets/utils/loader.py | 5 +- mmocr/models/kie/extractors/sdmgr.py | 9 +- mmocr/models/textrecog/convertors/base.py | 10 +- mmocr/utils/__init__.py | 4 +- mmocr/utils/fileio.py | 31 ++++++ mmocr/utils/lmdb_util.py | 9 +- tests/test_utils/test_textio.py | 46 +++++++++ tools/data/textdet/coco_to_line_dict.py | 32 +++---- tools/data/textdet/ctw1500_converter.py | 8 +- tools/data/textdet/icdar_converter.py | 9 +- .../data/textrecog/seg_synthtext_converter.py | 94 +++++++++---------- tools/data/textrecog/svt_converter.py | 60 ++++++------ tools/ocr_test_imgs.py | 92 +++++++++--------- tools/test_imgs.py | 45 +++++---- 15 files changed, 261 insertions(+), 200 deletions(-) create mode 100644 mmocr/utils/fileio.py create mode 100644 tests/test_utils/test_textio.py diff --git a/mmocr/datasets/kie_dataset.py b/mmocr/datasets/kie_dataset.py index 8d5061dfdf..02b30c2fc0 100644 --- a/mmocr/datasets/kie_dataset.py +++ b/mmocr/datasets/kie_dataset.py @@ -1,15 +1,14 @@ import copy from os import path as osp -import mmcv import numpy as np import torch -import mmocr.utils as utils from mmdet.datasets.builder import DATASETS from mmocr.core import compute_f1_score from mmocr.datasets.base_dataset import BaseDataset from mmocr.datasets.pipelines import sort_vertex8 +from mmocr.utils import is_type_list, list_from_file @DATASETS.register_module() @@ -52,7 +51,7 @@ def __init__(self, '': 0, **{ line.rstrip('\r\n'): ind - for ind, line in enumerate(mmcv.list_from_file(dict_file), 1) + for ind, line in enumerate(list_from_file(dict_file), 1) } } @@ -79,7 +78,7 @@ def _parse_anno_info(self, annotations): box_num * (box_num + 1). """ - assert utils.is_type_list(annotations, dict) + assert is_type_list(annotations, dict) assert len(annotations) > 0, 'Please remove data with empty annotation' assert 'box' in annotations[0] assert 'text' in annotations[0] diff --git a/mmocr/datasets/utils/loader.py b/mmocr/datasets/utils/loader.py index b68eb87e53..e9c65f0de2 100644 --- a/mmocr/datasets/utils/loader.py +++ b/mmocr/datasets/utils/loader.py @@ -1,8 +1,7 @@ import os.path as osp -import mmcv - from mmocr.datasets.builder import LOADERS, build_parser +from mmocr.utils import list_from_file @LOADERS.register_module() @@ -60,7 +59,7 @@ class HardDiskLoader(Loader): """ def _load(self, ann_file): - return mmcv.list_from_file(ann_file) + return list_from_file(ann_file) @LOADERS.register_module() diff --git a/mmocr/models/kie/extractors/sdmgr.py b/mmocr/models/kie/extractors/sdmgr.py index c8275a28a9..9a754661dc 100644 --- a/mmocr/models/kie/extractors/sdmgr.py +++ b/mmocr/models/kie/extractors/sdmgr.py @@ -8,6 +8,7 @@ from mmdet.models.builder import DETECTORS, build_roi_extractor from mmdet.models.detectors import SingleStageDetector from mmocr.core import imshow_edge_node +from mmocr.utils import list_from_file @DETECTORS.register_module() @@ -126,11 +127,9 @@ def show_result(self, idx_to_cls = {} if self.class_list is not None: - with open(self.class_list, 'r') as fr: - for line in fr: - line = line.strip().split() - class_idx, class_label = line - idx_to_cls[class_idx] = class_label + for line in list_from_file(self.class_list): + class_idx, class_label = line.strip().split() + idx_to_cls[class_idx] = class_label # if out_file specified, do not show image in window if out_file is not None: diff --git a/mmocr/models/textrecog/convertors/base.py b/mmocr/models/textrecog/convertors/base.py index 9002d4fed0..e54a232a3f 100644 --- a/mmocr/models/textrecog/convertors/base.py +++ b/mmocr/models/textrecog/convertors/base.py @@ -1,4 +1,5 @@ from mmocr.models.builder import CONVERTORS +from mmocr.utils import list_from_file @CONVERTORS.register_module() @@ -27,11 +28,10 @@ def __init__(self, dict_type='DICT90', dict_file=None, dict_list=None): assert dict_list is None or isinstance(dict_list, list) self.idx2char = [] if dict_file is not None: - with open(dict_file, encoding='utf-8') as fr: - for line in fr: - line = line.strip() - if line != '': - self.idx2char.append(line) + for line in list_from_file(dict_file): + line = line.strip() + if line != '': + self.idx2char.append(line) elif dict_list is not None: self.idx2char = dict_list else: diff --git a/mmocr/utils/__init__.py b/mmocr/utils/__init__.py index 98cb3657ac..3de0212e4f 100644 --- a/mmocr/utils/__init__.py +++ b/mmocr/utils/__init__.py @@ -4,6 +4,7 @@ is_none_or_type, is_type_list, valid_boundary) from .collect_env import collect_env from .data_convert_util import convert_annotations +from .fileio import list_from_file, list_to_file from .img_util import drop_orientation, is_not_png from .lmdb_util import lmdb_converter from .logger import get_root_logger @@ -12,5 +13,6 @@ 'Registry', 'build_from_cfg', 'get_root_logger', 'collect_env', 'is_3dlist', 'is_ndarray_list', 'is_type_list', 'is_none_or_type', 'equal_len', 'is_2dlist', 'valid_boundary', 'lmdb_converter', - 'drop_orientation', 'convert_annotations', 'is_not_png' + 'drop_orientation', 'convert_annotations', 'is_not_png', 'list_to_file', + 'list_from_file' ] diff --git a/mmocr/utils/fileio.py b/mmocr/utils/fileio.py new file mode 100644 index 0000000000..8dfd20c021 --- /dev/null +++ b/mmocr/utils/fileio.py @@ -0,0 +1,31 @@ +def list_to_file(filename, lines): + """Write a list of strings to a text file. + + Args: + filename (str): The output filename. It will be created/overwritten. + lines (list(str)): Data to be written. + """ + with open(filename, 'w', encoding='utf-8') as fw: + for line in lines: + fw.write(f'{line}\n') + + +def list_from_file(filename, encoding='utf-8'): + """Load a text file and parse the content as a list of strings. The + trailing "\\r" and "\\n" of each line will be removed. + + Note: + This will be replaced by mmcv's version after it supports encoding. + + Args: + filename (str): Filename. + encoding (str): Encoding used to open the file. Default utf-8. + + Returns: + list[str]: A list of strings. + """ + item_list = [] + with open(filename, 'r', encoding=encoding) as f: + for line in f: + item_list.append(line.rstrip('\n\r')) + return item_list diff --git a/mmocr/utils/lmdb_util.py b/mmocr/utils/lmdb_util.py index 704365d9cf..24edae96da 100644 --- a/mmocr/utils/lmdb_util.py +++ b/mmocr/utils/lmdb_util.py @@ -5,11 +5,12 @@ import lmdb +from mmocr.utils import list_from_file -def lmdb_converter(img_list, output, batch_size=1000, coding='utf-8'): - # read img_list - with open(img_list) as f: - lines = f.readlines() + +def lmdb_converter(img_list_file, output, batch_size=1000, coding='utf-8'): + # read img_list_file + lines = list_from_file(img_list_file) # create lmdb database if Path(output).is_dir(): diff --git a/tests/test_utils/test_textio.py b/tests/test_utils/test_textio.py new file mode 100644 index 0000000000..2b504dfd58 --- /dev/null +++ b/tests/test_utils/test_textio.py @@ -0,0 +1,46 @@ +import tempfile + +from mmocr.utils import list_from_file, list_to_file + +lists = [ + [], + [' '], + ['\t'], + ['a'], + [1], + [1.], + ['a', 'b'], + ['a', 1, 1.], + [1, 1., 'a'], + ['啊', '啊啊'], + ['選択', 'noël', 'Информацией', 'ÄÆä'], +] + + +def test_list_to_file(): + with tempfile.TemporaryDirectory() as tmpdirname: + for i, lines in enumerate(lists): + filename = f'{tmpdirname}/{i}.txt' + list_to_file(filename, lines) + lines2 = [ + line.rstrip('\r\n') + for line in open(filename, 'r', encoding='utf-8').readlines() + ] + lines = list(map(str, lines)) + assert len(lines) == len(lines2) + assert all(line1 == line2 for line1, line2 in zip(lines, lines2)) + + +def test_list_from_file(): + with tempfile.TemporaryDirectory() as tmpdirname: + for encoding in ['utf-8', 'utf-8-sig']: + for lineend in ['\n', '\r\n']: + for i, lines in enumerate(lists): + filename = f'{tmpdirname}/{i}.txt' + with open(filename, 'w', encoding=encoding) as f: + f.writelines(f'{line}{lineend}' for line in lines) + lines2 = list_from_file(filename, encoding=encoding) + lines = list(map(str, lines)) + assert len(lines) == len(lines2) + assert all(line1 == line2 + for line1, line2 in zip(lines, lines2)) diff --git a/tools/data/textdet/coco_to_line_dict.py b/tools/data/textdet/coco_to_line_dict.py index 44b14505ae..ed81d4031c 100644 --- a/tools/data/textdet/coco_to_line_dict.py +++ b/tools/data/textdet/coco_to_line_dict.py @@ -1,16 +1,13 @@ import argparse -import codecs import json +import mmcv -def read_json(fpath): - with codecs.open(fpath, 'r', 'utf-8') as f: - obj = json.load(f) - return obj +from mmocr.utils import list_to_file def parse_coco_json(in_path): - json_obj = read_json(in_path) + json_obj = mmcv.load(in_path) image_infos = json_obj['images'] annotations = json_obj['annotations'] imgid2imgname = {} @@ -35,18 +32,17 @@ def parse_coco_json(in_path): def gen_line_dict_file(out_path, imgid2imgname, imgid2anno): - # import pdb; pdb.set_trace() - with codecs.open(out_path, 'w', 'utf-8') as fw: - for key, value in imgid2imgname.items(): - if key in imgid2anno: - anno = imgid2anno[key] - line_dict = {} - line_dict['file_name'] = value['file_name'] - line_dict['height'] = value['height'] - line_dict['width'] = value['width'] - line_dict['annotations'] = anno - line_dict_str = json.dumps(line_dict) - fw.write(line_dict_str + '\n') + lines = [] + for key, value in imgid2imgname.items(): + if key in imgid2anno: + anno = imgid2anno[key] + line_dict = {} + line_dict['file_name'] = value['file_name'] + line_dict['height'] = value['height'] + line_dict['width'] = value['width'] + line_dict['annotations'] = anno + lines.append(json.dumps(line_dict)) + list_to_file(out_path, lines) def parse_args(): diff --git a/tools/data/textdet/ctw1500_converter.py b/tools/data/textdet/ctw1500_converter.py index 186181f970..18b6f7666a 100644 --- a/tools/data/textdet/ctw1500_converter.py +++ b/tools/data/textdet/ctw1500_converter.py @@ -8,7 +8,8 @@ import numpy as np from shapely.geometry import Polygon -from mmocr.utils import convert_annotations, drop_orientation, is_not_png +from mmocr.utils import (convert_annotations, drop_orientation, is_not_png, + list_from_file) def collect_files(img_dir, gt_dir, split): @@ -84,11 +85,8 @@ def collect_annotations(files, split, nproc=1): def load_txt_info(gt_file, img_info): - with open(gt_file) as f: - gt_list = f.readlines() - anno_info = [] - for line in gt_list: + for line in list_from_file(gt_file): # each line has one ploygen (n vetices), and one text. # e.g., 695,885,866,888,867,1146,696,1143,####Latin 9 line = line.strip() diff --git a/tools/data/textdet/icdar_converter.py b/tools/data/textdet/icdar_converter.py index cab8bfe6e0..2c6b3f6ec4 100644 --- a/tools/data/textdet/icdar_converter.py +++ b/tools/data/textdet/icdar_converter.py @@ -7,7 +7,8 @@ import numpy as np from shapely.geometry import Polygon -from mmocr.utils import convert_annotations, drop_orientation, is_not_png +from mmocr.utils import (convert_annotations, drop_orientation, is_not_png, + list_from_file) def collect_files(img_dir, gt_dir): @@ -96,11 +97,9 @@ def load_img_info(files, dataset): assert img.shape[0:2] == img_color.shape[0:2] if dataset == 'icdar2017': - with open(gt_file) as f: - gt_list = f.readlines() + gt_list = list_from_file(gt_file) elif dataset == 'icdar2015': - with open(gt_file, mode='r', encoding='utf-8-sig') as f: - gt_list = f.readlines() + gt_list = list_from_file(gt_file, encoding='utf-8-sig') else: raise NotImplementedError(f'Not support {dataset}') diff --git a/tools/data/textrecog/seg_synthtext_converter.py b/tools/data/textrecog/seg_synthtext_converter.py index fc4e060006..14ec35f497 100644 --- a/tools/data/textrecog/seg_synthtext_converter.py +++ b/tools/data/textrecog/seg_synthtext_converter.py @@ -4,64 +4,62 @@ import cv2 +from mmocr.utils import list_from_file, list_to_file + def parse_old_label(data_root, in_path, img_size=False): imgid2imgname = {} imgid2anno = {} idx = 0 - with open(in_path, 'r') as fr: - for line in fr: - line = line.strip().split() - img_full_path = osp.join(data_root, line[0]) - if not osp.exists(img_full_path): - continue - ann_file = osp.join(data_root, line[1]) - if not osp.exists(ann_file): - continue - - img_info = {} - img_info['file_name'] = line[0] - if img_size: - img = cv2.imread(img_full_path) - h, w = img.shape[:2] - img_info['height'] = h - img_info['width'] = w - imgid2imgname[idx] = img_info - - imgid2anno[idx] = [] - char_annos = [] - with open(ann_file, 'r') as fr: - t = 0 - for line in fr: - line = line.strip() - if t == 0: - img_info['text'] = line - else: - char_box = [float(x) for x in line.split()] - char_text = img_info['text'][t - 1] - char_ann = dict(char_box=char_box, char_text=char_text) - char_annos.append(char_ann) - t += 1 - imgid2anno[idx] = char_annos - idx += 1 + for line in list_from_file(in_path): + line = line.strip().split() + img_full_path = osp.join(data_root, line[0]) + if not osp.exists(img_full_path): + continue + ann_file = osp.join(data_root, line[1]) + if not osp.exists(ann_file): + continue + + img_info = {} + img_info['file_name'] = line[0] + if img_size: + img = cv2.imread(img_full_path) + h, w = img.shape[:2] + img_info['height'] = h + img_info['width'] = w + imgid2imgname[idx] = img_info + + imgid2anno[idx] = [] + char_annos = [] + for t, ann_line in enumerate(list_from_file(ann_file)): + ann_line = ann_line.strip() + if t == 0: + img_info['text'] = ann_line + else: + char_box = [float(x) for x in ann_line.split()] + char_text = img_info['text'][t - 1] + char_ann = dict(char_box=char_box, char_text=char_text) + char_annos.append(char_ann) + imgid2anno[idx] = char_annos + idx += 1 return imgid2imgname, imgid2anno def gen_line_dict_file(out_path, imgid2imgname, imgid2anno, img_size=False): - with open(out_path, 'w', encoding='utf-8') as fw: - for key, value in imgid2imgname.items(): - if key in imgid2anno: - anno = imgid2anno[key] - line_dict = {} - line_dict['file_name'] = value['file_name'] - line_dict['text'] = value['text'] - if img_size: - line_dict['height'] = value['height'] - line_dict['width'] = value['width'] - line_dict['annotations'] = anno - line_dict_str = json.dumps(line_dict) - fw.write(line_dict_str + '\n') + lines = [] + for key, value in imgid2imgname.items(): + if key in imgid2anno: + anno = imgid2anno[key] + line_dict = {} + line_dict['file_name'] = value['file_name'] + line_dict['text'] = value['text'] + if img_size: + line_dict['height'] = value['height'] + line_dict['width'] = value['width'] + line_dict['annotations'] = anno + lines.append(json.dumps(line_dict)) + list_to_file(out_path, lines) def parse_args(): diff --git a/tools/data/textrecog/svt_converter.py b/tools/data/textrecog/svt_converter.py index 1b340b3e0d..35e4d7f090 100644 --- a/tools/data/textrecog/svt_converter.py +++ b/tools/data/textrecog/svt_converter.py @@ -5,6 +5,8 @@ import cv2 +from mmocr.utils.fileio import list_to_file + def parse_args(): parser = argparse.ArgumentParser( @@ -43,35 +45,35 @@ def main(): root = tree.getroot() index = 1 - with open(dst_label_file, 'w', encoding='utf-8') as fw: - total_img_num = len(root) - i = 1 - for image_node in root.findall('image'): - image_name = image_node.find('imageName').text - print(f'[{i}/{total_img_num}] Process image: {image_name}') - i += 1 - lexicon = image_node.find('lex').text.lower() - lexicon_list = lexicon.split(',') - lex_size = len(lexicon_list) - src_img = cv2.imread(osp.join(src_image_root, image_name)) - for rectangle in image_node.find('taggedRectangles'): - x = int(rectangle.get('x')) - y = int(rectangle.get('y')) - w = int(rectangle.get('width')) - h = int(rectangle.get('height')) - rb, re = max(0, y), max(0, y + h) - cb, ce = max(0, x), max(0, x + w) - dst_img = src_img[rb:re, cb:ce] - text_label = rectangle.find('tag').text.lower() - if args.resize: - dst_img = cv2.resize(dst_img, (args.width, args.height)) - dst_img_name = f'img_{index:04}' + '.jpg' - index += 1 - dst_img_path = osp.join(dst_image_root, dst_img_name) - cv2.imwrite(dst_img_path, dst_img) - fw.write(f'{osp.basename(dst_image_root)}/{dst_img_name} ' - f'{text_label} {lex_size} {lexicon}\n') - + lines = [] + total_img_num = len(root) + i = 1 + for image_node in root.findall('image'): + image_name = image_node.find('imageName').text + print(f'[{i}/{total_img_num}] Process image: {image_name}') + i += 1 + lexicon = image_node.find('lex').text.lower() + lexicon_list = lexicon.split(',') + lex_size = len(lexicon_list) + src_img = cv2.imread(osp.join(src_image_root, image_name)) + for rectangle in image_node.find('taggedRectangles'): + x = int(rectangle.get('x')) + y = int(rectangle.get('y')) + w = int(rectangle.get('width')) + h = int(rectangle.get('height')) + rb, re = max(0, y), max(0, y + h) + cb, ce = max(0, x), max(0, x + w) + dst_img = src_img[rb:re, cb:ce] + text_label = rectangle.find('tag').text.lower() + if args.resize: + dst_img = cv2.resize(dst_img, (args.width, args.height)) + dst_img_name = f'img_{index:04}' + '.jpg' + index += 1 + dst_img_path = osp.join(dst_image_root, dst_img_name) + cv2.imwrite(dst_img_path, dst_img) + lines.append(f'{osp.basename(dst_image_root)}/{dst_img_name} ' + f'{text_label} {lex_size} {lexicon}') + list_to_file(dst_label_file, lines) print(f'Finish to generate svt testset, ' f'with label file {dst_label_file}') diff --git a/tools/ocr_test_imgs.py b/tools/ocr_test_imgs.py index d2118bc093..040d21ed5f 100755 --- a/tools/ocr_test_imgs.py +++ b/tools/ocr_test_imgs.py @@ -3,6 +3,7 @@ import shutil import time from argparse import ArgumentParser +from itertools import compress import mmcv import torch @@ -13,7 +14,7 @@ from mmocr.core.evaluation.ocr_metric import eval_ocr_metric from mmocr.datasets import build_dataset # noqa: F401 from mmocr.models import build_detector # noqa: F401 -from mmocr.utils import get_root_logger +from mmocr.utils import get_root_logger, list_from_file, list_to_file def save_results(img_paths, pred_labels, gt_labels, res_dir): @@ -26,21 +27,15 @@ def save_results(img_paths, pred_labels, gt_labels, res_dir): res_dir (str) """ assert len(img_paths) == len(pred_labels) == len(gt_labels) - res_file = osp.join(res_dir, 'results.txt') - correct_file = osp.join(res_dir, 'correct.txt') - wrong_file = osp.join(res_dir, 'wrong.txt') - with open(res_file, 'w') as fw, \ - open(correct_file, 'w') as fw_correct, \ - open(wrong_file, 'w') as fw_wrong: - for img_path, pred_label, gt_label in zip(img_paths, pred_labels, - gt_labels): - fw.write(img_path + ' ' + pred_label + ' ' + gt_label + '\n') - if pred_label == gt_label: - fw_correct.write(img_path + ' ' + pred_label + ' ' + gt_label + - '\n') - else: - fw_wrong.write(img_path + ' ' + pred_label + ' ' + gt_label + - '\n') + corrects = [pred == gt for pred, gt in zip(pred_labels, gt_labels)] + wrongs = [not c for c in corrects] + lines = [ + f'{img} {pred} {gt}' + for img, pred, gt in zip(img_paths, pred_labels, gt_labels) + ] + list_to_file(osp.join(res_dir, 'results.txt'), lines) + list_to_file(osp.join(res_dir, 'correct.txt'), compress(lines, corrects)) + list_to_file(osp.join(res_dir, 'wrong.txt'), compress(lines, wrongs)) def main(): @@ -80,39 +75,38 @@ def main(): total_img_num = sum([1 for _ in open(args.img_list)]) progressbar = ProgressBar(task_num=total_img_num) num_gt_label = 0 - with open(args.img_list, 'r') as fr: - for line in fr: - progressbar.update() - item_list = line.strip().split() - img_file = item_list[0] - gt_label = '' - if len(item_list) >= 2: - gt_label = item_list[1] - num_gt_label += 1 - img_path = osp.join(args.img_root_path, img_file) - if not osp.exists(img_path): - raise FileNotFoundError(img_path) - # Test a single image - result = model_inference(model, img_path) - pred_label = result['text'] - - out_img_name = '_'.join(img_file.split('/')) - out_file = osp.join(out_vis_dir, out_img_name) - kwargs_dict = { - 'gt_label': gt_label, - 'show': args.show, - 'out_file': '' if args.show else out_file - } - model.show_result(img_path, result, **kwargs_dict) - if gt_label != '': - if gt_label == pred_label: - dst_file = osp.join(correct_vis_dir, out_img_name) - else: - dst_file = osp.join(wrong_vis_dir, out_img_name) - shutil.copy(out_file, dst_file) - img_paths.append(img_path) - gt_labels.append(gt_label) - pred_labels.append(pred_label) + for line in list_from_file(args.img_list): + progressbar.update() + item_list = line.strip().split() + img_file = item_list[0] + gt_label = '' + if len(item_list) >= 2: + gt_label = item_list[1] + num_gt_label += 1 + img_path = osp.join(args.img_root_path, img_file) + if not osp.exists(img_path): + raise FileNotFoundError(img_path) + # Test a single image + result = model_inference(model, img_path) + pred_label = result['text'] + + out_img_name = '_'.join(img_file.split('/')) + out_file = osp.join(out_vis_dir, out_img_name) + kwargs_dict = { + 'gt_label': gt_label, + 'show': args.show, + 'out_file': '' if args.show else out_file + } + model.show_result(img_path, result, **kwargs_dict) + if gt_label != '': + if gt_label == pred_label: + dst_file = osp.join(correct_vis_dir, out_img_name) + else: + dst_file = osp.join(wrong_vis_dir, out_img_name) + shutil.copy(out_file, dst_file) + img_paths.append(img_path) + gt_labels.append(gt_label) + pred_labels.append(pred_label) # Save results save_results(img_paths, pred_labels, gt_labels, args.out_dir) diff --git a/tools/test_imgs.py b/tools/test_imgs.py index c981456456..1f87b5d501 100755 --- a/tools/test_imgs.py +++ b/tools/test_imgs.py @@ -1,5 +1,4 @@ #!/usr/bin/env python -import codecs import os.path as osp from argparse import ArgumentParser @@ -11,6 +10,7 @@ from mmdet.apis import inference_detector, init_detector from mmocr.core.evaluation.utils import filter_result from mmocr.models import build_detector # noqa: F401 +from mmocr.utils import list_from_file, list_to_file def gen_target_path(target_root_path, src_name, suffix): @@ -25,9 +25,9 @@ def gen_target_path(target_root_path, src_name, suffix): assert isinstance(src_name, str) assert isinstance(suffix, str) - dir_name, file_name = osp.split(src_name) - name, file_suffix = osp.splitext(file_name) - return target_root_path + '/' + name + suffix + file_name = osp.split(src_name)[-1] + name = osp.splitext(file_name)[0] + return osp.join(target_root_path, name + suffix) def save_2darray(mat, file_name): @@ -37,10 +37,8 @@ def save_2darray(mat, file_name): mat (ndarray): 2d-array of shape (n, m). file_name (str): The output file name. """ - with codecs.open(file_name, 'w', 'utf-8') as fw: - for row in mat: - row_str = ','.join([str(x) for x in row]) - fw.write(row_str + '\n') + lines = [','.join([str(x) for x in row]) for row in mat] + list_to_file(file_name, lines) def save_bboxes_quadrangles(bboxes_with_scores, @@ -144,22 +142,21 @@ def main(): total_img_num = sum([1 for _ in open(args.img_list)]) progressbar = ProgressBar(task_num=total_img_num) - with codecs.open(args.img_list, 'r', 'utf-8') as fr: - for line in fr: - progressbar.update() - img_path = args.img_root + '/' + line.strip() - if not osp.exists(img_path): - raise FileNotFoundError(img_path) - # Test a single image - result = inference_detector(model, img_path) - img_name = osp.basename(img_path) - out_file = osp.join(out_vis_dir, img_name) - kwargs_dict = { - 'score_thr': args.score_thr, - 'show': False, - 'out_file': out_file - } - model.show_result(img_path, result, **kwargs_dict) + for line in list_from_file(args.img_list): + progressbar.update() + img_path = osp.join(args.img_root, line.strip()) + if not osp.exists(img_path): + raise FileNotFoundError(img_path) + # Test a single image + result = inference_detector(model, img_path) + img_name = osp.basename(img_path) + out_file = osp.join(out_vis_dir, img_name) + kwargs_dict = { + 'score_thr': args.score_thr, + 'show': False, + 'out_file': out_file + } + model.show_result(img_path, result, **kwargs_dict) print(f'\nInference done, and results saved in {args.out_dir}\n')