Skip to content

Commit

Permalink
fix: Prevent sorting entities labeled in Document AI Workbench (#200)
Browse files Browse the repository at this point in the history
* Refactored Entity loading and changed sorting to only occur if all ids are digits.
---------
Co-authored-by: Holt Skinner <holtskinner@google.com>
  • Loading branch information
lababidi committed Nov 30, 2023
1 parent 378ebd6 commit d843e51
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 17 deletions.
31 changes: 15 additions & 16 deletions google/cloud/documentai_toolbox/wrappers/document.py
Expand Up @@ -50,23 +50,22 @@ def _entities_from_shards(
List[Entity]:
a list of Entities.
"""
result = []
# Needed to load the correct page index for sharded documents.
page_offset = 0
for shard in shards:
entities = [
Entity(documentai_object=entity, page_offset=page_offset)
for entity in shard.entities
]
properties = [
Entity(documentai_object=prop, page_offset=page_offset)
for entity in shard.entities
for prop in entity.properties
]
result.extend(entities + properties)
page_offset += len(shard.pages)
result = [
Entity(
documentai_object=item,
# Needed to load the correct page index for sharded documents.
page_offset=sum(len(shard.pages) for shard in shards[:i]),
)
for i, shard in enumerate(shards)
for entity in shard.entities
for item in (entity, *entity.properties)
]

if len(result) > 1 and result[0].documentai_object.id:
# https://github.com/googleapis/python-documentai-toolbox/issues/199
# Only sort entities if the ids are all numeric.
# Document AI Workbench labeling outputs hexadecimal ids which should not be sorted.
# Sorting numeric ids is needed for backwards-compatible behavior.
if len(result) > 1 and all(item.documentai_object.id.isdigit() for item in result):
result.sort(key=lambda x: int(x.documentai_object.id))
return result

Expand Down
1 change: 1 addition & 0 deletions tests/unit/resources/hex_ids/patent.json

Large diffs are not rendered by default.

19 changes: 18 additions & 1 deletion tests/unit/test_document.py
Expand Up @@ -109,6 +109,7 @@ def create_document_with_images_without_bbox(get_bytes_images_mock):
doc = document.Document.from_gcs(
gcs_bucket_name="test-directory", gcs_prefix="documentai/output/123456789/0"
)
get_bytes_images_mock.assert_called_once()

del (
doc.entities[0]
Expand Down Expand Up @@ -169,7 +170,7 @@ def test_pages_from_shards():
assert page.page_number == page_index + 1


def test_entities_from_shard():
def test_entities_from_shards():
shards = []
for byte in get_bytes("tests/unit/resources/0"):
shards.append(documentai.Document.from_json(byte))
Expand All @@ -183,6 +184,22 @@ def test_entities_from_shard():
assert actual[1].normalized_text == "140 USD"


# For documents labeled in Document AI Workbench
def test_entities_from_shards_with_hex_ids():
shards = []
for byte in get_bytes("tests/unit/resources/hex_ids"):
shards.append(documentai.Document.from_json(byte))

actual = document._entities_from_shards(shards=shards)

assert actual[0].documentai_object.id == "ef4fd8a921c0ea81"
assert actual[0].mention_text == "453,945"
assert actual[0].type_ == "application_number"
assert actual[1].documentai_object.id == "ef4fd8a921c0e000"
assert actual[1].mention_text == "G06F 1/26"
assert actual[1].type_ == "class_international"


@mock.patch("google.cloud.documentai_toolbox.wrappers.document.documentai")
def test_get_batch_process_metadata_with_valid_operation(
mock_docai,
Expand Down

0 comments on commit d843e51

Please sign in to comment.