In [1]:
import os
import asyncio
from sqlalchemy import create_engine
from typing import Any
import asyncpg
from dotenv import load_dotenv, find_dotenv
from langchain_community.utilities.sql_database import SQLDatabase


In [2]:
_ = load_dotenv(find_dotenv())
postgres_uri = os.environ.get("DATABASE")

In [3]:
import pandas as pd

In [4]:
class PostgreSQL:
    def __init__(self,
                 *,
                 uri: Any | None=None,
                 **kwargs,
                 )-> None:
        self._dialect = "postgresql"
        self.uri = uri
        self.config = kwargs
        self.conn = None

    @property
    def dialect(self, ):
        return self._dialect

    async def connect(self, ) -> None:
        self.conn = await asyncpg.connect(
            dsn=self.uri,
            **self.config,                                          
        )

    async def create_table_schema(self, table_name: str) -> str:
        # Get columns
        columns = await self.conn.fetch(f"""
            SELECT column_name, data_type, is_nullable, character_maximum_length
            FROM information_schema.columns
            WHERE table_name = '{table_name}'
            ORDER BY ordinal_position;
        """)

        # Get primary keys
        pk_rows = await self.conn.fetch(f"""
            SELECT a.attname
            FROM pg_index i
            JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
            WHERE i.indrelid = '{table_name}'::regclass AND i.indisprimary;
        """)
        primary_keys = [row['attname'] for row in pk_rows]

        # Get foreign keys
        fk_rows = await self.conn.fetch(f"""
            SELECT
                kcu.column_name,
                ccu.table_name AS foreign_table,
                ccu.column_name AS foreign_column
            FROM information_schema.table_constraints AS tc
            JOIN information_schema.key_column_usage AS kcu
                ON tc.constraint_name = kcu.constraint_name
            JOIN information_schema.constraint_column_usage AS ccu
                ON ccu.constraint_name = tc.constraint_name
            WHERE tc.constraint_type = 'FOREIGN KEY' AND tc.table_name = '{table_name}';
        """)

        # Build column definitions
        col_defs = []
        for col in columns:
            name = f'"{col["column_name"]}"'
            dtype = col["data_type"].upper()
            if dtype == "CHARACTER VARYING":
                dtype = f'NVARCHAR({col["character_maximum_length"]})'
            elif dtype == "CHARACTER":
                dtype = f'NCHAR({col["character_maximum_length"]})'
            elif dtype == "TEXT":
                dtype = "TEXT"
            elif dtype == "INTEGER":
                dtype = "INTEGER"
            # Add more type mappings if needed

            nullable = "NOT NULL" if col["is_nullable"] == "NO" else ""
            col_defs.append(f"{name} {dtype} {nullable}".strip())

        # Add primary key
        if primary_keys:
            pk = ', '.join(f'"{key}"' for key in primary_keys)
            col_defs.append(f"PRIMARY KEY ({pk})")

        # Add foreign keys
        for fk in fk_rows:
            col_defs.append(
                f'FOREIGN KEY("{fk["column_name"]}") REFERENCES "{fk["foreign_table"]}" ("{fk["foreign_column"]}")'
            )

        # Final SQL
        sql = f'CREATE TABLE "{table_name}" (\n\t' + ',\n\t'.join(col_defs) + '\n);'     
        return sql
    
    
    async def list_tables(self, )-> list[str]:
        rows = await self.conn.fetch("""
            SELECT tablename
            FROM pg_catalog.pg_tables
            WHERE schemaname != 'pg_catalog' AND schemaname != 'information_schema';
        """)
        return [row['tablename'].title() for row in rows]

    async def get_example(self, table_name: str) -> str:
         # Execute a query
        values = await self.conn.fetch(
            f"SELECT * FROM {table_name} LIMIT 3;",

        )
        examples = []
        for item in values:
            examples.append(list(item.items()))

        records = [dict(row) for row in examples]
        
        table = pd.DataFrame(records).to_string()

        return ("/*\n" + f"3 rows from '{table_name}' table:\n" + table + "\n"+
                "*/"
        )

    async def get_all_table_schema_example(self) -> str:
        table_names = await self.list_tables()
        all_schema = ""
        for t_names in table_names:
            tab = await self.create_table_schema(t_names.lower())
            example = await self.get_example(t_names.lower())

            all_schema += tab + "\n\n" + example + "\n\n"
            
        return all_schema

    async def execute(self, query: str) -> asyncpg.Record:
        # Execute a query
        values = await self.conn.fetch(
            query,

        )
        await self.clean_up()
        return values
    
    async def clean_up(self) -> None:
        await self.conn.close()
        
psql = PostgreSQL(uri=postgres_uri)
await psql.connect()


In [5]:
await psql.list_tables()


['Artist',
 'Album',
 'Employee',
 'Customer',
 'Invoice',
 'Invoice_Line',
 'Track',
 'Playlist',
 'Playlist_Track',
 'Genre',
 'Media_Type']

In [6]:
res = await psql.get_example("customer")
print(res)

