In [None]:
from langchain_community.utilities.sql_database import SQLDatabase
from sqlalchemy import create_engine

engine = create_engine("sqlite:///Chinook.db")
db = SQLDatabase(engine)

In [None]:
from langchain.chat_models import init_chat_model

llm = init_chat_model("gpt-4o", model_provider="azure_openai")

In [None]:
from langchain_community.agent_toolkits import SQLDatabaseToolkit

toolkit = SQLDatabaseToolkit(db=db, llm=llm)

tools = toolkit.get_tools()

for tool in tools:
    print(f"{tool.name}: {tool.description}\n")

In [None]:
import uuid
from typing import Annotated, Any, List
from uuid import uuid4

import pandas as pd
from langgraph.graph.message import MessagesState
from pydantic import BaseModel, Field


class Artifact(BaseModel):
    id: str = Field(default_factory=lambda: str(uuid4()))


class TableArtifact(Artifact):
    type: str = "table"
    query: str | None
    result: Any | None


class ChartArtifact(Artifact):
    type: str = "chart"
    query: str | None
    options: Any | None
    vis_script: str | None


class TreeState(MessagesState):
    # The unique ID
    id: uuid.UUID
    # The height of the current node
    height: int
    # The focused sub-problem
    focused_problem: str
    # The maximum initial height for expanding
    max_height: int | None
    # Reason of associating this focused_problem
    # reason_of_associating: Optional[str]
    # Original problem that trigerred this problem
    original_problem: str | None
    # Child list
    child_list: List["TreeState"] | None
    # Responsd to this node as leaf
    respond: str | None
    # Visualization info about this node
    artifact: Annotated[TableArtifact, ChartArtifact] | None
    # Type of Visualization
    artifact_type: str | None
    # Node to be expanded
    expanded_id: uuid.UUID | None
    # Questions asked previously
    prev_questions: List[str] | None
    # Regenerate time record, 0 means no need to regenerate
    regenerate_times: int | None
    # num of subproblems need to generate
    max_subproblems: int | None
    # tables in the dataset
    data_tables: str | None
    # schema of data tables
    data_schemas: dict | None
    # first level subproblems number
    first_level_subproblems: int | None

In [None]:
from typing import Any, Dict, Tuple, Type

from langchain_community.tools.sql_database.prompt import QUERY_CHECKER
from langchain_community.utilities.sql_database import SQLDatabase
from langchain_core.callbacks import (
    AsyncCallbackManagerForToolRun,
    CallbackManagerForToolRun,
)
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import PromptTemplate
from langchain_core.tools import BaseTool
from pydantic import BaseModel, ConfigDict, Field, model_validator


class BaseSQLDatabaseTool(BaseModel):
    """Base tool for interacting with a SQL database."""

    db: SQLDatabase = Field(exclude=True)

    model_config = ConfigDict(
        arbitrary_types_allowed=True,
    )


class _QuerySQLDataBaseToolInput(BaseModel):
    query: str = Field(..., description="A detailed and correct SQL query.")


class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool):  # type: ignore[override, override]
    """Tool for querying a SQL database."""

    name: str = "sql_db_query"
    description: str = """
    Execute a SQL query against the database and get back the result.
    If the query is not correct, an error message will be returned.
    Empty result is regarded correct.
    """
    args_schema: Type[BaseModel] = _QuerySQLDataBaseToolInput

    def _run(
        self,
        query: str,
        run_manager: CallbackManagerForToolRun | None = None,
    ) -> Tuple[str, Any]:
        """Execute the query, return the results or an error message."""
        result = self.db.run(query, include_columns=True)
        return str(result), result


class _InfoSQLDatabaseToolInput(BaseModel):
    table_names: str = Field(
        ...,
        description=(
            "A comma-separated list of the table names for which to return the schema. "
            "Example input: 'table1, table2, table3'"
        ),
    )


class InfoSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):  # type: ignore[override, override]
    """Tool for getting metadata about a SQL database."""

    name: str = "sql_db_schema"
    description: str = "Get the schema and sample rows for the specified SQL tables."
    args_schema: Type[BaseModel] = _InfoSQLDatabaseToolInput

    def _run(
        self,
        table_names: str,
        run_manager: CallbackManagerForToolRun | None = None,
    ) -> str:
        """Get the schema for tables in a comma-separated list."""
        return self.db.get_table_info_no_throw([t.strip() for t in table_names.split(",")])


class _ListSQLDataBaseToolInput(BaseModel):
    tool_input: str = Field("", description="An empty string")


class ListSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):  # type: ignore[override, override]
    """Tool for getting tables names."""

    name: str = "sql_db_list_tables"
    description: str = (
        "Input is an empty string, output is a comma-separated list of tables in the database."
    )
    args_schema: Type[BaseModel] = _ListSQLDataBaseToolInput

    def _run(
        self,
        tool_input: str = "",
        run_manager: CallbackManagerForToolRun | None = None,
    ) -> str:
        """Get a comma-separated list of table names."""
        return ", ".join(self.db.get_usable_table_names())


