Skip to content

Commit

Permalink
page.py Class restructuring for clarity
Browse files Browse the repository at this point in the history
  • Loading branch information
holtskinner committed Jun 13, 2024
1 parent 1e871f6 commit 46fcb37
Showing 1 changed file with 84 additions and 82 deletions.
166 changes: 84 additions & 82 deletions google/cloud/documentai_toolbox/wrappers/page.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from abc import ABC
import dataclasses
from functools import cached_property
from typing import List, Optional, Type, cast
from typing import Iterable, List, Optional, Type, cast

import pandas as pd

Expand All @@ -44,16 +44,16 @@ class Table:
_page: "Page" = dataclasses.field(repr=False)

@cached_property
def body_rows(self):
return _table_rows_from_documentai_table_rows(
table_rows=list(self.documentai_object.body_rows),
def body_rows(self) -> List[List[str]]:
return Table._extract_table_rows(
table_rows=self.documentai_object.body_rows,
text=self._page._document_text,
)

@cached_property
def header_rows(self):
return _table_rows_from_documentai_table_rows(
table_rows=list(self.documentai_object.header_rows),
def header_rows(self) -> List[List[str]]:
return Table._extract_table_rows(
table_rows=self.documentai_object.header_rows,
text=self._page._document_text,
)

Expand All @@ -75,6 +75,31 @@ def to_dataframe(self) -> pd.DataFrame:

return pd.DataFrame(self.body_rows, columns=columns)

@staticmethod
def _extract_table_rows(
table_rows: Iterable[documentai.Document.Page.Table.TableRow], text: str
) -> List[List[str]]:
r"""Returns a list of rows from table_rows.
Args:
table_rows (List[documentai.Document.Page.Table.TableRow]):
Required. A documentai.Document.Page.Table.TableRow.
text (str):
Required. UTF-8 encoded text in reading order
from the document.
Returns:
List[List[str]]:
A list of table rows.
"""
return [
[
_text_from_layout(cell.layout, text).replace("\n", "")
for cell in row.cells
]
for row in table_rows
]


@dataclasses.dataclass
class FormField:
Expand All @@ -95,21 +120,35 @@ class FormField:
_page: "Page" = dataclasses.field(repr=False)

@cached_property
def field_name(self):
return _trim_text(
def field_name(self) -> str:
return FormField._trim_text(
_text_from_layout(
self.documentai_object.field_name, self._page._document_text
)
)

@cached_property
def field_value(self):
return _trim_text(
def field_value(self) -> str:
return FormField._trim_text(
_text_from_layout(
self.documentai_object.field_value, self._page._document_text
)
)

@staticmethod
def _trim_text(text: str) -> str:
r"""Remove extra space characters from text (blank, newline, tab, etc.)
Args:
text (str):
Required. UTF-8 encoded text in reading order
from the document.
Returns:
str:
Text without trailing spaces/newlines
"""
return text.strip().replace("\n", " ")


@dataclasses.dataclass
class _BasePageElement(ABC):
Expand All @@ -119,7 +158,7 @@ class _BasePageElement(ABC):
_page: "Page" = dataclasses.field(repr=False)

@cached_property
def text(self):
def text(self) -> str:
"""
Text of the page element.
"""
Expand All @@ -137,6 +176,35 @@ def hocr_bounding_box(self):
page_dimension=self._page.documentai_object.dimension,
)

def _get_children_of_element(
self, children: List[ElementWithLayout]
) -> List[ElementWithLayout]:
r"""Returns a list of children inside element.
Args:
children (List[ElementWithLayout]):
Required. List of wrapped children.
Returns:
List[ElementWithLayout]:
A list of wrapped children that are inside an element.
"""
start_index = self.documentai_object.layout.text_anchor.text_segments[
0
].start_index
end_index = self.documentai_object.layout.text_anchor.text_segments[0].end_index

return [
child
for child in children
if start_index
<= child.documentai_object.layout.text_anchor.text_segments[0].start_index
< end_index
and start_index
< child.documentai_object.layout.text_anchor.text_segments[0].end_index
<= end_index
]


@dataclasses.dataclass
class Symbol(_BasePageElement):
Expand Down Expand Up @@ -173,7 +241,7 @@ class Token(_BasePageElement):
def symbols(self):
return cast(
List[Symbol],
_get_children_of_element(self.documentai_object, self._page.symbols),
self._get_children_of_element(self._page.symbols),
)


Expand All @@ -194,7 +262,7 @@ class Line(_BasePageElement):
def tokens(self):
return cast(
List[Token],
_get_children_of_element(self.documentai_object, self._page.tokens),
self._get_children_of_element(self._page.tokens),
)


Expand All @@ -215,7 +283,7 @@ class Paragraph(_BasePageElement):
def lines(self):
return cast(
List[Line],
_get_children_of_element(self.documentai_object, self._page.lines),
self._get_children_of_element(self._page.lines),
)


Expand All @@ -236,7 +304,7 @@ class Block(_BasePageElement):
def paragraphs(self):
return cast(
List[Paragraph],
_get_children_of_element(self.documentai_object, self._page.paragraphs),
self._get_children_of_element(self._page.paragraphs),
)


Expand All @@ -258,28 +326,6 @@ def hocr_bounding_box(self):
return None


def _table_rows_from_documentai_table_rows(
table_rows: List[documentai.Document.Page.Table.TableRow], text: str
) -> List[List[str]]:
r"""Returns a list of rows from table_rows.
Args:
table_rows (List[documentai.Document.Page.Table.TableRow]):
Required. A documentai.Document.Page.Table.TableRow.
text (str):
Required. UTF-8 encoded text in reading order
from the document.
Returns:
List[List[str]]:
A list of table rows.
"""
return [
[_text_from_layout(cell.layout, text).replace("\n", "") for cell in row.cells]
for row in table_rows
]


def _get_hocr_bounding_box(
element_with_layout: ElementWithLayout,
page_dimension: documentai.Document.Page.Dimension,
Expand Down Expand Up @@ -334,50 +380,6 @@ def _text_from_layout(layout: documentai.Document.Page.Layout, text: str) -> str
)


def _get_children_of_element(
element: ElementWithLayout, children: List[ElementWithLayout]
) -> List[ElementWithLayout]:
r"""Returns a list of children inside element.
Args:
element (ElementWithLayout):
Required. A element in a page.
children (List[ElementWithLayout]):
Required. List of wrapped children.
Returns:
List[ElementWithLayout]:
A list of wrapped children that are inside a element.
"""
start_index = element.layout.text_anchor.text_segments[0].start_index
end_index = element.layout.text_anchor.text_segments[0].end_index

return [
child
for child in children
if start_index
<= child.documentai_object.layout.text_anchor.text_segments[0].start_index
< end_index
and start_index
< child.documentai_object.layout.text_anchor.text_segments[0].end_index
<= end_index
]


def _trim_text(text: str) -> str:
r"""Remove extra space characters from text (blank, newline, tab, etc.)
Args:
text (str):
Required. UTF-8 encoded text in reading order
from the document.
Returns:
str:
Text without trailing spaces/newlines
"""
return text.strip().replace("\n", " ")


@dataclasses.dataclass
class Page:
"""Represents a wrapped documentai.Document.Page .
Expand Down

0 comments on commit 46fcb37

Please sign in to comment.