From ea704e980f748f0b3cf3b5558d95a5e1143bee3a Mon Sep 17 00:00:00 2001 From: fg-mindee Date: Sat, 5 Jun 2021 17:00:21 +0200 Subject: [PATCH] refactor: Refactored dataset constructors --- doctr/datasets/detection.py | 8 ++++---- doctr/datasets/ocr.py | 7 +++---- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/doctr/datasets/detection.py b/doctr/datasets/detection.py index 03fc262858..ce992a948f 100644 --- a/doctr/datasets/detection.py +++ b/doctr/datasets/detection.py @@ -47,12 +47,12 @@ def __init__( with open(os.path.join(label_folder, img_path + '.json'), 'rb') as f: boxes = json.load(f) bboxes = np.asarray(boxes["boxes_1"] + boxes["boxes_2"] + boxes["boxes_3"], dtype=np.float32) - if not rotated_bbox: - # Switch to xmin, ymin, xmax, ymax - bboxes = np.concatenate((bboxes.min(axis=1), bboxes.max(axis=1)), axis=1) - else: + if rotated_bbox: # Switch to rotated rects bboxes = np.asarray([list(fit_rbbox(box)) for box in bboxes], dtype=np.float32) + else: + # Switch to xmin, ymin, xmax, ymax + bboxes = np.concatenate((bboxes.min(axis=1), bboxes.max(axis=1)), axis=1) is_ambiguous = [False] * (len(boxes["boxes_1"]) + len(boxes["boxes_2"])) + [True] * len(boxes["boxes_3"]) self.data.append((img_path, dict(boxes=bboxes, flags=np.asarray(is_ambiguous)))) diff --git a/doctr/datasets/ocr.py b/doctr/datasets/ocr.py index f34ff8f680..f73072bd0a 100644 --- a/doctr/datasets/ocr.py +++ b/doctr/datasets/ocr.py @@ -62,15 +62,14 @@ def __init__( for box in file_dic["coordinates"]: if rotated_bbox: x, y, w, h, alpha = fit_rbbox(np.asarray(box, dtype=np.float32)) + box = [x, y, w, h, alpha] is_valid.append(w > 0 and h > 0) - if is_valid[-1]: - box_targets.append([x, y, w, h, alpha]) else: xs, ys = zip(*box) box = [min(xs), min(ys), max(xs), max(ys)] is_valid.append(box[0] < box[2] and box[1] < box[3]) - if is_valid[-1]: - box_targets.append(box) + if is_valid[-1]: + box_targets.append(box) text_targets = [word for word, _valid in zip(file_dic["string"], is_valid) if _valid] self.data.append((img_name, dict(boxes=np.asarray(box_targets, dtype=np.float32), labels=text_targets)))