class _QuerySQLCheckerToolInput(BaseModel):
    query: str = Field(..., description="A detailed and SQL query to be checked.")


class QuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool):  # type: ignore[override, override]
    """Use an LLM to check if a query is correct.
    Adapted from https://www.patterns.app/blog/2023/01/18/crunchbot-sql-analyst-gpt/"""

    template: str = QUERY_CHECKER
    llm: BaseLanguageModel
    llm_chain: Any = Field(init=False)
    name: str = "sql_db_query_checker"
    description: str = """
    Use this tool to double check if your query is correct before executing it.
    Always use this tool before executing a query with sql_db_query!
    """
    args_schema: Type[BaseModel] = _QuerySQLCheckerToolInput

    @model_validator(mode="before")
    @classmethod
    def initialize_llm_chain(cls, values: Dict[str, Any]) -> Any:
        if "llm_chain" not in values:
            from langchain.chains.llm import LLMChain

            values["llm_chain"] = LLMChain(
                llm=values.get("llm"),  # type: ignore[arg-type]
                prompt=PromptTemplate(template=QUERY_CHECKER, input_variables=["dialect", "query"]),
            )

        if values["llm_chain"].prompt.input_variables != ["dialect", "query"]:
            raise ValueError(
                "LLM chain for QueryCheckerTool must have input variables ['query', 'dialect']"
            )

        return values

    def _run(
        self,
        query: str,
        run_manager: CallbackManagerForToolRun | None = None,
    ) -> str:
        """Use the LLM to check the query."""
        return self.llm_chain.predict(
            query=query,
            dialect=self.db.dialect,
            callbacks=run_manager.get_child() if run_manager else None,
        )

    async def _arun(
        self,
        query: str,
        run_manager: AsyncCallbackManagerForToolRun | None = None,
    ) -> str:
        return await self.llm_chain.apredict(
            query=query,
            dialect=self.db.dialect,
            callbacks=run_manager.get_child() if run_manager else None,
        )

In [None]:
from typing import List

from langchain_community.utilities.sql_database import SQLDatabase
from langchain_core.caches import BaseCache as BaseCache
from langchain_core.callbacks import Callbacks as Callbacks
from langchain_core.language_models import BaseLanguageModel
from langchain_core.tools import BaseTool
from langchain_core.tools.base import BaseToolkit
from pydantic import ConfigDict, Field


class SQLDatabaseToolkit_query(BaseToolkit):
    db: SQLDatabase = Field(exclude=True)
    llm: BaseLanguageModel = Field(exclude=True)

    @property
    def dialect(self) -> str:
        """Return string representation of SQL dialect to use."""
        return self.db.dialect

    model_config = ConfigDict(
        arbitrary_types_allowed=True,
    )

    def get_tools(self) -> List[BaseTool]:
        """Get the tools in the toolkit."""
        query_sql_database_tool_description = (
            "Input to this tool is a detailed and correct SQL query, output is a "
            "result from the database. If the query is not correct, an error message "
            "will be returned, rewrite the query, check the query, and try again."
            "VERY IMPORTANT! LIMIT your query results at most 20."
        )
        query_sql_database_tool = QuerySQLDataBaseTool(
            db=self.db,
            description=query_sql_database_tool_description,
            response_format="content_and_artifact",
        )
        query_sql_checker_tool_description = (
            "Use this tool to double check if your query is correct before executing "
            f"it. Always use this tool before executing a query with {query_sql_database_tool.name}!"
            "VERY IMPORTANT! LIMIT your query results at most 20."
        )
        query_sql_checker_tool = QuerySQLCheckerTool(
            db=self.db, llm=self.llm, description=query_sql_checker_tool_description
        )
        info_sql_database_tool_description = (
            "Input to this tool is a comma-separated list of tables, output is the "
            "schema and sample rows for those tables. "
            "Example Input: table1, table2, table3"
        )
        info_sql_database_tool = InfoSQLDatabaseTool(
            db=self.db, description=info_sql_database_tool_description
        )
        return [
            query_sql_database_tool,
            info_sql_database_tool,
            query_sql_checker_tool,
        ]

    def get_context(self) -> dict:
        """Return db context that you may want in agent prompt."""
        return self.db.get_context()


