### Import feedbacks.db file

In [3]:
# test access to db file: import db tables into data frames and select by the column names

import pandas as pd
import sqlite3
from sqlalchemy import create_engine, inspect

engine = create_engine('sqlite:///feedbacks_db.db')
inspector = inspect(engine)

df_company = pd.read_sql_query('SELECT company_name,annual_revenue_usd FROM company', engine)
df_feedback = pd.read_sql_query('SELECT feedback_id,feedback_date,product_id,product_company_name,feedback_text,"feedback_rating" FROM feedback', engine)
df_products = pd.read_sql_query('SELECT product_id,product_name,product_brand,product_manufacturer,product_company_name,product_price,product_average_rating FROM products', engine)

### Instantiate chat model (OpenAI)

In [22]:
import langchain, langgraph, langchain_openai

import os
from dotenv import load_dotenv

load_dotenv()
openai_api_key = openai_api_key = os.getenv('OPENAI_API_KEY')
os.environ['OPENAI_API_KEY'] = openai_api_key

from langchain_openai import ChatOpenAI
llm = ChatOpenAI(model='gpt-4o',temperature=0)

### Define the test state for debugging and constants

In [5]:
# constants
question = 'What can you tell me about the dataset?'

In [6]:
# empty test_state
test_state = {'question':'',
              'sql_query':[],
              'sql_query_explanation':[],
              'sql_query_result':[],
              'llm_answer':[]
              }

# function to initialize the state with the question
def add_question_test_state(question:str):
 test_state['question'] = question

add_question_test_state(question)

### Create a function that generates the sql query to be executed

In [7]:
objects_documentation = """
  Table company: List of public companies. Granularity is company-name. Column (prefixed with table name):
  company.company-name: the name of the public company.
  company.annual_revenue_usd: revenue in last 12 months ($).

  Table feedback: Feedbacks given by clients to products. Granularity is feedback. Key is feedback_id. Columns (prefixed with table name):
  feedback.feedback_id: id of the feedback.
  feedback.feedback_date: date of feedback.
  feedback.product_id: id of the product the feedback was given for.
  feedback.product_company_name: company owning the product.
  feedback.feedback_text: the text of the feedback.
  feedback.feedback_rating: rating of the feedback from 1 to 5, 5 being the highest score.

  Table products: Shows product metadata. Granularity is product. Key is product_id. Columns (prefixed with table name):
  products.product_id: id of the product.
  products.product_name: name of the product.
  products.product_brand: the brand under which the product was presented.
  products.product_manufacturer: product manufacturer.
  products.product_company_name: company owning the product.
  products.product_price: price of the product at crawling time.
  products.product_average-rating: average ratings across all feedbacks for the product, at crawling time.

  Table company -> column company_name relates to table feedback -> column product_company_name
  Table products -> column product_company_name relates to table feedback -> column product_company-name
  Table feedback -> column product_id relates to table products -> column product_id
  """

In [9]:
# create a function that generates the sql query to be executed


# https://python.langchain.com/docs/concepts/prompt_templates/
# API reference: https://python.langchain.com/api_reference/core/prompts/langchain_core.prompts.prompt.PromptTemplate.html

from typing_extensions import TypedDict, Annotated
from langchain_core.prompts import PromptTemplate

# define the state of the graph, which includes user's question, AI's answer, query that has been created and its result;
class State(TypedDict):
 question: str
 sql_query: list[str]
 sql_query_explanation : list[str]
 sql_query_result: list[str]
 llm_answer: str

class OutputAsAQuery(TypedDict):
  """ generated sql query or sql queries if there are multiple """
  query: Annotated[list[str],"clean sql query"]

