In [17]:
"""Arxiv client tools."""

from __future__ import annotations

import io
import re
import ssl
import tempfile
from typing import Iterable, cast

import fitz
from arxiv import Client as _ArxivClient
from arxiv import Result as ArxivResult
from arxiv import Search as ArxivSearch
from more_itertools import flatten, unique_everseen
from pydantic import AnyHttpUrl, BaseModel


class Paper(BaseModel):
    title: str
    text: str
    url: AnyHttpUrl
    references: list[Paper] | None = None

    def flatten(self) -> list[Paper]:
        papers = [self]
        if self.references:
            for paper in self.references:
                papers.extend(paper.flatten())
        return papers


class ArxivClient:

    def __init__(self) -> None:
        self.client = _ArxivClient()
        self.ensure_ssl_verified()

    @staticmethod
    def ensure_ssl_verified() -> None:
        ssl._create_default_https_context = ssl._create_unverified_context

    def search_by_url(self, id_list: list[str]) -> Iterable[ArxivResult]:
        search = ArxivSearch(id_list=id_list)
        yield from self.client.results(search=search)

    def search_by_query(self, queries: Iterable[str]) -> Iterable[ArxivResult]:
        for query in queries:
            search = ArxivSearch(query=query, max_results=1)
            yield from self.client.results(search=search)

    @staticmethod
    def extract_id(url: str) -> str | None:
        match = re.search(r"(\d{4}\.\d{4,5})(v\d+)?", url)
        return match.group(1) if match else None

    @staticmethod
    def parse_references(text: str) -> list[str]:
        arxiv_urls = re.findall(r"(https?://arxiv\.org/abs/\d{4}\.\d{4,5}(v\d+)?)", text)
        return [match[0] for match in arxiv_urls]

    def fetch_papers_by_url(self, urls: Iterable[str]) -> list[Paper]:
        id_list = list(filter(None, (self.extract_id(url) for url in urls)))

        papers = []
        for paper in self.search_by_url(id_list):
            with tempfile.TemporaryDirectory() as temp_dir:
                pdf_path = paper.download_pdf(dirpath=temp_dir)
                doc = cast(fitz.Document, fitz.open(pdf_path))
                text = "".join([page.get_text() for page in doc])  # type: ignore

            papers.append(Paper(title=paper.title, text=text, url=paper.entry_id))  # type: ignore

        return papers

    def fetch_papers_with_references_by_url(self, urls: Iterable[str]) -> list[Paper]:
        parent_papers = self.fetch_papers_by_url(urls)

        for paper in parent_papers:
            reference_urls = self.parse_references(paper.text)
            if reference_urls:
                referenced_papers = self.fetch_papers_by_url(reference_urls)
                paper.references = referenced_papers

        return parent_papers

    def fetch_papers_by_query(self, queries: Iterable[str]) -> list[Paper]:
        papers = []
        for requested_query in queries:
            results = self.search_by_query([requested_query])
            for paper in results:
                if paper.title.lower() != requested_query.lower():
                    continue
                with tempfile.TemporaryDirectory() as temp_dir:
                    pdf_path = paper.download_pdf(dirpath=temp_dir)
                    doc = cast(fitz.Document, fitz.open(pdf_path))
                    text = "".join([page.get_text() for page in doc])  # type: ignore
                papers.append(Paper(title=paper.title, text=text, url=paper.entry_id))  # type: ignore

        return papers

    def fetch_papers_with_references_by_query(self, queries: Iterable[str]) -> list[Paper]:
        parent_papers = self.fetch_papers_by_query(queries)

        for paper in parent_papers:
            reference_urls = self.parse_references(paper.text)
            if reference_urls:
                referenced_papers = self.fetch_papers_by_url(reference_urls)
                paper.references = referenced_papers

        return parent_papers

    def download_papers_by_url(self, urls: str | Iterable[str], save_dir: str) -> None:
        id_list = list(filter(None, (self.extract_id(url) for url in urls)))

        for paper in self.search_by_url(id_list):
            filename = re.sub(r"\W+", "_", paper.title) + ".pdf"
            paper.download_pdf(dirpath=save_dir, filename=filename)

    def download_papers_by_query(self, queries: str | Iterable[str], save_dir: str) -> None:
        for requested_query in queries:
            results = self.search_by_query([requested_query])
            for paper in results:
                if paper.title.lower() != requested_query.lower():
                    continue
                filename = re.sub(r"\W+", "_", paper.title) + ".pdf"
                paper.download_pdf(dirpath=save_dir, filename=filename)

    def load_paper_as_file_by_url(self, url: str) -> io.BytesIO:
        if not isinstance(url, str):
            raise ValueError("Only one URL is allowed.")
        id_ = self.extract_id(url)
        for paper in self.search_by_url([id_]):
            with tempfile.TemporaryDirectory() as temp_dir:
                pdf_path = paper.download_pdf(dirpath=temp_dir)
                with open(pdf_path, "rb") as f:
                    return io.BytesIO(f.read())


