In [1]:
# titles = ["Training Large Language Models to Reason in a Continuous Latent Space", 
#           "Towards System 2 Reasoning in LLMs: Learning How to Think With Meta Chain-of-Thought",
#           "s1: Simple test-time scaling",
#           "From Few to Many: Self-Improving Many-Shot Reasoners Through Iterative Optimization and Generation",
#           "Satori: Reinforcement Learning with Chain-of-Action-Thought Enhances LLM Reasoning via Autoregressive Search"]
title = "Training Large Language Models to Reason in a Continuous Latent Space"

# pdf_pathes = ["./data/2412.06769v2.pdf",
#               "./data/2501.04682v1.pdf",
#               "./data/2501.19393v2.pdf",
#               "./data/2502.00330v1.pdf",
#               "./data/2502.02508v1.pdf"]

In [2]:
import sys
import os

# 获取当前脚本所在目录的父目录 (即 my_project)
parent_dir = os.path.dirname(os.getcwd())

# 将父目录添加到 sys.path
sys.path.append(parent_dir)

In [3]:
import time
import requests

from typing import List, Dict, Optional

In [4]:
def download_file(url, filename):
    """Downloads a file from the given URL and saves it as filename."""
    try:
        response = requests.get(url)
        response.raise_for_status()  # Raise an exception for bad status codes

        with open(filename, 'wb') as f:
            f.write(response.content)

        print(f"Successfully downloaded: {filename}")
    except requests.exceptions.RequestException as e:
        print(f"Error downloading: {e}")

In [5]:
import zipfile

def unzip_file(original_zip_file, destination_folder):
    assert os.path.splitext(original_zip_file)[-1] == '.zip'
    with zipfile.ZipFile(original_zip_file, 'r') as zip_ref:
        zip_ref.extractall(destination_folder)

# Paper Metadata Extractijon

### Basic Metadata

In [6]:
from apis.arxiv_tool import ArxivKit
from apis.semanticscholar_tool import SemanticScholarKit

In [7]:
arxiv = ArxivKit()

# arxiv_metadata = []
# for title in titles:
#     candit_arxiv_metadata = arxiv.retrieve_metadata_by_paper(query_term=title, max_cnt=3)
#     arxiv_metadata.append(candit_arxiv_metadata)
#     time.sleep(5)

arxiv_metadata = arxiv.retrieve_metadata_by_paper(query_term=title, max_cnt=3)

2025-02-13 09:08:31,568 - INFO - Requesting page (first: True, try: 0): https://export.arxiv.org/api/query?search_query=Training+Large+Language+Models+to+Reason+in+a+Continuous+Latent+Space&id_list=&sortBy=relevance&sortOrder=descending&start=0&max_results=100
2025-02-13 09:08:35,822 - INFO - Got first page: 100 of 2656582 total results


In [8]:
ss = SemanticScholarKit()

# ss_metadata = []
# for title in titles:
#     candit_ss_metadata = ss.search_paper_by_keywords(query=title, limit=3)
#     ss_metadata.append(candit_ss_metadata)
#     time.sleep(5)
ss_metadata = ss.search_paper_by_keywords(query=title, limit=3)

2025-02-13 09:08:37,611 - INFO - HTTP Request: GET https://api.semanticscholar.org/graph/v1/paper/search?query=Training%20Large%20Language%20Models%20to%20Reason%20in%20a%20Continuous%20Latent%20Space&fields=abstract,authors,citationCount,citationStyles,corpusId,externalIds,fieldsOfStudy,influentialCitationCount,isOpenAccess,journal,openAccessPdf,paperId,publicationDate,publicationTypes,publicationVenue,referenceCount,s2FieldsOfStudy,title,url,venue,year&offset=0&limit=3 "HTTP/1.1 200 OK"


### Reference and Citedby Data

In [9]:
# paper_ss_id = ss_metadata[0][0].get('paperId')
paper_ss_id = ss_metadata[0].get('paperId')
print(paper_ss_id)

673fbdd957cada770d10dffca5e45b53da43a3c6


In [10]:
reference_metadata = ss.get_semanticscholar_references(paper_id=paper_ss_id, limit=100)

2025-02-13 09:08:46,239 - INFO - HTTP Request: GET https://api.semanticscholar.org/graph/v1/paper/673fbdd957cada770d10dffca5e45b53da43a3c6/references?fields=contexts,intents,contextsWithIntent,isInfluential,abstract,authors,citationCount,citationStyles,corpusId,externalIds,fieldsOfStudy,influentialCitationCount,isOpenAccess,journal,openAccessPdf,paperId,publicationDate,publicationTypes,publicationVenue,referenceCount,s2FieldsOfStudy,title,url,venue,year&offset=0&limit=100 "HTTP/1.1 200 OK"


