In [1]:
import sys
import os

project_root = os.path.abspath(os.path.join(os.getcwd(), "../../.."))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

print(f"Added to path: {project_root}")

Added to path: /home/yuvrajsinh/repos/factly/gopie/chat-server


In [2]:
from app.tests.e2e.multi_dataset_cases import COMPLEX_QUERY_CASES

TEST_CASES = COMPLEX_QUERY_CASES

In [3]:
from dotenv import load_dotenv
import os

load_dotenv()

api_key = os.getenv("PORTKEY_API_KEY")
virtual_key = os.getenv("GEMINI_VIRTUAL_KEY")

In [4]:
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from portkey_ai import PORTKEY_GATEWAY_URL, createHeaders

model = ChatOpenAI(
    api_key="X",  # type: ignore
    base_url=PORTKEY_GATEWAY_URL,
    default_headers=createHeaders(
        api_key=api_key,
        virtual_key=virtual_key,
    ),
    model="gemini-2.5-flash-preview-04-17",
)

llm = model

In [5]:
template = """
    You are a judgemental data analyst assistant.

    You will be given:
    1. The answer generated by our system in response to a user query
    2. Details about the expected result

    generated_answer: {generated_answer}
    expected_result: {expected_result}

    The expected_result may contain some or all of these details:
    - dataset_identified: which dataset should be used (only present for multi-dataset cases)
    - sql_query_count: number of SQL queries expected
    - visualization_needed: whether visualization is required
    - visualization_type: what type of visualization is appropriate

    Your task is to evaluate if the generated answer meets the expected criteria:

    RETURN 'true' if:
    - The answer correctly identified the expected dataset (if dataset_identified is specified)
    - The number of SQL queries used is appropriate (matching sql_query_count if specified)
    - Visualization recommendations (if applicable) match the expected result
    - The numerical values are within acceptable ranges
    - The SQL queries were properly extracted and reported in the response

    RETURN 'false' if:
    - The answer identified the wrong dataset (when dataset_identified is specified)
    - The SQL query count is significantly off
    - Visualization was needed but not recommended (or vice versa)
    - The answer is factually incorrect
    - No SQL queries were detected when they should have been

    RETURN 'partial' if:
    - The answer is factually correct but there are minor issues with:
      - Dataset identification (when applicable)
      - SQL query count
      - Visualization recommendations
    - Or if the answer is partially correct but missing some components

    RETURN the response as JSON object with the following keys:
    {{
        "correct": "true" | "false" | "partial",
        "reasoning": "reasoning for why the answer is incorrect or partial,
                      no reasoning needed for correct answers",
    }}
"""

prompt = ChatPromptTemplate.from_template(template)
parser = JsonOutputParser()
chain = prompt | llm | parser

In [None]:
import json
import requests
import traceback

from app.core.constants import (
    INTERMEDIATE_MESSAGES,
    SQL_QUERIES_GENERATED,
    DATASETS_USED,
)

url = "http://localhost:8001/api/v1/chat/completions"

def process_tool_calls(tool_calls):
    tool_messages = []
    generated_sql_queries = []
    selected_datasets = []

    if not tool_calls:
        return tool_messages, generated_sql_queries, selected_datasets

    for tool_call in tool_calls:
        if not tool_call or "function" not in tool_call:
            continue

        function_data = tool_call.get("function", {})
        name = function_data.get("name")

        try:
            args = json.loads(function_data.get("arguments", "{}"))

            if name == SQL_QUERIES_GENERATED:
                if "query" in args:
                    query = args.get("query")
                    generated_sql_queries.append(query)
                elif "queries" in args:
                    queries = args.get("queries")
                    if isinstance(queries, list):
                        generated_sql_queries.extend(queries)
                    else:
                        query = str(queries)
                        generated_sql_queries.append(query)

            elif name == DATASETS_USED:
                datasets = args.get("datasets", [])
                if datasets:
                    if isinstance(datasets, list):
                        selected_datasets.extend(datasets)
                    else:
                        selected_datasets.append(str(datasets))

            elif name == INTERMEDIATE_MESSAGES:
                content = args.get("content", "")
                category = args.get("category", "")
                if args.get("role") == "intermediate":
                    tool_messages.append(content)

                    if category == DATASETS_USED or (content and "dataset" in content.lower()):
                        if "datasets" in args:
                            datasets = args.get("datasets")
                            if isinstance(datasets, list):
                                selected_datasets.extend(datasets)
                            else:
                                selected_datasets.append(str(datasets))
                        elif ":" in content:
                            if "using dataset" in content.lower() or "selected dataset" in content.lower():
                                datasets_part = content.split(":", 1)[1].strip()
                                selected_datasets.append(datasets_part)

                    if category == SQL_QUERIES_GENERATED or (content and (("sql" in content.lower()) or ("query" in content.lower()))):
                        if "query" in args:
                            query = args.get("query")
                            generated_sql_queries.append(query)
                        elif "queries" in args:
                            queries = args.get("queries")
                            if isinstance(queries, list):
                                generated_sql_queries.extend(queries)
                            else:
                                query = str(queries)
                                generated_sql_queries.append(query)
                        elif "SELECT" in content.upper() or "FROM" in content.upper() or "WHERE" in content.upper():
                            generated_sql_queries.append(content)

        except json.JSONDecodeError:
            continue

    return tool_messages, generated_sql_queries, selected_datasets