class SQLDatabaseToolkit_schema(BaseToolkit):
    db: SQLDatabase = Field(exclude=True)
    llm: BaseLanguageModel = Field(exclude=True)

    @property
    def dialect(self) -> str:
        """Return string representation of SQL dialect to use."""
        return self.db.dialect

    model_config = ConfigDict(
        arbitrary_types_allowed=True,
    )

    def get_tools(self) -> List[BaseTool]:
        """Get the tools in the toolkit."""
        info_sql_database_tool_description = (
            "Input to this tool is a comma-separated list of tables, output is the "
            "schema and sample rows for those tables. "
            "Example Input: table1, table2, table3"
        )
        info_sql_database_tool = InfoSQLDatabaseTool(
            db=self.db, description=info_sql_database_tool_description
        )
        return [
            info_sql_database_tool,
        ]

    def get_context(self) -> dict:
        """Return db context that you may want in agent prompt."""
        return self.db.get_context()


class SQLDatabaseToolkit_list(BaseToolkit):
    db: SQLDatabase = Field(exclude=True)
    llm: BaseLanguageModel = Field(exclude=True)

    @property
    def dialect(self) -> str:
        """Return string representation of SQL dialect to use."""
        return self.db.dialect

    model_config = ConfigDict(
        arbitrary_types_allowed=True,
    )

    def get_tools(self) -> List[BaseTool]:
        """Get the tools in the toolkit."""
        info_list_sql_database_tool = "Use this tool to list the table names in the dataset. "
        list_sql_database_tool = ListSQLDatabaseTool(
            db=self.db, description=info_list_sql_database_tool
        )

        return [
            list_sql_database_tool,
        ]

    def get_context(self) -> dict:
        """Return db context that you may want in agent prompt."""
        return self.db.get_context()


SQLDatabaseToolkit_query.model_rebuild()
SQLDatabaseToolkit_schema.model_rebuild()
SQLDatabaseToolkit_list.model_rebuild()

In [None]:
tools_span = SQLDatabaseToolkit_schema(db=SQLDatabase(engine=engine), llm=llm).get_tools()
tools_init = SQLDatabaseToolkit_list(db=SQLDatabase(engine=engine), llm=llm).get_tools()
tools_execute = SQLDatabaseToolkit_query(db=SQLDatabase(engine=engine), llm=llm).get_tools()

In [None]:
from enum import Enum

from langchain_core.messages import HumanMessage, RemoveMessage, ToolMessage
from langgraph.graph import END

init_prompt = """
Return the list of all the table names.
Do not contain any other words. Your response should exactly be of the form:
table1, table2, table3, table4
"""


class SearchingStrategy(Enum):
    BREADTH_FIRST = 0
    DEPTH_FIRST = 1


def traverse(root: TreeState, searching_strategy=SearchingStrategy.BREADTH_FIRST):
    """returns a generator of agent tree traversal according to searching strategy"""
    temp_node_list = [root]
    if SearchingStrategy.BREADTH_FIRST == searching_strategy:
        # Do BFS, use temp_node_list as a queue
        while len(temp_node_list) != 0:
            next_node = temp_node_list.pop(0)
            if "child_list" in next_node:
                temp_node_list.extend(next_node["child_list"])
            yield next_node
    elif SearchingStrategy.DEPTH_FIRST == searching_strategy:
        # Do DFS, use temp_node_list as a stack
        while len(temp_node_list) != 0:
            next_node = temp_node_list.pop(-1)
            if "child_list" in next_node:
                temp_node_list.extend(reversed(next_node["child_list"]))
            yield next_node
    else:
        raise NotImplementedError


def traverse_tree(state: TreeState) -> TreeState:
    """Traverse the tree to decide the next action to take."""
    # get table names first
    if "data_tables" not in state or state["data_tables"] is None:
        if "messages" not in state or state["messages"] == []:
            state["messages"] = [HumanMessage(init_prompt)]
        llm_response = llm.bind_tools(tools_init).invoke(state["messages"])
        state["messages"].append(llm_response)
        if (
            len(state["messages"]) > 2
            and isinstance(state["messages"][-2], ToolMessage)
            and state["messages"][-2].name == "sql_db_list_tables"
        ):
            state["data_tables"] = state["messages"][-2].content
            state["messages"] = [RemoveMessage(id=m.id) for m in state["messages"]]
            state["next"] = "tree_traverser"
            return state
        state["next"] = "tools_init"
        return state

    for node in traverse(state):
        if (
            ("respond" not in node or node["respond"] is None)
            or "child_list" in node
            and node["child_list"] != []
            and "respond" not in node["child_list"][-1]
        ):
            return {"next": "execute"}

        if node["focused_problem"]:
            if (
                ("child_list" not in node or node["child_list"] == [])
                and node["height"] < state["max_height"]
            ) or ("expanded_id" in state and state["expanded_id"] == node["id"]):
                return {"next": "span"}

    return {"next": END}

In [None]:
import uuid
from typing import List

from langchain_core.messages import AIMessage, SystemMessage

