### Import feedbacks.db file

In [92]:
# test access to db file: import db tables into data frames and select by the column names

import pandas as pd
import sqlite3
from sqlalchemy import create_engine, inspect
import uuid

engine = create_engine('sqlite:///feedbacks_db.db')
inspector = inspect(engine)

df_company = pd.read_sql_query('SELECT company_name,annual_revenue_usd FROM company', engine)
df_feedback = pd.read_sql_query('SELECT feedback_id,feedback_date,product_id,product_company_name,feedback_text,"feedback_rating" FROM feedback', engine)
df_products = pd.read_sql_query('SELECT product_id,product_name,product_brand,product_manufacturer,product_company_name,product_price,product_average_rating FROM products', engine)

### Instantiate chat model (OpenAI)

In [93]:
import langchain, langgraph, langchain_openai, langsmith

import os
from dotenv import load_dotenv
from langchain_core.runnables import RunnableConfig
from langchain.callbacks.tracers.langchain import LangChainTracer
from langchain_core.tools import tool

load_dotenv(override=True)
openai_api_key = os.getenv('OPENAI_API_KEY')
LANGSMITH_API_KEY = os.getenv('LANGSMITH_API_KEY')
os.environ['OPENAI_API_KEY'] = openai_api_key
os.environ['LANGSMITH_API_KEY'] = LANGSMITH_API_KEY
os.environ['LANGSMITH_TRACING'] = "true"
os.environ['LANGSMITH_ENDPOINT'] = "https://api.smith.langchain.com"
langsmith_project_name = "db_agent_v1"
os.environ['LANGSMITH_PROJECT'] = langsmith_project_name

# Set up LangSmith tracer manually
tracer = LangChainTracer(project_name=langsmith_project_name)

from langchain_openai import ChatOpenAI
llm = ChatOpenAI(model='gpt-4o',temperature=0) # Smart & expensive
llm_fast = ChatOpenAI(model='gpt-4.1') # Faster

### Create config

In [94]:
import datetime

def create_config(run_name: str, is_new_thread_id: bool = False, thread_id: str = None):
    """
    Create a config dictionary for LCEL runnables.
    Includes LangSmith run tracing and optional thread_id management.

    Args:
        run_name (str): Descriptive run name shown in LangSmith.
        is_new_thread_id (bool): Whether to generate a new thread_id.
        thread_id (str): Optionally provide an existing thread_id to reuse.

    Returns:
        dict: Config dictionary with callbacks, run_name, and thread_id.

    Use it like so (example): 
        config, thread_id = create_config('create_sql_query_or_queries', True) (start a new thread)
        config, _ = create_config('generate_answer', False, thread_id) (re-use same thread)
    """

    time_now = datetime.datetime.now().strftime("%Y-%m-%d %H:%M")
    full_run_name = f"{run_name} {time_now}"
    if is_new_thread_id or not thread_id:
        thread_id = str(uuid.uuid4())

    config={'callbacks': [tracer],
            'run_name': full_run_name,
            'configurable' : { 'thread_id':thread_id }
            }

    return config,thread_id

### Initialize variables

In [170]:
vector_store = None

objects_documentation = '''Table company: List of public companies. Granularity is company-name. Column (prefixed with table name):
     company.company-name: the name of the public company.
     company.annual_revenue_usd: revenue in last 12 months ($).

     Table feedback: Feedbacks given by clients to products. Granularity is feedback. Key is feedback_id. Columns (prefixed with table name):
     feedback.feedback_id: id of the feedback.
     feedback.feedback_date: date of feedback.
     feedback.product_id: id of the product the feedback was given for.
     feedback.product_company_name: company owning the product.
     feedback.feedback_text: the text of the feedback.
     feedback.feedback_rating: rating of the feedback from 1 to 5, 5 being the highest score.

     Table products: Shows product metadata. Granularity is product. Key is product_id. Columns (prefixed with table name):
     products.product_id: id of the product.
     products.product_name: name of the product.
     products.product_brand: the brand under which the product was presented.
     products.product_manufacturer: product manufacturer.
     products.product_company_name: company owning the product.
     products.product_price: price of the product at crawling time.
     products.product_average-rating: average ratings across all feedbacks for the product, at crawling time.

     Table company -> column company_name relates to table feedback -> column product_company_name
     Table products -> column product_company_name relates to table feedback -> column product_company-name
     Table feedback -> column product_id relates to table products -> column product_id'''

### Define state

In [171]:
# define the state of the graph, which includes user's question, AI's answer, query that has been created and its result;
from typing_extensions import TypedDict, Annotated
from langgraph.graph.message import add_messages
from typing import Sequence
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage, RemoveMessage
from langchain_core.agents import AgentAction
import operator