In [11]:
len(reference_metadata)

49

In [12]:
citedby_metadata = ss.get_semanticscholar_citedby(paper_id=paper_ss_id, limit=100)

2025-02-13 09:08:50,728 - INFO - HTTP Request: GET https://api.semanticscholar.org/graph/v1/paper/673fbdd957cada770d10dffca5e45b53da43a3c6/citations?fields=contexts,intents,contextsWithIntent,isInfluential,abstract,authors,citationCount,citationStyles,corpusId,externalIds,fieldsOfStudy,influentialCitationCount,isOpenAccess,journal,openAccessPdf,paperId,publicationDate,publicationTypes,publicationVenue,referenceCount,s2FieldsOfStudy,title,url,venue,year&offset=0&limit=100 "HTTP/1.1 200 OK"


In [13]:
len(citedby_metadata)

9

# Paper PDF Process

## Outline Detection

In [14]:
import fitz

def pdf_outline_detection(pdf_path, excpert_len:Optional[int]=300):
    doc = fitz.open(pdf_path)
    toc_infos = doc.get_toc(simple=False) or []

    pdf_toc = []
    for item in toc_infos:
        lvl = item[0] if len(item) > 0 else None
        title = item[1] if len(item) > 1 else None
        start_page = item[2] if len(item) > 2 else None
        end_pos = item[3].get('to') if len(item) > 3 and item[3] else None
        nameddest = item[3].get('nameddest') if len(item) > 3 and item[3] else None
        if_collapse = item[3].get('collapse', False) if len(item) > 3 and item[3] else None

        if start_page is not None:
            page = doc[start_page-1]
            blocks = page.get_text("blocks")

            lines = ""
            for block in blocks:
                x0, y0, x1, y1, text, _, _ = block
                if len(lines) < excpert_len:
                    if end_pos and x0 >= end_pos[0]:
                        lines += text
                else:
                    break

            pdf_toc.append({
                "level": lvl,
                "title": title,
                "page": start_page,
                "position": end_pos,
                "nameddest": nameddest,
                'if_collapse': if_collapse,
                "excerpt": lines + "..."
            })
    return pdf_toc

In [15]:
pdf_path = "/home/jiezi/Code/Temp/data/2412.06769v2.pdf"
pdf_toc = pdf_outline_detection(pdf_path=pdf_path)

## Miner U PDF Process

In [None]:
from apis.mineru_tool import MinerUKit

mineru_api_key = os.getenv('MINERU_API_KEY_1')
mineru = MinerUKit(api_key=mineru_api_key)
upload_res = mineru.batch_process_files(pdf_files=pdf_pathes, if_ocr=False, lang='en')

In [None]:
batch_id = upload_res.json().get('data', {}).get('batch_id')
running_res = mineru.batch_status_check(batch_id=batch_id)

In [None]:
temp_path = "/home/jiezi/Code/Temp/tmp"

if running_res.json().get('msg') == 'ok':
    results = running_res.json().get('data', {}).get('extract_result', []) 
    for item in results:
        if item.get('state') == 'done':
            file_name_nosuffix = item.get('file_name').rsplit('.', 1)[0] 
            zip_url = item.get('full_zip_url')
            download_file_name = os.path.join(temp_path, file_name_nosuffix+".zip") 
            unzip_folder_name = os.path.join(temp_path, file_name_nosuffix) 
            download_file(zip_url, download_file_name)
            unzip_file(download_file_name, unzip_folder_name)


# Paper Syntax Coversion

In [16]:
import os

temp_path = "/home/jiezi/Code/Temp/tmp"
file_name_nosuffix = "2412.06769v2"
file_path = os.path.join(temp_path, file_name_nosuffix)

from pathlib import Path  
 
for file in Path(file_path).glob('*'): 
    file_nm = os.path.basename(file)
    if "_origin.pdf" in file_nm:
        os.remove(file) 
    elif "_content_list.json" in file_nm:
        os.rename(file, os.path.join(file_path, "content_list.json"))

md_file = os.path.join(file_path, "full.md")
content_json_file = os.path.join(file_path, "content_list.json")
layout_json_file = os.path.join(file_path, "layout.json")

In [17]:
import json
with open("/home/jiezi/Code/Temp/tmp/2412.06769v2/content_list.json") as json_data:
    content_json = json.load(json_data)

