In [1]:
import os
from langchain_community.utilities.sql_database import SQLDatabase
from langchain.chains import create_sql_query_chain
from langchain_openai import ChatOpenAI
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
from urllib.parse import quote


In [14]:
os.environ["OPENAI_API_KEY"] = "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" 

In [16]:
hostname='localhost'
database='inventoryManagement'
username='postgres'
pwd='kri@123'
port_id=5432
encoded_pwd = quote(pwd)
connection_uri = f"postgresql+psycopg2://{username}:{encoded_pwd}@{hostname}:{port_id}/{database}"
db = SQLDatabase.from_uri(connection_uri)


In [4]:
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
generate_query = create_sql_query_chain(llm, db)
query = generate_query.invoke({"question": "how many products are there"})
print(query)

SELECT COUNT("productId") AS total_products
FROM "Products";


In [5]:
execute_query = QuerySQLDataBaseTool(db=db)
print(execute_query.invoke(query))

chain=generate_query | execute_query
chain.get_prompts()[0].pretty_print()

[(36,)]
You are a PostgreSQL expert. Given an input question, first create a syntactically correct PostgreSQL query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per PostgreSQL. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use CURRENT_DATE function to get the current date, if the question involves "today".

Use the following format:

Question: Question here
SQLQuery: SQL Query to 

In [22]:
from operator import itemgetter

from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough

answer_prompt = PromptTemplate.from_template(
    """Given the following user question, corresponding SQL query, and SQL result, answer the user question.

Question: {question}
SQL Query: {query}
SQL Result: {result}
Answer: """
)

rephrase_answer = answer_prompt | llm | StrOutputParser()

chain = (
    RunnablePassthrough.assign(query=generate_query).assign(
        result=lambda inputs: execute_query.invoke(inputs["query"])
    )
    | rephrase_answer
)
user_question = "how many customers are their whose name starts with letter h? Show at most 5."
response = chain.invoke({"question": user_question})
print(response)



There is 1 customer whose name starts with the letter 'H'.


In [11]:
examples = [
    {
        "input": "List all customers in France with a credit limit over 20,000.",
        "query": "SELECT * FROM customers WHERE country = 'France' AND creditLimit > 20000;"
    },
    {
        "input": "Get the highest payment amount made by any customer.",
        "query": "SELECT MAX(amount) FROM payments;"
    },
  
]

In [23]:
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder,FewShotChatMessagePromptTemplate,PromptTemplate

example_prompt = ChatPromptTemplate.from_messages(
    [
        ("human", "{input}\nSQLQuery:"),
        ("ai", "{query}"),
    ]
)
few_shot_prompt = FewShotChatMessagePromptTemplate(
    example_prompt=example_prompt,
    examples=examples,
    # input_variables=["input","top_k"],
    input_variables=["input"],
)
print(few_shot_prompt.format(input1="How many products are there?"))


Human: List all customers in France with a credit limit over 20,000.
SQLQuery:
AI: SELECT * FROM customers WHERE country = 'France' AND creditLimit > 20000;
Human: Get the highest payment amount made by any customer.
SQLQuery:
AI: SELECT MAX(amount) FROM payments;


In [24]:
from langchain_community.vectorstores import Chroma
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
from langchain_openai import OpenAIEmbeddings

vectorstore = Chroma()
vectorstore.delete_collection()
example_selector = SemanticSimilarityExampleSelector.from_examples(
    examples,
    OpenAIEmbeddings(),
    vectorstore,
    k=2,
    input_keys=["input"],
)
example_selector.select_examples({"input": "how many users we have?"})
few_shot_prompt = FewShotChatMessagePromptTemplate(
    example_prompt=example_prompt,
    example_selector=example_selector,
    input_variables=["input","top_k"],
)
print(few_shot_prompt.format(input="How many products are there?"))


  vectorstore = Chroma()


ImportError: Could not import chromadb python package. Please install it with `pip install chromadb`.

In [None]:
final_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", "You are a MySQL expert. Given an input question, create a syntactically correct MySQL query to run. Unless otherwise specificed.\n\nHere is the relevant table info: {table_info}\n\nBelow are a number of examples of questions and their corresponding SQL queries."),
        few_shot_prompt,
        ("human", "{input}"),
    ]
)
print(final_prompt.format(input="How many products are there?",table_info="some table info"))
generate_query = create_sql_query_chain(llm, db,final_prompt)
chain = (
RunnablePassthrough.assign(query=generate_query).assign(
    result=itemgetter("query") | execute_query
)
| rephrase_answer
)
chain.invoke({"question": "How many cutomers who have purchased more than one item"})


