diff --git a/docling_core/types/doc/base.py b/docling_core/types/doc/base.py index 2f1eeedf..b823ca8a 100644 --- a/docling_core/types/doc/base.py +++ b/docling_core/types/doc/base.py @@ -4,8 +4,9 @@ # """Define common models across CCS objects.""" -from typing import Annotated, Literal, Optional, Union +from typing import Annotated, List, Literal, Optional, Union +import pandas as pd from pydantic import BaseModel, Field, PositiveInt, StrictStr from docling_core.search.mapping import es_field @@ -152,6 +153,102 @@ class Table(BaseCell): data: Optional[list[list[Union[GlmTableCell, TableCell]]]] = None model: Optional[str] = None + def _get_tablecell_span(self, cell: TableCell, ix: int): + if cell.spans is None: + span = set() + else: + span = set([s[ix] for s in cell.spans]) + if len(span) == 0: + return 1, None, None + return len(span), min(span), max(span) + + def export_to_dataframe(self) -> pd.DataFrame: + """Export the table as a Pandas DataFrame.""" + if self.data is None or self.num_rows == 0 or self.num_cols == 0: + return pd.DataFrame() + + # Count how many rows are column headers + num_headers = 0 + for i, row in enumerate(self.data): + if len(row) == 0: + raise RuntimeError(f"Invalid table. {len(row)=} but {self.num_cols=}.") + + any_header = False + for cell in row: + if cell.obj_type == "col_header": + any_header = True + break + + if any_header: + num_headers += 1 + else: + break + + # Create the column names from all col_headers + columns: Optional[List[str]] = None + if num_headers > 0: + columns = ["" for _ in range(self.num_cols)] + for i in range(num_headers): + for j, cell in enumerate(self.data[i]): + col_name = cell.text + if columns[j] != "": + col_name = f".{col_name}" + columns[j] += col_name + + # Create table data + table_data = [[cell.text for cell in row] for row in self.data[num_headers:]] + + # Create DataFrame + df = pd.DataFrame(table_data, columns=columns) + + return df + + def export_to_html(self) -> str: + """Export the table as html.""" + body = "" + nrows = self.num_rows + ncols = self.num_cols + + if self.data is None: + return "" + for i in range(nrows): + body += "" + for j in range(ncols): + cell: TableCell = self.data[i][j] + + rowspan, rowstart, rowend = self._get_tablecell_span(cell, 0) + colspan, colstart, colend = self._get_tablecell_span(cell, 1) + + if rowstart is not None and rowstart != i: + continue + if colstart is not None and colstart != j: + continue + + if rowstart is None: + rowstart = i + if colstart is None: + colstart = j + + content = cell.text.strip() + label = cell.obj_type + celltag = "td" + if label in ["row_header", "row_multi_header", "row_title"]: + pass + elif label in ["col_header", "col_multi_header"]: + celltag = "th" + + opening_tag = f"{celltag}" + if rowspan > 1: + opening_tag += f' rowspan="{rowspan}"' + if colspan > 1: + opening_tag += f' colspan="{colspan}"' + + body += f"<{opening_tag}>{content}" + body += "" + body = f"{body}
" + + return body + # FIXME: let's add some figure specific data-types later class Figure(BaseCell):