class State(TypedDict):
 objects_documentation: str
 messages_log: Sequence[BaseMessage]
 intermediate_steps: list[AgentAction]
 current_question: str
 current_sql_queries: list[dict]
 llm_answer: BaseMessage

### Create sql query or queries

In [172]:
# create a function that generates the sql query to be executed

from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder

class OutputAsAQuery(TypedDict):
  """ generated sql query or sql queries if there are multiple """
  query: Annotated[list[str],"clean sql query"]

@tool
def create_sql_query_or_queries(state:State):
  """ creates sql query/queries to anwser a question based on documentation of tables and columns available """

  system_prompt = """You are a sql expert and an expert data modeler.  

  Your task is to create a sql script to answer the user's question. In the sql script, use only these tables and columns you have access to:
  {objects_documentation}

  User question:
  {question}

  Answer just with the resulting sql code.

  IMPORTANT:
    - Return only raw SQL strings in the list.
    - DO NOT include comments (like "-- Query 1"), labels, or explanations.
    - If only one SQL query is needed, just return a list with that one query.
    - Do not generate more than 5 queries.

  Example output:
    [
      "SELECT COUNT(*) FROM feedback;",
      "SELECT AVG(product_price) FROM products;"
    ]
  """

  prompt = ChatPromptTemplate.from_messages(
    [('system', system_prompt)]
  )

  chain = prompt | llm.with_structured_output(OutputAsAQuery)

  result = chain.invoke({'objects_documentation':state['objects_documentation'], 'question': state['current_question']})
  for q in result['query']:
   state['current_sql_queries'].append( {'query': q,
                                     'explanation': '', ## add it later
                                     'result':'', ## add it later
                                     'insight': '', ## add it later
                                     'metadata':'' ## add it later
                                      } )
  
  print(f"✅ SQL queries created:{len(state['current_sql_queries'])}")
  return state

In [173]:
# since gpt-4o allows a maximum completion limit (output context limit) of 4k tokens, I half it to get maximum context size, so 2k. Assuming the entire context is not just the data,
# I divide this number by 5 and arrive at a limit of 400 tokens for the result of the sql query.

import tiktoken

maximum_nr_tokens_sql_query = 200

# create a function that counts the tokens from a string
def count_tokens(string:str):
 """ returns the number of tokens in a text string """
 encoding = tiktoken.encoding_for_model("gpt-4o")
 num_tokens = len(encoding.encode(string))
 return num_tokens

# create a function that compares the tokens from the sql query result with the maximum token limit, and returns true if the context limit has been exceeded, false otherwise.
def check_if_exceed_maximum_context_limit(sql_query_result):
 """ compares the tokens from the sql query result with the maximum token limit, and returns true if the context limit has been exceeded, false otherwise """
 tokens_sql_query_result = count_tokens(sql_query_result)
 if tokens_sql_query_result > maximum_nr_tokens_sql_query:
  return True
 else:
  return False

### Create query analysis

In [174]:
class QueryAnalysis(TypedDict):
    ''' complete analysis of a sql query, including its explanation, limitation and insight '''
    explanation: str
    limitation: str
    insight: str

def create_query_analysis(sql_query:str, sql_query_result:str):
   ''' creates: explanation - a concise explanation of what the sql query does.
                limitation - a concise explanation of the sql query by pointing out its limitations.
                insight - insight from the result of the sql query.
   '''
   system_prompt = """
   You are an expert data analyst.

   You are provided with the following SQL query:
   {sql_query}.

   Which yielded the following result:
   {sql_query_result}.

   Provide a structured analysis with three components:

   Step 1: Explanation: A concise description of what the query outputs, in one short phrase. 
                   Do not include introductory words like "The query" or "It outputs."

   Step 2: Limitation: Inherent limitations or assumptions of the query based strictly on its structure and logic.
                  Focus on:
                  - How LIMIT, ORDER BY, GROUP BY, or JOINs may introduce assumptions
                  - How filtering or aggregation logic may bias the output
                  - Situations where the query might **return incomplete or misleading results due to logic only**
                  - Cases where ORDER BY combined with LIMIT might exclude other rows with equal values (ties)

                  Only describe things that follow **logically from the query**, not from the dataset itself.

                  🚫 Do NOT mention:
                  - speculate on what the user is trying to analyze
                  - suggest what insights are missing
                  - mention field names being correct or incorrect
                  - mention data types, nulls, formatting, spelling, or schema correctness
                  - mention what other attributes, columns, filters, or relationships "could have" been used
                  - assume anything about the intent behind the query

                  If the query has no structural limitations or assumptions, respond with exactly "No comments for the query".

                  Respond in 1 to 3 concise sentences, or with the exact phrase above.
   
   Step 3: Insight: Key findings from the results, stating facts directly without technical terms.
               - Include the limitations discovered in step 2, as long as it's different than "No comments for the query".
               - Do not mention your subjective assessment over the results.
               - Avoid technical terms like "data","dataset","table","list","provided information","query" etc.
   """

   prompt = ChatPromptTemplate.from_messages(('system',system_prompt))
   chain = prompt | llm_fast.with_structured_output(QueryAnalysis)
   return chain.invoke({'sql_query':sql_query,
                        'sql_query_result':sql_query_result})   

