<a href="https://colab.research.google.com/github/kuberiitb/retail_price_ai_agent/blob/main/notebooks/RetailPriceAgent_Text2Sql_V2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Agent created using code from [LangGraph Tutorial](https://python.langchain.com/docs/tutorials/sql_qa/)

In [1]:
!pip install langchain langchain_community langchain-groq langgraph ipython --quiet

# Setting up SQL data

In [None]:
from langchain_community.utilities import SQLDatabase

db = SQLDatabase.from_uri("sqlite:///retail_price_agent.db")
print(db.get_usable_table_names())

[]


In [18]:
print(db.get_table_info(["historical_data"]))


CREATE TABLE historical_data (
	sku_id INTEGER, 
	product_name TEXT, 
	category TEXT, 
	date TIMESTAMP, 
	unit_price REAL, 
	unit_cost REAL, 
	discount_pct REAL, 
	seasonality_factor REAL, 
	units_sold REAL, 
	revenue REAL, 
	profit REAL
)

/*
3 rows from historical_data table:
sku_id	product_name	category	date	unit_price	unit_cost	discount_pct	seasonality_factor	units_sold	revenue	profit
1	t-shirt	men	2023-06-01 00:00:00	1165.0	753.0	20.0	1.1	57.0	66405.0	23484.0
1	t-shirt	men	2023-07-01 00:00:00	1171.0	753.0	20.0	1.1	56.0	65576.0	23408.0
1	t-shirt	men	2023-08-01 00:00:00	1456.0	753.0	0.0	1.1	38.0	55328.0	26714.0
*/


In [19]:
db.run("SELECT * FROM historical_data LIMIT 10;")

"[(1, 't-shirt', 'men', '2023-06-01 00:00:00', 1165.0, 753.0, 20.0, 1.1, 57.0, 66405.0, 23484.0), (1, 't-shirt', 'men', '2023-07-01 00:00:00', 1171.0, 753.0, 20.0, 1.1, 56.0, 65576.0, 23408.0), (1, 't-shirt', 'men', '2023-08-01 00:00:00', 1456.0, 753.0, 0.0, 1.1, 38.0, 55328.0, 26714.0), (1, 't-shirt', 'men', '2023-09-01 00:00:00', 1460.0, 753.0, 0.0, 0.9, 31.0, 45260.0, 21917.0), (1, 't-shirt', 'men', '2023-10-01 00:00:00', 1453.0, 753.0, 0.0, 0.9, 31.0, 45043.0, 21700.0), (1, 't-shirt', 'men', '2023-11-01 00:00:00', 1452.0, 753.0, 0.0, 1.3, 45.0, 65340.0, 31455.0), (1, 't-shirt', 'men', '2023-12-01 00:00:00', 1161.0, 753.0, 20.0, 1.3, 67.0, 77787.0, 27336.0), (1, 't-shirt', 'men', '2024-01-01 00:00:00', 1455.0, 753.0, 0.0, 0.9, 31.0, 45105.0, 21762.0), (1, 't-shirt', 'men', '2024-02-01 00:00:00', 1450.0, 753.0, 0.0, 0.9, 31.0, 44950.0, 21607.0), (1, 't-shirt', 'men', '2024-03-01 00:00:00', 1461.0, 753.0, 0.0, 0.9, 31.0, 45291.0, 21948.0)]"

In [20]:
db.run("SELECT * FROM competitior_information LIMIT 10;")

"[(1, 't-shirt', 'men', 864.5, 'BTGO', 0.3), (8, 't-shirt', 'women', 946.4, 'BTGO', 0.3)]"

In [21]:
db.run("SELECT * FROM forecast_data LIMIT 10;")

"[(1, 't-shirt', 'men', '2025-09-01 00:00:00', 41.0), (1, 't-shirt', 'men', '2025-10-01 00:00:00', 41.0), (1, 't-shirt', 'men', '2025-11-01 00:00:00', 41.0), (1, 't-shirt', 'men', '2025-12-01 00:00:00', 41.0), (1, 't-shirt', 'men', '2026-01-01 00:00:00', 41.0), (1, 't-shirt', 'men', '2026-02-01 00:00:00', 41.0), (8, 't-shirt', 'women', '2025-09-01 00:00:00', 33.0), (8, 't-shirt', 'women', '2025-10-01 00:00:00', 33.0), (8, 't-shirt', 'women', '2025-11-01 00:00:00', 33.0), (8, 't-shirt', 'women', '2025-12-01 00:00:00', 33.0)]"

In [22]:
db.run("SELECT * FROM inventory_data LIMIT 10;")

"[(1, 't-shirt', 'men', 82.0), (8, 't-shirt', 'women', 99.0)]"

