Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add obj-det pipeline support for LayoutLMV2 #13622

Closed
wants to merge 34 commits into from
Closed
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
302831e
Add obj-det pipe support for LayoutLMV2
mishig25 Sep 17, 2021
2a265ef
Chore
mishig25 Sep 17, 2021
fd7d47b
Update src/transformers/models/layoutlmv2/feature_extraction_layoutlm…
mishig25 Sep 17, 2021
6b44c80
Add box normalization comments
mishig25 Sep 20, 2021
0964a63
Update src/transformers/models/layoutlmv2/feature_extraction_layoutlm…
mishig25 Sep 20, 2021
18d3805
Update src/transformers/models/layoutlmv2/feature_extraction_layoutlm…
mishig25 Sep 20, 2021
ad7d9b2
Update src/transformers/models/layoutlmv2/feature_extraction_layoutlm…
mishig25 Sep 20, 2021
a1267ca
Update src/transformers/models/layoutlmv2/feature_extraction_layoutlm…
mishig25 Sep 20, 2021
af6886b
Update src/transformers/models/layoutlmv2/feature_extraction_layoutlm…
mishig25 Sep 20, 2021
61c7df2
Chore
mishig25 Sep 20, 2021
85e40a4
Merge branch 'layout_detection' of github.com:mishig25/transformers i…
mishig25 Sep 20, 2021
59ec807
Add comments and make fixup
mishig25 Sep 20, 2021
6796027
Merge branch 'master' into layout_detection
mishig25 Sep 20, 2021
64c2df5
Refacotr is_subwords
mishig25 Sep 20, 2021
52bcdbe
Update src/transformers/models/auto/modeling_auto.py
mishig25 Sep 20, 2021
12d3205
Add post_process test and better offset naming
mishig25 Sep 21, 2021
6f69e80
Merge branch 'layout_detection' of github.com:mishig25/transformers i…
mishig25 Sep 21, 2021
f3248a6
Chore
mishig25 Sep 21, 2021
0704a05
Set Obj Det default threhsold 0.5
mishig25 Sep 21, 2021
9a8eadf
make fixup
mishig25 Sep 21, 2021
05db577
Rm print statements
mishig25 Sep 21, 2021
6589fb0
Update src/transformers/models/layoutlmv2/feature_extraction_layoutlm…
mishig25 Sep 23, 2021
01509bf
Better variable names in post_process
mishig25 Sep 23, 2021
3e3d626
Fix wrong comment
mishig25 Sep 23, 2021
df86d51
Extend run_pipeline_test test
mishig25 Sep 23, 2021
2cedc54
Fix tests
mishig25 Sep 23, 2021
d8bdd19
Merge branch 'master' into layout_detection
mishig25 Sep 23, 2021
3825a41
Add require backends to tests
mishig25 Sep 24, 2021
5ac49f8
Update get_config to use smaller resent
mishig25 Sep 24, 2021
907c4f4
Rm dev change
mishig25 Sep 24, 2021
c93385e
Disabling tests for bimodal models on pipelines that do not support it.
Narsil Sep 24, 2021
6004132
Merge branch 'master' into layout_detection
mishig25 Sep 24, 2021
defd574
Making all tests pass attempt #1.
Narsil Sep 25, 2021
e610734
Rm layout specific decorators from run_pipeline_ts
mishig25 Oct 4, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,8 @@
[
# Model for Object Detection mapping
("detr", "DetrForObjectDetection"),
# layoutlmv2 is able to detect entities (tokens) within an image with an OCR which can be interpreted as objects
("layoutlmv2", "LayoutLMv2ForTokenClassification"),
NielsRogge marked this conversation as resolved.
Show resolved Hide resolved
]
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,15 @@
from PIL import Image

from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
from ...file_utils import TensorType, is_pytesseract_available, requires_backends
from ...file_utils import TensorType, is_pytesseract_available, is_torch_available, requires_backends
from ...image_utils import ImageFeatureExtractionMixin, is_torch_tensor
from ...utils import logging