prefix_for_analysis_planning = """
You are an agent designed to interact with a SQL database and act as an excellent data scientist.
Given an input problem, you should try to divide it into up to {max_subproblems} sub-problems.
Each sub-problem should correspond to a single section in the report.
For each sub-problem, you must just state the problem directly without any other information.

Here are some hints for you:
- Read the SQL schema, and think about what properties can be analyzed.
- Check a column and consider how does a property change over time.
- Are there any relationship between two columns?
- Are there any noticeable value or trend in given data?

You should not mention words like SQL table schema or SQL clauses, or output with json-like formats, or table column names etc, just give your sub-problems.
Ensure the sub-problems are well-aligned with the input problem and data schema.
Your sub-problems should be meaningful for analysis.
Information like code, hot value and ID are meaningless. Information like number, price and quantity are meaningful.

Upon listing the sub-problems, you should start with "Final Answer:", and then state the n-th sub-problem by starting with 'Sub-problem n:', followed by
the sub-problem. You should only use this format in your final answer:
Final Answer:
Sub-problem 1: ...
Sub-problem 2: ...
"""

suffix_for_exclusion = """
Previous problem list (marked with **):
{previous_sub_problems}

VERY IMPORTANT! Please follow these instructions:

1. The user's input `problem` might be one of the existing problems in the previous problem list. Treat it as a valid input.
2. Generate DISTINCT sub-problems based on the input `problem`. Sub-problems MUST NOT be identical or overly similar to any problem in the previous problem list.
3. Focus on creating NEW sub-problems related to the input, even if the input matches a problem in the previous problem list.

Examples:
1. Input: Analyze the effects of climate change. (from the list)
   - Existing problems:
     - **Analyze the effects of climate change.
     - **Study the causes of global warming.
   - Output (Correct):
    Final Answer:
    Sub-problem 1: Investigate strategies to mitigate climate change.
    Sub-problem 2: Examine the economic impact of climate adaptation.

2. Input: Study the impact of renewable energy adoption. (not from the list)
   - Existing problems:
     - **Analyze barriers to renewable energy adoption.
   - Output (Correct):
    Final Answer:
    Sub-problem 1: Evaluate the role of policy in renewable energy growth.
    Sub-problem 2: Assess the social acceptance of renewable technologies.
"""

table_name_info = """
When interacting with SQL tools to fetch schema details, make sure to only consider the following tables:
{table_names}
"""


def output_parser_division(ori_problem: str, parent_height: int, message: str) -> List["TreeState"]:
    ret = [
        TreeState(
            id=uuid.uuid4(),
            focused_problem=s[3:],
            child_list=[],
            original_problem=ori_problem,
            height=parent_height + 1,
        )
        for s in message.split("Sub-problem ")[1:]
    ]
    return ret


def divide_problem(state: TreeState) -> TreeState:
    """Generate the subproblems given the original problem."""
    if "messages" not in state or state["messages"] == []:
        focused_node = None
        for node in traverse(state):
            if (
                ("child_list" not in node or node["child_list"] == [])
                and node["height"] < state["max_height"]
            ) or ("expanded_id" in state and state["expanded_id"] == node["id"]):
                focused_node = node
                break
        focused_problem = "\n" + focused_node["focused_problem"] + "\n"
        return {"messages": [HumanMessage(focused_problem)]}

    max_subproblems = 2  # default
    if "max_subproblems" in state:
        max_subproblems = state["max_subproblems"]
    if "first_level_subproblems" in state and (
        "child_list" not in state or state["child_list"] == []
    ):  # root span
        max_subproblems = state["first_level_subproblems"]

    _prefix_for_analysis_planning = prefix_for_analysis_planning.format(
        max_subproblems=max_subproblems
    )
    system_message = SystemMessage(content=_prefix_for_analysis_planning)
    messages = [system_message]
    if len(state["messages"]) > 0 and (
        isinstance(state["messages"][-1], HumanMessage) or "Error:" in state["messages"][-1].content
    ):
        if "prev_questions" in state and state["prev_questions"] != []:
            _suffix_for_exclusion = suffix_for_exclusion.format(
                previous_sub_problems="\n".join(state["prev_questions"])
            )
            messages.append(SystemMessage(content=_suffix_for_exclusion))
        messages.append(SystemMessage(table_name_info.format(table_names=state["data_tables"])))

    if len(state["messages"]) > 0:
        messages.append(state["messages"][0])
    if len(state["messages"]) > 1:
        # append last tool msg
        prev_tool_msg = []
        for m in state["messages"][::-1]:
            prev_tool_msg.append(m)
            if isinstance(m, AIMessage):
                break
        messages += prev_tool_msg[::-1]
    llm_response = llm.bind_tools(tools_span).invoke(messages)

    if hasattr(llm_response, "tool_calls") and len(llm_response.tool_calls) > 0:
        return {"messages": [llm_response]}

    for node in traverse(state):
        if (
            ("child_list" not in node or node["child_list"] == [])
            and node["height"] < state["max_height"]
        ) or ("expanded_id" in state and state["expanded_id"] == node["id"]):
            for child in output_parser_division(
                node["focused_problem"], node["height"], llm_response.content
            ):
                node["child_list"].append(child)
                state["prev_questions"].append("**" + child["focused_problem"])
            break
    state["messages"] = [RemoveMessage(id=m.id) for m in state["messages"]]
    return state