async def process_test_cases(test_cases):
    for index, query in enumerate(test_cases):
        query_copy = query.copy()
        expected_result = query_copy.get("expected_result", {})
        if "expected_result" in query_copy:
            del query_copy["expected_result"]

        print(f"{index+1}/{len(test_cases)}: Processing query: {query['messages'][0]['content']}")

        try:
            response = requests.post(
                url,
                json={
                    **query_copy,
                    "chat_id": "",
                    "trace_id": "",
                },
                headers={
                    "Content-Type": "application/json",
                    "Accept": "text/event-stream"
                },
                stream=True
            )
            response.raise_for_status()

            final_response = ""
            tool_messages = []
            generated_sql_queries = []
            selected_datasets = []

            for line in response.iter_lines():
                if not line:
                    continue

                decoded_line = line.decode('utf-8').strip()
                if not decoded_line.startswith("data:"):
                    continue

                if decoded_line == "data: [DONE]":
                    break

                try:
                    chunk_data = json.loads(decoded_line[len("data: "):])

                    if "choices" in chunk_data and chunk_data["choices"]:
                        delta = chunk_data["choices"][0].get("delta", {})

                        if "content" in delta and delta["content"] is not None:
                            final_response += delta["content"] or ""

                        tool_calls = delta.get("tool_calls")
                        if tool_calls is not None:
                            msgs, queries, datasets = process_tool_calls(tool_calls)
                            tool_messages.extend(msgs)
                            generated_sql_queries.extend(queries)
                            selected_datasets.extend(datasets)
                except json.JSONDecodeError:
                    continue

            final_response += f"\n     used_datasets: {selected_datasets}"
            final_response += f"\n     generated_sql queries: {generated_sql_queries}"
            final_response += f"\n     tool_messages: {tool_messages}"

            print(f"    Final AI Response: {final_response}")

            result = await chain.ainvoke({
                "generated_answer": final_response,
                "expected_result": expected_result
            })

            if result["correct"] == "true":
                print("    ✅ Query passed")
                print(f"    Used dataset: {', '.join(selected_datasets) if selected_datasets else 'None detected'}")
                print(f"    SQL queries: {len(generated_sql_queries)}")
            elif result["correct"] == "false":
                print("    ❌ Query failed")
                print(f"    Reasoning: {result.get('reasoning', 'No reasoning provided')}")
                print(f"    Used dataset: {', '.join(selected_datasets) if selected_datasets else 'None detected'}")
                print(f"    SQL queries: {len(generated_sql_queries)}")
            elif result["correct"] == "partial":
                print("    🟡 Query partially correct")
                print(f"    Reasoning: {result.get('reasoning', 'No reasoning provided')}")
                print(f"    Used dataset: {', '.join(selected_datasets) if selected_datasets else 'None detected'}")
                print(f"    SQL queries: {len(generated_sql_queries)}")

            if 'dataset_identified' in expected_result:
                print(f"    Expected dataset: {expected_result['dataset_identified']}")
            else:
                print("    Expected dataset: Not applicable for single dataset case")
            print(f"    Expected SQL count: {expected_result.get('sql_query_count', 'Not specified')}")
            if expected_result.get('visualization_needed'):
                print(f"    Expected visualization: {expected_result.get('visualization_type', 'Not specified')}")

            print("")
        except Exception as e:
            print(f"API request failed: {str(e)}")
            print(traceback.format_exc())
            print(f"For query: {query['messages'][0]['content']}")

    print("\n--------------------------------\n\n")

await process_test_cases(TEST_CASES)


1/3: Processing query: First find the top 5 companies with highest average net profit in 2022-23, then analyze their CSR spending trends from 2016 to 2023
    Final AI Response: I was unable to retrieve the CSR spending trends for the top 5 companies with the highest average net profit in 2022-23.

To get the information you're looking for, please try breaking down your request into two steps:

1.  First, ask: "What are the top 5 companies with the highest average net profit in 2022-23?"
2.  Once you have that list, you can then ask for the CSR spending trends for those specific companies, for example: "What was the CSR spending trend for [Company A], [Company B], [Company C], [Company D], and [Company E] from 2016 to 2023?"

This two-step approach should help in fetching the required details.
     used_datasets: ['9975bec1-1133-4b94-8ab2-179bba9ba887', '9975bec1-1133-4b94-8ab2-179bba9ba887']
     generated_sql queries: ["WITH top_companies AS (\n  SELECT company_name\n  FROM gp_JrlIdI