if is_torch_available():
import torch
from torch import nn

Comment on lines +30 to +33
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you still import LayoutLMv2FeatureExtractor when you don't have torch installed?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With lazy loading, all classes are None if the various requirements are not met I think

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you check this @mishig25?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if this answers your question. Added import torch lines because the newly added post_process method depends on torch

def post_process(self, outputs, target_sizes, offset_mapping, bbox):
"""
Converts the output of :class:`~transformers.LayoutLMv2ForTokenClassification` into the format expected by the
COCO api. Only supports PyTorch.
Args:

Should I move the import torch statement inside post_process method ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc'ing @LysandreJik here to check what's the best option

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would keep the import torch statement at the top level. Moving it inside the post_process function means it might crash once a program is well into progress, which can be painful. I'd rather it fail early on.

# soft dependency
if is_pytesseract_available():
import pytesseract
Expand All @@ -39,6 +43,7 @@


def normalize_box(box, width, height):
# box values are normalized in the range 0-1000
return [
int(1000 * (box[0] / width)),
int(1000 * (box[1] / height)),
Expand All @@ -47,6 +52,16 @@ def normalize_box(box, width, height):
]


def unnormalize_box(box, height, width):
# box values are normalized in the range 0-1000
return [
mishig25 marked this conversation as resolved.
Show resolved Hide resolved
int(width * (box[0] / 1000)),
int(height * (box[1] / 1000)),
int(width * (box[2] / 1000)),
int(height * (box[3] / 1000)),
]


def apply_tesseract(image: Image.Image):
"""Applies Tesseract OCR on a document image, and returns recognized words + normalized bounding boxes."""

Expand Down Expand Up @@ -220,3 +235,52 @@ def __call__(
encoded_inputs["boxes"] = boxes_batch

return encoded_inputs

def post_process(self, outputs, target_sizes, offset_mapping, bbox):
mishig25 marked this conversation as resolved.
Show resolved Hide resolved
"""
Converts the output of :class:`~transformers.LayoutLMv2ForTokenClassification` into the format expected by the
COCO api. Only supports PyTorch.

Args:
outputs (:obj:`Dict`):
Raw outputs of the model.
target_sizes (:obj:`torch.Tensor` of shape :obj:`(batch_size, 2)`):
Tensor containing the original size (h, w) of each image of the batch.
offset_mapping (:obj:`torch.Tensor` of shape :obj:`(batch_size, x, 2)`):
Tensor coming from the "offset_mapping" field of the outputs of
:class:`~transformer.LayoutLMv2TokenizerFast`.
bbox (:obj:`torch.Tensor` of shape :obj:`(batch_size, x, 4)`):
Tensor coming from the "bbox" field of the outputs of :class:`~transformer.LayoutLMv2TokenizerFast`.