### Create query metadata

In [216]:
import sqlglot
from sqlglot import parse_one, exp

def extract_metadata_from_sql_query(sql_query):
   # returns a dictionary with parsed names of tables and columns used in filters, aggregations and groupings 
   
 ast = parse_one(sql_query)

 sql_query_metadata = {
    "tables": [],
    "filters": [],
    "aggregations": [],
    "groupings": []
 }

 # extract tables
 table_generator = ast.find_all(sqlglot.expressions.Table)
 for items in table_generator:
    sql_query_metadata['tables'].append(items.sql())
 # remove dups
 sql_query_metadata['tables'] = list(dict.fromkeys(sql_query_metadata['tables']))

 # extract filters
 where_conditions = ast.find_all(sqlglot.expressions.Where)
 for item in where_conditions:
  sql_query_metadata['filters'].append(item.this.sql())
  # remove dups
 sql_query_metadata['filters'] = list(dict.fromkeys(sql_query_metadata['filters']))

 # extract aggregate functions
 funcs = ast.find_all(sqlglot.expressions.AggFunc)
 for item in funcs:
  sql_query_metadata['aggregations'].append(item.sql())
 # remove dups
 sql_query_metadata['aggregations'] = list(dict.fromkeys(sql_query_metadata['aggregations']))

 # extract groupings
 groupings = ast.find_all(sqlglot.expressions.Group)
 for item in groupings:
  groupings_flattened = item.flatten()
  for item in groupings_flattened:
    sql_query_metadata['groupings'].append(item.sql())
 # remove dups
 sql_query_metadata['groupings'] = list(dict.fromkeys(sql_query_metadata['groupings']))

 return {'tables':sql_query_metadata.get('tables'),
         'filters':sql_query_metadata.get('filters'),
         'aggregations':sql_query_metadata.get('aggregations'),
         'groupings':sql_query_metadata.get('groupings'),
          }

def format_sql_metadata_explanation(tables:list=None, filters:list=None, aggregations:list=None, groupings:list=None,header :str='') -> str:
    # creates a string explanation of the filters, tables, aggregations and groupings used by the query
    explanation = header

    if tables:
        explanation += "\n\n🧊 Tables: • " + " • ".join(tables)
    if filters:
        explanation += "\n\n🔍 Filters applied: • " + " • ".join(filters)
    if aggregations:
        explanation += "\n\n🧮 Aggregations: • " + " • ".join(aggregations)
    if groupings:
        explanation += "\n\n📦 Groupings: • " + " • ".join(groupings)

    return explanation

def create_query_metadata(sql_query: str):
 """ creates an explanation for one single query """

 metadata = extract_metadata_from_sql_query(sql_query)
 return format_sql_metadata_explanation(metadata['tables'],metadata['filters'],metadata['aggregations'],metadata['groupings'])


def create_queries_metadata(sql_queries: list[dict]):
 """ creates an explanation for multiple queries: used in the generate_answer tool """

 all_tables = []
 all_filters = []
 all_aggregations = []
 all_groupings = []

 for q in sql_queries: 

  metadata = extract_metadata_from_sql_query(q['query'])
  all_tables.extend(metadata["tables"])
  all_filters.extend(metadata["filters"])
  all_aggregations.extend(metadata["aggregations"])
  all_groupings.extend(metadata["groupings"])
  
  # include all metadata
  #output = format_sql_metadata_explanation(all_tables,all_filters,all_aggregations,all_groupings,header='🔍 Filters applied:')

  # include the default min/max feedback filters if feedback table has been used and was not filtered at all
  if 'feedback' in all_tables and not any('feedback_date' in item for item in all_filters):
     all_filters.append('feedback_date between 11/18/2002 and 09/12/2023')
     output = format_sql_metadata_explanation(filters = all_filters,header='')
  # include just the filters if there are any
  elif all_filters:    
     output = format_sql_metadata_explanation(filters = all_filters,header='')
  # if no filters were applied, don't include other metadata for the sake of keeping the message simple
  else:
     output = ''   

 return output