In [None]:
import ast
import json
import re
import traceback
from contextlib import redirect_stdout
from io import StringIO
from json import JSONDecodeError
from typing import Any, Dict


def sanitize_input(query: str) -> str:
    """Sanitize input to the python REPL.

    Remove whitespace, backtick & python (if llm mistakes python console as terminal)

    Args:
        query: The query to sanitize

    Returns:
        str: The sanitized query
    """

    # Removes `, whitespace & python from start
    query = re.sub(r"^(\s|`)*(?i:python)?\s*", "", query)
    # Removes whitespace & ` from end
    query = re.sub(r"(\s|`)*$", "", query)
    return query


class Echarts:
    def __init__(self, result: Any | None):
        self._result = result

    def run(self, script: str) -> Dict | str:
        """
        Use to run a python script for generating an Echarts options dict
        and return the dict.
        """
        locals = {"result": eval(self._result), "pd": pd}
        try:
            query = sanitize_input(script)
            tree = ast.parse(query)
            module_end = ast.Module(tree.body[-1:], type_ignores=[])
            if not module_end.body or not isinstance(module_end.body[0], ast.Expr):
                raise self.ResultNotFoundException()
            module = ast.Module(tree.body[:-1], type_ignores=[])
            exec(ast.unparse(module), locals)  # type: ignore
            module_end_str = ast.unparse(module_end)  # type: ignore
            io_buffer = StringIO()
            has_json_decode_err = False
            try:
                json.loads(module_end_str)
            except JSONDecodeError:
                has_json_decode_err = True
            if not has_json_decode_err:
                # 只要大模型生成的代码中用到了上面的变量，则最后的变量用Json解析是必然会出错的
                # 如果没出错，说明它直接加载了数据
                raise self.DataframeNotUsedException()
            try:
                with redirect_stdout(io_buffer):
                    ret = eval(module_end_str, locals)
                    if ret is None:
                        ret = io_buffer.getvalue()
                    assert isinstance(ret, Dict)
                    return ret
            except Exception:
                with redirect_stdout(io_buffer):
                    exec(module_end_str, locals)
                ret = io_buffer.getvalue()
                assert isinstance(ret, Dict)
                return ret
        except Exception as e:
            traceback.print_exc()
            return f"{type(e).__name__}: {str(e)}"

    class ResultNotFoundException(Exception):
        def __init__(self):
            super().__init__(
                "Please write the variable name of the echarts options dict in the last line."
            )

    class DataframeNotUsedException(Exception):
        def __init__(self):
            super().__init__(
                "Please write python script to load data into the echarts options dict "
                "rather than write data directly to the dict."
            )


def show_table(tool_call_msg: AIMessage, tool_msg: ToolMessage):
    tool_call_id = tool_msg.tool_call_id
    tool_calls = tool_call_msg.additional_kwargs["tool_calls"]
    tool_call_content = None
    for tool_call in tool_calls:
        if tool_call["id"] == tool_call_id:
            tool_call_content = tool_call
            break

    try:
        query = json.loads(tool_call_content["function"]["arguments"])["query"]
    except:
        traceback.print_exc()
        query = None
    artifact = TableArtifact(query=query, result=tool_msg.artifact)
    # print("artifact here: ", artifact)
    return artifact


def show_charts(script: str, tool_call_msg: AIMessage, tool_msg: ToolMessage):
    tool_call_id = tool_msg.tool_call_id
    tool_calls = tool_call_msg.additional_kwargs["tool_calls"]
    tool_call_content = None
    for tool_call in tool_calls:
        if tool_call["id"] == tool_call_id:
            tool_call_content = tool_call
            break
    try:
        options_dict = Echarts(tool_msg.artifact).run(script)
        if isinstance(options_dict, str):
            return options_dict, None
        try:
            query = json.loads(tool_call_content["function"]["arguments"])["query"]
        except:
            traceback.print_exc()
            query = None
        artifact = ChartArtifact(
            query=query,
            options=options_dict,
            vis_script=script,
        )
        return artifact
    except:
        return None

In [None]:
highlevel_prefix = """Your are an excellent data scientist agent for business intelligence.
You are responsible for writing section(s) of a report related to the given problem(s).
"""


