<a href="https://colab.research.google.com/github/kuberiitb/retail_price_ai_agent/blob/main/notebooks/RetailPriceAgent_Text2Sql_V2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Agent created using code from [LangGraph Tutorial](https://python.langchain.com/docs/tutorials/sql_qa/)

In [1]:
!pip install langchain langchain_community langchain-groq langgraph ipython --quiet

# Setting up SQL data

In [36]:
from langchain_groq import ChatGroq
from IPython.display import Markdown, display

In [3]:
from langchain_community.utilities import SQLDatabase

db = SQLDatabase.from_uri("sqlite:///C:/Users/SYAMALA/ai_agent_course_github/retail_price_ai_agent/data/enhanced_retail_data.db")
print(db.get_usable_table_names())

['competitor_information', 'current_product_information', 'forecast_data', 'historical_data', 'inventory_data']


In [18]:
print(db.get_table_info(["competitor_information"]))


CREATE TABLE competitor_information (
	sku_id INTEGER, 
	product_name TEXT, 
	category TEXT, 
	unit_price REAL, 
	promotion TEXT, 
	discount_pct REAL
)

/*
3 rows from competitor_information table:
sku_id	product_name	category	unit_price	promotion	discount_pct
1	t-shirt	men	851.9	BTGO	0.7
8	t-shirt	women	750.0	0	1.0
*/


In [5]:
db.run("SELECT * FROM historical_data LIMIT 10;")

"[(1, 't-shirt', 'men', '2023-06-01 00:00:00', 861.0, 520.0, 20.0, 1.1, 77.0, 66297.0, 26257.0, -0.93, 0.4), (1, 't-shirt', 'men', '2023-07-01 00:00:00', 867.0, 520.0, 20.0, 1.1, 76.0, 65892.0, 26372.0, -0.93, 0.4), (1, 't-shirt', 'men', '2023-08-01 00:00:00', 1076.0, 520.0, 0.0, 1.1, 51.0, 54876.0, 28356.0, -0.93, 0.52), (1, 't-shirt', 'men', '2023-09-01 00:00:00', 1080.0, 520.0, 0.0, 0.9, 42.0, 45360.0, 23520.0, -0.93, 0.52), (1, 't-shirt', 'men', '2023-10-01 00:00:00', 1073.0, 520.0, 0.0, 0.9, 42.0, 45066.0, 23226.0, -0.93, 0.52), (1, 't-shirt', 'men', '2023-11-01 00:00:00', 1072.0, 520.0, 0.0, 1.3, 61.0, 65392.0, 33672.0, -0.93, 0.51), (1, 't-shirt', 'men', '2023-12-01 00:00:00', 857.0, 520.0, 20.0, 1.3, 91.0, 77987.0, 30667.0, -0.93, 0.39), (1, 't-shirt', 'men', '2024-01-01 00:00:00', 1075.0, 520.0, 0.0, 0.9, 42.0, 45150.0, 23310.0, -0.93, 0.52), (1, 't-shirt', 'men', '2024-02-01 00:00:00', 1070.0, 520.0, 0.0, 0.9, 42.0, 44940.0, 23100.0, -0.93, 0.51), (1, 't-shirt', 'men', '2024-

In [6]:
db.run("SELECT * FROM competitor_information LIMIT 10;")

"[(1, 't-shirt', 'men', 851.9, 'BTGO', 0.7), (8, 't-shirt', 'women', 750.0, '0', 1.0)]"

In [7]:
db.run("SELECT * FROM forecast_data LIMIT 10;")

"[(1, 't-shirt', 'men', '2025-09-01 00:00:00', 56.0), (1, 't-shirt', 'men', '2025-10-01 00:00:00', 56.0), (1, 't-shirt', 'men', '2025-11-01 00:00:00', 56.0), (1, 't-shirt', 'men', '2025-12-01 00:00:00', 56.0), (1, 't-shirt', 'men', '2026-01-01 00:00:00', 56.0), (1, 't-shirt', 'men', '2026-02-01 00:00:00', 56.0), (8, 't-shirt', 'women', '2025-09-01 00:00:00', 60.0), (8, 't-shirt', 'women', '2025-10-01 00:00:00', 60.0), (8, 't-shirt', 'women', '2025-11-01 00:00:00', 60.0), (8, 't-shirt', 'women', '2025-12-01 00:00:00', 60.0)]"

In [8]:
db.run("SELECT * FROM inventory_data LIMIT 10;")

"[(1, 't-shirt', 'men', 168.0), (8, 't-shirt', 'women', 120.0)]"

In [9]:
db.run("SELECT * FROM current_product_information LIMIT 10;")

'[(1, 1080, 100, -0.93, 1.075268817204301), (8, 1010, 300, -1.15, 0.8695652173913044)]'

