In [1]:
import requests
from langchain import hub
from langchain_groq import ChatGroq
from langchain_community.utilities import SQLDatabase
from pprint import pprint

In [2]:
def db_read_test():
    db = SQLDatabase.from_uri("sqlite:///./db/Chinook.db")
    print(db.dialect)
    print(db.get_usable_table_names())
    print(db.run("SELECT * FROM Artist LIMIT 10;"))

db_read_test()

sqlite
['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']
[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]


In [3]:
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langgraph.prebuilt import create_react_agent
from langchain_ollama.chat_models import ChatOllama

In [4]:
def sql_agent(user_input:str):
    db = SQLDatabase.from_uri("sqlite:///./db/Chinook.db")
    llm = ChatGroq(temperature=0, model_name= "llama-3.2-90b-text-preview")
    # llm = ChatOllama(base_url="http://localhost:11434", model="llama3.2:latest")

    toolkit = SQLDatabaseToolkit(db=db, llm=llm)
    prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt")
    system_message = prompt_template.format(dialect="SQLite", top_k=5)
    agent_executor = create_react_agent(
        llm, 
        toolkit.get_tools(), 
        state_modifier=system_message)
    events = agent_executor.stream(
        {"messages": [("user", user_input)]},
        stream_mode="values",
        debug=True)
    
    all_results = []
    for event in events:
        all_results.append(event["messages"][-1])

    return all_results

In [9]:
prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt")
print(prompt_template)

input_variables=['dialect', 'top_k'] input_types={} partial_variables={} metadata={'lc_hub_owner': 'langchain-ai', 'lc_hub_repo': 'sql-agent-system-prompt', 'lc_hub_commit_hash': '31156d5fe3945188ee172151b086712d22b8c70f8f1c0505f5457594424ed352'} messages=[SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=['dialect', 'top_k'], input_types={}, partial_variables={}, template='You are an agent designed to interact with a SQL database.\nGiven an input question, create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.\nUnless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.\nYou can order the results by a relevant column to return the most interesting examples in the database.\nNever query for all the columns from a specific table, only ask for the relevant columns given the question.\nYou have access to tools for interacting with the database.\

In [5]:
result = sql_agent("2009년에 가장 많이 팔린 장르는 무엇이며, 해당 장르의 총 매출액은 얼마인가요?")
pprint(result)


[36;1m[1;3m[-1:checkpoint][0m [1mState at the end of step -1:
[0m{'messages': []}
[36;1m[1;3m[0:tasks][0m [1mStarting step 0 with 1 task:
[0m- [32;1m[1;3m__start__[0m -> {'messages': [('user', '2009년에 가장 많이 팔린 장르는 무엇이며, 해당 장르의 총 매출액은 얼마인가요?')]}
[36;1m[1;3m[0:writes][0m [1mFinished step 0 with writes to 1 channel:
[0m- [33;1m[1;3mmessages[0m -> [('user', '2009년에 가장 많이 팔린 장르는 무엇이며, 해당 장르의 총 매출액은 얼마인가요?')]
[36;1m[1;3m[0:checkpoint][0m [1mState at the end of step 0:
[0m{'messages': [HumanMessage(content='2009년에 가장 많이 팔린 장르는 무엇이며, 해당 장르의 총 매출액은 얼마인가요?', additional_kwargs={}, response_metadata={}, id='52e4680b-3fed-4b51-af25-f6a87fc7e421')]}
[36;1m[1;3m[1:tasks][0m [1mStarting step 1 with 1 task:
[0m- [32;1m[1;3magent[0m -> {'is_last_step': False,
 'messages': [HumanMessage(content='2009년에 가장 많이 팔린 장르는 무엇이며, 해당 장르의 총 매출액은 얼마인가요?', additional_kwargs={}, response_metadata={}, id='52e4680b-3fed-4b51-af25-f6a87fc7e421')]}
[36;1m[1;3m[1:writes][0m [1mFinished

KeyboardInterrupt: 