Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add rotated boxes #281

Merged
merged 52 commits into from
Jun 2, 2021
Merged
Show file tree
Hide file tree
Changes from 51 commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
840524e
refacto: rotated boxes
charlesmindee May 21, 2021
62e73b7
feat: modify ref & scripts
charlesmindee May 21, 2021
68694ac
fix: flake8 ./
charlesmindee May 21, 2021
ecbd6c0
fix: flake8 ./
charlesmindee May 21, 2021
436c311
merged main
charlesmindee May 21, 2021
dabd924
fix: metric test
charlesmindee May 21, 2021
7391b94
fix: tests
charlesmindee May 21, 2021
396faf4
Merge branch 'main' into rotate
charlesmindee May 21, 2021
31dbf4f
fix: unitests
charlesmindee May 24, 2021
7409055
fix: test
charlesmindee May 24, 2021
f659a9e
fix: tests
charlesmindee May 24, 2021
f1ae4f1
fix: detection
charlesmindee May 24, 2021
212f9d6
fix: test
charlesmindee May 25, 2021
295976e
fix: metric
charlesmindee May 25, 2021
1da8879
Merge branch 'main' into rotate
charlesmindee May 25, 2021
dcfbd54
fix: postprocess
charlesmindee May 25, 2021
7d42c59
add: perf
charlesmindee May 25, 2021
abc49c2
fix: rot matrix
charlesmindee May 25, 2021
95919a8
fox: dtype
charlesmindee May 26, 2021
1713e9c
fox: dtype
charlesmindee May 26, 2021
e5c3908
fix: api
charlesmindee May 26, 2021
b3f03bb
Merge branch 'main' into rotate
charlesmindee May 26, 2021
ff00819
fix: api
charlesmindee May 26, 2021
88e8eb9
Merge branch 'main' into rotate
charlesmindee May 27, 2021
08bd90d
refacto: datasets + extract
charlesmindee May 27, 2021
0d44f1d
refacto:core
charlesmindee May 27, 2021
40d57ce
merging
charlesmindee May 28, 2021
3af7a72
refacto: model core
charlesmindee May 28, 2021
6eb0817
refacto utils + models
charlesmindee May 28, 2021
0657d76
refacto scripts
charlesmindee May 28, 2021
50e0f02
Merge branch 'main' into rotate
charlesmindee May 28, 2021
6afdbf6
refacto: models
charlesmindee May 28, 2021
370f98d
refacto: unitests
charlesmindee May 31, 2021
4332c75
refacto: api
charlesmindee May 31, 2021
bef6eb7
refacto: api
charlesmindee May 31, 2021
98b0ac8
refacto: api
charlesmindee May 31, 2021
633fbbc
fix: api
charlesmindee May 31, 2021
378b224
feat: updates cov
charlesmindee May 31, 2021
5ff470f
feat: add viz rotated
charlesmindee May 31, 2021
f0c9a24
feat: updated scripts
charlesmindee May 31, 2021
e0f153d
feat: change angle conv
charlesmindee Jun 1, 2021
543f0b3
fix: extract crops
charlesmindee Jun 1, 2021
e737c0b
merging
charlesmindee Jun 1, 2021
3aa47a1
feat: metric in scripts
charlesmindee Jun 1, 2021
ea9f3b8
feat: metric in scripts
charlesmindee Jun 1, 2021
cd6e59d
refacto: models
charlesmindee Jun 1, 2021
a4de158
refacto: requested changes
charlesmindee Jun 1, 2021
69edf7d
refacto: boxpoints
charlesmindee Jun 1, 2021
b5e076a
fix: type
charlesmindee Jun 1, 2021
69b76b1
fix: type
charlesmindee Jun 1, 2021
35c0c47
refacto: requested changes
charlesmindee Jun 2, 2021
feb522c
fix: docstring
charlesmindee Jun 2, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion api/app/routes/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ async def text_detection(file: UploadFile = File(...)):
"""Runs DocTR text detection model to analyze the input"""
img = decode_image(file.file.read())
out = det_predictor([img], training=False)
return [DetectionOut(box=box.tolist()) for box in out[0][:, :4]]
return [DetectionOut(box=box.tolist()) for box in out[0][:, :-1]]
6 changes: 3 additions & 3 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -148,11 +148,11 @@ Predictors that localize and identify text elements in images
+-----------------------------+------------+---------------+---------+------------+---------------+---------+
| db_resnet50 + sar_resnet31 | N/A | N/A | 0.27 | N/A | N/A | 0.83 |
+-----------------------------+------------+---------------+---------+------------+---------------+---------+
| Gvision text detection | 0.595 | 0.625 | | 0.753 | 0.700 | |
| Gvision text detection | 59.50 | 62.50 | | 75.30 | 70.00 | |
+-----------------------------+------------+---------------+---------+------------+---------------+---------+
| Gvision doc. text detection | 0.640 | 0.533 | | 0.689 | 0.611 | |
| Gvision doc. text detection | 64.00 | 53.30 | | 68.90 | 61.10 | |
+-----------------------------+------------+---------------+---------+------------+---------------+---------+
| AWS textract | **0.781** | **0.830** | | **0.875** | 0.660 | |
| AWS textract | **78.10** | **83.00** | | **87.50** | 66.00 | |
+-----------------------------+------------+---------------+---------+------------+---------------+---------+

