In [25]:
from langchain_community.utilities import SQLDatabase

db = SQLDatabase.from_uri("sqlite:///my.db")
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM Artist LIMIT 10;")

sqlite
['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']


"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]"

In [26]:
from langchain.chains import create_sql_query_chain
from langchain_community.chat_models import ChatOllama
model = ChatOllama(model="qwen:1.8b")

chain = create_sql_query_chain(model, db)
response = chain.invoke({"question": "有多少员工?"})
response

'The SQL query to find the number of employees in a given table is:\n```\nSELECT COUNT(*) FROM `Employee` ;\n```\nThis query selects all rows from the `Employee` table and returns the count of rows, which represents the number of employees in the given table.'

In [27]:
db.run(response)
chain.get_prompts()[0].pretty_print()

OperationalError: (sqlite3.OperationalError) near "The": syntax error
[SQL: The SQL query to find the number of employees in a given table is:
```
SELECT COUNT(*) FROM `Employee` ;
```
This query selects all rows from the `Employee` table and returns the count of rows, which represents the number of employees in the given table.]
(Background on this error at: https://sqlalche.me/e/20/e3q8)

In [28]:
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool

execute_query = QuerySQLDataBaseTool(db=db)
write_query = create_sql_query_chain(model, db)
chain = write_query | execute_query
chain.invoke({"question": "有多少员工?"})

'Error: (sqlite3.OperationalError) near "The": syntax error\n[SQL: The SQL query to count the number of employees in a given database would be:\n```\nSELECT COUNT(*) \nFROM Employee;\n```\nThis query selects the `COUNT(*)` column from the `Employee` table, which represents the total number of employees.\nNote that the exact syntax and structure of the SQL query may vary depending on the specific database management system being used.]\n(Background on this error at: https://sqlalche.me/e/20/e3q8)'

In [30]:
from operator import itemgetter

from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough

answer_prompt = PromptTemplate.from_template(
    """给定以下用户问题、相应的SQL查询和SQL结果，回答用户问题。

Question: {question}
SQL Query: {query}
SQL Result: {result}
Answer: """
)

answer = answer_prompt | model | StrOutputParser()
chain = (
    RunnablePassthrough.assign(query=write_query).assign(
        result=itemgetter("query") | execute_query
    )
    | answer
)

chain.invoke({"question": "有多少员工?"})

'The SQL query to count the number of employees in a given table and exclude non-nullable value in Bill Amount Column is as follows:\n```sql\nSELECT COUNT(*) \nFROM (\n    SELECT EmployeeID, BillAmount\n    FROM `Employee` Table\n    WHERE BillAmount IS NOT NULL\n    GROUP BY EmployeeID)\nWHERE Count > 0;\n```\nIn this query, we first select the column `BillAmount` from the table `Employee`. We then group the results by `EmployeeID` using the `GROUP BY` clause.\nNext, we use a subquery to filter out rows where `BillAmount` is null. This subquery is enclosed in parentheses and used as a condition inside the outer `WHERE` clause.\nFinally, we apply the `COUNT(*)` column value operator to count the number of non-null rows present in the result set of the subquery.\nThe SQL query executed successfully and returned the expected answer, which is the count of employees in a given table and exclude non-nullable value in Bill Amount Column.\n'