In [23]:
db.run("SELECT * FROM current_product_information LIMIT 10;")

'[(1, 1460.0, 500.0, -1.06, 0.94), (8, 1820.0, 800.0, -0.93, 1.08)]'

In [24]:
from sqlalchemy import create_engine, inspect

engine = create_engine("sqlite:///retail_price_agent.db")
inspector = inspect(engine)

columns = inspector.get_columns("inventory_data")
for col in columns:
    print(col["name"], col["type"])

sku_id INTEGER
product_name TEXT
category TEXT
stock REAL


In [25]:
# Langgraph State
from typing_extensions import TypedDict

class State(TypedDict):
    question: str
    query: str
    result: str
    answer: str

# Setup LLM

In [26]:
import os
from dotenv import load_dotenv
load_dotenv(".env")

True

In [27]:
# from langchain.chat_models import ChatOpenAI

if not os.environ["OPENAI_API_KEY"]:
    os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")

# llm = ChatOpenAI(model="gpt-4o-mini", temperature=0.1)

from langchain.chat_models import init_chat_model

llm = init_chat_model("gpt-4o-mini", model_provider="openai")

In [28]:
llm.invoke("Hi").content

'Hello! How can I assist you today?'

# Using inbuilt Langgraph ReAct agent

In [29]:
from langchain_community.agent_toolkits import SQLDatabaseToolkit

toolkit = SQLDatabaseToolkit(db=db, llm=llm)

tools = toolkit.get_tools()

for t in tools:
  print(t.name, ":", t.description)

sql_db_query : Input to this tool is a detailed and correct SQL query, output is a result from the database. If the query is not correct, an error message will be returned. If an error is returned, rewrite the query, check the query, and try again. If you encounter an issue with Unknown column 'xxxx' in 'field list', use sql_db_schema to query the correct table fields.
sql_db_schema : Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. Be sure that the tables actually exist by calling sql_db_list_tables first! Example Input: table1, table2, table3
sql_db_list_tables : Input is an empty string, output is a comma-separated list of tables in the database.
sql_db_query_checker : Use this tool to double check if your query is correct before executing it. Always use this tool before executing a query with sql_db_query!


In [47]:
database_information = """"
TABLE: historical_data - Contains historical monthly-level sales, pricing, and profit data for each SKU.

COLUMN: sku_id - Unique identifier for each product-category combination.

COLUMN: product_name - Name of the specific product.

COLUMN: category - Category or department the product belongs to.

COLUMN: date - Month (or month-start date) representing the sales period.

COLUMN: unit_price - Actual selling price per unit after applying discounts.

COLUMN: unit_cost - Cost to produce or acquire one unit of the product.

COLUMN: discount_pct - Percentage discount applied to the base price.

COLUMN: seasonality_factor - Seasonal adjustment factor reflecting demand fluctuations.

COLUMN: units_sold - Total quantity sold for the SKU during the given period.

COLUMN: revenue - Total revenue generated = unit_price * units_sold.

COLUMN: profit - Total profit = (unit_price - unit_cost) * units_sold.

#########################################################################################

TABLE: current_product_information - Stores reference or metadata information about each SKU used for pricing and forecasting.

COLUMN: sku_id - Unique identifier for each product-category combination.

COLUMN: base_price - Reference or standard list price of the product.

COLUMN: base_demand - Baseline expected demand level for the product.

COLUMN: elasticity - Price elasticity coefficient indicating sensitivity of demand to price changes.

COLUMN: margin - Target profit margin ratio derived from elasticity or business rules.

#########################################################################################
TABLE: forecast_data - Contains forecasted monthly unit sales for each SKU based on predictive modeling.

COLUMN: sku_id - Unique identifier for each product-category combination.

COLUMN: product_name - Name of the specific product.

COLUMN: category - Category or department the product belongs to.

COLUMN: date - Forecast month or future period.

COLUMN: units_sale - Forecasted number of units expected to be sold.

#########################################################################################

TABLE: inventory_data - Tracks current stock levels for each SKU in the inventory.

COLUMN: sku_id - Unique identifier for each product-category combination.

COLUMN: product_name - Name of the specific product.

COLUMN: category - Category or department the product belongs to.

COLUMN: stock - Current quantity of the SKU available in inventory.

#########################################################################################

TABLE: competitior_information - Captures competitor pricing and promotion details for comparative analysis.

COLUMN: sku_id - Unique identifier representing the same or equivalent SKU.

COLUMN: product_name - Name of the product for cross-reference with competitors.

COLUMN: category - Product category for comparison.

COLUMN: unit_price - Competitor’s selling price for the product.

COLUMN: promotion - Competitor’s promotion or offer label (e.g., "BOGO","NONE" or discount value).

COLUMN: discount_pct - Discount percentage applied by the competitor, if available.
"""