def create_sql_query_or_queries(state:State):
  """ creates a sql query based on the question """

  query_prompt_template_string = """You are a sql expert and an expert data modeler.

  You have access to the following tables and columns:
  {objects_documentation}

  Using only the objects you have access to, create a sql script to answer the following question:
  {question}

  Answer just with the resulting sql code.

  IMPORTANT:
    - Return only raw SQL strings in the list.
    - DO NOT include comments (like "-- Query 1"), labels, or explanations.
    - If only one SQL query is needed, just return a list with that one query.
    - Do not generate more than 5 queries.

  Example output:
    [
      "SELECT COUNT(*) FROM feedback;",
      "SELECT AVG(product_price) FROM products;"
    ]
  """

  query_prompt_template = PromptTemplate.from_template(query_prompt_template_string)

  chain = (query_prompt_template
          | llm.with_structured_output(OutputAsAQuery)
          | (lambda output: {'sql_query':output['query']} # make default value of iterations to zero
        )  )

  return chain.invoke({'objects_documentation':objects_documentation, 'question': state['question']})

In [10]:
# update test_state for debug

test_state.update(create_sql_query_or_queries(test_state))
#test_state

### create a function that executes the sql query

In [11]:
# since gpt-4o allows a maximum completion limit (output context limit) of 4k tokens, I half it to get maximum context size, so 2k. Assuming the entire context is not just the data,
# I divide this number by 5 and arrive at a limit of 400 tokens for the result of the sql query.

import tiktoken

maximum_nr_tokens_sql_query = 200

# create a function that counts the tokens from a string
def count_tokens(string:str):
 """ returns the number of tokens in a text string """
 encoding = tiktoken.encoding_for_model("gpt-4o")
 num_tokens = len(encoding.encode(string))
 return num_tokens

# create a function that compares the tokens from the sql query result with the maximum token limit, and returns true if the context limit has been exceeded, false otherwise.
def check_if_exceed_maximum_context_limit(sql_query_result):
 """ compares the tokens from the sql query result with the maximum token limit, and returns true if the context limit has been exceeded, false otherwise """
 tokens_sql_query_result = count_tokens(sql_query_result)
 if tokens_sql_query_result > maximum_nr_tokens_sql_query:
  return True
 else:
  return False


In [12]:
# create a function that creates an explanation of a sql query

def create_sql_query_explanation(sql_query:str):
 """ creates a concise explanation of the sql query """

 prompt_template = PromptTemplate.from_template("""
 As a data expert, you are provided with this sql query:
 {sql_query}

 Create a brief explanation of this query in simple terms by taking into account these factors, if exist:
 - Pay attention to the nuances of the query: the filters, aggregations, groupings, etc.
 - Take into account underlying assumptions.
 - Query limitations.
 Keep it short.
 """)

 chain = prompt_template | llm
 sql_query_explanation = chain.invoke({'sql_query':sql_query}).content
 return sql_query_explanation

In [13]:
class OutputAsASingleQuery(TypedDict):
  """ generated sql query """
  query: Annotated[str,...,"the generated sql query"]

def refine_sql_query(question: str, sql_query: str, maximum_nr_tokens_sql_query: int):
 """ refines the sql query so that its output tokens do not exceed the maximum context limit """

 query_prompt_template_string = """
 You are a sql expert an an expert data modeler.

 You are tying to answer the following question from the user:
 {question}

 The following sql query produces an output that exceeds {maximum_nr_tokens_sql_query} tokens:
 {sql_query}

 Please optimize this query so that its output stays within the token limit while still providing as much insight as possible to answer the question.
 Prefer using WHERE or LIMIT clauses to reduce the size of the result.
 """
 sys_prompt_template = PromptTemplate.from_template(query_prompt_template_string)

 chain = (sys_prompt_template
         | llm.with_structured_output(OutputAsASingleQuery)
 )

 sql_query = chain.invoke({'question': question,
               'sql_query':sql_query,
               'maximum_nr_tokens_sql_query':maximum_nr_tokens_sql_query})
 return sql_query

In [14]:
# the function checks if the query output exceeds context window limit and if yes, send the query for refinement

from langchain_community.tools import QuerySQLDataBaseTool
from langchain_community.utilities import SQLDatabase
from typing import Iterator

db = SQLDatabase(engine)