# use it like so:
#sql_queries = [ 
#    {'query':'SELECT COUNT(DISTINCT company.company_name) FROM company;', 'result':''} ,
#    {'query':'SELECT COUNT(DISTINCT feedback.feedback_id) FROM feedback;', 'result':''} 
#    ]
#create_queries_metadata(sql_queries)

In [205]:
sql_queries = [ 
    {'query':'SELECT COUNT(DISTINCT company.company_name) FROM company;', 'result':''} ,
    {'query':'SELECT COUNT(DISTINCT feedback.feedback_id) FROM feedback;', 'result':''} 
    ]

all_tables = []
all_filters = []
all_aggregations = []
all_groupings = []

for q in sql_queries: 
  metadata = extract_metadata_from_sql_query(q['query'])
  all_tables.extend(metadata["tables"])
  all_filters.extend(metadata["filters"])
  all_aggregations.extend(metadata["aggregations"])
  all_groupings.extend(metadata["groupings"])

In [206]:
all_filters.append('feedback_date between 11/18/2002 and 09/12/2023')

In [207]:
all_filters

['feedback_date between 11/18/2002 and 09/12/2023']

### Retrieve insights

In [176]:
def create_or_retrieve_vector_store():
 global vector_store  
 if vector_store is None:
    vector_store = InMemoryVectorStore(embedding=OpenAIEmbeddings(model="text-embedding-3-small"))
 return vector_store

# use it like so: vector_store = create_or_retrieve_vector_store()


In [177]:
import re
def parse_explanation(content:str):
    ''' from a text with a format of Query Explanation: ... Query Insight: ...  parse just the explanation part '''
    match = re.search(r"Query Explanation:(.*?)Query Insight:", content, re.DOTALL)
    explanation = match.group(1).strip()  # removes leading/trailing whitespace including \n
    return explanation

@tool
def retrieve_insights(state:State):
    ''' Searches the vector store for relevant past query insights with similarity > 0.6,
    and appends them to state['current_sql_queries'] '''
    print("💭 Gathering my thoughts...")
    query = state['current_question']
    vector_store = create_or_retrieve_vector_store()
    result = vector_store.similarity_search_with_score(query,k=3)
    for doc,score in result:
     if score >= 0.6:   
      state['current_sql_queries'].append( {'query': doc.metadata.get('query'),
                                     'explanation': parse_explanation(doc.page_content), 
                                     'result':doc.metadata.get('result'), 
                                     'insight': doc.metadata.get('insight'),
                                     'metadata':doc.metadata.get('metadata')
                                      } )    
    return state

### Execute sql query and stores result

In [178]:
# the function checks if the query output exceeds context window limit and if yes, send the query for refinement

from langchain_community.tools import QuerySQLDataBaseTool
from langchain_community.utilities import SQLDatabase
from typing import Iterator
from langchain_core.vectorstores import InMemoryVectorStore
from langchain_openai import OpenAIEmbeddings
from langchain_core.documents import Document

db = SQLDatabase(engine)

def execute_sql_query(state:State):
  """ executes the sql query and retrieve the result """
  
  print("⚙️ Analysing results...")
  for query_index, q in enumerate(state['current_sql_queries']):
     
    if state['current_sql_queries'][query_index]['result'] == '':    
     sql_query = q['query'] 
    
     # refine the query 3 times if necessary.
     for i in range(3):

       sql_query_result = QuerySQLDataBaseTool(db=db).invoke(sql_query)

       # if the sql query does not exceed output context window return its result
       if not check_if_exceed_maximum_context_limit(sql_query_result):
         analysis = create_query_analysis(sql_query, sql_query_result)
         sql_query_metadata = create_query_metadata(sql_query)   

         # Update state
         state['current_sql_queries'][query_index]['result'] = sql_query_result
         state['current_sql_queries'][query_index]['insight'] = analysis['insight']
         state['current_sql_queries'][query_index]['query'] = sql_query
         state['current_sql_queries'][query_index]['metadata'] = sql_query_metadata
         state['current_sql_queries'][query_index]['explanation'] = analysis['explanation']   

         # add the queries to vector store
         vector_store = create_or_retrieve_vector_store()
         doc = [Document(
              id=len(vector_store.store)+1,
              page_content=f"Query Explanation:\n{analysis['explanation'] }\n\Query Insight:{analysis['insight']}",
              metadata={"query": sql_query,
                        "result": sql_query_result,
                        "insight": analysis['insight'],
                        "metadata": sql_query_metadata
                        })]
         vector_store.add_documents(documents=doc)                                                          
         break

       # if the sql query exceeds output context window and there is more room for iterations, refine the query
       else:
        print(f"🔧 Refining query {query_index+1}/{len(state['current_sql_queries'])} as its output its too large...")
        sql_query = refine_sql_query(state['current_question'],sql_query,maximum_nr_tokens_sql_query)['query']

       # if there is no more room for sql query iterations and the result still exceeds context window, throw a message
     else:
        print(f"⚠️ Query result too large after 3 refinements.")
        state['current_sql_queries'][query_index]['result'] = 'Query result too large after 3 refinements.'
        state['current_sql_queries'][query_index]['explanation'] = "Refinement failed."
      
  return state

