Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Added form_fields_to_bigquery() method #104

Merged
merged 7 commits into from Apr 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
33 changes: 16 additions & 17 deletions google/cloud/documentai_toolbox/utilities/gcs_utilities.py
Expand Up @@ -26,33 +26,32 @@
from google.cloud.documentai_toolbox import constants


def _get_storage_client(module: str = None):
r"""Returns a Storage client with custom user agent header.
def _get_client_info(module: str = None) -> client_info.ClientInfo:
r"""Returns a custom user agent header.

Returns:
storage.Client.
client_info.ClientInfo.

"""
client_library_version = documentai_toolbox.__version__

if module:
user_agent = (
f"{constants.USER_AGENT_PRODUCT}/{documentai_toolbox.__version__}-{module}"
)
client_library_version = f"{client_library_version}-{module}"

info = client_info.ClientInfo(
client_library_version=f"{documentai_toolbox.__version__}-{module}",
user_agent=user_agent,
)
return storage.Client(client_info=info)
return client_info.ClientInfo(
client_library_version=client_library_version,
user_agent=f"{constants.USER_AGENT_PRODUCT}/{client_library_version}",
)

user_agent = f"{constants.USER_AGENT_PRODUCT}/{documentai_toolbox.__version__}"

info = client_info.ClientInfo(
client_library_version=documentai_toolbox.__version__,
user_agent=user_agent,
)
def _get_storage_client(module: str = None) -> storage.Client:
r"""Returns a Storage client with custom user agent header.

return storage.Client(client_info=info)
Returns:
storage.Client.

"""
return storage.Client(client_info=_get_client_info(module))


def get_bytes(gcs_bucket_name: str, gcs_prefix: str) -> List[bytes]:
Expand Down
182 changes: 152 additions & 30 deletions google/cloud/documentai_toolbox/wrappers/document.py
Expand Up @@ -234,6 +234,106 @@ def _get_batch_process_metadata(
return metadata


def _insert_into_dictionary_with_list(dic: Dict, key: str, value: str) -> Dict:
r"""Inserts value into a dictionary that can contain lists.

Args:
dic (Dict):
Required. The dictionary to insert into.
key (str):
Required. The key to be created or inserted into.
value (str):
Required. The value to be inserted.

Returns:
Dict:
The dictionary after adding the key value pair.
"""
existing_value = dic.get(key)

if existing_value:
# For duplicate keys,
# Change Type to a List if not already
if not isinstance(existing_value, list):
existing_value = [existing_value]

existing_value.append(value)
dic[key] = existing_value
else:
dic[key] = value

return dic


def _bigquery_column_name(input_string: str) -> str:
r"""Converts a string into a BigQuery column name.
https://cloud.google.com/bigquery/docs/schemas#column_names

Args:
input_string (str):
Required: The string to convert.
Returns:
str
The converted string.

"""
char_map: Dict[str, str] = {
r":|;|\(|\)|\[|\]|,|\.|\?|\!|\'|\n": "",
r"/| ": "_",
r"#": "num",
r"@": "at",
}

for key, value in char_map.items():
input_string = re.sub(key, value, input_string)

return input_string.lower()


def _dict_to_bigquery(
dic: Dict,
dataset_name: str,
table_name: str,
project_id: Optional[str],
) -> bigquery.job.LoadJob:
r"""Loads dictionary to a BigQuery table.

Args:
dic (Dict):
Required: The dictionary to insert.
dataset_name (str):
Required. Name of the BigQuery dataset.
table_name (str):
Required. Name of the BigQuery table.
project_id (Optional[str]):
Optional. Project ID containing the BigQuery table. If not passed, falls back to the default inferred from the environment.
Returns:
bigquery.job.LoadJob:
The BigQuery LoadJob for adding the dictionary.

"""
bq_client = bigquery.Client(
project=project_id, client_info=gcs_utilities._get_client_info()
)
table_ref = bigquery.DatasetReference(
project=project_id, dataset_id=dataset_name
).table(table_name)

job_config = bigquery.LoadJobConfig(
schema_update_options=[
bigquery.SchemaUpdateOption.ALLOW_FIELD_ADDITION,
bigquery.SchemaUpdateOption.ALLOW_FIELD_RELAXATION,
],
source_format=bigquery.SourceFormat.NEWLINE_DELIMITED_JSON,
)

return bq_client.load_table_from_json(
json_rows=[dic],
destination=table_ref,
job_config=job_config,
)


@dataclasses.dataclass
class Document:
r"""Represents a wrapped `Document`.
Expand Down Expand Up @@ -476,6 +576,49 @@ def get_form_field_by_name(self, target_field: str) -> List[FormField]:

return found_fields

def form_fields_to_dict(self) -> Dict:
r"""Returns Dictionary of form fields in document.

Returns:
Dict:
The Dict of the form fields indexed by type.

"""
form_fields_dict: Dict = {}
for p in self.pages:
for form_field in p.form_fields:
field_name = _bigquery_column_name(form_field.field_name)
form_fields_dict = _insert_into_dictionary_with_list(
form_fields_dict, field_name, form_field.field_value
)

return form_fields_dict

def form_fields_to_bigquery(
self, dataset_name: str, table_name: str, project_id: Optional[str] = None
) -> bigquery.job.LoadJob:
r"""Adds extracted form fields to a BigQuery table.

Args:
dataset_name (str):
Required. Name of the BigQuery dataset.
table_name (str):
Required. Name of the BigQuery table.
project_id (Optional[str]):
Optional. Project ID containing the BigQuery table. If not passed, falls back to the default inferred from the environment.
Returns:
bigquery.job.LoadJob:
The BigQuery LoadJob for adding the form fields.

"""