def execute_sql_query(state:State):
  """ executes the sql query and retrieve the result """

  for query_index, sql_query in enumerate(state['sql_query']):

    print(f"üöÄ Executing query {query_index+1}/{len(state['sql_query'])}...")
    # refine the query 3 times if necessary.
    for i in range(3):

      sql_query_result = QuerySQLDataBaseTool(db=db).invoke(sql_query)

      # if the sql query does not exceed output context window return its result
      if not check_if_exceed_maximum_context_limit(sql_query_result):

       sql_query_explanation = create_sql_query_explanation(sql_query)
       state['sql_query_result'].append(sql_query_result)
       state['sql_query_explanation'].append(sql_query_explanation)
       break

      # if the sql query exceeds output context window and there is more room for iterations, refine the query
      else:
        print(f"üîß Refining query {query_index+1}/{len(state['sql_query'])} as its output its too large...")
        sql_query = refine_sql_query(state['question'],sql_query,maximum_nr_tokens_sql_query)['query']

      # if there is no more room for sql query iterations and the result still exceeds context window, throw a message

    else:
      print(f"‚ö†Ô∏è Query result too large after 3 refinements.")
      state['sql_query_result'].append('Query result too large after 3 refinements.')
      state['sql_query_explanation'].append("Refinement failed.")

In [None]:
# update test_state for debug
execute_sql_query(test_state)
#test_state

### Extract metadata from sql query

In [16]:
import sqlglot
from sqlglot import parse_one, exp

def extract_metadata_from_sql_query(sql_query):
 ast = parse_one(sql_query)

 sql_query_metadata = {
    "tables": [],
    "filters": [],
    "aggregations": [],
    "groupings": []
 }

 # extract tables
 table_generator = ast.find_all(sqlglot.expressions.Table)
 for items in table_generator:
    sql_query_metadata['tables'].append(items.sql())
 # remove dups
 sql_query_metadata['tables'] = list(dict.fromkeys(sql_query_metadata['tables']))

 # extract filters
 where_conditions = ast.find_all(sqlglot.expressions.Where)
 for item in where_conditions:
  sql_query_metadata['filters'].append(item.this.sql())
  # remove dups
 sql_query_metadata['filters'] = list(dict.fromkeys(sql_query_metadata['filters']))

 # extract aggregate functions
 funcs = ast.find_all(sqlglot.expressions.AggFunc)
 for item in funcs:
  sql_query_metadata['aggregations'].append(item.sql())
 # remove dups
 sql_query_metadata['aggregations'] = list(dict.fromkeys(sql_query_metadata['aggregations']))

 # extract groupings
 groupings = ast.find_all(sqlglot.expressions.Group)
 for item in groupings:
  groupings_flattened = item.flatten()
  for item in groupings_flattened:
    sql_query_metadata['groupings'].append(item.sql())
 # remove dups
 sql_query_metadata['groupings'] = list(dict.fromkeys(sql_query_metadata['groupings']))

 return sql_query_metadata

In [17]:
# test it:
#sql_query = test_state['sql_query'][0]
#extract_metadata_from_sql_query(sql_query)

### create_explanation

In [17]:
def create_explanation(sql_queries: list[str]):
 """ based on the sql query metadata that was parsed, it creates a natural language message describing filters and transformations used by the query"""

 tables = []
 filters = []
 aggregations = []
 groupings = []

 for item in sql_queries:
 # get sql query metadata
  sql_query = item
  sql_query_metadata = extract_metadata_from_sql_query(sql_query)

  if sql_query_metadata['tables']:
   tables.extend(sql_query_metadata['tables'])
   tables = list(dict.fromkeys(tables))

  if sql_query_metadata['filters']:
   filters.extend(sql_query_metadata['filters'])
   filters = list(dict.fromkeys(filters))

  if sql_query_metadata['aggregations']:
   aggregations.extend(sql_query_metadata['aggregations'])
   aggregations = list(dict.fromkeys(aggregations))

  if sql_query_metadata['groupings']:
   groupings.extend(sql_query_metadata['groupings'])
   groupings = list(dict.fromkeys(groupings))

 # wrapping it all together
 sql_query_explanation = "I analyzed data based on the following filters and transformations:"

 if tables:
  tables = f"üßä Tables: ‚Ä¢ {' ‚Ä¢ '.join(tables)}"
  sql_query_explanation = sql_query_explanation + "\n\n" + tables

 if filters:
  filters = f"üîç Filters: ‚Ä¢ {' ‚Ä¢ '.join(filters)}"
  sql_query_explanation = sql_query_explanation + "\n\n" + filters

 if aggregations:
  aggregations = f"üßÆ Aggregations: ‚Ä¢ {' ‚Ä¢ '.join(aggregations)}"
  sql_query_explanation = sql_query_explanation + "\n\n" + aggregations

 if groupings:
  groupings = f"üì¶ Groupings: ‚Ä¢ {' ‚Ä¢ '.join(groupings)}"
  sql_query_explanation = sql_query_explanation + "\n\n" + groupings

 return sql_query_explanation