sql_agent_prefix = """Given the input problem(s), create syntactically correct SQL queries to run, then look at the results of the query
and return the answers.
You can order the query results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the problem.
You have access to tools for interacting with the database. Please make full use of them.
Pay attention to use functions to get current date, current month or current year, if the problem does not explicitly provide a date.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
For the input problem(s), look at the tables in the database to see what you can query.
and use 'sql_db_schema' to retrieve the schema of the most relevant tables first.
Query for meaningful table columns. Information like code, hot value and ID are meaningless. Information like price, quantity and so on are meaningful.
If get schema successfully, call 'sql_db_query_checker' to write and validate query that strongly related to the problem and schema.
If 'sql_db_query_checker' returns right queries, call 'sql_db_query' to get the data result.

VERY IMPORTANT! If a table contains both ID and name columns,
you MUST query for the name rather than ID because no one understands meanings of IDs.
VERY IMPORTANT! If the input problem asks you to predict something
(e.g., an input contains word `will`), you should answer it based on your analysis but do not try to query it
directly because the data user wants does not actually exist.
VERY IMPORTANT! Unless the user specifies the exact number of examples they want, always limit your response to a maximum of 5 results.
Even if the user requests a specific number of examples, do not provide more than 20 results under any circumstances.

In this current stage, do not write anything, always call tools. Do not give the final conclusion.
"""


answer_format_info = """
Now you should not call Tools anymore. You need to use previous query result to give section content.
DO NOT mention words like SQL table schema or SQL clauses, or output with json-like formats, etc, just give your section(s).
Your answer should solely based on actual data of SQL results without general information.
Your answer should be a meaningful analysis.
Information like code, hot value and ID are meaningless. Information like number, price and quantity are meaningful.

For each section, give suitable visualization suggestion according to the problem and query result.
Only give suggestion, do not really visualize the data.
You have 3 choices for visualization: 'table', 'chart', 'none', all in lowercase. Write your choice after '**Vis n: '.
If not specified, use a table when precise values and detailed comparisons are needed.
If not specified, use a chart when highlighting trends, relationships, or making comparison.
Use none if get no query result or Error.
Counting from 1, for N sections, suppose retrieved X ToolMessages, you should indicate x-th ToolMessage corresponding to n-th section, write: '**Tool n: x'.
If there are multiple ToolMessages correspond to one section, only save the first non-empty ToolMessage if exists.

If retrieved query result data for some problem is empty, error or meaningless, do not write any section content. Leave that section empty like:
**Section n:
**Vis n: none
**Tool n: 0

***VERY IMPORTANT!!! The total number of section(s) should strictly be the same as the total number of the input problem(s).
***VERY IMPORTANT!!! Start your answer by 'Final Answer:'. Your final response MUST strictly follow this format:
Final Answer:
**Section 1: ...
**Vis 1: ...
**Tool 1: ...
**Section 2: ...
**Vis 2: ...
**Tool 2: ...
...
"""


script_writing_prompt = """
You are an outstanding python developer who understands data well with a good command of Apache Echarts.
The result of last query has been stored in variable `result`,
you should write Python script to load data from variable `result` to
an echarts options dict (not json string) to display these data.
IMPORTANT! Do not write constants in your code directly.
You must use your code to load value from variable `result`.

Please select an appropriate chart type.
You should always add appropriate tooltips for your chart.
You should output your code only.

Output Instruction:
1. Remember to import necessary packages in your code if needed.
For example:
from decimal import Decimal

2. Do NOT import pyecharts!!!

3. You should write the variable name of the echarts options dict in the last line.
You should output your echarts options as dict. The echarts option dict of your output should be of the format like:
option = {
    'tooltip': {
        'trigger': ...,
        'axisPointer': {
            'type': ...,
            ...
        },
        ...
    },
    'legend': {
        'data': ...,
        ...
    },
    'xAxis': {
        'type': ...,
        'data': ...,
        ...
    },
    'yAxis': {
        'type': ...,
        ...
    },
    'series': [
        {
            'name': ...,
            'type': ...,
            'data': ...,
            ...
        },
        ...
    ],
    'tooltip': {
        'trigger': ...,
        'formatter': ...,
        ...
    },
    ...
}

option
"""

table_name_info = """
When interacting with SQL tools to fetch schema details, make sure to only consider the following tables:
{table_names}
"""


