In [1]:
import pandas as pd
from langchain_community.utilities import SQLDatabase
from sqlalchemy import create_engine

In [2]:
from app.config import Config

config = Config()
db_uri = config.DATABASE_URI()

# 2. สร้าง Engine เชื่อมต่อไปที่ Docker
engine = create_engine(db_uri)
db = SQLDatabase(engine)
print("✅ Connected to PostgreSQL successfully!")
print("Tables found:", db.get_usable_table_names())

✅ Connected to PostgreSQL successfully!
Tables found: ['media_customer_reviews', 'media_gold_reviews_chunked', 'sales_customers', 'sales_franchises', 'sales_suppliers', 'sales_transactions']


In [None]:
# print(db.get_table_info())


CREATE TABLE media_customer_reviews (
	new_id INTEGER NOT NULL, 
	franchise_id BIGINT, 
	review_date TEXT, 
	review TEXT, 
	CONSTRAINT media_customer_reviews_pkey PRIMARY KEY (new_id), 
	CONSTRAINT fk_media_customer_reviews_franchise_id FOREIGN KEY(franchise_id) REFERENCES sales_franchises (franchise_id)
)

/*
3 rows from media_customer_reviews table:
new_id	franchise_id	review_date	review
1	3000037	2024-05-20 17:24:06.591000+00:00	Title: A Delightful Cookie Experience at Bakehouse in Tenmonkan, Kagoshima

Bakehouse in Tenmonkan, 
2	3000017	2024-05-20 17:17:03.052000+00:00	"Sweet tooth heaven on East 6th Street! I'm obsessed with Bakehouse's Outback Oatmeal cookies - crun
3	3000007	2024-05-20 17:17:03.052000+00:00	**4.5/5 stars**

I stumbled upon Bakehouse in the charming Gion district of Kyoto and was thrilled t
*/