In [18]:
md_file = "/home/jiezi/Code/Temp/tmp/2412.06769v2/full.md"
with open(md_file, 'r', encoding='utf-8') as f:
    markdown_content = f.read()

## Markdown Syntax Conversion

Covert picture from html syntax to markdown syntax

Covert table from Markdown syntax to html form

In [19]:
import re

def markdown_table_to_html(markdown_text):
    """
    将 Markdown 文本中的 Markdown 表格转换为 HTML 表格。

    Args:
        markdown_text: 包含 Markdown 表格的 Markdown 文本。

    Returns:
        转换后的 Markdown 文本，表格部分已转换为 HTML 表格。
    """

    lines = markdown_text.splitlines()
    output_lines = []
    in_table = False
    table_lines = []

    for line in lines:
        if line.strip().startswith('|'):
            in_table = True
            table_lines.append(line)
        else:
            if in_table:
                # 表格结束，处理之前收集的表格行
                html_table = _convert_table_lines_to_html(table_lines)
                output_lines.append(html_table)
                in_table = False
                table_lines = []
            output_lines.append(line)

    # 处理文本末尾可能存在的表格
    if in_table:
        html_table = _convert_table_lines_to_html(table_lines)
        output_lines.append(html_table)

    return "\n".join(output_lines)


def _convert_table_lines_to_html(table_lines):
    """
    将 Markdown 表格行转换为 HTML 表格。

    Args:
        table_lines: Markdown 表格行的列表。

    Returns:
        HTML 表格字符串。
    """
    html_lines = ["<table>", "  <thead>", "    <tr>"]
    header_cells = [cell.strip() for cell in table_lines[0].strip('|').split('|')]
    for header in header_cells:
        html_lines.append(f"      <th>{header}</th>")
    html_lines.append("    </tr>")
    html_lines.append("  </thead>")
    html_lines.append("  <tbody>")

    if len(table_lines) > 1 and re.match(r'^\|[-:| ]+\|[-:| ]*$', table_lines[1].strip()):
        # 存在分隔行，跳过分隔行，从第三行开始是数据行
        data_start_index = 2
    else:
        data_start_index = 1 # 没有分隔行，从第二行开始是数据行

    for i in range(data_start_index, len(table_lines)):
        html_lines.append("    <tr>")
        data_cells = [cell.strip() for cell in table_lines[i].strip('|').split('|')]
        for cell in data_cells:
            html_lines.append(f"      <td>{cell}</td>")
        html_lines.append("    </tr>")

    html_lines.append("  </tbody>")
    html_lines.append("</table>")
    return "\n".join(html_lines)

In [20]:
markdown_content = markdown_table_to_html(markdown_content)

## Align Markdown Titles

Align markdown title with pdf ToC

In [21]:
import re

def restore_md_toc(md_content, pdf_toc):
    """
    Align markdown title with pdf table of content (generated from fitz)

    Args:
        md_file: Path to the markdown file.
        pdf_toc: pdf toc from pdf_outline_detection function

    Returns:
        A list of dictionaries, where each dictionary represents a section
        with 'level', 'section_num', 'title', and 'text' keys.
        Returns an empty list if the file doesn't exist.
        Returns None if an error occurs.
    """
    if pdf_toc:
        modified_lines = []  # 用于存储修改后的行的列表

        title_pattern = r"^#{1,}\s*.*$"  # patttern of markdown title
        md_titles = []

        for idx, line in enumerate(md_content.splitlines()):  # iterate markdown lines
            if line.strip() not in ["\n", "\s", "\r", ""]:
                match = re.search(title_pattern, line)
                if match:  # find markdown title
                    sec_title = line
                    flag = 0

                    for x in pdf_toc:  # iterate pdf toc, refine markdown title based on toc title
                        toc_title = x['title'] 
                        toc_level = int(x['level'])  
                        if toc_title in line:  
                            sec_title = "#"*toc_level + " " + toc_title + "  "
                            flag = 1
                            break
                    
                    if flag == 0:  # markdown title not exit in toc
                        for item in ['Acknowledgement', 'Reference', 'Appendix']:
                            if item in line:
                                sec_title = line
                                flag = 1
                    
                    if flag == 0:
                        if len(md_titles) > 0:
                            if re.match('^#{1,}', md_titles[-1]):
                                pre_level = re.match('^#{1,}', md_titles[-1]).group(0) + "#"
                                sec_title = re.sub('^#{1,}', pre_level, line)
                            else:
                                sec_title = "#" + line

                    modified_lines.append(sec_title)
                    md_titles.append(sec_title)  # get markdown title

                else:
                    modified_lines.append(line)
    return "\n".join(modified_lines), md_titles

