In [2]:
from langchain_community.llms import Ollama
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
from langchain_core.output_parsers import StrOutputParser
import pandas as pd
from langchain_community.utilities import SQLDatabase
from sqlalchemy import create_engine
from langchain.chains import create_sql_query_chain
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool

In [4]:
llm = Ollama(model="llama3.1", temperature=0)

In [5]:
df = pd.read_csv("budget_tracker.csv")
print(df.shape)
print(df.columns.tolist())

(547, 7)
['Date', 'Description', 'Category', 'Expenditure', 'Year', 'Month', 'Day']


In [6]:
engine = create_engine("sqlite:///budget_tracker.db")
df.to_sql("budget_tracker", engine, index=False)

ValueError: Table 'budget_tracker' already exists.

In [7]:
db = SQLDatabase(engine=engine)
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM budget_tracker WHERE Category = 'Home';")

sqlite
['budget_tracker']


"[('01/10/23', 'Rent payment', 'Home', 15119.0, 2023.0, 10.0, 1.0), ('26/10/23', 'Ac bill', 'Home', 825.0, 2023.0, 10.0, 26.0), ('01/11/23', 'Rent payment', 'Home', 16206.0, 2023.0, 11.0, 1.0), ('24/11/23', 'Ac bill', 'Home', 750.0, 2023.0, 11.0, 24.0), ('01/12/23', 'Rent', 'Home', 16206.0, 2023.0, 12.0, 1.0), ('01/01/24', 'Pay rent', 'Home', 16206.0, 2024.0, 1.0, 1.0), ('02/02/24', 'Rent', 'Home', 16206.0, 2024.0, 2.0, 2.0), ('01/03/24', 'Pay rent', 'Home', 16206.0, 2024.0, 3.0, 1.0), ('01/04/24', 'Rent', 'Home', 16206.0, 2024.0, 4.0, 1.0), ('01/05/24', 'Rent', 'Home', 16206.0, 2024.0, 5.0, 1.0), ('01/06/24', 'rent', 'Home', 16207.0, 2024.0, 6.0, 1.0), ('01/07/24', 'Rent', 'Home', 16207.0, 2024.0, 7.0, 1.0), ('01/08/24', 'rent', 'Home', 17442.0, 2024.0, 8.0, 1.0), ('12/08/24', 'Security deposit', 'Home', 6000.0, 2024.0, 8.0, 12.0), ('31/08/24', 'Electricity bill', 'Home', 300.0, 2024.0, 8.0, 31.0), ('01/09/24', 'Rent', 'Home', 985.0, 2024.0, 9.0, 1.0), ('02/09/24', 'Lunch', 'Home', 25

In [8]:
template = '''Given an input question, return the SQLite query for the same. Return the SQLite Query only, no other English terms. Just the query, since your output will directly be executed.

Only use the following tables:

{table_info}.
top_k={top_k}

Question: {input}'''
prompt = PromptTemplate.from_template(template=template)

In [14]:
chain = create_sql_query_chain(llm, db, prompt, k=1)
response = chain.invoke({"question": "What are the total number of transactions made ?"})
print(response)

SELECT COUNT(*) FROM budget_tracker


In [15]:
db.run(response)

'[(547,)]'

In [16]:
chain.get_prompts()[0].pretty_print()

Given an input question, return the SQLite query for the same. Return the SQLite Query only, no other English terms. Just the query, since your output will directly be executed.

Only use the following tables:

[33;1m[1;3m{table_info}[0m.
top_k=1

Question: [33;1m[1;3m{input}[0m


In [18]:
execute_query = QuerySQLDataBaseTool(db=db)
write_query = create_sql_query_chain(llm, db, prompt=prompt)
chain = write_query | execute_query
chain.invoke({"question": "What are the total number of transactions made ?"})

'[(547,)]'

In [19]:
from operator import itemgetter

from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough

answer_prompt = PromptTemplate.from_template(
    """Given the following user question, corresponding SQL query, and SQL result, answer the user question.

Question: {question}
SQL Query: {query}
SQL Result: {result}
Answer: """
)

answer = answer_prompt | llm | StrOutputParser()
chain = (
    RunnablePassthrough.assign(query=write_query).assign(
        result=itemgetter("query") | execute_query
    )
    | answer
)

chain.invoke({"question": "What are the total number of transactions made ?"})

'The total number of transactions made is 547.'