# Langgraph 기반 다단계 문서-그래프 구축 예제

In [None]:
%%capture

!pip install openai langchain python-arango

In [None]:
# LangGraph 기반 다단계 문서-그래프 구축 예제

from langchain.chat_models import ChatOpenAI
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from langgraph.graph import StateGraph
from langchain.document_loaders import PyPDFLoader
from arango import ArangoClient
import json

# 1. Document Loader

def load_document(state):
    doc_path = state["doc_path"]
    page = state.get("page", 0)
    loader = PyPDFLoader(doc_path)
    pages = loader.load()
    if page >= len(pages):
        return {"end": True}
    return {"text": pages[page].page_content, "page": page, "end": False}

# 2. Entity Extraction

ENTITY_PROMPT = PromptTemplate.from_template("""
다음 문서 내용을 분석하여 nodes와 edges를 추출하세요. 결과는 JSON 형식으로 출력하세요.

문서:
"""{text}"""

결과 예시:
{
  "nodes": [
    {"_key": "steve_jobs", "name": "Steve Jobs", "type": "person"},
    {"_key": "apple", "name": "Apple", "type": "company"}
  ],
  "edges": [
    {"_from": "nodes/steve_jobs", "_to": "nodes/apple", "relation": "founded"}
  ]
}
""")

llm = ChatOpenAI(temperature=0)
extract_chain = LLMChain(llm=llm, prompt=ENTITY_PROMPT)

def extract_entities_with_llm(state):
    result = extract_chain.run(text=state["text"])
    data = json.loads(result)
    return {"nodes": data["nodes"], "edges": data["edges"], "page": state["page"], "doc_path": state["doc_path"]}

# 3. Validation (간단히 중복 필터링만 예시)

def validate_entities(state):
    seen_keys = set()
    nodes = []
    for n in state["nodes"]:
        if n["_key"] not in seen_keys:
            seen_keys.add(n["_key"])
            nodes.append(n)
    return {**state, "nodes": nodes}

# 4. ArangoDB 삽입

def insert_to_arango(state):
    db = ArangoClient().db("_system", username="root", password="your_password")
    if not db.has_collection("nodes"):
        db.create_collection("nodes")
    if not db.has_collection("edges"):
        db.create_collection("edges", edge=True)
    nc = db.collection("nodes")
    ec = db.collection("edges")

    for n in state["nodes"]:
        if not nc.has(n["_key"]):
            nc.insert(n)
    for e in state["edges"]:
        ec.insert(e)

    return {**state, "page": state["page"] + 1}

# 5. 조건 분기

def check_more_pages(state):
    return "end" if state.get("end") else "continue"

# 6. LangGraph 정의

graph = StateGraph()
graph.add_node("Load", load_document)
graph.add_node("Extract", extract_entities_with_llm)
graph.add_node("Validate", validate_entities)
graph.add_node("Save", insert_to_arango)

graph.set_entry_point("Load")
graph.add_edge("Load", "Extract")
graph.add_edge("Extract", "Validate")
graph.add_edge("Validate", "Save")
graph.add_conditional_edges("Save", check_more_pages, {
    "continue": "Load",
    "end": "__end__"
})

workflow = graph.compile()

# 7. 실행

initial_state = {"doc_path": "sample.pdf", "page": 0}
workflow.invoke(initial_state)


# 8. Streamlit 시각화 대시보드

def visualize_graph():
    st.title("📊 ArangoDB 지식 그래프 시각화")
    db = ArangoClient().db("_system", username="root", password="your_password")
    query = """
    FOR v, e IN 1..2 ANY 'nodes/steve_jobs' GRAPH 'knowledge_graph'
      RETURN { vertex: v, edge: e }
    """
    cursor = db.aql.execute(query)
    data = list(cursor)

    G = nx.DiGraph()
    for item in data:
        v = item['vertex']
        e = item['edge']
        from_key = e['_from'].split('/')[1]
        to_key = e['_to'].split('/')[1]
        relation = e.get('relation', '')
        G.add_node(from_key)
        G.add_node(to_key)
        G.add_edge(from_key, to_key, label=relation)

    pos = nx.spring_layout(G)
    plt.figure(figsize=(10, 7))
    nx.draw(G, pos, with_labels=True, node_color="skyblue", node_size=2000, font_size=12)
    edge_labels = nx.get_edge_attributes(G, 'label')
    nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_size=10)
    st.pyplot(plt)

# Streamlit 실행 (main.py로 저장 후 실행)
# streamlit run main.py