CREATE TABLE media_gold_reviews_chunked (
	franchise_id BIGINT, 
	review_date TEXT, 
	chunked_text TEXT, 
	chunk_id TEXT, 
	review_uri TEXT, 
	CONSTRAINT fk_media_gold_r

In [44]:
from pydantic import BaseModel, Field, ConfigDict
from langgraph.graph.message import add_messages
from typing import Literal, Optional, Annotated
from pydantic import BaseModel, Field
from enum import Enum

class QueryOutput(BaseModel):
    generated_sql_query: Annotated[str, Field(description="Syntactically valid SQL query.")]

class RouterOutput(BaseModel):
    selected_agent: Literal["SQL_AGENT", "PLOT_AGENT", "GENERAL_AGENT"] = Field(
        description="The agent to route the message to"
    )

class State(BaseModel):
    user_message: str = ""       
    sql_query: str = ""             
    sql_query_execution_status: Literal["success", "failure"] = "failure"
    sql_error_count: int = 0  
    sql_query_error: str = ""  
    sql_result: str = ""       
    sql_agent_answer: str = ""
    plot_agent_answer: str = ""
    need_visualise: bool = False
    selected_agent: Optional[str] = None 

In [45]:
from app.config import GROQ_MODEL,Config
from langchain_groq import ChatGroq
from langchain_core.prompts import ChatPromptTemplate

In [46]:
def router_agent(state: State):
    system_message = """
    You are a router agent. You are given a user's message and you need to determine which agent to route the message to. 
    If the user's message is about the data in the database, you should route the message to the SQL_AGENT.
    If the user want to know about the plot, visualisation, graphing of the data, you should route the message to the PLOT_AGENT.
    Otherwise, you will need to route the agent to the GENERAL_AGENT.

    Only use the following agents:
    {agent_list}
    """

    user_prompt = """
    User's message: {user_message}
    """

    router_prompt_template = ChatPromptTemplate(
        [("system", system_message), ("user", user_prompt)]
    )

    prompt = router_prompt_template.invoke(
        {
            "agent_list": ["SQL_AGENT", "PLOT_AGENT", "GENERAL_AGENT"],
            "user_message": state.user_message
        }
    )
    
    # Call LLM with structured output
    llm = ChatGroq(model=GROQ_MODEL, groq_api_key=Config.groq_api_key)
    structured_llm = llm.with_structured_output(RouterOutput)
    result = structured_llm.invoke(prompt)
    
    # Store the selected agent in state (you might want to add this field to State)
    state.selected_agent = result.selected_agent
    
    return state

In [51]:
state = State(**{
    "user_message": "What are the top 3 most popular products by total quantity sold?"
})
result_state = router_agent(state)
print(result_state.selected_agent)

SQL_AGENT


In [52]:
def write_query(state: State):
    
    system_message = """
    Given an input question, create a syntactically correct {dialect} query to run to help find the answer.
    Never query for all the columns from a specific table, only ask for a few relevant columns given the question.
    Pay attention to use only the column names that you can see in the schema description. 
    Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
    
    Only use the following tables:
    {table_info}
    """

    user_prompt = """
    Question: {input}
    Use the following error information if there is any: {query_error}
    """

    query_prompt_template = ChatPromptTemplate(
        [("system", system_message), ("user", user_prompt)]
    )

    try:
        prompt = query_prompt_template.invoke(
            {
                "dialect": db.dialect,
                "table_info": db.get_table_info(),
                "input": state.user_message,
                "query_error": state.sql_query_error,
            }
        )

        llm = ChatGroq(model=GROQ_MODEL, groq_api_key=Config.groq_api_key)
        structured_llm = llm.with_structured_output(QueryOutput)
        result = structured_llm.invoke(prompt)
        state.sql_query = result.generated_sql_query
    except Exception as e:
        error_msg = str(e)
        if hasattr(e, 'response') and hasattr(e.response, 'body'):
            # Try to extract more detailed error message
            try:
                import json
                error_body = json.loads(e.response.body)
                if 'error' in error_body and 'failed_generation' in error_body.get('error', {}):
                    error_msg = error_body['error']['failed_generation']
            except:
                pass
        
        state.sql_query = ""  # No query generated
        state.sql_query_execution_status = "failure"
        state.sql_query_error = f"Error generating query: {error_msg}"
        state.sql_error_count = state.sql_error_count + 1
    
    return state

In [55]:
state = State(**{
    "user_message": "Visualize daily sales trends over time. Create a line chart showing total revenue per day and identify any patterns."
})
result_state = write_query(state)
print(result_state.sql_query, result_state.sql_query_error)
# Output: {'sql_query': "SELECT COUNT(TransactionID) FROM data WHERE UserID = 'U001';"}

SELECT TO_DATE(date_time, 'YYYY-MM-DD') AS sales_date, SUM(total_price) AS daily_revenue FROM sales_transactions GROUP BY sales_date ORDER BY sales_date; 


In [56]:
from langchain_community.tools.sql_database.tool import QuerySQLDatabaseTool

def execute_query(state: State):
    """Execute SQL query and set query_execution_status."""
    execute_query_tool = QuerySQLDatabaseTool(db=db)
    result = execute_query_tool.invoke(state.sql_query)

    if isinstance(result, str) and result.startswith("Error:"):
        state.sql_query_execution_status = "failure"
        state.sql_query_error = result  # Changed from state["sql_query_error"]
        state.sql_error_count = state.sql_error_count + 1  # Changed from state.get() and state["sql_error_count"]

    else:
        state.sql_result = result  # Changed from state["sql_result"]
        state.sql_query_execution_status = "success"  # Changed from state["sql_query_execution_status"]
        state.sql_query_error = ""  # Changed from state["sql_query_error"]
        state.sql_error_count = 0   # Changed from state["sql_error_count"]
       
    return state

In [57]:
state = State(**{
    "user_message": "Visualize daily sales trends over time. Create a line chart showing total revenue per day and identify any patterns.",
    "sql_query": "SELECT TO_DATE(date_time, 'YYYY-MM-DD') AS sales_date, SUM(total_price) AS daily_revenue FROM sales_transactions GROUP BY sales_date ORDER BY sales_date;"
})
result_state = execute_query(state)
print(result_state.sql_result, result_state.sql_query_execution_status)

[(datetime.date(2024, 5, 1), 4128.0), (datetime.date(2024, 5, 2), 4074.0), (datetime.date(2024, 5, 3), 4278.0), (datetime.date(2024, 5, 4), 3822.0), (datetime.date(2024, 5, 5), 3945.0), (datetime.date(2024, 5, 6), 4500.0), (datetime.date(2024, 5, 7), 3894.0), (datetime.date(2024, 5, 8), 3921.0), (datetime.date(2024, 5, 9), 4320.0), (datetime.date(2024, 5, 10), 3729.0), (datetime.date(2024, 5, 11), 3747.0), (datetime.date(2024, 5, 12), 4398.0), (datetime.date(2024, 5, 13), 4044.0), (datetime.date(2024, 5, 14), 4221.0), (datetime.date(2024, 5, 15), 3804.0), (datetime.date(2024, 5, 16), 3714.0), (datetime.date(2024, 5, 17), 1932.0)] success


In [15]:
def cannot_answer(state: State):
  """Cannot answer question; return a default answer."""

  state.sql_error_count = 0  # Changed from state["sql_error_count"]
  state.sql_answer = "I'm sorry, but I cannot find the information you're looking for."  # Changed from state["sql_answer"]

  return state

In [16]:
def sql_router(state: State):
  """Routes to generate_answer, cannot_answer or write_query based on query_execution_status."""

  if state.sql_query_execution_status == "success":  # Changed from state["sql_query_execution_status"]
    return "generate_answer"

  elif state.sql_query_execution_status == "failure":  # Changed from state["sql_query_execution_status"]
    if state.sql_error_count < 2:  # Changed from state["sql_error_count"]
      return "write_query"

    else:
      return "cannot_answer"

In [17]:
def generate_answer(state: State):
    prompt = (
        "/no_think\n"
        "You are a business assistant responding to a manager's queriesthe Question. in short sentence\n"
        "Given the manager's question and the result of the internal SQL query used to retrieve the relevant data, answer the question clearly and professionally.\n"
        "Use a well-formatted table with clear headers **only if** the question requires structured data, such as a list of transactions, balances over time, or multiple entries.\n"
        "Otherwise, respond in plain text that reads naturally.\n"
        "Do not mention SQL queries, databases, or how the data was retrieved.\n"
        "Avoid phrases like 'Hello there!', 'I'm happy to help...', or anything overly formal or robotic.\n"
        "Give a direct, informative, human-like answer as if responding to a manager's internal query.\n\n"
        f"Manager's Question: {state.sql_question}\n"  # Changed from state['sql_question']
        f"Result: {state.sql_result}"  # Changed from state['sql_result']
    )
    llm = ChatGroq(model=GROQ_MODEL, groq_api_key=Config.groq_api_key)
    response = llm.invoke(prompt)
    state.sql_answer = response.content  # Changed from state["sql_answer"]

    return state

In [18]:
generate_answer({'sql_question': "How many city exist in the table", 'sql_query': "SELECT COUNT(DISTINCT city) AS city_count FROM sales_suppliers", 'sql_result': '[(27,)]'})

# Output: {'sql_answer': "There are 20 transactions for UserID 'U001'."}

AttributeError: 'dict' object has no attribute 'sql_question'

## plot agent

In [61]:
def plot_agent(state: State):
    system_message = """
        You are a data visualization expert and use your favourite graphing library Plotly only. Suppose, that
        the data is provided as {sql_result}. Follow the user's indications when creating the graph.
    """
    user_prompt = """
        User message: {user_message}
    """
    plot_agent_prompt_template = ChatPromptTemplate(
        [("system", system_message), ("user", user_prompt)]
    )

    prompt = plot_agent_prompt_template.invoke(
        {
            "sql_result": state.sql_result,
            "user_message": state.user_message,
        }
    )

    llm = ChatGroq(model=GROQ_MODEL, groq_api_key=Config.groq_api_key)
    response = llm.invoke(prompt)
    state.plot_agent_answer = response.content
    return state

In [62]:
def get_fig_from_code(code):
    local_variables = {}
    exec(code, {}, local_variables)
    return local_variables.get("fig")

In [64]:
import re
def create_plot(state: State):
    response = plot_agent(state)
    result_output = response.plot_agent_answer
    print(result_output)

    code_block_match = re.search(r"```(?:python)?(.*)```", result_output, re.DOTALL)
    print(code_block_match)

    # checkt if code block is found
    if code_block_match:
        code_block = code_block_match.group(1).strip()
        cleaned_code = re.sub(r"(?m)^\s*fig\.show\(\)\s*$", "", code_block)
        fig = get_fig_from_code(cleaned_code)
        return fig, result_output
    else:
        return "", result_output

In [68]:
import re
from IPython.display import display

def create_plot(state: State):
    response = plot_agent(state)
    result_output = response.plot_agent_answer
    print(result_output)

    code_block_match = re.search(r"```(?:python)?(.*)```", result_output, re.DOTALL)

    if code_block_match:
        code_block = code_block_match.group(1).strip()
        cleaned_code = re.sub(r"(?m)^\s*fig\.show\(\)\s*$", "", code_block)
        fig = get_fig_from_code(cleaned_code)
        
        if fig is not None:
            display(fig)  # This will display in Jupyter
            return fig, result_output
        else:
            print("Warning: Could not generate figure from code")
            return None, result_output
    else:
        print("No code block found in response")
        return None, result_output

In [69]:
state = State(**{
    "user_message": "Visualize daily sales trends over time. Create a line chart showing total revenue per day and identify any patterns.",
    "sql_result": "[(datetime.date(2024, 5, 1), 4128.0), (datetime.date(2024, 5, 2), 4074.0), (datetime.date(2024, 5, 3), 4278.0), (datetime.date(2024, 5, 4), 3822.0), (datetime.date(2024, 5, 5), 3945.0), (datetime.date(2024, 5, 6), 4500.0), (datetime.date(2024, 5, 7), 3894.0), (datetime.date(2024, 5, 8), 3921.0), (datetime.date(2024, 5, 9), 4320.0), (datetime.date(2024, 5, 10), 3729.0), (datetime.date(2024, 5, 11), 3747.0), (datetime.date(2024, 5, 12), 4398.0), (datetime.date(2024, 5, 13), 4044.0), (datetime.date(2024, 5, 14), 4221.0), (datetime.date(2024, 5, 15), 3804.0), (datetime.date(2024, 5, 16), 3714.0), (datetime.date(2024, 5, 17), 1932.0)] success"
})
result_state = create_plot(state)
print(result_state.plot_agent_answer)

<think>
Okay, the user wants me to visualize daily sales trends using a line chart. Let me start by understanding the data provided. The data is a list of tuples with dates and corresponding sales figures. The dates start from May 1, 2024, up to May 17, 2024. The sales numbers vary each day.

First, I need to convert the dates into a more readable format. Since the dates are in datetime.date objects, in Plotly I can pass them directly, but it's good practice to ensure they're in a format that Plotly can handle as dates. The values are all floats, which should be straightforward.

Next, I should structure the data into a DataFrame. That way, it's easier to manipulate and plot with Plotly. I'll create columns for 'Date' and 'Revenue'. Then, I'll sort the data by date to ensure the line chart is in chronological order. 

The user mentioned identifying patterns. Patterns in time series data can include trends, seasonality, or anomalies. Looking at the data, I'll plot the line chart and see

ValueError: Mime type rendering requires nbformat>=4.2.0 but it is not installed

AttributeError: 'tuple' object has no attribute 'plot_agent_answer'

In [60]:
state = State(**{
    "user_message": "Visualize daily sales trends over time. Create a line chart showing total revenue per day and identify any patterns.",
    "sql_result": "[(datetime.date(2024, 5, 1), 4128.0), (datetime.date(2024, 5, 2), 4074.0), (datetime.date(2024, 5, 3), 4278.0), (datetime.date(2024, 5, 4), 3822.0), (datetime.date(2024, 5, 5), 3945.0), (datetime.date(2024, 5, 6), 4500.0), (datetime.date(2024, 5, 7), 3894.0), (datetime.date(2024, 5, 8), 3921.0), (datetime.date(2024, 5, 9), 4320.0), (datetime.date(2024, 5, 10), 3729.0), (datetime.date(2024, 5, 11), 3747.0), (datetime.date(2024, 5, 12), 4398.0), (datetime.date(2024, 5, 13), 4044.0), (datetime.date(2024, 5, 14), 4221.0), (datetime.date(2024, 5, 15), 3804.0), (datetime.date(2024, 5, 16), 3714.0), (datetime.date(2024, 5, 17), 1932.0)] success"
})
result_state = plot_agent(state)
print(result_state.plot_agent_answer)

<think>
Okay, let's start by understanding the user's request. They want a line chart showing daily sales trends over time, specifically total revenue per day, and to identify any patterns. 

First, I need to process the data they provided. The data is a list of tuples with dates and corresponding sales figures. The dates start from May 1st to May 17th, 2024. Each tuple has a date and a revenue value.

So, the first step is to extract the dates and the revenue values into separate lists. That way, I can plot the dates on the x-axis and the revenue on the y-axis. Since the dates are in Python's datetime.date format, Plotly should handle them correctly. 

Next, I need to create a line chart using Plotly. The layout should have a title, labeled axes, and a line representing the sales over time. The x-axis will be the dates, and the y-axis will be the total revenue. 

I should also consider the user's mention of identifying patterns. Looking at the data, I can see that there's a significan

In [19]:
from langgraph.graph import StateGraph, END

sql_graph_builder = StateGraph(State)

sql_graph_builder.set_entry_point("write_query")
sql_graph_builder.add_node("write_query", write_query)
sql_graph_builder.add_node("execute_query", execute_query)
sql_graph_builder.add_node("generate_answer", generate_answer)
sql_graph_builder.add_node("cannot_answer", cannot_answer)

sql_graph_builder.add_edge("write_query", "execute_query")
sql_graph_builder.add_conditional_edges(
    "execute_query",
    sql_router,
    {
        "generate_answer": "generate_answer",
        "write_query": "write_query",
        "cannot_answer": "cannot_answer"
    }

)

sql_graph_builder.add_edge("generate_answer", END)
sql_graph_builder.add_edge("cannot_answer", END)

sql_graph_manager = sql_graph_builder.compile()

In [26]:
sql_graph_manager.invoke({"sql_question": "Which franchises have received the most customer reviews? List the top 5 franchise  names with their review counts."})

# Output: 'sql_answer': "There are 20 transactions for UserID 'U001'."

{'sql_question': 'Which franchises have received the most customer reviews? List the top 5 franchise  names with their review counts.',
 'sql_query': 'SELECT sf.name, COUNT(*) AS review_count FROM media_customer_reviews mcr JOIN sales_franchises sf ON mcr.franchise_id = sf.franchise_id GROUP BY sf.franchise_id, sf.name ORDER BY review_count DESC LIMIT 5;',
 'sql_query_execution_status': 'success',
 'sql_error_count': 0,
 'sql_query_error': '',
 'sql_result': "[('Hiroshima Delicacies', 6), ('The Baking Lab', 6), ('Nagoya Nibbles', 6), ('Dough Dreamers', 6), ('Tokyo Treats', 6)]",
 'sql_answer': '<think>\n\n</think>\n\nHere are the top 5 franchises with the most customer reviews:\n\n| Franchise Name          | Review Count |\n|-------------------------|--------------|\n| Hiroshima Delicacies    | 6            |\n| The Baking Lab          | 6            |\n| Nagoya Nibbles          | 6            |\n| Dough Dreamers          | 6            |\n| Tokyo Treats            | 6            |'}