In [121]:
# %pip install sqlparse

https://python.langchain.com/v0.2/docs/tutorials/sql_qa/

In [1]:
from sqlalchemy import create_engine
from langchain_community.utilities import SQLDatabase

engine = create_engine("sqlite:///data.db")

# Load the database using the engine
db = SQLDatabase(engine)
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM dim_daypart LIMIT 10;")

sqlite
['dim_daypart', 'locations', 'order_item', 'order_summary', 'product_margin']


"[('0:00', 'Latenight'), ('0:01', 'Latenight'), ('0:02', 'Latenight'), ('0:03', 'Latenight'), ('0:04', 'Latenight'), ('0:05', 'Latenight'), ('0:06', 'Latenight'), ('0:07', 'Latenight'), ('0:08', 'Latenight'), ('0:09', 'Latenight')]"

In [2]:
from langchain.chains.sql_database.prompt import SQL_PROMPTS

# supported dialects
list(SQL_PROMPTS)

['crate',
 'duckdb',
 'googlesql',
 'mssql',
 'mysql',
 'mariadb',
 'oracle',
 'postgresql',
 'sqlite',
 'clickhouse',
 'prestodb']

In [3]:
import os 
from dotenv import load_dotenv
import openai

# load the .env file
load_dotenv(".env")
openai.api_key = os.getenv("OPENAI_API_KEY")

from langchain_openai import ChatOpenAI

llm = ChatOpenAI(model="gpt-4o-mini")

In [4]:
from langchain.chains import create_sql_query_chain

chain = create_sql_query_chain(llm, db)
chain.get_prompts()[0].pretty_print()

You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use date('now') function to get the current date, if the question involves "today".

Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result

In [5]:
context = db.get_context()
print(list(context))
print(context["table_info"])

['table_info', 'table_names']

CREATE TABLE dim_daypart (
	minute_id TEXT, 
	daypart TEXT
)

/*
3 rows from dim_daypart table:
minute_id	daypart
0:00	Latenight
0:01	Latenight
0:02	Latenight
*/


CREATE TABLE locations (
	store_id INTEGER, 
	street_address TEXT, 
	city TEXT, 
	zip_code INTEGER, 
	dma_name TEXT, 
	franchise_owner TEXT, 
	franchise_name TEXT
)

/*
3 rows from locations table:
store_id	street_address	city	zip_code	dma_name	franchise_owner	franchise_name
1	123 Main St	Lindaview	10178	New York	Annette Turner	Turner Food Inc
2	124 Main St	Deannaland	60304	Chicago	Alexis Jefferson	Jefferson Food Inc
3	125 Main St	Port Kaylaville	19006	Philadelphia	Amanda Greene	Greene Food Inc
*/


CREATE TABLE order_item (
	oid_id INTEGER, 
	transaction_id INTEGER, 
	product_id INTEGER, 
	product_quantity INTEGER, 
	FOREIGN KEY(product_id) REFERENCES product_margin (product_id), 
	FOREIGN KEY(transaction_id) REFERENCES order_summary (transaction_id)
)

