In [1]:
from sqlalchemy import create_engine
from langchain_community.utilities import SQLDatabase

engine = create_engine("sqlite:///data.db")

# Load the database using the engine
db = SQLDatabase(engine)

In [2]:
import os 
from dotenv import load_dotenv
import openai

# load the .env file
load_dotenv(".env")
openai.api_key = os.getenv("OPENAI_API_KEY")

from langchain_openai import ChatOpenAI

llm = ChatOpenAI(model="gpt-4o-mini")

In [3]:
from langchain.chains import create_sql_query_chain

In [4]:
context = db.get_context()

In [5]:
# EXAMPLE QUERIES

examples = [
    {"input": "What is the average check for each Product Name?",
     "query": 
            """
        SELECT
        product_margin.product_name, 
        AVG(order_summary.net_sales) AS avg_check
        
        FROM order_summary
        INNER JOIN product_margin ON order_item.product_id = product_margin.product_id
        INNER JOIN order_item ON order_summary.transaction_id = order_item.transaction_id
        GROUP BY product_margin.product_name
        ORDER BY avg_check DESC;
            """
    },

    {"input": "What are the top three products sold in each daypart?",
     "query": 
            """
        WITH ranked_products AS (
            SELECT
                dim_daypart.daypart,
                product_margin.product_name,
                SUM(order_item.product_quantity) AS total_quantity,
                ROW_NUMBER() OVER (PARTITION BY dim_daypart.daypart ORDER BY SUM(order_item.product_quantity) DESC) AS rank
            FROM order_item
            JOIN product_margin ON order_item.product_id = product_margin.product_id
            JOIN order_summary ON order_item.transaction_id = order_summary.transaction_id
            JOIN dim_daypart ON order_summary.order_time = dim_daypart.minute_id
            GROUP BY dim_daypart.daypart, product_margin.product_name
        )
        SELECT
            daypart,
            product_name,
            total_quantity
        FROM ranked_products
        WHERE rank <= 3
        ORDER BY 
            CASE daypart
                WHEN 'Breakfast' THEN 1
                WHEN 'Lunch' THEN 2
                WHEN 'Snack' THEN 3
                WHEN 'Dinner' THEN 4
                WHEN 'Latenight' THEN 5
            END,
            total_quantity DESC;
            """
    },

    {"input": "Which Store sold the most waffles by quantity?",
     "query": 
            """
        SELECT
            locations.dma_name,
            locations.store_id,
            SUM(order_item.product_quantity) AS total_waffles
        FROM order_item
        JOIN product_margin ON order_item.product_id = product_margin.product_id
        JOIN order_summary ON order_item.transaction_id = order_summary.transaction_id
        JOIN locations ON order_summary.store_id = locations.store_id
        WHERE LOWER(product_margin.product_name) LIKE '%waffle%'
        GROUP BY locations.dma_name, locations.store_id
        ORDER BY total_waffles DESC;
            """
    },

    {"input": "What is the average check over time? Aggregate by day",
     "query":
            """
        SELECT
            order_summary.date,
            AVG(order_summary.net_sales) AS avg_check
        FROM order_summary
        GROUP BY order_summary.date
        ORDER BY order_summary.date;
            """
    },

    {"input": "what is the average check over time by DMA?",
     "query": 
            """
        SELECT
            order_summary.date,
            locations.dma_name,
            AVG(order_summary.net_sales) AS avg_check
        FROM order_summary
        JOIN locations ON order_summary.store_id = locations.store_id
        GROUP BY order_summary.date, locations.dma_name
        ORDER BY order_summary.date;
            """
    },

    # {"input": "QUESTION",
    #  "query": 
    #         """
    #
    #         """
    # },

]

In [6]:
from langchain_core.prompts import FewShotPromptTemplate, PromptTemplate

example_prompt = PromptTemplate.from_template("User input: {input}\nSQL query: {query}")

In [7]:
from langchain_community.vectorstores import FAISS
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
from langchain_openai import OpenAIEmbeddings

example_selector = SemanticSimilarityExampleSelector.from_examples(
    examples,
    OpenAIEmbeddings(),
    FAISS,
    k=5,
    input_keys=["input"],
)

