In [1]:
import os
from dotenv import load_dotenv, find_dotenv

In [2]:
_ = load_dotenv(find_dotenv())
postgres_uri = os.environ.get("DATABASE")

In [3]:
from business_copilot.biz_analytics.db_tools import (list_tables, 
                                                     get_relevant_schema_example, 
                                                     resolve_error, 
                                                     execute_query,
                                                     double_check_query,
                                                     )
from business_copilot.biz_analytics.prompts import POSTGRES_SYSTEM_MESSAGE
from langchain_core.prompts import PromptTemplate
from langgraph.prebuilt import create_react_agent

In [4]:
# Get the tools
POSTGRES_SYSTEM_MESSAGE = """
You are a Data Analyst agent designed to interact with a POSTGRESQL database.
Given an input question:
- Create a syntactically correct query to run,
- Look at the results of the query and return the answer. 
- Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.

IMPORTANT:
- To start you should ALWAYS look at the list of tables in the database to be sure your predicted table exists. Do NOT skip this step.

Check the list of tables again to be sure they are all relevant before, Get example and schema for each of the tables you predicted. 
Make sure that they exist. So, You can have context you can work with.

MUST DO: You MUST double check your query before executing it. DO NOT FORGET this step.
ONLY Then you execute the query with the most relevant table.

You can order the results by a relevant column to return the most interesting
examples in the database. Never query for all the columns from a specific table,
only ask for the relevant columns given the question.

Note:
    If you get an error while executing the final query, that is, when using the execute_query tool 
    ONLY then can you use the resolve_error tool. to try again and resolve the error.
VERY IMPORTANT:
    - DO NOT make any DOMAIN DEFINITION LANGUAGE (DDL) statements such  as (CREATE, ALTER, DROP, TRUNCATE, RENAME)
    - Also DO NOT make any DML statements (INSERT, UPDATE, DELETE etc.) to the database.
""".format(
    top_k=5,
)

tools = [
    list_tables, get_relevant_schema_example,
    resolve_error, execute_query, 
    double_check_query,
]

In [5]:
from langchain.chat_models import init_chat_model
from langgraph.prebuilt.chat_agent_executor import AgentState
from langmem.short_term import SummarizationNode, RunningSummary
from langchain_core.messages.utils import count_tokens_approximately
from langgraph.prebuilt.chat_agent_executor import AgentState
from langgraph.checkpoint.memory import InMemorySaver
from typing import Any

In [6]:

model = "gpt-4o"
llm = init_chat_model(f"openai:{model}", temperature=0.0)

In [7]:
from langchain.chat_models import init_chat_model
summarization_llm = init_chat_model("openai:gpt-4o", temperature=0.0)

In [8]:
summarization_node = SummarizationNode( 
    token_counter=count_tokens_approximately,
    model=summarization_llm,
    max_tokens=4096,  # make it higher depends on question
    max_summary_tokens=2048,
    output_messages_key="summarized_llm_messages",
)

In [9]:
class State(AgentState):
    # NOTE: We're adding this key to keep track of previous summary information
    # to make sure we're not summarizing on every LLM call
    context: dict[str, RunningSummary]

In [10]:
checkpointer = InMemorySaver()

In [11]:
agent_mem = create_react_agent(
    llm, tools=tools, 
    prompt=POSTGRES_SYSTEM_MESSAGE, 
    pre_model_hook=summarization_node,
    state_schema=State,
    checkpointer=checkpointer,
                              )

config = {
    "configurable": {
        "thread_id": "1"
    }
}



In [14]:
question =  "Which country's customers spent the most?"
async for step in agent_mem.astream(
    {"messages": [{"role": "user", "content": question}]},
    config=config,
    stream_mode="values",
):
    step["messages"][-1].pretty_print()


Which country's customers spent the most?
Tool Calls:
  list_tables (call_N9M0M0hssLiQflqJCqyKdcVo)
 Call ID: call_N9M0M0hssLiQflqJCqyKdcVo
  Args:
Name: list_tables

["artist", "album", "employee", "customer", "invoice", "invoice_line", "track", "playlist", "playlist_track", "genre", "media_type"]
Tool Calls:
  get_relevant_schema_example (call_nAjGtVqEy2FQvaVmR0AyXw9Q)
 Call ID: call_nAjGtVqEy2FQvaVmR0AyXw9Q
  Args:
    table_names: ['customer', 'invoice']
Name: get_relevant_schema_example

CREATE TABLE "customer" (
	"customer_id" INTEGER NOT NULL,
	"first_name" NVARCHAR(40) NOT NULL,
	"last_name" NVARCHAR(20) NOT NULL,
	"company" NVARCHAR(80),
	"address" NVARCHAR(70),
	"city" NVARCHAR(40),
	"state" NVARCHAR(40),
	"country" NVARCHAR(40),
	"postal_code" NVARCHAR(10),
	"phone" NVARCHAR(24),
	"fax" NVARCHAR(24),
	"email" NVARCHAR(60) NOT NULL,
	"support_rep_id" INTEGER,
	PRIMARY KEY ("customer_id"),
	FOREIGN KEY("support_rep_id") REFERENCES "employee" ("employee_id")
);

/*
3 rows from

In [None]:
agent_mem.get_state(config)