# Imports

In [None]:
# System
import os 
import sys
import re
import json
from pathlib import Path
from dotenv import load_dotenv

# Data Type
import json
from textwrap import dedent
from pprint import pprint

# Configs

In [None]:
# Add project root to path
sys.path.append(str(Path.cwd().parent))
print("Project root added to path.")

env_path = Path.cwd().parent / '.env'
load_dotenv(dotenv_path=env_path)

# Load environment variables
print(f"Environment variables loaded: {load_dotenv()}.")

# Gemini Model

In [None]:
from langchain_google_genai.chat_models import ChatGoogleGenerativeAI


model = 'gemini-2.5-flash-preview-04-17'

llm = ChatGoogleGenerativeAI(
    api_key=os.environ.get("GEMINI_API_KEY"),
    model=model, 
    temperature=0.6,
)
llm.invoke("Hello, testing connections!")

# PGVector

In [None]:
from langchain_postgres.vectorstores import PGVector
from langchain_google_genai import GoogleGenerativeAIEmbeddings

embeddings = GoogleGenerativeAIEmbeddings(
    model="models/gemini-embedding-exp-03-07",
    google_api_key=os.environ.get("GEMINI_API_KEY")
)

# See docker command above to launch a postgres instance with pgvector enabled.
pg_user = os.environ.get('DB_USER')
pg_password = os.environ.get('DB_PASSWORD')
pg_db = os.environ.get('DB_NAME')
pg_host = os.environ.get('DB_HOST')
pg_port = os.environ.get('DB_PORT')
connection = f"postgresql+psycopg://{pg_user}:{pg_password}@{pg_host}:{pg_port}/{pg_db}"
collection_name = "steam_games"

vector_store = PGVector(
    embeddings=embeddings,
    collection_name=collection_name,
    connection=connection,
    use_jsonb=True,
)

In [None]:
from langchain.chains.query_constructor.schema import AttributeInfo
from langchain.retrievers.self_query.base import SelfQueryRetriever


metadata_field_info = [
    AttributeInfo(
        name="name",
        description="The name of the game",
        type="string",
    ),
    AttributeInfo(
        name="year",
        description="The year the movie was released",
        type="integer",
    ),
    AttributeInfo(
        name="director",
        description="The name of the movie director",
        type="string",
    ),
    AttributeInfo(
        name="rating", description="A 1-10 rating for the movie", type="float"
    ),
]
document_content_description = "Brief summary of a movie"
llm = ChatGoogleGenerativeAI()
retriever = SelfQueryRetriever.from_llm(
    llm, vector_store, document_content_description, metadata_field_info, verbose=True
)

retriever = SelfQueryRetriever.from_llm(
    llm,
    vector_store,
    document_content_description,
    metadata_field_info,
    enable_limit=True,
    verbose=True,
)

# This example only specifies a relevant query
retriever.invoke("what are two movies about dinosaurs")

# LangGraph

## Tools

checkout: https://python.langchain.com/docs/integrations/tools/discord/  
for discord tool

In [None]:
from langchain_community.tools import DuckDuckGoSearchResults

search = DuckDuckGoSearchResults(output_format='json')

search.invoke("Obama")

In [None]:
from langchain.tools import tool
from langchain_tavily import TavilySearch
from langgraph.prebuilt import ToolNode


@tool
def tavily_search(query):
    """Returns search results from Tavily."""
    return TavilySearch(max_results=2).invoke(query)

@tool
def get_time():
    """Returns the current system time."""
    from datetime import datetime
    return datetime.now().strftime("%Y-%m-%d %H:%M:%S")


tools = [tavily_search, get_time]
tool_node = ToolNode(tools)

llm_with_tools = llm.bind_tools(tools)

In [None]:
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langchain_community.utilities import SQLDatabase

pg_uri = f"postgresql+psycopg2://{pg_user}:{pg_password}@{pg_host}:{pg_port}/{pg_db}"
db = SQLDatabase.from_uri(pg_uri)

print(f"Dialect: {db.dialect}")
# print(f"Available tables: {db.get_usable_table_names()}")
# print(f'Sample output: {db.run("SELECT * FROM bronze.details LIMIT 5;")}')


toolkit = SQLDatabaseToolkit(db=db, llm=llm)

tools = toolkit.get_tools()

for tool in tools:
    print(f"{tool.name}: {tool.description}\n")
    
system_prompt = """
You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct {dialect} query to run,
then look at the results of the query and return the answer. Unless the user
specifies a specific number of examples they wish to obtain, always limit your
query to at most {top_k} results.

You can order the results by a relevant column to return the most interesting
examples in the database. Never query for all the columns from a specific table,
only ask for the relevant columns given the question.

You MUST double check your query before executing it. If you get an error while
executing a query, rewrite the query and try again.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the
database.

To start you should ALWAYS look at the tables in the database to see what you
can query. Do NOT skip this step.

Then you should query the schema of the most relevant tables.
""".format(
    dialect=db.dialect,
    top_k=5,
)

## Agent Definition

In [None]:
from langgraph.prebuilt import create_react_agent

# Create Agent
agent = create_react_agent(
    model=llm,
    tools=tools,
    prompt=system_prompt,
)

## State Definition

In [None]:
from langgraph.graph.message import add_messages
from typing import TypedDict, Annotated, List
from langchain_core.documents import Document

