Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 73 additions & 31 deletions docling_core/experimental/serializer/doctags.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import html
from enum import Enum
from pathlib import Path
from typing import Optional, Union
from typing import Dict, List, Optional, Union

from pydantic import AnyUrl, BaseModel
from typing_extensions import override
Expand All @@ -23,6 +23,7 @@
from docling_core.experimental.serializer.common import CommonParams, DocSerializer
from docling_core.types.doc.document import (
CodeItem,
DocItem,
DoclingDocument,
Formatting,
FormItem,
Expand Down Expand Up @@ -54,7 +55,6 @@ class Mode(str, Enum):
MINIFIED = "minified"
HUMAN_FRIENDLY = "human_friendly"

new_line: str = ""
xsize: int = 500
ysize: int = 500
add_location: bool = True
Expand All @@ -67,13 +67,13 @@ class Mode(str, Enum):
mode: Mode = Mode.HUMAN_FRIENDLY


def _get_delim(mode: DocTagsParams.Mode) -> str:
if mode == DocTagsParams.Mode.HUMAN_FRIENDLY:
def _get_delim(params: DocTagsParams) -> str:
if params.mode == DocTagsParams.Mode.HUMAN_FRIENDLY:
delim = "\n"
elif mode == DocTagsParams.Mode.MINIFIED:
elif params.mode == DocTagsParams.Mode.MINIFIED:
delim = ""
else:
raise RuntimeError(f"Unknown DocTags mode: {mode}")
raise RuntimeError(f"Unknown DocTags mode: {params.mode}")
return delim


Expand Down Expand Up @@ -102,7 +102,6 @@ def serialize(
if params.add_location:
location = item.get_location_tokens(
doc=doc,
new_line=params.new_line,
xsize=params.xsize,
ysize=params.ysize,
)
Expand Down Expand Up @@ -158,7 +157,6 @@ def serialize(
if params.add_location:
body += item.get_location_tokens(
doc=doc,
new_line=params.new_line,
xsize=params.xsize,
ysize=params.ysize,
)
Expand All @@ -178,15 +176,14 @@ def serialize(
body += f"<{DocumentToken.CAPTION.value}>"
for caption in item.captions:
if caption.cref not in doc_serializer.get_excluded_refs(**kwargs):
body += caption.resolve(doc).get_location_tokens(
doc=doc,
new_line=params.new_line,
xsize=params.xsize,
ysize=params.ysize,
)
if isinstance(cap := caption.resolve(doc), DocItem):
body += cap.get_location_tokens(
doc=doc,
xsize=params.xsize,
ysize=params.ysize,
)
body += f"{text.strip()}"
body += f"</{DocumentToken.CAPTION.value}>"
body += f"{params.new_line}"

if body:
body = _wrap(text=body, wrap_tag=DocumentToken.OTSL.value)
Expand All @@ -208,15 +205,13 @@ def serialize(
) -> SerializationResult:
"""Serializes the passed item."""
params = DocTagsParams(**kwargs)

parts: list[str] = []

if item.self_ref not in doc_serializer.get_excluded_refs(**kwargs):
body = ""
if params.add_location:
body += item.get_location_tokens(
doc=doc,
new_line=params.new_line,
xsize=params.xsize,
ysize=params.ysize,
)
Expand Down Expand Up @@ -246,13 +241,13 @@ def serialize(
body = ""
for caption in item.captions:
if caption.cref not in doc_serializer.get_excluded_refs(**kwargs):
body += caption.resolve(doc).get_location_tokens(
doc=doc,
new_line=params.new_line,
xsize=params.xsize,
ysize=params.ysize,
)
body += f"{text.strip()}"
if isinstance(cap := caption.resolve(doc), DocItem):
body += cap.get_location_tokens(
doc=doc,
xsize=params.xsize,
ysize=params.ysize,
)
body += f"{text.strip()}"
if body:
body = _wrap(text=body, wrap_tag=DocumentToken.CAPTION.value)
parts.append(body)
Expand All @@ -279,9 +274,56 @@ def serialize(
**kwargs,
) -> SerializationResult:
"""Serializes the passed item."""
# TODO add actual implementation
text_res = ""
return SerializationResult(text=text_res)
params = DocTagsParams(**kwargs)

body = ""

page_no = 1
if len(item.prov) > 0:
page_no = item.prov[0].page_no

if params.add_location:
body += item.get_location_tokens(
doc=doc,
xsize=params.xsize,
ysize=params.ysize,
)

# mapping from source_cell_id to a list of target_cell_ids
source_to_targets: Dict[int, List[int]] = {}
for link in item.graph.links:
source_to_targets.setdefault(link.source_cell_id, []).append(
link.target_cell_id
)

for cell in item.graph.cells:
cell_txt = ""
if cell.prov is not None:
if len(doc.pages.keys()):
page_w, page_h = doc.pages[page_no].size.as_tuple()
cell_txt += DocumentToken.get_location(
bbox=cell.prov.bbox.to_top_left_origin(page_h).as_tuple(),
page_w=page_w,
page_h=page_h,
xsize=params.xsize,
ysize=params.ysize,
)
if params.add_content:
cell_txt += cell.text.strip()

if cell.cell_id in source_to_targets:
targets = source_to_targets[cell.cell_id]
for target in targets:
# TODO centralize token creation
cell_txt += f"<link_{target}>"

# TODO centralize token creation
tok = f"{cell.label.value}_{cell.cell_id}"
cell_txt = _wrap(text=cell_txt, wrap_tag=tok)
body += cell_txt

body = _wrap(body, DocumentToken.KEY_VALUE_REGION.value)
return SerializationResult(text=body)


class DocTagsFormSerializer(BaseFormSerializer):
Expand Down Expand Up @@ -329,7 +371,7 @@ def serialize(
visited=my_visited,
**kwargs,
)
delim = _get_delim(mode=params.mode)
delim = _get_delim(params=params)
if parts:
text_res = delim.join(
[
Expand Down Expand Up @@ -374,7 +416,7 @@ def serialize(
**kwargs,
)
wrap_tag = DocumentToken.INLINE.value
delim = _get_delim(mode=params.mode)
delim = _get_delim(params=params)
text_res = delim.join([p.text for p in parts if p.text])
if text_res:
text_res = f"{text_res}{delim}"
Expand Down Expand Up @@ -437,14 +479,14 @@ def post_process(
@override
def serialize_page(self, parts: list[SerializationResult]) -> SerializationResult:
"""Serialize a page out of its parts."""
delim = _get_delim(mode=self.params.mode)
delim = _get_delim(params=self.params)
text_res = delim.join([p.text for p in parts])
return SerializationResult(text=text_res)

@override
def serialize_doc(self, pages: list[SerializationResult]) -> SerializationResult:
"""Serialize a document out of its pages."""
delim = _get_delim(mode=self.params.mode)
delim = _get_delim(params=self.params)
if self.params.add_page_break:
page_sep = f"{delim}<{DocumentToken.PAGE_BREAK.value}>{delim}"
content = page_sep.join([p.text for p in pages if p.text])
Expand Down
Loading