In [22]:
md_content_rvsd, md_titles = restore_md_toc(markdown_content, pdf_toc)

## Process Json Content (Charts, Tables, and Equations)

Get id, title, desc, etc. for charts, tables, and equations in content json.

In [None]:
def get_first_lines(text, sentence_length):
    if not text:
        return ""

    # 使用正则表达式分割句子
    sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?|;|!)\s', text) # 更精确的断句正则

    result = ""
    current_length = 0

    for sentence in sentences:
        cleaned_sentence = sentence.strip()

        if cleaned_sentence:
            result += cleaned_sentence + " "
            current_length = len(result.strip())

            if current_length >= sentence_length:
                return result.strip()

    return result.strip()

In [7]:
import re 

def process_content_json(content_json):
    """assign title and ids to images/ charts, tables, and equations
    """
    img_lst, tbl_lst, formula_lst = [], [], []
    i, j, k = 1, 1, 1
    for x in content_json:
        if x['type'] in ['image']:
            desc = "\n".join(x.get('img_caption', [])) + "\n" + "\n".join(x.get('img_footnote', []))
            ptrn = r"(pic|picture|img|image|chart|figure|fig|table|tbl)\s*([0-9]+(?:\.[0-9]+)?|[0-9]+|[IVXLCDM]+|[a-zA-Z]+)"
            mtch_rslts = re.finditer(ptrn, desc, re.IGNORECASE)

            img_ids = []
            for match in mtch_rslts:
                img_ids.append(match.group(0))  # 直接获取整个匹配的字符串

            if len(img_ids) == 0:
                img_ids = [f"Image_Number_{i}"]
                i += 1
            x['id'] = img_ids[0]
            x['related_ids'] = img_ids[1:]
            x['title'] = get_first_lines(desc, 10)
            x['description'] = desc
            img_lst.append(x)

        elif x['type'] == 'table':
            desc = "\n".join(x.get('table_caption', [])) + "\n" + "\n".join(x.get('table_footnote', []))
            ptrn = r"(tbl|table|chart|figure|fig)\s*([0-9]+(?:\.[0-9]+)?|[0-9]+|[IVXLCDM]+|[a-zA-Z]+)"
            mtch_rslts = re.finditer(ptrn, desc, re.IGNORECASE)

            tbl_ids = []
            for match in mtch_rslts:
                tbl_ids.append(match.group(0))  # 直接获取整个匹配的字符串

            if len(tbl_ids) == 0:
                tbl_ids = [f"Table_Number_{j}"]
                j += 1
            x['id'] = tbl_ids[0]
            x['related_ids'] = tbl_ids[1:]
            x['title'] = get_first_lines(desc, 10)
            x['description'] = desc
            tbl_lst.append(x)

            # for table with image
            if x.get('img_path') is not None:
                item = {'type':'image', 'img_path': x.get('img_path'), 'img_caption': x.get('table_caption'), 'table_footnote': x.get('table_footnote'), 'page_idx': x.get('page_idx')}
                desc = "\n".join(item.get('img_caption', [])) + "\n" + "\n".join(item.get('img_footnote', []))
                ptrn = r"(table|tbl|pic|picture|img|image|chart|figure|fig)\s*([0-9]+(?:\.[0-9]+)?|[0-9]+|[IVXLCDM]+|[a-zA-Z]+)"
                mtch_rslts = re.finditer(ptrn, desc, re.IGNORECASE)

                img_ids = []
                for match in mtch_rslts:
                    img_ids.append(match.group(0))  # 直接获取整个匹配的字符串

                if len(img_ids) == 0:
                    img_ids = [f"Table_Image_Number_{i}"]
                    i += 1
                item['id'] = img_ids[0]
                item['related_ids'] = img_ids[1:]
                item['title'] = get_first_lines(desc, 10)
                item['description'] = desc
                img_lst.append(item)

        elif x['type'] == 'equation':

            desc = x.get('text')
            ptrn = r"(formula|equation|notation|syntax)\s*([0-9]+(?:\.[0-9]+)?|[0-9]+|[IVXLCDM]+|[a-zA-Z]+)"
            mtch_rslts = re.finditer(ptrn, desc, re.IGNORECASE)

            equation_ids = []
            for match in mtch_rslts:
                equation_ids.append(match.group(0))  # 直接获取整个匹配的字符串

            if len(equation_ids) == 0:
                equation_ids = [f"Equation_Number_{k}"]
                k += 1
            x['id'] = equation_ids[0]
            x['related_ids'] = equation_ids[1:]
            x['title'] = equation_ids[0]
            x['description'] = equation_ids[0]
            formula_lst.append(x)

            # for table with image
            if x.get('img_path') is not None:
                item = {'type':'image', 'img_path': x.get('img_path'), 'img_caption': x.get('img_caption'), 'img_caption': x.get('img_caption'), 'page_idx': x.get('page_idx')}
                desc = item.get('text')
                ptrn = r"(formula|equation|notation|syntax)\s*([0-9]+(?:\.[0-9]+)?|[0-9]+|[IVXLCDM]+|[a-zA-Z]+)"
                mtch_rslts = re.finditer(ptrn, desc, re.IGNORECASE)

                img_ids = []
                for match in mtch_rslts:
                    img_ids.append(match.group(0))  # 直接获取整个匹配的字符串

                if len(img_ids) == 0:
                    img_ids = [f"Equation_Image_Number_{i}"]
                    i += 1
                x['id'] = img_ids[0]
                x['related_ids'] = img_ids[1:]
                x['title'] = equation_ids[0]
                x['description'] = equation_ids[0]
                img_lst.append(x)
    return img_lst, tbl_lst, formula_lst 

