In [71]:
from pathlib import Path
from langchain.agents import initialize_agent, Tool
from langchain.sql_database import SQLDatabase
from langchain.agents.agent_types import AgentType
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from sqlalchemy import create_engine
import sqlite3
from langchain_groq import ChatGroq
from dotenv import load_dotenv
from table_relationships import describe_table_relationships
from tbl_col_info import table_info_and_examples
import os
import pandas as pd
import re
import requests
import glob
from sqlalchemy.types import Date

In [72]:
from langgraph.graph import StateGraph, START, END
from typing import Dict, Any, TypedDict, Annotated
import operator
import pickle
from IPython.display import Image

from thefuzz import process
from datetime import datetime
import json
import tqdm


import pandas as pd
from sqlalchemy import create_engine,  text
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.types import Integer, Float, String

In [73]:
from agents.sql_cleaned_query_agent import clean_query_node
from agents.find_tables import find_tables_node
from agents.create_sql_query import create_sql_query
from agents.execute_sql_query import execute_sql_query
from agents.rewrite_sql_query import rewrite_sql_query
from agents.summarise_query_results import summarise_results

In [74]:
# Load .env file and get GROQ API key
load_dotenv()
api_key = os.getenv("GROQ_API_KEY")

In [75]:
import urllib.parse 

# Load engine and knowledge base
password = urllib.parse.quote_plus("Iameighteeni@18")

In [76]:
from langchain_openai import ChatOpenAI

In [77]:

llm = ChatOpenAI(
    model="gpt-4o",
    temperature=0,
    max_tokens=None,
    timeout=None,
    max_retries=2,
    # api_key="...",  # if you prefer to pass api key in directly instaed of using env vars
    # base_url="...",
    # organization="...",
    # other params...
)

In [78]:
# Configure and return SQLite database connectiondef configure_db():

def configure_db():
    # âœ… Create MySQL engine using pymysql
    mysql_engine = create_engine(
        f"postgresql+psycopg2://postgres:{password}@localhost:5432/LLM_Haldiram_primary"
    )

    csv_folder = Path.cwd() / "cooked_data_gk"
    for csv_file in glob.glob(str(csv_folder / "*.csv")):
        table_name = Path(csv_file).stem.lower()
        df = pd.read_csv(csv_file)

        # âœ… Save each CSV as table in MySQL
        df.to_sql(name=table_name, con=mysql_engine, index=False, if_exists="replace")
        print(f"âœ… Loaded table: {table_name}")
        
                # âœ… Convert bill_date column from text to date
        if table_name == "tbl_primary":
            with mysql_engine.connect() as conn:
                conn.execute(
                    text("""
                        ALTER TABLE public.tbl_primary
                        ALTER COLUMN bill_date TYPE date
                        USING TO_DATE(bill_date, 'DD/MM/YY')
                    """)
                )
                conn.commit()
                print("âœ… Converted bill_date to DATE type")

    # âœ… Return LangChain-compatible MySQL connection using pymysql
    return mysql_engine,SQLDatabase.from_uri(
        f"postgresql+psycopg2://postgres:{password}@localhost:5432/LLM_Haldiram_primary"
    )

# ðŸ”Œ Connect to DB and print tables
mysql_engine, db = configure_db()
print("ðŸ“„ Tables Loaded:", db.get_table_names())

âœ… Loaded table: tbl_distributor_master
âœ… Loaded table: tbl_primary
âœ… Converted bill_date to DATE type
âœ… Loaded table: tbl_product_master
âœ… Loaded table: tbl_superstockist_master
ðŸ“„ Tables Loaded: ['tbl_distributor_master', 'tbl_primary', 'tbl_product_master', 'tbl_superstockist_master']


In [79]:
print(db.dialect)
print(db.get_usable_table_names())

postgresql
['tbl_distributor_master', 'tbl_primary', 'tbl_product_master', 'tbl_superstockist_master']


In [80]:
class finalstate(TypedDict):
    user_query: str
    cleaned_user_query: str
    tables: list[str]
    dataframe : pd.DataFrame
    failed_query: Annotated[list[str], operator.add] 
    query_error_message : Annotated[list[str], operator.add]
    retry_count : int
    is_empty_result : bool
    sql_query: str
    query_results: str
    summary_results: str
    

In [81]:
### Create a node to check wether the query was failed or not
def check_query_failed(state: finalstate) -> finalstate:
    # if state["is_empty_result"]:
    #     return "empty_result" 
    if state["query_results"] == "Query executed successfully":
        return "success"
    else:
        return "failed"

## Defining Nodes

In [82]:
graph = StateGraph(finalstate)

graph.add_node("clean_query_node", clean_query_node)
graph.add_node("find_tables_node", find_tables_node)
graph.add_node("create_sql_query", create_sql_query)
graph.add_node("execute_sql_query", execute_sql_query)
graph.add_node("rewrite_sql_query", rewrite_sql_query)
graph.add_node("summarise_results", summarise_results)

# edges
graph.add_edge(START, 'clean_query_node')
graph.add_edge('clean_query_node', 'find_tables_node')
graph.add_edge('find_tables_node', 'create_sql_query')
graph.add_edge('create_sql_query', 'execute_sql_query')
graph.add_conditional_edges('execute_sql_query', check_query_failed, {"success" : "summarise_results", "failed" : "rewrite_sql_query"})
graph.add_edge('rewrite_sql_query', 'execute_sql_query')
graph.add_edge('summarise_results', END)