def output_parser(focused_node, last_tool_msg, llm_response):
    nth_node = 0
    for child in focused_node["child_list"]:
        if "respond" in child:
            nth_node += 1
    for content in llm_response.content.split("**Section ")[1:]:
        paragraph = content.split("**Vis ")[0][3:].strip()
        vis_choice = content.split("**Vis ")[1][3:].split("**Tool ")[0]
        tool_order = int(content.split("**Tool ")[1][3:])
        focused_node["child_list"][nth_node]["respond"] = paragraph
        if "table" in vis_choice:
            focused_node["child_list"][nth_node]["artifact_type"] = "table"
            if tool_order <= len(last_tool_msg) - 1:
                tool_call_msg = last_tool_msg[0]
                tool_msg = last_tool_msg[tool_order]
                focused_node["child_list"][nth_node]["artifact"] = show_table(
                    tool_call_msg, tool_msg
                )
        elif "chart" in vis_choice:
            focused_node["child_list"][nth_node]["artifact_type"] = "chart"
            if tool_order <= len(last_tool_msg) - 1:
                tool_call_msg = last_tool_msg[0]
                tool_msg = last_tool_msg[tool_order]
                prompt = llm.invoke([SystemMessage(script_writing_prompt)] + last_tool_msg).content
                focused_node["child_list"][nth_node]["artifact"] = show_charts(
                    prompt, tool_call_msg, tool_msg
                )
        nth_node += 1


def solve_problem(state: TreeState) -> TreeState:
    """Solve each sub-problem based on selected tools."""
    # retrieve messages related to the current sub-problem
    if "messages" not in state or state["messages"] == []:
        if "respond" not in state:  # empty state
            focused_problem = "\nProblem 1: " + state["focused_problem"] + "\n"
            return {"messages": [HumanMessage(focused_problem)]}
        focused_node = None  # none empty state
        for node in traverse(state):
            if (
                "child_list" in node
                and node["child_list"] != []
                and "respond" not in node["child_list"][-1]
            ):
                focused_node = node
                break
        focused_problem = ""
        problem_no = 1
        for child_node in focused_node["child_list"]:
            if "respond" not in child_node:
                focused_problem += (
                    "\nProblem " + str(problem_no) + ": " + child_node["focused_problem"] + "\n"
                )
                problem_no += 1
        return {"messages": [HumanMessage(focused_problem)]}

    last_tool_msg = []
    if len(state["messages"]) > 1:
        # append last tool msg
        prev_tool_msg = []
        for m in state["messages"][::-1]:
            prev_tool_msg.append(m)
            if isinstance(m, AIMessage):
                break
        last_tool_msg += prev_tool_msg[::-1]
    get_query = False
    for tool_msg in last_tool_msg:
        if tool_msg.name == "sql_db_query":
            get_query = True

    highlevel_message = SystemMessage(content=highlevel_prefix)
    if get_query:
        system_message = SystemMessage(content=answer_format_info)
    else:
        system_message = SystemMessage(content=sql_agent_prefix)
    messages = [highlevel_message, system_message]

    if len(state["messages"]) > 0 and (
        isinstance(state["messages"][-1], HumanMessage)
        or (
            "Error:" in state["messages"][-1].content
            and state["messages"][-1].name == "sql_db_schema"
        )
    ):
        messages.append(SystemMessage(table_name_info.format(table_names=state["data_tables"])))

    if len(state["messages"]) > 0:
        messages.append(state["messages"][0])

    if get_query:
        llm_response = llm.invoke(messages + last_tool_msg)
    else:
        llm_response = llm.bind_tools(tools_execute).invoke(messages + last_tool_msg)

    # DO NOT REGENERATE TEMPORARILY DUE TO TOKEN COST ISSUE
    # if state['height'] > 0 and state['messages'][-1].name == 'sql_db_query' and \
    #   (state['messages'][-1].content is None or state['messages'][-1].content == ''):
    #     # regenerate subproblem due to empty query result
    #     if state['regenerate_times'] < 1:
    #         state['messages'] = [RemoveMessage(id=m.id) for m in state['messages']]
    #         state['regenerate_times'] += 1
    #         return state
    #     # regenrate too much times, just ignore
    #     else:
    #         state['regenerate_times'] = 0

    if hasattr(llm_response, "tool_calls") and len(llm_response.tool_calls) > 0:
        # has more tool call request
        return {"messages": [llm_response]}
    # no more tool call request, go to the next node
    if "respond" not in state:  # for root
        if not (
            "**Section " in llm_response.content
            and "**Vis " in llm_response.content
            and "**Tool " in llm_response.content
        ):
            raise ValueError("LLM paragraph response format error")
        content = llm_response.content.split("**Section ")[1]
        paragraph = content.split("**Vis ")[0][3:].strip()
        vis_choice = content.split("**Vis ")[1][3:].split("**Tool ")[0]
        tool_order = int(content.split("**Tool ")[1][3:])
        state["respond"] = paragraph
        if "table" in vis_choice:
            # state['artifact_type'] = 'chart'
            # if tool_order <= len(last_tool_msg) - 1:
            #     tool_call_msg = last_tool_msg[0]
            #     tool_msg = last_tool_msg[tool_order]
            #     prompt = llm.invoke([SystemMessage(script_writing_prompt)] + last_tool_msg).content
            #     state['artifact'] = show_charts(prompt, tool_call_msg, tool_msg)
            state["artifact_type"] = "table"
            if tool_order <= len(last_tool_msg) - 1:
                tool_call_msg = last_tool_msg[0]
                tool_msg = last_tool_msg[tool_order]
                state["artifact"] = show_table(tool_call_msg, tool_msg)
        elif "chart" in vis_choice:
            state["artifact_type"] = "chart"
            if tool_order <= len(last_tool_msg) - 1:
                tool_call_msg = last_tool_msg[0]
                tool_msg = last_tool_msg[tool_order]
                prompt = llm.invoke([SystemMessage(script_writing_prompt)] + last_tool_msg).content
                state["artifact"] = show_charts(prompt, tool_call_msg, tool_msg)
    else:  # none leaf node
        for node in traverse(state):
            if (
                "child_list" in node
                and node["child_list"] != []
                and "respond" not in node["child_list"][-1]
            ):
                output_parser(node, last_tool_msg, llm_response)
                break
    if "expanded_id" in state:
        state["expanded_id"] = None

    state["messages"] = [RemoveMessage(id=m.id) for m in state["messages"]]
    # state['regenerate_times'] = 0
    return state