In [8]:
img_lst, tbl_lst, formula_lst = process_content_json(content_json)

In [28]:
import os
from google import genai
from google.genai import types

def llm_gen_w_images(api_key, model_name, qa_prompt, pil_images, sys_prompt=None, temperature=0.3):
    """q&a with images
    Args:
        pil_images:
            import PIL.Image
            image = PIL.Image.open('/path/to/image.png')
    """

    client = genai.Client(api_key=api_key)

    config = types.GenerateContentConfig(
        system_instruction=sys_prompt,
        temperature=temperature)

    response = client.models.generate_content(
        model=model_name,  #　"gemini-2.0-flash-exp",
        contents=[qa_prompt]+pil_images,
        config=config)

    return response.text

In [29]:
import PIL.Image

example = {"img_id":"[original imgage id provided]", "img_nm":"[original image name provided as attached]", "img_title": "[suggested image title]", "img_desc":"[a detailed description of the image]"}
img_analysis_prompt = """You are provided with multiple images. Please analyze the images, try to extract their image title and give a desctiption of the images.
Output in json format like:
```json
{example_json}
```

## INPUTf
Here are images ids and names. 
{img_info}

## OUTPUT
Now get started!

"""

import os

tmp_path = "/home/jiezi/Code/Temp/tmp/2412.06769v2"

img_info = ""
pil_images = []
for img in img_lst:
    img_url = os.path.join(tmp_path, img.get('img_path'))
    pil_images.append(PIL.Image.open(img_url))
    img_info += f"img_id: {img.get('id')}   img_nm: {os.path.basename(img.get('img_path'))}\n"

qa_prompt = img_analysis_prompt.format(example_json=json.dumps(example, ensure_ascii=False), img_info=img_info)


In [30]:
api_key = os.getenv('GEMINI_API_KEY_1')
temperature = 0.7
res = llm_gen_w_images(
    api_key=api_key, model_name='gemini-2.0-flash-thinking-exp', 
    qa_prompt=qa_prompt, pil_images=pil_images, sys_prompt=None, temperature=0.6)

In [31]:
print(qa_prompt)

You are provided with multiple images. Please analyze the images, try to extract their image title and give a desctiption of the images.
Output in json format like:
```json
{"img_id": "[original imgage id provided]", "img_nm": "[original image name provided as attached]", "img_title": "[suggested image title]", "img_desc": "[a detailed description of the image]"}
```

## INPUTf
Here are images ids and names. 
img_id: Figure 1   img_nm: 3b0f18697b44445c12ab2b41e0ff7a5fa498867fbcd33644600f98041b2a9f6a.jpg
img_id: Figure 2   img_nm: e72083f8a262261062393a5c15691de83822ae7b995a8d74e27b22ad37bcb993.jpg
img_id: Table 1   img_nm: c81e5cdf7aa362b6915254737e207d052d121cdf144405aa782f3632384b8feb.jpg
img_id: Figure 3   img_nm: 4692454aa0746096ebd571ecdc3a93b136498d5be71a2a1a90c85d5df37c9864.jpg
img_id: Figure 4   img_nm: 1dffc8860659d93110f6f8489bbebab79afecb8b091a9ccd9d75f3d097fa6ccd.jpg
img_id: Figure 5   img_nm: cc15459750485d6d98eafccaf8e50ced906abc8a704d695310089ac62cca1f28.jpg
img_id: Figu