/*
3 rows from 'customer' table:
   customer_id first_name  last_name                                           company                          address                 city state  country postal_code               phone                 fax                  email  support_rep_id
0            1       Luís  Gonçalves  Embraer - Empresa Brasileira de Aeronáutica S.A.  Av. Brigadeiro Faria Lima, 2170  São José dos Campos    SP   Brazil   12227-000  +55 (12) 3923-5555  +55 (12) 3923-5566   luisg@embraer.com.br               3
1            2     Leonie     Köhler                                              None          Theodor-Heuss-Straße 34            Stuttgart  None  Germany       70174    +49 0711 2842222                None  leonekohler@surfeu.de               5
2            3   François   Tremblay                                              None                1498 rue Bélanger             Montréal    QC   Canada     H2G 1A7   +1 (514) 721-4711                None    ftremblay@gmail

In [7]:
print(psql.dialect)


postgresql


In [8]:
schema = await psql.create_table_schema("album")
print(schema)

CREATE TABLE "album" (
	"album_id" INTEGER NOT NULL,
	"title" NVARCHAR(160) NOT NULL,
	"artist_id" INTEGER NOT NULL,
	PRIMARY KEY ("album_id"),
	FOREIGN KEY("artist_id") REFERENCES "artist" ("artist_id")
);


In [10]:
schemas_examples = await psql.get_all_table_schema_example()
print(schemas_examples)

CREATE TABLE "artist" (
	"artist_id" INTEGER NOT NULL,
	"name" NVARCHAR(120),
	PRIMARY KEY ("artist_id")
);

/*
3 rows from 'artist' table:
   artist_id       name
0          1      AC/DC
1          2     Accept
2          3  Aerosmith
*/

CREATE TABLE "album" (
	"album_id" INTEGER NOT NULL,
	"title" NVARCHAR(160) NOT NULL,
	"artist_id" INTEGER NOT NULL,
	PRIMARY KEY ("album_id"),
	FOREIGN KEY("artist_id") REFERENCES "artist" ("artist_id")
);

/*
3 rows from 'album' table:
   album_id                                  title  artist_id
0         1  For Those About To Rock We Salute You          1
1         2                      Balls to the Wall          2
2         3                      Restless and Wild          2
*/

CREATE TABLE "employee" (
	"employee_id" INTEGER NOT NULL,
	"last_name" NVARCHAR(20) NOT NULL,
	"first_name" NVARCHAR(20) NOT NULL,
	"title" NVARCHAR(30),
	"reports_to" INTEGER,
	"birth_date" TIMESTAMP WITHOUT TIME ZONE,
	"hire_date" TIMESTAMP WITHOUT TIME ZONE,
	"addre

In [11]:
from langchain.chat_models import init_chat_model

model = "gpt-4o"
model_provider = "openai"
# Instantiate the chat model
llm = init_chat_model(f"{model_provider}:{model}")

In [12]:
# test that it works
output = await llm.ainvoke("hi!") # call the model
output.content  # Get the content

'Hello! How can I assist you today?'

In [14]:
from business_copilot.biz_analytics.db_tools import (list_tables, 
                                                     get_relevant_schema_example, 
                                                     resolve_error, 
                                                     execute_query,
                                                     double_check_query,
                                                     clean_up)
from business_copilot.biz_analytics.prompts import POSTGRES_SYSTEM_MESSAGE
from langchain_core.prompts import PromptTemplate
from langgraph.prebuilt import create_react_agent

In [16]:
import pprint
pprint.pprint(POSTGRES_SYSTEM_MESSAGE)

('\n'
 'You are an agent designed to interact with a POSTGRESQL database.\n'
 'Given an input question, create a syntactically correct query to run,\n'
 'then look at the results of the query and return the answer. Unless the '
 'user\n'
 'specifies a specific number of examples they wish to obtain, always limit '
 'your\n'
 'query to at most 5 results.\n'
 '\n'
 'To start you should ALWAYS look at the list of tables in the database to to '
 'be sure your predicted table exist \n'
 'so you can query. Do NOT skip this step.\n'
 '\n'
 'You will have to get examples and schema for each relevant tables to have a '
 'context you can work with.\n'
 '\n'
 'You can order the results by a relevant column to return the most '
 'interesting\n'
 'examples in the database. Never query for all the columns from a specific '
 'table,\n'
 'only ask for the relevant columns given the question.\n'
 '\n'
 'You MUST double check your query before executing it.\n'
 '\n'
 'If you get an error while executing

In [17]:
# Get the tools
tools = [
    list_tables, get_relevant_schema_example,
    resolve_error, execute_query, 
    double_check_query, clean_up
]

In [19]:
# Create graph
model = "gpt-4o-mini"
llm = init_chat_model(f"openai:{model}")
agent_executor = create_react_agent(llm , tools=tools, prompt=POSTGRES_SYSTEM_MESSAGE)

In [20]:
question =  "Which country's customers spent the most?"
async for step in agent_executor.astream(
    {"messages": [{"role": "user", "content": question}]},
    stream_mode="values",
):
    step["messages"][-1].pretty_print()


Which country's customers spent the most?
Tool Calls:
  list_tables (call_yUBnwCoH4IPl77aklGJBDGZn)
 Call ID: call_yUBnwCoH4IPl77aklGJBDGZn
  Args:
Name: list_tables

["artist", "album", "employee", "customer", "invoice", "invoice_line", "track", "playlist", "playlist_track", "genre", "media_type"]
Tool Calls:
  get_relevant_schema_example (call_3TymFE9LNMOyZosgn4q2MkoB)
 Call ID: call_3TymFE9LNMOyZosgn4q2MkoB
  Args:
    table_names: ['customer']
  get_relevant_schema_example (call_8CMJiBzmWGAI2qmU0Fv3sZyO)
 Call ID: call_8CMJiBzmWGAI2qmU0Fv3sZyO
  Args:
    table_names: ['invoice']
Name: get_relevant_schema_example

CREATE TABLE "invoice" (
	"invoice_id" INTEGER NOT NULL,
	"customer_id" INTEGER NOT NULL,
	"invoice_date" TIMESTAMP WITHOUT TIME ZONE NOT NULL,
	"billing_address" NVARCHAR(70),
	"billing_city" NVARCHAR(40),
	"billing_state" NVARCHAR(40),
	"billing_country" NVARCHAR(40),
	"billing_postal_code" NVARCHAR(10),
	"total" NUMERIC NOT NULL,
	PRIMARY KEY ("invoice_id"),
	FOREIGN 