In [1]:
from langchain_google_genai import ChatGoogleGenerativeAI


In [None]:
api_key = "enter_your_google_api_key_here"
llm = ChatGoogleGenerativeAI(
    model="gemini-2.5-pro",
    google_api_key=api_key, 
    temperature=0.2)
#response = llm.invoke("write a poem on my love for dosa")
#print(response.content)


In [3]:
from langchain_community.utilities import SQLDatabase

In [4]:
db_user = "root"
db_password= "root"
db_host = "localhost"
db_name = "atliq_tshirts"

db = SQLDatabase.from_uri(f"mysql+pymysql://{db_user}:{db_password}@{db_host}/{db_name}", sample_rows_in_table_info=3)
#print(db.table_info)

In [5]:
from langchain.chains import create_sql_query_chain

query_chain = create_sql_query_chain(llm, db)

In [6]:
def ask_databasef2(question):
    import ast
    import decimal

    # Step 1 — Generate SQL
    query = query_chain.invoke({"question": question})
    print("Generated SQL Query:\n", query)

    # Step 2 — Extract SQL
    sql = query.split("SQLQuery:")[-1].strip()
    sql = sql.replace("```sql", "").replace("```", "").strip()

    # Step 3 — Run SQL
    raw = db.run(sql)
    print("Answer:", raw)

    # Step 4 — Convert ANY output into a Python list
    raw_str = str(raw)                      # convert object → string
    raw_str = raw_str.replace("Decimal", "")  # remove Decimal(...) wrapper
    try:
        result = ast.literal_eval(raw_str)
    except:
        result = raw

    # --------------------------
    # CASE 1: Single value
    # --------------------------
    if (isinstance(result, list)
        and len(result) == 1
        and isinstance(result[0], tuple)
        and len(result[0]) == 1):
        
        val = result[0][0]
        print(val)
        return val

    # --------------------------
    # CASE 2: Multiple rows
    # --------------------------
    if isinstance(result, list):
        clean = []
        for row in result:
            clean_row = []
            for col in row:
                clean_row.append(col)
            clean.append(tuple(clean_row))
        print(clean)
        return clean

    # Fallback
    return result


In [7]:
q1 = ask_databasef2("how many small size t-shirts do we have of Nike brand?")



Generated SQL Query:
 Question: how many small size t-shirts do we have of Nike brand?
SQLQuery: SELECT sum(`stock_quantity`) FROM t_shirts WHERE `brand` = 'Nike' AND `size` = 'S'
Answer: [(Decimal('178'),)]
178


In [20]:
q2 = ask_databasef2("If we sell all Levi's t-shirts today, how much revenue will we generate?")

Generated SQL Query:
 Question: If we sell all Levi's t-shirts today, how much revenue will we generate?
SQLQuery: SELECT SUM(T1.`price` * T1.`stock_quantity` * (1 - T2.`pct_discount` / 100)) FROM t_shirts AS T1 JOIN discounts AS T2 ON T1.`t_shirt_id` = T2.`t_shirt_id` WHERE T1.`brand` = 'Levi'
Answer: [(Decimal('876.600000'),)]
876.600000


In [21]:
q3 = ask_databasef2("How much total revenue will we generate after applying discounts on all t-shirts?")

Generated SQL Query:
 Question: How much total revenue will we generate after applying discounts on all t-shirts?
SQLQuery: SELECT SUM(T1.`price` * (1 - T2.`pct_discount` / 100) * T1.`stock_quantity`) FROM t_shirts AS T1 JOIN discounts AS T2 ON T1.`t_shirt_id` = T2.`t_shirt_id`
Answer: [(Decimal('13730.000000'),)]
13730.000000


In [22]:
q4 = ask_databasef2("How many t-shirts do not have any discount applied?")

Generated SQL Query:
 Question: How many t-shirts do not have any discount applied?
SQLQuery: SELECT COUNT(*) FROM t_shirts AS T1 LEFT JOIN discounts AS T2 ON T1.t_shirt_id  =  T2.t_shirt_id WHERE T2.discount_id IS NULL
Answer: [(70,)]
70


In [26]:
q5 = ask_databasef2("Find the total discounted revenue for each brand.")

Generated SQL Query:
 Question: Find the total discounted revenue for each brand.
