In [127]:
import pandas as pd
from langchain_community.utilities import SQLDatabase
from sqlalchemy import create_engine

# df = pd.read_csv(
#     "/Users/job/Desktop/project_x/project-sql-agent/services/database/test/sales_suppliers.csv"
# )
# engine = create_engine("sqlite:///data.db")
# df.to_sql("data", engine, index=False, if_exists="replace")
# db = SQLDatabase(engine=engine)

In [128]:
db_uri = "postgresql+psycopg2://postgres:postgres@localhost:5432/postgres"

# 2. สร้าง Engine เชื่อมต่อไปที่ Docker
engine = create_engine(db_uri)
db = SQLDatabase(engine)
print("✅ Connected to PostgreSQL successfully!")
print("Tables found:", db.get_usable_table_names())

✅ Connected to PostgreSQL successfully!
Tables found: []


In [123]:
from pydantic import BaseModel, Field, ConfigDict
from langgraph.graph.message import add_messages
from typing import Literal, Optional

class State(BaseModel):
    sql_question: str          
    sql_query: str = ""             
    sql_query_execution_status: Literal["success", "failure"] = "failure"
    sql_error_count: int = 0  
    sql_query_error: str = ""  
    sql_result: str = ""       
    sql_answer: str = ""

In [124]:
# from typing_extensions import Annotated

# class QueryOutput(BaseModel):
#     """Generated SQL query."""
#     generated_sql_query: Annotated[str, ..., "Syntactically valid SQL query."]

In [125]:
from typing import Annotated
from pydantic import BaseModel, Field

class QueryOutput(BaseModel):
    generated_sql_query: Annotated[str, Field(description="Syntactically valid SQL query.")]

In [126]:
print(db.get_table_info())




In [82]:
from langchain_core.prompts import ChatPromptTemplate

system_message = """
Given an input question, create a syntactically correct {dialect} query to run to help find the answer.

Never query for all the columns from a specific table, only ask for a few relevant columns given the question.

Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.

Only use the following tables:
{table_info}

Schema highlights:
- TransactionYear: the 4-digit year (e.g., 2024)
- TransactionMonth: an integer from 1 to 12 representing the month
- TransactionDay: an integer from 1 to 31 representing the day of the month
    Example: To get all transactions that occurred on January 16, 2024, use: WHERE TransactionYear = 2024 AND TransactionMonth = 1 AND TransactionDay = 16
- Category (e.g. 'Healthcare', 'Electronics', 'Unknown')
- Merchant (e.g. 'MediCare', 'ElectroWorld', 'Unknown')
- Channel (one of 'Online', 'ATM', 'Branch')
- CustomerAge (INT, in years)
- CustomerOccupation (one of 'Retired', 'Engineer', 'Student', 'Teacher', 'Doctor')
- CreditScore (INT, 300-850)
- RiskProfile (one of 'Low', 'Medium', 'High')
- CustomerTenure (INT, years with bank)
- PreferredSpendingCategory (e.g. 'Electronics', 'Entertainment', ...)
- MostInterestedProduct (e.g. 'Home Loan', 'Credit Card', ...)
- IncomeBracket (string ranges: '<25K', '25K-50K','50K-100K', '150K-200K', '200K-250K', '300K-350K')
 - For comparisons (">200K"), include all brackets whose lower bound exceeds the threshold (e.g., IN ('200K-250K', '300K-350K')).
"""

user_prompt = """
Question: {input}
Use the following error information if there is any: {query_error}
"""

query_prompt_template = ChatPromptTemplate(
    [("system", system_message), ("user", user_prompt)]
)

In [83]:
from config import GROQ_MODEL,Config
from langchain_groq import ChatGroq

llm = ChatGroq(model=GROQ_MODEL, groq_api_key=Config.groq_api_key)


In [84]:
def write_query(state: State):
    """Generate SQL query to fetch information."""
    prompt = query_prompt_template.invoke(
        {
            "dialect": db.dialect,
            "table_info": db.get_table_info(),
            "input": state.sql_question,  # Changed from state["sql_question"]
            "query_error": state.sql_query_error or "",  # Changed from state.get("sql_query_error", "")
        }
    )
    structured_llm = llm.with_structured_output(QueryOutput)
    result = structured_llm.invoke(prompt)
    state.sql_query = result.generated_sql_query  # Changed from state["sql_query"]

    return state

In [111]:
def write_query(state: State):
    """Generate SQL query to fetch information."""
    try:
        prompt = query_prompt_template.invoke(
            {
                "dialect": db.dialect,
                "table_info": db.get_table_info(),
                "input": state.sql_question,
                "query_error": state.sql_query_error or "",
            }
        )
        structured_llm = llm.with_structured_output(QueryOutput)
        result = structured_llm.invoke(prompt)
        state.sql_query = result.generated_sql_query
    except Exception as e:
        # Handle cases where LLM refuses to generate query (e.g., question can't be answered with available schema)
        error_msg = str(e)
        if hasattr(e, 'response') and hasattr(e.response, 'body'):
            # Try to extract more detailed error message
            try:
                import json
                error_body = json.loads(e.response.body)
                if 'error' in error_body and 'failed_generation' in error_body.get('error', {}):
                    error_msg = error_body['error']['failed_generation']
            except:
                pass
        
        state.sql_query = ""  # No query generated
        state.sql_query_execution_status = "failure"
        state.sql_query_error = f"Error generating query: {error_msg}"
        state.sql_error_count = state.sql_error_count + 1
    
    return state