class State(TypedDict):
    context: Annotated[List[Document], add_messages]
    answer: Annotated[List[Document], add_messages]
    question: Annotated[str, 'user question']
    sql_query: Annotated[str, 'sql query']
    binary_score: Annotated[str, 'binary score yes or no']

## Node Definition

all nodes have to be functions, the class GraphState will be passed to these functions but in the end these are just TypedDict so basically python dictionaries

meaning that:  
`return GraphState(context=documents) == {'context': documents}`

In [None]:
def retrieve(state: State) -> State:
    # retrieve: 검색
    documents = "검색된 문서"
    return {"context": documents}


def rewrite_query(state: State) -> State:
    # Query Transform: 쿼리 재작성
    documents = "검색된 문서"
    return State(context=documents)


def execute_gemini(state: State) -> State:
    # LLM 실행
    answer = "Gemini's Response"
    return State(answer=answer)


def relevance_check(state: State) -> State:
    # Relevance Check: 관련성 확인
    binary_score = "Relevance Score"
    return State(binary_score=binary_score)


def sum_up(state: State) -> State:
    # sum_up: 결과 종합
    answer = "종합된 답변"
    return State(answer=answer)


def search_on_web(state: State) -> State:
    # Search on Web: 웹 검색
    documents = state["context"] = "기존 문서"
    searched_documents = "검색된 문서"
    documents += searched_documents
    return State(context=documents)


def get_table_info(state: State) -> State:
    # Get Table Info: 테이블 정보 가져오기
    table_info = "테이블 정보"
    return State(context=table_info)


def generate_sql_query(state: State) -> State:
    # Make SQL Query: SQL 쿼리 생성
    sql_query = "SQL 쿼리"
    return State(sql_query=sql_query)


def execute_sql_query(state: State) -> State:
    # Execute SQL Query: SQL 쿼리 실행
    sql_result = "SQL 결과"
    return State(context=sql_result)


def validate_sql_query(state: State) -> State:
    # Validate SQL Query: SQL 쿼리 검증
    binary_score = "SQL 쿼리 검증 결과"
    return State(binary_score=binary_score)


def handle_error(state: State) -> State:
    # Error Handling: 에러 처리
    error = "에러 발생"
    return State(context=error)


def decision(state: State) -> State:
    # 의사결정
    decision = "결정"
    # 로직을 추가할 수 가 있고요.

    if state["binary_score"] == "yes":
        return "종료"
    else:
        return "재검색"

## Graph Definition

In [None]:
# LangGraph
from langgraph.graph import START,END, StateGraph

# langgraph.graph에서 StateGraph와 END를 가져옵니다.
workflow = StateGraph(State)

# 노드를 추가합니다.
workflow.add_node("query", retrieve)
workflow.add_node("rewrite_query", rewrite_query)
workflow.add_node("rewrite_question", rewrite_query)
workflow.add_node("execute_gemini", execute_gemini)
workflow.add_node("relevance_check_gemini", relevance_check)
workflow.add_node("결과 종합", sum_up)
workflow.add_node("get_table_info", get_table_info)
workflow.add_node("generate_sql_query", generate_sql_query)
workflow.add_node("execute_sql_query", execute_sql_query)
workflow.add_node("validate_sql_query", validate_sql_query)

# 각 노드들을 연결합니다.
workflow.add_edge(START, "query")
workflow.add_edge("query", "get_table_info")
workflow.add_edge("get_table_info", "generate_sql_query")
workflow.add_edge("generate_sql_query", "execute_sql_query")
workflow.add_edge("execute_sql_query", "validate_sql_query")

workflow.add_conditional_edges(
    "validate_sql_query",
    decision,
    {
        "QUERY ERROR": "rewrite_query",
        "UNKNOWN MEANING": "rewrite_question",
        "PASS": "execute_gemini",
    },
)

workflow.add_edge("rewrite_query", "execute_sql_query")
workflow.add_edge("rewrite_question", "rewrite_query")
workflow.add_edge("execute_gemini", "relevance_check_gemini")
workflow.add_edge("relevance_check_gemini", "결과 종합")
workflow.add_edge("결과 종합", END)

## Memory Definition

In [None]:
from langgraph.checkpoint.memory import MemorySaver

# 기록을 위한 메모리 저장소를 설정합니다.
memory = MemorySaver()

In [None]:
from psycopg_pool import ConnectionPool
from langgraph.checkpoint.postgres import PostgresSaver
from psycopg.rows import dict_row

conninfo = f"postgres://{pg_user}:{pg_password}@{pg_host}:{pg_port}/{pg_db}?sslmode=disable"
connection_kwargs = {
    "autocommit": True,
    "prepare_threshold": 0,
    "row_factory": dict_row
}

pool = ConnectionPool(
    conninfo=conninfo,
    max_size=20,
    kwargs=connection_kwargs,
)

checkpointer = PostgresSaver(pool)

checkpointer.setup()

## Graph Compilation

In [None]:
# 그래프를 컴파일합니다.
app = workflow.compile(checkpointer=checkpointer)

## Graph Visualization

In [None]:
# LangGraph Visualization
from IPython.display import Image, display
from langchain_core.runnables.graph_mermaid import MermaidDrawMethod

# Visualize
Image(
    app
    .get_graph()
    .draw_mermaid_png(
        max_retries=5, retry_delay=5, draw_method=MermaidDrawMethod.API,
        output_file_path='../reports/figures/langgraph_viz.png'
    )
)