In [48]:
from langchain_core.messages import HumanMessage
from langgraph.prebuilt import create_react_agent

extra_details_about_data = """
1. SKU means product_name+category combination, can be represented by sku_id
2. product means product_name unless specified otherwise.
3. While returning SKU information, mention its details like product_name and category.
4. All the price and revenue information is in INR.
"""

system_message = """
You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct {dialect} query to run,
then look at the results of the query and return the answer. Unless the user
specifies a specific number of examples they wish to obtain, always limit your
query to at most {top_k} results.

{extra_details_about_data}

You can order the results by a relevant column to return the most interesting
examples in the database. Never query for all the columns from a specific table,
only ask for the relevant columns given the question.

You MUST double check your query before executing it. If you get an error while
executing a query, rewrite the query and try again.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the
database.

To start you should ALWAYS look at the tables in the database to see what you
can query. Do NOT skip this step.

Use this description of tables and columns for reference.
{database_information}

Then you should query the schema of the most relevant tables.

Business context:
Product usually means product-category combination.
""".format(
    dialect="SQLite",
    top_k=5,
    database_information=database_information,
    extra_details_about_data=extra_details_about_data
)

sql_agent_executor = create_react_agent(llm, tools, prompt=system_message)

## Example with pretty printing

In [49]:
question = "Which is highest selling SKU_ID?"

for step in sql_agent_executor.stream(
    {"messages": [{"role": "user", "content": question}]},
    stream_mode="values",
):
    step["messages"][-1].pretty_print()


Which is highest selling SKU_ID?
Tool Calls:
  sql_db_list_tables (call_1KnJN0cOVnlsjKuRnV8Mq651)
 Call ID: call_1KnJN0cOVnlsjKuRnV8Mq651
  Args:
Name: sql_db_list_tables

competitior_information, current_product_information, forecast_data, historical_data, inventory_data
Tool Calls:
  sql_db_schema (call_6zpVn3n8wHK8af8nliEBOceq)
 Call ID: call_6zpVn3n8wHK8af8nliEBOceq
  Args:
    table_names: historical_data
Name: sql_db_schema


CREATE TABLE historical_data (
	sku_id INTEGER, 
	product_name TEXT, 
	category TEXT, 
	date TIMESTAMP, 
	unit_price REAL, 
	unit_cost REAL, 
	discount_pct REAL, 
	seasonality_factor REAL, 
	units_sold REAL, 
	revenue REAL, 
	profit REAL
)

/*
3 rows from historical_data table:
sku_id	product_name	category	date	unit_price	unit_cost	discount_pct	seasonality_factor	units_sold	revenue	profit
1	t-shirt	men	2023-06-01 00:00:00	1165.0	753.0	20.0	1.1	57.0	66405.0	23484.0
1	t-shirt	men	2023-07-01 00:00:00	1171.0	753.0	20.0	1.1	56.0	65576.0	23408.0
1	t-shirt	men	202

## Returning final output only

## skipping all the intermediate steps

In [68]:
def get_sql_reponse(question):
  for step in sql_agent_executor.stream(
      {"messages": [{"role": "user", "content": question}]},
      stream_mode="values",
  ):
      if step['messages'][-1].response_metadata.get('finish_reason')=='stop':
        try:
          return display(Markdown(step['messages'][-1].content))
        except:
          return display(Markdown("I don't know"))

get_sql_reponse("Which is most profitable SKU_ID?")

The most profitable SKU_ID is 8, which corresponds to the product "t-shirt" in the "women" category, with a total profit of 693,898 INR.

In [69]:
get_sql_reponse("Which is most profitable SKU?")

The most profitable SKUs are:

1. SKU ID: 8
   - Product Name: t-shirt
   - Category: women
   - Total Profit: ₹693,898.00

2. SKU ID: 1
   - Product Name: t-shirt
   - Category: men
   - Total Profit: ₹656,949.00

In [70]:
get_sql_reponse("Which is most profitable SKU?")

The most profitable SKU is for the **t-shirt** in the **women** category, with a total profit of **INR 693,898**.

In [46]:
get_sql_reponse("Which is most profitable SKU of 2025?")

"The most profitable SKU of 2025 is a t-shirt in the women's category, with a total profit of $196,767. The second most profitable SKU is a t-shirt in the men's category, with a total profit of $188,747."

In [71]:
get_sql_reponse("What is most elastic SKU?")

The most elastic SKU has an elasticity of **-1.06** and a base price of **INR 1460**. The SKU ID for this product-category combination is **1**.