# Paper Context Modification

## Modify Image Information

Modify markdown image text to better align with standard syntax.

In [26]:
import copy

def modify_image_info(md_text, img_lst):
    """update image information with alternative text, image title, etc."""
    img_ptrn = re.compile(
        r'!\s*\[\s*(?P<alt>.*?)\s*\]'  # 匹配 ![alt] alt 部分
        r'\s*\(\s*(?P<link>.*?)\s*'   # 匹配 (link) link 部分
        r'(?:'                         # 非捕获组，处理可选的 title 部分
        r'(?P<quotetitle>"(?P<title_double_quote>.*?)"|'  # 匹配 双引号 title， 命名组 quotetitle 和 title_double_quote
        r"'(?P<title_single_quote>.*?)')"                 # 匹配 单引号 title， 命名组 title_single_quote
        r')?'                          # title 部分可选
        r'\s*\)'                      # 匹配 ) 括号结尾
    )
    lines = md_text.splitlines()
    img_lst_rvsd = copy.deepcopy(img_lst)
    
    for idx, line in enumerate(lines):
        if line.strip() not in ["\n", "\s", "\r", ""]:

            # image match logic
            img_matches = list(re.finditer(img_ptrn, line))  # 使用 finditer 获取所有匹配项

            if img_matches:
                for match in reversed(img_matches):  # 逆序遍历匹配项，避免替换位置错乱
                    alt_text = match.group(1).strip()
                    image_url = match.group(2)
                    title = match.group(4).strip() if match.group(4) else None

                    for item in img_lst_rvsd:
                        if item.get('img_path') == image_url:
                            alt_text = item.get('description') if alt_text is None or alt_text == "" else alt_text
                            title = item.get('title', "") if title is None or title == "" else title
                            title = f"{item.get('id')}: {title}" if item.get('id').lower() not in title.lower() else title
                            img_md = f"![{alt_text.strip()}]({image_url.strip()} '{title.strip()}')"

                            # 计算替换的起始和结束位置
                            start, end = match.span()
                            if item.get('org_md_ref') is None:
                                item['org_md_ref'] = line[start:end]  # 在image list中添加原始的markdown引用格式 

                            lines[idx] = line[:start] + img_md + line[end:]  # 精确替换
                            if item.get('mod_md_ref') is None:
                                item['mod_md_ref'] = line[:start] + img_md + line[end:]  # 在image list中添加修订后的markdown引用格式 

                            # 改进删除重复信息逻辑
                            caption = "\n".join(item.get('img_caption')).strip()
                            footnote = "\n".join(item.get('img_footnote')).strip()

                            # 由于alt_text和title中已经包括了足够的信息，删除上下文中的重复信息
                            if caption and len(caption) > 20 and caption != title:
                                if idx > 0 and caption in lines[idx-1]:
                                    lines[idx-1] = lines[idx-1].replace(caption, "")
                                if idx < len(lines) - 1 and caption in lines[idx+1]:
                                    lines[idx+1] = lines[idx+1].replace(caption, "")

                            if footnote and len(footnote) > 20 and footnote != title:
                                if idx > 0 and footnote in lines[idx-1]:
                                    lines[idx-1] = lines[idx-1].replace(footnote, "")
                                if idx < len(lines) - 1 and footnote in lines[idx+1]:
                                    lines[idx+1] = lines[idx+1].replace(footnote, "")
                            break  # 找到匹配的 item 后跳出循环
    return "\n".join(lines), img_lst_rvsd

In [27]:
md_content_rvsd, img_lst_rvsd = modify_image_info(md_content_rvsd, img_lst)

## Modify Tables Information

In [28]:
import copy
from bs4 import BeautifulSoup

