## SQL 에이전트

이 튜토리얼에서는 LangGraph를 사용하여 SQL 데이터베이스에 대한 질문에 답변할 수 있는 사용자 지정 에이전트를 구축합니다. LangGraph 기본 요소를 사용하여 SQL 에이전트의 구현 예시를 보여줍니다. 

LangChain은 LangGraph 기본 요소를 사용하여 구현된 내장 에이전트 구현체를 제공합니다. 상위 수준의 LangChain 추상화를 사용하여 SQL 에이전트를 구축하는 튜토리얼은 [여기](https://docs.langchain.com/oss/python/langchain/sql-agent)에서 확인할 수 있습니다.

랭그래프 공식 튜토리얼 참고: https://docs.langchain.com/oss/python/langgraph/sql-agent

## 환경 설정

In [None]:
import os
import getpass
from dotenv import load_dotenv

load_dotenv("../.env", override=True)


def _set_env(var: str):
    env_value = os.environ.get(var)
    if not env_value:
        env_value = getpass.getpass(f"{var}: ")

    os.environ[var] = env_value


_set_env("LANGSMITH_API_KEY")
os.environ["LANGSMITH_TRACING"] = "true"
os.environ["LANGSMITH_PROJECT"] = "langchain-academy"
_set_env("OPENAI_API_KEY")

In [None]:
from langchain.chat_models import init_chat_model

llm = init_chat_model("openai:gpt-4.1-mini")

## 데이터베이스 구성

SQLite 데이터베이스를 생성합니다. 공개 GCS 버킷에서 Chinook.db 데이터베이스 파일을 다운로드 합니다.

In [2]:
import requests
import pathlib

url = "https://storage.googleapis.com/benchmarks-artifacts/chinook/Chinook.db"
local_path = pathlib.Path("temp", "Chinook.db")

if local_path.exists():
    print(f"{local_path} already exists, skipping download.")
else:
    response = requests.get(url)
    if response.status_code == 200:
        local_path.write_bytes(response.content)
        print(f"File downloaded and saved as {local_path}")
    else:
        print(f"Failed to download the file. Status code: {response.status_code}")

File downloaded and saved as temp/Chinook.db


데이터베이스와 상호작용하기 위해 langchain_community 패키지에서 제공하는 SQLDatabase 래퍼를 사용합니다. SQLDatabase는 SQL 쿼리를 실행하고 결과를 가져오는 간단한 인터페이스를 제공합니다.

In [None]:
from langchain_community.utilities import SQLDatabase

db = SQLDatabase.from_uri(f"sqlite:///{local_path}")

print(f"Dialect: {db.dialect}")
print(f"Available tables: {db.get_usable_table_names()}")
print(f"Sample output: {db.run('SELECT * FROM Artist LIMIT 5;')}")

Dialect: sqlite
Available tables: ['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']
Sample output: [(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains')]


## 데이터베이스 상호작용을 위한 도구 추가

이번에도 langchain_community 패키지에서 제공하는 SQLDatabaseToolkit 래퍼를 사용하여 데이터베이스와 상호작용합니다.

In [12]:
from langchain_community.agent_toolkits import SQLDatabaseToolkit

toolkit = SQLDatabaseToolkit(db=db, llm=llm)

tools = toolkit.get_tools()

for tool in tools:
    print(f"* {tool.name}: {tool.description}\n")

* sql_db_query: Input to this tool is a detailed and correct SQL query, output is a result from the database. If the query is not correct, an error message will be returned. If an error is returned, rewrite the query, check the query, and try again. If you encounter an issue with Unknown column 'xxxx' in 'field list', use sql_db_schema to query the correct table fields.

* sql_db_schema: Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. Be sure that the tables actually exist by calling sql_db_list_tables first! Example Input: table1, table2, table3

* sql_db_list_tables: Input is an empty string, output is a comma-separated list of tables in the database.

* sql_db_query_checker: Use this tool to double check if your query is correct before executing it. Always use this tool before executing a query with sql_db_query!



## 노드 정의

다음 단계들을 위한 노드를 구성합니다:

- DB 테이블 목록 생성
- `get schema` 도구 호출
- 쿼리 생성
- 쿼리 검증

이러한 단계들을 노드에 배치함으로써 (1) 필요 시 도구 호출을 강제하고, (2) 각 단계와 연관된 프롬프트를 맞춤 설정할 수 있습니다.

In [None]:
from typing import Literal
from langchain_community.agents import ToolNode
from langchain_core.messages import AIMessage, ToolMessage
from langchain_core.runnables import RunnableConfig
from langgraph.graph import StateGraph, MessagesState, END

list_tables_tool = next(tool for tool in tools if tool.name == "sql_db_list_tables")

get_schema_tool = next(tool for tool in tools if tool.name == "sql_db_schema")
get_schema_tool = ToolNode([get_schema_tool], name="get_schema")

run_query_tool = next(tool for tool in tools if tool.name == "sql_db_query")
run_query_node = ToolNode([run_query_tool], name="run_query")


In [15]:
def list_tables(state: MessagesState):
    tool_call = {
        "name": "sql_db_list_tables",
        "args": {},
        "id": "abc123",
        "type": "tool_call",
    }
    tool_call_message = AIMessage("", tool_call=tool_call)
    tool_message = list_tables_tool.invoke(tool_call)
    response = AIMessage(f"Available tables: {tool_message.content}")

    return {"messages": [tool_call_message, tool_message, response]}

In [16]:
list_tables({})

{'messages': [AIMessage(content='', additional_kwargs={}, response_metadata={}, tool_call={'name': 'sql_db_list_tables', 'args': {}, 'id': 'abc123', 'type': 'tool_call'}),
  ToolMessage(content='Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track', name='sql_db_list_tables', tool_call_id='abc123'),
  AIMessage(content='Available tables: Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track', additional_kwargs={}, response_metadata={})]}

In [23]:
def call_get_schema(state: MessagesState):
    llm_with_tools = llm.bind_tools([get_schema_tool], tool_choice="required")
    response = llm_with_tools.invoke(state["messages"])
    return {"messages": [response]}

In [24]:
call_get_schema({"messages": [("user", "사용 가능한 테이블은?")]})

ValueError: Unsupported function

get_schema(tags=None, recurse=True, explode_args=False, func_accepts={'config': ('N/A', <class 'inspect._empty'>), 'store': ('store', None)}, tools_by_name={'sql_db_schema': InfoSQLDatabaseTool(description='Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. Be sure that the tables actually exist by calling sql_db_list_tables first! Example Input: table1, table2, table3', db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x1165e9af0>)}, tool_to_state_args={'sql_db_schema': {}}, tool_to_store_arg={'sql_db_schema': None}, handle_tool_errors=True, messages_key='messages')

Functions must be passed in as Dict, pydantic.BaseModel, or Callable. If they're a dict they must either be in OpenAI function format or valid JSON schema with top-level 'title' and 'description' keys.