In [179]:
class OutputAsASingleQuery(TypedDict):
  """ generated sql query """
  query: Annotated[str,...,"the generated sql query"]

def refine_sql_query(question: str, sql_query: str, maximum_nr_tokens_sql_query: int):
 """ refines the sql query so that its output tokens do not exceed the maximum context limit """

 system_prompt = """
 You are a sql expert an an expert data modeler.

 You are tying to answer the following question from the user:
 {question}

 The following sql query produces an output that exceeds {maximum_nr_tokens_sql_query} tokens:
 {sql_query}

 Please optimize this query so that its output stays within the token limit while still providing as much insight as possible to answer the question.
 Prefer using WHERE or LIMIT clauses to reduce the size of the result.
 """
 
 prompt = ChatPromptTemplate.from_messages(
   ('system',system_prompt)
 )

 chain = (prompt
         | llm.with_structured_output(OutputAsASingleQuery)
 )

 sql_query = chain.invoke({'question': question,
               'sql_query':sql_query,
               'maximum_nr_tokens_sql_query':maximum_nr_tokens_sql_query}
               )
 return sql_query

### Generate answer

In [180]:
def format_sql_query_results_for_prompt (sql_queries : list[dict]) -> str:
    
    formatted_queries = []
    for query_index,q in enumerate(sql_queries):
        block = f"Insight {query_index+1}:\n{q['insight']}\n\nRaw Result of insight {query_index+1}:\n{q['result']}"
        formatted_queries.append(block)
    return "\n\n".join(formatted_queries)

# print(format_sql_query_results_for_prompt(test_state['sql_queries']))

In [181]:
## create a function that generates the agent answer based on sql query result

from langchain_core.runnables import RunnableLambda, RunnablePassthrough

@tool
def generate_answer(state:State):
  """ generates the AI answer taking into consideration the explanation and the result of the sql query that was executed """

  system_prompt = """ You are a decision support consultant helping users become more data-driven.
     Continue the conversation by answering the following question: {question}.

     Use both the raw SQL results and the extracted insights below to form your answer: {insights}. Include all details from these insights.

     Respond in clear, concise, non-technical language.
     """

  prompt = ChatPromptTemplate.from_messages([
      MessagesPlaceholder("messages_log") ,
      ('system',system_prompt)            
  ] )

  llm_answer_chain = prompt | llm 
  final_answer_chain = { 'llm_answer': llm_answer_chain, 'input_state': RunnablePassthrough() } | RunnableLambda (lambda x: { 'ai_message': AIMessage( content = f"{x['llm_answer'].content}\n\n{create_queries_metadata(x['input_state']['current_sql_queries'])}", 
                                                                                                                                                        response_metadata = x['llm_answer'].response_metadata  ) } ) 

  result = final_answer_chain.invoke({ 'messages_log':state['messages_log'],
               'question':state['current_question'],
               'insights': format_sql_query_results_for_prompt(state['current_sql_queries']),
              'current_sql_queries': state['current_sql_queries'] })
  
  ai_msg = result['ai_message']

  explanation_token_count = llm.get_num_tokens(create_queries_metadata(state['current_sql_queries']))
  ai_msg.response_metadata['token_usage']['total_tokens'] += explanation_token_count

  state['llm_answer'] = ai_msg
  state['messages_log'].append(HumanMessage(state['current_question']))
  state['messages_log'].append(ai_msg)

  return state

### Manage memory and chat history

