# LLM <-> DB
- Langgraph agent - PostgreDB 연결
- 사용자 요청 > LLM > SQL 쿼리 변환 > DB > LLM 답변 생성 > 사용자

`.env`에 Postgres와 연결되는 데이터를 줘야함

In [None]:
from dotenv import load_dotenv

load_dotenv()

True

In [19]:
from langchain_openai import ChatOpenAI

llm = ChatOpenAI(model='gpt-4.1', temperature=0)

In [None]:
import os
from langchain_community.utilities import SQLDatabase

POSTGRES_USER = os.getenv('POSTGRES_USER')
POSTGRES_PASSWORD = os.getenv('POSTGRES_PASSWORD')
POSTGRES_DB = os.getenv('POSTGRES_DB')

URI = f'postgresql://{POSTGRES_USER}:{POSTGRES_PASSWORD}@localhost:5432/{POSTGRES_DB}'

db = SQLDatabase.from_uri(URI)

In [None]:
print(db.dialect)  # 사용중인 DB 종류
print(db.get_usable_table_names())  # 사용 가능한 테이블 명

db.run('SELECT * FROM artist LIMIT 10')
# >> LLM한테 쿼리를 작성하라 하면, 결과를 파이썬에서 받아볼 수 있다. 

postgresql
['album', 'artist', 'customer', 'employee', 'genre', 'invoice', 'invoice_line', 'media_type', 'playlist', 'playlist_track', 'track']


"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]"

In [18]:
from langchain_core.prompts import ChatPromptTemplate

# dialect: postgres, top_k: 최대 가져올 제한, table_info: 사용 가능한 테이블 명
system_message = """
Given an input question, create a syntactically correct {dialect} query to
run to help find the answer. Unless the user specifies in his question a
specific number of examples they wish to obtain, always limit your query to
at most {top_k} results. You can order the results by a relevant column to
return the most interesting examples in the database.

Never query for all the columns from a specific table, only ask for a the
few relevant columns given the question.

Pay attention to use only the column names that you can see in the schema
description. Be careful to not query for columns that do not exist. Also,
pay attention to which column is in which table.

Only use the following tables:
{table_info}
"""

user_prompt = 'Question: {input}'

query_prompt_template = ChatPromptTemplate(
    {
        ('system', system_message),
        ('user', user_prompt),
    }
)

# for msg in query_prompt_template.messages:
#     msg.pretty_print()

In [None]:
# 이 프롬프트로 어느정도의 쿼리를 만들어올지 확인
from typing_extensions import Annotated, TypedDict
from langgraph.graph import MessagesState

class QueryOutput(TypedDict):
    '''Generate SQL query'''
    query: Annotated[str, ..., '문법적으로 올바른 SQL 쿼리']

class State(MessagesState):
    question: str  # 사용자의 질문
    sql: str  # 여기에 SQL문
    result: str  # SQL 쿼리를 돌린 결과
    answer: str  # 최종 결과물

In [28]:
# SQL 생성 노드
def write_sql(state: State):
    '''Generate SQL query to fetch info'''
    prompt = query_prompt_template.invoke({
        'dialect': db.dialect,
        'top_k': 10,
        'table_info': db.get_table_info(),
        'input': state['question']
    })
    structured_llm = llm.with_structured_output(QueryOutput)
    result = structured_llm.invoke(prompt)
    return {'sql': result['query']}

sql = write_sql({'question': '총 직원은 몇명이야?'})['sql']
print(sql)
print(db.run(sql))

SELECT COUNT(*) AS employee_count FROM employee;
[(8,)]