In [112]:
state = State(**{
    "sql_question": "How many customer are there"
})
result_state = write_query(state)
print(result_state.sql_query, result_state.sql_query_error)
# Output: {'sql_query': "SELECT COUNT(TransactionID) FROM data WHERE UserID = 'U001';"}

SELECT COUNT(*) FROM sales_suppliers; 


In [113]:
from langchain_community.tools.sql_database.tool import QuerySQLDatabaseTool

def execute_query(state: State):
    """Execute SQL query and set query_execution_status."""
    execute_query_tool = QuerySQLDatabaseTool(db=db)
    result = execute_query_tool.invoke(state.sql_query)  # Changed from state["sql_query"]

    if isinstance(result, str) and result.startswith("Error:"):
        state.sql_query_execution_status = "failure"  # Changed from state["sql_query_execution_status"]
        state.sql_query_error = result  # Changed from state["sql_query_error"]
        state.sql_error_count = state.sql_error_count + 1  # Changed from state.get() and state["sql_error_count"]

    else:
        state.sql_result = result  # Changed from state["sql_result"]
        state.sql_query_execution_status = "success"  # Changed from state["sql_query_execution_status"]
        state.sql_query_error = ""  # Changed from state["sql_query_error"]
        state.sql_error_count = 0   # Changed from state["sql_error_count"]
       
    return state

In [114]:
state = State(**{
    "sql_question": "How many city exist in the table",  # Required field
    "sql_query": "SELECT COUNT(DISTINCT city) FROM sales_suppl"
})
result_state = execute_query(state)
print(result_state.sql_result, result_state.sql_query_execution_status)

 failure


In [115]:
def cannot_answer(state: State):
  """Cannot answer question; return a default answer."""

  state.sql_error_count = 0  # Changed from state["sql_error_count"]
  state.sql_answer = "I'm sorry, but I cannot find the information you're looking for."  # Changed from state["sql_answer"]

  return state

In [116]:
def sql_router(state: State):
  """Routes to generate_answer, cannot_answer or write_query based on query_execution_status."""

  if state.sql_query_execution_status == "success":  # Changed from state["sql_query_execution_status"]
    return "generate_answer"

  elif state.sql_query_execution_status == "failure":  # Changed from state["sql_query_execution_status"]
    if state.sql_error_count < 2:  # Changed from state["sql_error_count"]
      return "write_query"

    else:
      return "cannot_answer"

In [107]:
def generate_answer(state: State):
    prompt = (
        "/no_think\n"
        "You are a business assistant responding to a manager's queriesthe Question. in short sentence\n"
        "Given the manager's question and the result of the internal SQL query used to retrieve the relevant data, answer the question clearly and professionally.\n"
        "Use a well-formatted table with clear headers **only if** the question requires structured data, such as a list of transactions, balances over time, or multiple entries.\n"
        "Otherwise, respond in plain text that reads naturally.\n"
        "Do not mention SQL queries, databases, or how the data was retrieved.\n"
        "Avoid phrases like 'Hello there!', 'I'm happy to help...', or anything overly formal or robotic.\n"
        "Give a direct, informative, human-like answer as if responding to a manager's internal query.\n\n"
        f"Manager's Question: {state.sql_question}\n"  # Changed from state['sql_question']
        f"Result: {state.sql_result}"  # Changed from state['sql_result']
    )
    llm = ChatGroq(model=GROQ_MODEL, groq_api_key=Config.groq_api_key)
    response = llm.invoke(prompt)
    state.sql_answer = response.content  # Changed from state["sql_answer"]

    return state

In [108]:
generate_answer({'sql_question': "How many city exist in the table", 'sql_query': "SELECT COUNT(DISTINCT city) AS city_count FROM sales_suppliers", 'sql_result': '[(27,)]'})

# Output: {'sql_answer': "There are 20 transactions for UserID 'U001'."}

AttributeError: 'dict' object has no attribute 'sql_question'

In [109]:
from langgraph.graph import StateGraph, END

sql_graph_builder = StateGraph(State)

sql_graph_builder.set_entry_point("write_query")
sql_graph_builder.add_node("write_query", write_query)
sql_graph_builder.add_node("execute_query", execute_query)
sql_graph_builder.add_node("generate_answer", generate_answer)
sql_graph_builder.add_node("cannot_answer", cannot_answer)

sql_graph_builder.add_edge("write_query", "execute_query")
sql_graph_builder.add_conditional_edges(
    "execute_query",
    sql_router,
    {
        "generate_answer": "generate_answer",
        "write_query": "write_query",
        "cannot_answer": "cannot_answer"
    }

)

sql_graph_builder.add_edge("generate_answer", END)
sql_graph_builder.add_edge("cannot_answer", END)

sql_graph_manager = sql_graph_builder.compile()

In [110]:
sql_graph_manager.invoke({"sql_question": "What is the total amount of transactions for UserID 'U001'?"})

# Output: 'sql_answer': "There are 20 transactions for UserID 'U001'."

{'sql_question': "What is the total amount of transactions for UserID 'U001'?",
 'sql_query': '',
 'sql_query_execution_status': 'failure',
 'sql_error_count': 0,
 'sql_query_error': "Error: (psycopg2.ProgrammingError) can't execute an empty query\n(Background on this error at: https://sqlalche.me/e/20/f405)",
 'sql_result': '',
 'sql_answer': "I'm sorry, but I cannot find the information you're looking for."}