In [10]:
from sqlalchemy import create_engine, inspect

engine = create_engine("sqlite:///C:/Users/SYAMALA/ai_agent_course_github/retail_price_ai_agent/data/enhanced_retail_data.db")
inspector = inspect(engine)

columns = inspector.get_columns("inventory_data")
for col in columns:
    print(col["name"], col["type"])

sku_id INTEGER
product_name TEXT
category TEXT
stock REAL


In [11]:
# Langgraph State
from typing_extensions import TypedDict

class State(TypedDict):
    question: str
    query: str
    result: str
    answer: str

# Setup LLM

In [12]:
import os
from dotenv import load_dotenv
load_dotenv(".env")

True

In [50]:
def setup_llm(model_name):
    llm = ChatGroq(groq_api_key=os.environ["API_KEY"], model_name=model_name,temperature=0)
    return llm

In [51]:
llm = setup_llm("llama-3.1-8b-instant")

In [42]:
llm

ChatGroq(client=<groq.resources.chat.completions.Completions object at 0x000001C126D169E0>, async_client=<groq.resources.chat.completions.AsyncCompletions object at 0x000001C126D17360>, model_name='llama-3.1-8b-instant', model_kwargs={}, groq_api_key=SecretStr('**********'))

In [43]:
llm.invoke("Hi")

AIMessage(content="It's nice to meet you. Is there something I can help you with or would you like to chat?", additional_kwargs={}, response_metadata={'token_usage': {'completion_tokens': 23, 'prompt_tokens': 36, 'total_tokens': 59, 'completion_time': 0.027284929, 'prompt_time': 0.002673068, 'queue_time': 0.044535599, 'total_time': 0.029957997}, 'model_name': 'llama-3.1-8b-instant', 'system_fingerprint': 'fp_33e8adf159', 'service_tier': 'on_demand', 'finish_reason': 'stop', 'logprobs': None}, id='run--97e17b89-9fdb-4888-9fc7-19dc34237294-0', usage_metadata={'input_tokens': 36, 'output_tokens': 23, 'total_tokens': 59})

# Using inbuilt Langgraph ReAct agent

In [44]:
from langchain_community.agent_toolkits import SQLDatabaseToolkit

toolkit = SQLDatabaseToolkit(db=db, llm=llm)

tools = toolkit.get_tools()

for t in tools:
  print(t.name, ":", t.description)

sql_db_query : Input to this tool is a detailed and correct SQL query, output is a result from the database. If the query is not correct, an error message will be returned. If an error is returned, rewrite the query, check the query, and try again. If you encounter an issue with Unknown column 'xxxx' in 'field list', use sql_db_schema to query the correct table fields.
sql_db_schema : Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. Be sure that the tables actually exist by calling sql_db_list_tables first! Example Input: table1, table2, table3
sql_db_list_tables : Input is an empty string, output is a comma-separated list of tables in the database.
sql_db_query_checker : Use this tool to double check if your query is correct before executing it. Always use this tool before executing a query with sql_db_query!


In [96]:
database_information = """"
TABLE: historical_data - Contains historical monthly-level sales, pricing, and profit data for each SKU.

COLUMN: sku_id - Unique identifier for each product-category combination.

COLUMN: product_name - Name of the specific product.

COLUMN: category - Category or department the product belongs to.

COLUMN: date - Month (or month-start date) representing the sales period.

COLUMN: unit_price - Actual selling price per unit after applying discounts.

COLUMN: unit_cost - Cost to produce or acquire one unit of the product.

COLUMN: discount_pct: Fractional discount applied to the base price (e.g., 0.1 = 10%, 0.2 = 20%).

COLUMN: seasonality_factor - Seasonal adjustment factor reflecting demand fluctuations.

COLUMN: units_sold - Total quantity sold for the SKU during the given period.

COLUMN: revenue - Total revenue generated = unit_price * units_sold.

COLUMN: profit - Total profit = (unit_price - unit_cost) * units_sold.

COLUMN: price_elasticity - Measures the sensitivity of demand to changes in price; negative values indicate that demand decreases as price increases.

COLUMN: profit_margin - Percentage of revenue remaining as profit after accounting for costs; calculated as (price - cost) / price.



#########################################################################################

TABLE: current_product_information - Stores reference or metadata information about each SKU used for pricing and forecasting.

COLUMN: sku_id - Unique identifier for each product-category combination.

COLUMN: base_price - Reference or standard list price of the product.

COLUMN: base_demand - Baseline expected demand level for the product.

COLUMN: elasticity - Price elasticity coefficient indicating sensitivity of demand to price changes.

COLUMN: margin - Target profit margin ratio derived from elasticity or business rules.

#########################################################################################
TABLE: forecast_data - Contains forecasted monthly unit sales for each SKU based on predictive modeling.

COLUMN: sku_id - Unique identifier for each product-category combination.

COLUMN: product_name - Name of the specific product.

COLUMN: category - Category or department the product belongs to.

COLUMN: date - Forecast month or future period.

COLUMN: units_sale - Forecasted number of units expected to be sold.

#########################################################################################

TABLE: inventory_data - Tracks current stock levels for each SKU in the inventory.

COLUMN: sku_id - Unique identifier for each product-category combination.

COLUMN: product_name - Name of the specific product.

COLUMN: category - Category or department the product belongs to.

COLUMN: stock - Current quantity of the SKU available in inventory.

#########################################################################################

TABLE: competitor_information - Captures competitor pricing and promotion details for comparative analysis.

COLUMN: sku_id - Unique identifier representing the same or equivalent SKU.

COLUMN: product_name - Name of the product for cross-reference with competitors.

COLUMN: category - Product category for comparison.

COLUMN: unit_price - Competitor’s selling price for the product.

COLUMN: promotion - Competitor’s promotion or offer label (e.g., "BOGO","NONE" or discount value).

COLUMN: discount_pct - Discount percentage applied by the competitor, if available.
"""