In [182]:
def manage_memory_chat_history(state:State):
    """ Manages the chat history so that it does not become too large in terms of output tokens.
    Specifically, it checks if the chat history is larger than 1000 tokens. If yes, keep just the last 4 pairs of human prompts and AI responses, and summarize the older messages.
    Additionally, check if the logs of sql queries is larger than 20 entries. If yes, delete the older records. """           

    tokens_chat_history = state['messages_log'][-1].response_metadata.get('token_usage', {}).get('total_tokens', 0) if state['messages_log'] else 0    

    if tokens_chat_history >= 1000 and len(state['messages_log']) > 4:
        message_history_to_summarize = [msg.content for msg in state['messages_log'][:-4]]
        prompt = ChatPromptTemplate.from_messages( [('user', 'Distill the below chat messages into a single summary paragraph.The summary paragraph should have maximum 400 tokens.Include as many specific details as you can.Chat messages:{message_history_to_summarize}') ])
        runnable = prompt | llm_fast # use the cheap model
        chat_history_summary = runnable.invoke({'message_history_to_summarize':message_history_to_summarize})
        last_4_messages = state['messages_log'][-4:]
        state['messages_log'] = [chat_history_summary] +[*last_4_messages]
    else:
        state['messages_log'] = state['messages_log']

    # Truncate SQL logs to the most recent 20
    #if len(state['log_sql_queries']) > 20:
    #    state['log_sql_queries']= state['log_sql_queries'][-20:]    
        
    return state

### Orchestrator

In [183]:
tools =[retrieve_insights,create_sql_query_or_queries,generate_answer]

def get_next_tool(state:State):
  ''' creates a list of actions taken by the agent from the intermediate steps '''  
  nr_executions_retrieve_insights = 0
  nr_executions_create_sql_query_or_queries = 0
  for index,action in enumerate(state['intermediate_steps']):
      
    if action.tool == 'retrieve_insights' and action.log == 'tool ran successfully':
      nr_executions_retrieve_insights +=1

    if action.tool == 'create_sql_query_or_queries' and action.log == 'tool ran successfully':
      nr_executions_create_sql_query_or_queries +=1

  if nr_executions_retrieve_insights == nr_executions_create_sql_query_or_queries == 1:
    next_tool = 'generate_answer'
  elif nr_executions_retrieve_insights == 0 and nr_executions_create_sql_query_or_queries == 0:
    next_tool = 'retrieve_insights'
  elif nr_executions_retrieve_insights > 0 and nr_executions_create_sql_query_or_queries == 0:
    next_tool = 'create_sql_query_or_queries'  

  return next_tool

def extract_msg_content_from_history(messages_log:list):
 ''' from a list of base messages, extract just the content '''
 content = []
 for msg in messages_log:
     content.append(msg.content)
 return "\n".join(content)

# use it like so: extract_msg_content_from_history(test_state['messages_log'])

In [184]:
def orchestrator(state:State):
  ''' Function that decides which tools to use '''

  # if all tools for retrieving insights have been used, go directly to answer.
  next_tool = get_next_tool(state)
  if next_tool == 'generate_answer':
     action = AgentAction(tool='generate_answer', tool_input='', log='')
     state['intermediate_steps'].append(action)
     return state

  else:
    system_prompt = f"""You are a decision support consultant helping users make data-driven decisions.

    Your task is to continue the conversation by answering the following question: {{question}}.

    Conversation history:
    {{messages_log}}.
  
    You have access to a database available to you with the following schema: {{objects_documentation}}.

    Current insights: "{{insights}}".
  
    Decision rules:
 
    Step 1. If current question was already answered in the conversation history, or current insights are sufficient to answer the question, return generate_answer tool.      

    Step 2. If current insights are not sufficient to answer the question, return {next_tool} tool.          
    """

    prompt = ChatPromptTemplate.from_messages(
      [('system', system_prompt)]
    )

    chain = prompt | llm_fast.bind_tools(tools) 
    result = chain.invoke({'messages_log':extract_msg_content_from_history(state['messages_log']),
                         'question': state['current_question'], 
                         'objects_documentation':state['objects_documentation'],
                         'insights': format_sql_query_results_for_prompt(state['current_sql_queries'])
                         #'next_tool':get_next_tool(state)
                         })

    tool_name = result.tool_calls[0]["name"]

    action = AgentAction( tool=tool_name, tool_input='', log = '' )
    state['intermediate_steps'].append(action)  
    return state     

### Run control flow

In [185]:
# run the nodes

def run_control_flow(state:State):
    ''' Based on the last tool name stored in intermediate_steps (generated by the orchestrator), it executes the next node that will trigger the control flow '''
    
    # get the next tool to execute by looking in the last tool_name in the intermediate steps
    tool_name = state['intermediate_steps'][-1].tool
    
    # control flow 1: retrieve insights from previous run queries
    if tool_name == 'retrieve_insights':
      state = retrieve_insights.invoke({'state':state})   
      # log the tool call
      action = AgentAction(tool=tool_name, tool_input='',log='tool ran successfully')
      state['intermediate_steps'].append(action)

    # control flow 2: generate new insights by creating & executing new queries
    elif tool_name == 'create_sql_query_or_queries':
      state = create_sql_query_or_queries.invoke({'state':state})
      execute_sql_query(state)
      # log the tool call
      action = AgentAction(tool=tool_name, tool_input='',log='tool ran successfully')
      state['intermediate_steps'].append(action)

    # control flow 3: generate answer & manage chat history.
    elif tool_name == 'generate_answer':  
      state = generate_answer.invoke({'state':state}) 
      manage_memory_chat_history(state) 
      # log the tool call
      action = AgentAction(tool=tool_name, tool_input='',log='tool ran successfully')
      state['intermediate_steps'].append(action)

    return state

