Skip to content

Commit

Permalink
feat: added get_document and list_document functions (#7)
Browse files Browse the repository at this point in the history
* feat: add get_document and list_document functions

* fixed tests

* lint fix

* added tests and changed DocumentWrapper

* fixed failing test

* updated tests

* changed DocStrings and tests

* fixed failing test

* changed name and return type of list_documents

* updated failing tests

* lint fix

* updated get_document name to get_shards
  • Loading branch information
galz10 committed Oct 3, 2022
1 parent e360dce commit b5ac4ca
Show file tree
Hide file tree
Showing 2 changed files with 205 additions and 9 deletions.
57 changes: 54 additions & 3 deletions google/cloud/documentai_toolbox/wrappers/document_wrapper.py
Expand Up @@ -60,8 +60,8 @@ def _get_bytes(output_bucket: str, output_prefix: str) -> List[bytes]:
return result


def _read_output(gcs_prefix: str) -> List[documentai.Document]:
"""Returns a list of Document shards."""
def _get_shards(gcs_prefix: str) -> List[documentai.Document]:
"""Gets shards from gcs_prefix location and returns a list of shards."""

shards = []

Expand All @@ -85,6 +85,57 @@ def _read_output(gcs_prefix: str) -> List[documentai.Document]:
return shards


def print_gcs_document_tree(gcs_prefix: str) -> None:
"""Prints a tree of Documents in gcs_prefix location."""
display_filename_prefix_middle = "├──"
display_filename_prefix_last = "└──"

match = re.match(r"gs://(.*?)/(.*)", gcs_prefix)

if match is None:
raise ValueError("gcs_prefix does not match accepted format")

output_bucket, output_prefix = match.groups()

file_check = re.match(r"(.*[.].*$)", output_prefix)

if file_check is not None:
raise ValueError("gcs_prefix cannot contain file types")

storage_client = storage.Client()

blob_list = storage_client.list_blobs(output_bucket, prefix=output_prefix)

path_list = {}

for blob in blob_list:
file_path = blob.name.split("/")
file_name = file_path.pop()

file_path2 = "/".join(file_path)

if file_path2 in path_list:
path_list[file_path2] += f"{file_name},"
else:
path_list[file_path2] = f"{file_name},"

for key in path_list:
a = path_list[key].split(",")
a.pop()
print(f"{key}")
togo = 4
for idx, val in enumerate(a):
if idx == len(a) - 1:
if len(a) > 4:
print("│ ....")
print(f"{display_filename_prefix_last}{val}\n")
elif len(a) > 4 and togo != -1:
togo -= 1
print(f"{display_filename_prefix_middle}{val}")
elif len(a) <= 4:
print(f"{display_filename_prefix_middle}{val}")


@dataclasses.dataclass
class DocumentWrapper:
"""Represents a wrapped Document.
Expand All @@ -98,7 +149,7 @@ class DocumentWrapper:
gcs_prefix: str

def __post_init__(self):
self._shards = _read_output(self.gcs_prefix)
self._shards = _get_shards(gcs_prefix=self.gcs_prefix)
self.pages = _pages_from_shards(shards=self._shards)
self.entities = _entities_from_shards(shards=self._shards)

Expand Down
157 changes: 151 additions & 6 deletions tests/unit/test_document_wrapper.py
Expand Up @@ -27,6 +27,7 @@
from google.cloud.documentai_toolbox.wrappers import DocumentWrapper, document_wrapper

from google.cloud import documentai
from google.cloud import storage


def get_bytes(file_name):
Expand All @@ -38,22 +39,22 @@ def get_bytes(file_name):
return result


def test_read_output_with_gcs_uri_contains_file_type():
def test_get_shards_with_gcs_uri_contains_file_type():
with pytest.raises(ValueError, match="gcs_prefix cannot contain file types"):
document_wrapper._read_output(
document_wrapper._get_shards(
"gs://test-directory/documentai/output/123456789/0.json"
)


def test_read_output_with_invalid_gcs_uri():
def test_get_shards_with_invalid_gcs_uri():
with pytest.raises(ValueError, match="gcs_prefix does not match accepted format"):
document_wrapper._read_output("test-directory/documentai/output/")
document_wrapper._get_shards("test-directory/documentai/output/")


def test_read_output_with_valid_gcs_uri():
def test_get_shards_with_valid_gcs_uri():
with mock.patch.object(document_wrapper, "_get_bytes") as factory:
factory.return_value = get_bytes("tests/unit/resources/0")
actual = document_wrapper._read_output(
actual = document_wrapper._get_shards(
"gs://test-directory/documentai/output/123456789/0"
)
# We are testing only one of the fields to make sure the file content could be loaded.
Expand Down Expand Up @@ -92,3 +93,147 @@ def test_document_wrapper_with_multiple_shards():
factory.return_value = get_bytes("tests/unit/resources/1")
actual = DocumentWrapper("gs://test-directory/documentai/output/123456789/1")
assert len(actual.pages) == 48


@mock.patch("google.cloud.documentai_toolbox.wrappers.document_wrapper.storage")
def test_get_bytes(mock_storage):

client = mock_storage.Client.return_value

mock_bucket = mock.Mock()
mock_bucket.blob.return_value.download_as_string.return_value = "test".encode(
"utf-8"
)

client.Bucket.return_value = mock_bucket

blobs = [
storage.Blob(
name="gs://test-directory/documentai/output/123456789/1/test_shard1.json",
bucket=mock_bucket,
),
storage.Blob(
name="gs://test-directory/documentai/output/123456789/1/test_shard2.json",
bucket=mock_bucket,
),
]

client.list_blobs.return_value = blobs

actual = document_wrapper._get_bytes(
"gs://test-directory/documentai/", "output/123456789/1"
)
mock_storage.Client.assert_called_once()

assert actual == [b"", b""]


@mock.patch("google.cloud.documentai_toolbox.wrappers.document_wrapper.storage")
def test_print_gcs_document_tree_with_3_documents(mock_storage, capfd):

client = mock_storage.Client.return_value

mock_bucket = mock.Mock()

client.Bucket.return_value = mock_bucket

blobs = [
storage.Blob(
name="gs://test-directory/documentai/output/123456789/1/test_shard1.json",
bucket="gs://test-directory/documentai/output/123456789/1",
),
storage.Blob(
name="gs://test-directory/documentai/output/123456789/1/test_shard2.json",
bucket="gs://test-directory/documentai/output/123456789/1",
),
storage.Blob(
name="gs://test-directory/documentai/output/123456789/1/test_shard3.json",
bucket="gs://test-directory/documentai/output/123456789/1",
),
]

client.list_blobs.return_value = blobs

document_wrapper.print_gcs_document_tree(
"gs://test-directory/documentai/output/123456789/1"
)

mock_storage.Client.assert_called_once()

out, err = capfd.readouterr()
assert (
out
== """gs://test-directory/documentai/output/123456789/1
├──test_shard1.json
├──test_shard2.json
└──test_shard3.json\n\n"""
)


@mock.patch("google.cloud.documentai_toolbox.wrappers.document_wrapper.storage")
def test_print_gcs_document_tree_with_more_than_5_document(mock_storage, capfd):

client = mock_storage.Client.return_value

mock_bucket = mock.Mock()

client.Bucket.return_value = mock_bucket

blobs = [
storage.Blob(
name="gs://test-directory/documentai/output/123456789/1/test_shard1.json",
bucket="gs://test-directory/documentai/output/123456789/1",
),
storage.Blob(
name="gs://test-directory/documentai/output/123456789/1/test_shard2.json",
bucket="gs://test-directory/documentai/output/123456789/1",
),
storage.Blob(
name="gs://test-directory/documentai/output/123456789/1/test_shard3.json",
bucket="gs://test-directory/documentai/output/123456789/1",
),
storage.Blob(
name="gs://test-directory/documentai/output/123456789/1/test_shard4.json",
bucket="gs://test-directory/documentai/output/123456789/1",
),
storage.Blob(
name="gs://test-directory/documentai/output/123456789/1/test_shard5.json",
bucket="gs://test-directory/documentai/output/123456789/1",
),
storage.Blob(
name="gs://test-directory/documentai/output/123456789/1/test_shard6.json",
bucket="gs://test-directory/documentai/output/123456789/1",
),
]
client.list_blobs.return_value = blobs

document_wrapper.print_gcs_document_tree(
"gs://test-directory/documentai/output/123456789/1"
)

mock_storage.Client.assert_called_once()

out, err = capfd.readouterr()
assert (
out
== """gs://test-directory/documentai/output/123456789/1
├──test_shard1.json
├──test_shard2.json
├──test_shard3.json
├──test_shard4.json
├──test_shard5.json
│ ....
└──test_shard6.json\n\n"""
)


def test_print_gcs_document_tree_with_gcs_uri_contains_file_type():
with pytest.raises(ValueError, match="gcs_prefix cannot contain file types"):
document_wrapper.print_gcs_document_tree(
"gs://test-directory/documentai/output/123456789/1/test_file.json"
)


def test_print_gcs_document_tree_with_invalid_gcs_uri():
with pytest.raises(ValueError, match="gcs_prefix does not match accepted format"):
document_wrapper.print_gcs_document_tree("documentai/output/123456789/1")

0 comments on commit b5ac4ca

Please sign in to comment.