Returns:
:obj:`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an
image in the batch as predicted by the model.
"""
out_logits = outputs.logits

assert len(out_logits) == len(
target_sizes
), "Make sure that you pass in as many target sizes as the batch dimension of the logits"
assert (
target_sizes.shape[1] == 2
), "Each element of target_sizes must contain the size (h, w) of each image of the batch"

prob = nn.functional.softmax(out_logits, -1)
scores, labels = prob[..., :-1].max(-1)
scores = scores.tolist()
labels = labels.tolist()
boxes = bbox.tolist()
target_sizes = target_sizes.tolist()
offsets = offset_mapping.tolist()

results = []

for s, l, b, o, (height, width) in zip(scores, labels, boxes, offsets, target_sizes):
# only keep start of a particular word (i.e. offset start == 0) (e.g. [San, Fran, ##cis, ##co] -> [San, Fran])
mishig25 marked this conversation as resolved.
Show resolved Hide resolved
l = [label for label, (offset_start, _) in zip(l, o) if offset_start == 0]
s = [score for score, (offset_start, _) in zip(s, o) if offset_start == 0]
b = [unnormalize_box(box, height, width) for box, (offset_start, _) in zip(b, o) if offset_start == 0]
results.append(
{"scores": torch.FloatTensor(s), "labels": torch.IntTensor(l), "boxes": torch.FloatTensor(b)}
)
return results
35 changes: 29 additions & 6 deletions src/transformers/pipelines/object_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __call__(self, *args, **kwargs) -> Union[Predictions, List[Prediction]]:

The pipeline accepts either a single image or a batch of images. Images in a batch must all be in the
same format: all as HTTP(S) links, all as local paths, or all as PIL images.
threshold (:obj:`float`, `optional`, defaults to 0.9):
threshold (:obj:`float`, `optional`, defaults to 0.5):
The probability necessary to make a prediction.

Return:
Expand All @@ -108,22 +108,42 @@ def preprocess(self, image):
image = self.load_image(image)
target_size = torch.IntTensor([[image.height, image.width]])
inputs = self.feature_extractor(images=[image], return_tensors="pt")
if self._is_layout_detection():
encoded_inputs = self.tokenizer(
text=inputs["words"],
boxes=inputs["boxes"],
truncation=True,
return_offsets_mapping=True,
return_tensors="pt",
)
encoded_inputs["image"] = inputs.pop("pixel_values")
encoded_inputs["target_size"] = target_size
return encoded_inputs
mishig25 marked this conversation as resolved.
Show resolved Hide resolved
inputs["target_size"] = target_size
return inputs

def _forward(self, model_inputs):
target_size = model_inputs.pop("target_size")
offset_mapping = model_inputs.pop("offset_mapping", None)
outputs = self.model(**model_inputs)
model_outputs = {"outputs": outputs, "target_size": target_size}
if offset_mapping is not None:
model_outputs["bbox"] = model_inputs["bbox"]
model_outputs["offset_mapping"] = offset_mapping
return model_outputs

def postprocess(self, model_outputs, threshold=0.9):
raw_annotations = self.feature_extractor.post_process(model_outputs["outputs"], model_outputs["target_size"])
def postprocess(self, model_outputs, threshold=0.5):
post_process_inputs = [model_outputs["outputs"], model_outputs["target_size"]]
offset_mapping = model_outputs.get("offset_mapping", None)
bbox = model_outputs.get("bbox", None)
if offset_mapping is not None and bbox is not None:
post_process_inputs.extend([offset_mapping, bbox])
raw_annotations = self.feature_extractor.post_process(*post_process_inputs)
raw_annotation = raw_annotations[0]

scores, labels, boxes = [raw_annotation[key] for key in ["scores", "labels", "boxes"]]
keep = raw_annotation["scores"] > threshold
scores = raw_annotation["scores"][keep]
labels = raw_annotation["labels"][keep]
boxes = raw_annotation["boxes"][keep]
scores, labels, boxes = scores[keep], labels[keep], boxes[keep]

raw_annotation["scores"] = scores.tolist()
raw_annotation["labels"] = [self.model.config.id2label[label.item()] for label in labels]
Expand Down Expand Up @@ -158,3 +178,6 @@ def _get_bounding_box(self, box: "torch.Tensor") -> Dict[str, int]:
"ymax": ymax,
}
return bbox

def _is_layout_detection(self) -> bool:
return self.model.__class__.__name__.endswith("ForTokenClassification")
39 changes: 38 additions & 1 deletion tests/test_feature_extraction_layoutlmv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import numpy as np

from transformers.file_utils import is_pytesseract_available, is_torch_available
from transformers.file_utils import ModelOutput, is_pytesseract_available, is_torch_available
from transformers.testing_utils import require_pytesseract, require_torch

from .test_feature_extraction_common import FeatureExtractionSavingTestMixin, prepare_image_inputs
Expand Down Expand Up @@ -219,3 +219,40 @@ def test_layoutlmv2_integration_test(self):
224,
),
)

def test_post_process(self):
# Initialize feature_extractor
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
# Initialize test inputs
outputs = ModelOutput(
logits=torch.FloatTensor(
[
[
[2.9213, -0.6988, -0.3930, -0.7230, -0.6250, -0.8034, -0.2588],
[3.3945, -0.7632, -0.8638, -0.6346, -0.7023, -0.6832, -0.5296],
[3.3949, -0.7318, -0.8169, -0.6397, -0.6878, -0.7218, -0.5687],
]
]
)
)
target_sizes = torch.IntTensor([[1000, 754]])
offset_mapping = torch.IntTensor([[[0, 0], [0, 1], [0, 2]]])
bbox = torch.IntTensor([[[0, 0, 0, 0], [632, 59, 647, 67], [137, 88, 168, 98]]])

outputs = feature_extractor.post_process(outputs, target_sizes, offset_mapping, bbox)

self.assertTrue(isinstance(outputs, list))
self.assertEqual(len(outputs), 1)
output = outputs[0]
self.assertEqual(
output["scores"].shape,
(3,),
)
self.assertEqual(
output["labels"].shape,
(3,),
)
self.assertEqual(
output["boxes"].shape,
(3, 4),
)
64 changes: 64 additions & 0 deletions tests/test_pipelines_object_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
MODEL_FOR_OBJECT_DETECTION_MAPPING,
AutoFeatureExtractor,
AutoModelForObjectDetection,
AutoTokenizer,
ObjectDetectionPipeline,
is_vision_available,
pipeline,
Expand All @@ -26,6 +27,8 @@
is_pipeline_test,
nested_simplify,
require_datasets,
require_detectron2,
require_pytesseract,
require_tf,
require_timm,
require_torch,
Expand Down Expand Up @@ -251,3 +254,64 @@ def test_threshold(self):
{"score": 0.9987, "label": "cat", "box": {"xmin": 345, "ymin": 23, "xmax": 640, "ymax": 368}},
],
)

@require_detectron2
@require_datasets
@require_pytesseract
@require_torch
@slow
def test_layout_architecture(self):
import datasets

dataset = datasets.load_dataset("nielsr/funsd", split="test")

object_detector = pipeline("object-detection", model="nielsr/layoutlmv2-finetuned-funsd")

outputs = object_detector([dataset[0]["image_path"], dataset[1]["image_path"]])
outputs = [o[:2] for o in outputs] # trimming the output

self.assertEqual(
nested_simplify(outputs, decimals=4),
[
[
{"score": 0.8447, "label": "other", "box": {"xmin": 0, "ymin": 0, "xmax": 0, "ymax": 0}},
{"score": 0.9083, "label": "other", "box": {"xmin": 476, "ymin": 59, "xmax": 487, "ymax": 67}},
],
[
{"score": 0.7886, "label": "other", "box": {"xmin": 0, "ymin": 0, "xmax": 0, "ymax": 0}},
{"score": 0.9086, "label": "other", "box": {"xmin": 85, "ymin": 78, "xmax": 135, "ymax": 87}},
],
],
)

@require_detectron2
@require_datasets
@require_pytesseract
@require_torch
@slow
def test_layout_prcoessing(self):
model_id = "nielsr/layoutlmv2-finetuned-funsd"

model = AutoModelForObjectDetection.from_pretrained(model_id)
feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
object_detector = ObjectDetectionPipeline(
model=model, feature_extractor=feature_extractor, tokenizer=tokenizer
)
mishig25 marked this conversation as resolved.
Show resolved Hide resolved

import datasets

dataset = datasets.load_dataset("nielsr/funsd", split="test")
doc_img = dataset[0]["image_path"]

pipe_inputs = object_detector.preprocess(doc_img)
self.assertEqual(
set(pipe_inputs.keys()),
{"input_ids", "token_type_ids", "attention_mask", "offset_mapping", "bbox", "image", "target_size"},
mishig25 marked this conversation as resolved.
Show resolved Hide resolved
)

model_outputs = object_detector.forward(pipe_inputs)
self.assertEqual(
set(model_outputs.keys()),
{"outputs", "target_size", "bbox", "offset_mapping"},
)