In [1]:
import sqlite3

import requests
from langchain_community.utilities.sql_database import SQLDatabase
from sqlalchemy import create_engine
from sqlalchemy.pool import StaticPool


def get_engine_for_chinook_db():
    """Pull sql file, populate in-memory database, and create engine."""
    url = "https://raw.githubusercontent.com/lerocha/chinook-database/master/ChinookDatabase/DataSources/Chinook_Sqlite.sql"
    response = requests.get(url)
    sql_script = response.text

    connection = sqlite3.connect(":memory:", check_same_thread=False)
    connection.executescript(sql_script)
    return create_engine(
        "sqlite://",
        creator=lambda: connection,
        poolclass=StaticPool,
        connect_args={"check_same_thread": False},
    )


engine = get_engine_for_chinook_db()

db = SQLDatabase(engine)

In [2]:
import getpass
import os

if not os.environ.get("GROQ_API_KEY"):
  os.environ["GROQ_API_KEY"] = getpass.getpass("Enter API key for Groq: ")

from langchain.chat_models import init_chat_model

llm = init_chat_model("llama3-8b-8192", model_provider="groq")

In [3]:
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit

toolkit = SQLDatabaseToolkit(db=db, llm=llm)

In [4]:
toolkit.get_tools()

[QuerySQLDatabaseTool(description="Input to this tool is a detailed and correct SQL query, output is a result from the database. If the query is not correct, an error message will be returned. If an error is returned, rewrite the query, check the query, and try again. If you encounter an issue with Unknown column 'xxxx' in 'field list', use sql_db_schema to query the correct table fields.", db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x77f455b21d20>),
 InfoSQLDatabaseTool(description='Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. Be sure that the tables actually exist by calling sql_db_list_tables first! Example Input: table1, table2, table3', db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x77f455b21d20>),
 ListSQLDatabaseTool(db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x77f455b21d20>),
 QuerySQLCheckerTool(description='Use this tool to double check

In [5]:
from langchain_community.tools.sql_database.tool import (
    InfoSQLDatabaseTool,
    ListSQLDatabaseTool,
    QuerySQLCheckerTool,
    QuerySQLDatabaseTool,
)

In [6]:
from langchain import hub

prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt")

assert len(prompt_template.messages) == 1
print(prompt_template.input_variables)



['dialect', 'top_k']


In [7]:
system_message = prompt_template.format(dialect="SQLite", top_k=5)

In [8]:
from langgraph.prebuilt import create_react_agent

agent_executor = create_react_agent(llm, toolkit.get_tools(), prompt=system_message)

In [15]:
# example_query = 'Retorne os generos que começam com letra r'
example_query = "Who are Brazil's customers?"


events = agent_executor.invoke(
    {"messages": [("user", example_query)]},
    # stream_mode="values",
)

# for event in events:
  # event["messages"][-1].pretty_print()

events['messages']

[HumanMessage(content="Who are Brazil's customers?", additional_kwargs={}, response_metadata={}, id='6c5b2b85-b508-4838-abe7-abcf64ad6e12'),
 AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_81by', 'function': {'arguments': '{"tool_input":""}', 'name': 'sql_db_list_tables'}, 'type': 'function'}, {'id': 'call_6s78', 'function': {'arguments': '{"table_names":""}', 'name': 'sql_db_schema'}, 'type': 'function'}, {'id': 'call_sabz', 'function': {'arguments': '{"query":""}', 'name': 'sql_db_query_checker'}, 'type': 'function'}, {'id': 'call_sh86', 'function': {'arguments': '{"query":""}', 'name': 'sql_db_query'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 235, 'prompt_tokens': 1633, 'total_tokens': 1868, 'completion_time': 0.195833333, 'prompt_time': 0.223540467, 'queue_time': 0.05357554100000003, 'total_time': 0.4193738}, 'model_name': 'llama3-8b-8192', 'system_fingerprint': 'fp_a97cfe35ae', 'finish_reason': 'tool_calls', 'logprobs': N