From 7b066323ce7e0cc6b7270c6187647b41e5f13689 Mon Sep 17 00:00:00 2001 From: cosmin chauciuc Date: Fri, 5 Jun 2026 15:46:41 +0300 Subject: [PATCH] Add conversational Assistant panel to /query MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit A chat panel where editors ask data questions and grow the semantic layer in plain language. One stateless POST /query/assistant turn runs a single structured-JSON LLM call that classifies intent and extracts a draft; the assistant only drafts — every write rides the existing REST + auto-embed endpoints, so authorization is unchanged. Intents: question -> sql_preview (reuses generate_sql_only), and editor-gated glossary/metric/dictionary/knowledge drafts. Dictionary resolves (table_name, column_name) -> column_id server-side and supports multiple value mappings per column. Viewers get questions only. Frontend: AssistantPanel renders per-intent confirmation cards that POST to the existing create endpoints and invalidate the relevant query caches; QueryResultView extracted from QueryPage for reuse. MCP stays external-only. Tests: 33 unit tests (agent normalizers + downgrade, service branching incl. dictionary column resolution). Full backend suite green (131); frontend tsc + eslint + build clean. Co-Authored-By: Claude Opus 4.8 (1M context) --- backend/app/api/v1/endpoints/assistant.py | 31 + backend/app/api/v1/router.py | 2 + backend/app/api/v1/schemas/assistant.py | 111 ++++ backend/app/llm/agents/assistant_router.py | 175 +++++ backend/app/llm/prompts/assistant_prompts.py | 72 +++ backend/app/services/assistant_service.py | 122 ++++ backend/tests/test_assistant_agent.py | 173 +++++ backend/tests/test_assistant_service.py | 185 ++++++ docs/assistant-plan.md | 189 ++++++ frontend/src/api/assistantApi.ts | 10 + .../components/assistant/AssistantPanel.tsx | 611 ++++++++++++++++++ .../src/components/query/QueryResultView.tsx | 125 ++++ frontend/src/pages/QueryPage.tsx | 125 +--- frontend/src/types/api.ts | 62 ++ 14 files changed, 1874 insertions(+), 119 deletions(-) create mode 100644 backend/app/api/v1/endpoints/assistant.py create mode 100644 backend/app/api/v1/schemas/assistant.py create mode 100644 backend/app/llm/agents/assistant_router.py create mode 100644 backend/app/llm/prompts/assistant_prompts.py create mode 100644 backend/app/services/assistant_service.py create mode 100644 backend/tests/test_assistant_agent.py create mode 100644 backend/tests/test_assistant_service.py create mode 100644 docs/assistant-plan.md create mode 100644 frontend/src/api/assistantApi.ts create mode 100644 frontend/src/components/assistant/AssistantPanel.tsx create mode 100644 frontend/src/components/query/QueryResultView.tsx diff --git a/backend/app/api/v1/endpoints/assistant.py b/backend/app/api/v1/endpoints/assistant.py new file mode 100644 index 0000000..2542c5b --- /dev/null +++ b/backend/app/api/v1/endpoints/assistant.py @@ -0,0 +1,31 @@ +from fastapi import APIRouter, Depends +from sqlalchemy.ext.asyncio import AsyncSession + +from app.api.v1.schemas.assistant import AssistantRequest, AssistantResponse +from app.core.auth import AuthContext, get_org_context +from app.db.session import get_db +from app.services import assistant_service + +# Mounted under ``/query`` so it inherits the query rate limiter, and uses +# ``get_org_context`` like the other LLM query endpoints — the service resolves +# and authorizes the connection from the body. +router = APIRouter(prefix="/query", tags=["assistant"]) + + +@router.post("/assistant", response_model=AssistantResponse) +async def assistant_turn( + body: AssistantRequest, + ctx: AuthContext = Depends(get_org_context), + db: AsyncSession = Depends(get_db), +): + """One conversational turn: classify the message and return ``{message, action?}``. + + ``action`` is either a ``sql_preview`` (confirm → run via /query/execute-sql) or a + ``glossary_draft`` (confirm → create via /connections/{id}/glossary). The assistant + itself never writes. + """ + history = [m.model_dump() for m in body.history] + result = await assistant_service.handle_turn( + db, body.connection_id, body.message, history, ctx + ) + return AssistantResponse(**result) diff --git a/backend/app/api/v1/router.py b/backend/app/api/v1/router.py index f237687..0a1e572 100644 --- a/backend/app/api/v1/router.py +++ b/backend/app/api/v1/router.py @@ -2,6 +2,7 @@ from app.api.v1.endpoints import ( api_keys, + assistant, auth, connections, dictionary, @@ -23,6 +24,7 @@ api_router.include_router(teams.router) api_router.include_router(api_keys.router) api_router.include_router(query.router) +api_router.include_router(assistant.router) api_router.include_router(connections.router) api_router.include_router(schemas.router) api_router.include_router(glossary.router) diff --git a/backend/app/api/v1/schemas/assistant.py b/backend/app/api/v1/schemas/assistant.py new file mode 100644 index 0000000..1b438b0 --- /dev/null +++ b/backend/app/api/v1/schemas/assistant.py @@ -0,0 +1,111 @@ +from typing import Annotated, Literal +from uuid import UUID + +from pydantic import BaseModel, Field + + +class AssistantMessage(BaseModel): + role: Literal["user", "assistant"] + content: str = Field(min_length=1, max_length=4000) + + +class AssistantRequest(BaseModel): + connection_id: UUID + message: str = Field(min_length=1, max_length=2000) + history: list[AssistantMessage] = Field(default_factory=list, max_length=20) + + +# --- Draft payloads ------------------------------------------------------- +# Each mirrors the corresponding *Create schema (a usable subset) so the +# frontend can POST it to the existing REST endpoint after confirmation. + + +class GlossaryDraft(BaseModel): + term: str + definition: str + sql_expression: str = "" + related_tables: list[str] = Field(default_factory=list) + related_columns: list[str] = Field(default_factory=list) + + +class MetricDraft(BaseModel): + metric_name: str + display_name: str + description: str = "" + sql_expression: str = "" + aggregation_type: str = "" + related_tables: list[str] = Field(default_factory=list) + dimensions: list[str] = Field(default_factory=list) + + +class DictionaryEntryDraft(BaseModel): + raw_value: str + display_value: str + description: str = "" + + +class DictionaryDraft(BaseModel): + """Coded-value mappings for one column. + + ``column_id`` is resolved server-side from ``table_name``/``column_name``; + the frontend POSTs each entry to ``/columns/{column_id}/dictionary``. + """ + + column_id: UUID + table_name: str + column_name: str + entries: list[DictionaryEntryDraft] + + +class KnowledgeDraft(BaseModel): + title: str + content: str + source_url: str = "" + + +class SqlPreviewPayload(BaseModel): + sql: str + explanation: str = "" + + +# --- Action discriminated union ------------------------------------------ + + +class GlossaryDraftAction(BaseModel): + type: Literal["glossary_draft"] = "glossary_draft" + payload: GlossaryDraft + + +class MetricDraftAction(BaseModel): + type: Literal["metric_draft"] = "metric_draft" + payload: MetricDraft + + +class DictionaryDraftAction(BaseModel): + type: Literal["dictionary_draft"] = "dictionary_draft" + payload: DictionaryDraft + + +class KnowledgeDraftAction(BaseModel): + type: Literal["knowledge_draft"] = "knowledge_draft" + payload: KnowledgeDraft + + +class SqlPreviewAction(BaseModel): + type: Literal["sql_preview"] = "sql_preview" + payload: SqlPreviewPayload + + +AssistantAction = Annotated[ + GlossaryDraftAction + | MetricDraftAction + | DictionaryDraftAction + | KnowledgeDraftAction + | SqlPreviewAction, + Field(discriminator="type"), +] + + +class AssistantResponse(BaseModel): + message: str + action: AssistantAction | None = None diff --git a/backend/app/llm/agents/assistant_router.py b/backend/app/llm/agents/assistant_router.py new file mode 100644 index 0000000..b5a92f7 --- /dev/null +++ b/backend/app/llm/agents/assistant_router.py @@ -0,0 +1,175 @@ +"""Agent: Assistant Router — classifies a chat message and extracts drafts. + +One structured-JSON LLM call. Mirrors :class:`QueryComposerAgent`: build messages, +call ``provider.complete``, parse with ``repair_json`` and degrade gracefully to a +plain chat reply when the model returns non-JSON or an unusable draft. +""" + +import json +from dataclasses import dataclass + +from app.llm.base_provider import BaseLLMProvider, LLMConfig, LLMMessage +from app.llm.prompts.assistant_prompts import SYSTEM_PROMPT, USER_PROMPT_TEMPLATE +from app.llm.utils import repair_json + +# Intents that carry a structured draft payload. +DRAFT_INTENTS = {"glossary", "metric", "dictionary", "knowledge"} +VALID_INTENTS = {"question", "chat"} | DRAFT_INTENTS + + +@dataclass +class AssistantDecision: + intent: str + message: str + payload: dict | None = None + + +class AssistantAgent: + def __init__(self, provider: BaseLLMProvider, config: LLMConfig): + self.provider = provider + self.config = config + + async def route( + self, + message: str, + context: str, + history: list[dict] | None = None, + ) -> AssistantDecision: + """Classify ``message`` and extract a draft payload when applicable.""" + user_prompt = USER_PROMPT_TEMPLATE.format( + context=context, + history=_format_history(history), + message=message, + ) + messages = [ + LLMMessage(role="system", content=SYSTEM_PROMPT), + LLMMessage(role="user", content=user_prompt), + ] + + response = await self.provider.complete(messages, self.config) + + try: + parsed = json.loads(repair_json(response.content)) + except (json.JSONDecodeError, TypeError): + # Non-JSON response — treat the raw text as a plain chat reply. + return AssistantDecision(intent="chat", message=response.content.strip()) + + intent = parsed.get("intent") + if intent not in VALID_INTENTS: + intent = "chat" + + message_text = (parsed.get("message") or "").strip() + payload = None + + if intent in DRAFT_INTENTS: + payload = _NORMALIZERS[intent](parsed.get("payload")) + if payload is None: + # Model claimed a draft intent but gave no usable payload. + intent = "chat" + + return AssistantDecision(intent=intent, message=message_text, payload=payload) + + +def _format_history(history: list[dict] | None) -> str: + if not history: + return "(no prior messages)" + lines = [] + for turn in history: + role = turn.get("role", "user") + content = (turn.get("content") or "").strip() + if content: + lines.append(f"{role}: {content}") + return "\n".join(lines) if lines else "(no prior messages)" + + +def _str(value) -> str: + return str(value).strip() if value is not None else "" + + +def _as_list(value) -> list[str]: + if isinstance(value, list): + return [str(v).strip() for v in value if str(v).strip()] + if isinstance(value, str) and value.strip(): + return [value.strip()] + return [] + + +def _normalize_glossary(raw) -> dict | None: + if not isinstance(raw, dict): + return None + term, definition = _str(raw.get("term")), _str(raw.get("definition")) + if not term or not definition: + return None + return { + "term": term, + "definition": definition, + "sql_expression": _str(raw.get("sql_expression")), + "related_tables": _as_list(raw.get("related_tables")), + "related_columns": _as_list(raw.get("related_columns")), + } + + +def _normalize_metric(raw) -> dict | None: + if not isinstance(raw, dict): + return None + display_name = _str(raw.get("display_name")) or _str(raw.get("metric_name")) + sql_expression = _str(raw.get("sql_expression")) + if not display_name or not sql_expression: + return None + metric_name = _str(raw.get("metric_name")) or _slug(display_name) + return { + "metric_name": metric_name, + "display_name": display_name, + "description": _str(raw.get("description")), + "sql_expression": sql_expression, + "aggregation_type": _str(raw.get("aggregation_type")), + "related_tables": _as_list(raw.get("related_tables")), + "dimensions": _as_list(raw.get("dimensions")), + } + + +def _normalize_dictionary(raw) -> dict | None: + if not isinstance(raw, dict): + return None + table_name, column_name = _str(raw.get("table_name")), _str(raw.get("column_name")) + if not table_name or not column_name: + return None + entries = [] + for item in raw.get("entries") or []: + if not isinstance(item, dict): + continue + raw_value, display_value = _str(item.get("raw_value")), _str(item.get("display_value")) + if not raw_value or not display_value: + continue + entries.append( + { + "raw_value": raw_value, + "display_value": display_value, + "description": _str(item.get("description")), + } + ) + if not entries: + return None + return {"table_name": table_name, "column_name": column_name, "entries": entries} + + +def _normalize_knowledge(raw) -> dict | None: + if not isinstance(raw, dict): + return None + title, content = _str(raw.get("title")), _str(raw.get("content")) + if not title or not content: + return None + return {"title": title, "content": content, "source_url": _str(raw.get("source_url"))} + + +_NORMALIZERS = { + "glossary": _normalize_glossary, + "metric": _normalize_metric, + "dictionary": _normalize_dictionary, + "knowledge": _normalize_knowledge, +} + + +def _slug(text: str) -> str: + cleaned = "".join(c if c.isalnum() else "_" for c in text.lower()) + return "_".join(part for part in cleaned.split("_") if part) or "metric" diff --git a/backend/app/llm/prompts/assistant_prompts.py b/backend/app/llm/prompts/assistant_prompts.py new file mode 100644 index 0000000..9bdead9 --- /dev/null +++ b/backend/app/llm/prompts/assistant_prompts.py @@ -0,0 +1,72 @@ +"""Prompts for the conversational Assistant router. + +A single structured-JSON call classifies the user's latest message into an intent +and, for semantic-layer additions, extracts a structured draft grounded in the +connection's real schema + semantic context. +""" + +SYSTEM_PROMPT = """You are the QueryWise Assistant, embedded in a text-to-SQL app with a business \ +semantic layer. The user is talking to you in a chat box scoped to one database connection. + +Classify the user's LATEST message into exactly one intent and respond with JSON containing +"intent", "message", and "payload" (the intent-specific object, or null). + +Intents and their payloads: + +1. "question" — the user is asking a data question that should be answered with a SQL query + (e.g. "what is total ECL by stage?"). Do NOT write SQL yourself; the app generates it. + payload: null. "message": a short friendly lead-in (e.g. "Here's a query for that:"). + +2. "glossary" — define/add a business glossary term (e.g. "NPL means loans where stage = 3"). + payload: { + "term": "...", "definition": "...", + "sql_expression": "SQL condition using REAL column names, e.g. stage = 3", + "related_tables": ["..."], "related_columns": ["..."] + } + +3. "metric" — define/add a reusable metric (an aggregate measure, e.g. + "ECL coverage ratio = sum(ecl)/sum(exposure)"). + payload: { + "metric_name": "machine_friendly_snake_case", + "display_name": "Human Friendly Name", + "description": "what it measures", + "sql_expression": "the SQL aggregate expression using REAL columns", + "aggregation_type": "SUM | AVG | COUNT | RATIO | ... or empty", + "related_tables": ["..."], "dimensions": ["columns it can be grouped by"] + } + +4. "dictionary" — explain coded/enumerated column VALUES (e.g. + "in the stage column, 1 = performing, 2 = underperforming, 3 = non-performing"). + payload: { + "table_name": "real table name", "column_name": "real column name", + "entries": [ {"raw_value": "1", "display_value": "Performing", "description": ""}, ... ] + } + +5. "knowledge" — add a knowledge/document snippet to the context (policy text, notes, + definitions in prose). e.g. "remember this about IFRS 9: ...". + payload: { "title": "short title", "content": "the document body", "source_url": "" } + +6. "chat" — anything else: greetings, capability questions, clarifications, or requests you + cannot fulfil. payload: null. Put your full natural-language reply in "message". + +Rules: +- Only ground table/column names to those that appear in the provided context. +- For "metric" set a sensible machine-friendly metric_name if the user didn't give one. +- For "dictionary" put EVERY value the user mentioned as a separate entry. +- For draft intents, set "message" to a short confirmation lead-in + (e.g. "Here's a draft — review and confirm:"). Set "payload" to null for question/chat. + +Respond with ONLY a JSON object: +{ "intent": "...", "message": "...", "payload": { ... } | null }""" + +USER_PROMPT_TEMPLATE = """Connection semantic context (real schema, glossary, metrics): + +{context} + +Recent conversation (oldest first): +{history} + +User's latest message: +"{message}" + +Respond with a JSON object: intent, message, payload (or null).""" diff --git a/backend/app/services/assistant_service.py b/backend/app/services/assistant_service.py new file mode 100644 index 0000000..393eabd --- /dev/null +++ b/backend/app/services/assistant_service.py @@ -0,0 +1,122 @@ +"""Assistant Service — orchestrates one conversational turn. + +Stateless: the caller passes recent history each turn. The assistant *never* +writes — it only classifies intent and drafts. Writes (creating a glossary term, +metric, dictionary entry, or knowledge document) ride the existing REST endpoints, +so authorization + embedding stay on their established paths. + +Returns a dict shaped as ``{message, action?}`` where ``action`` is a +discriminated union (``sql_preview`` | ``glossary_draft`` | ``metric_draft`` | +``dictionary_draft`` | ``knowledge_draft``). +""" + +import uuid + +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.auth import AuthContext +from app.core.telemetry import start_span +from app.db.models.membership import ROLE_EDITOR +from app.db.models.schema_cache import CachedColumn, CachedTable +from app.llm.agents.assistant_router import AssistantAgent +from app.llm.router import route +from app.semantic.context_builder import build_context +from app.services.connection_service import get_connection +from app.services.query_service import generate_sql_only + +_NO_EDITOR = ( + "Adding to the semantic layer needs editor access on this workspace. " + "I can still help you query the data." +) + + +async def handle_turn( + db: AsyncSession, + connection_id: uuid.UUID, + message: str, + history: list[dict], + ctx: AuthContext, +) -> dict: + """Classify the user's message and return ``{message, action?}``. + + - ``question`` → generate SQL (reuses ``generate_sql_only``) → ``sql_preview`` + - ``glossary`` / ``metric`` / ``dictionary`` / ``knowledge`` → draft action (editors only) + - ``chat`` → plain message, no action + """ + # Read-authz + dialect. Raises 404/403 if the caller can't read the connection. + conn = await get_connection(db, connection_id, ctx) + + with start_span("assistant_build_context", **{"connection_id": str(connection_id)}): + context = await build_context(db, connection_id, message, dialect=conn.connector_type) + + provider, llm_config = route(message) + agent = AssistantAgent(provider, llm_config) + with start_span("assistant_route", **{"llm.model": llm_config.model}): + decision = await agent.route(message, context.prompt_context, history) + + if decision.intent == "question": + sql = await generate_sql_only(db, connection_id, message, ctx) + if not sql.get("generated_sql"): + return {"message": "I couldn't generate a query for that. Try rephrasing?"} + return _msg( + decision.message or "Here's a query for that:", + "sql_preview", + {"sql": sql["generated_sql"], "explanation": sql.get("explanation", "")}, + ) + + if decision.intent in {"glossary", "metric", "knowledge"}: + if not ctx.has_role(ROLE_EDITOR): + return {"message": _NO_EDITOR} + action_type = f"{decision.intent}_draft" + return _msg(_confirm(decision.message), action_type, decision.payload) + + if decision.intent == "dictionary": + if not ctx.has_role(ROLE_EDITOR): + return {"message": _NO_EDITOR} + payload = decision.payload + column_id = await _resolve_column_id( + db, connection_id, payload["table_name"], payload["column_name"] + ) + if column_id is None: + return { + "message": ( + f"I couldn't find a column named '{payload['column_name']}' on table " + f"'{payload['table_name']}' in this connection. " + "Has the schema been introspected?" + ) + } + return _msg( + _confirm(decision.message), + "dictionary_draft", + {**payload, "column_id": str(column_id)}, + ) + + return {"message": decision.message or "How can I help with your data?"} + + +def _msg(message: str, action_type: str, payload: dict) -> dict: + return {"message": message, "action": {"type": action_type, "payload": payload}} + + +def _confirm(message: str) -> str: + return message or "Here's a draft — review and confirm:" + + +async def _resolve_column_id( + db: AsyncSession, + connection_id: uuid.UUID, + table_name: str, + column_name: str, +) -> uuid.UUID | None: + """Find a cached column id by (connection, table_name, column_name), case-insensitive.""" + result = await db.execute( + select(CachedColumn.id) + .join(CachedTable, CachedColumn.table_id == CachedTable.id) + .where( + CachedTable.connection_id == connection_id, + func.lower(CachedTable.table_name) == table_name.lower(), + func.lower(CachedColumn.column_name) == column_name.lower(), + ) + ) + return result.scalar_one_or_none() diff --git a/backend/tests/test_assistant_agent.py b/backend/tests/test_assistant_agent.py new file mode 100644 index 0000000..2739f09 --- /dev/null +++ b/backend/tests/test_assistant_agent.py @@ -0,0 +1,173 @@ +"""Unit tests for the Assistant router agent (no DB/LLM needed).""" + +from types import SimpleNamespace + +import pytest + +from app.llm.agents.assistant_router import ( + AssistantAgent, + _normalize_dictionary, + _normalize_glossary, + _normalize_metric, + _slug, +) +from app.llm.base_provider import LLMConfig + + +class _FakeProvider: + """Returns a canned completion regardless of input.""" + + def __init__(self, content: str): + self._content = content + + async def complete(self, messages, config): + return SimpleNamespace(content=self._content) + + +def _agent(content: str) -> AssistantAgent: + return AssistantAgent(_FakeProvider(content), LLMConfig(model="fake")) + + +async def test_question_intent(): + agent = _agent('{"intent": "question", "message": "Here you go:", "payload": null}') + decision = await agent.route("what is total ECL by stage?", context="ctx") + assert decision.intent == "question" + assert decision.message == "Here you go:" + assert decision.payload is None + + +async def test_glossary_intent_extracts_draft(): + agent = _agent( + '{"intent": "glossary", "message": "Draft:", "payload": ' + '{"term": "NPL", "definition": "non-performing loans", ' + '"sql_expression": "stage = 3", "related_tables": ["exposures"], ' + '"related_columns": ["stage"]}}' + ) + decision = await agent.route("add term NPL = stage 3", context="ctx") + assert decision.intent == "glossary" + assert decision.payload["term"] == "NPL" + assert decision.payload["sql_expression"] == "stage = 3" + assert decision.payload["related_tables"] == ["exposures"] + + +async def test_metric_intent_extracts_draft(): + agent = _agent( + '{"intent": "metric", "message": "Draft:", "payload": ' + '{"display_name": "ECL Coverage Ratio", "sql_expression": "SUM(ecl)/SUM(exposure)", ' + '"aggregation_type": "RATIO", "related_tables": ["ecl_provisions"], "dimensions": ["stage"]}}' + ) + decision = await agent.route("define ECL coverage ratio", context="ctx") + assert decision.intent == "metric" + assert decision.payload["display_name"] == "ECL Coverage Ratio" + # metric_name is derived from display_name when omitted + assert decision.payload["metric_name"] == "ecl_coverage_ratio" + assert decision.payload["dimensions"] == ["stage"] + + +async def test_dictionary_intent_extracts_entries(): + agent = _agent( + '{"intent": "dictionary", "message": "Draft:", "payload": ' + '{"table_name": "exposures", "column_name": "stage", "entries": [' + '{"raw_value": "1", "display_value": "Performing"},' + '{"raw_value": "3", "display_value": "Non-performing", "description": "default"}]}}' + ) + decision = await agent.route("stage 1 = performing, 3 = non-performing", context="ctx") + assert decision.intent == "dictionary" + assert decision.payload["column_name"] == "stage" + assert len(decision.payload["entries"]) == 2 + assert decision.payload["entries"][1]["display_value"] == "Non-performing" + + +async def test_knowledge_intent_extracts_draft(): + agent = _agent( + '{"intent": "knowledge", "message": "Draft:", "payload": ' + '{"title": "IFRS 9 staging", "content": "Assets move through three stages..."}}' + ) + decision = await agent.route("remember this about IFRS 9...", context="ctx") + assert decision.intent == "knowledge" + assert decision.payload["title"] == "IFRS 9 staging" + assert decision.payload["source_url"] == "" + + +async def test_chat_intent(): + agent = _agent('{"intent": "chat", "message": "Hi! Ask me about your data.", "payload": null}') + decision = await agent.route("hello", context="ctx") + assert decision.intent == "chat" + assert "Ask me" in decision.message + + +async def test_malformed_json_degrades_to_chat(): + agent = _agent("Sorry, I'm just going to talk normally here.") + decision = await agent.route("hello", context="ctx") + assert decision.intent == "chat" + assert decision.message == "Sorry, I'm just going to talk normally here." + + +async def test_unknown_intent_falls_back_to_chat(): + agent = _agent('{"intent": "delete_everything", "message": "no", "payload": null}') + decision = await agent.route("x", context="ctx") + assert decision.intent == "chat" + + +@pytest.mark.parametrize( + "intent,payload", + [ + ("glossary", '{"definition": "x"}'), # missing term + ("metric", '{"display_name": "X"}'), # missing sql_expression + ("dictionary", '{"table_name": "t", "column_name": "c", "entries": []}'), # no entries + ("knowledge", '{"title": "t"}'), # missing content + ], +) +async def test_draft_intent_without_usable_payload_downgrades(intent, payload): + agent = _agent(f'{{"intent": "{intent}", "message": "hm", "payload": {payload}}}') + decision = await agent.route("x", context="ctx") + assert decision.intent == "chat" + assert decision.payload is None + + +@pytest.mark.parametrize( + "raw,expected_tables", + [ + ({"term": "t", "definition": "d", "related_tables": "exposures"}, ["exposures"]), + ({"term": "t", "definition": "d", "related_tables": ["a", "", "b"]}, ["a", "b"]), + ({"term": "t", "definition": "d"}, []), + ], +) +def test_normalize_glossary_coerces_lists(raw, expected_tables): + assert _normalize_glossary(raw)["related_tables"] == expected_tables + + +def test_normalize_glossary_requires_term_and_definition(): + assert _normalize_glossary({"term": "", "definition": "d"}) is None + assert _normalize_glossary({"term": "t", "definition": ""}) is None + assert _normalize_glossary("not a dict") is None + + +def test_normalize_metric_derives_name_and_requires_sql(): + assert _normalize_metric({"display_name": "My Ratio", "sql_expression": "a/b"})[ + "metric_name" + ] == "my_ratio" + assert _normalize_metric({"display_name": "X"}) is None # no sql + assert _normalize_metric({"sql_expression": "a"}) is None # no name + + +def test_normalize_dictionary_drops_incomplete_entries(): + result = _normalize_dictionary( + { + "table_name": "t", + "column_name": "c", + "entries": [ + {"raw_value": "1", "display_value": "One"}, + {"raw_value": "2"}, # missing display_value — dropped + ], + } + ) + assert len(result["entries"]) == 1 + + +@pytest.mark.parametrize( + "text,expected", + [("ECL Coverage Ratio", "ecl_coverage_ratio"), (" ", "metric"), ("A/B %", "a_b")], +) +def test_slug(text, expected): + assert _slug(text) == expected diff --git a/backend/tests/test_assistant_service.py b/backend/tests/test_assistant_service.py new file mode 100644 index 0000000..428778e --- /dev/null +++ b/backend/tests/test_assistant_service.py @@ -0,0 +1,185 @@ +"""Unit tests for assistant_service.handle_turn branching (no DB/LLM needed). + +The service's collaborators (connection lookup, context builder, LLM router, +agent, SQL generator, column resolver) are monkeypatched, so we assert only the +orchestration: intent → action shape, the editor-only gate, and dictionary's +column resolution / downgrade. +""" + +import uuid +from types import SimpleNamespace + +import pytest + +from app.db.models.membership import ROLE_EDITOR, ROLE_VIEWER +from app.llm.agents.assistant_router import AssistantDecision +from app.services import assistant_service + + +def _ctx(role: str): + rank = {ROLE_VIEWER: 1, ROLE_EDITOR: 2, "admin": 3} + return SimpleNamespace( + role=role, + has_role=lambda minimum, _r=role: rank[_r] >= rank[minimum], + ) + + +@pytest.fixture(autouse=True) +def _patch_collaborators(monkeypatch): + async def fake_get_connection(db, connection_id, ctx, write=False): + return SimpleNamespace(connector_type="postgresql") + + async def fake_build_context(db, connection_id, message, dialect): + return SimpleNamespace(prompt_context="CTX") + + def fake_route(message): + return (object(), SimpleNamespace(model="fake")) + + monkeypatch.setattr(assistant_service, "get_connection", fake_get_connection) + monkeypatch.setattr(assistant_service, "build_context", fake_build_context) + monkeypatch.setattr(assistant_service, "route", fake_route) + + +def _patch_decision(monkeypatch, decision: AssistantDecision): + class _FakeAgent: + def __init__(self, *_a, **_k): + pass + + async def route(self, *_a, **_k): + return decision + + monkeypatch.setattr(assistant_service, "AssistantAgent", _FakeAgent) + + +async def _run(role=ROLE_EDITOR, message="x"): + return await assistant_service.handle_turn( + db=None, connection_id=uuid.uuid4(), message=message, history=[], ctx=_ctx(role) + ) + + +# --- question ---------------------------------------------------------------- + + +async def test_question_returns_sql_preview(monkeypatch): + _patch_decision(monkeypatch, AssistantDecision(intent="question", message="Here:")) + + async def fake_sql_only(db, connection_id, message, ctx): + return {"generated_sql": "SELECT 1", "explanation": "ones"} + + monkeypatch.setattr(assistant_service, "generate_sql_only", fake_sql_only) + + result = await _run() + assert result["action"]["type"] == "sql_preview" + assert result["action"]["payload"]["sql"] == "SELECT 1" + + +async def test_question_with_no_sql_returns_plain_message(monkeypatch): + _patch_decision(monkeypatch, AssistantDecision(intent="question", message="Here:")) + + async def fake_sql_only(db, connection_id, message, ctx): + return {"generated_sql": "", "explanation": ""} + + monkeypatch.setattr(assistant_service, "generate_sql_only", fake_sql_only) + + result = await _run() + assert "action" not in result + + +# --- glossary / metric / knowledge (uniform draft branch) -------------------- + + +async def test_glossary_editor_gets_draft(monkeypatch): + draft = {"term": "NPL", "definition": "d", "sql_expression": "stage = 3", + "related_tables": [], "related_columns": []} + _patch_decision(monkeypatch, AssistantDecision("glossary", "Draft:", payload=draft)) + result = await _run() + assert result["action"]["type"] == "glossary_draft" + assert result["action"]["payload"]["term"] == "NPL" + + +async def test_metric_editor_gets_draft(monkeypatch): + draft = {"metric_name": "ecl_ratio", "display_name": "ECL Ratio", "description": "", + "sql_expression": "SUM(a)/SUM(b)", "aggregation_type": "RATIO", + "related_tables": [], "dimensions": []} + _patch_decision(monkeypatch, AssistantDecision("metric", "Draft:", payload=draft)) + result = await _run() + assert result["action"]["type"] == "metric_draft" + assert result["action"]["payload"]["metric_name"] == "ecl_ratio" + + +async def test_knowledge_editor_gets_draft(monkeypatch): + draft = {"title": "IFRS 9", "content": "body", "source_url": ""} + _patch_decision(monkeypatch, AssistantDecision("knowledge", "Draft:", payload=draft)) + result = await _run() + assert result["action"]["type"] == "knowledge_draft" + assert result["action"]["payload"]["title"] == "IFRS 9" + + +@pytest.mark.parametrize("intent", ["glossary", "metric", "knowledge"]) +async def test_draft_viewer_is_downgraded(monkeypatch, intent): + _patch_decision(monkeypatch, AssistantDecision(intent, "Draft:", payload={"any": "thing"})) + result = await _run(role=ROLE_VIEWER) + assert "action" not in result + assert "editor" in result["message"].lower() + + +# --- dictionary (column resolution) ------------------------------------------ + + +async def test_dictionary_resolves_column_and_drafts(monkeypatch): + draft = {"table_name": "exposures", "column_name": "stage", + "entries": [{"raw_value": "1", "display_value": "Performing", "description": ""}]} + _patch_decision(monkeypatch, AssistantDecision("dictionary", "Draft:", payload=draft)) + col_id = uuid.uuid4() + + async def fake_resolve(db, connection_id, table_name, column_name): + assert table_name == "exposures" and column_name == "stage" + return col_id + + monkeypatch.setattr(assistant_service, "_resolve_column_id", fake_resolve) + + result = await _run() + assert result["action"]["type"] == "dictionary_draft" + assert result["action"]["payload"]["column_id"] == str(col_id) + assert len(result["action"]["payload"]["entries"]) == 1 + + +async def test_dictionary_unknown_column_downgrades(monkeypatch): + draft = {"table_name": "nope", "column_name": "missing", "entries": [{"raw_value": "1", + "display_value": "x", "description": ""}]} + _patch_decision(monkeypatch, AssistantDecision("dictionary", "Draft:", payload=draft)) + + async def fake_resolve(db, connection_id, table_name, column_name): + return None + + monkeypatch.setattr(assistant_service, "_resolve_column_id", fake_resolve) + + result = await _run() + assert "action" not in result + assert "couldn't find" in result["message"].lower() + + +async def test_dictionary_viewer_downgraded_before_resolution(monkeypatch): + _patch_decision( + monkeypatch, + AssistantDecision("dictionary", "Draft:", payload={"table_name": "t", + "column_name": "c", "entries": [{"raw_value": "1", + "display_value": "x"}]}), + ) + + async def boom(*_a, **_k): + raise AssertionError("column resolution must not run for viewers") + + monkeypatch.setattr(assistant_service, "_resolve_column_id", boom) + + result = await _run(role=ROLE_VIEWER) + assert "action" not in result + + +# --- chat -------------------------------------------------------------------- + + +async def test_chat_returns_message_only(monkeypatch): + _patch_decision(monkeypatch, AssistantDecision("chat", "Hello there")) + result = await _run(role=ROLE_VIEWER) + assert result == {"message": "Hello there"} diff --git a/docs/assistant-plan.md b/docs/assistant-plan.md new file mode 100644 index 0000000..f2adf7c --- /dev/null +++ b/docs/assistant-plan.md @@ -0,0 +1,189 @@ +# Assistant Panel — Implementation Plan + +A conversational "Assistant" on the `/query` page. Users ask questions (→ existing +NL-to-SQL pipeline) and add glossary terms by talking (→ structured draft → +confirmation card → existing glossary REST POST). REST inside the app; MCP stays +external-only. + +## Principles + +- **One stateless endpoint**, one structured-JSON LLM call per turn. No tool-calling + loop, no provider changes, no chat-session table (history passed from the client). +- **The assistant never writes.** It only *drafts*. Every write rides the existing + `POST /connections/{id}/glossary` (auth + auto-embed unchanged). So the assistant + endpoint needs only **read** auth. +- **Discriminated action union** from day one (`glossary_draft | sql_preview`), so + `metric_draft` etc. are new cards later, not a refactor. +- **Reuse, don't re-derive:** `sql_preview` → existing `generate_sql_only`; + executing it → existing `execute_raw_sql` (`/query/execute-sql`). + +## Flow + +``` +user message ──► POST /connections/{id}/assistant (require_connection_read) + │ + ├─ build_context(db, id, message) # existing + ├─ AssistantAgent.route(message, ctx, history) # 1 LLM call + │ └─► {intent: question|glossary|chat, message, glossary?} + │ + ├─ intent=question → generate_sql_only() # existing, 2nd LLM call + │ └─► action = sql_preview {sql, explanation} + ├─ intent=glossary → action = glossary_draft {term,...} + │ └─ (if caller lacks write → downgrade to chat message) + └─ intent=chat → no action + ▼ + { message: string, action?: {type, payload} } + │ + frontend renders a card: + glossary_draft → editable card → [Create] → glossaryApi.create() → invalidate glossary + sql_preview → SQL card → [Execute] → queryApi.executeSql() → results inline +``` + +## Backend + +### 1. Prompt — `app/llm/prompts/assistant_prompts.py` (new) +System + user template instructing the model to return JSON: +```json +{ + "intent": "question | glossary | chat", + "message": "natural-language reply to show the user", + "glossary": { + "term": "...", "definition": "...", "sql_expression": "...", + "related_tables": ["..."], "related_columns": ["..."] + } +} +``` +- `glossary` present only when `intent == "glossary"`. +- Ground `related_tables`/`related_columns` and the `sql_expression` against the + assembled semantic context (real table/column names), same context string the + composer uses. + +### 2. Agent — `app/llm/agents/assistant_router.py` (new) +Mirror `QueryComposerAgent`: `__init__(provider, config)`, `async route(message, context, history) -> AssistantDecision`. +- `provider.complete(messages, config)` then `json.loads(repair_json(...))` with a + safe fallback (`intent="chat"`, echo message) on `JSONDecodeError`. +- `@dataclass AssistantDecision: intent: str; message: str; glossary: dict | None`. + +### 3. Service — `app/services/assistant_service.py` (new) +```python +async def handle_turn(db, connection_id, message, history, ctx) -> dict: + conn = await get_connection(db, connection_id, ctx) # read-authz + context = await build_context(db, connection_id, message, dialect=conn.connector_type) + provider, llm_config = route(message) + decision = await AssistantAgent(provider, llm_config).route(message, context.prompt_context, history) + + if decision.intent == "question": + sql = await generate_sql_only(db, connection_id, message, ctx) + return {"message": decision.message, + "action": {"type": "sql_preview", + "payload": {"sql": sql["generated_sql"], + "explanation": sql["explanation"]}}} + + if decision.intent == "glossary" and decision.glossary: + if not await _can_write(db, connection_id, ctx): # viewer downgrade + return {"message": "You need editor access to add glossary terms."} + return {"message": decision.message, + "action": {"type": "glossary_draft", "payload": decision.glossary}} + + return {"message": decision.message} # chat +``` +- `_can_write`: `try: await get_connection(db, id, ctx, write=True); return True / except AppError: return False` + (or read `ctx` role rank if exposed on `AuthContext`). + +### 4. Schemas — `app/api/v1/schemas/assistant.py` (new) +- `AssistantMessage{role: Literal["user","assistant"], content: str}` +- `AssistantRequest{message: str, history: list[AssistantMessage] = []}` +- `GlossaryDraft{term, definition, sql_expression, related_tables=[], related_columns=[]}` + (shape matches `GlossaryTermCreate` minus `examples`) +- `SqlPreviewPayload{sql, explanation}` +- `AssistantAction` = discriminated union on `type` (`glossary_draft` | `sql_preview`) +- `AssistantResponse{message: str, action: AssistantAction | None = None}` + +### 5. Endpoint — `app/api/v1/endpoints/assistant.py` (new) — AS BUILT +```python +router = APIRouter(prefix="/query", tags=["assistant"]) + +@router.post("/assistant", response_model=AssistantResponse) +async def assistant_turn(body: AssistantRequest, + ctx=Depends(get_org_context), db=Depends(get_db)): + history = [m.model_dump() for m in body.history] + return AssistantResponse(**await assistant_service.handle_turn( + db, body.connection_id, body.message, history, ctx)) +``` +- **Mounted under `/query`** with `connection_id` in the body, using `get_org_context` — + mirrors the existing query endpoints exactly. The service's `get_connection` enforces + read-authz. Chosen over a `/connections/{id}/assistant` path route because the existing + rate limiter is scoped to `/api/v1/query/*`, so the assistant gets rate-limiting for + free and stays consistent with `/query`, `/query/sql-only`, `/query/execute-sql`. +- Registered in `app/api/v1/router.py` (`assistant.router`). +- Rate limiting: no code change needed — `/api/v1/query/assistant` already matches the + `/query` prefix scope in `install_rate_limiting`. + +### Backend tests — `backend/tests/` +- Agent: feed canned provider responses (question / glossary / chat / malformed JSON) → + assert `AssistantDecision`. +- Service: monkeypatch agent + `build_context`; assert action shapes and viewer downgrade. + +## Frontend + +### 6. Types — `frontend/src/types/api.ts` (extend) +`GlossaryDraft`, `SqlPreviewPayload`, `AssistantAction` (discriminated on `type`), +`AssistantResponse`, `AssistantChatMessage{role, content, action?}`. + +### 7. API client — `frontend/src/api/assistantApi.ts` (new) +```ts +export const assistantApi = { + send: (connectionId: string, data: { message: string; history: {role:string;content:string}[] }) => + api.post(`/connections/${connectionId}/assistant`, data).then(r => r.data), +}; +``` + +### 8. Extract reusable result view +Pull `QueryResultView` out of `QueryPage.tsx` into +`frontend/src/components/query/QueryResultView.tsx` so the assistant's `sql_preview` +execution can render results with the same component. + +### 9. Component — `frontend/src/components/assistant/AssistantPanel.tsx` (new) +Props: `{ connectionId: string }`. +- Local `messages: AssistantChatMessage[]` state; input box; `useMutation` → + `assistantApi.send(connectionId, { message, history: last N })`. +- On response: append assistant message; if `action`, render the matching card: + - **GlossaryDraftCard** — editable fields (term, definition, sql_expression, + related_tables, related_columns). `[Create]` → `glossaryApi.create(connectionId, draft)`, + on success show confirmation + `queryClient.invalidateQueries(['glossary', connectionId])`. + `[Dismiss]` to discard. + - **SqlPreviewCard** — show SQL (Monaco/`Code`). `[Execute]` → + `queryApi.executeSql({connection_id, sql, original_question})` → render `` + in the thread. `[Edit]`/`[Cancel]` like the existing preview. +- Gate by role: hide the panel (or show read-only) for viewers via `useAuth().role` + rank; server already downgrades glossary drafts regardless. + +### 10. Mount on QueryPage — `frontend/src/pages/QueryPage.tsx` +Add `` as a side/below section. Inherits +the already-selected connection — no new connection selector. Defer the app-wide drawer. + +## Implemented draft intents +- `glossary` → `glossary_draft` → `POST /connections/{id}/glossary` +- `metric` → `metric_draft` → `POST /connections/{id}/metrics` +- `dictionary` → `dictionary_draft` → `POST /columns/{column_id}/dictionary` (per entry). + The service resolves `(table_name, column_name) → column_id` against the schema cache + (`_resolve_column_id`) and downgrades to a chat message if the column isn't found. + Drafts carry MULTIPLE value entries for one column. +- `knowledge` → `knowledge_draft` → `POST /connections/{id}/knowledge` +- `question` → `sql_preview` → `POST /query/execute-sql` + +All draft intents are editor-gated (`ctx.has_role(ROLE_EDITOR)`); viewers get questions only. +The agent returns a generic `AssistantDecision{intent, message, payload}`; per-intent +normalizers in `assistant_router.py` validate/shape the payload (and downgrade to chat if +unusable). + +## Out of scope (intentionally deferred) +- Sample-query drafts (same pattern — would be a 6th card). +- Persisted chat sessions (history + audit) and the app-wide drawer. +- True tool-calling agent loop (provider `tools=` support). + +## File summary +**New:** `assistant_prompts.py`, `assistant_router.py`, `assistant_service.py`, +`schemas/assistant.py`, `endpoints/assistant.py`, `assistantApi.ts`, +`components/assistant/AssistantPanel.tsx`, `components/query/QueryResultView.tsx`. +**Edit:** `api/v1/router.py`, rate-limit wiring, `types/api.ts`, `QueryPage.tsx`. diff --git a/frontend/src/api/assistantApi.ts b/frontend/src/api/assistantApi.ts new file mode 100644 index 0000000..115b4a5 --- /dev/null +++ b/frontend/src/api/assistantApi.ts @@ -0,0 +1,10 @@ +import { api } from './client'; +import type { AssistantResponse } from '../types/api'; + +export const assistantApi = { + send: (data: { + connection_id: string; + message: string; + history: { role: 'user' | 'assistant'; content: string }[]; + }) => api.post('/query/assistant', data).then((r) => r.data), +}; diff --git a/frontend/src/components/assistant/AssistantPanel.tsx b/frontend/src/components/assistant/AssistantPanel.tsx new file mode 100644 index 0000000..826b9c1 --- /dev/null +++ b/frontend/src/components/assistant/AssistantPanel.tsx @@ -0,0 +1,611 @@ +import { useRef, useState } from 'react'; +import { + Paper, + Stack, + Group, + Textarea, + Button, + Text, + TextInput, + TagsInput, + Alert, + Loader, + Badge, + ActionIcon, + ScrollArea, +} from '@mantine/core'; +import { IconSend, IconCheck, IconPlayerPlay, IconX, IconSparkles } from '@tabler/icons-react'; +import { useMutation, useQueryClient } from '@tanstack/react-query'; +import { assistantApi } from '../../api/assistantApi'; +import { glossaryApi, metricsApi, dictionaryApi } from '../../api/glossaryApi'; +import { knowledgeApi } from '../../api/knowledgeApi'; +import { queryApi } from '../../api/queryApi'; +import { QueryResultView } from '../query/QueryResultView'; +import type { + AssistantChatMessage, + DictionaryDraft, + GlossaryDraft, + KnowledgeDraft, + MetricDraft, + QueryResult, + SqlPreviewPayload, +} from '../../types/api'; + +export function AssistantPanel({ connectionId }: { connectionId: string | null }) { + const [messages, setMessages] = useState([]); + const [input, setInput] = useState(''); + const viewport = useRef(null); + + const scrollToBottom = () => + requestAnimationFrame(() => + viewport.current?.scrollTo({ top: viewport.current.scrollHeight, behavior: 'smooth' }), + ); + + const sendMutation = useMutation({ + mutationFn: (message: string) => + assistantApi.send({ + connection_id: connectionId!, + message, + // Send prior turns as plain {role, content}; cards aren't part of LLM history. + history: messages.map((m) => ({ role: m.role, content: m.content })), + }), + onSuccess: (data) => { + setMessages((prev) => [ + ...prev, + { role: 'assistant', content: data.message, action: data.action }, + ]); + scrollToBottom(); + }, + }); + + const handleSend = () => { + const message = input.trim(); + if (!message || !connectionId) return; + setMessages((prev) => [...prev, { role: 'user', content: message }]); + setInput(''); + sendMutation.mutate(message); + scrollToBottom(); + }; + + return ( + + + + Assistant + + ask questions or add glossary terms in plain language + + + + {messages.length > 0 && ( + + + {messages.map((m, i) => ( + + ))} + + + )} + + {sendMutation.isPending && ( + + + + Thinking… + + + )} + + {sendMutation.isError && ( + + {(sendMutation.error as Error).message} + + )} + +