In [8]:
prompt = FewShotPromptTemplate(
    example_selector=example_selector,
    example_prompt=example_prompt,
    prefix="You are a SQLite expert. Given an input question, create a syntactically correct SQLite query to run not including sql``` or any prefixes before the start of the query. Unless otherwise specified, do not return more than {top_k} rows.\n\nHere is the relevant table info: {table_info}\n\nBelow are a number of examples of questions and their corresponding SQL queries.",
    suffix="User input: {input}\nSQL query: ",
    input_variables=["input", "top_k", "table_info"],
)

In [9]:
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool

execute_query = QuerySQLDataBaseTool(db=db)
write_query = create_sql_query_chain(llm, db, prompt) # prompt included in the chain

In [10]:
from operator import itemgetter

from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough
import logging

In [11]:
import sqlparse

# Function to pretty print SQL query with indentation
def pretty_print_sql(query, indent="    "):
    formatted_query = sqlparse.format(query, reindent=True, keyword_case='upper')
    indented_query = "\n".join(indent + line for line in formatted_query.splitlines())
    return indented_query

In [12]:
### VALID with logging query to object ###

# Configure logging
class CustomFormatter(logging.Formatter):
    def format(self, record):
        return f"{record.getMessage()}"

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger()

# Set custom formatter for all handlers
for handler in logging.root.handlers:
    handler.setFormatter(CustomFormatter())

# Set logging level for httpx to WARNING to suppress its INFO logs
logging.getLogger("httpx").setLevel(logging.WARNING)

# Define the prompt template
answer_prompt = PromptTemplate.from_template(
    """Given the following user question, corresponding SQL query, and SQL result, answer the user question.

Question: {question}
SQL Query: {query}
SQL Result: {result}
Answer: """
)

# List to store logged queries
logged_queries = []

# Function to log the query
def log_query(query):
    # Temporarily disable logging
    logging.disable(logging.INFO)
    logger.info(f"Executing SQL Query: {query}")
    # Re-enable logging
    logging.disable(logging.NOTSET)
    logged_queries.append(query)  # Save the query to the list
    return query

# Define the chain
chain = (
    RunnablePassthrough.assign(query=write_query).assign(
        result=itemgetter("query") | (log_query | execute_query)
    )
    | answer_prompt
    | llm
    | StrOutputParser()
)



def qna(question: str):
    # Invoke the chain and capture the result
    result = chain.invoke({"question": question})
    # Print the logged query with indentation
    formatted_query = pretty_print_sql(logged_queries[-1])
    print("Executed Query:")
    print(formatted_query, "\n")

    # Print the result
    print("Summarized Answer: \n", result)

In [13]:
qna("How many dayparts are there?")

Executed Query:
    SELECT COUNT(DISTINCT daypart) AS total_dayparts
    FROM dim_daypart; 

Summarized Answer: 
 There are 5 distinct dayparts.


In [14]:
qna("what is the first daypart in alphabetical order?")

Executed Query:
    SELECT daypart
    FROM dim_daypart
    ORDER BY daypart
    LIMIT 1; 

Summarized Answer: 
 The first daypart in alphabetical order is Breakfast.


In [15]:
qna("what is the first daypart in alphabetical order?")

Executed Query:
    SELECT daypart
    FROM dim_daypart
    ORDER BY daypart
    LIMIT 1; 

Summarized Answer: 
 The first daypart in alphabetical order is Breakfast.


In [16]:
qna("what are the daily sales for each store in the Philadelphia DMA? Filter dates between 2024-08-01 and 2024-08-31. Return the results as a table")

Executed Query:
    SELECT order_summary.date,
           locations.store_id,
           SUM(order_summary.net_sales) AS daily_sales
    FROM order_summary
    JOIN locations ON order_summary.store_id = locations.store_id
    WHERE locations.dma_name = 'Philadelphia'
      AND order_summary.date BETWEEN '2024-08-01' AND '2024-08-31'
    GROUP BY order_summary.date,
             locations.store_id
    ORDER BY order_summary.date; 

Summarized Answer: 
 Here are the daily sales for each store in the Philadelphia DMA from August 1 to August 31, 2024:

| Date       | Store ID | Daily Sales |
|------------|----------|-------------|
| 2024-08-01 | 3        | 538.56      |
| 2024-08-01 | 17       | 282.99      |
| 2024-08-01 | 23       | 536.51      |
| 2024-08-01 | 33       | 531.15      |
| 2024-08-01 | 37       | 468.09      |
| 2024-08-01 | 45       | 649.76      |
| 2024-08-01 | 50       | 276.05      |
| 2024-08-01 | 56       | 217.62      |
| 2024-08-01 | 58       | 512.11      |
| 202