SQLQuery: SELECT T1.`brand`, SUM(T1.`price` * (1 - T2.`pct_discount` / 100) * T1.`stock_quantity`) AS `total_discounted_revenue` FROM t_shirts AS T1 JOIN discounts AS T2 ON T1.`t_shirt_id` = T2.`t_shirt_id` GROUP BY T1.`brand`
Answer: [('Levi', Decimal('876.600000')), ('Nike', Decimal('6911.300000')), ('Van Huesen', Decimal('5942.100000'))]
[('Levi', '876.600000'), ('Nike', '6911.300000'), ('Van Huesen', '5942.100000')]


In [28]:
print(q1, q2, q3, q4, q5, sep="\n")


178
876.600000
13730.000000
70
[('Levi', '876.600000'), ('Nike', '6911.300000'), ('Van Huesen', '5942.100000')]


Few Shot Learnig

In [None]:
few_shots = [
    {
        "Question": "How many small size t-shirts do we have of Nike brand?",
        "SQLQuery": "SELECT SUM(stock_quantity) FROM t_shirts WHERE brand = 'Nike' AND size = 'S';",
        "SQLResult": "Result of this SQL query",
        "Answer": q1
    },
    {
        "Question": "If we sell all Levi's t-shirts today, how much revenue will we generate?",
        "SQLQuery": "SELECT SUM(price * stock_quantity) FROM t_shirts WHERE brand='Levi';",
        "SQLResult": "Result of the SQL query",
        "Answer": q2
    },
    {
        "Question": "How much total revenue will we generate after applying discounts on all t-shirts?",
        "SQLQuery": "SELECT SUM(T1.price * T1.stock_quantity * (1 - IFNULL(T2.pct_discount,0)/100)) FROM t_shirts AS T1 LEFT JOIN discounts AS T2 ON T1.t_shirt_id = T2.t_shirt_id;",
        "SQLResult": "Result of the SQL query",
        "Answer": q3
    },
    {
        "Question": "How many t-shirts do not have any discount applied?",
        "SQLQuery": "SELECT COUNT(*) FROM t_shirts WHERE t_shirt_id NOT IN (SELECT t_shirt_id FROM discounts);",
        "SQLResult": "Result of the SQL query",
        "Answer": q4
    },
    {
        "Question": "Find the total discounted revenue for each brand.",
        "SQLQuery": "SELECT T1.brand, SUM(T1.price * T1.stock_quantity * (1 - IFNULL(T2.pct_discount,0)/100)) AS total_revenue FROM t_shirts AS T1 LEFT JOIN discounts AS T2 ON T1.t_shirt_id = T2.t_shirt_id GROUP BY T1.brand ORDER BY total_revenue DESC;",
        "SQLResult": "Result of the SQL query",
        "Answer": q5
}
]



In [23]:
from langchain.embeddings import HuggingFaceEmbeddings

embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")

  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")


In [24]:
e = embeddings.embed_query("How many small size t-shirts do we have of Nike brand?") 

In [25]:
e[ :5]

[0.06397144496440887,
 0.029311122372746468,
 -0.004600989632308483,
 -0.02196970023214817,
 0.05475088953971863]

In [56]:
to_vectorize = [" ".join([str(v) for v in example.values()]) for example in few_shots]



In [31]:
from langchain.vectorstores import Chroma

clean_metadata = [
    {
        "Question": fs["Question"],
        "SQLQuery": fs["SQLQuery"],
        "SQLResult": fs["SQLResult"],
        "Answer": str(fs["Answer"])
    }
    for fs in few_shots
]

vectorstore = Chroma.from_texts(to_vectorize, embedding=embeddings, metadatas=clean_metadata)

Failed to send telemetry event ClientStartEvent: capture() takes 1 positional argument but 3 were given
Failed to send telemetry event ClientCreateCollectionEvent: capture() takes 1 positional argument but 3 were given


In [32]:
from langchain.prompts import SemanticSimilarityExampleSelector
exemple_selector = SemanticSimilarityExampleSelector(
    vectorstore=vectorstore,
    k=3
)
exemple_selector.select_examples({"Question": "How many small size t-shirts do we have of Levi's brand?"})

Failed to send telemetry event CollectionQueryEvent: capture() takes 1 positional argument but 3 were given