In [97]:
from langchain_core.messages import HumanMessage
from langgraph.prebuilt import create_react_agent
from langchain.agents import create_sql_agent
from langgraph.checkpoint.memory import MemorySaver
from langchain import PromptTemplate



extra_details_about_data = """
1. SKU means product_name+category combination, can be represented by sku_id
2. product means product_name unless specified otherwise.
3. While returning SKU information, mention its details like product_name and category.
4. All the price and revenue information is in INR.
"""

system_message = """
You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct {dialect} query to run,
then look at the results of the query and return the answer. Unless the user
specifies a specific number of examples they wish to obtain, always limit your
query to at most {top_k} results.

{extra_details_about_data}

You can order the results by a relevant column to return the most interesting
examples in the database. Never query for all the columns from a specific table,
only ask for the relevant columns given the question.


Before executing a query, reason carefully about the query.
- Only call the SQL validation/execution tool **once per unique query**.
- If the query is invalid or returns an error, you may modify it and call the tool again, but avoid calling the tool multiple times for the same query.
- Always validate your query logically in your reasoning before calling the tool.
- Do not attempt to call the tool multiple times without changing the query.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the
database.

To start you should ALWAYS look at the tables in the database to see what you
can query. Do NOT skip this step.

Use this description of tables and columns for reference.
{database_information}

Then you should query the schema of the most relevant tables.

Business context:
Product usually means product-category combination.
""".format(
    dialect="SQLite",
    top_k=5,
    database_information=database_information,
    extra_details_about_data=extra_details_about_data
)

system_message_with_prompt_template = PromptTemplate(
    template=system_message
)

system_message = system_message + """
If a user question does not require data from the SQL database (for example, general advice, business strategy, or descriptive explanations),
you should answer directly using reasoning — do NOT call any SQL tool in those cases.
"""


memory = MemorySaver()

sql_agent_executor = create_react_agent(llm, tools, prompt=system_message,checkpointer=memory)

## Example with pretty printing

In [98]:
question = "Which is highest selling SKU_ID?"

for step in sql_agent_executor.stream(
    {"messages": [{"role": "user", "content": question}]},
    stream_mode="values",
):
    step["messages"][-1].pretty_print()

ValueError: Checkpointer requires one or more of the following 'configurable' keys: thread_id, checkpoint_ns, checkpoint_id

## Returning final output only

## skipping all the intermediate steps

In [99]:
def get_sql_reponse(question):
  config = {"configurable": {"thread_id": "session-1"}}
  for step in sql_agent_executor.stream(
      {"messages": [{"role": "user", "content": question}]},
      config=config,
      stream_mode="values",
  ):
    step["messages"][-1].pretty_print()

get_sql_reponse("Which is most profitable SKU_ID in the current time")


Which is most profitable SKU_ID in the current time
Tool Calls:
  sql_db_list_tables (axnvasbdz)
 Call ID: axnvasbdz
  Args:
  sql_db_list_tables (5bc71wnq8)
 Call ID: 5bc71wnq8
  Args:
  sql_db_schema (s6xx9n8g5)
 Call ID: s6xx9n8g5
  Args:
    table_names: historical_data, current_product_information, forecast_data, inventory_data, competitor_information
  sql_db_query_checker (0g6b1mjjv)
 Call ID: 0g6b1mjjv
  Args:
    query: SELECT sku_id FROM current_product_information ORDER BY base_price DESC LIMIT 1
  sql_db_query (84093b8q3)
 Call ID: 84093b8q3
  Args:
    query: SELECT sku_id FROM current_product_information ORDER BY base_price DESC LIMIT 1