In [None]:
from langgraph.prebuilt import ToolNode

use_tools_init = ToolNode(tools=tools_init)
use_tools_span = ToolNode(tools=tools_span)
# use_tools_regenerate = ToolNode(tools=tools_regenerate)
use_tools_execute = ToolNode(tools=tools_execute)

In [None]:
from langgraph.graph import START, StateGraph


def route_span(
    state: TreeState,
):
    """
    Use in the conditional_edge to route to the ToolNode if the last message
    has tool calls. Otherwise, route to the next.
    """
    if state["messages"] == []:
        return "tree_traverser"
    elif type(state["messages"][-1]) == HumanMessage:
        return "span"
    else:
        return "tools_span"


# def route_regenerate(
#         state: TreeState,
# ):
#     """
#     Use in the conditional_edge to route to the ToolNode if the last message
#     has tool calls. Otherwise, route to the next.
#     """
#     if state['messages'] == []:
#         return 'execute'
#     elif type(state['messages'][-1]) == HumanMessage:
#         return "regenerate"
#     else:
#         return "tools_span"


def route_execute(
    state: TreeState,
):
    """
    Use in the conditional_edge to route to the ToolNode if the last message
    has tool calls. Otherwise, route to the next.
    """
    if state["messages"] == []:
        # if state['regenerate_times'] > 0:
        #     return 'regenerate'
        return "tree_traverser"
    if type(state["messages"][-1]) == HumanMessage:
        return "execute"
    else:
        return "tools_execute"


builder = StateGraph(TreeState)
builder.add_node("tree_traverser", traverse_tree)
builder.add_node("span", divide_problem)
builder.add_node("execute", solve_problem)
# builder.add_node("regenerate", regenerate_problem)
# builder.add_node("summarize", summarize_subproblems)
builder.add_node("tools_init", use_tools_init)
builder.add_node("tools_span", use_tools_span)
# builder.add_node("tools_regenerate", use_tools_regenerate)
builder.add_node("tools_execute", use_tools_execute)

builder.add_edge(START, "tree_traverser")
builder.add_conditional_edges("tree_traverser", lambda state: state["next"])
builder.add_conditional_edges(
    "span",
    route_span,
    {"span": "span", "tools_span": "tools_span", "tree_traverser": "tree_traverser"},
)
# builder.add_conditional_edges(
#     "regenerate",
#     route_regenerate,
#     {"regenerate": "regenerate", "tools_regenerate": "tools_regenerate", 'execute': 'execute'},
# )
builder.add_conditional_edges(
    "execute",
    route_execute,
    {
        "execute": "execute",
        "tools_execute": "tools_execute",
        "tree_traverser": "tree_traverser",
    },  # 'regenerate': 'regenerate'},
)

builder.add_edge("tools_init", "tree_traverser")
builder.add_edge("tools_span", "span")
# builder.add_edge("tools_regenerate", "regenerate")
builder.add_edge("tools_execute", "execute")

In [None]:
from langgraph.checkpoint.memory import MemorySaver

agent = builder.compile(checkpointer=MemorySaver())

In [None]:
from IPython.display import Image, display

display(Image(agent.get_graph().draw_mermaid_png()))

In [None]:
example_query = "Conduct a detailed analysis of orders."


root = TreeState(
    id=uuid.uuid4(),
    max_height=1,
    focused_problem=example_query,
    height=0,
    prev_questions=["**" + str(example_query) + "\n"],
    regenerate_times=0,
    max_subproblems=4,
    first_level_subproblems=2,
    child_list=[],
)

config_thread = 100

In [None]:
config_thread += 1
config = {
    "configurable": {"thread_id": config_thread},
    "recursion_limit": 50,
}
print(root, config)

for step in agent.stream(root, config):
    print(step)