In [139]:
from pathlib import Path
from langchain.agents import initialize_agent, Tool
from langchain.sql_database import SQLDatabase
from langchain.agents.agent_types import AgentType
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from sqlalchemy import create_engine
import sqlite3
from langchain_groq import ChatGroq
from dotenv import load_dotenv
from table_relationships import describe_table_relationships
from tbl_col_info import table_info_and_examples
import os
import pandas as pd
import re
import requests
import glob

In [140]:
from langgraph.graph import StateGraph, START, END
from typing import Dict, Any, TypedDict, Annotated
from operator import add
import pickle
from IPython.display import Image

from thefuzz import process
from datetime import datetime
import json
import tqdm


import pandas as pd
from sqlalchemy import create_engine,  text
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.types import Integer, Float, String

In [141]:
from agents.sql_cleaned_query_agent import clean_query_node
from agents.find_tables import find_tables_node
from agents.create_sql_query import create_sql_query
from agents.execute_sql_query import execute_sql_query

In [142]:
# Load .env file and get GROQ API key
load_dotenv()
api_key = os.getenv("GROQ_API_KEY")

In [143]:
from langchain_openai import ChatOpenAI

In [144]:

llm = ChatOpenAI(
    model="gpt-4o",
    temperature=0,
    max_tokens=None,
    timeout=None,
    max_retries=2,
    # api_key="...",  # if you prefer to pass api key in directly instaed of using env vars
    # base_url="...",
    # organization="...",
    # other params...
)

In [145]:
# Configure and return SQLite database connectiondef configure_db():

def configure_db():
    # ✅ Create MySQL engine using pymysql
    mysql_engine = create_engine(
        "mysql+pymysql://root:Iameighteeni%4018@127.0.0.1:3306/txt2sql"
    )

    csv_folder = Path.cwd() / "cooked_data_gk"
    for csv_file in glob.glob(str(csv_folder / "*.csv")):
        table_name = Path(csv_file).stem.lower()
        df = pd.read_csv(csv_file)

        # ✅ Save each CSV as table in MySQL
        df.to_sql(name=table_name, con=mysql_engine, index=False, if_exists="replace")
        print(f"✅ Loaded table: {table_name}")

    # ✅ Return LangChain-compatible MySQL connection using pymysql
    return mysql_engine,SQLDatabase.from_uri(
        "mysql+pymysql://root:Iameighteeni%4018@127.0.0.1:3306/txt2sql"
    )

# 🔌 Connect to DB and print tables
mysql_engine, db = configure_db()
print("📄 Tables Loaded:", db.get_table_names())

✅ Loaded table: distributor_closing_stock
✅ Loaded table: product_master
✅ Loaded table: retailer_master
✅ Loaded table: retailer_order_product_details
✅ Loaded table: retailer_order_summary
✅ Loaded table: scheme_details
📄 Tables Loaded: ['distributor_closing_stock', 'product_master', 'retailer_master', 'retailer_order_product_details', 'retailer_order_summary', 'scheme_details']


In [146]:
print(db.dialect)
print(db.get_usable_table_names())

mysql
['distributor_closing_stock', 'product_master', 'retailer_master', 'retailer_order_product_details', 'retailer_order_summary', 'scheme_details']


In [None]:
class finalstate(TypedDict):
    user_query: str
    cleaned_user_query: str
    tables: list[str]
    failed_query : str
    # order_out: str
    # product_out: str
    # filtered_col : str
    # filter_extractor: list[str]
    # fuzz_match: list[str]
    sql_query: str
    query_results: str

## Defining Nodes

In [148]:
graph = StateGraph(finalstate)

graph.add_node("clean_query_node", clean_query_node)
graph.add_node("find_tables_node", find_tables_node)
graph.add_node("create_sql_query", create_sql_query)
graph.add_node("execute_sql_query", execute_sql_query)

# edges
graph.add_edge(START, 'clean_query_node')
graph.add_edge('clean_query_node', 'find_tables_node')
graph.add_edge('find_tables_node', 'create_sql_query')
graph.add_edge('create_sql_query', 'execute_sql_query')

graph.add_edge('execute_sql_query', END)

workflow = graph.compile()

In [149]:
# import nest_asyncio
# nest_asyncio.apply()


In [150]:
# from langchain_core.runnables.graph_mermaid import MermaidDrawMethod
# from IPython.display import Image