Name: sql_db_query

[(1,)]

The most profitable SKU_ID in the current time is 1.


In [70]:
get_sql_reponse("Which is most profitable SKU of 2025?")


Which is most profitable SKU of 2025?
Tool Calls:
  sql_db_list_tables (911z1ezwe)
 Call ID: 911z1ezwe
  Args:
  sql_db_query_checker (414s1eynd)
 Call ID: 414s1eynd
  Args:
    query: SELECT sku_id, product_name, category, revenue FROM historical_data WHERE date LIKE '2025%' ORDER BY profit DESC LIMIT 5
  sql_db_query_checker (yfn1hgwht)
 Call ID: yfn1hgwht
  Args:
    query: SELECT sku_id, product_name, category, revenue FROM historical_data WHERE date LIKE '2025%' ORDER BY profit DESC LIMIT 5
  sql_db_query (q0mkmnmnc)
 Call ID: q0mkmnmnc
  Args:
    query: SELECT sku_id, product_name, category, revenue FROM historical_data WHERE date LIKE '2025%' ORDER BY profit DESC LIMIT 5
Name: sql_db_query

[(1, 't-shirt', 'men', 54825.0), (1, 't-shirt', 'men', 54723.0), (1, 't-shirt', 'men', 66374.0), (8, 't-shirt', 'women', 54972.0), (8, 't-shirt', 'women', 54648.0)]

The most profitable SKU of 2025 is SKU 1, which is a 't-shirt' in the 'men' category, with a revenue of 54825.0 INR.


In [71]:
get_sql_reponse("What is most elastic SKU?")


What is most elastic SKU?
Tool Calls:
  sql_db_query_checker (tnjywbnzm)
 Call ID: tnjywbnzm
  Args:
    query: SELECT sku_id FROM historical_data ORDER BY price_elasticity DESC LIMIT 1
  sql_db_query_checker (2ya9na6yy)
 Call ID: 2ya9na6yy
  Args:
    query: SELECT sku_id FROM historical_data ORDER BY price_elasticity DESC LIMIT 1
  sql_db_query_checker (n4sgxxdkb)
 Call ID: n4sgxxdkb
  Args:
    query: SELECT sku_id FROM historical_data ORDER BY price_elasticity DESC LIMIT 1
  sql_db_query_checker (97359ydkq)
 Call ID: 97359ydkq
  Args:
    query: SELECT sku_id FROM historical_data ORDER BY price_elasticity DESC LIMIT 1
  sql_db_query_checker (m3skt4btk)
 Call ID: m3skt4btk
  Args:
    query: SELECT sku_id FROM historical_data ORDER BY price_elasticity DESC LIMIT 1
  sql_db_query_checker (kd20xevax)
 Call ID: kd20xevax
  Args:
    query: SELECT sku_id FROM historical_data ORDER BY price_elasticity DESC LIMIT 1
  sql_db_query_checker (rxah5h42t)
 Call ID: rxah5h42t
  Args:
    query:

In [102]:
get_sql_reponse("for which products competitors are offering better prices than us")


for which products competitors are offering better prices than us


APIStatusError: Error code: 413 - {'error': {'message': 'Request too large for model `llama-3.1-8b-instant` in organization `org_01jam7vee2ejsrcm6rdz1rkfja` service tier `on_demand` on tokens per minute (TPM): Limit 6000, Requested 39800, please reduce your message size and try again. Need more tokens? Upgrade to Dev Tier today at https://console.groq.com/settings/billing', 'type': 'tokens', 'code': 'rate_limit_exceeded'}}

In [94]:
get_sql_reponse("Competitor ABC company is offering a 20 percent discount on t-shirts Should we match this offer, create a similar promotion, or use a different strategy?")


Competitor ABC company is offering a 20 percent discount on t-shirts Should we match this offer, create a similar promotion, or use a different strategy?


ValueError: Found AIMessages with tool_calls that do not have a corresponding ToolMessage. Here are the first few of those tool calls: [{'name': 'sql_db_list_tables', 'args': {}, 'id': 'mv40vpnf5', 'type': 'tool_call'}, {'name': 'sql_db_list_tables', 'args': {}, 'id': 'phr5s8jdd', 'type': 'tool_call'}, {'name': 'sql_db_list_tables', 'args': {}, 'id': '6wtqmxtbx', 'type': 'tool_call'}].

Every tool call (LLM requesting to call a tool) in the message history MUST have a corresponding ToolMessage (result of a tool invocation to return to the LLM) - this is required by most LLM providers.
For troubleshooting, visit: https://python.langchain.com/docs/troubleshooting/errors/INVALID_CHAT_HISTORY