arxiv_client = ArxivClient()


def fetch_papers_by_url(urls: str | Iterable[str], parse_reference: bool = False) -> list[Paper]:
    urls = flatten([urls])

    if parse_reference:
        return arxiv_client.fetch_papers_with_references_by_url(urls)

    return arxiv_client.fetch_papers_by_url(urls)


def fetch_papers_by_query(queries: str | Iterable[str], parse_reference: bool = False) -> list[Paper]:
    queries = flatten([queries])

    if parse_reference:
        return arxiv_client.fetch_papers_with_references_by_query(queries)

    return arxiv_client.fetch_papers_by_query(queries)


def load_papers_by_url(urls: str | Iterable[str], save_dir: str = "./") -> None:
    urls = flatten([urls])
    arxiv_client.download_papers_by_url(urls, save_dir)


def load_papers_by_query(queries: str | Iterable[str], save_dir: str = "./") -> None:
    queries = flatten([queries])
    arxiv_client.download_papers_by_query(queries, save_dir)


def load_paper_as_file_by_url(urls: str) -> io.BytesIO:
    return arxiv_client.load_paper_as_file_by_url(urls)


def extract_refs(papers: Iterable[Paper]) -> list[Paper]:
    return list(unique_everseen(flatten([paper.flatten() for paper in papers])))

In [None]:
import tempfile
from unstructured.partition.pdf import partition_pdf, Element

url = "https://arxiv.org/abs/2410.18057"
file = load_paper_as_file_by_url(url)
print(file)

elements = partition_pdf(
    file=file,
    strategy="hi_res",
    infer_table_structure=True,
    extract_images_in_pdf=True,
    extract_image_block_types=["Image", "Table"],
    extract_image_block_to_payload=False,
    extract_image_block_output_dir=tempfile.mkdtemp(),
)
# with tempfile.TemporaryDirectory() as tmp_dir:
# download_papers_by_url(urls=urls, save_dir=tmp_dir)

class PDFExtractor:

    def __init__(self):
        self.output_dir = tempfile.mkdtemp()
        
    def partition(self, file: io.BytesIO) -> list[Element]:
        return partition_pdf(
            file=file,
            strategy="hi_res",
            infer_table_structure=True,
            extract_images_in_pdf=True,
            extract_image_block_types=["Image", "Table"],
            extract_image_block_to_payload=False,
            extract_image_block_output_dir=self.output_dir,
        )
        
    def run(self, file: io.BytesIO) -> None:
        elements = self.partition(file=file)


<_io.BytesIO object at 0x131686480>


In [19]:
import os
from pathlib import Path
from dotenv import load_dotenv
import base64
import openai
import plotly.graph_objects as go
from PIL import Image
from io import BytesIO
from langfuse.openai import OpenAI

llm = OpenAI(
    api_key=os.getenv("OPENAI_API_KEY"),
    base_url=os.getenv("OPENAI_BASE_URL"),
)


def encode_image(image_path):
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode("utf-8")


for element in elements:
    element_dict = element.to_dict()
    type_ = element_dict["type"]
    text = element_dict["text"]
    if type_ == "Image" or type_ == "Table":
        img_path = element_dict["metadata"]["image_path"]
        base64_img = encode_image(img_path)
        
        fig = go.Figure(go.Image(source="data:image/png;base64," + base64_img))
        fig.update_layout(title=text)
        fig.show()

        # response = openai.chat.completions.create(
        #     model="gpt-4o-mini",
        #     messages=[
        #         {
        #             "role": "user",
        #             "content": [
        #                 {
        #                     "type": "text",
        #                     "text": "Given the image of an paper, explain the details of the image.",
        #                 },
        #                 {
        #                     "type": "image_url",
        #                     "image_url": {"url": f"data:image/jpeg;base64,{base64_img}"},
        #                 },
        #             ],
        #         }
        #     ],
        # )
        # print(response.choices[0].message.content)
        break

In [14]:
response.choices[0].message.content

"The image contains a grid with four sections, each featuring a different concept related to visual contexts. \n\n1. **Retain**: This section poses a question about the image of an individual likely posing for a portrait, with a focus on a specific theme, possibly involving a notable figure in literature.\n\n2. **Forget**: Here, another question relates to an image that features a recognizable landscape, notably Mount Fuji, along with a subject holding a flower. The description indicates a cultural or geographical significance.\n\n3. **Real Faces**: This part questions the identity of a person in the image, suggesting a well-known personality from the entertainment industry. \n\n4. **Real World**: It features a question about a cat's appearance, focusing on the direction of its gaze, making it more relatable to everyday life and pets.\n\nEach section emphasizes different aspects of recognition and identification based on visual cues."