In [15]:
from dotenv import load_dotenv
from langchain_community.utilities import SQLDatabase
from langchain_community.agent_toolkits import create_sql_agent
from langchain_openai import ChatOpenAI
from langchain_community.vectorstores import FAISS
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
from langchain_openai import OpenAIEmbeddings
from langchain_core.prompts import (
    ChatPromptTemplate,
    FewShotPromptTemplate,
    MessagesPlaceholder,
    PromptTemplate,
    SystemMessagePromptTemplate,
)
import os

In [16]:
load_dotenv()
db_user = os.getenv('DB_USER')
db_pass = os.getenv('DB_PASSWORD')
db_host = os.getenv('DB_HOST')
db_name = os.getenv('DB_NAME')

In [17]:
examples = [
   {
        "input": "Give all Woman Owned companies with sales in 2023.",
        "query": "SELECT * FROM demo_vendors WHERE Hub_Type LIKE '%women%' OR Hub_Type LIKE '%female%' AND sales_2023 > 0;",
    },
    {
        "input": "Give top 10 HUB certified application development vendors.",
        "query": "SELECT * FROM demo_vendors WHERE `Technology_Category_1` LIKE '%Application Development%' AND hub_certified = 1 ORDER BY sales_2023 DESC LIMIT 10;",
    },
    {
        "input": "Give top 10 HUB certified software development vendors.",
        "query": "SELECT * FROM demo_vendors WHERE `Technology_Category_1` LIKE '%Application Development%' AND hub_certified = 1 ORDER BY sales_2023 DESC LIMIT 10;",
    },
    {
        "input": "Give top 10 HUB certified IT development vendors.",
        "query": "SELECT * FROM demo_vendors WHERE `Technology_Category_1` LIKE '%Application Development%' AND hub_certified = 1 ORDER BY sales_2023 DESC LIMIT 10;",
    },
    {
        "input": "Give top 10 HUB certified custom software development vendors.",
        "query": "SELECT * FROM demo_vendors WHERE `Technology_Category_1` LIKE '%Application Development%' AND hub_certified = 1 ORDER BY sales_2023 DESC LIMIT 10;",
    },
    {
        "input": "Give all application development vendors.",
        "query": "SELECT * FROM demo_vendors WHERE `Technology_Category_1` LIKE '%Application Development%';",
    },
    {
        "input": "Give all software development vendors.",
        "query": "SELECT * FROM demo_vendors WHERE `Technology_Category_1` LIKE '%Application Development%';",
    },
    {
        "input": "Give all IT development vendors.",
        "query": "SELECT * FROM demo_vendors WHERE `Technology_Category_1` LIKE '%Application Development%';",
    }
]

In [18]:
db = SQLDatabase.from_uri(f"mysql+mysqlconnector://{db_user}:{db_pass}@{db_host}/{db_name}")

In [19]:
llm = ChatOpenAI(model="gpt-4-turbo", temperature=0)
agent_executor = create_sql_agent(llm, db=db, agent_type="openai-tools", verbose=True)

In [20]:
example_selector = SemanticSimilarityExampleSelector.from_examples(
    examples,
    OpenAIEmbeddings(),
    FAISS,
    k=5,
    input_keys=["input"],
)

In [21]:
system_prefix = """You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
You can order the results by a relevant column to return the most interesting examples in the database.
Always query for all the columns from a specific table.
You have access to tools for interacting with the database.
Only use the given tools. Only use the information returned by the tools to construct your final answer.
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.
If asked about women-owned companies, use %LIKE% to search for women or female in the Hub_Type column.
DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.

If the question does not seem related to the database, just return "I don't know" as the answer.

Here are some examples of user inputs and their corresponding SQL queries:"""

few_shot_prompt = FewShotPromptTemplate(
    example_selector=example_selector,
    example_prompt=PromptTemplate.from_template(
        "User input: {input}\nSQL query: {query}"
    ),
    input_variables=["input", "dialect", "top_k"],
    prefix=system_prefix,
    suffix="",
)

In [22]:
full_prompt = ChatPromptTemplate.from_messages(
    [
        SystemMessagePromptTemplate(prompt=few_shot_prompt),
        ("human", "{input}"),
        MessagesPlaceholder("agent_scratchpad"),
    ]
)

In [23]:
# Example formatted prompt
prompt_val = full_prompt.invoke(
    {
        "input": "How many vendors are there",
        "top_k": 100,
        "dialect": "mysql",
        "agent_scratchpad": [],
    }
)
print(prompt_val.to_string())

System: You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct mysql query to run, then look at the results of the query and return the answer.
You can order the results by a relevant column to return the most interesting examples in the database.
Always query for all the columns from a specific table.
You have access to tools for interacting with the database.
Only use the given tools. Only use the information returned by the tools to construct your final answer.
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.
If asked about women-owned companies, use %LIKE% to search for women or female in the Hub_Type column.
DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.

If the question does not seem related to the database, just return "I don't know" as the answer.

Here are some examples of user inputs and their corre

In [24]:
agent = create_sql_agent(
    llm=llm,
    db=db,
    prompt=full_prompt,
    verbose=True,
    agent_type="openai-tools",
)

In [31]:
agent.invoke("Which vendor did highest sales in the year 2022?")



[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3m
Invoking: `sql_db_query` with `{'query': 'SELECT * FROM vendors ORDER BY sales_2022 DESC LIMIT 1'}`


[0m[36;1m[1;3m[('SAS', 'DIR-CPO-5005', 'SAS offers deliverables-based information technology services (DBITS) through this contract, specifically: Technology Category 2: Business Intelligence (BI), Data Management, Analytics, and Automation, including Data Warehousing. This contract is for services ONLY. No hardware or software...', '', 'Business Intelligence (BI), Data Management, Analytics, and Automation, including Data Warehousing', '', '', 'Matthew Arabshahi', '5/9/2025', 'others', 'Paul Graeber', 'paul.graeber@sas.com', '(512) 840-6219', 1, Decimal('1982606.35'), Decimal('1998518.37'), Decimal('1455989.91'), Decimal('660535.46'), 1)][0m[32;1m[1;3mThe vendor that did the highest sales in the year 2022 is SAS, with sales amounting to $1,998,518.37.[0m

[1m> Finished chain.[0m


{'input': 'Which vendor did highest sales in the year 2022?',
 'output': 'The vendor that did the highest sales in the year 2022 is SAS, with sales amounting to $1,998,518.37.'}