In [None]:
import os
import asyncio
from typing import Any
import asyncpg
import pandas as pd
from dotenv import load_dotenv, find_dotenv

In [None]:
_ = load_dotenv(find_dotenv())
postgres_uri = os.environ.get("DATABASE")

In [None]:
class PostgreSQL:
    def __init__(self,
                 *,
                 uri: Any | None=None,
                 **kwargs,
                 )-> None:
        self._dialect = "postgresql"
        self.uri = uri
        self.config = kwargs
        self.conn = None

    @property
    def dialect(self, ):
        return self._dialect

    async def connect(self, ) -> None:
        self.conn = await asyncpg.connect(
            dsn=self.uri,
            **self.config,                                          
        )

    async def create_table_schema(self, table_name: str) -> str:
        # Get columns
        columns = await self.conn.fetch(f"""
            SELECT column_name, data_type, is_nullable, character_maximum_length
            FROM information_schema.columns
            WHERE table_name = '{table_name}'
            ORDER BY ordinal_position;
        """)

        # Get primary keys
        pk_rows = await self.conn.fetch(f"""
            SELECT a.attname
            FROM pg_index i
            JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
            WHERE i.indrelid = '{table_name}'::regclass AND i.indisprimary;
        """)
        primary_keys = [row['attname'] for row in pk_rows]

        # Get foreign keys
        fk_rows = await self.conn.fetch(f"""
            SELECT
                kcu.column_name,
                ccu.table_name AS foreign_table,
                ccu.column_name AS foreign_column
            FROM information_schema.table_constraints AS tc
            JOIN information_schema.key_column_usage AS kcu
                ON tc.constraint_name = kcu.constraint_name
            JOIN information_schema.constraint_column_usage AS ccu
                ON ccu.constraint_name = tc.constraint_name
            WHERE tc.constraint_type = 'FOREIGN KEY' AND tc.table_name = '{table_name}';
        """)

        # Build column definitions
        col_defs = []
        for col in columns:
            name = f'"{col["column_name"]}"'
            dtype = col["data_type"].upper()
            if dtype == "CHARACTER VARYING":
                dtype = f'NVARCHAR({col["character_maximum_length"]})'
            elif dtype == "CHARACTER":
                dtype = f'NCHAR({col["character_maximum_length"]})'
            elif dtype == "TEXT":
                dtype = "TEXT"
            elif dtype == "INTEGER":
                dtype = "INTEGER"
            # Add more type mappings if needed

            nullable = "NOT NULL" if col["is_nullable"] == "NO" else ""
            col_defs.append(f"{name} {dtype} {nullable}".strip())

        # Add primary key
        if primary_keys:
            pk = ', '.join(f'"{key}"' for key in primary_keys)
            col_defs.append(f"PRIMARY KEY ({pk})")

        # Add foreign keys
        for fk in fk_rows:
            col_defs.append(
                f'FOREIGN KEY("{fk["column_name"]}") REFERENCES "{fk["foreign_table"]}" ("{fk["foreign_column"]}")'
            )

        # Final SQL
        sql = f'CREATE TABLE "{table_name}" (\n\t' + ',\n\t'.join(col_defs) + '\n);'     
        return sql
    
    
    async def list_tables(self, )-> list[str]:
        rows = await self.conn.fetch("""
            SELECT tablename
            FROM pg_catalog.pg_tables
            WHERE schemaname != 'pg_catalog' AND schemaname != 'information_schema';
        """)
        return [row['tablename'].title() for row in rows]

    async def get_example(self, table_name: str) -> str:
         # Execute a query
        values = await self.conn.fetch(
            f"SELECT * FROM {table_name} LIMIT 3;",

        )
        examples = []
        for item in values:
            examples.append(list(item.items()))

        records = [dict(row) for row in examples]
        
        table = pd.DataFrame(records).to_string()

        return ("/*\n" + f"3 rows from '{table_name}' table:\n" + table + "\n"+
                "*/"
        )

    async def get_all_table_schema_example(self) -> str:
        table_names = await self.list_tables()
        all_schema = ""
        for t_names in table_names:
            tab = await self.create_table_schema(t_names.lower())
            example = await self.get_example(t_names.lower())

            all_schema += tab + "\n\n" + example + "\n\n"
            
        return all_schema

    async def execute(self, query: str) -> asyncpg.Record:
        # Execute a query
        values = await self.conn.fetch(
            query,

        )
        await self.clean_up()
        return values
    
    async def clean_up(self) -> None:
        await self.conn.close()
        
psql = PostgreSQL(uri=postgres_uri)
await psql.connect()


In [None]:
await psql.list_tables()


In [None]:
res = await psql.get_example("customer")
print(res)

In [None]:
print(psql.dialect)


In [None]:
schema = await psql.create_table_schema("album")
print(schema)

In [None]:
schemas_examples = await psql.get_all_table_schema_example()
print(schemas_examples)

In [None]:
from langchain.chat_models import init_chat_model

model = "gpt-4o"
model_provider = "openai"
# Instantiate the chat model
llm = init_chat_model(f"{model_provider}:{model}")

In [None]:
# test that it works
output = await llm.ainvoke("hi!") # call the model
output.content  # Get the content

In [None]:
async for step in llm.astream("hello how are you"):
    print(step.content, end="")

In [None]:
from business_copilot.biz_analytics.db_tools import (list_tables, 
                                                     get_relevant_schema_example, 
                                                     resolve_error, 
                                                     execute_query,
                                                     double_check_query,
                                                     )
from business_copilot.biz_analytics.prompts import POSTGRES_SYSTEM_MESSAGE
from langchain_core.prompts import PromptTemplate
from langgraph.prebuilt import create_react_agent

In [None]:
# Get the tools
POSTGRES_SYSTEM_MESSAGE = """
You are a Data Analyst agent designed to interact with a POSTGRESQL database.
Given an input question:
- Create a syntactically correct query to run,
- Look at the results of the query and return the answer. 
- Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.

IMPORTANT:
- To start you should ALWAYS look at the list of tables in the database to be sure your predicted table exists. Do NOT skip this step.

Check the list of tables again to be sure they are all relevant before, Get example and schema for each of the tables you predicted. 
Make sure that they exist. So, You can have context you can work with.

MUST DO: You MUST double check your query before executing it. DO NOT FORGET this step.
ONLY Then you execute the query with the most relevant table.

You can order the results by a relevant column to return the most interesting
examples in the database. Never query for all the columns from a specific table,
only ask for the relevant columns given the question.

Note:
    If you get an error while executing the final query, that is, when using the execute_query tool 
    ONLY then can you use the resolve_error tool. to try again and resolve the error.
VERY IMPORTANT:
    - DO NOT make any DOMAIN DEFINITION LANGUAGE (DDL) statements such  as (CREATE, ALTER, DROP, TRUNCATE, RENAME)
    - Also DO NOT make any DML statements (INSERT, UPDATE, DELETE etc.) to the database.
""".format(
    top_k=5,
)

tools = [
    list_tables, get_relevant_schema_example,
    resolve_error, execute_query, 
    double_check_query,
]

In [None]:
from langchain.chat_models import init_chat_model
# Create graph
model = "gpt-4o"
llm = init_chat_model(f"openai:{model}", temperature=0.0)
agent_executor = create_react_agent(llm , tools=tools, prompt=POSTGRES_SYSTEM_MESSAGE)

In [None]:
question =  "Which country's customers spent the most?"
async for step in agent_executor.astream(
    {"messages": [{"role": "user", "content": question}]},
    stream_mode="values",
):
    step["messages"][-1].pretty_print()

### Memory