Load csv

In [1]:
import pandas as pd

df = pd.read_csv('./data/mortgage-300h-360m-5r.csv')

def rename_columns(df):
    # Rename columns with spaces to underscores
    df.columns = [col.replace(" ", "_") if " " in col else col for col in df.columns]
    return df

# Rename columns in the DataFrame
df = rename_columns(df)
print(df.shape)
print(df.columns.tolist())

(361, 5)
['Period', 'Monthly_Payment', 'Computed_Interest_Due', 'Principal_Due', 'Principal_Balance']


In [2]:
from langchain_community.utilities import SQLDatabase
from sqlalchemy import create_engine, text

# Create an engine to connect to the database
engine = create_engine("sqlite:///mortgage.db")

# Function to check if table exists
def check_table_exists(table_name):
    with engine.connect() as connection:
        result = connection.execute(text(f"SELECT name FROM sqlite_master WHERE type='table' AND name='{table_name}';"))
        return result.fetchone() is not None

# Check if the table 'mortgage' exists
table_exists = check_table_exists('mortgage')

if table_exists:
    print("Table 'mortgage' exists.")
else:
    print("Table 'mortgage' does not exist. Creating table...")
    # If the table doesn't exist, create it
    df.to_sql("mortgage", engine, index=False)




Table 'mortgage' exists.


In [3]:
db = SQLDatabase(engine=engine)
print(db.dialect)
print(db.get_usable_table_names())
print(db.run("SELECT * FROM mortgage WHERE Period < 3;"))

sqlite
['mortgage']
[(0, 0.0, 0.0, 0.0, 300000.0), (1, 1610.4648690364193, 1250.0, 360.46486903641926, 299639.5351309636), (2, 1610.4648690364193, 1248.4980630456705, 361.966805990749, 299277.5683249729)]


In [4]:
from langchain_ollama import ChatOllama

# LLM from Ollama
local_model = "mistral:latest"
llm = ChatOllama(model=local_model, temperature=0)

In [5]:
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langchain_community.agent_toolkits import create_sql_agent
from langchain_core.messages import HumanMessage

toolkit = SQLDatabaseToolkit(db=db, llm=llm) 
agent_executor = create_sql_agent(llm, toolkit=toolkit, agent_type="tool-calling")

In [6]:
# agent_executor.invoke({"input": "List tables in database"})

In [7]:
from langchain import hub

# Get from https://smith.langchain.com/hub
query_prompt_template = hub.pull("langchain-ai/sql-query-system-prompt")

assert len(query_prompt_template.messages) == 1
query_prompt_template.messages[0].pretty_print()




Given an input question, create a syntactically correct [33;1m[1;3m{dialect}[0m query to run to help find the answer. Unless the user specifies in his question a specific number of examples they wish to obtain, always limit your query to at most [33;1m[1;3m{top_k}[0m results. You can order the results by a relevant column to return the most interesting examples in the database.

Never query for all the columns from a specific table, only ask for a the few relevant columns given the question.

Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.

Only use the following tables:
[33;1m[1;3m{table_info}[0m

Question: [33;1m[1;3m{input}[0m


In [10]:
from typing_extensions import Annotated, TypedDict
from typing import Optional, TypedDict
from pydantic import BaseModel, Field
from langchain_community.utilities import SQLDatabase
from sqlalchemy import create_engine

# Initialize the database connection
engine = create_engine("sqlite:///mortgage.db")
db = SQLDatabase(engine)

class State(TypedDict):
    question: str
    query: str
    result: str
    answer: str

class QueryOutput(TypedDict):
    """Generated SQL query."""

    query: Annotated[str, ..., "Syntactically valid SQL query."]

def write_query(state: State):
    """Generate SQL query to fetch information."""
    print(state)

    # Generate the SQL query using the template
    prompt = """
    ================================ System Message ================================
    Given an input question, create a syntactically correct {dialect} query to run to help find the answer. Unless the user specifies in his question a specific number of examples they wish to obtain, always limit your query to at most {top_k} results. You can order the results by a relevant column to return the most interesting examples in the database.

    Never query for all the columns from a specific table, only ask for a few relevant columns given the question.

    Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.

    Only use the following tables:
    {table_info}

    Question: {input}
    """.format(
        dialect=db.dialect,
        top_k=10,
        table_info=db.get_table_info(),
        input=state["question"]
    )

    structured_llm = llm.with_structured_output(QueryOutput)
    result = structured_llm.invoke(prompt)
    return result

# Example usage
state = {
    "question": "What are the top 5 most recent mortgage payments?"
}

query_result = write_query(state)
print(query_result)


{'question': 'What are the top 5 most recent mortgage payments?'}
{'query': 'SELECT Period, Monthly_Payment FROM mortgage ORDER BY Period DESC LIMIT 5'}


In [11]:
from langchain_community.tools.sql_database.tool import QuerySQLDatabaseTool

def execute_query(state: State):
    """Execute SQL query."""
    execute_query_tool = QuerySQLDatabaseTool(db=db)
    return {"result": execute_query_tool.invoke(state["query"])}

In [12]:
ex_result = execute_query(query_result)
ex_result

{'result': '[(360, 1610.4648690364193), (359, 1610.4648690364193), (358, 1610.4648690364193), (357, 1610.4648690364193), (356, 1610.4648690364193)]'}

In [13]:
def generate_answer(state: State):
    """Answer question using retrieved information as context."""
    prompt = (
        "Given the following user question, corresponding SQL query, "
        "and SQL result, answer the user question.\n\n"
        f'Question: {state["question"]}\n'
        f'SQL Query: {state["query"]}\n'
        f'SQL Result: {state["result"]}'
    )
    response = llm.invoke(prompt)
    return {"answer": response.content}

In [14]:
from langgraph.graph import START, StateGraph

graph_builder = StateGraph(State).add_sequence(
    [write_query, execute_query, generate_answer]
)
graph_builder.add_edge(START, "write_query")
graph = graph_builder.compile()

In [15]:
for step in graph.stream(
    {"question": "What are the top 5 most recent mortgage payments?"}, stream_mode="updates"
):
    print(step)

{'question': 'What are the top 5 most recent mortgage payments?'}
{'write_query': {'query': 'SELECT Period, Monthly_Payment FROM mortgage ORDER BY Period DESC LIMIT 5'}}
{'execute_query': {'result': '[(360, 1610.4648690364193), (359, 1610.4648690364193), (358, 1610.4648690364193), (357, 1610.4648690364193), (356, 1610.4648690364193)]'}}
{'generate_answer': {'answer': ' Based on the SQL result provided, the top 5 most recent mortgage payments are as follows:\n\n1. Period 360 with a Monthly Payment of $1610.46\n2. Period 359 with a Monthly Payment of $1610.46\n3. Period 358 with a Monthly Payment of $1610.46\n4. Period 357 with a Monthly Payment of $1610.46\n5. Period 356 with a Monthly Payment of $1610.46\n\nEach payment has the same amount of $1610.46, but they are ordered by their respective periods in descending order (most recent first).'}}


In [16]:
for step in graph.stream(
    {"question": "What is the total monthly payment?"}, stream_mode="updates"
):
    print(step)

{'question': 'What is the total monthly payment?'}
{'write_query': {'query': 'SELECT SUM(Monthly_Payment) FROM mortgage;'}}
{'execute_query': {'result': '[(579767.352853111,)]'}}
{'generate_answer': {'answer': ' The total monthly payment is 579767.35.'}}
