In [1]:
from dotenv import load_dotenv
# .env 파일에 등록된 변수(데이터)를 os 환경변수에 적용
load_dotenv()

True

In [2]:
from langchain_community.utilities import SQLDatabase

# if you are using SQLite
# sqlite_uri = 'sqlite:///./Chinook.db'

# if you are using MySQL
# mysql_uri = 'mysql+mysqlconnector://root:root1234@192.168.123.103:3306/Chinook'
mysql_uri = 'mysql+mysqlconnector://root:root1234@localhost:3306/Chinook'

db = SQLDatabase.from_uri(mysql_uri)

In [3]:
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM Artist LIMIT 10;")


mysql
['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']


"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]"

In [37]:
from langchain_openai import ChatOpenAI

llm = ChatOpenAI(model="gpt-3.5-turbo", temperature="0")

In [38]:
from langchain.chains import create_sql_query_chain

In [39]:
chain = create_sql_query_chain(llm=llm, db=db)

In [40]:
prompts = chain.get_prompts()

In [41]:
len(prompts)

1

In [42]:
prompts[0].pretty_print()

You are a MySQL expert. Given an input question, first create a syntactically correct MySQL query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per MySQL. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in backticks (`) to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use CURDATE() function to get the current date, if the question involves "today".

Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result of the S

In [43]:
response = chain.invoke({
    "question":"How many emplyees are there"
})

In [45]:
print(response)

SELECT COUNT(EmployeeId) AS TotalEmployees FROM Employee;


In [46]:
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool

In [47]:
execute_query = QuerySQLDataBaseTool(db=db)

In [48]:
execute_query(response)

'[(8,)]'

In [49]:
finally_chain = chain | execute_query

In [50]:
finally_chain.invoke({
    "question": "How many employees are there"
})

'[(8,)]'

In [51]:
from operator import itemgetter

from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough

answer_prompt = PromptTemplate.from_template(
  """
    Given the following user question, corresponding SQL query, and SQL result, answer the user question.

    Question: {question}
    SQL Query: {query}
    SQL Result: {result}
    Answer: 
  """
)

In [52]:
answer_prompt.input_variables

['query', 'question', 'result']

In [53]:
chain.invoke({
    "question": "How many employees are there"
})

'SELECT COUNT(EmployeeId) AS TotalEmployees FROM Employee;'

In [57]:
RunnablePassthrough.assign(query=chain).invoke({
    "question": "How many employees are there"
})

{'question': 'How many employees are there',
 'query': 'SELECT COUNT(EmployeeId) AS TotalEmployees FROM Employee;'}

In [55]:
RunnablePassthrough.assign(query=chain).assign(
        result=itemgetter("query") | execute_query
    ).invoke({
        "question": "How many employees are there"
    })

{'question': 'How many employees are there',
 'query': 'SELECT COUNT(EmployeeId) AS TotalEmployees FROM Employee;',
 'result': '[(8,)]'}

In [60]:
tmp_chain = (
    RunnablePassthrough.assign(query=chain).assign(
        result=itemgetter("query") | execute_query
    )
    | answer_prompt
)

tmp_prompt = tmp_chain.invoke({
    "question": "How many employees are there"
})

print(tmp_prompt.text)


    Given the following user question, corresponding SQL query, and SQL result, answer the user question.

    Question: How many employees are there
    SQL Query: SELECT COUNT(EmployeeId) AS TotalEmployees FROM Employee;
    SQL Result: [(8,)]
    Answer: 
  


In [61]:
new_chain = (
    RunnablePassthrough.assign(query=chain).assign(
        result=itemgetter("query") | execute_query
    )
    | answer_prompt
    | llm
    | StrOutputParser()
)

In [62]:
new_chain.invoke({
    "question": "How many employees are there"
})

'There are 8 employees.'

In [63]:
from langchain_community.agent_toolkits import create_sql_agent
from langchain_openai import ChatOpenAI

llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)

agent_executor = create_sql_agent(llm, db=db, agent_type="openai-tools", verbose=True)

In [65]:
agent_executor.invoke({
    "input": "How many employees are there"
})



[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3m
Invoking: `sql_db_list_tables` with `{}`


[0m[38;5;200m[1;3mAlbum, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track[0m[32;1m[1;3m
Invoking: `sql_db_schema` with `{'table_names': 'Employee'}`


[0m[33;1m[1;3m
CREATE TABLE `Employee` (
	`EmployeeId` INTEGER NOT NULL, 
	`LastName` VARCHAR(20) CHARACTER SET utf8mb3 COLLATE utf8mb3_general_ci NOT NULL, 
	`FirstName` VARCHAR(20) CHARACTER SET utf8mb3 COLLATE utf8mb3_general_ci NOT NULL, 
	`Title` VARCHAR(30) CHARACTER SET utf8mb3 COLLATE utf8mb3_general_ci, 
	`ReportsTo` INTEGER, 
	`BirthDate` DATETIME, 
	`HireDate` DATETIME, 
	`Address` VARCHAR(70) CHARACTER SET utf8mb3 COLLATE utf8mb3_general_ci, 
	`City` VARCHAR(40) CHARACTER SET utf8mb3 COLLATE utf8mb3_general_ci, 
	`State` VARCHAR(40) CHARACTER SET utf8mb3 COLLATE utf8mb3_general_ci, 
	`Country` VARCHAR(40) CHARACTER SET utf8mb3 COLLATE utf8mb3_general_ci, 
	`Po

{'input': 'How many employees are there',
 'output': 'There are 8 employees in the database.'}