In [None]:
from operator import itemgetter
from langchain.chains.openai_tools import create_extraction_chain_pydantic
from langchain_core.pydantic_v1 import BaseModel, Field
from typing import List
import pandas as pd

def get_table_details():
    table_description = pd.read_csv("table_descriptions.csv")
    table_docs = []

    table_details = ""
    for index, row in table_description.iterrows():
        table_details = table_details + "Table Name:" + row['Table'] + "\n" + "Table Description:" + row['Description'] + "\n\n"

    return table_details


class Table(BaseModel):
    """Table in SQL database."""

    name: str = Field(description="Name of table in SQL database.")

# table_names = "\n".join(db.get_usable_table_names())
table_details = get_table_details()
print(table_details)


In [None]:
table_details_prompt = f"""Return the names of ALL the SQL tables that MIGHT be relevant to the user question. \
The tables are:

{table_details}

Remember to include ALL POTENTIALLY RELEVANT tables, even if you're not sure that they're needed."""

table_chain = create_extraction_chain_pydantic(Table, llm, system_message=table_details_prompt)
tables = table_chain.invoke({"input": "give me details of customer and their order count"})
tables


In [None]:
def get_tables(tables: List[Table]) -> List[str]:
    tables  = [table.name for table in tables]
    return tables

select_table = {"input": itemgetter("question")} | create_extraction_chain_pydantic(Table, llm, system_message=table_details_prompt) | get_tables
select_table.invoke({"question": "give me details of customer and their order count"})


In [None]:
chain = (
RunnablePassthrough.assign(table_names_to_use=select_table) |
RunnablePassthrough.assign(query=generate_query).assign(
    result=itemgetter("query") | execute_query
)
| rephrase_answer
)
chain.invoke({"question": "How many cutomers with order count more than 5"})


In [None]:
from langchain.memory import ChatMessageHistory
history = ChatMessageHistory()


In [None]:
final_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", "You are a MySQL expert. Given an input question, create a syntactically correct MySQL query to run. Unless otherwise specificed.\n\nHere is the relevant table info: {table_info}\n\nBelow are a number of examples of questions and their corresponding SQL queries. Those examples are just for referecne and hsould be considered while answering follow up questions"),
        few_shot_prompt,
        MessagesPlaceholder(variable_name="messages"),
        ("human", "{input}"),
    ]
)
print(final_prompt.format(input="How many products are there?",table_info="some table info",messages=[]))


In [None]:
generate_query = create_sql_query_chain(llm, db,final_prompt)

chain = (
RunnablePassthrough.assign(table_names_to_use=select_table) |
RunnablePassthrough.assign(query=generate_query).assign(
    result=itemgetter("query") | execute_query
)
| rephrase_answer
)


In [None]:
question = "How many cutomers with order count more than 5"
response = chain.invoke({"question": question,"messages":history.messages})

In [None]:
history.add_user_message(question)
history.add_ai_message(response)
history.messages

In [None]:
response = chain.invoke({"question": "Can you list there names?","messages":history.messages})
response

In [20]:

# Initialize language model
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)

# Create SQL query chain
generate_query = create_sql_query_chain(llm, db)

# Example query
question = "How many products are there?"
response = generate_query.invoke({"question": question})
print(f"Generated Query:{response}")


Generated Query:SELECT COUNT("productId") AS total_products
FROM "Products";


In [21]:

# Execute the query
execute_query = QuerySQLDataBaseTool(db=db)
result = execute_query.invoke(response)

# Format results for readability
def format_results(result):
    if isinstance(result, list) and len(result) > 0:
        return "".join([str(row) for row in result])
    return "No results found."

formatted_result = format_results(result)
print(f"Query Result:{formatted_result}")


Query Result:No results found.


In [None]:

# Store context using ChatMessageHistory
from langchain.memory import ChatMessageHistory

history = ChatMessageHistory()
history.add_user_message(question)
history.add_ai_message(formatted_result)

# View stored context
for msg in history.messages:
    print(f"{msg.type.capitalize()}: {msg.content}")