def modify_tables_info(md_text, tbl_lst):
    """update table information with alternative text, image title, etc."""
    
    tbl_lst_rvsd = copy.deepcopy(tbl_lst)

    lines = md_text.splitlines()

    for idx, line in enumerate(lines):  # iterate lines
        soup = BeautifulSoup(line, 'html.parser')
        table = soup.find('table')

        if table:
            for item in tbl_lst_rvsd:  # iterate over table list 
                tbl_desc = item.get('description')
                tbl_caption = "\n".join(item.get('table_caption', [])).strip()
                tbl_footnote = "\n".join(item.get('table_footnote', [])).strip()
                tbl_body = BeautifulSoup(item.get('table_body') , 'html.parser').find('table')
                tbl_title = item.get('title')

                if table == tbl_body:
                    md_caption = table.find('caption')
                    if md_caption:
                        md_caption.string = tbl_desc      # 将<caption>标签的文本内容替换为 tbl_desc
                    else:
                        # 如果没有<caption>标签，则创建一个新的<caption>标签并添加到table中
                        new_caption_tag = soup.new_tag('caption')
                        new_caption_tag.string = tbl_desc
                        table.insert(0, new_caption_tag) # 将新的<caption>标签插入到table的开头 (作为第一个子元素)
                        
                    lines[idx] = f"<html><body>{table}</body></html>  "

                    # 计算替换的起始和结束位置
                    if item.get('org_md_ref') is None:
                        item['org_md_ref'] = f"<html><body>{tbl_body}</body></html>  " # original table

                    if item.get('mod_md_ref') is None:
                        item['mod_md_ref'] = f"<html><body>{table}</body></html>  "  # table with caption


                    # 由于alt_text和title中已经包括了足够的信息，删除上下文中的重复信息
                    if tbl_caption and len(tbl_caption) > 20 and tbl_caption != tbl_title:
                        if idx > 0 and tbl_caption in lines[idx-1]:
                            lines[idx-1] = lines[idx-1].replace(tbl_caption, "")
                        if idx < len(lines) - 1 and tbl_caption in lines[idx+1]:
                            lines[idx+1] = lines[idx+1].replace(tbl_caption, "")

                    if tbl_footnote and len(tbl_footnote) > 20 and tbl_footnote != tbl_title:
                        if idx > 0 and tbl_footnote in lines[idx-1]:
                            lines[idx-1] = lines[idx-1].replace(tbl_footnote, "")
                        if idx < len(lines) - 1 and tbl_footnote in lines[idx+1]:
                            lines[idx+1] = lines[idx+1].replace(tbl_footnote, "")


                    break  # 找到匹配的 item 后跳出循环

    return "\n".join(lines), tbl_lst_rvsd

In [29]:
md_content_rvsd, tbl_lst_rvsd = modify_tables_info(md_content_rvsd, tbl_lst)

  soup = BeautifulSoup(line, 'html.parser')


## Modify Equations Information

In [103]:
# # 本处非必要
# def modify_formula_info(md_text, formula_lst):
#     """update table information with alternative text, image title, etc."""
#     md_content_rvsd = md_text
#     for formula in formula_lst:
#         text = formula.get('text', '')
#         id = formula.get('id', '')
#         format = formula.get('text_format', '')
#         md_content_rvsd = md_content_rvsd.replace(text, f"``'{format} {id}\n{text}\n```")

#         # 计算替换的起始和结束位置
#         if item.get('org_md_ref') is None:
#             item['org_md_ref'] = text # original formula

#         if item.get('mod_md_ref') is None:
#             item['mod_md_ref'] = f"``'{format} {id}\n{text}\n```"  # formula with format information
#     return md_content_rvsd


## Modify Reference Information

Sometimes the author list in refernce might be to long
Cut down and keep within top 5 authors

Add original text (citation) information

Deal with no reference cases

# Paper Segmentation Proceess

Process markdown into segments:
- cut markdown content into segments
- restore images, tables, equations positions and informations in each segement

### Segmentation

In [50]:
import re 
def md_seg_by_title(md_content, level):
    title_pattern = re.compile(rf"^#{{{level}}}\s+(.+)$", re.MULTILINE)

    segments = []

    lines = []
    current_section = ""
    current_title = ""

    num = 1  # Initialize section number
    para_id = 1  # initialize pragraph number

    for idx, line in enumerate(md_content.splitlines()):
        if line.strip() not in ["\n", "\s", "\r", ""]:
            match = title_pattern.match(line)
            if match:
                if current_section:  # Save the previous section
                    segments.append({
                        'level': level,
                        'num': num,
                        'title': current_title,
                        'text': current_section.strip(),  # Remove leading/trailing whitespace
                        'lines': lines
                    })
                    num += 1  # Increment for the next section
                
                # ready for next section
                current_title = match.group(1).strip()
                current_section = ""  # Start a new section (no title line)
                lines = []
                para_id = 1
            else:
                current_section += line + "\n"  # Add to the current section
                lines.append({'id': idx, 'line': line})
                para_id += 1

    if current_section:  # Save the last section
        segments.append({
            'level': level,
            'num': num,
            'title': current_title,
            'text': current_section.strip(),
            'lines': lines
        })

    return segments

