From 61ef0d3095c91c05d94527c713cf6b2dc21c4c23 Mon Sep 17 00:00:00 2001 From: cosmin chauciuc Date: Fri, 5 Jun 2026 13:22:13 +0300 Subject: [PATCH] Phase 1: identity, teams & ownership (backend) Add the identity layer that gates the rest of the roadmap: real users, teams/workspaces, roles, ownership, and audit FKs. Single-tenant per deployment (isolation by workspace_id within an auto-created default org); organization_id carried on every core table so a future managed-SaaS fleet needs no migration. Auth: magic-link + local password login issuing HS256 session JWTs delivered as HTTP-only cookies; OIDC is a registered (unimplemented) provider seam; DISABLE_AUTH escape hatch for local dev. - Models: Organization, User, Team, Membership (admin|editor|viewer), ApiKey - Core: security.py (PBKDF2 + JWT + API keys), auth.py (get_current_user / get_org_context / require_role + cookie helpers), auth_providers.py seam - Services: auth_service (login/magic-link/register), identity_service (org/team/membership/api-key CRUD, bootstrap + system_context) - Endpoints: /auth/*, /teams (+members), /api-keys - AuthZ wired through connection_service, query/schema services, all metadata endpoints (via connection cascade root), the MCP server, and startup auto-setup (system context). created_by/user_id promoted to real User FKs. - Migration 004: nullable -> seed default org/workspace/admin -> backfill -> NOT NULL, with (org_id, created_at) indexes and a full downgrade. - Tests: password/JWT/API-key crypto, provider registry, role hierarchy, 401 on protected routes (98 pass). - docker-compose: DISABLE_AUTH=true bridge until the Phase 1 frontend lands. Co-Authored-By: Claude Opus 4.8 (1M context) --- CLAUDE.md | 23 ++ .../versions/004_identity_and_ownership.py | 252 ++++++++++++++++++ backend/app/api/v1/deps.py | 75 ++++++ backend/app/api/v1/endpoints/api_keys.py | 51 ++++ backend/app/api/v1/endpoints/auth.py | 98 +++++++ backend/app/api/v1/endpoints/connections.py | 42 ++- backend/app/api/v1/endpoints/dictionary.py | 6 + backend/app/api/v1/endpoints/glossary.py | 9 + backend/app/api/v1/endpoints/knowledge.py | 7 + backend/app/api/v1/endpoints/metrics.py | 9 + backend/app/api/v1/endpoints/query.py | 25 +- backend/app/api/v1/endpoints/query_history.py | 38 ++- .../app/api/v1/endpoints/sample_queries.py | 13 +- backend/app/api/v1/endpoints/schemas.py | 10 +- backend/app/api/v1/endpoints/teams.py | 86 ++++++ backend/app/api/v1/router.py | 6 + backend/app/api/v1/schemas/auth.py | 64 +++++ backend/app/api/v1/schemas/team.py | 61 +++++ backend/app/config.py | 30 +++ backend/app/core/auth.py | 208 +++++++++++++++ backend/app/core/auth_providers.py | 132 +++++++++ backend/app/core/exceptions.py | 10 + backend/app/core/security.py | 110 ++++++++ backend/app/db/models/__init__.py | 10 + backend/app/db/models/api_key.py | 36 +++ backend/app/db/models/connection.py | 15 +- backend/app/db/models/glossary.py | 7 +- backend/app/db/models/knowledge.py | 5 + backend/app/db/models/membership.py | 40 +++ backend/app/db/models/metric.py | 7 +- backend/app/db/models/organization.py | 35 +++ backend/app/db/models/query_history.py | 8 +- backend/app/db/models/sample_query.py | 9 +- backend/app/db/models/team.py | 40 +++ backend/app/db/models/user.py | 50 ++++ backend/app/mcp/server.py | 44 ++- backend/app/services/auth_service.py | 127 +++++++++ backend/app/services/connection_service.py | 77 +++++- backend/app/services/identity_service.py | 250 +++++++++++++++++ backend/app/services/knowledge_service.py | 2 + backend/app/services/query_service.py | 18 +- backend/app/services/schema_service.py | 18 +- backend/app/services/setup_service.py | 31 ++- backend/pyproject.toml | 1 + backend/tests/test_auth_endpoints.py | 53 ++++ backend/tests/test_auth_providers.py | 99 +++++++ backend/tests/test_security.py | 109 ++++++++ docker-compose.yml | 4 + planfull.md | 2 +- 49 files changed, 2384 insertions(+), 78 deletions(-) create mode 100644 backend/alembic/versions/004_identity_and_ownership.py create mode 100644 backend/app/api/v1/deps.py create mode 100644 backend/app/api/v1/endpoints/api_keys.py create mode 100644 backend/app/api/v1/endpoints/auth.py create mode 100644 backend/app/api/v1/endpoints/teams.py create mode 100644 backend/app/api/v1/schemas/auth.py create mode 100644 backend/app/api/v1/schemas/team.py create mode 100644 backend/app/core/auth.py create mode 100644 backend/app/core/auth_providers.py create mode 100644 backend/app/core/security.py create mode 100644 backend/app/db/models/api_key.py create mode 100644 backend/app/db/models/membership.py create mode 100644 backend/app/db/models/organization.py create mode 100644 backend/app/db/models/team.py create mode 100644 backend/app/db/models/user.py create mode 100644 backend/app/services/auth_service.py create mode 100644 backend/app/services/identity_service.py create mode 100644 backend/tests/test_auth_endpoints.py create mode 100644 backend/tests/test_auth_providers.py create mode 100644 backend/tests/test_security.py diff --git a/CLAUDE.md b/CLAUDE.md index 687c571..db1ced8 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -129,6 +129,17 @@ frontend/src/ | `AZURE_OPENAI_API_KEY` | — | Azure OpenAI key | | `AZURE_OPENAI_API_VERSION` | `2024-10-21` | Azure OpenAI API version | | `AZURE_OPENAI_DEPLOYMENT` | — | Azure OpenAI embedding deployment name | +| `DISABLE_AUTH` | `false` | Local-dev escape hatch — treat every request as the default admin (no login). **Never enable in production** | +| `AUTH_PROVIDER` | `local` | Interactive login backend: `local` (password + magic-link), `magic_link`, or `oidc` (registered seam, not yet implemented) | +| `JWT_SECRET` | `dev-jwt-secret-change-in-production` | HS256 signing secret for session + magic-link JWTs | +| `JWT_ACCESS_TTL_MINUTES` | `720` | Session lifetime (minutes) | +| `MAGIC_LINK_TTL_MINUTES` | `15` | Magic-link token lifetime (minutes) | +| `AUTH_COOKIE_NAME` | `qw_session` | Session cookie name (HTTP-only) | +| `AUTH_COOKIE_SECURE` | `false` | Set `true` behind TLS (HTTPS-only cookie) | +| `AUTH_COOKIE_SAMESITE` | `lax` | Session cookie SameSite (`lax`/`strict`/`none`) | +| `DEFAULT_ORG_SLUG` | `default` | Slug of the auto-created default organization | +| `DEFAULT_ADMIN_EMAIL` | `admin@querywise.local` | Bootstrapped admin user (created on boot + in migration 004) | +| `DEFAULT_ADMIN_PASSWORD` | — | If set, the bootstrapped admin gets this local-login password | ## Ollama (Local LLM) @@ -223,3 +234,15 @@ dependencies degrade gracefully — the app boots without `structlog` / - **Health** (`app/api/v1/endpoints/health.py`): `GET /health/live` (process) and `GET /health/ready` (DB + job queue + LLM provider, 503 on failure) for K8s probes. - **LLM endpoints:** Azure OpenAI provider (`azure_openai`) added so the pipeline can run inside a customer VPC; registered in `provider_registry`. - **Tests/CI:** unit tests in `backend/tests/` (no DB/LLM needed); `.github/workflows/ci.yml` runs pytest (gating) + ruff/mypy/frontend build (advisory until pre-existing lint debt is cleared). Optional deps: `pip install -e ".[observability,jobs]"`. + +## Identity & auth (Phase 1) + +Real users, teams, roles, and ownership. Single-tenant per deployment; isolation is by `workspace_id` (a `Team`) within the auto-created default `Organization`. `organization_id` is carried on every core table so a future managed-SaaS fleet needs no migration. Migration `004` creates the identity tables, seeds the default org/workspace/admin, backfills all existing rows, and promotes the free-text `created_by`/`user_id` columns to real `User` FKs. + +- **Identity models** (`app/db/models/`): `Organization`, `User`, `Team` (= workspace), `Membership` (role `admin|editor|viewer`, ranked in `ROLE_RANK`), `ApiKey` (only the SHA-256 hash stored). +- **Primitives** (`app/core/security.py`): PBKDF2 password hashing (stdlib), HS256 JWTs with a `purpose` claim (`session` / `magic_link`), and API-key gen/hash. Dependency-light + unit-tested. +- **Request plumbing** (`app/core/auth.py`): `get_current_user` (API key → Bearer → HTTP-only `qw_session` cookie), `get_org_context` → `AuthContext` (active workspace via `X-Workspace-Id` header, else earliest membership), and `require_role(...)`. `DISABLE_AUTH=true` short-circuits to the bootstrapped admin for local dev. +- **Login** (`app/services/auth_service.py`, `app/api/v1/endpoints/auth.py`): password + magic-link, both issuing a session-cookie JWT. Magic-link delivery (email/Slack) lands in Phase 4 — the token is logged and, outside production, returned by `POST /auth/magic-link`. `app/core/auth_providers.py` is a name-keyed seam (`local`/`magic_link`/`oidc`); **OIDC is registered but not implemented**. +- **AuthZ in services** (per the existing convention): `connection_service` scopes by org+workspace and enforces role; metadata endpoints authorize through the connection (the cascade root) via `app/api/v1/deps.py` (`require_connection_read/write`, `require_column_read/write`). Non-request entry points — startup auto-setup, the MCP server, the seed script via `DISABLE_AUTH` — act under `identity_service.system_context()` (admin in the default workspace). +- **Endpoints:** `/auth/*` (login, register, magic-link request/verify, logout, me, providers), `/teams` + `/teams/{id}/members` (admin-managed), `/api-keys` (per-user, plaintext shown once). +- **Heads-up:** once auth is enforced, the current (pre-auth) frontend gets 401s — run with `DISABLE_AUTH=true` until the Phase 1 frontend (login + auth context + workspace switcher) lands. diff --git a/backend/alembic/versions/004_identity_and_ownership.py b/backend/alembic/versions/004_identity_and_ownership.py new file mode 100644 index 0000000..3d864b4 --- /dev/null +++ b/backend/alembic/versions/004_identity_and_ownership.py @@ -0,0 +1,252 @@ +"""Identity, teams & ownership (Phase 1) + +Revision ID: 004 +Revises: 003 +Create Date: 2026-06-05 + +Adds the identity layer (organizations, users, teams, memberships, api_keys), +re-keys all core tables with organization_id, scopes connections to a workspace +(team) + owner, and promotes the free-text created_by / user_id columns to real +User foreign keys. + +Migration strategy per the roadmap: add nullable → create the default org / +workspace / admin → backfill every existing row → enforce NOT NULL. Rollback is +``DISABLE_AUTH=true`` + this downgrade (+ pg_dump restore if needed). +""" + +from collections.abc import Sequence +from typing import Union + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects.postgresql import JSONB, UUID + +from app.config import settings + +# revision identifiers, used by Alembic. +revision: str = "004" +down_revision: str = "003" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +# Core tables that gain organization_id (SaaS-ready scoping). +ORG_SCOPED_TABLES = [ + "database_connections", + "glossary_terms", + "metric_definitions", + "sample_queries", + "knowledge_documents", + "query_executions", +] +# Tables whose free-text created_by becomes a created_by_id User FK. +CREATED_BY_TABLES = ["glossary_terms", "metric_definitions", "sample_queries"] + + +def _q(value: str) -> str: + """Single-quote-escape a trusted config string for inline SQL.""" + return value.replace("'", "''") + + +def upgrade() -> None: + org_slug = _q(settings.default_org_slug) + admin_email = _q(settings.default_admin_email) + + # --- Identity tables --------------------------------------------------- + op.create_table( + "organizations", + sa.Column("id", UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")), + sa.Column("name", sa.String(255), nullable=False), + sa.Column("slug", sa.String(255), nullable=False, unique=True), + sa.Column("settings", JSONB, server_default=sa.text("'{}'::jsonb")), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + ) + op.create_table( + "users", + sa.Column("id", UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")), + sa.Column("email", sa.String(320), nullable=False, unique=True), + sa.Column("name", sa.String(255), nullable=True), + sa.Column("sso_subject", sa.String(255), nullable=True, unique=True), + sa.Column("password_hash", sa.String(255), nullable=True), + sa.Column("status", sa.String(20), nullable=False, server_default="active"), + sa.Column("last_login_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + ) + op.create_table( + "teams", + sa.Column("id", UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")), + sa.Column( + "organization_id", + UUID(as_uuid=True), + sa.ForeignKey("organizations.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column("name", sa.String(255), nullable=False), + sa.Column("slug", sa.String(255), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + ) + op.create_table( + "memberships", + sa.Column("id", UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")), + sa.Column( + "user_id", UUID(as_uuid=True), sa.ForeignKey("users.id", ondelete="CASCADE"), nullable=False + ), + sa.Column( + "team_id", UUID(as_uuid=True), sa.ForeignKey("teams.id", ondelete="CASCADE"), nullable=False + ), + sa.Column("role", sa.String(20), nullable=False, server_default="viewer"), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.UniqueConstraint("user_id", "team_id", name="uq_membership_user_team"), + ) + op.create_table( + "api_keys", + sa.Column("id", UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")), + sa.Column( + "user_id", UUID(as_uuid=True), sa.ForeignKey("users.id", ondelete="CASCADE"), nullable=False + ), + sa.Column("name", sa.String(255), nullable=False), + sa.Column("key_hash", sa.String(64), nullable=False, unique=True), + sa.Column("key_prefix", sa.String(16), nullable=False), + sa.Column("permissions", JSONB, server_default=sa.text("'{}'::jsonb")), + sa.Column("expires_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("last_used_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("revoked_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + ) + + # --- Seed the default org / workspace / admin -------------------------- + op.execute( + f"INSERT INTO organizations (name, slug) " + f"VALUES ('{_q(settings.default_org_name)}', '{org_slug}')" + ) + op.execute( + f"INSERT INTO teams (organization_id, name, slug) " + f"SELECT id, '{_q(settings.default_workspace_name)}', 'default-workspace' " + f"FROM organizations WHERE slug = '{org_slug}'" + ) + op.execute( + f"INSERT INTO users (email, name, status) " + f"VALUES ('{admin_email}', 'Administrator', 'active')" + ) + op.execute( + f"INSERT INTO memberships (user_id, team_id, role) " + f"SELECT u.id, t.id, 'admin' FROM users u, teams t " + f"JOIN organizations o ON o.id = t.organization_id " + f"WHERE u.email = '{admin_email}' AND o.slug = '{org_slug}'" + ) + + org_subq = f"(SELECT id FROM organizations WHERE slug = '{org_slug}')" + team_subq = ( + f"(SELECT t.id FROM teams t JOIN organizations o ON o.id = t.organization_id " + f"WHERE o.slug = '{org_slug}' ORDER BY t.created_at LIMIT 1)" + ) + admin_subq = f"(SELECT id FROM users WHERE email = '{admin_email}')" + + # --- organization_id on every core table ------------------------------- + for table in ORG_SCOPED_TABLES: + op.add_column(table, sa.Column("organization_id", UUID(as_uuid=True), nullable=True)) + op.execute(f"UPDATE {table} SET organization_id = {org_subq}") + op.alter_column(table, "organization_id", nullable=False) + op.create_foreign_key( + f"fk_{table}_organization_id", + table, + "organizations", + ["organization_id"], + ["id"], + ondelete="CASCADE", + ) + op.create_index(f"ix_{table}_org_created", table, ["organization_id", "created_at"]) + + # --- database_connections: workspace + owner + privacy ----------------- + op.add_column("database_connections", sa.Column("workspace_id", UUID(as_uuid=True), nullable=True)) + op.add_column("database_connections", sa.Column("owner_id", UUID(as_uuid=True), nullable=True)) + op.add_column( + "database_connections", + sa.Column("is_private", sa.Boolean(), nullable=False, server_default=sa.false()), + ) + op.execute( + f"UPDATE database_connections SET workspace_id = {team_subq}, owner_id = {admin_subq}" + ) + op.alter_column("database_connections", "workspace_id", nullable=False) + op.create_foreign_key( + "fk_database_connections_workspace_id", + "database_connections", + "teams", + ["workspace_id"], + ["id"], + ondelete="CASCADE", + ) + op.create_foreign_key( + "fk_database_connections_owner_id", + "database_connections", + "users", + ["owner_id"], + ["id"], + ondelete="SET NULL", + ) + + # --- created_by → created_by_id User FK -------------------------------- + for table in CREATED_BY_TABLES: + op.add_column(table, sa.Column("created_by_id", UUID(as_uuid=True), nullable=True)) + # Existing rows were created by the system; attribute them to the admin. + op.execute(f"UPDATE {table} SET created_by_id = {admin_subq} WHERE created_by IS NOT NULL") + op.create_foreign_key( + f"fk_{table}_created_by_id", + table, + "users", + ["created_by_id"], + ["id"], + ondelete="SET NULL", + ) + op.drop_column(table, "created_by") + + # --- query_executions.user_id: free-text string → User FK -------------- + # Old free-text values cannot be mapped to real users; they become NULL. + op.drop_column("query_executions", "user_id") + op.add_column("query_executions", sa.Column("user_id", UUID(as_uuid=True), nullable=True)) + op.create_foreign_key( + "fk_query_executions_user_id", + "query_executions", + "users", + ["user_id"], + ["id"], + ondelete="SET NULL", + ) + + +def downgrade() -> None: + # query_executions.user_id back to free-text + op.drop_constraint("fk_query_executions_user_id", "query_executions", type_="foreignkey") + op.drop_column("query_executions", "user_id") + op.add_column("query_executions", sa.Column("user_id", sa.String(255), nullable=True)) + + # created_by_id → created_by string + for table in CREATED_BY_TABLES: + op.drop_constraint(f"fk_{table}_created_by_id", table, type_="foreignkey") + op.drop_column(table, "created_by_id") + op.add_column(table, sa.Column("created_by", sa.String(255), nullable=True)) + + # database_connections extras + op.drop_constraint("fk_database_connections_owner_id", "database_connections", type_="foreignkey") + op.drop_constraint( + "fk_database_connections_workspace_id", "database_connections", type_="foreignkey" + ) + op.drop_column("database_connections", "is_private") + op.drop_column("database_connections", "owner_id") + op.drop_column("database_connections", "workspace_id") + + # organization_id on core tables + for table in ORG_SCOPED_TABLES: + op.drop_index(f"ix_{table}_org_created", table_name=table) + op.drop_constraint(f"fk_{table}_organization_id", table, type_="foreignkey") + op.drop_column(table, "organization_id") + + # Identity tables + op.drop_table("api_keys") + op.drop_table("memberships") + op.drop_table("teams") + op.drop_table("users") + op.drop_table("organizations") diff --git a/backend/app/api/v1/deps.py b/backend/app/api/v1/deps.py new file mode 100644 index 0000000..6e9b31e --- /dev/null +++ b/backend/app/api/v1/deps.py @@ -0,0 +1,75 @@ +"""Shared FastAPI dependencies for authorization. + +Metadata endpoints (glossary, metrics, dictionary, sample queries, knowledge, +schema) are keyed by ``connection_id`` — the workspace cascade root. These +dependencies resolve the caller's :class:`AuthContext` and assert access to the +connection in the path, so handlers can stay thin. +""" + +from __future__ import annotations + +import uuid + +from fastapi import Depends +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.auth import AuthContext, get_org_context +from app.core.exceptions import NotFoundError +from app.db.models.schema_cache import CachedColumn, CachedTable +from app.db.session import get_db +from app.services import connection_service + + +async def require_connection_read( + connection_id: uuid.UUID, + ctx: AuthContext = Depends(get_org_context), + db: AsyncSession = Depends(get_db), +) -> AuthContext: + """Caller must be able to read the connection in the path.""" + await connection_service.get_connection(db, connection_id, ctx) + return ctx + + +async def require_connection_write( + connection_id: uuid.UUID, + ctx: AuthContext = Depends(get_org_context), + db: AsyncSession = Depends(get_db), +) -> AuthContext: + """Caller must be an editor (or above) on the connection in the path.""" + await connection_service.get_connection(db, connection_id, ctx, write=True) + return ctx + + +async def _connection_id_for_column(db: AsyncSession, column_id: uuid.UUID) -> uuid.UUID: + result = await db.execute( + select(CachedTable.connection_id) + .join(CachedColumn, CachedColumn.table_id == CachedTable.id) + .where(CachedColumn.id == column_id) + ) + connection_id = result.scalar_one_or_none() + if connection_id is None: + raise NotFoundError("Column", str(column_id)) + return connection_id + + +async def require_column_read( + column_id: uuid.UUID, + ctx: AuthContext = Depends(get_org_context), + db: AsyncSession = Depends(get_db), +) -> AuthContext: + """Caller must be able to read the connection owning the column in the path.""" + connection_id = await _connection_id_for_column(db, column_id) + await connection_service.get_connection(db, connection_id, ctx) + return ctx + + +async def require_column_write( + column_id: uuid.UUID, + ctx: AuthContext = Depends(get_org_context), + db: AsyncSession = Depends(get_db), +) -> AuthContext: + """Caller must be an editor on the connection owning the column in the path.""" + connection_id = await _connection_id_for_column(db, column_id) + await connection_service.get_connection(db, connection_id, ctx, write=True) + return ctx diff --git a/backend/app/api/v1/endpoints/api_keys.py b/backend/app/api/v1/endpoints/api_keys.py new file mode 100644 index 0000000..65bb655 --- /dev/null +++ b/backend/app/api/v1/endpoints/api_keys.py @@ -0,0 +1,51 @@ +import uuid + +from fastapi import APIRouter, Depends +from sqlalchemy.ext.asyncio import AsyncSession + +from app.api.v1.schemas.team import ApiKeyCreate, ApiKeyCreatedResponse, ApiKeyResponse +from app.core.auth import get_current_user +from app.db.models.user import User +from app.db.session import get_db +from app.services import identity_service + +router = APIRouter(prefix="/api-keys", tags=["api-keys"]) + + +@router.get("", response_model=list[ApiKeyResponse]) +async def list_api_keys( + user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + keys = await identity_service.list_api_keys(db, user) + return [ApiKeyResponse.model_validate(k) for k in keys] + + +@router.post("", response_model=ApiKeyCreatedResponse, status_code=201) +async def create_api_key( + body: ApiKeyCreate, + user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + api_key, plaintext = await identity_service.create_api_key( + db, user, body.name, body.expires_at + ) + return ApiKeyCreatedResponse( + id=api_key.id, + name=api_key.name, + key_prefix=api_key.key_prefix, + expires_at=api_key.expires_at, + last_used_at=api_key.last_used_at, + revoked_at=api_key.revoked_at, + created_at=api_key.created_at, + key=plaintext, + ) + + +@router.delete("/{key_id}", status_code=204) +async def revoke_api_key( + key_id: uuid.UUID, + user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + await identity_service.revoke_api_key(db, user, key_id) diff --git a/backend/app/api/v1/endpoints/auth.py b/backend/app/api/v1/endpoints/auth.py new file mode 100644 index 0000000..5fae5a4 --- /dev/null +++ b/backend/app/api/v1/endpoints/auth.py @@ -0,0 +1,98 @@ +import logging + +from fastapi import APIRouter, Depends, Response +from sqlalchemy.ext.asyncio import AsyncSession + +from app.api.v1.schemas.auth import ( + AuthProviderInfo, + LoginRequest, + MagicLinkRequest, + MagicLinkResponse, + MagicLinkVerifyRequest, + MeResponse, + RegisterRequest, + UserResponse, + WorkspaceMembershipResponse, +) +from app.config import settings +from app.core.auth import clear_session_cookie, get_current_user, set_session_cookie +from app.core.auth_providers import get_auth_provider +from app.db.models.user import User +from app.db.session import get_db +from app.services import auth_service, identity_service + +logger = logging.getLogger("querywise.auth") + +router = APIRouter(prefix="/auth", tags=["auth"]) + + +def _login(response: Response, user: User) -> UserResponse: + set_session_cookie(response, auth_service.issue_session_token(user)) + return UserResponse.model_validate(user) + + +@router.get("/providers", response_model=AuthProviderInfo) +async def auth_providers(): + """Advertise the configured login method so the frontend renders the right UI.""" + provider = get_auth_provider() + return AuthProviderInfo(**provider.describe(), disable_auth=settings.disable_auth) + + +@router.post("/login", response_model=UserResponse) +async def login(body: LoginRequest, response: Response, db: AsyncSession = Depends(get_db)): + user = await auth_service.authenticate_password(db, body.email, body.password) + return _login(response, user) + + +@router.post("/register", response_model=UserResponse, status_code=201) +async def register(body: RegisterRequest, response: Response, db: AsyncSession = Depends(get_db)): + user = await auth_service.register_user(db, body.email, body.password, body.name) + return _login(response, user) + + +@router.post("/magic-link", response_model=MagicLinkResponse) +async def request_magic_link(body: MagicLinkRequest, db: AsyncSession = Depends(get_db)): + token = await auth_service.request_magic_link(db, body.email) + # Delivery (email/Slack) is wired in Phase 4; for now log it and, outside + # production, surface it so local dev can complete the flow. + frontend = settings.cors_origins[0] if settings.cors_origins else None + verify_url = f"{frontend}/login/verify?token={token}" if frontend else None + logger.info("Magic link issued for %s: %s", body.email, verify_url or token) + expose = settings.environment != "production" + return MagicLinkResponse( + sent=True, + dev_token=token if expose else None, + dev_verify_url=verify_url if expose else None, + ) + + +@router.post("/magic-link/verify", response_model=UserResponse) +async def verify_magic_link( + body: MagicLinkVerifyRequest, + response: Response, + db: AsyncSession = Depends(get_db), +): + user = await auth_service.verify_magic_link(db, body.token) + return _login(response, user) + + +@router.post("/logout", status_code=204) +async def logout(response: Response): + clear_session_cookie(response) + + +@router.get("/me", response_model=MeResponse) +async def me( + user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + memberships = await identity_service.list_my_memberships(db, user) + return MeResponse( + user=UserResponse.model_validate(user), + workspaces=[ + WorkspaceMembershipResponse( + team_id=m.team_id, team_name=m.team.name, role=m.role + ) + for m in memberships + ], + ) diff --git a/backend/app/api/v1/endpoints/connections.py b/backend/app/api/v1/endpoints/connections.py index 523b7a0..661200f 100644 --- a/backend/app/api/v1/endpoints/connections.py +++ b/backend/app/api/v1/endpoints/connections.py @@ -9,6 +9,7 @@ ConnectionTestResult, ConnectionUpdate, ) +from app.core.auth import AuthContext, get_org_context from app.db.session import get_db from app.services import connection_service @@ -16,8 +17,11 @@ @router.get("", response_model=list[ConnectionResponse]) -async def list_connections(db: AsyncSession = Depends(get_db)): - connections = await connection_service.list_connections(db) +async def list_connections( + ctx: AuthContext = Depends(get_org_context), + db: AsyncSession = Depends(get_db), +): + connections = await connection_service.list_connections(db, ctx) return [ ConnectionResponse( id=c.id, @@ -37,9 +41,14 @@ async def list_connections(db: AsyncSession = Depends(get_db)): @router.post("", response_model=ConnectionResponse, status_code=201) -async def create_connection(body: ConnectionCreate, db: AsyncSession = Depends(get_db)): +async def create_connection( + body: ConnectionCreate, + ctx: AuthContext = Depends(get_org_context), + db: AsyncSession = Depends(get_db), +): conn = await connection_service.create_connection( db, + ctx, name=body.name, connector_type=body.connector_type, connection_string=body.connection_string, @@ -63,8 +72,12 @@ async def create_connection(body: ConnectionCreate, db: AsyncSession = Depends(g @router.get("/{connection_id}", response_model=ConnectionResponse) -async def get_connection(connection_id: uuid.UUID, db: AsyncSession = Depends(get_db)): - conn = await connection_service.get_connection(db, connection_id) +async def get_connection( + connection_id: uuid.UUID, + ctx: AuthContext = Depends(get_org_context), + db: AsyncSession = Depends(get_db), +): + conn = await connection_service.get_connection(db, connection_id, ctx) return ConnectionResponse( id=conn.id, name=conn.name, @@ -84,10 +97,11 @@ async def get_connection(connection_id: uuid.UUID, db: AsyncSession = Depends(ge async def update_connection( connection_id: uuid.UUID, body: ConnectionUpdate, + ctx: AuthContext = Depends(get_org_context), db: AsyncSession = Depends(get_db), ): conn = await connection_service.update_connection( - db, connection_id, **body.model_dump(exclude_none=True) + db, connection_id, ctx, **body.model_dump(exclude_none=True) ) return ConnectionResponse( id=conn.id, @@ -105,11 +119,19 @@ async def update_connection( @router.delete("/{connection_id}", status_code=204) -async def delete_connection(connection_id: uuid.UUID, db: AsyncSession = Depends(get_db)): - await connection_service.delete_connection(db, connection_id) +async def delete_connection( + connection_id: uuid.UUID, + ctx: AuthContext = Depends(get_org_context), + db: AsyncSession = Depends(get_db), +): + await connection_service.delete_connection(db, connection_id, ctx) @router.post("/{connection_id}/test", response_model=ConnectionTestResult) -async def test_connection(connection_id: uuid.UUID, db: AsyncSession = Depends(get_db)): - success, message = await connection_service.test_connection(db, connection_id) +async def test_connection( + connection_id: uuid.UUID, + ctx: AuthContext = Depends(get_org_context), + db: AsyncSession = Depends(get_db), +): + success, message = await connection_service.test_connection(db, connection_id, ctx) return ConnectionTestResult(success=success, message=message) diff --git a/backend/app/api/v1/endpoints/dictionary.py b/backend/app/api/v1/endpoints/dictionary.py index 7debfa1..69a0157 100644 --- a/backend/app/api/v1/endpoints/dictionary.py +++ b/backend/app/api/v1/endpoints/dictionary.py @@ -4,11 +4,13 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from app.api.v1.deps import require_column_read, require_column_write from app.api.v1.schemas.dictionary import ( DictionaryEntryCreate, DictionaryEntryResponse, DictionaryEntryUpdate, ) +from app.core.auth import AuthContext from app.core.exceptions import NotFoundError from app.db.models.dictionary import DictionaryEntry from app.db.session import get_db @@ -19,6 +21,7 @@ @router.get("", response_model=list[DictionaryEntryResponse]) async def list_dictionary_entries( column_id: uuid.UUID, + _ctx: AuthContext = Depends(require_column_read), db: AsyncSession = Depends(get_db), ): result = await db.execute( @@ -33,6 +36,7 @@ async def list_dictionary_entries( async def create_dictionary_entry( column_id: uuid.UUID, body: DictionaryEntryCreate, + _ctx: AuthContext = Depends(require_column_write), db: AsyncSession = Depends(get_db), ): entry = DictionaryEntry(column_id=column_id, **body.model_dump()) @@ -46,6 +50,7 @@ async def update_dictionary_entry( column_id: uuid.UUID, entry_id: uuid.UUID, body: DictionaryEntryUpdate, + _ctx: AuthContext = Depends(require_column_write), db: AsyncSession = Depends(get_db), ): entry = await db.get(DictionaryEntry, entry_id) @@ -63,6 +68,7 @@ async def update_dictionary_entry( async def delete_dictionary_entry( column_id: uuid.UUID, entry_id: uuid.UUID, + _ctx: AuthContext = Depends(require_column_write), db: AsyncSession = Depends(get_db), ): entry = await db.get(DictionaryEntry, entry_id) diff --git a/backend/app/api/v1/endpoints/glossary.py b/backend/app/api/v1/endpoints/glossary.py index bb7c260..1187a84 100644 --- a/backend/app/api/v1/endpoints/glossary.py +++ b/backend/app/api/v1/endpoints/glossary.py @@ -4,11 +4,13 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from app.api.v1.deps import require_connection_read, require_connection_write from app.api.v1.schemas.glossary import ( GlossaryTermCreate, GlossaryTermResponse, GlossaryTermUpdate, ) +from app.core.auth import AuthContext from app.core.exceptions import NotFoundError from app.db.models.glossary import GlossaryTerm from app.db.session import get_db @@ -23,6 +25,7 @@ ) async def list_glossary_terms( connection_id: uuid.UUID, + _ctx: AuthContext = Depends(require_connection_read), db: AsyncSession = Depends(get_db), ): result = await db.execute( @@ -41,10 +44,13 @@ async def list_glossary_terms( async def create_glossary_term( connection_id: uuid.UUID, body: GlossaryTermCreate, + ctx: AuthContext = Depends(require_connection_write), db: AsyncSession = Depends(get_db), ): term = GlossaryTerm( connection_id=connection_id, + organization_id=ctx.organization_id, + created_by_id=ctx.user_id, term=body.term, definition=body.definition, sql_expression=body.sql_expression, @@ -68,6 +74,7 @@ async def create_glossary_term( async def get_glossary_term( connection_id: uuid.UUID, term_id: uuid.UUID, + _ctx: AuthContext = Depends(require_connection_read), db: AsyncSession = Depends(get_db), ): term = await db.get(GlossaryTerm, term_id) @@ -84,6 +91,7 @@ async def update_glossary_term( connection_id: uuid.UUID, term_id: uuid.UUID, body: GlossaryTermUpdate, + _ctx: AuthContext = Depends(require_connection_write), db: AsyncSession = Depends(get_db), ): term = await db.get(GlossaryTerm, term_id) @@ -108,6 +116,7 @@ async def update_glossary_term( async def delete_glossary_term( connection_id: uuid.UUID, term_id: uuid.UUID, + _ctx: AuthContext = Depends(require_connection_write), db: AsyncSession = Depends(get_db), ): term = await db.get(GlossaryTerm, term_id) diff --git a/backend/app/api/v1/endpoints/knowledge.py b/backend/app/api/v1/endpoints/knowledge.py index 6966312..829d36a 100644 --- a/backend/app/api/v1/endpoints/knowledge.py +++ b/backend/app/api/v1/endpoints/knowledge.py @@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload +from app.api.v1.deps import require_connection_read, require_connection_write from app.api.v1.schemas.knowledge import ( FetchUrlRequest, FetchUrlResponse, @@ -13,6 +14,7 @@ KnowledgeDocumentDetail, KnowledgeDocumentResponse, ) +from app.core.auth import AuthContext from app.core.exceptions import NotFoundError from app.db.models.knowledge import KnowledgeDocument from app.db.session import get_db @@ -31,6 +33,7 @@ ) async def list_knowledge_documents( connection_id: uuid.UUID, + _ctx: AuthContext = Depends(require_connection_read), db: AsyncSession = Depends(get_db), ): result = await db.execute( @@ -49,6 +52,7 @@ async def list_knowledge_documents( async def create_knowledge_document( connection_id: uuid.UUID, body: KnowledgeDocumentCreate, + ctx: AuthContext = Depends(require_connection_write), db: AsyncSession = Depends(get_db), ): doc = await import_document( @@ -56,6 +60,7 @@ async def create_knowledge_document( connection_id=connection_id, title=body.title, content=body.content, + organization_id=ctx.organization_id, source_url=body.source_url, ) return doc @@ -68,6 +73,7 @@ async def create_knowledge_document( async def get_knowledge_document( connection_id: uuid.UUID, document_id: uuid.UUID, + _ctx: AuthContext = Depends(require_connection_read), db: AsyncSession = Depends(get_db), ): result = await db.execute( @@ -88,6 +94,7 @@ async def get_knowledge_document( async def delete_knowledge_document( connection_id: uuid.UUID, document_id: uuid.UUID, + _ctx: AuthContext = Depends(require_connection_write), db: AsyncSession = Depends(get_db), ): doc = await db.get(KnowledgeDocument, document_id) diff --git a/backend/app/api/v1/endpoints/metrics.py b/backend/app/api/v1/endpoints/metrics.py index 34498ce..3a77ef6 100644 --- a/backend/app/api/v1/endpoints/metrics.py +++ b/backend/app/api/v1/endpoints/metrics.py @@ -4,7 +4,9 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from app.api.v1.deps import require_connection_read, require_connection_write from app.api.v1.schemas.metric import MetricCreate, MetricResponse, MetricUpdate +from app.core.auth import AuthContext from app.core.exceptions import NotFoundError from app.db.models.metric import MetricDefinition from app.db.session import get_db @@ -19,6 +21,7 @@ ) async def list_metrics( connection_id: uuid.UUID, + _ctx: AuthContext = Depends(require_connection_read), db: AsyncSession = Depends(get_db), ): result = await db.execute( @@ -37,10 +40,13 @@ async def list_metrics( async def create_metric( connection_id: uuid.UUID, body: MetricCreate, + ctx: AuthContext = Depends(require_connection_write), db: AsyncSession = Depends(get_db), ): metric = MetricDefinition( connection_id=connection_id, + organization_id=ctx.organization_id, + created_by_id=ctx.user_id, **body.model_dump(), ) db.add(metric) @@ -59,6 +65,7 @@ async def create_metric( async def get_metric( connection_id: uuid.UUID, metric_id: uuid.UUID, + _ctx: AuthContext = Depends(require_connection_read), db: AsyncSession = Depends(get_db), ): metric = await db.get(MetricDefinition, metric_id) @@ -75,6 +82,7 @@ async def update_metric( connection_id: uuid.UUID, metric_id: uuid.UUID, body: MetricUpdate, + _ctx: AuthContext = Depends(require_connection_write), db: AsyncSession = Depends(get_db), ): metric = await db.get(MetricDefinition, metric_id) @@ -99,6 +107,7 @@ async def update_metric( async def delete_metric( connection_id: uuid.UUID, metric_id: uuid.UUID, + _ctx: AuthContext = Depends(require_connection_write), db: AsyncSession = Depends(get_db), ): metric = await db.get(MetricDefinition, metric_id) diff --git a/backend/app/api/v1/endpoints/query.py b/backend/app/api/v1/endpoints/query.py index d6db4c7..22bbabe 100644 --- a/backend/app/api/v1/endpoints/query.py +++ b/backend/app/api/v1/endpoints/query.py @@ -2,6 +2,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.api.v1.schemas.query import ExecuteSQLRequest, QueryRequest, SQLOnlyResponse +from app.core.auth import AuthContext, get_org_context from app.db.session import get_db from app.services import query_service @@ -9,27 +10,39 @@ @router.post("") -async def execute_query(body: QueryRequest, db: AsyncSession = Depends(get_db)): +async def execute_query( + body: QueryRequest, + ctx: AuthContext = Depends(get_org_context), + db: AsyncSession = Depends(get_db), +): """Submit a natural language question and get SQL + results + interpretation.""" result = await query_service.execute_nl_query( - db, body.connection_id, body.question + db, body.connection_id, body.question, ctx ) return result @router.post("/execute-sql") -async def execute_sql(body: ExecuteSQLRequest, db: AsyncSession = Depends(get_db)): +async def execute_sql( + body: ExecuteSQLRequest, + ctx: AuthContext = Depends(get_org_context), + db: AsyncSession = Depends(get_db), +): """Execute user-provided SQL directly (no LLM generation).""" result = await query_service.execute_raw_sql( - db, body.connection_id, body.sql, body.original_question + db, body.connection_id, body.sql, ctx, body.original_question ) return result @router.post("/sql-only", response_model=SQLOnlyResponse) -async def generate_sql_only(body: QueryRequest, db: AsyncSession = Depends(get_db)): +async def generate_sql_only( + body: QueryRequest, + ctx: AuthContext = Depends(get_org_context), + db: AsyncSession = Depends(get_db), +): """Generate SQL without executing it.""" result = await query_service.generate_sql_only( - db, body.connection_id, body.question + db, body.connection_id, body.question, ctx ) return SQLOnlyResponse(**result) diff --git a/backend/app/api/v1/endpoints/query_history.py b/backend/app/api/v1/endpoints/query_history.py index e2a0092..a413636 100644 --- a/backend/app/api/v1/endpoints/query_history.py +++ b/backend/app/api/v1/endpoints/query_history.py @@ -5,21 +5,36 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.api.v1.schemas.query import QueryHistoryResponse +from app.core.auth import AuthContext, get_org_context from app.core.exceptions import NotFoundError +from app.db.models.connection import DatabaseConnection from app.db.models.query_history import QueryExecution from app.db.session import get_db router = APIRouter(prefix="/query-history", tags=["query_history"]) +def _workspace_scoped(ctx: AuthContext): + """History rows whose connection lives in the caller's workspace.""" + return ( + select(QueryExecution) + .join(DatabaseConnection, QueryExecution.connection_id == DatabaseConnection.id) + .where( + QueryExecution.organization_id == ctx.organization_id, + DatabaseConnection.workspace_id == ctx.workspace_id, + ) + ) + + @router.get("", response_model=list[QueryHistoryResponse]) async def list_query_history( connection_id: uuid.UUID | None = Query(default=None), limit: int = Query(default=50, ge=1, le=200), offset: int = Query(default=0, ge=0), + ctx: AuthContext = Depends(get_org_context), db: AsyncSession = Depends(get_db), ): - stmt = select(QueryExecution).order_by(QueryExecution.created_at.desc()) + stmt = _workspace_scoped(ctx).order_by(QueryExecution.created_at.desc()) if connection_id: stmt = stmt.where(QueryExecution.connection_id == connection_id) stmt = stmt.offset(offset).limit(limit) @@ -27,25 +42,32 @@ async def list_query_history( return list(result.scalars().all()) +async def _get_scoped_execution( + db: AsyncSession, ctx: AuthContext, execution_id: uuid.UUID +) -> QueryExecution: + stmt = _workspace_scoped(ctx).where(QueryExecution.id == execution_id) + execution = (await db.execute(stmt)).scalar_one_or_none() + if not execution: + raise NotFoundError("QueryExecution", str(execution_id)) + return execution + + @router.get("/{execution_id}", response_model=QueryHistoryResponse) async def get_query_execution( execution_id: uuid.UUID, + ctx: AuthContext = Depends(get_org_context), db: AsyncSession = Depends(get_db), ): - execution = await db.get(QueryExecution, execution_id) - if not execution: - raise NotFoundError("QueryExecution", str(execution_id)) - return execution + return await _get_scoped_execution(db, ctx, execution_id) @router.patch("/{execution_id}/favorite") async def toggle_favorite( execution_id: uuid.UUID, + ctx: AuthContext = Depends(get_org_context), db: AsyncSession = Depends(get_db), ): - execution = await db.get(QueryExecution, execution_id) - if not execution: - raise NotFoundError("QueryExecution", str(execution_id)) + execution = await _get_scoped_execution(db, ctx, execution_id) execution.is_favorite = not execution.is_favorite await db.flush() return {"is_favorite": execution.is_favorite} diff --git a/backend/app/api/v1/endpoints/sample_queries.py b/backend/app/api/v1/endpoints/sample_queries.py index 28eee5f..69c4489 100644 --- a/backend/app/api/v1/endpoints/sample_queries.py +++ b/backend/app/api/v1/endpoints/sample_queries.py @@ -5,6 +5,8 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from app.api.v1.deps import require_connection_read, require_connection_write +from app.core.auth import AuthContext from app.core.exceptions import NotFoundError from app.db.models.sample_query import SampleQuery from app.db.session import get_db @@ -49,6 +51,7 @@ class SampleQueryResponse(BaseModel): ) async def list_sample_queries( connection_id: uuid.UUID, + _ctx: AuthContext = Depends(require_connection_read), db: AsyncSession = Depends(get_db), ): result = await db.execute( @@ -67,9 +70,15 @@ async def list_sample_queries( async def create_sample_query( connection_id: uuid.UUID, body: SampleQueryCreate, + ctx: AuthContext = Depends(require_connection_write), db: AsyncSession = Depends(get_db), ): - sq = SampleQuery(connection_id=connection_id, **body.model_dump()) + sq = SampleQuery( + connection_id=connection_id, + organization_id=ctx.organization_id, + created_by_id=ctx.user_id, + **body.model_dump(), + ) db.add(sq) await db.flush() try: @@ -87,6 +96,7 @@ async def update_sample_query( connection_id: uuid.UUID, sq_id: uuid.UUID, body: SampleQueryUpdate, + _ctx: AuthContext = Depends(require_connection_write), db: AsyncSession = Depends(get_db), ): sq = await db.get(SampleQuery, sq_id) @@ -109,6 +119,7 @@ async def update_sample_query( async def delete_sample_query( connection_id: uuid.UUID, sq_id: uuid.UUID, + _ctx: AuthContext = Depends(require_connection_write), db: AsyncSession = Depends(get_db), ): sq = await db.get(SampleQuery, sq_id) diff --git a/backend/app/api/v1/endpoints/schemas.py b/backend/app/api/v1/endpoints/schemas.py index 09e71eb..c809358 100644 --- a/backend/app/api/v1/endpoints/schemas.py +++ b/backend/app/api/v1/endpoints/schemas.py @@ -10,6 +10,7 @@ TableDetailResponse, TableResponse, ) +from app.core.auth import AuthContext, get_org_context from app.db.session import get_db from app.services import schema_service from app.services.setup_service import launch_background_embeddings @@ -23,9 +24,10 @@ ) async def introspect_connection( connection_id: uuid.UUID, + ctx: AuthContext = Depends(get_org_context), db: AsyncSession = Depends(get_db), ): - result = await schema_service.introspect_and_cache(db, connection_id) + result = await schema_service.introspect_and_cache(db, connection_id, ctx) launch_background_embeddings(connection_id) return IntrospectionResult(**result) @@ -36,9 +38,10 @@ async def introspect_connection( ) async def list_tables( connection_id: uuid.UUID, + ctx: AuthContext = Depends(get_org_context), db: AsyncSession = Depends(get_db), ): - tables = await schema_service.get_tables(db, connection_id) + tables = await schema_service.get_tables(db, connection_id, ctx) return [ TableResponse( id=t.id, @@ -60,9 +63,10 @@ async def list_tables( ) async def get_table_detail( table_id: uuid.UUID, + ctx: AuthContext = Depends(get_org_context), db: AsyncSession = Depends(get_db), ): - table = await schema_service.get_table_detail(db, table_id) + table = await schema_service.get_table_detail(db, table_id, ctx) columns = [ ColumnResponse( diff --git a/backend/app/api/v1/endpoints/teams.py b/backend/app/api/v1/endpoints/teams.py new file mode 100644 index 0000000..539b3dd --- /dev/null +++ b/backend/app/api/v1/endpoints/teams.py @@ -0,0 +1,86 @@ +import uuid + +from fastapi import APIRouter, Depends +from sqlalchemy.ext.asyncio import AsyncSession + +from app.api.v1.schemas.team import ( + MembershipCreate, + MembershipResponse, + TeamCreate, + TeamResponse, +) +from app.core.auth import AuthContext, get_org_context +from app.db.session import get_db +from app.services import identity_service + +router = APIRouter(prefix="/teams", tags=["teams"]) + + +@router.get("", response_model=list[TeamResponse]) +async def list_teams( + ctx: AuthContext = Depends(get_org_context), + db: AsyncSession = Depends(get_db), +): + teams = await identity_service.list_teams(db, ctx) + return [TeamResponse.model_validate(t) for t in teams] + + +@router.post("", response_model=TeamResponse, status_code=201) +async def create_team( + body: TeamCreate, + ctx: AuthContext = Depends(get_org_context), + db: AsyncSession = Depends(get_db), +): + team = await identity_service.create_team(db, ctx, body.name) + return TeamResponse.model_validate(team) + + +@router.get("/{team_id}/members", response_model=list[MembershipResponse]) +async def list_members( + team_id: uuid.UUID, + ctx: AuthContext = Depends(get_org_context), + db: AsyncSession = Depends(get_db), +): + memberships = await identity_service.list_memberships(db, ctx, team_id) + return [ + MembershipResponse( + id=m.id, + team_id=m.team_id, + user_id=m.user_id, + user_email=m.user.email, + user_name=m.user.name, + role=m.role, + created_at=m.created_at, + ) + for m in memberships + ] + + +@router.post("/{team_id}/members", response_model=MembershipResponse, status_code=201) +async def add_member( + team_id: uuid.UUID, + body: MembershipCreate, + ctx: AuthContext = Depends(get_org_context), + db: AsyncSession = Depends(get_db), +): + membership = await identity_service.add_membership(db, ctx, team_id, body.email, body.role) + await db.refresh(membership, ["user"]) + return MembershipResponse( + id=membership.id, + team_id=membership.team_id, + user_id=membership.user_id, + user_email=membership.user.email, + user_name=membership.user.name, + role=membership.role, + created_at=membership.created_at, + ) + + +@router.delete("/{team_id}/members/{user_id}", status_code=204) +async def remove_member( + team_id: uuid.UUID, + user_id: uuid.UUID, + ctx: AuthContext = Depends(get_org_context), + db: AsyncSession = Depends(get_db), +): + await identity_service.remove_membership(db, ctx, team_id, user_id) diff --git a/backend/app/api/v1/router.py b/backend/app/api/v1/router.py index c569da5..f237687 100644 --- a/backend/app/api/v1/router.py +++ b/backend/app/api/v1/router.py @@ -1,6 +1,8 @@ from fastapi import APIRouter from app.api.v1.endpoints import ( + api_keys, + auth, connections, dictionary, glossary, @@ -11,11 +13,15 @@ query_history, sample_queries, schemas, + teams, ) api_router = APIRouter() api_router.include_router(health.router) +api_router.include_router(auth.router) +api_router.include_router(teams.router) +api_router.include_router(api_keys.router) api_router.include_router(query.router) api_router.include_router(connections.router) api_router.include_router(schemas.router) diff --git a/backend/app/api/v1/schemas/auth.py b/backend/app/api/v1/schemas/auth.py new file mode 100644 index 0000000..ba0a837 --- /dev/null +++ b/backend/app/api/v1/schemas/auth.py @@ -0,0 +1,64 @@ +from datetime import datetime +from uuid import UUID + +from pydantic import BaseModel, Field + +# A pragmatic email pattern — full RFC validation would pull in email-validator; +# we keep deps light and rely on a length/format-bounded string. +_EMAIL_PATTERN = r"^[^@\s]+@[^@\s]+\.[^@\s]+$" + + +class LoginRequest(BaseModel): + email: str = Field(min_length=3, max_length=320, pattern=_EMAIL_PATTERN) + password: str = Field(min_length=1) + + +class RegisterRequest(BaseModel): + email: str = Field(min_length=3, max_length=320, pattern=_EMAIL_PATTERN) + password: str = Field(min_length=8, max_length=256) + name: str | None = Field(default=None, max_length=255) + + +class MagicLinkRequest(BaseModel): + email: str = Field(min_length=3, max_length=320, pattern=_EMAIL_PATTERN) + + +class MagicLinkVerifyRequest(BaseModel): + token: str = Field(min_length=1) + + +class MagicLinkResponse(BaseModel): + sent: bool + # Surfaced only in non-production so local dev can complete the flow. + dev_token: str | None = None + dev_verify_url: str | None = None + + +class UserResponse(BaseModel): + id: UUID + email: str + name: str | None + status: str + last_login_at: datetime | None + created_at: datetime + + model_config = {"from_attributes": True} + + +class WorkspaceMembershipResponse(BaseModel): + team_id: UUID + team_name: str + role: str + + +class MeResponse(BaseModel): + user: UserResponse + workspaces: list[WorkspaceMembershipResponse] + + +class AuthProviderInfo(BaseModel): + name: str + supports_password: bool + supports_magic_link: bool + is_sso: bool + disable_auth: bool diff --git a/backend/app/api/v1/schemas/team.py b/backend/app/api/v1/schemas/team.py new file mode 100644 index 0000000..f844b7b --- /dev/null +++ b/backend/app/api/v1/schemas/team.py @@ -0,0 +1,61 @@ +from datetime import datetime +from uuid import UUID + +from pydantic import BaseModel, Field + +from app.db.models.membership import ROLES + + +class TeamCreate(BaseModel): + name: str = Field(min_length=1, max_length=255) + + +class TeamResponse(BaseModel): + id: UUID + organization_id: UUID + name: str + slug: str + created_at: datetime + + model_config = {"from_attributes": True} + + +class MembershipCreate(BaseModel): + email: str = Field(min_length=3, max_length=320) + role: str = Field(default="viewer") + + +class MembershipResponse(BaseModel): + id: UUID + team_id: UUID + user_id: UUID + user_email: str + user_name: str | None + role: str + created_at: datetime + + +class ApiKeyCreate(BaseModel): + name: str = Field(min_length=1, max_length=255) + expires_at: datetime | None = None + + +class ApiKeyResponse(BaseModel): + id: UUID + name: str + key_prefix: str + expires_at: datetime | None + last_used_at: datetime | None + revoked_at: datetime | None + created_at: datetime + + model_config = {"from_attributes": True} + + +class ApiKeyCreatedResponse(ApiKeyResponse): + # The plaintext key — returned exactly once, on creation. + key: str + + +# Re-exported so endpoints can validate role values against the model layer. +VALID_ROLES = ROLES diff --git a/backend/app/config.py b/backend/app/config.py index 61595bd..162be4f 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -61,6 +61,36 @@ class Settings(BaseSettings): max_queries_per_minute: int = 30 rate_limit_enabled: bool = True + # Authentication & identity (Phase 1) + # Master switch for local dev: when true, every request is treated as a + # synthetic admin user — no login required. NEVER enable in production. + disable_auth: bool = False + # Auth provider for interactive login: local (password) | magic_link | oidc. + # `local` and `magic_link` are implemented; `oidc` is a registered seam. + auth_provider: str = "local" + # JWT signing — sessions are stateless HS256 tokens delivered as a cookie. + jwt_secret: str = "dev-jwt-secret-change-in-production" + jwt_algorithm: str = "HS256" + jwt_access_ttl_minutes: int = 60 * 12 # session lifetime + magic_link_ttl_minutes: int = 15 # magic-link token lifetime + # Session cookie delivery (HTTP-only; Secure should be true behind TLS). + auth_cookie_name: str = "qw_session" + auth_cookie_secure: bool = False # set true in production (HTTPS only) + auth_cookie_samesite: str = "lax" # lax | strict | none + auth_cookie_domain: str | None = None + # Default organization + first admin, created on boot (and in migration 004). + default_org_name: str = "Default Organization" + default_org_slug: str = "default" + default_workspace_name: str = "Default Workspace" + default_admin_email: str = "admin@querywise.local" + # When set, the bootstrapped admin gets this password (local login). + default_admin_password: str | None = None + # OIDC seam (not implemented yet — placeholders for the provider stub). + oidc_issuer: str | None = None + oidc_client_id: str | None = None + oidc_client_secret: str | None = None + oidc_redirect_url: str | None = None + # Context builder max_context_tables: int = 8 max_sample_queries: int = 3 diff --git a/backend/app/core/auth.py b/backend/app/core/auth.py new file mode 100644 index 0000000..6e7ab09 --- /dev/null +++ b/backend/app/core/auth.py @@ -0,0 +1,208 @@ +"""Request-level authentication & authorization. + +Provides the FastAPI dependencies that the API and service layers build on: + +* :func:`get_current_user` — resolves the caller (session cookie, Bearer JWT, + or ``X-API-Key``) to a :class:`User`, or 401. +* :func:`get_org_context` — resolves the active workspace + role into an + :class:`AuthContext` (the object threaded into services for scoping). +* :func:`require_role` — dependency factory enforcing a minimum role. + +``DISABLE_AUTH=true`` short-circuits authentication to the bootstrapped default +admin for local development. Never enable it in production. +""" + +from __future__ import annotations + +import uuid +from dataclasses import dataclass +from datetime import UTC, datetime +from typing import Literal, cast + +from fastapi import Depends, Request, Response +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload + +from app.config import settings +from app.core.exceptions import AuthenticationError, AuthorizationError +from app.core.security import ( + TOKEN_PURPOSE_SESSION, + decode_token, + hash_api_key, +) +from app.db.models.api_key import ApiKey +from app.db.models.membership import ROLE_RANK, Membership +from app.db.models.user import User +from app.db.session import get_db + + +@dataclass +class AuthContext: + """The authenticated caller plus their active workspace + role. + + Threaded into service functions so they can scope queries by + ``organization_id`` / ``workspace_id`` and enforce role checks. + """ + + user: User + organization_id: uuid.UUID + workspace_id: uuid.UUID + role: str + + @property + def user_id(self) -> uuid.UUID: + return self.user.id + + def has_role(self, minimum: str) -> bool: + return ROLE_RANK.get(self.role, 0) >= ROLE_RANK.get(minimum, 99) + + def require_role(self, minimum: str) -> None: + if not self.has_role(minimum): + raise AuthorizationError(f"This action requires the '{minimum}' role.") + + +# --- Cookie helpers --------------------------------------------------------- + + +def set_session_cookie(response: Response, token: str) -> None: + """Attach the session JWT as an HTTP-only cookie.""" + response.set_cookie( + key=settings.auth_cookie_name, + value=token, + max_age=settings.jwt_access_ttl_minutes * 60, + httponly=True, + secure=settings.auth_cookie_secure, + samesite=cast(Literal["lax", "strict", "none"], settings.auth_cookie_samesite), + domain=settings.auth_cookie_domain, + path="/", + ) + + +def clear_session_cookie(response: Response) -> None: + response.delete_cookie( + key=settings.auth_cookie_name, + domain=settings.auth_cookie_domain, + path="/", + ) + + +# --- Authentication --------------------------------------------------------- + + +async def _user_from_api_key(db: AsyncSession, raw_key: str) -> User: + result = await db.execute( + select(ApiKey) + .where(ApiKey.key_hash == hash_api_key(raw_key)) + .options(selectinload(ApiKey.user)) + ) + api_key = result.scalar_one_or_none() + if api_key is None or api_key.revoked_at is not None: + raise AuthenticationError("Invalid API key") + now = datetime.now(UTC) + if api_key.expires_at is not None and api_key.expires_at < now: + raise AuthenticationError("API key expired") + if api_key.user is None or not api_key.user.is_active: + raise AuthenticationError("API key owner is inactive") + api_key.last_used_at = now + return api_key.user + + +def _extract_bearer_or_cookie(request: Request) -> str | None: + auth_header = request.headers.get("authorization") + if auth_header and auth_header.lower().startswith("bearer "): + return auth_header[7:].strip() + return request.cookies.get(settings.auth_cookie_name) + + +async def _dev_admin(db: AsyncSession) -> User: + """Load the bootstrapped default admin for DISABLE_AUTH local dev.""" + result = await db.execute(select(User).where(User.email == settings.default_admin_email)) + user = result.scalar_one_or_none() + if user is None: + raise AuthenticationError( + "DISABLE_AUTH is set but the default admin " + f"({settings.default_admin_email}) does not exist. Run migrations / boot once." + ) + return user + + +async def get_current_user( + request: Request, + db: AsyncSession = Depends(get_db), +) -> User: + """Resolve the calling user from API key, Bearer token, or session cookie.""" + if settings.disable_auth: + return await _dev_admin(db) + + api_key = request.headers.get("x-api-key") + if api_key: + return await _user_from_api_key(db, api_key) + + token = _extract_bearer_or_cookie(request) + if not token: + raise AuthenticationError() + + payload = decode_token(token, TOKEN_PURPOSE_SESSION) + try: + user_id = uuid.UUID(payload["sub"]) + except (KeyError, ValueError) as exc: + raise AuthenticationError("Malformed token subject") from exc + + user = await db.get(User, user_id) + if user is None or not user.is_active: + raise AuthenticationError("User not found or inactive") + return user + + +# --- Authorization (workspace + role) --------------------------------------- + + +async def get_org_context( + request: Request, + user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +) -> AuthContext: + """Resolve the active workspace + role for the caller. + + The active workspace is chosen from the ``X-Workspace-Id`` header when the + user is a member of it; otherwise the earliest-joined membership is used. + """ + result = await db.execute( + select(Membership) + .where(Membership.user_id == user.id) + .options(selectinload(Membership.team)) + .order_by(Membership.created_at) + ) + memberships = list(result.scalars().all()) + if not memberships: + raise AuthorizationError("User is not a member of any workspace.") + + selected = memberships[0] + requested = request.headers.get("x-workspace-id") + if requested: + try: + requested_id = uuid.UUID(requested) + except ValueError as exc: + raise AuthorizationError("Invalid X-Workspace-Id header.") from exc + match = next((m for m in memberships if m.team_id == requested_id), None) + if match is None: + raise AuthorizationError("You are not a member of the requested workspace.") + selected = match + + return AuthContext( + user=user, + organization_id=selected.team.organization_id, + workspace_id=selected.team_id, + role=selected.role, + ) + + +def require_role(minimum: str): + """Dependency factory: require at least ``minimum`` role in the workspace.""" + + async def _dependency(ctx: AuthContext = Depends(get_org_context)) -> AuthContext: + ctx.require_role(minimum) + return ctx + + return _dependency diff --git a/backend/app/core/auth_providers.py b/backend/app/core/auth_providers.py new file mode 100644 index 0000000..f738072 --- /dev/null +++ b/backend/app/core/auth_providers.py @@ -0,0 +1,132 @@ +"""Pluggable interactive-login backends. + +Mirrors :mod:`app.core.secrets`: a small ``AuthProvider`` interface behind a +name-keyed registry so a deployment selects its login method with +``AUTH_PROVIDER`` and SSO backends can be registered without touching the core. + +* ``local`` — email + password *and* magic-link (default; fully implemented + in :mod:`app.services.auth_service`). +* ``magic_link`` — passwordless email magic-link only. +* ``oidc`` — registered seam for OIDC/OAuth2 (Google/Okta/Entra). Not yet + implemented; ``authorization_url`` / ``exchange_code`` raise + until an implementation is registered via + :func:`register_auth_provider`. + +The password / magic-link flows are driven by the service layer; this seam +exists so the configured provider advertises its capabilities to the frontend +(:func:`AuthProvider.describe`) and so SSO can slot in later. +""" + +from __future__ import annotations + +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any + +from app.config import settings + + +@dataclass +class VerifiedIdentity: + """An identity verified by a provider (e.g. after an SSO code exchange).""" + + email: str + name: str | None = None + sso_subject: str | None = None + + +class AuthProvider: + """Base class for login backends. Subclasses set capability flags and, for + SSO, override :meth:`authorization_url` / :meth:`exchange_code`.""" + + name: str = "base" + supports_password: bool = False + supports_magic_link: bool = False + is_sso: bool = False + + def authorization_url(self, state: str) -> str: + """Return the IdP URL to redirect the browser to (SSO providers).""" + raise NotImplementedError(f"Auth provider '{self.name}' does not support SSO redirects.") + + async def exchange_code(self, code: str, state: str | None = None) -> VerifiedIdentity: + """Exchange an SSO authorization code for a verified identity.""" + raise NotImplementedError(f"Auth provider '{self.name}' does not support code exchange.") + + def describe(self) -> dict[str, Any]: + return { + "name": self.name, + "supports_password": self.supports_password, + "supports_magic_link": self.supports_magic_link, + "is_sso": self.is_sso, + } + + +class LocalAuthProvider(AuthProvider): + """Self-hosted accounts — password login plus magic-link as a fallback.""" + + name = "local" + supports_password = True + supports_magic_link = True + + +class MagicLinkAuthProvider(AuthProvider): + """Passwordless email magic-link only (lowest-friction first-run).""" + + name = "magic_link" + supports_magic_link = True + + +class OIDCAuthProvider(AuthProvider): + """Seam for OIDC/OAuth2 SSO. Implementation deferred (Phase 1 follow-up).""" + + name = "oidc" + is_sso = True + + def _unimplemented(self) -> NotImplementedError: + return NotImplementedError( + "OIDC auth is not yet implemented. Configure AUTH_PROVIDER=local or " + "magic_link, or register an implementation with " + "register_auth_provider('oidc', factory)." + ) + + def authorization_url(self, state: str) -> str: + raise self._unimplemented() + + async def exchange_code(self, code: str, state: str | None = None) -> VerifiedIdentity: + raise self._unimplemented() + + +_PROVIDER_FACTORIES: dict[str, Callable[[], AuthProvider]] = { + "local": LocalAuthProvider, + "magic_link": MagicLinkAuthProvider, + "oidc": OIDCAuthProvider, +} + +_instance: AuthProvider | None = None + + +def register_auth_provider(name: str, factory: Callable[[], AuthProvider]) -> None: + """Register (or override) an auth provider factory by name.""" + global _instance + _PROVIDER_FACTORIES[name] = factory + _instance = None + + +def get_auth_provider() -> AuthProvider: + """Return the process-wide provider for the configured ``AUTH_PROVIDER``.""" + global _instance + if _instance is None: + factory = _PROVIDER_FACTORIES.get(settings.auth_provider) + if factory is None: + raise ValueError( + f"Unknown auth provider '{settings.auth_provider}'. " + f"Available: {sorted(_PROVIDER_FACTORIES)}" + ) + _instance = factory() + return _instance + + +def reset_auth_provider() -> None: + """Clear the cached provider. Test/reconfiguration hook.""" + global _instance + _instance = None diff --git a/backend/app/core/exceptions.py b/backend/app/core/exceptions.py index 9386574..85e4f41 100644 --- a/backend/app/core/exceptions.py +++ b/backend/app/core/exceptions.py @@ -20,6 +20,16 @@ def __init__(self, message: str): super().__init__(message, status_code=422) +class AuthenticationError(AppError): + def __init__(self, message: str = "Not authenticated"): + super().__init__(message, status_code=401) + + +class AuthorizationError(AppError): + def __init__(self, message: str = "Not authorized"): + super().__init__(message, status_code=403) + + class SQLSafetyError(AppError): def __init__(self, message: str): super().__init__(f"SQL safety violation: {message}", status_code=403) diff --git a/backend/app/core/security.py b/backend/app/core/security.py new file mode 100644 index 0000000..d41e77c --- /dev/null +++ b/backend/app/core/security.py @@ -0,0 +1,110 @@ +"""Low-level auth primitives: password hashing, JWTs, and API keys. + +Kept dependency-light and free of FastAPI/DB imports so it is trivially unit +testable. Higher-level request plumbing (dependencies, cookies, AuthContext) +lives in :mod:`app.core.auth`. +""" + +from __future__ import annotations + +import hashlib +import hmac +import secrets +from datetime import UTC, datetime, timedelta +from typing import Any + +import jwt + +from app.config import settings +from app.core.exceptions import AuthenticationError + +# --- Passwords (PBKDF2-HMAC-SHA256, stdlib — no native build deps) ---------- + +_PBKDF2_ALGO = "pbkdf2_sha256" +_PBKDF2_ROUNDS = 390_000 + + +def hash_password(password: str) -> str: + """Return an encoded ``algo$rounds$salt$hash`` string for ``password``.""" + salt = secrets.token_bytes(16) + digest = hashlib.pbkdf2_hmac("sha256", password.encode(), salt, _PBKDF2_ROUNDS) + return f"{_PBKDF2_ALGO}${_PBKDF2_ROUNDS}${salt.hex()}${digest.hex()}" + + +def verify_password(password: str, encoded: str | None) -> bool: + """Constant-time check of ``password`` against an encoded hash.""" + if not encoded: + return False + try: + algo, rounds_s, salt_hex, hash_hex = encoded.split("$") + if algo != _PBKDF2_ALGO: + return False + rounds = int(rounds_s) + salt = bytes.fromhex(salt_hex) + expected = bytes.fromhex(hash_hex) + except (ValueError, AttributeError): + return False + digest = hashlib.pbkdf2_hmac("sha256", password.encode(), salt, rounds) + return hmac.compare_digest(digest, expected) + + +# --- JWTs (HS256, stateless sessions + magic-link tokens) ------------------- + +TOKEN_PURPOSE_SESSION = "session" +TOKEN_PURPOSE_MAGIC_LINK = "magic_link" + + +def create_token( + subject: str, + purpose: str, + ttl_minutes: int, + **extra_claims: Any, +) -> str: + """Sign a JWT for ``subject`` with a ``purpose`` claim and TTL.""" + now = datetime.now(UTC) + payload: dict[str, Any] = { + "sub": subject, + "purpose": purpose, + "iat": now, + "exp": now + timedelta(minutes=ttl_minutes), + **extra_claims, + } + return jwt.encode(payload, settings.jwt_secret, algorithm=settings.jwt_algorithm) + + +def decode_token(token: str, expected_purpose: str) -> dict[str, Any]: + """Decode and validate a JWT, enforcing signature, expiry, and purpose.""" + try: + payload = jwt.decode(token, settings.jwt_secret, algorithms=[settings.jwt_algorithm]) + except jwt.PyJWTError as exc: + raise AuthenticationError("Invalid or expired token") from exc + if payload.get("purpose") != expected_purpose: + raise AuthenticationError("Token purpose mismatch") + return payload + + +def create_session_token(user_id: str) -> str: + return create_token(user_id, TOKEN_PURPOSE_SESSION, settings.jwt_access_ttl_minutes) + + +def create_magic_link_token(email: str) -> str: + # Magic-link subject is the email — the user may not exist yet at request time. + return create_token(email, TOKEN_PURPOSE_MAGIC_LINK, settings.magic_link_ttl_minutes) + + +# --- API keys --------------------------------------------------------------- + +API_KEY_PREFIX = "qw_" + + +def generate_api_key() -> tuple[str, str, str]: + """Return ``(plaintext, sha256_hash, display_prefix)`` for a new API key. + + The plaintext is shown to the user exactly once; only the hash is stored. + """ + plaintext = API_KEY_PREFIX + secrets.token_urlsafe(32) + return plaintext, hash_api_key(plaintext), plaintext[:10] + + +def hash_api_key(plaintext: str) -> str: + return hashlib.sha256(plaintext.encode()).hexdigest() diff --git a/backend/app/db/models/__init__.py b/backend/app/db/models/__init__.py index 6e07a5e..6da6387 100644 --- a/backend/app/db/models/__init__.py +++ b/backend/app/db/models/__init__.py @@ -1,13 +1,23 @@ +from app.db.models.api_key import ApiKey from app.db.models.connection import DatabaseConnection from app.db.models.dictionary import DictionaryEntry from app.db.models.glossary import GlossaryTerm from app.db.models.knowledge import KnowledgeChunk, KnowledgeDocument +from app.db.models.membership import Membership from app.db.models.metric import MetricDefinition +from app.db.models.organization import Organization from app.db.models.query_history import QueryExecution from app.db.models.sample_query import SampleQuery from app.db.models.schema_cache import CachedColumn, CachedRelationship, CachedTable +from app.db.models.team import Team +from app.db.models.user import User __all__ = [ + "Organization", + "User", + "Team", + "Membership", + "ApiKey", "DatabaseConnection", "CachedTable", "CachedColumn", diff --git a/backend/app/db/models/api_key.py b/backend/app/db/models/api_key.py new file mode 100644 index 0000000..d5760fb --- /dev/null +++ b/backend/app/db/models/api_key.py @@ -0,0 +1,36 @@ +import uuid +from datetime import datetime + +from sqlalchemy import DateTime, ForeignKey, String, func +from sqlalchemy.dialects.postgresql import JSONB, UUID +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from app.db.base import Base + + +class ApiKey(Base): + """A programmatic credential bound to a user. + + Only the SHA-256 ``key_hash`` is stored — the plaintext key is shown once at + creation. ``key_prefix`` is the non-secret leading segment kept for display + ("qw_ab12…"). The key inherits the user's memberships/roles at request time. + """ + + __tablename__ = "api_keys" + + id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + user_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False + ) + name: Mapped[str] = mapped_column(String(255), nullable=False) + key_hash: Mapped[str] = mapped_column(String(64), nullable=False, unique=True) + key_prefix: Mapped[str] = mapped_column(String(16), nullable=False) + permissions: Mapped[dict | None] = mapped_column(JSONB, default=dict) + expires_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True)) + last_used_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True)) + revoked_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True)) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now() + ) + + user: Mapped["User"] = relationship(back_populates="api_keys") # noqa: F821 diff --git a/backend/app/db/models/connection.py b/backend/app/db/models/connection.py index 026260e..459a376 100644 --- a/backend/app/db/models/connection.py +++ b/backend/app/db/models/connection.py @@ -1,7 +1,7 @@ import uuid from datetime import datetime -from sqlalchemy import Boolean, DateTime, Integer, String, Text, func +from sqlalchemy import Boolean, DateTime, ForeignKey, Integer, String, Text, func from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.orm import Mapped, mapped_column, relationship @@ -12,6 +12,19 @@ class DatabaseConnection(Base): __tablename__ = "database_connections" id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + # Identity scoping (Phase 1). organization_id is SaaS-ready; workspace_id + # (a Team) is the isolation unit and the cascade root for all metadata. + organization_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), ForeignKey("organizations.id", ondelete="CASCADE"), nullable=False + ) + workspace_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), ForeignKey("teams.id", ondelete="CASCADE"), nullable=False + ) + owner_id: Mapped[uuid.UUID | None] = mapped_column( + UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL") + ) + # Private connections are visible only to their owner within the workspace. + is_private: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) name: Mapped[str] = mapped_column(String(255), nullable=False) connector_type: Mapped[str] = mapped_column(String(50), nullable=False) # Connection string stored encrypted; handled at service layer diff --git a/backend/app/db/models/glossary.py b/backend/app/db/models/glossary.py index c8b0de5..14c47dd 100644 --- a/backend/app/db/models/glossary.py +++ b/backend/app/db/models/glossary.py @@ -14,6 +14,9 @@ class GlossaryTerm(Base): __tablename__ = "glossary_terms" id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + organization_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), ForeignKey("organizations.id", ondelete="CASCADE"), nullable=False + ) connection_id: Mapped[uuid.UUID] = mapped_column( UUID(as_uuid=True), ForeignKey("database_connections.id", ondelete="CASCADE"), nullable=False ) @@ -24,7 +27,9 @@ class GlossaryTerm(Base): related_columns: Mapped[list[str] | None] = mapped_column(ARRAY(Text)) examples: Mapped[dict | None] = mapped_column(JSONB, default=list) term_embedding = mapped_column(Vector(settings.embedding_dimension), nullable=True) - created_by: Mapped[str | None] = mapped_column(String(255)) + created_by_id: Mapped[uuid.UUID | None] = mapped_column( + UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL") + ) created_at: Mapped[datetime] = mapped_column( DateTime(timezone=True), server_default=func.now() ) diff --git a/backend/app/db/models/knowledge.py b/backend/app/db/models/knowledge.py index edfdb28..4b874bd 100644 --- a/backend/app/db/models/knowledge.py +++ b/backend/app/db/models/knowledge.py @@ -16,6 +16,11 @@ class KnowledgeDocument(Base): id: Mapped[uuid.UUID] = mapped_column( UUID(as_uuid=True), primary_key=True, default=uuid.uuid4 ) + organization_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("organizations.id", ondelete="CASCADE"), + nullable=False, + ) connection_id: Mapped[uuid.UUID] = mapped_column( UUID(as_uuid=True), ForeignKey("database_connections.id", ondelete="CASCADE"), diff --git a/backend/app/db/models/membership.py b/backend/app/db/models/membership.py new file mode 100644 index 0000000..eb34414 --- /dev/null +++ b/backend/app/db/models/membership.py @@ -0,0 +1,40 @@ +import uuid +from datetime import datetime + +from sqlalchemy import DateTime, ForeignKey, String, UniqueConstraint, func +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from app.db.base import Base + +# Roles, ordered by privilege. `admin` can manage the team and its members; +# `editor` can mutate connections + the semantic layer; `viewer` is read-only. +ROLE_ADMIN = "admin" +ROLE_EDITOR = "editor" +ROLE_VIEWER = "viewer" +ROLES = (ROLE_ADMIN, ROLE_EDITOR, ROLE_VIEWER) + +# Higher number = more privilege. Used by require_role() comparisons. +ROLE_RANK = {ROLE_VIEWER: 1, ROLE_EDITOR: 2, ROLE_ADMIN: 3} + + +class Membership(Base): + """Links a :class:`User` to a :class:`Team` with a role.""" + + __tablename__ = "memberships" + __table_args__ = (UniqueConstraint("user_id", "team_id", name="uq_membership_user_team"),) + + id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + user_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False + ) + team_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), ForeignKey("teams.id", ondelete="CASCADE"), nullable=False + ) + role: Mapped[str] = mapped_column(String(20), nullable=False, default=ROLE_VIEWER) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now() + ) + + user: Mapped["User"] = relationship(back_populates="memberships") # noqa: F821 + team: Mapped["Team"] = relationship(back_populates="memberships") # noqa: F821 diff --git a/backend/app/db/models/metric.py b/backend/app/db/models/metric.py index 6c6adf1..df77766 100644 --- a/backend/app/db/models/metric.py +++ b/backend/app/db/models/metric.py @@ -14,6 +14,9 @@ class MetricDefinition(Base): __tablename__ = "metric_definitions" id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + organization_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), ForeignKey("organizations.id", ondelete="CASCADE"), nullable=False + ) connection_id: Mapped[uuid.UUID] = mapped_column( UUID(as_uuid=True), ForeignKey("database_connections.id", ondelete="CASCADE"), nullable=False ) @@ -26,7 +29,9 @@ class MetricDefinition(Base): dimensions: Mapped[list[str] | None] = mapped_column(ARRAY(Text)) filters: Mapped[dict | None] = mapped_column(JSONB, default=dict) metric_embedding = mapped_column(Vector(settings.embedding_dimension), nullable=True) - created_by: Mapped[str | None] = mapped_column(String(255)) + created_by_id: Mapped[uuid.UUID | None] = mapped_column( + UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL") + ) created_at: Mapped[datetime] = mapped_column( DateTime(timezone=True), server_default=func.now() ) diff --git a/backend/app/db/models/organization.py b/backend/app/db/models/organization.py new file mode 100644 index 0000000..ac26332 --- /dev/null +++ b/backend/app/db/models/organization.py @@ -0,0 +1,35 @@ +import uuid +from datetime import datetime + +from sqlalchemy import DateTime, String, func +from sqlalchemy.dialects.postgresql import JSONB, UUID +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from app.db.base import Base + + +class Organization(Base): + """The top-level tenant boundary. + + QueryWise is single-tenant per deployment — one default org is auto-created + on boot. ``organization_id`` is carried on every core table from day one so + the future managed-SaaS fleet (a set of isolated single-tenant instances) + needs no schema migration. + """ + + __tablename__ = "organizations" + + id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + name: Mapped[str] = mapped_column(String(255), nullable=False) + slug: Mapped[str] = mapped_column(String(255), nullable=False, unique=True) + settings: Mapped[dict | None] = mapped_column(JSONB, default=dict) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now() + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now(), onupdate=func.now() + ) + + teams: Mapped[list["Team"]] = relationship( # noqa: F821 + back_populates="organization", cascade="all, delete-orphan" + ) diff --git a/backend/app/db/models/query_history.py b/backend/app/db/models/query_history.py index 1651654..aedb647 100644 --- a/backend/app/db/models/query_history.py +++ b/backend/app/db/models/query_history.py @@ -12,10 +12,16 @@ class QueryExecution(Base): __tablename__ = "query_executions" id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + organization_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), ForeignKey("organizations.id", ondelete="CASCADE"), nullable=False + ) connection_id: Mapped[uuid.UUID] = mapped_column( UUID(as_uuid=True), ForeignKey("database_connections.id"), nullable=False ) - user_id: Mapped[str | None] = mapped_column(String(255)) + # Promoted from a free-text string to a real FK (Phase 1). + user_id: Mapped[uuid.UUID | None] = mapped_column( + UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL") + ) natural_language: Mapped[str] = mapped_column(Text, nullable=False) generated_sql: Mapped[str | None] = mapped_column(Text) final_sql: Mapped[str | None] = mapped_column(Text) diff --git a/backend/app/db/models/sample_query.py b/backend/app/db/models/sample_query.py index 5ade1d7..0060fe8 100644 --- a/backend/app/db/models/sample_query.py +++ b/backend/app/db/models/sample_query.py @@ -2,7 +2,7 @@ from datetime import datetime from pgvector.sqlalchemy import Vector -from sqlalchemy import Boolean, DateTime, ForeignKey, String, Text, func +from sqlalchemy import Boolean, DateTime, ForeignKey, Text, func from sqlalchemy.dialects.postgresql import ARRAY, UUID from sqlalchemy.orm import Mapped, mapped_column, relationship @@ -14,6 +14,9 @@ class SampleQuery(Base): __tablename__ = "sample_queries" id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + organization_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), ForeignKey("organizations.id", ondelete="CASCADE"), nullable=False + ) connection_id: Mapped[uuid.UUID] = mapped_column( UUID(as_uuid=True), ForeignKey("database_connections.id", ondelete="CASCADE"), nullable=False ) @@ -23,7 +26,9 @@ class SampleQuery(Base): tags: Mapped[list[str] | None] = mapped_column(ARRAY(Text)) is_validated: Mapped[bool] = mapped_column(Boolean, default=False) question_embedding = mapped_column(Vector(settings.embedding_dimension), nullable=True) - created_by: Mapped[str | None] = mapped_column(String(255)) + created_by_id: Mapped[uuid.UUID | None] = mapped_column( + UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL") + ) created_at: Mapped[datetime] = mapped_column( DateTime(timezone=True), server_default=func.now() ) diff --git a/backend/app/db/models/team.py b/backend/app/db/models/team.py new file mode 100644 index 0000000..73d07e9 --- /dev/null +++ b/backend/app/db/models/team.py @@ -0,0 +1,40 @@ +import uuid +from datetime import datetime + +from sqlalchemy import DateTime, ForeignKey, String, func +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from app.db.base import Base + + +class Team(Base): + """A team — also the isolation unit (workspace) within an organization. + + Connections and (later) artifacts are scoped to a team via ``workspace_id``. + Users join teams through :class:`Membership` with a role. + """ + + __tablename__ = "teams" + + id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + organization_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("organizations.id", ondelete="CASCADE"), + nullable=False, + ) + name: Mapped[str] = mapped_column(String(255), nullable=False) + slug: Mapped[str] = mapped_column(String(255), nullable=False) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now() + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now(), onupdate=func.now() + ) + + organization: Mapped["Organization"] = relationship( # noqa: F821 + back_populates="teams" + ) + memberships: Mapped[list["Membership"]] = relationship( # noqa: F821 + back_populates="team", cascade="all, delete-orphan" + ) diff --git a/backend/app/db/models/user.py b/backend/app/db/models/user.py new file mode 100644 index 0000000..280b6ea --- /dev/null +++ b/backend/app/db/models/user.py @@ -0,0 +1,50 @@ +import uuid +from datetime import datetime + +from sqlalchemy import DateTime, String, func +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from app.db.base import Base + +# User account states. +USER_STATUS_ACTIVE = "active" +USER_STATUS_DISABLED = "disabled" +USER_STATUSES = (USER_STATUS_ACTIVE, USER_STATUS_DISABLED) + + +class User(Base): + """A real person who authenticates and owns/contributes artifacts. + + ``password_hash`` is set for local-login accounts and null for SSO-only + users; ``sso_subject`` holds the stable IdP subject claim once OIDC lands. + """ + + __tablename__ = "users" + + id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + email: Mapped[str] = mapped_column(String(320), nullable=False, unique=True) + name: Mapped[str | None] = mapped_column(String(255)) + # Stable identity provider subject (OIDC `sub`); null for local accounts. + sso_subject: Mapped[str | None] = mapped_column(String(255), unique=True) + # PBKDF2 hash for local login; null for SSO-only / magic-link-only users. + password_hash: Mapped[str | None] = mapped_column(String(255)) + status: Mapped[str] = mapped_column(String(20), nullable=False, default=USER_STATUS_ACTIVE) + last_login_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True)) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now() + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now(), onupdate=func.now() + ) + + memberships: Mapped[list["Membership"]] = relationship( # noqa: F821 + back_populates="user", cascade="all, delete-orphan" + ) + api_keys: Mapped[list["ApiKey"]] = relationship( # noqa: F821 + back_populates="user", cascade="all, delete-orphan" + ) + + @property + def is_active(self) -> bool: + return self.status == USER_STATUS_ACTIVE diff --git a/backend/app/mcp/server.py b/backend/app/mcp/server.py index 2ceddcb..b1f57d9 100644 --- a/backend/app/mcp/server.py +++ b/backend/app/mcp/server.py @@ -39,7 +39,7 @@ from app.db.models.schema_cache import CachedColumn, CachedTable from app.db.session import async_session_factory from app.semantic.context_builder import build_context -from app.services import connection_service, query_service, schema_service +from app.services import connection_service, identity_service, query_service, schema_service from app.services.embedding_service import ( embed_glossary_term, embed_metric, @@ -234,7 +234,8 @@ async def _mcp_lifespan(_server: FastMCP) -> AsyncIterator[dict]: async def list_connections() -> list[dict]: """List configured database connections (id, name, type, limits).""" async with _session_scope() as db: - conns = await connection_service.list_connections(db) + ctx = await identity_service.system_context(db) + conns = await connection_service.list_connections(db, ctx) return [_conn_dict(c) for c in conns] @@ -274,8 +275,10 @@ async def create_connection( with test_connection, then introspect_connection. Returns the new id. """ async with _session_scope() as db: + ctx = await identity_service.system_context(db) conn = await connection_service.create_connection( db, + ctx, name=name, connector_type=connector_type, connection_string=connection_string, @@ -290,8 +293,9 @@ async def create_connection( async def test_connection(connection: Connection) -> dict: """Check that a connection can be reached and authenticated.""" async with _session_scope() as db: + ctx = await identity_service.system_context(db) conn = await _resolve_connection(db, connection) - ok, message = await connection_service.test_connection(db, conn.id) + ok, message = await connection_service.test_connection(db, conn.id, ctx) return {"success": ok, "message": message} @@ -319,9 +323,10 @@ async def introspect_connection( Idempotent — re-running refreshes the cache. """ async with _session_scope() as db: + ctx = await identity_service.system_context(db) conn = await _resolve_connection(db, connection) cid = conn.id - counts = await schema_service.introspect_and_cache(db, cid) + counts = await schema_service.introspect_and_cache(db, cid, ctx) if generate_embeddings: launch_background_embeddings(cid) return {**counts, "embeddings_started": bool(generate_embeddings)} @@ -331,8 +336,9 @@ async def introspect_connection( async def delete_connection(connection: Connection) -> dict: """Permanently delete a connection and all its cached schema + semantic metadata.""" async with _session_scope() as db: + ctx = await identity_service.system_context(db) conn = await _resolve_connection(db, connection) - await connection_service.delete_connection(db, conn.id) + await connection_service.delete_connection(db, conn.id, ctx) return {"deleted": True} @@ -345,8 +351,9 @@ async def delete_connection(connection: Connection) -> dict: async def list_tables(connection: Connection) -> list[dict]: """List a connection's cached tables with their columns.""" async with _session_scope() as db: + ctx = await identity_service.system_context(db) conn = await _resolve_connection(db, connection) - tables = await schema_service.get_tables(db, conn.id) + tables = await schema_service.get_tables(db, conn.id, ctx) return [ { "id": str(t.id), @@ -379,12 +386,13 @@ async def describe_table( ) -> dict: """Describe one cached table in detail, including its foreign keys.""" async with _session_scope() as db: + ctx = await identity_service.system_context(db) conn = await _resolve_connection(db, connection) - tables = await schema_service.get_tables(db, conn.id) + tables = await schema_service.get_tables(db, conn.id, ctx) match = next((t for t in tables if t.table_name == table_name), None) if not match: raise ValueError(f"Table '{table_name}' not found on '{conn.name}'.") - detail = await schema_service.get_table_detail(db, match.id) + detail = await schema_service.get_table_detail(db, match.id, ctx) return { "schema": detail.schema_name, "name": detail.table_name, @@ -460,8 +468,9 @@ async def run_sql( ) -> dict: """Execute read-only SQL against the target database and return the rows.""" async with _session_scope() as db: + ctx = await identity_service.system_context(db) conn = await _resolve_connection(db, connection) - result = await query_service.execute_raw_sql(db, conn.id, sql) + result = await query_service.execute_raw_sql(db, conn.id, sql, ctx) return { "columns": result.get("columns"), "rows": result.get("rows"), @@ -478,8 +487,9 @@ async def generate_sql( ) -> dict: """Translate a natural-language question into SQL without executing it.""" async with _session_scope() as db: + ctx = await identity_service.system_context(db) conn = await _resolve_connection(db, connection) - return await query_service.generate_sql_only(db, conn.id, question) + return await query_service.generate_sql_only(db, conn.id, question, ctx) @mcp.tool(annotations=ToolAnnotations(title="Ask (NL->answer)", **_READ_ONLY_EXTERNAL)) @@ -493,8 +503,9 @@ async def ask( results. Returns Markdown. """ async with _session_scope() as db: + ctx = await identity_service.system_context(db) conn = await _resolve_connection(db, connection) - res = await query_service.execute_nl_query(db, conn.id, question) + res = await query_service.execute_nl_query(db, conn.id, question, ctx) return _format_ask_result(res) @@ -572,9 +583,12 @@ async def add_glossary_term( ) -> dict: """Define a business glossary term that maps business language to a SQL expression.""" async with _session_scope() as db: + ctx = await identity_service.system_context(db) conn = await _resolve_connection(db, connection) obj = GlossaryTerm( connection_id=conn.id, + organization_id=ctx.organization_id, + created_by_id=ctx.user_id, term=term, definition=definition, sql_expression=sql_expression, @@ -672,9 +686,12 @@ async def add_metric( Use add_glossary_term for phrase-to-SQL mappings; use this for aggregate KPIs. """ async with _session_scope() as db: + ctx = await identity_service.system_context(db) conn = await _resolve_connection(db, connection) obj = MetricDefinition( connection_id=conn.id, + organization_id=ctx.organization_id, + created_by_id=ctx.user_id, metric_name=metric_name, display_name=display_name, sql_expression=sql_expression, @@ -806,9 +823,12 @@ async def add_sample_query( ) -> dict: """Save an NL -> SQL example pair used as a few-shot to steer generation.""" async with _session_scope() as db: + ctx = await identity_service.system_context(db) conn = await _resolve_connection(db, connection) obj = SampleQuery( connection_id=conn.id, + organization_id=ctx.organization_id, + created_by_id=ctx.user_id, natural_language=natural_language, sql_query=sql_query, description=description, @@ -868,12 +888,14 @@ async def add_knowledge( ) -> dict: """Import a knowledge document (text or HTML). Chunked and embedded for retrieval.""" async with _session_scope() as db: + ctx = await identity_service.system_context(db) conn = await _resolve_connection(db, connection) doc = await import_document( db, connection_id=conn.id, title=title, content=content, + organization_id=ctx.organization_id, source_url=source_url, ) return {"id": str(doc.id), "title": doc.title} diff --git a/backend/app/services/auth_service.py b/backend/app/services/auth_service.py new file mode 100644 index 0000000..e3d925f --- /dev/null +++ b/backend/app/services/auth_service.py @@ -0,0 +1,127 @@ +"""Interactive authentication flows: password login and magic-link. + +The configured :mod:`app.core.auth_providers` advertises which of these a +deployment exposes; the actual credential handling lives here. New users +discovered via magic-link are auto-provisioned into the default workspace as +viewers (first-run, single-company convenience). +""" + +from __future__ import annotations + +from datetime import UTC, datetime + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.exceptions import AuthenticationError, ValidationError +from app.core.security import ( + TOKEN_PURPOSE_MAGIC_LINK, + create_magic_link_token, + create_session_token, + decode_token, + hash_password, + verify_password, +) +from app.db.models.membership import ROLE_VIEWER, Membership +from app.db.models.user import User +from app.services import identity_service + + +def _normalize_email(email: str) -> str: + return email.lower().strip() + + +async def _get_user_by_email(db: AsyncSession, email: str) -> User | None: + result = await db.execute(select(User).where(User.email == _normalize_email(email))) + return result.scalar_one_or_none() + + +async def _ensure_default_membership(db: AsyncSession, user: User) -> None: + """Give a freshly provisioned user viewer access to the default workspace.""" + org, team, _admin = await identity_service.bootstrap_default_identity(db) + result = await db.execute( + select(Membership).where( + Membership.user_id == user.id, Membership.team_id == team.id + ) + ) + if result.scalar_one_or_none() is None: + db.add(Membership(user_id=user.id, team_id=team.id, role=ROLE_VIEWER)) + await db.flush() + + +async def find_or_create_user( + db: AsyncSession, + email: str, + name: str | None = None, + sso_subject: str | None = None, +) -> User: + user = await _get_user_by_email(db, email) + if user is None: + user = User(email=_normalize_email(email), name=name, sso_subject=sso_subject) + db.add(user) + await db.flush() + await _ensure_default_membership(db, user) + return user + + +async def register_user( + db: AsyncSession, + email: str, + password: str, + name: str | None = None, +) -> User: + """Create a local password account and provision default access.""" + if await _get_user_by_email(db, email) is not None: + raise ValidationError("A user with this email already exists.") + if len(password) < 8: + raise ValidationError("Password must be at least 8 characters.") + user = User( + email=_normalize_email(email), + name=name, + password_hash=hash_password(password), + ) + db.add(user) + await db.flush() + await _ensure_default_membership(db, user) + return user + + +async def authenticate_password(db: AsyncSession, email: str, password: str) -> User: + user = await _get_user_by_email(db, email) + # Verify even when the user is missing to avoid timing-based enumeration. + if user is None or not verify_password(password, user.password_hash): + raise AuthenticationError("Invalid email or password.") + if not user.is_active: + raise AuthenticationError("This account is disabled.") + await _touch_login(db, user) + return user + + +async def request_magic_link(db: AsyncSession, email: str) -> str: + """Issue a magic-link token for ``email``. + + Returns the signed token; delivery (email/Slack) is the caller's concern. + For local dev the token is surfaced directly by the endpoint. + """ + return create_magic_link_token(_normalize_email(email)) + + +async def verify_magic_link(db: AsyncSession, token: str) -> User: + payload = decode_token(token, TOKEN_PURPOSE_MAGIC_LINK) + email = payload.get("sub") + if not email: + raise AuthenticationError("Malformed magic-link token.") + user = await find_or_create_user(db, email) + if not user.is_active: + raise AuthenticationError("This account is disabled.") + await _touch_login(db, user) + return user + + +async def _touch_login(db: AsyncSession, user: User) -> None: + user.last_login_at = datetime.now(UTC) + await db.flush() + + +def issue_session_token(user: User) -> str: + return create_session_token(str(user.id)) diff --git a/backend/app/services/connection_service.py b/backend/app/services/connection_service.py index 2199a1d..c887611 100644 --- a/backend/app/services/connection_service.py +++ b/backend/app/services/connection_service.py @@ -1,12 +1,18 @@ import uuid -from sqlalchemy import select +from sqlalchemy import or_, select from sqlalchemy.ext.asyncio import AsyncSession -from app.connectors.connector_registry import get_connector_class, get_or_create_connector, remove_connector -from app.core.exceptions import NotFoundError +from app.connectors.connector_registry import ( + get_connector_class, + get_or_create_connector, + remove_connector, +) +from app.core.auth import AuthContext +from app.core.exceptions import AuthorizationError, NotFoundError from app.core.secrets import get_secrets_provider from app.db.models.connection import DatabaseConnection +from app.db.models.membership import ROLE_ADMIN, ROLE_EDITOR # Encryption of connection strings is delegated to the configured secrets # backend (env/Fernet by default — see app.core.secrets). @@ -20,33 +26,75 @@ def _decrypt(value: str) -> str: return get_secrets_provider().decrypt(value) -async def list_connections(db: AsyncSession) -> list[DatabaseConnection]: - result = await db.execute( - select(DatabaseConnection).order_by(DatabaseConnection.created_at.desc()) +def _assert_access(conn: DatabaseConnection, ctx: AuthContext, *, write: bool = False) -> None: + """Enforce workspace scoping + role for a connection. + + Cross-workspace access raises 404 (don't leak existence); private + connections are visible only to their owner or a workspace admin. + """ + if conn.organization_id != ctx.organization_id or conn.workspace_id != ctx.workspace_id: + raise NotFoundError("Connection", str(conn.id)) + if conn.is_private and conn.owner_id != ctx.user_id and not ctx.has_role(ROLE_ADMIN): + raise AuthorizationError("This connection is private to its owner.") + if write: + ctx.require_role(ROLE_EDITOR) + + +async def list_connections(db: AsyncSession, ctx: AuthContext) -> list[DatabaseConnection]: + stmt = ( + select(DatabaseConnection) + .where( + DatabaseConnection.organization_id == ctx.organization_id, + DatabaseConnection.workspace_id == ctx.workspace_id, + ) + .order_by(DatabaseConnection.created_at.desc()) ) + # Non-admins don't see other people's private connections. + if not ctx.has_role(ROLE_ADMIN): + stmt = stmt.where( + or_( + DatabaseConnection.is_private.is_(False), + DatabaseConnection.owner_id == ctx.user_id, + ) + ) + result = await db.execute(stmt) return list(result.scalars().all()) -async def get_connection(db: AsyncSession, connection_id: uuid.UUID) -> DatabaseConnection: +async def get_connection( + db: AsyncSession, + connection_id: uuid.UUID, + ctx: AuthContext, + *, + write: bool = False, +) -> DatabaseConnection: conn = await db.get(DatabaseConnection, connection_id) if not conn: raise NotFoundError("Connection", str(connection_id)) + _assert_access(conn, ctx, write=write) return conn async def create_connection( db: AsyncSession, + ctx: AuthContext, name: str, connector_type: str, connection_string: str, default_schema: str = "public", max_query_timeout_seconds: int = 30, max_rows: int = 1000, + is_private: bool = False, ) -> DatabaseConnection: + ctx.require_role(ROLE_EDITOR) # Validate connector type exists get_connector_class(connector_type) conn = DatabaseConnection( + organization_id=ctx.organization_id, + workspace_id=ctx.workspace_id, + owner_id=ctx.user_id, + is_private=is_private, name=name, connector_type=connector_type, connection_string_encrypted=_encrypt(connection_string), @@ -62,9 +110,10 @@ async def create_connection( async def update_connection( db: AsyncSession, connection_id: uuid.UUID, + ctx: AuthContext, **updates: object, ) -> DatabaseConnection: - conn = await get_connection(db, connection_id) + conn = await get_connection(db, connection_id, ctx, write=True) if "connection_string" in updates and updates["connection_string"] is not None: conn.connection_string_encrypted = _encrypt(str(updates.pop("connection_string"))) @@ -79,15 +128,19 @@ async def update_connection( return conn -async def delete_connection(db: AsyncSession, connection_id: uuid.UUID) -> None: - conn = await get_connection(db, connection_id) +async def delete_connection( + db: AsyncSession, connection_id: uuid.UUID, ctx: AuthContext +) -> None: + conn = await get_connection(db, connection_id, ctx, write=True) await remove_connector(str(connection_id)) await db.delete(conn) await db.flush() -async def test_connection(db: AsyncSession, connection_id: uuid.UUID) -> tuple[bool, str]: - conn = await get_connection(db, connection_id) +async def test_connection( + db: AsyncSession, connection_id: uuid.UUID, ctx: AuthContext +) -> tuple[bool, str]: + conn = await get_connection(db, connection_id, ctx) connection_string = _decrypt(conn.connection_string_encrypted) try: connector = await get_or_create_connector( diff --git a/backend/app/services/identity_service.py b/backend/app/services/identity_service.py new file mode 100644 index 0000000..0560ca0 --- /dev/null +++ b/backend/app/services/identity_service.py @@ -0,0 +1,250 @@ +"""Organizations, teams (workspaces), memberships, and API keys. + +Also owns identity bootstrap — ensuring the default org/workspace/admin exist — +which runs both in migration 004 (for existing deployments) and on every boot +(idempotent, for fresh databases). +""" + +from __future__ import annotations + +import re +import uuid +from datetime import UTC, datetime + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload + +from app.config import settings +from app.core.auth import AuthContext +from app.core.exceptions import NotFoundError, ValidationError +from app.core.security import generate_api_key, hash_password +from app.db.models.api_key import ApiKey +from app.db.models.membership import ROLE_ADMIN, ROLES, Membership +from app.db.models.organization import Organization +from app.db.models.team import Team +from app.db.models.user import User + + +def _slugify(value: str) -> str: + slug = re.sub(r"[^a-z0-9]+", "-", value.lower()).strip("-") + return slug or "team" + + +# --- Bootstrap -------------------------------------------------------------- + + +async def get_default_organization(db: AsyncSession) -> Organization | None: + result = await db.execute( + select(Organization).where(Organization.slug == settings.default_org_slug) + ) + return result.scalar_one_or_none() + + +async def bootstrap_default_identity(db: AsyncSession) -> tuple[Organization, Team, User]: + """Idempotently ensure the default org, workspace, and admin user exist.""" + org = await get_default_organization(db) + if org is None: + org = Organization(name=settings.default_org_name, slug=settings.default_org_slug) + db.add(org) + await db.flush() + + team_res = await db.execute( + select(Team).where(Team.organization_id == org.id).order_by(Team.created_at) + ) + team = team_res.scalars().first() + if team is None: + team = Team( + organization_id=org.id, + name=settings.default_workspace_name, + slug=_slugify(settings.default_workspace_name), + ) + db.add(team) + await db.flush() + + user_res = await db.execute(select(User).where(User.email == settings.default_admin_email)) + admin = user_res.scalar_one_or_none() + if admin is None: + admin = User( + email=settings.default_admin_email, + name="Administrator", + password_hash=( + hash_password(settings.default_admin_password) + if settings.default_admin_password + else None + ), + ) + db.add(admin) + await db.flush() + + mem_res = await db.execute( + select(Membership).where( + Membership.user_id == admin.id, Membership.team_id == team.id + ) + ) + if mem_res.scalar_one_or_none() is None: + db.add(Membership(user_id=admin.id, team_id=team.id, role=ROLE_ADMIN)) + await db.flush() + + return org, team, admin + + +async def system_context(db: AsyncSession) -> AuthContext: + """An admin :class:`AuthContext` bound to the default org/workspace. + + Used by non-request entry points (startup auto-setup, the MCP server, seed + scripts) that act on behalf of the deployment rather than an end user. + """ + org, team, admin = await bootstrap_default_identity(db) + return AuthContext( + user=admin, + organization_id=org.id, + workspace_id=team.id, + role=ROLE_ADMIN, + ) + + +# --- Teams ------------------------------------------------------------------ + + +async def list_teams(db: AsyncSession, ctx: AuthContext) -> list[Team]: + """Teams in the caller's organization.""" + result = await db.execute( + select(Team) + .where(Team.organization_id == ctx.organization_id) + .order_by(Team.created_at) + ) + return list(result.scalars().all()) + + +async def list_my_memberships(db: AsyncSession, user: User) -> list[Membership]: + result = await db.execute( + select(Membership) + .where(Membership.user_id == user.id) + .options(selectinload(Membership.team)) + .order_by(Membership.created_at) + ) + return list(result.scalars().all()) + + +async def create_team(db: AsyncSession, ctx: AuthContext, name: str) -> Team: + ctx.require_role(ROLE_ADMIN) + team = Team(organization_id=ctx.organization_id, name=name, slug=_slugify(name)) + db.add(team) + await db.flush() + # The creator joins their new team as admin. + db.add(Membership(user_id=ctx.user_id, team_id=team.id, role=ROLE_ADMIN)) + await db.flush() + return team + + +async def _get_team_in_org(db: AsyncSession, ctx: AuthContext, team_id: uuid.UUID) -> Team: + team = await db.get(Team, team_id) + if team is None or team.organization_id != ctx.organization_id: + raise NotFoundError("Team", str(team_id)) + return team + + +async def list_memberships( + db: AsyncSession, ctx: AuthContext, team_id: uuid.UUID +) -> list[Membership]: + await _get_team_in_org(db, ctx, team_id) + result = await db.execute( + select(Membership) + .where(Membership.team_id == team_id) + .options(selectinload(Membership.user)) + .order_by(Membership.created_at) + ) + return list(result.scalars().all()) + + +async def add_membership( + db: AsyncSession, + ctx: AuthContext, + team_id: uuid.UUID, + email: str, + role: str, +) -> Membership: + """Add a user (by email) to a team. Admin-only.""" + ctx.require_role(ROLE_ADMIN) + if role not in ROLES: + raise ValidationError(f"Invalid role '{role}'. Must be one of {list(ROLES)}.") + await _get_team_in_org(db, ctx, team_id) + + user_res = await db.execute(select(User).where(User.email == email.lower().strip())) + user = user_res.scalar_one_or_none() + if user is None: + raise NotFoundError("User", email) + + mem_res = await db.execute( + select(Membership).where( + Membership.user_id == user.id, Membership.team_id == team_id + ) + ) + membership = mem_res.scalar_one_or_none() + if membership is not None: + membership.role = role + else: + membership = Membership(user_id=user.id, team_id=team_id, role=role) + db.add(membership) + await db.flush() + return membership + + +async def remove_membership( + db: AsyncSession, ctx: AuthContext, team_id: uuid.UUID, user_id: uuid.UUID +) -> None: + ctx.require_role(ROLE_ADMIN) + await _get_team_in_org(db, ctx, team_id) + result = await db.execute( + select(Membership).where( + Membership.user_id == user_id, Membership.team_id == team_id + ) + ) + membership = result.scalar_one_or_none() + if membership is None: + raise NotFoundError("Membership", f"{user_id} in {team_id}") + await db.delete(membership) + await db.flush() + + +# --- API keys --------------------------------------------------------------- + + +async def create_api_key( + db: AsyncSession, + user: User, + name: str, + expires_at: datetime | None = None, +) -> tuple[ApiKey, str]: + """Create an API key for ``user``. Returns ``(record, plaintext)``. + + The plaintext is returned exactly once and never persisted. + """ + plaintext, key_hash, prefix = generate_api_key() + api_key = ApiKey( + user_id=user.id, + name=name, + key_hash=key_hash, + key_prefix=prefix, + expires_at=expires_at, + ) + db.add(api_key) + await db.flush() + return api_key, plaintext + + +async def list_api_keys(db: AsyncSession, user: User) -> list[ApiKey]: + result = await db.execute( + select(ApiKey).where(ApiKey.user_id == user.id).order_by(ApiKey.created_at.desc()) + ) + return list(result.scalars().all()) + + +async def revoke_api_key(db: AsyncSession, user: User, key_id: uuid.UUID) -> None: + api_key = await db.get(ApiKey, key_id) + if api_key is None or api_key.user_id != user.id: + raise NotFoundError("ApiKey", str(key_id)) + if api_key.revoked_at is None: + api_key.revoked_at = datetime.now(UTC) + await db.flush() diff --git a/backend/app/services/knowledge_service.py b/backend/app/services/knowledge_service.py index c3e0de6..023526d 100644 --- a/backend/app/services/knowledge_service.py +++ b/backend/app/services/knowledge_service.py @@ -225,6 +225,7 @@ async def import_document( connection_id: uuid.UUID, title: str, content: str, + organization_id: uuid.UUID, source_url: str | None = None, ) -> KnowledgeDocument: """Import text or HTML content as a knowledge document with embedded chunks. @@ -266,6 +267,7 @@ async def import_document( doc = KnowledgeDocument( connection_id=connection_id, + organization_id=organization_id, title=document_title, source_url=source_url, content=content, diff --git a/backend/app/services/query_service.py b/backend/app/services/query_service.py index 9425bf5..ebdeb6f 100644 --- a/backend/app/services/query_service.py +++ b/backend/app/services/query_service.py @@ -5,6 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.connectors.connector_registry import get_or_create_connector +from app.core.auth import AuthContext from app.core.exceptions import AppError, SQLSafetyError from app.core.telemetry import start_span from app.db.models.query_history import QueryExecution @@ -22,6 +23,7 @@ async def execute_nl_query( db: AsyncSession, connection_id: uuid.UUID, question: str, + ctx: AuthContext, ) -> dict: """Full pipeline: NL question → SQL → execute → interpret. @@ -36,7 +38,7 @@ async def execute_nl_query( Returns dict with all response fields. """ - conn = await get_connection(db, connection_id) + conn = await get_connection(db, connection_id, ctx) connection_string = get_decrypted_connection_string(conn) # Step 1: Build context @@ -154,7 +156,9 @@ async def execute_nl_query( else: # Save failed execution to history execution = QueryExecution( + organization_id=ctx.organization_id, connection_id=connection_id, + user_id=ctx.user_id, natural_language=question, generated_sql=generated_sql, final_sql=final_sql, @@ -189,7 +193,9 @@ async def execute_nl_query( # Step 7: Save to history execution = QueryExecution( + organization_id=ctx.organization_id, connection_id=connection_id, + user_id=ctx.user_id, natural_language=question, generated_sql=generated_sql, final_sql=final_sql, @@ -229,9 +235,10 @@ async def generate_sql_only( db: AsyncSession, connection_id: uuid.UUID, question: str, + ctx: AuthContext, ) -> dict: """Generate SQL without executing it.""" - conn = await get_connection(db, connection_id) + conn = await get_connection(db, connection_id, ctx) context = await build_context(db, connection_id, question, dialect=conn.connector_type) provider, llm_config = route(question) composer = QueryComposerAgent(provider, llm_config) @@ -250,6 +257,7 @@ async def execute_raw_sql( db: AsyncSession, connection_id: uuid.UUID, sql: str, + ctx: AuthContext, original_question: str | None = None, ) -> dict: """Execute user-provided SQL directly (no LLM generation). @@ -266,7 +274,7 @@ async def execute_raw_sql( if safety_issues: raise SQLSafetyError("; ".join(safety_issues)) - conn = await get_connection(db, connection_id) + conn = await get_connection(db, connection_id, ctx) connection_string = get_decrypted_connection_string(conn) # Step 2: Execute query @@ -283,7 +291,9 @@ async def execute_raw_sql( except Exception as e: # Save failed execution to history execution = QueryExecution( + organization_id=ctx.organization_id, connection_id=connection_id, + user_id=ctx.user_id, natural_language=original_question or "(manual SQL)", generated_sql=None, final_sql=sql, @@ -325,7 +335,9 @@ async def execute_raw_sql( # Step 4: Save to history execution = QueryExecution( + organization_id=ctx.organization_id, connection_id=connection_id, + user_id=ctx.user_id, natural_language=question_text, generated_sql=None, final_sql=sql, diff --git a/backend/app/services/schema_service.py b/backend/app/services/schema_service.py index 95630cf..789f6a5 100644 --- a/backend/app/services/schema_service.py +++ b/backend/app/services/schema_service.py @@ -1,11 +1,12 @@ import uuid -from datetime import datetime, timezone +from datetime import UTC, datetime -from sqlalchemy import select, delete +from sqlalchemy import delete, select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload from app.connectors.connector_registry import get_or_create_connector +from app.core.auth import AuthContext from app.core.exceptions import NotFoundError from app.db.models.schema_cache import CachedColumn, CachedRelationship, CachedTable from app.services.connection_service import get_connection, get_decrypted_connection_string @@ -14,9 +15,10 @@ async def introspect_and_cache( db: AsyncSession, connection_id: uuid.UUID, + ctx: AuthContext, ) -> dict[str, int]: """Introspect a target database and cache the schema metadata.""" - conn = await get_connection(db, connection_id) + conn = await get_connection(db, connection_id, ctx, write=True) connection_string = get_decrypted_connection_string(conn) connector = await get_or_create_connector( @@ -99,7 +101,7 @@ async def introspect_and_cache( total_relationships += 1 # Update last_introspected_at - conn.last_introspected_at = datetime.now(timezone.utc) + conn.last_introspected_at = datetime.now(UTC) await db.flush() return { @@ -110,8 +112,10 @@ async def introspect_and_cache( async def get_tables( - db: AsyncSession, connection_id: uuid.UUID + db: AsyncSession, connection_id: uuid.UUID, ctx: AuthContext ) -> list[CachedTable]: + # Access check (raises if the connection isn't in the caller's workspace). + await get_connection(db, connection_id, ctx) result = await db.execute( select(CachedTable) .where(CachedTable.connection_id == connection_id) @@ -122,7 +126,7 @@ async def get_tables( async def get_table_detail( - db: AsyncSession, table_id: uuid.UUID + db: AsyncSession, table_id: uuid.UUID, ctx: AuthContext ) -> CachedTable: result = await db.execute( select(CachedTable) @@ -140,4 +144,6 @@ async def get_table_detail( table = result.scalar_one_or_none() if not table: raise NotFoundError("Table", str(table_id)) + # Ensure the table's connection belongs to the caller's workspace. + await get_connection(db, table.connection_id, ctx) return table diff --git a/backend/app/services/setup_service.py b/backend/app/services/setup_service.py index 03f6a41..a1e81e2 100644 --- a/backend/app/services/setup_service.py +++ b/backend/app/services/setup_service.py @@ -472,12 +472,19 @@ async def auto_setup_sample_db() -> None: try: async with async_session_factory() as db: try: - connection = await _ensure_connection(db) + # Auto-setup acts on behalf of the deployment (no request + # user), so it runs under the default org/workspace admin. + from app.services import identity_service + + ctx = await identity_service.system_context(db) + org_id = ctx.organization_id + + connection = await _ensure_connection(db, ctx) connection_id = connection.id if connection.last_introspected_at is None: logger.info("Auto-setup: introspecting schema...") - await schema_service.introspect_and_cache(db, connection_id) + await schema_service.introspect_and_cache(db, connection_id, ctx) await db.commit() # Refresh to get updated last_introspected_at await db.refresh(connection) @@ -485,10 +492,10 @@ async def auto_setup_sample_db() -> None: else: logger.info("Auto-setup: schema already introspected, skipping") - await _seed_glossary(db, connection_id) - await _seed_metrics(db, connection_id) + await _seed_glossary(db, connection_id, org_id) + await _seed_metrics(db, connection_id, org_id) await _seed_dictionary(db, connection_id) - await _seed_knowledge(db, connection_id) + await _seed_knowledge(db, connection_id, org_id) await db.commit() logger.info( @@ -515,7 +522,7 @@ async def auto_setup_sample_db() -> None: ) -async def _ensure_connection(db): +async def _ensure_connection(db, ctx): """Create or find the sample DB connection.""" result = await db.execute( select(DatabaseConnection).where(DatabaseConnection.name == CONNECTION_NAME) @@ -529,6 +536,7 @@ async def _ensure_connection(db): logger.info("Auto-setup: creating connection '%s'...", CONNECTION_NAME) connection = await connection_service.create_connection( db, + ctx, name=CONNECTION_NAME, connector_type="postgresql", connection_string=settings.sample_db_connection_string, @@ -539,7 +547,7 @@ async def _ensure_connection(db): return connection -async def _seed_glossary(db, connection_id: uuid.UUID) -> None: +async def _seed_glossary(db, connection_id: uuid.UUID, org_id: uuid.UUID) -> None: """Seed glossary terms if none exist.""" count = await db.scalar( select(func.count()).select_from(GlossaryTerm).where( @@ -552,11 +560,11 @@ async def _seed_glossary(db, connection_id: uuid.UUID) -> None: logger.info("Auto-setup: seeding %d glossary terms...", len(GLOSSARY_TERMS)) for term_data in GLOSSARY_TERMS: - db.add(GlossaryTerm(connection_id=connection_id, **term_data)) + db.add(GlossaryTerm(connection_id=connection_id, organization_id=org_id, **term_data)) await db.flush() -async def _seed_metrics(db, connection_id: uuid.UUID) -> None: +async def _seed_metrics(db, connection_id: uuid.UUID, org_id: uuid.UUID) -> None: """Seed metric definitions if none exist.""" count = await db.scalar( select(func.count()).select_from(MetricDefinition).where( @@ -569,7 +577,7 @@ async def _seed_metrics(db, connection_id: uuid.UUID) -> None: logger.info("Auto-setup: seeding %d metrics...", len(METRICS)) for metric_data in METRICS: - db.add(MetricDefinition(connection_id=connection_id, **metric_data)) + db.add(MetricDefinition(connection_id=connection_id, organization_id=org_id, **metric_data)) await db.flush() @@ -617,7 +625,7 @@ async def _seed_dictionary(db, connection_id: uuid.UUID) -> None: logger.info("Auto-setup: seeded %d dictionary entries", total) -async def _seed_knowledge(db, connection_id: uuid.UUID) -> None: +async def _seed_knowledge(db, connection_id: uuid.UUID, org_id: uuid.UUID) -> None: """Seed a sample knowledge document if none exist.""" count = await db.scalar( select(func.count()).select_from(KnowledgeDocument).where( @@ -638,6 +646,7 @@ async def _seed_knowledge(db, connection_id: uuid.UUID) -> None: connection_id=connection_id, title=KNOWLEDGE_DOCUMENT["title"], content=KNOWLEDGE_DOCUMENT["content"], + organization_id=org_id, source_url=KNOWLEDGE_DOCUMENT["source_url"], ) logger.info( diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 1657016..c3c3b5c 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -16,6 +16,7 @@ dependencies = [ "sse-starlette>=2.0", "sqlparse>=0.5", "cryptography>=43.0", + "pyjwt>=2.9", "mcp>=1.2", ] diff --git a/backend/tests/test_auth_endpoints.py b/backend/tests/test_auth_endpoints.py new file mode 100644 index 0000000..4f62f81 --- /dev/null +++ b/backend/tests/test_auth_endpoints.py @@ -0,0 +1,53 @@ +"""Endpoint-level auth tests via TestClient. + +Like test_health, these do NOT enter the lifespan (no DB setup). They assert +that protected routes reject unauthenticated callers *before* any DB access, +and that the public provider-discovery endpoint works. +""" + +import pytest +from fastapi.testclient import TestClient + +from app.core import auth as auth_module +from app.main import app + +client = TestClient(app) + + +@pytest.fixture(autouse=True) +def _auth_enabled(monkeypatch): + # Ensure the dev escape hatch is off so auth is actually enforced. + monkeypatch.setattr(auth_module.settings, "disable_auth", False) + yield + + +@pytest.mark.parametrize( + "method,path", + [ + ("get", "/api/v1/connections"), + ("post", "/api/v1/connections"), + ("get", "/api/v1/teams"), + ("get", "/api/v1/api-keys"), + ("get", "/api/v1/auth/me"), + ], +) +def test_protected_routes_require_auth(method, path): + kwargs = {"json": {}} if method == "post" else {} + resp = getattr(client, method)(path, **kwargs) + assert resp.status_code == 401 + assert "error" in resp.json() + + +def test_bad_bearer_token_rejected(): + resp = client.get( + "/api/v1/connections", headers={"Authorization": "Bearer not-a-real-jwt"} + ) + assert resp.status_code == 401 + + +def test_auth_providers_is_public(): + resp = client.get("/api/v1/auth/providers") + assert resp.status_code == 200 + body = resp.json() + assert "name" in body + assert {"supports_password", "supports_magic_link", "is_sso", "disable_auth"} <= body.keys() diff --git a/backend/tests/test_auth_providers.py b/backend/tests/test_auth_providers.py new file mode 100644 index 0000000..b47855a --- /dev/null +++ b/backend/tests/test_auth_providers.py @@ -0,0 +1,99 @@ +"""Unit tests for the pluggable auth-provider registry and AuthContext roles.""" + +import uuid + +import pytest + +from app.core import auth_providers +from app.core.auth import AuthContext +from app.core.auth_providers import ( + AuthProvider, + LocalAuthProvider, + OIDCAuthProvider, + get_auth_provider, + register_auth_provider, + reset_auth_provider, +) +from app.db.models.membership import ROLE_ADMIN, ROLE_EDITOR, ROLE_VIEWER +from app.db.models.user import User + + +@pytest.fixture(autouse=True) +def _reset_provider(): + reset_auth_provider() + yield + reset_auth_provider() + + +def test_default_provider_is_local(monkeypatch): + monkeypatch.setattr(auth_providers.settings, "auth_provider", "local") + provider = get_auth_provider() + assert isinstance(provider, LocalAuthProvider) + assert provider.supports_password + assert provider.supports_magic_link + + +def test_describe_shape(monkeypatch): + monkeypatch.setattr(auth_providers.settings, "auth_provider", "magic_link") + info = get_auth_provider().describe() + assert info == { + "name": "magic_link", + "supports_password": False, + "supports_magic_link": True, + "is_sso": False, + } + + +def test_oidc_is_an_unimplemented_seam(): + provider = OIDCAuthProvider() + assert provider.is_sso + with pytest.raises(NotImplementedError): + provider.authorization_url("state") + + +def test_unknown_provider_raises(monkeypatch): + monkeypatch.setattr(auth_providers.settings, "auth_provider", "nope") + with pytest.raises(ValueError): + get_auth_provider() + + +def test_register_custom_provider(monkeypatch): + class CustomSSO(AuthProvider): + name = "custom" + is_sso = True + + register_auth_provider("custom", CustomSSO) + monkeypatch.setattr(auth_providers.settings, "auth_provider", "custom") + assert isinstance(get_auth_provider(), CustomSSO) + + +# --- AuthContext role logic ------------------------------------------------- + + +def _ctx(role: str) -> AuthContext: + return AuthContext( + user=User(id=uuid.uuid4(), email="x@y.z"), + organization_id=uuid.uuid4(), + workspace_id=uuid.uuid4(), + role=role, + ) + + +def test_role_hierarchy(): + admin = _ctx(ROLE_ADMIN) + editor = _ctx(ROLE_EDITOR) + viewer = _ctx(ROLE_VIEWER) + + assert admin.has_role(ROLE_EDITOR) and admin.has_role(ROLE_VIEWER) + assert editor.has_role(ROLE_EDITOR) and not editor.has_role(ROLE_ADMIN) + assert viewer.has_role(ROLE_VIEWER) and not viewer.has_role(ROLE_EDITOR) + + +def test_require_role_raises_for_insufficient(): + from app.core.exceptions import AuthorizationError + + viewer = _ctx(ROLE_VIEWER) + with pytest.raises(AuthorizationError): + viewer.require_role(ROLE_EDITOR) + # Sufficient role does not raise. + viewer.require_role(ROLE_VIEWER) diff --git a/backend/tests/test_security.py b/backend/tests/test_security.py new file mode 100644 index 0000000..c3258e3 --- /dev/null +++ b/backend/tests/test_security.py @@ -0,0 +1,109 @@ +"""Unit tests for the auth primitives (passwords, JWTs, API keys). + +No DB or network — pure crypto/token logic. +""" + +import time + +import jwt +import pytest + +from app.config import settings +from app.core.exceptions import AuthenticationError +from app.core.security import ( + TOKEN_PURPOSE_MAGIC_LINK, + TOKEN_PURPOSE_SESSION, + API_KEY_PREFIX, + create_session_token, + create_token, + decode_token, + generate_api_key, + hash_api_key, + hash_password, + verify_password, +) + + +# --- Passwords -------------------------------------------------------------- + + +def test_password_roundtrip(): + encoded = hash_password("correct horse battery staple") + assert encoded != "correct horse battery staple" + assert verify_password("correct horse battery staple", encoded) + + +def test_password_wrong_rejected(): + encoded = hash_password("s3cret-value") + assert not verify_password("wrong-value", encoded) + + +def test_password_salt_is_random(): + assert hash_password("same") != hash_password("same") + + +@pytest.mark.parametrize("bad", [None, "", "not-the-format", "pbkdf2_sha256$abc"]) +def test_verify_handles_malformed(bad): + assert verify_password("whatever", bad) is False + + +# --- JWTs ------------------------------------------------------------------- + + +def test_session_token_roundtrip(): + token = create_session_token("user-123") + payload = decode_token(token, TOKEN_PURPOSE_SESSION) + assert payload["sub"] == "user-123" + assert payload["purpose"] == TOKEN_PURPOSE_SESSION + + +def test_purpose_mismatch_rejected(): + token = create_session_token("user-123") + with pytest.raises(AuthenticationError): + decode_token(token, TOKEN_PURPOSE_MAGIC_LINK) + + +def test_bad_signature_rejected(): + token = create_session_token("user-123") + tampered = token[:-3] + ("aaa" if not token.endswith("aaa") else "bbb") + with pytest.raises(AuthenticationError): + decode_token(tampered, TOKEN_PURPOSE_SESSION) + + +def test_expired_token_rejected(): + token = create_token("user-123", TOKEN_PURPOSE_SESSION, ttl_minutes=0) + time.sleep(1) + with pytest.raises(AuthenticationError): + decode_token(token, TOKEN_PURPOSE_SESSION) + + +def test_wrong_secret_rejected(): + token = jwt.encode( + {"sub": "x", "purpose": TOKEN_PURPOSE_SESSION}, "some-other-secret", algorithm="HS256" + ) + with pytest.raises(AuthenticationError): + decode_token(token, TOKEN_PURPOSE_SESSION) + + +def test_signed_with_configured_secret(): + token = create_session_token("user-123") + # Decodable with the configured secret directly. + payload = jwt.decode(token, settings.jwt_secret, algorithms=[settings.jwt_algorithm]) + assert payload["sub"] == "user-123" + + +# --- API keys --------------------------------------------------------------- + + +def test_api_key_generation(): + plaintext, key_hash, prefix = generate_api_key() + assert plaintext.startswith(API_KEY_PREFIX) + assert hash_api_key(plaintext) == key_hash + assert prefix == plaintext[:10] + assert len(key_hash) == 64 # sha256 hex + + +def test_api_keys_are_unique(): + a, _, _ = generate_api_key() + b, _, _ = generate_api_key() + assert a != b diff --git a/docker-compose.yml b/docker-compose.yml index 29bf48f..6edef29 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -45,6 +45,10 @@ services: DATABASE_URL: postgresql+asyncpg://querywise:querywise_dev@app-db:5432/querywise ENVIRONMENT: development PYTHONPATH: /app + # Phase 1 auth is enforced once migration 004 runs. The pre-auth frontend + # has no login flow yet, so keep the dev escape hatch on (every request + # acts as the default admin). Remove this once the Phase 1 frontend lands. + DISABLE_AUTH: "true" # OLLAMA_BASE_URL comes from .env: # host.docker.internal:11434 = native Ollama on host (default, GPU-accelerated on macOS) # ollama:11434 = Docker Ollama (CPU-only, use with --profile ollama-docker) diff --git a/planfull.md b/planfull.md index 571cca0..9e0e286 100644 --- a/planfull.md +++ b/planfull.md @@ -377,7 +377,7 @@ managed-SaaS fleet but is **not** used to share one DB across customers today. | Phase | Status | Reference | |---|---|---| | **0** — Production hardening & async foundation | ✅ Implemented | PR #7 (→ v2.0.0) | -| **1** — Identity, teams & ownership | ⬜ Not started | — | +| **1** — Identity, teams & ownership | ✅ Backend implemented (frontend pending) | migration `004`; OIDC is a registered seam (magic-link + local live) | | **2** — Durable analytics artifacts | ⬜ Not started | — | | **3** — Discovery, catalog & trust | ⬜ Not started | — | | **4** — Scheduling, distribution & governance | ⬜ Not started | — |