### Assemble graph

In [186]:
# assemble graph

from langgraph.graph import StateGraph, START, END
from langgraph.checkpoint.memory import MemorySaver

# function to reset the state current queries (to add in the start of graph execution)
def reset_state(state:State):
    state['current_sql_queries'] = []
    state['intermediate_steps'] = []
    state['llm_answer'] = AIMessage(content='')
    state['objects_documentation'] = """
     Table company: List of public companies. Granularity is company-name. Column (prefixed with table name):
     company.company-name: the name of the public company.
     company.annual_revenue_usd: revenue in last 12 months ($).

     Table feedback: Feedbacks given by clients to products. Granularity is feedback. Key is feedback_id. Columns (prefixed with table name):
     feedback.feedback_id: id of the feedback.
     feedback.feedback_date: date of feedback.
     feedback.product_id: id of the product the feedback was given for.
     feedback.product_company_name: company owning the product.
     feedback.feedback_text: the text of the feedback.
     feedback.feedback_rating: rating of the feedback from 1 to 5, 5 being the highest score.

     Table products: Shows product metadata. Granularity is product. Key is product_id. Columns (prefixed with table name):
     products.product_id: id of the product.
     products.product_name: name of the product.
     products.product_brand: the brand under which the product was presented.
     products.product_manufacturer: product manufacturer.
     products.product_company_name: company owning the product.
     products.product_price: price of the product at crawling time.
     products.product_average-rating: average ratings across all feedbacks for the product, at crawling time.

     Table company -> column company_name relates to table feedback -> column product_company_name
     Table products -> column product_company_name relates to table feedback -> column product_company-name
     Table feedback -> column product_id relates to table products -> column product_id
     """
    return state

def router(state:State):
    # returns the tool name to use
    return state['intermediate_steps'][-1].tool

graph= StateGraph(State)
graph.add_node("reset_state",reset_state)
graph.add_node("orchestrator",orchestrator)

# here you add the node corresponding to the first tool of each control flow, as the subsequent tools are run by the run_control_flow node
graph.add_node("retrieve_insights",run_control_flow)
graph.add_node("create_sql_query_or_queries",run_control_flow)
graph.add_node("generate_answer",run_control_flow)

# starting the agent
graph.add_edge(START,"reset_state")
graph.add_edge("reset_state","orchestrator")
graph.add_conditional_edges(source='orchestrator',path=router)

# here you add a link from each the control flow node back to the orchestator - except for the generate_answer node.
graph.add_edge("retrieve_insights","orchestrator")
graph.add_edge("create_sql_query_or_queries","orchestrator")

# last control flow is generate_answer
graph.add_edge("generate_answer",END)

memory = MemorySaver()
graph = graph.compile(checkpointer=memory)


### test the agent

In [187]:
# Start a new conversation

question = 'How many products have a rating of 5?'
messages_log = []

initial_dict = {'objects_documentation':objects_documentation,
     'messages_log': messages_log,
     'intermediate_steps':[],
     'current_question':question,
     'current_sql_queries': [],
     'llm_answer': AIMessage(content='')
     }

vector_store = None  # reset vector store
config, thread_id = create_config('Run Agent',True)

result = graph.invoke(initial_dict, config = config)
print(result['llm_answer'].content)

💭 Gathering my thoughts...
✅ SQL queries created:1
⚙️ Analysing results...
There are 6,932 unique products that have received a rating of 5.

I analyzed data based on the following filters and transformations:

🧊 Tables: • feedback • products

🔍 Filters: • feedback.feedback_rating = 5

🧮 Aggregations: • COUNT(DISTINCT products.product_id)


In [None]:
for step in graph.stream(initial_dict, config = config, stream_mode="updates"):
 step_name, output = list(step.items())[0]
 if step_name == 'create_sql_query_or_queries':
   print(f"✅ SQL queries created:{len(output['current_sql_queries'])}")
 elif step_name == 'execute_sql_query':
   print("⚙️ Analysing results...")
 elif step_name == 'generate_answer':
   print("\n📣 Final Answer:\n")
   print(output['llm_answer'].content)