# # This will now work in Jupyter
# Image(workflow.get_graph().draw_mermaid_png(draw_method=MermaidDrawMethod.PYPPETEER))



In [151]:
initial_state = {"user_query" : "what is the discount or scheme I willl get 10 qts of Eno"}


In [152]:
workflow.invoke(initial_state)

{'user_query': 'what is the discount or scheme I willl get 10 qts of Eno',
 'cleaned_user_query': 'To find out the discount or scheme applicable for purchasing 10 units of Eno, you should refer to the "scheme_details" table. Specifically, look for any active promotional schemes where the "sku_code" is associated with Eno, such as "ENO001". Check the "discount_percent" and "scheme_name" columns to identify the discount percentage and the name of the scheme. Additionally, ensure that the scheme is currently active by verifying the "is_active" column is \'Y\' and that the current date falls within the "start_date_time" and "end_date_time" range.',
 'tables': ['scheme_details'],
 'sql_query': "SELECT scheme_name, discount_percent\nFROM scheme_details\nWHERE sku_code = 'ENO001'\n  AND is_active = 'Y'\n  AND CURRENT_TIMESTAMP BETWEEN STR_TO_DATE(start_date_time, '%d/%m/%y %H:%i') AND STR_TO_DATE(end_date_time, '%d/%m/%y %H:%i')",
 'query_result': "The query is designed to retrieve informatio

# Old code

In [153]:
# # Tools
# relationship_tool = Tool(
#     name="TableRelationships",
#     func=describe_table_relationships,
#     description="Use this tool to understand how tables are related before writing SQL queries."
# )
# info_example_tool = Tool(
#     name="TableInfoAndExamples",
#     func=table_info_and_examples,
#     description="Use this tool to understand available tables and columns see example queries.",
# )

In [154]:
# toolkit = SQLDatabaseToolkit(db=db, llm=llm)
# tools = toolkit.get_tools() + [relationship_tool, info_example_tool]

In [155]:
# tools

In [156]:
# agent = initialize_agent(
#     tools=tools,
#     llm=llm,
#     agent_type=AgentType.OPENAI_FUNCTIONS,
#     verbose=True,
#     handle_parsing_errors=True,
#     agent_kwargs={
#         "system_message": """
# You are an expert SQL assistant helping query a SQLite-based retail database which gives one line answer.

# ✅ Instructions:
# 1. If table info is missing, always use `TableInfoAndExamples` first and then check for 
# 'TableRelationships' for the relationships between tables.

# 2. Use this response format:
#    - Action: <tool-name or Final Answer>
#    - Action Input: <input>
# 3. Execute real SQL queries after table identification.
# 4. Only show SELECT query results, not intermediate thoughts or tools.
# 5. When asked about "top discount schemes", use:
# SELECT name, discount_percent FROM tbl_scheme WHERE is_active = 1 ORDER BY discount_percent DESC;

# 6. When asked about MRP of the brands, check in the tbl_product_master
# 7. When asked about products with have top discounts then use 'TableRelationships' for joining and fetch products.
# """
#     }
# )

In [157]:
# # Chat input
# user_query = input()

In [158]:
# response = agent.run(user_query)

In [159]:
# response

In [160]:
# sql_match = re.search(r"(SELECT\s.+?;)", response, re.IGNORECASE | re.DOTALL)
# print(sql_match)

In [161]:
# try:
#     sql_match = re.search(r"(SELECT\s.+?;)", response, re.IGNORECASE | re.DOTALL)
#     if sql_match:
#         query = sql_match.group(1).strip()
#         result = db.run(query)

#         if isinstance(result, list):
#             df = pd.DataFrame(result)
#             print("Executed SQL Query:")
#             print(query)

#             # Clean column names for display
#             df.columns = [col.replace("_", " ").title() for col in df.columns]

#             # Round numeric columns (like discount)
#             for col in df.select_dtypes(include=['float']):
#                 df[col] = df[col].round(2)

#             # Format currency columns if MRP exists
#             if "Mrp" in df.columns:
#                 df["Mrp"] = df["Mrp"].apply(lambda x: f"₹{x:.2f}")

#             # Display DataFrame
#             display(df)

#         else:
#             print("Query result:")
#             print(result)
#     else:
#         print("No SQL query found in response.")
#         print(response)
# except Exception as e:
#     print("Error:", e)