/*
3 rows from order_item table:
oid_id	

In [6]:
prompt_with_context = chain.get_prompts()[0].partial(table_info=context["table_info"])
print(prompt_with_context.pretty_repr()[:1500])

You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use date('now') function to get the current date, if the question involves "today".

Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result

In [7]:
# EXAMPLE QUERIES

examples = [
    {"input": "What is the average check for each Product Name?",
     "query": 
            """
        SELECT
        product_margin.product_name, 
        AVG(order_summary.net_sales) AS avg_check
        
        FROM order_summary
        INNER JOIN product_margin ON order_item.product_id = product_margin.product_id
        INNER JOIN order_item ON order_summary.transaction_id = order_item.transaction_id
        GROUP BY product_margin.product_name
        ORDER BY avg_check DESC;
            """
    },

    {"input": "What are the top three products sold in each daypart?",
     "query": 
            """
        WITH ranked_products AS (
            SELECT
                dim_daypart.daypart,
                product_margin.product_name,
                SUM(order_item.product_quantity) AS total_quantity,
                ROW_NUMBER() OVER (PARTITION BY dim_daypart.daypart ORDER BY SUM(order_item.product_quantity) DESC) AS rank
            FROM order_item
            JOIN product_margin ON order_item.product_id = product_margin.product_id
            JOIN order_summary ON order_item.transaction_id = order_summary.transaction_id
            JOIN dim_daypart ON order_summary.order_time = dim_daypart.minute_id
            GROUP BY dim_daypart.daypart, product_margin.product_name
        )
        SELECT
            daypart,
            product_name,
            total_quantity
        FROM ranked_products
        WHERE rank <= 3
        ORDER BY 
            CASE daypart
                WHEN 'Breakfast' THEN 1
                WHEN 'Lunch' THEN 2
                WHEN 'Snack' THEN 3
                WHEN 'Dinner' THEN 4
                WHEN 'Latenight' THEN 5
            END,
            total_quantity DESC;
            """
    },

    {"input": "Which Store sold the most waffles by quantity?",
     "query": 
            """
        SELECT
            locations.dma_name,
            locations.store_id,
            SUM(order_item.product_quantity) AS total_waffles
        FROM order_item
        JOIN product_margin ON order_item.product_id = product_margin.product_id
        JOIN order_summary ON order_item.transaction_id = order_summary.transaction_id
        JOIN locations ON order_summary.store_id = locations.store_id
        WHERE LOWER(product_margin.product_name) LIKE '%waffle%'
        GROUP BY locations.dma_name, locations.store_id
        ORDER BY total_waffles DESC;
            """
    },

    {"input": "What is the average check over time? Aggregate by day",
     "query":
            """
        SELECT
            order_summary.date,
            AVG(order_summary.net_sales) AS avg_check
        FROM order_summary
        GROUP BY order_summary.date
        ORDER BY order_summary.date;
            """
    },

    {"input": "what is the average check over time by DMA?",
     "query": 
            """
        SELECT
            order_summary.date,
            locations.dma_name,
            AVG(order_summary.net_sales) AS avg_check
        FROM order_summary
        JOIN locations ON order_summary.store_id = locations.store_id
        GROUP BY order_summary.date, locations.dma_name
        ORDER BY order_summary.date;
            """
    },

    # {"input": "QUESTION",
    #  "query": 
    #         """
    #
    #         """
    # },

]

In [8]:
from langchain_core.prompts import FewShotPromptTemplate, PromptTemplate

example_prompt = PromptTemplate.from_template("User input: {input}\nSQL query: {query}")
prompt = FewShotPromptTemplate(
    examples=examples[:5],
    example_prompt=example_prompt,
    prefix="You are a SQLite expert. Given an input question, create a syntactically correct SQLite query to run not including sql``` or any prefixes before the start of the query. Unless otherwise specificed, do not return more than {top_k} rows.\n\nHere is the relevant table info: {table_info}\n\nBelow are a number of examples of questions and their corresponding SQL queries.",
    suffix="User input: {input}\nSQL query: ",
    input_variables=["input", "top_k", "table_info"],
)

In [9]:
print(prompt.format(input="How many dayparts are there?", top_k=6, table_info="foo"))

You are a SQLite expert. Given an input question, create a syntactically correct SQLite query to run not including sql``` or any prefixes before the start of the query. Unless otherwise specificed, do not return more than 6 rows.

Here is the relevant table info: foo

Below are a number of examples of questions and their corresponding SQL queries.

User input: What is the average check for each Product Name?
SQL query: 
        SELECT
        product_margin.product_name, 
        AVG(order_summary.net_sales) AS avg_check
        
        FROM order_summary
        INNER JOIN product_margin ON order_item.product_id = product_margin.product_id
        INNER JOIN order_item ON order_summary.transaction_id = order_item.transaction_id
        GROUP BY product_margin.product_name
        ORDER BY avg_check DESC;
            

User input: What are the top three products sold in each daypart?
SQL query: 
        WITH ranked_products AS (
            SELECT
                dim_daypart.daypart,


In [10]:
from langchain_community.vectorstores import FAISS
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
from langchain_openai import OpenAIEmbeddings

example_selector = SemanticSimilarityExampleSelector.from_examples(
    examples,
    OpenAIEmbeddings(),
    FAISS,
    k=5,
    input_keys=["input"],
)

In [11]:
example_selector.select_examples({"input": "how many dayparts are there?"})

[{'input': 'What are the top three products sold in each daypart?',
  'query': "\n        WITH ranked_products AS (\n            SELECT\n                dim_daypart.daypart,\n                product_margin.product_name,\n                SUM(order_item.product_quantity) AS total_quantity,\n                ROW_NUMBER() OVER (PARTITION BY dim_daypart.daypart ORDER BY SUM(order_item.product_quantity) DESC) AS rank\n            FROM order_item\n            JOIN product_margin ON order_item.product_id = product_margin.product_id\n            JOIN order_summary ON order_item.transaction_id = order_summary.transaction_id\n            JOIN dim_daypart ON order_summary.order_time = dim_daypart.minute_id\n            GROUP BY dim_daypart.daypart, product_margin.product_name\n        )\n        SELECT\n            daypart,\n            product_name,\n            total_quantity\n        FROM ranked_products\n        WHERE rank <= 3\n        ORDER BY \n            CASE daypart\n                WHEN 

In [12]:
prompt = FewShotPromptTemplate(
    example_selector=example_selector,
    example_prompt=example_prompt,
    prefix="You are a SQLite expert. Given an input question, create a syntactically correct SQLite query to run not including sql``` or any prefixes before the start of the query. Unless otherwise specified, do not return more than {top_k} rows.\n\nHere is the relevant table info: {table_info}\n\nBelow are a number of examples of questions and their corresponding SQL queries.",
    suffix="User input: {input}\nSQL query: ",
    input_variables=["input", "top_k", "table_info"],
)

In [13]:
print(prompt.format(input="how many dayparts are there?", top_k=3, table_info="foo"))

You are a SQLite expert. Given an input question, create a syntactically correct SQLite query to run not including sql``` or any prefixes before the start of the query. Unless otherwise specified, do not return more than 3 rows.

Here is the relevant table info: foo

Below are a number of examples of questions and their corresponding SQL queries.

User input: What are the top three products sold in each daypart?
SQL query: 
        WITH ranked_products AS (
            SELECT
                dim_daypart.daypart,
                product_margin.product_name,
                SUM(order_item.product_quantity) AS total_quantity,
                ROW_NUMBER() OVER (PARTITION BY dim_daypart.daypart ORDER BY SUM(order_item.product_quantity) DESC) AS rank
            FROM order_item
            JOIN product_margin ON order_item.product_id = product_margin.product_id
            JOIN order_summary ON order_item.transaction_id = order_summary.transaction_id
            JOIN dim_daypart ON order_sum

In [14]:
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool

execute_query = QuerySQLDataBaseTool(db=db)
write_query = create_sql_query_chain(llm, db, prompt) # prompt included in the chain
# chain = write_query | execute_query
# chain.invoke({"question": "how many dayparts are there?"})

In [15]:
from operator import itemgetter

from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough

In [16]:
import logging

In [17]:
import sqlparse

# Function to pretty print SQL query with indentation
def pretty_print_sql(query, indent="    "):
    formatted_query = sqlparse.format(query, reindent=True, keyword_case='upper')
    indented_query = "\n".join(indent + line for line in formatted_query.splitlines())
    return indented_query

In [18]:
### VALID with logging query to object ###

# Configure logging
class CustomFormatter(logging.Formatter):
    def format(self, record):
        return f"{record.getMessage()}"

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger()

# Set custom formatter for all handlers
for handler in logging.root.handlers:
    handler.setFormatter(CustomFormatter())

# Set logging level for httpx to WARNING to suppress its INFO logs
logging.getLogger("httpx").setLevel(logging.WARNING)

# Define the prompt template
answer_prompt = PromptTemplate.from_template(
    """Given the following user question, corresponding SQL query, and SQL result, answer the user question.

Question: {question}
SQL Query: {query}
SQL Result: {result}
Answer: """
)

# List to store logged queries
logged_queries = []

# Function to log the query
def log_query(query):
    # Temporarily disable logging
    logging.disable(logging.INFO)
    logger.info(f"Executing SQL Query: {query}")
    # Re-enable logging
    logging.disable(logging.NOTSET)
    logged_queries.append(query)  # Save the query to the list
    return query

# Define the chain
chain = (
    RunnablePassthrough.assign(query=write_query).assign(
        result=itemgetter("query") | (log_query | execute_query)
    )
    | answer_prompt
    | llm
    | StrOutputParser()
)



def qna(question: str):
    # Invoke the chain and capture the result
    result = chain.invoke({"question": question})
    # Print the logged query with indentation
    formatted_query = pretty_print_sql(logged_queries[-1])
    print("Executed Query:")
    print(formatted_query, "\n")

    # Print the result
    print("Summarized Answer: \n", result)

In [19]:
qna("How many dayparts are there?")

Executed Query:
    SELECT COUNT(DISTINCT daypart) AS total_dayparts
    FROM dim_daypart; 

Summarized Answer: 
 There are 5 distinct dayparts.


In [20]:
qna("what is the first daypart in alphabetical order?")

Executed Query:
    SELECT daypart
    FROM dim_daypart
    ORDER BY daypart
    LIMIT 1; 

Summarized Answer: 
 The first daypart in alphabetical order is Breakfast.


In [21]:
qna("what is the first daypart in alphabetical order?")

Executed Query:
    SELECT daypart
    FROM dim_daypart
    ORDER BY daypart
    LIMIT 1; 

Summarized Answer: 
 The first daypart in alphabetical order is Breakfast.


In [22]:
qna("what are the daily sales for each store in the Philadelphia DMA? Filter dates between 2024-08-01 and 2024-08-31. Return the results as a table")

Executed Query:
    SELECT order_summary.date,
           locations.store_id,
           SUM(order_summary.net_sales) AS daily_sales
    FROM order_summary
    JOIN locations ON order_summary.store_id = locations.store_id
    WHERE locations.dma_name = 'Philadelphia'
      AND order_summary.date BETWEEN '2024-08-01' AND '2024-08-31'
    GROUP BY order_summary.date,
             locations.store_id
    ORDER BY order_summary.date,
             locations.store_id; 

Summarized Answer: 
 Here is the table showing the daily sales for each store in the Philadelphia DMA from August 1, 2024, to August 31, 2024:

| Date       | Store ID | Daily Sales |
|------------|----------|-------------|
| 2024-08-01 | 3        | 538.56      |
| 2024-08-01 | 17       | 282.99      |
| 2024-08-01 | 23       | 536.51      |
| 2024-08-01 | 33       | 531.15      |
| 2024-08-01 | 37       | 468.09      |
| 2024-08-01 | 45       | 649.76      |
| 2024-08-01 | 50       | 276.05      |
| 2024-08-01 | 56       | 21