In [1]:
import os
from uuid import uuid4
from typing import Optional, Union, List, Tuple
from pydantic import BaseModel
from pathlib import Path
import openparse
from openparse.schemas import TableElement, TextElement, ImageElement, LineElement
import pdfplumber
from IPython.display import HTML, display
from langchain_openai import ChatOpenAI
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser


class TextIOU(BaseModel):
    element: TextElement
    target_table: TableElement
    iou: float

    @property
    def text(self):
        return self.element.text

def get_bbox(element: Union[TableElement, TextElement]) -> Tuple[float, float, float, float]:
    x0 = element.bbox.x0
    y0 = element.bbox.y0
    x1 = element.bbox.x1
    y1 = element.bbox.y1

    return (x0, y0, x1, y1)

def calc_text_iou(table_element: TableElement, text_element: TextElement, y_offsets: Tuple[float, float]) -> float:
    b1x0, b1y0, b1x1, b1y1 = get_bbox(table_element)
    b2x0, b2y0, b2x1, b2y1 = get_bbox(text_element)

    b1y0, b1y1 = b1y0 - y_offsets[0], b1y1 + y_offsets[1]

    left = max(b1x0, b2x0)
    top = max(b1y0, b2y0)
    right = min(b1x1, b2x1)
    bottom = min(b1y1, b2y1)

    if right < left or bottom < top:
        return 0.0

    intersection_area = (right - left) * (bottom - top)
    text_area = (b2x1 - b2x0) * (b2y1 - b2y0)
    iou = intersection_area / text_area

    return iou

def find_description_by_iou(candidates: List[TextIOU]) -> str:
    if not candidates:
        return "Not found"
    
    table_descriptions = \
        [candidate for candidate in candidates if candidate.text.lower().startswith("table")]
    
    if table_descriptions:
        highest_iou = 0
        highest = None
        for candidate in table_descriptions:
            if highest_iou <= candidate.iou:
                highest_iou = candidate.iou
                highest = candidate
        return highest.text.replace("\n", " ")
    
    highest = sorted(candidates, key=lambda x: x.iou, reverse=True)[0]
    return highest.text.replace("\n", " ")

def refine(table_html: str, window_context: List[str]):
    template = """# Role
You're an expert at refining and modifying HTML tables.
Given a table and its surrounding contextual information, you need to fix and correct what is wrong or incorrect about the table's columns and data.
Do not add data by referencing and inferring from the contextual information. The number of rows does not change.
If a description of the table exists in the contextual information, add it as a <h2> tag.
Your answers are formatted as HTML code only (e.g., <html> ... </html>).
# TABLE
{table_html}

# CONTEXT
{context}

html:"""

    prompt = PromptTemplate.from_template(template)
    llm = ChatOpenAI(model_name="gpt-4o")
    ouput_parser = StrOutputParser()

    chain = prompt | llm | ouput_parser
    return chain.invoke({"table_html": table_html, "context": "\n\n".join(window_context)})

In [2]:
os.environ["OPENAI_API_KEY"] = "..."

PDF_PATH = "../data/pdf"
IMAGE_PATH = "../data/images"
pdf_path = Path(PDF_PATH)
image_path = Path(IMAGE_PATH)

In [9]:
# 테스트 PDF 파일(https://arxiv.org/pdf/2410.10315v1)
# filepath = pdf_path / "EasyRAG: Efficient Retrieval-Augmented Generation Framework for Automated Network Operations.pdf"
filepath = pdf_path / "LightRAG: Simple and Fast Retrieval-Augmented Generation.pdf"

In [10]:
table_args={
    "parsing_algorithm": "unitable",
    "min_table_confidence": 0.8,
}
parser = openparse.DocumentParser(table_args=table_args)
parsed_basic_doc = parser.parse(filepath)

In [11]:
allowed_elements = [TextElement, LineElement, TableElement, ImageElement]

all_elements = []

nodes = parsed_basic_doc.nodes
for node in nodes:
    for elements in node.elements:
        if type(elements) in allowed_elements:
            all_elements.append(elements)
            continue
        for element in elements:
            if type(element) in allowed_elements:
                all_elements.append(element)

table_elements = [element for element in all_elements if isinstance(element, TableElement)]

In [12]:
def extract_pdf_image(pdf_path, element, save_path, resolution=500):
    def _get_bbox(element) -> Tuple[float, float, float, float]:
        return (element.bbox.x0, 
                element.bbox.page_height - element.bbox.y1, 
                element.bbox.x1, 
                element.bbox.page_height - element.bbox.y0)
    
    pdf_obj = pdfplumber.open(pdf_path)
    page = pdf_obj.pages[element.bbox.page]
    cropped_page = page.crop(_get_bbox(element))
    image = cropped_page.to_image(resolution=resolution)
    image.save(save_path)

In [15]:
for element in all_elements:
    if isinstance(element, TableElement):
        filename = f"t-{uuid4()}.jpg" 
    elif isinstance(element, ImageElement):
        filename = f"i-{uuid4()}.jpg" 
    else:
        continue

    save_path = image_path / filename
    extract_pdf_image(filepath, element, save_path)
    