Skip to content

Commit

Permalink
Further refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
holtskinner committed Jun 13, 2024
1 parent 46fcb37 commit 867b24e
Showing 1 changed file with 55 additions and 74 deletions.
129 changes: 55 additions & 74 deletions google/cloud/documentai_toolbox/wrappers/page.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from abc import ABC
import dataclasses
from functools import cached_property
from typing import Iterable, List, Optional, Type, cast
from typing import Iterable, List, Optional, Type

import pandas as pd

Expand All @@ -45,56 +45,48 @@ class Table:

@cached_property
def body_rows(self) -> List[List[str]]:
return Table._extract_table_rows(
table_rows=self.documentai_object.body_rows,
text=self._page._document_text,
)
return self._extract_table_rows(self.documentai_object.body_rows)

@cached_property
def header_rows(self) -> List[List[str]]:
return Table._extract_table_rows(
table_rows=self.documentai_object.header_rows,
text=self._page._document_text,
)
return self._extract_table_rows(self.documentai_object.header_rows)

def to_dataframe(self) -> pd.DataFrame:
r"""Returns pd.DataFrame from documentai.table
"""Returns pd.DataFrame from documentai.table
Returns:
pd.DataFrame:
The DataFrame of the table.
"""
if not self.body_rows:
return pd.DataFrame(columns=self.header_rows)

if self.header_rows:
columns = pd.MultiIndex.from_arrays(self.header_rows)
else:
columns = [None] * len(self.body_rows[0])
columns = (
pd.MultiIndex.from_arrays(self.header_rows)
if self.header_rows
else [None] * len(self.body_rows[0])
)

return pd.DataFrame(self.body_rows, columns=columns)

@staticmethod
def _extract_table_rows(
table_rows: Iterable[documentai.Document.Page.Table.TableRow], text: str
self, table_rows: Iterable[documentai.Document.Page.Table.TableRow]
) -> List[List[str]]:
r"""Returns a list of rows from table_rows.
"""Returns a list of rows from table_rows.
Args:
table_rows (List[documentai.Document.Page.Table.TableRow]):
Required. A documentai.Document.Page.Table.TableRow.
text (str):
Required. UTF-8 encoded text in reading order
from the document.
Returns:
List[List[str]]:
A list of table rows.
"""
return [
[
_text_from_layout(cell.layout, text).replace("\n", "")
_text_from_layout(cell.layout, self._page._document_text).replace(
"\n", ""
)
for cell in row.cells
]
for row in table_rows
Expand All @@ -121,28 +113,29 @@ class FormField:

@cached_property
def field_name(self) -> str:
return FormField._trim_text(
return self._trim_text(
_text_from_layout(
self.documentai_object.field_name, self._page._document_text
self.documentai_object.field_name.layout, self._page._document_text
)
)

@cached_property
def field_value(self) -> str:
return FormField._trim_text(
return self._trim_text(
_text_from_layout(
self.documentai_object.field_value, self._page._document_text
self.documentai_object.field_value.layout, self._page._document_text
)
)

@staticmethod
def _trim_text(text: str) -> str:
r"""Remove extra space characters from text (blank, newline, tab, etc.)
"""Remove extra space characters from text (blank, newline, tab, etc.)
Args:
text (str):
Required. UTF-8 encoded text in reading order
from the document.
Returns:
str:
Text without trailing spaces/newlines
Expand All @@ -163,46 +156,47 @@ def text(self) -> str:
Text of the page element.
"""
return _text_from_layout(
layout=self.documentai_object.layout, text=self._page._document_text
self.documentai_object.layout, self._page._document_text
)

@cached_property
def hocr_bounding_box(self):
def hocr_bounding_box(self) -> Optional[str]:
"""
hOCR bounding box of the page element.
"""
return _get_hocr_bounding_box(
element_with_layout=self.documentai_object,
page_dimension=self._page.documentai_object.dimension,
self.documentai_object, self._page.documentai_object.dimension
)

@cached_property
def _text_segment(self) -> documentai.Document.TextAnchor.TextSegment:
"""
Element text section
"""
return self.documentai_object.layout.text_anchor.text_segments[0]

def _get_children_of_element(
self, children: List[ElementWithLayout]
) -> List[ElementWithLayout]:
r"""Returns a list of children inside element.
self, potential_children: List["_BasePageElement"]
) -> List["_BasePageElement"]:
"""Returns a list of children inside element.
Args:
children (List[ElementWithLayout]):
potential_children (List[_BasePageElement]):
Required. List of wrapped children.
Returns:
List[ElementWithLayout]:
List[_BasePageElement]:
A list of wrapped children that are inside an element.
"""
start_index = self.documentai_object.layout.text_anchor.text_segments[
0
].start_index
end_index = self.documentai_object.layout.text_anchor.text_segments[0].end_index

return [
child
for child in children
if start_index
<= child.documentai_object.layout.text_anchor.text_segments[0].start_index
< end_index
and start_index
< child.documentai_object.layout.text_anchor.text_segments[0].end_index
<= end_index
for child in potential_children
if self._text_segment.start_index
<= child._text_segment.start_index
< self._text_segment.end_index
and self._text_segment.start_index
< child._text_segment.end_index
<= self._text_segment.end_index
]


Expand All @@ -219,7 +213,7 @@ class Symbol(_BasePageElement):
"""

@cached_property
def hocr_bounding_box(self):
def hocr_bounding_box(self) -> Optional[str]:
# Symbols are not represented in hOCR
return None

Expand All @@ -238,11 +232,8 @@ class Token(_BasePageElement):
"""

@cached_property
def symbols(self):
return cast(
List[Symbol],
self._get_children_of_element(self._page.symbols),
)
def symbols(self) -> List[Symbol]:
return self._get_children_of_element(self._page.symbols)


@dataclasses.dataclass
Expand All @@ -259,11 +250,8 @@ class Line(_BasePageElement):
"""

@cached_property
def tokens(self):
return cast(
List[Token],
self._get_children_of_element(self._page.tokens),
)
def tokens(self) -> List[Token]:
return self._get_children_of_element(self._page.tokens)


@dataclasses.dataclass
Expand All @@ -280,11 +268,8 @@ class Paragraph(_BasePageElement):
"""

@cached_property
def lines(self):
return cast(
List[Line],
self._get_children_of_element(self._page.lines),
)
def lines(self) -> List[Line]:
return self._get_children_of_element(self._page.lines)


@dataclasses.dataclass
Expand All @@ -296,16 +281,13 @@ class Block(_BasePageElement):
Required. The original object.
text (str):
Required. The text of the Block.
_paragraphs (List[Paragraph]):
paragraphs (List[Paragraph]):
Optional. The Paragraphs contained within the Block.
"""

@cached_property
def paragraphs(self):
return cast(
List[Paragraph],
self._get_children_of_element(self._page.paragraphs),
)
def paragraphs(self) -> List[Paragraph]:
return self._get_children_of_element(self._page.paragraphs)


@dataclasses.dataclass
Expand All @@ -330,7 +312,7 @@ def _get_hocr_bounding_box(
element_with_layout: ElementWithLayout,
page_dimension: documentai.Document.Page.Dimension,
) -> Optional[str]:
r"""Returns a hOCR bounding box string.
"""Returns a hOCR bounding box string.
Args:
element_with_layout (ElementWithLayout):
Expand All @@ -340,7 +322,7 @@ def _get_hocr_bounding_box(
Returns:
Optional[str]:
hOCR bounding box sring.
hOCR bounding box string.
"""
if not element_with_layout.layout.bounding_poly:
return None
Expand Down Expand Up @@ -483,6 +465,5 @@ def blocks(self):
@cached_property
def hocr_bounding_box(self):
return _get_hocr_bounding_box(
element_with_layout=self.documentai_object,
page_dimension=self.documentai_object.dimension,
self.documentai_object, self.documentai_object.dimension
)

0 comments on commit 867b24e

Please sign in to comment.