[{'Answer': '178',
  'Question': 'How many small size t-shirts do we have of Nike brand?',
  'SQLQuery': "SELECT SUM(stock_quantity) FROM t_shirts WHERE brand = 'Nike' AND size = 'S';",
  'SQLResult': 'Result of this SQL query'},
 {'Answer': '876.600000',
  'Question': "If we sell all Levi's t-shirts today, how much revenue will we generate?",
  'SQLQuery': "SELECT SUM(price * stock_quantity) FROM t_shirts WHERE brand='Levi';",
  'SQLResult': 'Result of the SQL query'},
 {'Answer': '70',
  'Question': 'How many t-shirts do not have any discount applied?',
  'SQLQuery': 'SELECT COUNT(*) FROM t_shirts WHERE t_shirt_id NOT IN (SELECT t_shirt_id FROM discounts);',
  'SQLResult': 'Result of the SQL query'}]

In [38]:
from langchain.chains.sql_database.prompt import MYSQL_PROMPT, PROMPT_SUFFIX


In [42]:
from langchain.prompts.prompt import PromptTemplate

example_prompt = PromptTemplate(
    input_variables=["Question", "SQLQuery", "SQLResult", "Answer"],
    template="""
Question: {Question}
SQLQuery: {SQLQuery}
SQLResult: {SQLResult}
Answer: {Answer}
"""
)


In [44]:
from langchain.prompts import FewShotPromptTemplate

few_shot_prompt = FewShotPromptTemplate(
    example_selector=exemple_selector,
    example_prompt=example_prompt,
    prefix=MYSQL_PROMPT.template,
    suffix=PROMPT_SUFFIX,
    input_variables=["input", "table_info", "top_k"],  # These variables are used in the prompt
)


In [None]:
#from langchain.chains.sql_database.query import create_sql_query_chain

new_chain = create_sql_query_chain(
    llm=llm,
    db=db,
    prompt=few_shot_prompt
)

In [53]:
def ask_databasel1(question):
    import ast
    import decimal

    # Step 1 — Generate SQL using the NEW CHAIN
    query = new_chain.invoke({"question": question})
    print("Generated SQL Query:\n", query)

    # Step 2 — Extract SQL
    sql = query.split("SQLQuery:")[-1].strip()
    sql = sql.replace("```sql", "").replace("```", "").strip()

    # Step 3 — Run SQL
    raw = db.run(sql)
    print("Answer:", raw)

    # Step 4 — Convert ANY output into a Python list
    raw_str = str(raw)
    raw_str = raw_str.replace("Decimal", "")  
    try:
        result = ast.literal_eval(raw_str)
    except:
        result = raw

    # CASE 1: Single value
    if (isinstance(result, list)
        and len(result) == 1
        and isinstance(result[0], tuple)
        and len(result[0]) == 1):
        
        val = result[0][0]
        print(val)
        return val

    # CASE 2: Multiple rows
    if isinstance(result, list):
        clean = []
        for row in result:
            clean_row = []
            for col in row:
                clean_row.append(col)
            clean.append(tuple(clean_row))
        print(clean)
        return clean

    # Fallback
    return result


In [59]:
q = ask_databasel1("Which brand has the highest total discounted revenue, and what is that revenue?")

Generated SQL Query:
 SELECT T1.`brand`, SUM(T1.`price` * T1.`stock_quantity` * (100 - IFNULL(T2.`pct_discount`, 0)) / 100) AS `total_revenue` FROM t_shirts AS T1 LEFT JOIN discounts AS T2 ON T1.`t_shirt_id` = T2.`t_shirt_id` GROUP BY T1.`brand` ORDER BY `total_revenue` DESC LIMIT 1
Answer: [('Van Huesen', Decimal('34486.100000'))]
[('Van Huesen', '34486.100000')]


In [2]:
import Main


In [3]:
print(Main.get_few_shot_db_chain)

<function get_few_shot_db_chain at 0x0000020342318FE0>


In [2]:
import sys
print(sys.path)


['C:\\Users\\karan\\AppData\\Local\\Programs\\Python\\Python311\\python311.zip', 'C:\\Users\\karan\\AppData\\Local\\Programs\\Python\\Python311\\DLLs', 'C:\\Users\\karan\\AppData\\Local\\Programs\\Python\\Python311\\Lib', 'C:\\Users\\karan\\AppData\\Local\\Programs\\Python\\Python311', 'd:\\VS\\MINI Project\\palm_env_new', '', 'd:\\VS\\MINI Project\\palm_env_new\\Lib\\site-packages', 'd:\\VS\\MINI Project\\palm_env_new\\Lib\\site-packages\\win32', 'd:\\VS\\MINI Project\\palm_env_new\\Lib\\site-packages\\win32\\lib', 'd:\\VS\\MINI Project\\palm_env_new\\Lib\\site-packages\\Pythonwin']