In [169]:
# Continue the conversation
config, _ = create_config('Run Agent',False,thread_id)
result = graph.invoke({
    'current_question': 'How many products have a rating of 5?'
}, config=config)

print(result['llm_answer'].content)

There are 6,932 unique products that have received a perfect rating of 5. This means that these products have been highly rated by customers, indicating a high level of satisfaction with their quality or performance. If you're looking for top-rated products, these might be a good place to start.

I analyzed data based on the following filters and transformations:


In [78]:
# Continue the conversation

config, _ = create_config('Run Agent',False,thread_id)
if __name__ == "__main__":
 for step in graph.stream(initial_dict, config = config, stream_mode="updates"):
   step_name, output = list(step.items())[0]
   if step_name == 'create_sql_query_or_queries':
    print(f"✅ SQL queries created:{len(output['current_sql_queries'])}")
   elif step_name == 'execute_sql_query':
    print("⚙️ Analysing results...")
   elif step_name == 'generate_answer':
    print("\n📣 Final Answer:\n")
    print(output['llm_answer'].content)


📣 Final Answer:

Based on the data, there are 1,028 products that have a rating of 5. This means that these products have received the highest possible rating from customers, indicating a high level of satisfaction.

I analyzed data based on the following filters and transformations:


### Testing Locally

### test orchestrator

In [134]:
vector_store = None  # reset vector store
question = 'What is the firm with the most feedback entries?'
test_state = {
'objects_documentation':objects_documentation,
'messages_log':[],
'intermediate_steps' : [],
'current_question':question,
'current_sql_queries': [],
'llm_answer': AIMessage(content='')
}

orchestrator(test_state)
#test_state = run_control_flow(test_state) # retrieve insights
#orchestrator(test_state)
#test_state = run_control_flow(test_state) # create sql query + execute sql query
#orchestrator(test_state)
#test_state = run_control_flow(test_state) # generate answer + manage memory

{'objects_documentation': 'Table company: List of public companies. Granularity is company-name. Column (prefixed with table name):\n     company.company-name: the name of the public company.\n     company.annual_revenue_usd: revenue in last 12 months ($).\n\n     Table feedback: Feedbacks given by clients to products. Granularity is feedback. Key is feedback_id. Columns (prefixed with table name):\n     feedback.feedback_id: id of the feedback.\n     feedback.feedback_date: date of feedback.\n     feedback.product_id: id of the product the feedback was given for.\n     feedback.product_company_name: company owning the product.\n     feedback.feedback_text: the text of the feedback.\n     feedback.feedback_rating: rating of the feedback from 1 to 5, 5 being the highest score.\n\n     Table products: Shows product metadata. Granularity is product. Key is product_id. Columns (prefixed with table name):\n     products.product_id: id of the product.\n     products.product_name: name of the

In [None]:
orchestrator(test_state)

In [None]:
test_state

In [141]:
# continue the conversation
test_state['current_question'] = 'How many feedback entries does Samsung have?'
test_state['intermediate_steps'] = []
test_state['current_sql_queries'] = []
test_state['llm_answer'] = AIMessage(content='')

In [132]:
test_state = run_control_flow(test_state)

In [None]:
# test the orchestrator

vector_store = None  # reset vector store
question = 'How many feedback entries does Samsung have?'
test_state = {
'objects_documentation':objects_documentation,
'messages_log':[ HumanMessage(content='What is the firm with the most feedback entries?'),
                 AIMessage(content='''The firm with the most feedback entries is Samsung, with a total of 280,625 feedback entries. However, it's important to note that while Samsung is highlighted as having the highest number, there might be other companies with the same number of feedback entries that are not mentioned in the data provided.

I analyzed data based on the following filters and transformations:

🧊 Tables: • feedback

🧮 Aggregations: • COUNT(feedback.feedback_id)

📦 Groupings: • feedback.product_company_name''')],
'intermediate_steps' : [],
'current_question':question,
'current_sql_queries': [],
'llm_answer': AIMessage(content='')
}

orchestrator(test_state)

In [55]:
vector_store = None  # reset vector store
question = 'What is the firm with the most feedback entries?'
test_state = {
'objects_documentation':objects_documentation,
'messages_log':[],
'intermediate_steps' : [],
'current_question':question,
'current_sql_queries': [],
'llm_answer': AIMessage(content='')
}
orchestrator(test_state)
test_state = run_control_flow(test_state) # retrieve insights
orchestrator(test_state)
test_state = run_control_flow(test_state) # create sql query + execute sql query
orchestrator(test_state)
test_state = run_control_flow(test_state) # generate answer + manage memory

In [None]:
run_control_flow(test_state)