In [None]:
from langgraph.graph import StateGraph, MessagesState, END
from langchain_groq import ChatGroq
from typing import TypedDict
import groq
import pymupdf
import pymupdf4llm
import requests


SYSTEM_PROMPT = """
You are a professor teaching a course from the following paper.
Given the contents of the paper you should output a companion document with three sections
1) A summary of the paper
2) A glossary of important terms and keywords, along with their definition and context in the paper
3) A detailed bibliography that lists the references along with a description of where in the paper they are cited and how they relate to the paper
"""

SUMMARY_PROMPT = """
You are a professor teaching a course from the following paper.
Given the contents of the paper you should output a summary of the paper
"""

KEYWORD_PROMPT = """
You are a professor teaching a course from the following paper.
Given the contents of the paper you should output a glossary of important terms and keywords, along with their definition and context in the paper
"""

CITATION_PROMPT = """
You are a professor teaching a course from the following paper.
Given the contents of the paper you should report the full list of citations along with a description of how each is used in the paper.
Each citation should report the list of authors (without "Et Al"), the title, and the year.
"""


model_name = "llama-3.1-70b-versatile"
#model_name = "llama3-70b-8192"
#model_name = "llama-3.1-8b-instant"

model = ChatGroq(groq_api_key=GROQ_API_KEY, model=model_name)
keyword_model = model.with_structured_output(ConxualizedKeywordList)
summary_model = model
citation_model = model.with_structured_output(ConxualizedCitationList)


class ConxualizedKeyword(TypedDict):
    keyword: str
    definition: str
    local_context: str
    # global_context: str

class ConxualizedCitation(TypedDict):
    title: str
    authors: list[str]
    year: int
    description: str
    # local_context: str
    # global_context: str

class ContextualizedCitationsAbstracts(TypedDict):
    citations: str
    context: str
    abstract: str

class ConxualizedKeywordList(TypedDict):
    keywords: list[ConxualizedKeyword]

class ConxualizedCitationList(TypedDict):
    citations: list[ConxualizedCitation]

# Define the state with a built-in messages key
class ResearchState(MessagesState):
    paper_url: str
    paper_md: str
    summary: str
    keywords: ConxualizedKeywordList
    citations: ConxualizedCitationList
    reading_assistance_md: str

# Define the logic for each node
def input_node(state: ResearchState) -> ResearchState:
    # Logic to process the input paper
    url = state["paper_url"]
    r = requests.get(url)
    doc = pymupdf.Document(stream=r.content)
    paper = pymupdf4llm.to_markdown(doc)
    summary = summary_model.invoke([
      [
        "system",
        SUMMARY_PROMPT,
      ],
      ["human", paper],
    ]).content
    return {
        "paper_md": paper,
        "summary": summary,
    }


def keyword_extraction_node(state: ResearchState) -> ResearchState:
    # Logic to extract keywords
    reponse = keyword_model.invoke([
      [
        "system",
        KEYWORD_PROMPT,
      ],
      ["human", state["paper_md"]],
    ])
    return {"keywords": response}


def citation_extraction_node(state: ResearchState) -> ResearchState:
    # Logic to extract citations
    citations = []
    # Sometimes it only extracts a small number, or is in the wrong format, run it a few times to be safe
    for i in range(1):
        citations.extend(
            citation_model.invoke([
              [
                "system",
                CITATION_PROMPT,
              ],
              ["human", md_text],
            ])
        )
    return {"citations": citations}


def contextualization_node(state: ResearchState) -> ResearchState:
    # Logic to add context to keywords and citations
    return {"context": "contextualized information"}


def abstract_fetching_node(state: ResearchState) -> ResearchState:
    # Logic to fetch abstracts for citations
    return {"abstracts": ["abstract1", "abstract2"]}


def reading_assistance_node(state: ResearchState) -> ResearchState:
    # Logic to provide reading assistance
    return {"reading_assistance": "assistance context"}


def final_contextualization_node(state: ResearchState) -> ResearchState:
    # Logic to finalize the contextualization
    return {"messages": [("system", "Final contextualization complete")]}


# Create the graph
graph = StateGraph(ResearchState)

# Add nodes to the graph
graph.add_node("input_node", input_node)
graph.add_node("keyword_extraction_node", keyword_extraction_node)
graph.add_node("citation_extraction_node", citation_extraction_node)
graph.add_node("contextualization_node", contextualization_node)
graph.add_node("abstract_fetching_node", abstract_fetching_node)
graph.add_node("reading_assistance_node", reading_assistance_node)
graph.add_node("final_contextualization_node", final_contextualization_node)

# Define the edges between nodes
graph.set_entry_point("input_node")
graph.add_edge("input_node", "keyword_extraction_node")
graph.add_edge("input_node", "citation_extraction_node")
graph.add_edge("keyword_extraction_node", "contextualization_node")
graph.add_edge("citation_extraction_node", "contextualization_node")
graph.add_edge("citation_extraction_node", "abstract_fetching_node")
graph.add_edge("contextualization_node", "reading_assistance_node")
graph.add_edge("abstract_fetching_node", "reading_assistance_node")
graph.add_edge("reading_assistance_node", "final_contextualization_node")
graph.set_finish_point("final_contextualization_node")

# Compile the graph
app = graph.compile()

In [None]:
from IPython.display import Image, display

display(Image(app.get_graph().draw_mermaid_png()))

In [None]:
url = "https://arxiv.org/pdf/2310.04406"
app.invoke({"paper_url": url})