workflow = graph.compile()

In [83]:
# import nest_asyncio
# nest_asyncio.apply()


In [84]:
# from langchain_core.runnables.graph_mermaid import MermaidDrawMethod
# from IPython.display import Image

# # This will now work in Jupyter
# Image(workflow.get_graph().draw_mermaid_png(draw_method=MermaidDrawMethod.PYPPETEER))



In [85]:
initial_state = {"user_query" : "Total sales made last 8 months for each distributor", "is_empty_result" : False}


In [86]:
result = workflow.invoke(initial_state)

In [87]:
result

{'user_query': 'Total sales made last 8 months for each distributor',
 'cleaned_user_query': 'Please specify the distributor names from the list of valid distributor entities, and I will provide the total sales made over the last 8 months for each distributor accordingly.',
 'tables': ['tbl_distributor_master', 'tbl_Primary'],
 'dataframe': "[('SAWARIYA TRADING 41496', Decimal('4512972')), ('V H TRADING COMPANY 41303 -DELHI', Decimal('2386210'))]",
 'failed_query': [],
 'query_error_message': [],
 'is_empty_result': False,
 'sql_query': "SELECT \n    dm.distributor_name, \n    COALESCE(SUM(p.invoiced_total_quantity), 0) AS total_sales\nFROM \n    tbl_distributor_master dm\nJOIN \n    tbl_Primary p ON dm.distributor_erp_id = p.distributor_id\nWHERE \n    p.bill_date >= date_trunc('month', CURRENT_DATE - INTERVAL '8 months')\nGROUP BY \n    dm.distributor_name",
 'query_results': 'Query executed successfully',
 'summary_results': 'SAWARIYA TRADING sold 4,512,972 units, and V H TRADING CO

In [88]:
print(result.get("sql_query"))

SELECT 
    dm.distributor_name, 
    COALESCE(SUM(p.invoiced_total_quantity), 0) AS total_sales
FROM 
    tbl_distributor_master dm
JOIN 
    tbl_Primary p ON dm.distributor_erp_id = p.distributor_id
WHERE 
    p.bill_date >= date_trunc('month', CURRENT_DATE - INTERVAL '8 months')
GROUP BY 
    dm.distributor_name


# Old code

In [89]:
# # Tools
# relationship_tool = Tool(
#     name="TableRelationships",
#     func=describe_table_relationships,
#     description="Use this tool to understand how tables are related before writing SQL queries."
# )
# info_example_tool = Tool(
#     name="TableInfoAndExamples",
#     func=table_info_and_examples,
#     description="Use this tool to understand available tables and columns see example queries.",
# )

In [90]:
# toolkit = SQLDatabaseToolkit(db=db, llm=llm)
# tools = toolkit.get_tools() + [relationship_tool, info_example_tool]

In [91]:
# tools

In [92]:
# agent = initialize_agent(
#     tools=tools,
#     llm=llm,
#     agent_type=AgentType.OPENAI_FUNCTIONS,
#     verbose=True,
#     handle_parsing_errors=True,
#     agent_kwargs={
#         "system_message": """
# You are an expert SQL assistant helping query a SQLite-based retail database which gives one line answer.

# âœ… Instructions:
# 1. If table info is missing, always use `TableInfoAndExamples` first and then check for 
# 'TableRelationships' for the relationships between tables.

# 2. Use this response format:
#    - Action: <tool-name or Final Answer>
#    - Action Input: <input>
# 3. Execute real SQL queries after table identification.
# 4. Only show SELECT query results, not intermediate thoughts or tools.
# 5. When asked about "top discount schemes", use:
# SELECT name, discount_percent FROM tbl_scheme WHERE is_active = 1 ORDER BY discount_percent DESC;

# 6. When asked about MRP of the brands, check in the tbl_product_master
# 7. When asked about products with have top discounts then use 'TableRelationships' for joining and fetch products.
# """
#     }
# )

In [93]:
# # Chat input
# user_query = input()

In [94]:
# response = agent.run(user_query)

In [95]:
# response

In [96]:
# sql_match = re.search(r"(SELECT\s.+?;)", response, re.IGNORECASE | re.DOTALL)
# print(sql_match)

In [97]:
# try:
#     sql_match = re.search(r"(SELECT\s.+?;)", response, re.IGNORECASE | re.DOTALL)
#     if sql_match:
#         query = sql_match.group(1).strip()
#         result = db.run(query)

#         if isinstance(result, list):
#             df = pd.DataFrame(result)
#             print("Executed SQL Query:")
#             print(query)

#             # Clean column names for display
#             df.columns = [col.replace("_", " ").title() for col in df.columns]

#             # Round numeric columns (like discount)
#             for col in df.select_dtypes(include=['float']):
#                 df[col] = df[col].round(2)

#             # Format currency columns if MRP exists
#             if "Mrp" in df.columns:
#                 df["Mrp"] = df["Mrp"].apply(lambda x: f"â‚¹{x:.2f}")

#             # Display DataFrame
#             display(df)

#         else:
#             print("Query result:")
#             print(result)
#     else:
#         print("No SQL query found in response.")
#         print(response)
# except Exception as e:
#     print("Error:", e)