In [51]:
lvl_1_segments = md_seg_by_title(md_content_rvsd, 1)

In [None]:


# level 1 segment
if 2 in set([x.get('level') for x in pdf_toc]): 
    for seg in lvl_1_segments:  # iterate each level 1 seg
        seg_title = seg.get('title')
        seg_md = seg.get('text')
        for toc in pdf_toc:
            if toc.get('title') in seg_title and toc.get('if_collapse') == True:
                lvl_2_segments = md_seg_by_title(seg_md, 2)
                seg['sub_segmentations'] = lvl_2_segments

# level 2 segment
if 3 in set([x.get('level') for x in pdf_toc]): 
    for sub_seg in lvl_2_segments:  # iterate each level 1 seg
        sub_title = seg.get('title')
        sub_md = seg.get('text')
        for toc in pdf_toc:
            if toc.get('title') in sub_title and toc.get('if_collapse') == True:
                lvl_3_segments = md_seg_by_title(sub_md, 3)
                sub_seg['sub_segmentations'] = lvl_3_segments

### Segmentation to Blocks

Further break segementation to blocks for better emebding and for more comprehnsiable topics.

### Restore Information

- restore postion of images, tables and refernces
- specify external sources information
  - external sources like image helps LLM for further analysis
  - external sources like references help build knowledge graph
  - 

In [None]:
def restore_seg_image(md_text, img_lst, tbl_lst, ref_lst):
    """restore images within md_text"""
    lines = md_text.splitlines()

    seg_images, seg_tbls, seg_refs = [], [], []
    for idx, line in enumerate(lines):
        if line.strip() not in ["\n", "\s", "\r", ""]:
            # resore images in segment
            for img in img_lst:
                md_ref = img.get('mod_md_ref', '').strip()
                # image cited in line but not exist in section 
                if (md_ref not in "\n".join(lines).strip()
                    and (img.get('id') in line.strip() or img.get('title') in line.strip())):
                    lines.insert(idx+1, md_ref)
                    if img not in seg_images:
                        seg_images.append(img)

                # line contains image ref but not cited in section
                if md_ref in line.strip():
                    if img.get('id') not in "\n".join(lines).strip() or img.get('title') in "\n".join(lines).strip():
                        lines[idx] = line.replace(md_ref, "  ")
                    elif img not in seg_images:
                        seg_images.append(img)

            # resore tables in segment
            for tbl in tbl_lst:
                md_ref = tbl.get('mod_md_ref').strip()

                # image cited in line but not exist in section 
                if (md_ref not in "\n".join(lines).strip()
                    and (tbl.get('id') in line.strip() or tbl.get('title') in line.strip())):
                    lines.insert(idx+1, md_ref)
                    if tbl not in seg_tbls:
                        seg_tbls.append(tbl)

                # line contains image ref but not cited in section
                if md_ref in line.strip():
                    if (tbl.get('id') not in "\n".join(lines).strip() or tbl.get('title') in "\n".join(lines).strip()):
                        lines[idx] = line.replace(md_ref, "  ")
                    elif tbl not in seg_tbls:
                        seg_tbls.append(tbl)     

            # resore refs in segment
            for idx, line in enumerate(lines):
                if line.strip() not in ["\n", "\s", "\r", ""]:
                    for ref in ref_lst:
                        if ref not in seg_refs:
                            contexts = ref.get('contexts')
                            for x in contexts:
                                if x.strip() in line:
                                    seg_refs.append(ref.get('citedPaper', {}))  # get only ref paper information, neglect isInfluential, intent, etc.
                                    break
    
    # to-do: append references here
    # if len(seg_refs) > 0:
    #     lines.extend()

    return "\n".join(lines), seg_images, seg_tbls, seg_refs


In [60]:
for seg in lvl_1_segments:
    md_text = seg.get('text')
    md_text_new, seg_images, seg_tbls, seg_refs = restore_seg_image(md_text, img_lst_rvsd, tbl_lst_rvsd, reference_metadata)
    seg['refined_text'] = md_text_new
    seg['images'] = seg_images
    seg['tables'] = seg_tbls
    seg['references'] = seg_refs

In [64]:
tmp_wip_json_path = "pdf_processed_wip_20250212.json"

with open(tmp_wip_json_path, "w") as file:
    json.dump(lvl_1_segments, file, indent=4)

# Finalization

finalize process documents

save final markdown and processed json list