In [1]:
import streamlit as st
import json
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import time
from dotenv import load_dotenv
from src.auth.hashing import authenticate  # Import authentication function
from src.agents.request_handler import GPTRequestHandler
from src.agents.sql_query_generator import SQLQueryGenerator
from src.agents.sql_query_executor import SQLQueryExecutor
from src.agents.business_intelligence_analyzer import BusinessIntelligenceAnalyzer
from src.agents.sql_query_reasoning_generation import SQLQueryReasoningGenerator
from src.agents.query_intent_classifier import QueryIntentClassifier
from src.agents.misleading_query_handler import MisleadingQueryHandler
from src.agents.sql_auto_fixer import SQLQueryAutoFixer
from src.agents.prompts import (
    SYSTEM_PROMPT_INTENTCLASSIFIER,
    CONTEXT_SCHEMA,
    SYSTEM_PROMPT_SQL_REASONING,
    SYSTEM_PROMPT_SQG,
    SYSTEM_PROMPT_BI_ANALYSIS,
    LANGUAGE,
    SYTEM_PROMPT_MISLEADING_QUERY_SUGGESTION
)
from src.agents.utils import display_refrence_table, display_and_pin_charts, display_pinned_charts
from src.memory.manager import TokenLimitedMemoryBuffer
from src.agents import styles
from pprint import pprint
load_dotenv()

# Initialize Components
GPT4V_KEY = os.getenv("GPT4V_KEY")
ENDPOINT = os.getenv("GPT_ENDPOINT")
request_handler = GPTRequestHandler(api_key=GPT4V_KEY, endpoint=ENDPOINT)
query_intent_classifier = QueryIntentClassifier(request_handler=request_handler)
sql_reasoning_generator = SQLQueryReasoningGenerator(request_handler=request_handler)
sql_generator = SQLQueryGenerator(request_handler=request_handler)
sql_executor = SQLQueryExecutor()
bi_analyzer = BusinessIntelligenceAnalyzer(request_handler=request_handler)
misleading_query_handler = MisleadingQueryHandler(request_handler=request_handler, system_prompt=SYTEM_PROMPT_MISLEADING_QUERY_SUGGESTION)
sql_query_fixer = SQLQueryAutoFixer(request_handler=request_handler, database_schema=CONTEXT_SCHEMA)
memory_buffer = TokenLimitedMemoryBuffer(max_tokens=1000)
CHART_DIR = 'charts'
PINNED_CHART_DIR = 'pinned_charts'

In [2]:
user_question= "How do the sales of top 5 products compare across different regions?"

In [7]:
if user_question:
    for chart in os.listdir(CHART_DIR):
        os.remove(os.path.join(CHART_DIR, chart))

    try:
        query_results = {}
        bi_analysis_result = {"business_analysis": {"summary": "Analysis not available. Try again", "chart-python-code": None}}

        print("Understanding Intent...")
        intent_analysis = query_intent_classifier.classify(
            system_prompt=SYSTEM_PROMPT_INTENTCLASSIFIER,
            context=CONTEXT_SCHEMA,
            user_input=user_question
        )
        pprint(intent_analysis)
        if intent_analysis['intent'].lower() == 'misleading_query':
            assistant_response = misleading_query_handler.suggest_better_questions(
                reasoning=intent_analysis['reasoning'],
                user_question=intent_analysis["rephrased_question"]
            )
        elif intent_analysis['intent'].lower() == 'general':
            assistant_response = (
                "I can only answer database-related queries.\n\n"
                "**Example Questions:**\n"
                "- *'What were the total sales in the last quarter?'*\n"
                "- *'How many customers placed an order last month?'*"
            )
        else:
            
            try:
                print("Reasoning Optimal Query Plan...")
                reasoning = sql_reasoning_generator.generate_reasoning(
                    SYSTEM_PROMPT_SQL_REASONING, CONTEXT_SCHEMA, intent_analysis['rephrased_question'], LANGUAGE, chat_memory=memory_buffer.get_context_markdown()
                )
                pprint(reasoning)

                print("Writing Query...")
                sql_query = sql_generator.generate_queries(
                    SYSTEM_PROMPT_SQG, CONTEXT_SCHEMA, intent_analysis['rephrased_question'], reasoning, time.time(), LANGUAGE
                )
                pprint(sql_query)

                print("Executing Queries...")
                query_results = sql_executor.execute_queries(sql_query["sql_query_steps"])
                pprint(query_results)
                # st.write(query_results)
                sql_executor.close_connection()
            except Exception as e:
                
                print("Fixing Queries...", e)
                fixed_queries = sql_query_fixer.fix_sql_errors(intent_analysis['rephrased_question'], query_results)
                pprint(fixed_queries)
                query_results = sql_executor.execute_queries(fixed_queries)
                pprint(query_results)
                # st.write(query_results)

            try:
                print("Analyzing ...")
                bi_analysis_result = bi_analyzer.analyze_results(SYSTEM_PROMPT_BI_ANALYSIS, query_results, user_question=intent_analysis['rephrased_question'])
                pprint(bi_analysis_result)
            except Exception as e:
                print(f"Analysis failed: {e}")

            chart_code = bi_analysis_result["business_analysis"].get("chart-python-code", None)
            
            if chart_code:
                try:
                    exec_globals = {}
                    exec(chart_code.replace("```python", "").replace("```", ""), {"plt": plt, "sns": sns, "pd": pd, "st": st, "np": np, "os": os}, exec_globals)
                except Exception as e:
                    pass

    except Exception as e:
        print(f"Unexpected error: {e}")

    # Display Response
    memory_buffer.add_message(role='User', content=user_question)
    memory_buffer.add_message(role='Assistant', content=bi_analysis_result['business_analysis']['summary'])

Understanding Intent...
{'intent': 'TEXT_TO_SQL',
 'reasoning': "The question needs to be specific about the metric for 'top 5 "
              "products' and the regions. 'Regions' is rephrased to "
              "'countries' to match the schema.",
 'rephrased_question': 'How do the sales of the top 5 products by quantity '
                       'compare across different countries?'}
Reasoning Optimal Query Plan...
Fixing Queries... SQLQueryReasoningGenerator.generate_reasoning() got an unexpected keyword argument 'chat_memory'
[{'query': 'SELECT sr.product_name, sr.country_id, SUM(sr.quantity) AS '
           'total_quantity, SUM(sr.line_amount) AS total_sales_amount FROM '
           'sales_report sr JOIN (SELECT product_name, SUM(quantity) AS '
           'total_quantity FROM sales_report GROUP BY product_name ORDER BY '
           'total_quantity DESC LIMIT 5) top_products ON sr.product_name = '
           'top_products.product_name GROUP BY sr.product_name, sr.country_id '
      

In [4]:
intent_analysis

{'rephrased_question': 'How do the sales of the top 5 products by revenue compare across different countries?',
 'reasoning': 'The question needs to be specific about the metric (revenue) and the geographical division (countries) to align with the schema.',
 'intent': 'TEXT_TO_SQL'}

In [5]:
fixed_queries

[]

In [6]:
memory_buffer

[2025-04-03 09:33:18.538505+00:00] User: How do the sales of top 5 products compare across different regions?
[2025-04-03 09:33:18.538505+00:00] Assistant: Error occurred