return _dict_to_bigquery(
self.form_fields_to_dict(),
dataset_name,
table_name,
project_id,
)

def get_entity_by_type(self, target_type: str) -> List[Entity]:
r"""Returns the list of Entities of target_type.

Expand All @@ -500,20 +643,10 @@ def entities_to_dict(self) -> Dict:
"""
entities_dict: Dict = {}
for entity in self.entities:
entity_type = entity.type_.replace("/", "_")

existing_entity = entities_dict.get(entity_type)
if not existing_entity:
entities_dict[entity_type] = entity.mention_text
continue

# For entities that can have multiple (e.g. line_item)
# Change Entity Type to a List
if not isinstance(existing_entity, list):
existing_entity = [existing_entity]

existing_entity.append(entity.mention_text)
entities_dict[entity_type] = existing_entity
entity_type = _bigquery_column_name(entity.type_)
entities_dict = _insert_into_dictionary_with_list(
entities_dict, entity_type, entity.mention_text
)

return entities_dict

Expand All @@ -534,23 +667,12 @@ def entities_to_bigquery(
The BigQuery LoadJob for adding the entities.

"""
bq_client = bigquery.Client(project=project_id)
table_ref = bigquery.DatasetReference(
project=project_id, dataset_id=dataset_name
).table(table_name)

job_config = bigquery.LoadJobConfig(
schema_update_options=[
bigquery.SchemaUpdateOption.ALLOW_FIELD_ADDITION,
bigquery.SchemaUpdateOption.ALLOW_FIELD_RELAXATION,
],
source_format=bigquery.SourceFormat.NEWLINE_DELIMITED_JSON,
)

return bq_client.load_table_from_json(
json_rows=[self.entities_to_dict()],
destination=table_ref,
job_config=job_config,
return _dict_to_bigquery(
self.entities_to_dict(),
dataset_name,
table_name,
project_id,
)

def split_pdf(self, pdf_path: str, output_path: str) -> List[str]:
Expand Down
5 changes: 5 additions & 0 deletions samples/snippets/entities_to_bigquery_sample.py
Expand Up @@ -42,6 +42,11 @@ def entities_to_bigquery_sample(
dataset_name=dataset_name, table_name=table_name, project_id=project_id
)

# Also supported:
# job = wrapped_document.form_fields_to_bigquery(
# dataset_name=dataset_name, table_name=table_name, project_id=project_id
# )

print("Document entities loaded into BigQuery")
print(f"Job ID: {job.job_id}")
print(f"Table: {job.destination.path}")
Expand Down
52 changes: 51 additions & 1 deletion tests/unit/test_document.py
Expand Up @@ -186,7 +186,7 @@ def test_get_batch_process_metadata_with_no_metadata(mock_docai):


@mock.patch("google.cloud.documentai_toolbox.wrappers.document.documentai")
def test_document_from_batch_process_operation_with_invalid_metadata_type(mock_docai):
def test_get_batch_process_metadata_with_invalid_metadata_type(mock_docai):
with pytest.raises(
ValueError,
match="Operation metadata type is not",
Expand All @@ -206,6 +206,19 @@ def test_document_from_batch_process_operation_with_invalid_metadata_type(mock_d
document._get_batch_process_metadata(location, operation_name)


def test_bigquery_column_name():
string_map = {
"Phone #:": "phone_num",
"Emergency Contact:": "emergency_contact",
"Marital Status:": "marital_status",
"Are you currently taking any medication? (If yes, please describe):": "are_you_currently_taking_any_medication_if_yes_please_describe",
"Describe your medical concerns (symptoms, diagnoses, etc):": "describe_your_medical_concerns_symptoms_diagnoses_etc",
}

for key, value in string_map.items():
assert document._bigquery_column_name(key) == value


def test_document_from_document_path_with_single_shard():
actual = document.Document.from_document_path(
document_path="tests/unit/resources/0/toolbox_invoice_test-0.json"
Expand Down Expand Up @@ -401,6 +414,43 @@ def test_get_form_field_by_name(get_bytes_form_parser_mock):
assert actual[0].field_value == "(906) 917-3486"


def test_form_fields_to_dict(get_bytes_form_parser_mock):
doc = document.Document.from_gcs(
gcs_bucket_name="test-directory", gcs_prefix="documentai/output/123456789/0"
)
actual = doc.form_fields_to_dict()

get_bytes_form_parser_mock.assert_called_once()

assert len(actual) == 17
assert actual.get("address") == "24 Barney Lane"
assert actual.get("city") == "Towaco"


@mock.patch("google.cloud.documentai_toolbox.wrappers.document.bigquery")
def test_form_fields_to_bigquery(mock_bigquery, get_bytes_form_parser_mock):
client = mock_bigquery.Client.return_value

mock_table = mock.Mock()
client.dataset.table.return_value = mock_table

mock_load_job = mock.Mock()
client.load_table_from_json.return_value = mock_load_job

doc = document.Document.from_gcs(
gcs_bucket_name="test-directory", gcs_prefix="documentai/output/123456789/0"
)

actual = doc.form_fields_to_bigquery(
dataset_name="test_dataset", table_name="test_table", project_id="test_project"
)

get_bytes_form_parser_mock.assert_called_once()
mock_bigquery.Client.assert_called_once()

assert actual


def test_entities_to_dict(get_bytes_single_file_mock):
doc = document.Document.from_gcs(
gcs_bucket_name="test-directory", gcs_prefix="documentai/output/123456789/0"
Expand Down