Skip to content

Commit

Permalink
feat: Add entities_to_dict() and entities_to_bigquery() to `Docum…
Browse files Browse the repository at this point in the history
…ent` wrapper (#50)

* feat: Add `entities_to_dict()` and `entities_to_bigquery()` to Document wrapper

- Uploads entities to an existing dataset, creates new table if it doesn't already exist.

## Example Output Table

| supplier_iban | purchase_order | supplier_email      | freight_amount | supplier_address             | receiver_address             | total_amount | supplier_name | total_tax_amount | payment_terms    | line_item                                                                     | receiver_name | receiver_email    | due_date   | invoice_date | invoice_id | currency | receiver_tax_id | net_amount | vat |
|---------------|----------------|---------------------|----------------|------------------------------|------------------------------|--------------|---------------|------------------|------------------|-------------------------------------------------------------------------------|---------------|-------------------|------------|--------------|------------|----------|-----------------|------------|-----|
|            50 |              1 | user@companyabc.com |            600 | 111 Main Street Anytown, USA | 222 Main Street Anytown, USA |         2140 | Company ABC   |              140 | 6 month contract | [Tool A 500 1.00 500.00,Service B 1 900.00 900.00,Resource C 50 12.00 600.00] | John Doe      | johndoe@email.com | 2025-01-01 |   1970-01-01 | NO. 001    | $        |               1 |       2000 | 140 |

* Removed unneeded test code

* Added bigquery library to setup.py

* Updated Docstrings

* Fixed Test import linter error

* Added bigQuery Library to Testing Constraints

* Added handling of Nested Entities (properties)

* Dependency Update for Tests

* Update Dependencies

* Fixed Test Output

* Updated DatasetReference based on Deprecation Warning

* samples: Added Entities to BigQuery Sample Code

* Added Required tag to `entities_to_bigquery()` arguments

* Fixed Issues from merge conflict

* Fixed numpy import

---------

Co-authored-by: Gal Zahavi <38544478+galz10@users.noreply.github.com>
  • Loading branch information
holtskinner and galz10 committed Feb 15, 2023
1 parent 8359911 commit 494fa86
Show file tree
Hide file tree
Showing 15 changed files with 255 additions and 12 deletions.
66 changes: 66 additions & 0 deletions google/cloud/documentai_toolbox/wrappers/document.py
Expand Up @@ -21,6 +21,7 @@
from typing import Dict, List, Optional

from google.api_core import client_info
from google.cloud import bigquery
from google.cloud import documentai
from google.cloud import storage
from google.cloud import documentai_toolbox
Expand Down Expand Up @@ -50,6 +51,8 @@ def _entities_from_shards(
for shard in shards:
for entity in shard.entities:
result.append(Entity(documentai_entity=entity))
for prop in entity.properties:
result.append(Entity(documentai_entity=prop))
return result


Expand Down Expand Up @@ -368,6 +371,69 @@ def get_entity_by_type(self, target_type: str) -> List[Entity]:
"""
return [entity for entity in self.entities if entity.type_ == target_type]

def entities_to_dict(self) -> Dict:
r"""Returns Dictionary of entities in document.
Returns:
Dict:
The Dict of the entities indexed by type.
"""
entities_dict: Dict = {}
for entity in self.entities:
entity_type = entity.type_.replace("/", "_")

existing_entity = entities_dict.get(entity_type)
if not existing_entity:
entities_dict[entity_type] = entity.mention_text
continue

# For entities that can have multiple (e.g. line_item)
# Change Entity Type to a List
if not isinstance(existing_entity, list):
existing_entity = [existing_entity]

existing_entity.append(entity.mention_text)
entities_dict[entity_type] = existing_entity

return entities_dict

def entities_to_bigquery(
self, dataset_name: str, table_name: str, project_id: Optional[str] = None
) -> bigquery.job.LoadJob:
r"""Adds extracted entities to a BigQuery table.
Args:
dataset_name (str):
Required. Name of the BigQuery dataset.
table_name (str):
Required. Name of the BigQuery table.
project_id (Optional[str]):
Optional. Project ID containing the BigQuery table. If not passed, falls back to the default inferred from the environment.
Returns:
bigquery.job.LoadJob:
The BigQuery LoadJob for adding the entities.
"""
bq_client = bigquery.Client(project=project_id)
table_ref = bigquery.DatasetReference(
project=project_id, dataset_id=dataset_name
).table(table_name)

job_config = bigquery.LoadJobConfig(
schema_update_options=[
bigquery.SchemaUpdateOption.ALLOW_FIELD_ADDITION,
bigquery.SchemaUpdateOption.ALLOW_FIELD_RELAXATION,
],
source_format=bigquery.SourceFormat.NEWLINE_DELIMITED_JSON,
)

return bq_client.load_table_from_json(
json_rows=[self.entities_to_dict()],
destination=table_ref,
job_config=job_config,
)

def split_pdf(self, pdf_path: str, output_path: str) -> List[str]:
r"""Splits local PDF file into multiple PDF files based on output from a Splitter/Classifier processor.
Expand Down
7 changes: 7 additions & 0 deletions google/cloud/documentai_toolbox/wrappers/entity.py
Expand Up @@ -37,13 +37,20 @@ class Entity:
documentai_entity: documentai.Document.Entity = dataclasses.field(repr=False)
type_: str = dataclasses.field(init=False)
mention_text: str = dataclasses.field(init=False, default="")
normalized_text: str = dataclasses.field(init=False, default="")
# Only Populated for Splitter/Classifier Output
start_page: int = dataclasses.field(init=False)
end_page: int = dataclasses.field(init=False)

def __post_init__(self):
self.type_ = self.documentai_entity.type_
self.mention_text = self.documentai_entity.mention_text
if (
self.documentai_entity.normalized_value
and self.documentai_entity.normalized_value.text
):
self.normalized_text = self.documentai_entity.normalized_value.text

if self.documentai_entity.page_anchor.page_refs:
self.start_page = int(self.documentai_entity.page_anchor.page_refs[0].page)
self.end_page = int(self.documentai_entity.page_anchor.page_refs[-1].page)
50 changes: 50 additions & 0 deletions samples/snippets/entities_to_bigquery_sample.py
@@ -0,0 +1,50 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


# [START documentai_toolbox_entities_to_bigquery]

from google.cloud.documentai_toolbox import document

# TODO(developer): Uncomment these variables before running the sample.
# Given a document.proto or sharded document.proto in path gs://bucket/path/to/folder
# gcs_bucket_name = "bucket"
# gcs_prefix = "path/to/folder"
# dataset_name = "test_dataset"
# table_name = "test_table"
# project_id = "YOUR_PROJECT_ID"


def entities_to_bigquery_sample(
gcs_bucket_name: str,
gcs_prefix: str,
dataset_name: str,
table_name: str,
project_id: str,
) -> None:
wrapped_document = document.Document.from_gcs(
gcs_bucket_name=gcs_bucket_name, gcs_prefix=gcs_prefix
)

job = wrapped_document.entities_to_bigquery(
dataset_name=dataset_name, table_name=table_name, project_id=project_id
)

print("Document entities loaded into BigQuery")
print(f"Job ID: {job.job_id}")
print(f"Table: {job.destination.path}")


# [END documentai_toolbox_entities_to_bigquery]
3 changes: 2 additions & 1 deletion samples/snippets/requirements-test.txt
@@ -1,2 +1,3 @@
pytest==7.2.1
mock==5.0.1
mock==5.0.1
google-cloud-bigquery==3.5.0
1 change: 1 addition & 0 deletions samples/snippets/requirements.txt
@@ -1,3 +1,4 @@
google-cloud-bigquery==3.5.0
google-cloud-documentai==2.12.0
google-cloud-storage==2.7.0
google-cloud-documentai-toolbox==0.1.1a0
54 changes: 54 additions & 0 deletions samples/snippets/test_entities_to_bigquery_sample.py
@@ -0,0 +1,54 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import os
import uuid

import pytest
from samples.snippets import entities_to_bigquery_sample

from google.cloud import bigquery

location = "us"
project_id = os.environ["GOOGLE_CLOUD_PROJECT"]
gcs_bucket_name = "documentai_toolbox_samples"
gcs_input_uri = "output/123456789/0"
dataset_name = f"document_ai_toolbox_test_{uuid.uuid4().hex}"
table_name = f"test_table_{uuid.uuid4().hex}"


def test_entities_to_bigquery_sample(capsys: pytest.CaptureFixture) -> None:
client = bigquery.Client(project=project_id)
dataset = bigquery.Dataset(f"{project_id}.{dataset_name}")
dataset.location = "US"
dataset = client.create_dataset(dataset, timeout=30, exists_ok=True)

entities_to_bigquery_sample.entities_to_bigquery_sample(
gcs_bucket_name=gcs_bucket_name,
gcs_prefix=gcs_input_uri,
dataset_name=dataset_name,
table_name=table_name,
project_id=project_id,
)
out, _ = capsys.readouterr()

assert "Document entities loaded into BigQuery" in out
assert "Job ID:" in out
assert (
f"Table: /projects/{project_id}/datasets/{dataset_name}/tables/{table_name}"
in out
)

client.delete_dataset(dataset)
4 changes: 2 additions & 2 deletions samples/snippets/test_quickstart_sample.py
@@ -1,4 +1,4 @@
# Copyright 2020 Google LLC
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -31,4 +31,4 @@ def test_quickstart_sample(capsys: pytest.CaptureFixture) -> None:
out, _ = capsys.readouterr()

assert "Number of Pages: 1" in out
assert "Number of Entities: 22" in out
assert "Number of Entities: 35" in out
1 change: 1 addition & 0 deletions setup.py
Expand Up @@ -49,6 +49,7 @@
"proto-plus >= 1.22.0, <2.0.0dev",
"proto-plus >= 1.22.2, <2.0.0dev; python_version>='3.11'",
"grpc-google-iam-v1 >= 0.12.4, < 0.13dev",
"google-cloud-bigquery >= 3.5.0, < 4.0.0dev",
"google-cloud-documentai >= 1.2.1, < 3.0.0dev",
"google-cloud-storage >= 1.31.0, < 3.0.0dev",
"numpy >= 1.18.1",
Expand Down
2 changes: 2 additions & 0 deletions testing/constraints-3.10.txt
Expand Up @@ -6,6 +6,8 @@ libcst
pandas
proto-plus
grpc-google-iam-v1
google-cloud-bigquery
google-cloud-documentai
google-cloud-storage
numpy
pikepdf
2 changes: 2 additions & 0 deletions testing/constraints-3.11.txt
Expand Up @@ -6,6 +6,8 @@ libcst
pandas
proto-plus
grpc-google-iam-v1
google-cloud-bigquery
google-cloud-documentai
google-cloud-storage
numpy
pikepdf
13 changes: 7 additions & 6 deletions testing/constraints-3.7.txt
Expand Up @@ -4,12 +4,13 @@
# Pin the version to the lower bound.
# e.g., if setup.py has "google-cloud-foo >= 1.14.0, < 2.0.0dev",
# Then this file should have google-cloud-foo==1.14.0
google-api-core==1.31.5
libcst== 0.2.5
pandas== 1.0.0
proto-plus== 1.22.0
google-api-core==1.34.0
libcst==0.2.5
pandas==1.0.0
proto-plus==1.22.0
grpc-google-iam-v1==0.12.4
google-cloud-documentai==1.2.1
google-cloud-storage== 1.31.0
google-cloud-bigquery==3.5.0
google-cloud-documentai==2.12.0
google-cloud-storage==2.7.0
numpy==1.18.1
pikepdf==6.2.9
2 changes: 2 additions & 0 deletions testing/constraints-3.8.txt
Expand Up @@ -6,6 +6,8 @@ libcst
pandas
proto-plus
grpc-google-iam-v1
google-cloud-bigquery
google-cloud-documentai
google-cloud-storage
numpy
pikepdf
2 changes: 2 additions & 0 deletions testing/constraints-3.9.txt
Expand Up @@ -6,6 +6,8 @@ libcst
pandas
proto-plus
grpc-google-iam-v1
google-cloud-bigquery
google-cloud-documentai
google-cloud-storage
numpy
pikepdf
45 changes: 43 additions & 2 deletions tests/unit/test_document.py
Expand Up @@ -104,6 +104,9 @@ def test_entities_from_shard():

assert actual[0].mention_text == "$140.00"
assert actual[0].type_ == "vat"
assert actual[1].mention_text == "$140.00"
assert actual[1].type_ == "vat/tax_amount"
assert actual[1].normalized_text == "140 USD"


def test_document_from_document_path_with_single_shard():
Expand All @@ -114,7 +117,9 @@ def test_document_from_document_path_with_single_shard():


def test_document_from_documentai_document_with_single_shard():
with open("tests/unit/resources/0/toolbox_invoice_test-0.json", "r") as f:
with open(
"tests/unit/resources/0/toolbox_invoice_test-0.json", "r", encoding="utf-8"
) as f:
doc = documentai.Document.from_json(f.read())

actual = document.Document.from_documentai_document(documentai_document=doc)
Expand Down Expand Up @@ -360,12 +365,48 @@ def test_get_form_field_by_name(get_bytes_form_parser_mock):
assert actual[0].field_value == "(906) 917-3486"


def test_entities_to_dict(get_bytes_single_file_mock):
doc = document.Document.from_gcs(
gcs_bucket_name="test-directory", gcs_prefix="documentai/output/123456789/0"
)
actual = doc.entities_to_dict()

get_bytes_single_file_mock.assert_called_once()

assert len(actual) == 25
assert actual.get("vat") == "$140.00"
assert actual.get("vat_tax_amount") == "$140.00"


@mock.patch("google.cloud.documentai_toolbox.wrappers.document.bigquery")
def test_entities_to_bigquery(mock_bigquery, get_bytes_single_file_mock):
client = mock_bigquery.Client.return_value

mock_table = mock.Mock()
client.dataset.table.return_value = mock_table

mock_load_job = mock.Mock()
client.load_table_from_json.return_value = mock_load_job

doc = document.Document.from_gcs(
gcs_bucket_name="test-directory", gcs_prefix="documentai/output/123456789/0"
)

actual = doc.entities_to_bigquery(
dataset_name="test_dataset", table_name="test_table", project_id="test_project"
)

get_bytes_single_file_mock.assert_called_once()
mock_bigquery.Client.assert_called_once()

assert actual


@mock.patch("google.cloud.documentai_toolbox.wrappers.document.Pdf")
def test_split_pdf(mock_Pdf, get_bytes_splitter_mock):
doc = document.Document.from_gcs(
gcs_bucket_name="test-directory", gcs_prefix="documentai/output/123456789/0"
)

mock_input_file = mock.Mock()
mock_Pdf.open.return_value.__enter__.return_value.name = mock_input_file

Expand Down

0 comments on commit 494fa86

Please sign in to comment.