In [19]:
# test it
#print(create_explanation(test_state['sql_query']))

### create a function that generates the agent answer based on sql query result

In [18]:
## create a function that generates the agent answer based on sql query result

def generate_answer(state:State):
  """ generates the AI answer taking into consideration the explanation and the result of the sql query that was executed """

  prompt_template = PromptTemplate.from_template(
     """ You are a decision support consultant helping users become more data-driven.
     Your task is to answer the user question based on the following information:

     - The sql query result which is the result of a query created for the purpose of answering the question.
     - The query explanation is a short explanation of the query making you aware of its limitations and underlying assumptions.

    User question:
    {question}

    SQL query explanation:
    {sql_query_explanation}

    SQL query result:
    {sql_query_result}

    Take into account the insights from this explanation in your answer.
    Answer in simple terms, conversational, non-technical language. Be concise.
    """)

  chain = (prompt_template
        | llm
        | (lambda output: {'llm_answer': f"{output.content}\n\n{create_explanation(state['sql_query'])}"})
  )

  return chain.invoke({'question':state['question'],
                     'sql_query_explanation':state['sql_query_explanation'],
                     'sql_query_result':state['sql_query_result']})

In [19]:
# test the function
test_state.update(generate_answer(test_state))
#test_state

### assemble the graph

In [20]:
# assemble graph

from langgraph.graph import StateGraph, START, END

graph= StateGraph(State)
graph.add_node("create_sql_query_or_queries",create_sql_query_or_queries)
graph.add_node("execute_sql_query",execute_sql_query)
graph.add_node("generate_answer",generate_answer)

graph.add_edge(START,"create_sql_query_or_queries")
graph.add_edge("create_sql_query_or_queries","execute_sql_query")
graph.add_edge("execute_sql_query","generate_answer")
graph.add_edge("generate_answer",END)
graph = graph.compile()

### test the agent

In [21]:
initial_dict = {'objects_documentation':objects_documentation,
     'question':question,
     'sql_query': [],
     'sql_query_result': [],
     'sql_query_explanation': [],
     'llm_answer': ''
     }

for step in graph.stream(initial_dict, stream_mode="updates"):
  step_name, output = list(step.items())[0]
  if step_name == 'create_sql_query_or_queries':
    print(f"‚úÖ SQL queries created:{len(output['sql_query'])}")
  elif step_name == 'execute_sql_query':
    print("‚öôÔ∏è Analysing results...")
  elif step_name == 'generate_answer':
    print("\nüì£ Final Answer:\n")
    print(output['llm_answer'])

‚úÖ SQL queries created:5
üöÄ Executing query 1/5...
üöÄ Executing query 2/5...
üöÄ Executing query 3/5...
üöÄ Executing query 4/5...
üöÄ Executing query 5/5...
‚öôÔ∏è Analysing results...

üì£ Final Answer:

The dataset includes information about companies, feedback, and products. Here's a quick overview:

1. **Companies**: There are 12 unique company names in the dataset. This means there are 12 different companies listed.

2. **Feedback**: There are 413,898 unique feedback entries. This is a large number, indicating a lot of feedback data is available.

3. **Products**: There are 8,145 unique products. This suggests a wide variety of products are included in the dataset.

4. **Feedback Ratings**: The average feedback rating is about 3.84. This gives a general idea of how users feel about the products or services.

5. **Product Prices**: The average price of the products is approximately 158.87. This provides a sense of the typical cost of products in the dataset.

Overall, the