All OCR models above have been evaluated using both the training and evaluation sets of FUNSD and CORD (cf. :ref:`datasets`).
Expand Down
19 changes: 15 additions & 4 deletions doctr/datasets/cord.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import tensorflow as tf

from .core import VisionDataset
from doctr.utils.geometry import fit_rbbox

__all__ = ['CORD']

Expand Down Expand Up @@ -39,6 +40,7 @@ def __init__(
self,
train: bool = True,
sample_transforms: Optional[Callable[[tf.Tensor], tf.Tensor]] = None,
rotated_bbox: bool = False,
**kwargs: Any,
) -> None:

Expand All @@ -62,10 +64,19 @@ def __init__(
if len(word["text"]) > 0:
x = word["quad"]["x1"], word["quad"]["x2"], word["quad"]["x3"], word["quad"]["x4"]
y = word["quad"]["y1"], word["quad"]["y2"], word["quad"]["y3"], word["quad"]["y4"]
# Reduce 8 coords to 4
left, right = min(x), max(x)
top, bot = min(y), max(y)
_targets.append((word["text"], [left, top, right, bot]))
if not rotated_bbox:
# Reduce 8 coords to 4
left, right = min(x), max(x)
top, bot = min(y), max(y)
_targets.append((word["text"], [left, top, right, bot]))
else:
x, y, w, h, alpha = fit_rbbox(np.array([
[x[0], y[0]],
[x[1], y[1]],
[x[2], y[2]],
[x[3], y[3]],
], np.float32))
_targets.append((word["text"], [x, y, w, h, alpha]))

text_targets, box_targets = zip(*_targets)

Expand Down
12 changes: 8 additions & 4 deletions doctr/datasets/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import List, Tuple, Dict, Any, Optional, Callable

from .core import AbstractDataset
from doctr.utils.geometry import fit_rbbox

__all__ = ["DetectionDataset"]

Expand All @@ -32,6 +33,7 @@ def __init__(
img_folder: str,
label_folder: str,
sample_transforms: Optional[Callable[[tf.Tensor], tf.Tensor]] = None,
rotated_bbox: bool = False,
) -> None:
self.sample_transforms = sample_transforms
self.root = img_folder
Expand All @@ -42,13 +44,15 @@ def __init__(
raise FileNotFoundError(f"unable to locate {os.path.join(self.root, img_path)}")
with open(os.path.join(label_folder, img_path + '.json'), 'rb') as f:
boxes = json.load(f)

bboxes = np.asarray(boxes["boxes_1"] + boxes["boxes_2"] + boxes["boxes_3"], dtype=np.float32)
# Switch to xmin, ymin, xmax, ymax
bboxes = np.concatenate((bboxes.min(axis=1), bboxes.max(axis=1)), axis=1)
if not rotated_bbox:
# Switch to xmin, ymin, xmax, ymax
bboxes = np.concatenate((bboxes.min(axis=1), bboxes.max(axis=1)), axis=1)
else:
# Switch to rotated rects
bboxes = np.asarray([list(fit_rbbox(box)) for box in bboxes], dtype=np.float32)

is_ambiguous = [False] * (len(boxes["boxes_1"]) + len(boxes["boxes_2"])) + [True] * len(boxes["boxes_3"])

self.data.append((img_path, dict(boxes=bboxes, flags=np.asarray(is_ambiguous))))

def __getitem__(
Expand Down
9 changes: 8 additions & 1 deletion doctr/datasets/funsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(
self,
train: bool = True,
sample_transforms: Optional[Callable[[tf.Tensor], tf.Tensor]] = None,
rotated_bbox: bool = False,
**kwargs: Any,
) -> None:

Expand All @@ -60,8 +61,14 @@ def __init__(

_targets = [(word['text'], word['box']) for block in data['form']
for word in block['words'] if len(word['text']) > 0]

text_targets, box_targets = zip(*_targets)
charlesmindee marked this conversation as resolved.
Show resolved Hide resolved
if rotated_bbox:
# box_targets: xmin, ymin, xmax, ymax -> x, y, w, h, alpha = 0
box_targets = [
[
(box[0] + box[2]) / 2, (box[1] + box[3]) / 2, box[2] - box[0], box[3] - box[1], 0
] for box in box_targets
]

self.data.append((img_path, dict(boxes=np.asarray(box_targets, dtype=np.int), labels=text_targets)))

Expand Down
18 changes: 13 additions & 5 deletions doctr/datasets/ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import tensorflow as tf

from .core import AbstractDataset
from doctr.utils.geometry import fit_rbbox


__all__ = ['OCRDataset']
Expand All @@ -31,6 +32,7 @@ def __init__(
img_folder: str,
label_file: str,
sample_transforms: Optional[Callable[[tf.Tensor], tf.Tensor]] = None,
rotated_bbox: bool = False,
**kwargs: Any,
) -> None:

Expand All @@ -56,11 +58,17 @@ def __init__(
is_valid: List[bool] = []
box_targets: List[List[float]] = []
for box in file_dic["coordinates"]:
xs, ys = zip(*box)
box = [min(xs), min(ys), max(xs), max(ys)]
is_valid.append(box[0] < box[2] and box[1] < box[3])
if is_valid[-1]:
box_targets.append(box)
if rotated_bbox:
x, y, w, h, alpha = fit_rbbox(np.asarray(box, dtype=np.float32))
is_valid.append(w > 0 and h > 0)
if is_valid[-1]:
box_targets.append([x, y, w, h, alpha])
else:
xs, ys = zip(*box)
box = [min(xs), min(ys), max(xs), max(ys)]
is_valid.append(box[0] < box[2] and box[1] < box[3])
if is_valid[-1]:
box_targets.append(box)

text_targets = [word for word, _valid in zip(file_dic["string"], is_valid) if _valid]
self.data.append((img_name, dict(boxes=np.asarray(box_targets, dtype=np.float32), labels=text_targets)))
4 changes: 4 additions & 0 deletions doctr/datasets/sroie.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(
self,
train: bool = True,
sample_transforms: Optional[Callable[[tf.Tensor], tf.Tensor]] = None,
rotated_bbox: bool = False,
charlesmindee marked this conversation as resolved.
Show resolved Hide resolved
**kwargs: Any,
) -> None:

Expand All @@ -47,6 +48,9 @@ def __init__(
self.sample_transforms = sample_transforms
self.train = train

if rotated_bbox:
raise NotImplementedError

# # List images
self.root = os.path.join(self._root, 'images')
self.data: List[Tuple[str, Dict[str, Any]]] = []
Expand Down
15 changes: 11 additions & 4 deletions doctr/documents/elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import matplotlib.pyplot as plt
from typing import Tuple, Dict, List, Any, Optional

from doctr.utils.geometry import resolve_enclosing_bbox
from doctr.utils.geometry import resolve_enclosing_bbox, resolve_enclosing_rbbox
from doctr.utils.visualization import visualize_page
from doctr.utils.common_types import BoundingBox
from doctr.utils.repr import NestedObject
Expand Down Expand Up @@ -111,7 +111,9 @@ def __init__(
) -> None:
# Resolve the geometry using the smallest enclosing bounding box
if geometry is None:
geometry = resolve_enclosing_bbox([w.geometry for w in words])
# Check whether this is a rotated or straight box
box_resolution_fn = resolve_enclosing_rbbox if len(words[0].geometry) == 5 else resolve_enclosing_bbox
geometry = box_resolution_fn([w.geometry for w in words])

super().__init__(words=words)
self.geometry = geometry
Expand Down Expand Up @@ -146,7 +148,9 @@ def __init__(
if geometry is None:
line_boxes = [word.geometry for line in lines for word in line.words]
artefact_boxes = [artefact.geometry for artefact in artefacts]
geometry = resolve_enclosing_bbox(line_boxes + artefact_boxes)
box_resolution_fn = resolve_enclosing_rbbox if len(lines[0].geometry) == 5 else resolve_enclosing_bbox
geometry = box_resolution_fn(line_boxes + artefact_boxes)

super().__init__(lines=lines, artefacts=artefacts)
self.geometry = geometry

Expand Down Expand Up @@ -190,7 +194,9 @@ def render(self, block_break: str = '\n\n') -> str:
def extra_repr(self) -> str:
return f"dimensions={self.dimensions}"

def show(self, page: np.ndarray, interactive: bool = True, **kwargs) -> None:
def show(
self, page: np.ndarray, interactive: bool = True, **kwargs
) -> None:
"""Overlay the result on a given image

Args:
Expand Down Expand Up @@ -225,6 +231,7 @@ def show(self, pages: List[np.ndarray], **kwargs) -> None:

Args:
pages: list of images encoded as numpy arrays in uint8
rotation: display rotated_bboxes if True
charlesmindee marked this conversation as resolved.
Show resolved Hide resolved
"""
for img, result in zip(pages, self.pages):
result.show(img, **kwargs)
61 changes: 57 additions & 4 deletions doctr/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0.txt> for full license details.

import numpy as np
from typing import List
import cv2
import tensorflow as tf
from typing import List, Union

__all__ = ['extract_crops']
__all__ = ['extract_crops', 'extract_rcrops']


def extract_crops(img: np.ndarray, boxes: np.ndarray) -> List[np.ndarray]:
def extract_crops(img: Union[np.ndarray, tf.Tensor], boxes: np.ndarray) -> List[Union[np.ndarray, tf.Tensor]]:
"""Created cropped images from list of bounding boxes

Args:
Expand All @@ -20,7 +22,6 @@ def extract_crops(img: np.ndarray, boxes: np.ndarray) -> List[np.ndarray]:
Returns:
list of cropped images
"""

if boxes.shape[0] == 0:
return []
if boxes.shape[1] != 4:
Expand All @@ -35,3 +36,55 @@ def extract_crops(img: np.ndarray, boxes: np.ndarray) -> List[np.ndarray]:
# Add last index
_boxes[2:] += 1
return [img[box[1]: box[3], box[0]: box[2]] for box in _boxes]


def extract_rcrops(img: Union[np.ndarray, tf.Tensor], boxes: np.ndarray) -> List[np.ndarray]:
"""Created cropped images from list of rotated bounding boxes

Args:
img: input image
boxes: bounding boxes of shape (N, 5) where N is the number of boxes, and the relative
coordinates (x, y, w, h, alpha)

Returns:
list of cropped images
"""
if boxes.shape[0] == 0:
return []
if boxes.shape[1] != 5:
raise AssertionError("boxes are expected to be relative and in order (x, y, w, h, alpha)")

if isinstance(img, tf.Tensor):
img = img.numpy().astype(np.uint8)

# Project relative coordinates
_boxes = boxes.copy()
if _boxes.dtype != np.int:
_boxes[:, [0, 2]] *= img.shape[1]
_boxes[:, [1, 3]] *= img.shape[0]

crops = []
# Determine rotation direction (clockwise/counterclockwise)
# Angle coverage: [-90°, +90°], half of the quadrant
clockwise = False
if np.sum(boxes[:, 2]) > np.sum(boxes[:, 3]):
clockwise = True

for box in _boxes:
x, y, w, h, alpha = box.astype(np.float32)
src_pts = cv2.boxPoints(((x, y), (w, h), alpha))[1:, :]
# Preserve size
if clockwise:
dst_pts = np.array([[0, 0], [w - 1, 0], [w - 1, h - 1]], dtype=np.float32)
else:
dst_pts = np.array([[h - 1, 0], [h - 1, w - 1], [0, w - 1]], dtype=np.float32)
# The transformation matrix
M = cv2.getAffineTransform(src_pts, dst_pts)
# Warp the rotated rectangle
if clockwise:
crop = cv2.warpAffine(img, M, (int(w), int(h)))
else:
crop = cv2.warpAffine(img, M, (int(h), int(w)))
crops.append(crop)

return crops
Loading