Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions backend/app/api/v1/endpoints/assistant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from fastapi import APIRouter, Depends
from sqlalchemy.ext.asyncio import AsyncSession

from app.api.v1.schemas.assistant import AssistantRequest, AssistantResponse
from app.core.auth import AuthContext, get_org_context
from app.db.session import get_db
from app.services import assistant_service

# Mounted under ``/query`` so it inherits the query rate limiter, and uses
# ``get_org_context`` like the other LLM query endpoints — the service resolves
# and authorizes the connection from the body.
router = APIRouter(prefix="/query", tags=["assistant"])


@router.post("/assistant", response_model=AssistantResponse)
async def assistant_turn(
body: AssistantRequest,
ctx: AuthContext = Depends(get_org_context),
db: AsyncSession = Depends(get_db),
):
"""One conversational turn: classify the message and return ``{message, action?}``.

``action`` is either a ``sql_preview`` (confirm → run via /query/execute-sql) or a
``glossary_draft`` (confirm → create via /connections/{id}/glossary). The assistant
itself never writes.
"""
history = [m.model_dump() for m in body.history]
result = await assistant_service.handle_turn(
db, body.connection_id, body.message, history, ctx
)
return AssistantResponse(**result)
2 changes: 2 additions & 0 deletions backend/app/api/v1/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from app.api.v1.endpoints import (
api_keys,
assistant,
auth,
connections,
dictionary,
Expand All @@ -23,6 +24,7 @@
api_router.include_router(teams.router)
api_router.include_router(api_keys.router)
api_router.include_router(query.router)
api_router.include_router(assistant.router)
api_router.include_router(connections.router)
api_router.include_router(schemas.router)
api_router.include_router(glossary.router)
Expand Down
111 changes: 111 additions & 0 deletions backend/app/api/v1/schemas/assistant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from typing import Annotated, Literal
from uuid import UUID

from pydantic import BaseModel, Field


class AssistantMessage(BaseModel):
role: Literal["user", "assistant"]
content: str = Field(min_length=1, max_length=4000)


class AssistantRequest(BaseModel):
connection_id: UUID
message: str = Field(min_length=1, max_length=2000)
history: list[AssistantMessage] = Field(default_factory=list, max_length=20)


# --- Draft payloads -------------------------------------------------------
# Each mirrors the corresponding *Create schema (a usable subset) so the
# frontend can POST it to the existing REST endpoint after confirmation.


class GlossaryDraft(BaseModel):
term: str
definition: str
sql_expression: str = ""
related_tables: list[str] = Field(default_factory=list)
related_columns: list[str] = Field(default_factory=list)


class MetricDraft(BaseModel):
metric_name: str
display_name: str
description: str = ""
sql_expression: str = ""
aggregation_type: str = ""
related_tables: list[str] = Field(default_factory=list)
dimensions: list[str] = Field(default_factory=list)


class DictionaryEntryDraft(BaseModel):
raw_value: str
display_value: str
description: str = ""


class DictionaryDraft(BaseModel):
"""Coded-value mappings for one column.

``column_id`` is resolved server-side from ``table_name``/``column_name``;
the frontend POSTs each entry to ``/columns/{column_id}/dictionary``.
"""

column_id: UUID
table_name: str
column_name: str
entries: list[DictionaryEntryDraft]


class KnowledgeDraft(BaseModel):
title: str
content: str
source_url: str = ""


class SqlPreviewPayload(BaseModel):
sql: str
explanation: str = ""


# --- Action discriminated union ------------------------------------------


class GlossaryDraftAction(BaseModel):
type: Literal["glossary_draft"] = "glossary_draft"
payload: GlossaryDraft


class MetricDraftAction(BaseModel):
type: Literal["metric_draft"] = "metric_draft"
payload: MetricDraft


class DictionaryDraftAction(BaseModel):
type: Literal["dictionary_draft"] = "dictionary_draft"
payload: DictionaryDraft


class KnowledgeDraftAction(BaseModel):
type: Literal["knowledge_draft"] = "knowledge_draft"
payload: KnowledgeDraft


class SqlPreviewAction(BaseModel):
type: Literal["sql_preview"] = "sql_preview"
payload: SqlPreviewPayload


AssistantAction = Annotated[
GlossaryDraftAction
| MetricDraftAction
| DictionaryDraftAction
| KnowledgeDraftAction
| SqlPreviewAction,
Field(discriminator="type"),
]


class AssistantResponse(BaseModel):
message: str
action: AssistantAction | None = None
175 changes: 175 additions & 0 deletions backend/app/llm/agents/assistant_router.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
"""Agent: Assistant Router — classifies a chat message and extracts drafts.

One structured-JSON LLM call. Mirrors :class:`QueryComposerAgent`: build messages,
call ``provider.complete``, parse with ``repair_json`` and degrade gracefully to a
plain chat reply when the model returns non-JSON or an unusable draft.
"""

import json
from dataclasses import dataclass

from app.llm.base_provider import BaseLLMProvider, LLMConfig, LLMMessage
from app.llm.prompts.assistant_prompts import SYSTEM_PROMPT, USER_PROMPT_TEMPLATE
from app.llm.utils import repair_json

# Intents that carry a structured draft payload.
DRAFT_INTENTS = {"glossary", "metric", "dictionary", "knowledge"}
VALID_INTENTS = {"question", "chat"} | DRAFT_INTENTS


@dataclass
class AssistantDecision:
intent: str
message: str
payload: dict | None = None


class AssistantAgent:
def __init__(self, provider: BaseLLMProvider, config: LLMConfig):
self.provider = provider
self.config = config

async def route(
self,
message: str,
context: str,
history: list[dict] | None = None,
) -> AssistantDecision:
"""Classify ``message`` and extract a draft payload when applicable."""
user_prompt = USER_PROMPT_TEMPLATE.format(
context=context,
history=_format_history(history),
message=message,
)
messages = [
LLMMessage(role="system", content=SYSTEM_PROMPT),
LLMMessage(role="user", content=user_prompt),
]

response = await self.provider.complete(messages, self.config)

try:
parsed = json.loads(repair_json(response.content))
except (json.JSONDecodeError, TypeError):
# Non-JSON response — treat the raw text as a plain chat reply.
return AssistantDecision(intent="chat", message=response.content.strip())

intent = parsed.get("intent")
if intent not in VALID_INTENTS:
intent = "chat"

message_text = (parsed.get("message") or "").strip()
payload = None

if intent in DRAFT_INTENTS:
payload = _NORMALIZERS[intent](parsed.get("payload"))
if payload is None:
# Model claimed a draft intent but gave no usable payload.
intent = "chat"

return AssistantDecision(intent=intent, message=message_text, payload=payload)


def _format_history(history: list[dict] | None) -> str:
if not history:
return "(no prior messages)"
lines = []
for turn in history:
role = turn.get("role", "user")
content = (turn.get("content") or "").strip()
if content:
lines.append(f"{role}: {content}")
return "\n".join(lines) if lines else "(no prior messages)"


def _str(value) -> str:
return str(value).strip() if value is not None else ""


def _as_list(value) -> list[str]:
if isinstance(value, list):
return [str(v).strip() for v in value if str(v).strip()]
if isinstance(value, str) and value.strip():
return [value.strip()]
return []


def _normalize_glossary(raw) -> dict | None:
if not isinstance(raw, dict):
return None
term, definition = _str(raw.get("term")), _str(raw.get("definition"))
if not term or not definition:
return None
return {
"term": term,
"definition": definition,
"sql_expression": _str(raw.get("sql_expression")),
"related_tables": _as_list(raw.get("related_tables")),
"related_columns": _as_list(raw.get("related_columns")),
}


def _normalize_metric(raw) -> dict | None:
if not isinstance(raw, dict):
return None
display_name = _str(raw.get("display_name")) or _str(raw.get("metric_name"))
sql_expression = _str(raw.get("sql_expression"))
if not display_name or not sql_expression:
return None
metric_name = _str(raw.get("metric_name")) or _slug(display_name)
return {
"metric_name": metric_name,
"display_name": display_name,
"description": _str(raw.get("description")),
"sql_expression": sql_expression,
"aggregation_type": _str(raw.get("aggregation_type")),
"related_tables": _as_list(raw.get("related_tables")),
"dimensions": _as_list(raw.get("dimensions")),
}


def _normalize_dictionary(raw) -> dict | None:
if not isinstance(raw, dict):
return None
table_name, column_name = _str(raw.get("table_name")), _str(raw.get("column_name"))
if not table_name or not column_name:
return None
entries = []
for item in raw.get("entries") or []:
if not isinstance(item, dict):
continue
raw_value, display_value = _str(item.get("raw_value")), _str(item.get("display_value"))
if not raw_value or not display_value:
continue
entries.append(
{
"raw_value": raw_value,
"display_value": display_value,
"description": _str(item.get("description")),
}
)
if not entries:
return None
return {"table_name": table_name, "column_name": column_name, "entries": entries}


def _normalize_knowledge(raw) -> dict | None:
if not isinstance(raw, dict):
return None
title, content = _str(raw.get("title")), _str(raw.get("content"))
if not title or not content:
return None
return {"title": title, "content": content, "source_url": _str(raw.get("source_url"))}


_NORMALIZERS = {
"glossary": _normalize_glossary,
"metric": _normalize_metric,
"dictionary": _normalize_dictionary,
"knowledge": _normalize_knowledge,
}


def _slug(text: str) -> str:
cleaned = "".join(c if c.isalnum() else "_" for c in text.lower())
return "_".join(part for part in cleaned.split("_") if part) or "metric"
Loading
Loading