diff --git a/docs/integrations/app-integrations/rest-api.mdx b/docs/integrations/app-integrations/rest-api.mdx new file mode 100644 index 0000000000..1ead2f519a --- /dev/null +++ b/docs/integrations/app-integrations/rest-api.mdx @@ -0,0 +1,148 @@ +--- +title: REST API +sidebarTitle: REST API +--- + +In this section, we present how to connect any REST API to MindsDB using bearer-token authentication. + +The REST API handler is a generic integration that lets you forward HTTP requests to any API through MindsDB using stored credentials. Unlike named integrations (HubSpot, Shopify, etc.), it requires no handler-specific knowledge — just a base URL and a bearer token. + +This is useful for APIs that MindsDB doesn't have a dedicated handler for, or when you only need direct HTTP access without SQL table mapping. + +## Connection + +The required arguments to establish a connection are as follows: + +- `base_url`: the base URL of the REST API (e.g. `https://api.example.com`). All request paths are appended to this URL. +- `bearer_token`: the token used for authentication. Injected as `Authorization: Bearer ` on every request. + +Optional arguments: + +- `default_headers`: a JSON object of static headers added to every request (e.g. `{"Accept": "application/json"}`). +- `allowed_hosts`: a list of allowed hostnames for requests. Defaults to the hostname of `base_url`. Use `["*"]` to disable host containment. +- `test_path`: the path used by the test endpoint to verify connectivity. Defaults to `/`. + +To connect a REST API to MindsDB, create a new database: + +```sql +CREATE DATABASE my_api +WITH ENGINE = 'rest_api', +PARAMETERS = { + "base_url": "https://api.example.com", + "bearer_token": "your_token_here" +}; +``` + +### Example: Connect to HubSpot + +```sql +CREATE DATABASE my_hubspot +WITH ENGINE = 'rest_api', +PARAMETERS = { + "base_url": "https://api.hubapi.com", + "bearer_token": "pat-eu1-..." +}; +``` + +### Example: Connect with default headers and a custom test path + +```sql +CREATE DATABASE my_internal_api +WITH ENGINE = 'rest_api', +PARAMETERS = { + "base_url": "https://internal.example.com/api/v2", + "bearer_token": "sk-...", + "default_headers": {"Accept": "application/json"}, + "test_path": "/health" +}; +``` + +### Example: Multiple allowed hosts + +```sql +CREATE DATABASE my_multi_region_api +WITH ENGINE = 'rest_api', +PARAMETERS = { + "base_url": "https://api.example.com", + "bearer_token": "your_token", + "allowed_hosts": ["api.example.com", "api.eu.example.com"] +}; +``` + +## Usage + +This handler is **passthrough-only** — it does not expose SQL tables. All interaction is through the REST passthrough endpoint. + +### Sending requests + +Forward HTTP requests to the upstream API: + +``` +POST /api/integrations/my_api/passthrough +``` + +```json +{ + "method": "GET", + "path": "/v1/users", + "query": {"limit": "10"}, + "headers": {"Accept": "application/json"} +} +``` + +The response wraps the upstream HTTP response: + +```json +{ + "status_code": 200, + "headers": {"content-type": "application/json"}, + "body": {"results": [...]}, + "content_type": "application/json" +} +``` + +Supported HTTP methods: `GET`, `POST`, `PUT`, `PATCH`, `DELETE`. + +### Testing the connection + +Verify that the base URL, token, and host allowlist are configured correctly: + +``` +POST /api/integrations/my_api/passthrough/test +``` + +A successful response: + +```json +{"ok": true, "status_code": 200, "host": "api.example.com", "latency_ms": 140} +``` + +A failed response: + +```json +{"ok": false, "error_code": "auth_failed", "message": "upstream rejected credentials; base URL and allowlist look correct"} +``` + +## Security + +- Credentials are stored in MindsDB and never exposed to the caller. +- Requests are restricted to hostnames in the allowlist. Private and loopback IP addresses are rejected by default. +- Callers cannot override `Authorization`, `Host`, `Cookie`, or `Proxy-*` headers. +- If the upstream API echoes the token in responses, it is replaced with `[REDACTED_API_KEY]`. +- Request bodies are capped at 1 MB, response bodies at 10 MB. + + +**`host 'X' is not in the datasource allowlist`** + +The request path resolved to a different hostname than `base_url`. Add the hostname to `allowed_hosts`, or use `["*"]` to disable host containment (not recommended for production). + + + +**`upstream rejected credentials (401/403)`** + +The token is invalid, expired, or missing required scopes. Verify the token with the upstream API provider. + + + +For more information about available actions and development plans, visit [this page](https://github.com/mindsdb/mindsdb/blob/main/mindsdb/integrations/handlers/rest_api_handler/README.md). + diff --git a/mindsdb/api/http/initialize.py b/mindsdb/api/http/initialize.py index 66978cb69c..f0d72aef0b 100644 --- a/mindsdb/api/http/initialize.py +++ b/mindsdb/api/http/initialize.py @@ -28,6 +28,7 @@ from mindsdb.api.http.namespaces.default import ns_conf as default_ns, check_session_auth from mindsdb.api.http.namespaces.file import ns_conf as file_ns from mindsdb.api.http.namespaces.handlers import ns_conf as handlers_ns +from mindsdb.api.http.namespaces.integrations import ns_conf as integrations_ns from mindsdb.api.http.namespaces.knowledge_bases import ns_conf as knowledge_bases_ns from mindsdb.api.http.namespaces.models import ns_conf as models_ns from mindsdb.api.http.namespaces.projects import ns_conf as projects_ns @@ -280,6 +281,7 @@ def root_index(path): agents_ns, jobs_ns, knowledge_bases_ns, + integrations_ns, ] for ns in protected_namespaces: diff --git a/mindsdb/api/http/namespaces/configs/integrations.py b/mindsdb/api/http/namespaces/configs/integrations.py new file mode 100644 index 0000000000..14d3fab27f --- /dev/null +++ b/mindsdb/api/http/namespaces/configs/integrations.py @@ -0,0 +1,3 @@ +from flask_restx import Namespace + +ns_conf = Namespace("integrations", description="API for integration-level operations (passthrough, capabilities)") diff --git a/mindsdb/api/http/namespaces/integrations.py b/mindsdb/api/http/namespaces/integrations.py new file mode 100644 index 0000000000..7df3e8b490 --- /dev/null +++ b/mindsdb/api/http/namespaces/integrations.py @@ -0,0 +1,197 @@ +from http import HTTPStatus + +from flask import request +from flask_restx import Resource + +from mindsdb.api.http.utils import http_error +from mindsdb.api.http.namespaces.configs.integrations import ns_conf +from mindsdb.api.mysql.mysql_proxy.classes.fake_mysql_proxy import FakeMysqlProxy +from mindsdb.integrations.libs.passthrough import PassthroughProtocol +from mindsdb.integrations.libs.passthrough_types import ( + ALLOWED_METHODS, + FORBIDDEN_REQUEST_HEADERS, + PassthroughError, + PassthroughNotSupportedError, + PassthroughRequest, + PassthroughResponse, + PassthroughValidationError, +) +from mindsdb.interfaces.database.integrations import integration_controller +from mindsdb.metrics.metrics import api_endpoint_metrics +from mindsdb.utilities import log + +logger = log.getLogger(__name__) + + +def _handler_supports_passthrough(handler_module) -> bool: + handler_cls = getattr(handler_module, "Handler", None) + if handler_cls is None: + return False + # issubclass is the right check for Protocol when classes define the + # methods as real methods (not just dynamic attrs); runtime_checkable + # Protocols support issubclass in that mode. + try: + return issubclass(handler_cls, PassthroughProtocol) + except TypeError: + return False + + +def _resolve_auth_modes(handler_cls) -> list[str]: + """Resolve a handler's advertised auth modes for /capabilities. + + Order of preference: + 1. `_auth_modes` (list[str]) if defined and non-empty — handlers + that support more than one mode (e.g. rest_api: bearer + OAuth + client credentials). + 2. `_auth_mode` (str) if defined — single-mode handlers; preserves + the legacy declaration. + 3. ["bearer"] default — protocol-only handlers that don't declare + anything still land in the right bucket. + """ + if handler_cls is None: + return ["bearer"] + modes = getattr(handler_cls, "_auth_modes", None) + if isinstance(modes, (list, tuple)) and modes: + return [str(m) for m in modes] + mode = getattr(handler_cls, "_auth_mode", None) + if isinstance(mode, str) and mode: + return [mode] + return ["bearer"] + + +def _get_passthrough_handler(name: str): + """Look up the datasource's handler and verify it satisfies the contract.""" + proxy = FakeMysqlProxy() + handler = proxy.session.integration_controller.get_data_handler(name) + if not isinstance(handler, PassthroughProtocol): + raise PassthroughNotSupportedError(f"datasource '{name}' does not support REST passthrough") + return handler + + +def _parse_passthrough_request(payload: dict) -> PassthroughRequest: + if not isinstance(payload, dict): + raise PassthroughValidationError("request body must be a JSON object") + + method = payload.get("method") + path = payload.get("path") + if not isinstance(method, str) or method.upper() not in ALLOWED_METHODS: + raise PassthroughValidationError(f"'method' must be one of {sorted(ALLOWED_METHODS)}") + if not isinstance(path, str) or not path.startswith("/"): + raise PassthroughValidationError("'path' must be a string starting with '/'") + + headers = payload.get("headers") or {} + if not isinstance(headers, dict): + raise PassthroughValidationError("'headers' must be an object") + for name in headers: + if not isinstance(name, str): + raise PassthroughValidationError("header names must be strings") + if name.lower() in FORBIDDEN_REQUEST_HEADERS or name.lower().startswith("proxy-"): + raise PassthroughValidationError(f"header '{name}' is not allowed in passthrough requests") + + query = payload.get("query") or {} + if not isinstance(query, dict): + raise PassthroughValidationError("'query' must be an object") + + return PassthroughRequest( + method=method.upper(), + path=path, + query={str(k): str(v) for k, v in query.items()}, + headers={str(k): str(v) for k, v in headers.items()}, + body=payload.get("body"), + ) + + +def _serialize_response(resp: PassthroughResponse) -> dict: + return { + "status_code": resp.status_code, + "headers": resp.headers, + "body": resp.body, + "content_type": resp.content_type, + } + + +def _passthrough_error_response(err: PassthroughError): + return { + "error_code": err.error_code, + "message": str(err), + }, err.http_status + + +@ns_conf.route("//passthrough") +@ns_conf.param("name", "Datasource name") +class Passthrough(Resource): + @ns_conf.doc("passthrough") + @api_endpoint_metrics("POST", "/integrations/passthrough") + def post(self, name: str): + payload = request.json or {} + try: + req = _parse_passthrough_request(payload) + handler = _get_passthrough_handler(name) + response = handler.api_passthrough(req) + except PassthroughError as e: + return _passthrough_error_response(e) + except Exception as e: # noqa: BLE001 + logger.exception("passthrough failed for datasource %s", name) + return http_error(HTTPStatus.INTERNAL_SERVER_ERROR, "PassthroughError", str(e)) + + return _serialize_response(response), 200 + + +@ns_conf.route("//passthrough/test") +@ns_conf.param("name", "Datasource name") +class PassthroughTest(Resource): + @ns_conf.doc("passthrough_test") + @api_endpoint_metrics("POST", "/integrations/passthrough/test") + def post(self, name: str): + try: + handler = _get_passthrough_handler(name) + except PassthroughError as e: + return _passthrough_error_response(e) + except Exception as e: # noqa: BLE001 + logger.exception("passthrough test lookup failed for datasource %s", name) + return http_error(HTTPStatus.INTERNAL_SERVER_ERROR, "PassthroughError", str(e)) + + result = handler.test_passthrough() + return result, 200 + + +@ns_conf.route("/capabilities") +class Capabilities(Resource): + """Return structured passthrough capabilities per handler. + + The new ``handlers`` dict is the canonical shape callers should migrate + to. The legacy flat ``bearer_passthrough`` list is still populated for + backward compat — Minds can migrate on its own timeline. + """ + + @ns_conf.doc("integration_capabilities") + @api_endpoint_metrics("GET", "/integrations/capabilities") + def get(self): + handlers: dict[str, dict] = {} + bearer_engines: list[str] = [] + handler_modules = getattr(integration_controller, "handler_modules", {}) or {} + for engine, module in handler_modules.items(): + try: + if not _handler_supports_passthrough(module): + continue + handler_cls = getattr(module, "Handler", None) + # Resolve the handler's advertised auth modes — supports + # both the new list-shaped `_auth_modes` and the legacy + # single-mode `_auth_mode` declaration. + auth_modes = _resolve_auth_modes(handler_cls) + handlers[engine] = { + "auth_modes": auth_modes, + "operations": ["passthrough"], + } + if "bearer" in auth_modes: + bearer_engines.append(engine) + except Exception: + # A broken handler module should not break the capabilities endpoint. + logger.debug("skipping handler %s during capability probe", engine, exc_info=True) + bearer_engines.sort() + return { + "handlers": handlers, + # TODO: remove in v2 once Minds has migrated to the `handlers` + # structured shape. Keep backward-compat for now. + "bearer_passthrough": bearer_engines, + }, 200 diff --git a/mindsdb/api/http/namespaces/sql.py b/mindsdb/api/http/namespaces/sql.py index 39e53cc431..fec3f2a848 100644 --- a/mindsdb/api/http/namespaces/sql.py +++ b/mindsdb/api/http/namespaces/sql.py @@ -86,25 +86,37 @@ def post(self): try: handler = mysql_proxy.session.integration_controller.get_data_handler(db) - result = handler.native_query(query) + raw_result = handler.native_query(query) except Exception as e: - query_response = {"type": "error", "error_code": 0, "error_message": str(e)} + error_type = "unexpected" + result = SQLAnswer( + resp_type=SQL_RESPONSE_TYPE.ERROR, + error_code=0, + error_message=str(e), + ) else: - if result.type == SQL_RESPONSE_TYPE.ERROR: - query_response = {"type": "error", "error_code": 0, "error_message": result.error_message} - elif result.type == SQL_RESPONSE_TYPE.OK: - query_response = {"type": "ok"} + if raw_result.type == SQL_RESPONSE_TYPE.ERROR: + # raw_result will be ErrorResponse. + error_type = "expected" + result = SQLAnswer( + resp_type=SQL_RESPONSE_TYPE.ERROR, + error_code=0, + error_message=raw_result.error_message, + ) + elif raw_result.type == SQL_RESPONSE_TYPE.OK: + result = SQLAnswer( + resp_type=SQL_RESPONSE_TYPE.OK, + error_code=0, + error_message=None, + ) else: - df = result.data_frame - result_set = ResultSet.from_df(df) - query_response = { - "type": "table", - "column_names": result_set.get_column_names(), - "data": result_set.to_lists(json_types=True), - } + # raw_result will be TableResponse. + result_set = ResultSet.from_table_response(raw_result) + result = SQLAnswer( + resp_type=SQL_RESPONSE_TYPE.TABLE, + result_set=result_set, + ) - query_response["context"] = mysql_proxy.get_context() - query_response = query_response, 200 else: try: result: SQLAnswer = mysql_proxy.process_query(query) @@ -137,16 +149,16 @@ def post(self): ) logger.exception("Error query processing:") - context = mysql_proxy.get_context() + context = mysql_proxy.get_context() - if response_format == ReponseFormat.JSONLINES: - query_response = result.stream_http_response_jsonlines(context=context) - query_response = Response(query_response, mimetype="application/jsonlines") - elif response_format == ReponseFormat.SSE: - query_response = result.stream_http_response_sse(context=context) - query_response = Response(query_response, mimetype="text/event-stream") - else: - query_response = result.dump_http_response(context=context), 200 + if response_format == ReponseFormat.JSONLINES: + query_response = result.stream_http_response_jsonlines(context=context) + query_response = Response(query_response, mimetype="application/jsonlines") + elif response_format == ReponseFormat.SSE: + query_response = result.stream_http_response_sse(context=context) + query_response = Response(query_response, mimetype="text/event-stream") + else: + query_response = result.dump_http_response(context=context), 200 hooks.after_api_query( company_id=ctx.company_id, diff --git a/mindsdb/integrations/handlers/hubspot_handler/hubspot_handler.py b/mindsdb/integrations/handlers/hubspot_handler/hubspot_handler.py index b4adb0d978..19a6c0244d 100644 --- a/mindsdb/integrations/handlers/hubspot_handler/hubspot_handler.py +++ b/mindsdb/integrations/handlers/hubspot_handler/hubspot_handler.py @@ -28,6 +28,8 @@ PRIMARY_ASSOCIATIONS_CONFIG, ) from mindsdb.integrations.libs.api_handler import MetaAPIHandler +from mindsdb.integrations.libs.passthrough import PassthroughMixin +from mindsdb.integrations.libs.passthrough_types import PassthroughRequest from mindsdb.integrations.utilities.sql_utils import FilterCondition, FilterOperator, extract_comparison_conditions from mindsdb.integrations.libs.response import ( @@ -133,16 +135,25 @@ def _map_type(data_type: str) -> MYSQL_DATA_TYPE: return type_map.get(data_type_upper, MYSQL_DATA_TYPE.VARCHAR) -class HubspotHandler(MetaAPIHandler): +class HubspotHandler(MetaAPIHandler, PassthroughMixin): """Hubspot API handler implementation""" name = "hubspot" + # REST passthrough — PAT (Private App Token) only. OAuth2 credentials + # (client_id/client_secret/refresh_token) are NOT supported here yet; + # that path needs an OAuthPassthroughMixin that refreshes tokens on + # demand. Passthrough with OAuth2 fails fast with a config error when + # `access_token` is missing from connection_data. + _bearer_token_arg = "access_token" + _base_url_default = "https://api.hubapi.com" + _test_request = PassthroughRequest(method="GET", path="/crm/v3/owners?limit=1") + def __init__(self, name: str, **kwargs: Any) -> None: """Initialize the handler.""" super().__init__(name) - connection_data = kwargs.get("connection_data", {}) + connection_data = kwargs.get("connection_data") or {} self.connection_data = connection_data self.kwargs = kwargs self.handler_storage = kwargs.get("handler_storage") diff --git a/mindsdb/integrations/handlers/rest_api_handler/README.md b/mindsdb/integrations/handlers/rest_api_handler/README.md new file mode 100644 index 0000000000..037d6ff429 --- /dev/null +++ b/mindsdb/integrations/handlers/rest_api_handler/README.md @@ -0,0 +1,168 @@ +--- +title: REST API +sidebarTitle: REST API +--- + +This documentation describes the integration of MindsDB with generic REST APIs using bearer-token authentication. +The integration allows MindsDB to forward HTTP requests to any REST API using stored credentials via the passthrough endpoint — no SQL table mapping required. + +### Prerequisites + +Before proceeding, ensure the following prerequisites are met: + +1. Install MindsDB locally via [Docker](https://docs.mindsdb.com/setup/self-hosted/docker) or [Docker Desktop](https://docs.mindsdb.com/setup/self-hosted/docker-desktop). +2. Obtain a bearer token (API key, personal access token, etc.) for the target REST API. + +## Connection + +Establish a connection to a REST API from MindsDB by executing the following SQL command: + +```sql +CREATE DATABASE my_api +WITH ENGINE = 'rest_api', +PARAMETERS = { + "base_url": "https://api.example.com", + "bearer_token": "your_token_here" +}; +``` + +Required connection parameters include the following: + +* `base_url`: The base URL of the REST API (e.g. `https://api.example.com`). All passthrough request paths are appended to this URL. +* `bearer_token`: The bearer token used for authentication. Injected as `Authorization: Bearer ` on every request. + +Optional connection parameters include the following: + +* `default_headers`: A JSON object of static headers added to every request (e.g. `{"Accept": "application/json"}`). +* `allowed_hosts`: A list of allowed hostnames for passthrough requests. Defaults to the hostname of `base_url`. Use `["*"]` to disable host containment. +* `test_path`: The path used by the `/passthrough/test` endpoint to verify connectivity. Defaults to `/`. + +### Examples + +Connect to the HubSpot API: + +```sql +CREATE DATABASE my_hubspot +WITH ENGINE = 'rest_api', +PARAMETERS = { + "base_url": "https://api.hubapi.com", + "bearer_token": "pat-eu1-..." +}; +``` + +Connect to a custom internal API with default headers: + +```sql +CREATE DATABASE my_internal_api +WITH ENGINE = 'rest_api', +PARAMETERS = { + "base_url": "https://internal.example.com/api/v2", + "bearer_token": "sk-...", + "default_headers": {"Accept": "application/json", "X-Team": "data"}, + "test_path": "/health" +}; +``` + +Connect to an API with multiple allowed hosts: + +```sql +CREATE DATABASE my_multi_region_api +WITH ENGINE = 'rest_api', +PARAMETERS = { + "base_url": "https://api.example.com", + "bearer_token": "your_token", + "allowed_hosts": ["api.example.com", "api.eu.example.com"] +}; +``` + +## Usage + +This handler is **passthrough-only** — it does not expose SQL tables. All interaction is through the REST passthrough endpoint. + +### Passthrough Requests + +Send HTTP requests to the upstream API through MindsDB: + +``` +POST /api/integrations/my_api/passthrough +``` + +```json +{ + "method": "GET", + "path": "/v1/users", + "query": {"limit": "10"}, + "headers": {"Accept": "application/json"} +} +``` + +The response wraps the upstream HTTP response: + +```json +{ + "status_code": 200, + "headers": {"content-type": "application/json"}, + "body": {"results": [...]}, + "content_type": "application/json" +} +``` + +Supported HTTP methods: `GET`, `POST`, `PUT`, `PATCH`, `DELETE`. + +### Testing the Connection + +Verify that the base URL, token, and host allowlist are configured correctly: + +``` +POST /api/integrations/my_api/passthrough/test +``` + +Returns: + +```json +{"ok": true, "status_code": 200, "host": "api.example.com", "latency_ms": 140} +``` + +Or on failure: + +```json +{"ok": false, "error_code": "auth_failed", "message": "upstream rejected credentials; base URL and allowlist look correct"} +``` + +## Security + +- **Credentials are never exposed.** The bearer token is stored in MindsDB and injected at request time. It is never returned to the caller. +- **Host containment.** Requests are restricted to hostnames in the allowlist (defaults to the `base_url` host). Private/loopback IP addresses are rejected by default. +- **Header filtering.** Callers cannot override `Authorization`, `Host`, `Cookie`, or `Proxy-*` headers. +- **Response scrubbing.** If the upstream API echoes the token in responses, it is replaced with `[REDACTED_API_KEY]` before returning to the caller. +- **Size limits.** Request bodies are capped at 1 MB, response bodies at 10 MB (configurable via environment variables). + +## Troubleshooting + + +`base_url is not configured` + +* **Symptoms**: Passthrough requests fail with a configuration error. +* **Checklist**: + 1. Ensure `base_url` is provided in the connection parameters. + 2. The URL must include the scheme (`https://`). + + + +`host 'X' is not in the datasource allowlist` + +* **Symptoms**: Passthrough requests to a valid URL are rejected. +* **Checklist**: + 1. The request path may resolve to a different hostname than `base_url`. + 2. Add the hostname to `allowed_hosts` in the connection parameters. + 3. Use `["*"]` to disable host containment (not recommended for production). + + + +`upstream rejected credentials (401/403)` + +* **Symptoms**: The `/passthrough/test` endpoint returns `error_code: "auth_failed"`. +* **Checklist**: + 1. Verify the bearer token is valid and not expired. + 2. Check that the token has the required scopes/permissions for the API endpoints you are calling. + diff --git a/mindsdb/integrations/handlers/rest_api_handler/__about__.py b/mindsdb/integrations/handlers/rest_api_handler/__about__.py new file mode 100644 index 0000000000..b7f131f401 --- /dev/null +++ b/mindsdb/integrations/handlers/rest_api_handler/__about__.py @@ -0,0 +1,9 @@ +__title__ = "MindsDB REST API handler" +__package_name__ = "mindsdb_rest_api_handler" +__version__ = "0.0.1" +__description__ = "MindsDB handler for generic REST APIs with bearer-token passthrough" +__author__ = "MindsDB Inc" +__github__ = "https://github.com/mindsdb/mindsdb" +__pypi__ = "https://pypi.org/project/mindsdb/" +__license__ = "MIT" +__copyright__ = "Copyright 2026 - mindsdb" diff --git a/mindsdb/integrations/handlers/rest_api_handler/__init__.py b/mindsdb/integrations/handlers/rest_api_handler/__init__.py new file mode 100644 index 0000000000..d9f8fcf24e --- /dev/null +++ b/mindsdb/integrations/handlers/rest_api_handler/__init__.py @@ -0,0 +1,32 @@ +from mindsdb.integrations.libs.const import HANDLER_TYPE, HANDLER_SUPPORT_LEVEL + +from .__about__ import __version__ as version, __description__ as description +from .connection_args import connection_args, connection_args_example + +try: + from .rest_api_handler import RestApiHandler as Handler + + import_error = None +except Exception as e: + Handler = None + import_error = e + +title = "REST API" +name = "rest_api" +type = HANDLER_TYPE.DATA +icon_path = "icon.svg" +support_level = HANDLER_SUPPORT_LEVEL.MINDSDB + +__all__ = [ + "Handler", + "version", + "name", + "type", + "support_level", + "title", + "description", + "import_error", + "icon_path", + "connection_args", + "connection_args_example", +] diff --git a/mindsdb/integrations/handlers/rest_api_handler/connection_args.py b/mindsdb/integrations/handlers/rest_api_handler/connection_args.py new file mode 100644 index 0000000000..4e1c92518f --- /dev/null +++ b/mindsdb/integrations/handlers/rest_api_handler/connection_args.py @@ -0,0 +1,25 @@ +"""Aggregator for the rest_api handler's connection arguments. + +REST/passthrough fields and authentication fields are defined in separate +modules (rest_connection_args, oauth_connection_args). This module merges +them into the single `connection_args` mapping that MindsDB expects each +handler package to export. +""" + +from collections import OrderedDict + +from .rest_connection_args import rest_connection_args +from .oauth_connection_args import oauth_connection_args + + +connection_args = OrderedDict() +connection_args.update(rest_connection_args) +connection_args.update(oauth_connection_args) + + +connection_args_example = OrderedDict( + base_url="https://api.example.com", + bearer_token="your_token_here", + default_headers={"Accept": "application/json"}, + allowed_hosts=["api.example.com"], +) diff --git a/mindsdb/integrations/handlers/rest_api_handler/icon.svg b/mindsdb/integrations/handlers/rest_api_handler/icon.svg new file mode 100644 index 0000000000..2346f8d4d3 --- /dev/null +++ b/mindsdb/integrations/handlers/rest_api_handler/icon.svg @@ -0,0 +1,8 @@ + + + + + + + + diff --git a/mindsdb/integrations/handlers/rest_api_handler/oauth_connection_args.py b/mindsdb/integrations/handlers/rest_api_handler/oauth_connection_args.py new file mode 100644 index 0000000000..cccca7f2d5 --- /dev/null +++ b/mindsdb/integrations/handlers/rest_api_handler/oauth_connection_args.py @@ -0,0 +1,76 @@ +"""Authentication connection arguments for the rest_api handler. + +These fields describe *how* the handler should authenticate to the upstream +API. The current strategies are static bearer tokens and OAuth2 client +credentials; both share this argument schema. The handler — not the runtime +caller — is responsible for resolving these into an Authorization header. + +Keep this module focused on schema only: do not import the OAuth2 token +provider here, do not perform any HTTP, and do not change passthrough +forwarding behavior. This step only defines the args. +""" + +from collections import OrderedDict + +from mindsdb.integrations.libs.const import HANDLER_CONNECTION_ARG_TYPE as ARG_TYPE + + +oauth_connection_args = OrderedDict( + auth_type={ + "type": ARG_TYPE.STR, + "description": ( + "Authentication strategy. 'bearer' uses a static bearer_token; " + "'oauth_client_credentials' fetches a token via the OAuth2 client " + "credentials grant. Defaults to 'bearer' for backward compatibility." + ), + "required": False, + "label": "Auth Type", + }, + bearer_token={ + "type": ARG_TYPE.PWD, + "description": "Bearer token injected as Authorization: Bearer . Used when auth_type is 'bearer'.", + "required": False, + "label": "Bearer Token", + "secret": True, + }, + token_url={ + "type": ARG_TYPE.STR, + "description": "OAuth2 token endpoint URL. Used when auth_type is 'oauth_client_credentials'.", + "required": False, + "label": "OAuth Token URL", + }, + client_id={ + "type": ARG_TYPE.STR, + "description": "OAuth2 client identifier. Used when auth_type is 'oauth_client_credentials'.", + "required": False, + "label": "OAuth Client ID", + }, + client_secret={ + "type": ARG_TYPE.PWD, + "description": "OAuth2 client secret. Used when auth_type is 'oauth_client_credentials'.", + "required": False, + "label": "OAuth Client Secret", + "secret": True, + }, + scope={ + "type": ARG_TYPE.STR, + "description": "Optional OAuth2 scope string (space-separated) or list of scopes.", + "required": False, + "label": "OAuth Scope", + }, + audience={ + "type": ARG_TYPE.STR, + "description": "Optional OAuth2 audience parameter (Auth0/Cognito-style extension; not part of RFC 6749).", + "required": False, + "label": "OAuth Audience", + }, + token_auth_method={ + "type": ARG_TYPE.STR, + "description": ( + "How client credentials are sent to the token endpoint: " + "'client_secret_post' (default) or 'client_secret_basic'." + ), + "required": False, + "label": "Token Auth Method", + }, +) diff --git a/mindsdb/integrations/handlers/rest_api_handler/rest_api_handler.py b/mindsdb/integrations/handlers/rest_api_handler/rest_api_handler.py new file mode 100644 index 0000000000..ec8ab33595 --- /dev/null +++ b/mindsdb/integrations/handlers/rest_api_handler/rest_api_handler.py @@ -0,0 +1,317 @@ +from typing import Any, Optional + +from mindsdb.integrations.libs.api_handler import APIHandler +from mindsdb.integrations.libs.passthrough import PassthroughMixin +from mindsdb.integrations.libs.passthrough_types import ( + PassthroughConfigError, + PassthroughRequest, + PassthroughResponse, +) +from mindsdb.integrations.libs.response import ( + HandlerStatusResponse as StatusResponse, + HandlerResponse as Response, + RESPONSE_TYPE, +) +from mindsdb.integrations.utilities.handlers.auth_utilities.oauth2 import ( + OAuth2ClientCredentialsProvider, +) +from mindsdb.utilities import log + +logger = log.getLogger(__name__) + + +AUTH_TYPE_BEARER = "bearer" +AUTH_TYPE_OAUTH_CLIENT_CREDENTIALS = "oauth_client_credentials" +SUPPORTED_AUTH_TYPES = (AUTH_TYPE_BEARER, AUTH_TYPE_OAUTH_CLIENT_CREDENTIALS) + +DEFAULT_AUTH_TYPE = AUTH_TYPE_BEARER +DEFAULT_TOKEN_AUTH_METHOD = "client_secret_post" +SUPPORTED_TOKEN_AUTH_METHODS = ("client_secret_post", "client_secret_basic") + +# Fields that only make sense in OAuth mode. Bearer-mode configs that +# populate any of these are rejected so misconfigured datasources fail +# loudly at sync time rather than silently ignoring the OAuth values. +OAUTH_ONLY_FIELDS = ( + "token_url", + "client_id", + "client_secret", + "scope", + "audience", + "token_auth_method", +) + + +class RestApiHandler(APIHandler, PassthroughMixin): + """Generic REST API handler — passthrough only, no SQL tables. + + This is the "bring your own URL" escape hatch for any bearer-token API + that mindsdb doesn't have a named handler for. Users supply a base_url + and a bearer_token and get full passthrough access. + """ + + name = "rest_api" + + # Advertised to /capabilities. Both modes are supported by the same + # handler instance — the runtime mode is selected per-datasource via + # connection_data["auth_type"]. _auth_mode is kept as a fallback for + # any caller that still reads the single-mode field. + _auth_modes = ["bearer", "oauth_client_credentials"] + _auth_mode = "bearer" + + def __init__(self, name: str, **kwargs: Any) -> None: + super().__init__(name) + self.connection_data = kwargs.get("connection_data") or {} + self.kwargs = kwargs + self.handler_storage = kwargs.get("handler_storage") + self.is_connected = False + + # PassthroughMixin reads these instance attributes at runtime. + self._bearer_token_arg = "bearer_token" + self._base_url_default = None # user must supply base_url + + # Build the test request from connection_data. Default to GET / + # unless the user provided a custom test_path. + test_path = self.connection_data.get("test_path", "/") + if not test_path.startswith("/"): + test_path = f"/{test_path}" + self._test_request = PassthroughRequest(method="GET", path=test_path) + + # Lazy: instantiated on first token fetch when auth_type is OAuth. + # Bearer-mode handlers never construct a provider. + self._oauth_provider: Optional[OAuth2ClientCredentialsProvider] = None + + def _oauth_storage_key(self) -> str: + """Storage key for the OAuth token cache. + + Namespaced by handler name so two datasources with different client_ids + on the same MindsDB instance never share a cached token. + """ + return f"oauth_client_credentials_tokens:{self.name}" + + def _maybe_init_oauth_provider(self) -> None: + """Idempotent. Builds the OAuth provider only in oauth_client_credentials mode. + + Bearer-mode handlers never reach this path, so the provider is never + constructed for them — that's the contract enforced by the + bearer_mode_does_not_instantiate_oauth_provider test. + """ + if self._oauth_provider is not None: + return + if self._get_auth_type() != AUTH_TYPE_OAUTH_CLIENT_CREDENTIALS: + return + self._oauth_provider = OAuth2ClientCredentialsProvider( + self.connection_data, + handler_storage=self.handler_storage, + storage_key=self._oauth_storage_key(), + ) + + def _get_bearer_token(self) -> str: + """Token resolution dispatcher used by PassthroughMixin. + + - bearer mode: keep the mixin's default (read from connection_data). + - oauth_client_credentials: resolve via the OAuth provider, which + handles token fetch, caching, and refresh. + + Caller-supplied Authorization headers are filtered out by + PassthroughMixin._build_outgoing_headers (Authorization is in + FORBIDDEN_REQUEST_HEADERS) before this method's return value is + written into the outgoing headers, so the generated auth always wins. + """ + auth_type = self._get_auth_type() + if auth_type == AUTH_TYPE_BEARER: + return super()._get_bearer_token() + if auth_type == AUTH_TYPE_OAUTH_CLIENT_CREDENTIALS: + self._maybe_init_oauth_provider() + assert self._oauth_provider is not None # _maybe_init guarantees this + return self._oauth_provider.get_access_token() + raise PassthroughConfigError(f"Unsupported auth_type '{auth_type}'") + + def _secrets_for_scrub(self) -> list[str]: + """Per-request list of values to redact from the upstream response. + + Differs from PassthroughMixin's default in two ways: + 1. In OAuth mode it uses the provider's `current_secrets()` (which + returns the *currently cached* access token, or `[]` when uncached) + instead of `_get_bearer_token()` — the latter would force a token + fetch purely to populate the scrub list, which is wasteful and + could fail (e.g. IdP unreachable) in a code path that's only + supposed to redact strings from the response body. + 2. The list is deduplicated and stripped of empties before return so + the underlying string-replace loop never does redundant work. + + The default-headers heuristic (treat values >= 16 chars as potential + secrets) is preserved verbatim so handlers that move from bearer to + OAuth without changing default_headers see the same scrubbing. + """ + secrets: list[str] = [] + auth_type = self._get_auth_type() + + if auth_type == AUTH_TYPE_BEARER: + token = self.connection_data.get("bearer_token") + if token: + secrets.append(str(token)) + elif auth_type == AUTH_TYPE_OAUTH_CLIENT_CREDENTIALS: + # Only the currently-cached access token, never client_secret. + # current_secrets() returns [] before the first successful token + # fetch, so calling _secrets_for_scrub early is safe. + if self._oauth_provider is not None: + secrets.extend(self._oauth_provider.current_secrets()) + + defaults = self.connection_data.get("default_headers") or {} + if isinstance(defaults, dict): + for value in defaults.values(): + s = str(value) + if len(s) >= 16: + secrets.append(s) + + seen: set[str] = set() + deduped: list[str] = [] + for s in secrets: + if s and s not in seen: + seen.add(s) + deduped.append(s) + return deduped + + def _get_auth_type(self) -> str: + """Return the configured auth_type, defaulting to 'bearer'. + + Empty strings are treated as "not set" so a UI that submits empty + form fields still defaults correctly. + """ + value = self.connection_data.get("auth_type") + if value is None or value == "": + return DEFAULT_AUTH_TYPE + return str(value) + + def _validate_auth_config(self) -> None: + """Validate stored auth-related connection args. + + Runs against the *stored* datasource config (synced from Minds), not + runtime passthrough requests. Raises ValueError on misconfiguration. + """ + auth_type = self._get_auth_type() + if auth_type not in SUPPORTED_AUTH_TYPES: + raise ValueError( + f"Unsupported auth_type '{auth_type}'. Supported values: {', '.join(SUPPORTED_AUTH_TYPES)}" + ) + + if auth_type == AUTH_TYPE_BEARER: + self._validate_bearer_config() + else: + self._validate_oauth_client_credentials_config() + + def _validate_bearer_config(self) -> None: + if not self.connection_data.get("bearer_token"): + raise ValueError("bearer_token is required when auth_type is 'bearer'") + + present_oauth_fields = [field for field in OAUTH_ONLY_FIELDS if self.connection_data.get(field)] + if present_oauth_fields: + raise ValueError( + f"OAuth-only fields are not permitted when auth_type is 'bearer': {', '.join(present_oauth_fields)}" + ) + + def _validate_oauth_client_credentials_config(self) -> None: + if self.connection_data.get("bearer_token"): + raise ValueError("bearer_token is not permitted when auth_type is 'oauth_client_credentials'") + + for required in ("token_url", "client_id", "client_secret"): + if not self.connection_data.get(required): + raise ValueError(f"{required} is required when auth_type is 'oauth_client_credentials'") + + method = self.connection_data.get("token_auth_method") or DEFAULT_TOKEN_AUTH_METHOD + if method not in SUPPORTED_TOKEN_AUTH_METHODS: + raise ValueError( + f"Unsupported token_auth_method '{method}'. Supported values: {', '.join(SUPPORTED_TOKEN_AUTH_METHODS)}" + ) + + def connect(self) -> None: + """No persistent connection needed — passthrough is stateless. + + Validation happens in check_connection(), which we + call separately during CREATE DATABASE. + """ + self.is_connected = True + + def check_connection(self) -> StatusResponse: + """Validate that base_url and the configured auth strategy are valid. + + Bearer mode: schema-only validation (backward-compatible). + + OAuth mode: schema validation, then construct the provider (which + runs the SSRF check on token_url), fetch a token to confirm the IdP + accepts the credentials, then call test_passthrough() to confirm the + upstream accepts the token. Errors are reported as plain strings — + the underlying layers (provider error sanitization, test_passthrough's + structured result) already redact secrets, so we forward their + messages without re-formatting. + """ + response = StatusResponse(False) + try: + base_url = self._build_base_url() + if not base_url: + response.error_message = "base_url is required" + return response + self._validate_auth_config() + + if self._get_auth_type() == AUTH_TYPE_OAUTH_CLIENT_CREDENTIALS: + # Surface IdP / SSRF / token-shape problems at connect time. + self._maybe_init_oauth_provider() + assert self._oauth_provider is not None + self._oauth_provider.get_access_token() + + # _test_request is always set in __init__ (defaults to GET /), + # so the upstream sanity check always runs in OAuth mode. Its + # return is a structured dict with a safe `message` field. + test_result = self.test_passthrough() + if not test_result.get("ok"): + response.error_message = ( + test_result.get("message") or test_result.get("error_code") or "upstream test request failed" + ) + return response + + response.success = True + self.is_connected = True + except Exception as e: + response.error_message = str(e) + return response + + def api_passthrough(self, req: PassthroughRequest) -> PassthroughResponse: + """Forward to PassthroughMixin, with a single 401 retry in OAuth mode. + + On the first 401 from the upstream, assume the cached access token + was rejected (revoked, rotated by the IdP, or invalidated mid-flight), + clear the token cache, and replay the request once. The replay goes + through the full mixin path again — same SSRF / allowed_hosts checks, + same Authorization-header override protection, same response scrub — + so a second 401 is returned to the caller as-is rather than triggering + another retry. Bearer mode is untouched. + """ + response = super().api_passthrough(req) + if ( + response.status_code == 401 + and self._get_auth_type() == AUTH_TYPE_OAUTH_CLIENT_CREDENTIALS + and self._oauth_provider is not None + ): + self._oauth_provider.clear_cached_token() + response = super().api_passthrough(req) + return response + + def native_query(self, query: str) -> Response: + """Not supported — use passthrough instead.""" + return Response( + RESPONSE_TYPE.ERROR, + error_message="rest_api handler is passthrough-only. Use the /passthrough endpoint.", + ) + + def get_tables(self) -> Response: + """No SQL tables — passthrough only.""" + import pandas as pd + + return Response(RESPONSE_TYPE.TABLE, data_frame=pd.DataFrame()) + + def get_columns(self, table_name: str) -> Response: + """No SQL tables — passthrough only.""" + return Response( + RESPONSE_TYPE.ERROR, + error_message="rest_api handler is passthrough-only. No tables available.", + ) diff --git a/mindsdb/integrations/handlers/rest_api_handler/rest_connection_args.py b/mindsdb/integrations/handlers/rest_api_handler/rest_connection_args.py new file mode 100644 index 0000000000..87903b9cf0 --- /dev/null +++ b/mindsdb/integrations/handlers/rest_api_handler/rest_connection_args.py @@ -0,0 +1,39 @@ +"""REST/passthrough connection arguments for the rest_api handler. + +These fields configure how the handler talks HTTP to the upstream API +(base URL, allowed hosts, default headers, test path). Authentication +fields live in oauth_connection_args.py — keep them separate so the +passthrough plumbing stays independent of the auth strategy. +""" + +from collections import OrderedDict + +from mindsdb.integrations.libs.const import HANDLER_CONNECTION_ARG_TYPE as ARG_TYPE + + +rest_connection_args = OrderedDict( + base_url={ + "type": ARG_TYPE.STR, + "description": "Base URL of the REST API (e.g. https://api.example.com)", + "required": True, + "label": "Base URL", + }, + default_headers={ + "type": ARG_TYPE.DICT, + "description": 'Static headers added to every request (e.g. {"Accept": "application/json"})', + "required": False, + "label": "Default Headers", + }, + allowed_hosts={ + "type": ARG_TYPE.LIST, + "description": 'Allowed hostnames for passthrough requests. Defaults to the base_url host. Use ["*"] to disable containment.', + "required": False, + "label": "Allowed Hosts", + }, + test_path={ + "type": ARG_TYPE.STR, + "description": "Path used by the /passthrough/test endpoint. Defaults to /", + "required": False, + "label": "Test Path", + }, +) diff --git a/mindsdb/integrations/handlers/salesforce_handler/salesforce_handler.py b/mindsdb/integrations/handlers/salesforce_handler/salesforce_handler.py index d816e23314..e87b44317e 100644 --- a/mindsdb/integrations/handlers/salesforce_handler/salesforce_handler.py +++ b/mindsdb/integrations/handlers/salesforce_handler/salesforce_handler.py @@ -5,6 +5,8 @@ from salesforce_api.exceptions import AuthenticationError, RestRequestCouldNotBeUnderstoodError from mindsdb.integrations.libs.api_handler import MetaAPIHandler +from mindsdb.integrations.libs.passthrough import PassthroughMixin +from mindsdb.integrations.libs.passthrough_types import PassthroughRequest from mindsdb.integrations.libs.response import ( HandlerResponse as Response, HandlerStatusResponse as StatusResponse, @@ -18,13 +20,29 @@ logger = log.getLogger(__name__) -class SalesforceHandler(MetaAPIHandler): +class SalesforceHandler(MetaAPIHandler, PassthroughMixin): """ This handler handles the connection and execution of SQL statements on Salesforce. """ name = "salesforce" + # REST passthrough configuration. Salesforce's base URL is per-org + # (`instance_url`) and is normally discovered at auth time. v1 requires + # the caller to supply both `access_token` and `instance_url` explicitly + # in connection_data; dynamic discovery from the username/password flow + # is deferred to a future refresh-aware mixin. + _bearer_token_arg = "access_token" + _base_url_default = None + _test_request = PassthroughRequest(method="GET", path="/services/data/v60.0/") + + def _build_base_url(self) -> str | None: + data = self._get_connection_data() + instance_url = data.get("instance_url") + if not instance_url: + return None + return str(instance_url).rstrip("/") + def __init__(self, name: Text, connection_data: Dict, **kwargs: Any) -> None: """ Initializes the handler. diff --git a/mindsdb/integrations/handlers/shopify_handler/shopify_handler.py b/mindsdb/integrations/handlers/shopify_handler/shopify_handler.py index 6b789b7845..98bc6602f2 100644 --- a/mindsdb/integrations/handlers/shopify_handler/shopify_handler.py +++ b/mindsdb/integrations/handlers/shopify_handler/shopify_handler.py @@ -15,6 +15,8 @@ GiftCardsTable, ) from mindsdb.integrations.libs.api_handler import MetaAPIHandler +from mindsdb.integrations.libs.passthrough import PassthroughMixin +from mindsdb.integrations.libs.passthrough_types import PassthroughRequest from mindsdb.integrations.libs.response import ( HandlerStatusResponse as StatusResponse, HandlerResponse as Response, @@ -33,13 +35,37 @@ logger = log.getLogger(__name__) -class ShopifyHandler(MetaAPIHandler): +class ShopifyHandler(MetaAPIHandler, PassthroughMixin): """ The Shopify handler implementation. """ name = "shopify" + # REST passthrough configuration. Shopify sends the Admin API token in + # `X-Shopify-Access-Token`, not `Authorization: Bearer`, so we override + # the default auth header. v1 requires the caller to pre-supply the + # access token in connection_data — the existing client_id/client_secret + # OAuth dance runs inside `connect()` and isn't surfaced to the mixin. + _bearer_token_arg = "access_token" + _auth_header_name = "X-Shopify-Access-Token" + _auth_header_format = "{token}" + _auth_mode = "custom" + _base_url_default = None + # Version-less path — Shopify redirects this to the current stable + # Admin API version, so the probe survives quarterly API releases. + _test_request = PassthroughRequest(method="GET", path="/admin/shop.json") + + def _build_base_url(self) -> str | None: + data = self._get_connection_data() + shop = data.get("shop_url") + if not shop: + return None + shop = str(shop) + if not shop.startswith(("http://", "https://")): + shop = f"https://{shop}" + return shop.rstrip("/") + def __init__(self, name: str, **kwargs): """ Initialize the handler. diff --git a/mindsdb/integrations/libs/passthrough.py b/mindsdb/integrations/libs/passthrough.py new file mode 100644 index 0000000000..f535d3dfeb --- /dev/null +++ b/mindsdb/integrations/libs/passthrough.py @@ -0,0 +1,477 @@ +""" +PassthroughMixin — generic HTTP passthrough for authenticated REST APIs. + +A handler opts in by declaring three class attributes: + + class MyHandler(APIHandler, PassthroughMixin): + _bearer_token_arg = "api_key" # key in connection_data + _base_url_default = "https://api.example.com" # fallback if user omits + _test_request = PassthroughRequest("GET", "/me") + +The mixin defaults to ``Authorization: Bearer ``. Handlers using a +different auth scheme (e.g. Shopify's ``X-Shopify-Access-Token``) override +``_auth_header_name`` and ``_auth_header_format`` — see CHANGE 3. + +The mixin reads ``self.connection_data`` (a dict populated from +integration setup) to pull the token, resolve the base URL, and enforce +the host allowlist. Handlers that need custom URL composition (e.g. +``http://{host}:{port}``) override ``_build_base_url``. + +``PassthroughProtocol`` is a structural type describing the two public +methods (``api_passthrough`` and ``test_passthrough``). The HTTP layer +checks against the protocol rather than the mixin class, so a handler +can satisfy the contract without inheriting the default implementation. +""" + +import ipaddress +import os +import time +from typing import Any, Protocol, runtime_checkable +from urllib.parse import urlparse + +import requests + +from mindsdb.integrations.libs.passthrough_types import ( + ALLOWED_METHODS, + FORBIDDEN_REQUEST_HEADERS, + HOP_BY_HOP_RESPONSE_HEADERS, + HostNotAllowedError, + PassthroughConfigError, + PassthroughRequest, + PassthroughResponse, + PassthroughValidationError, +) +from mindsdb.utilities import log + +logger = log.getLogger(__name__) + + +PASSTHROUGH_TIMEOUT_SECONDS = int(os.getenv("MINDSDB_PASSTHROUGH_TIMEOUT_SECONDS", "30")) +PASSTHROUGH_MAX_REQUEST_BYTES = int(os.getenv("MINDSDB_PASSTHROUGH_MAX_REQUEST_BYTES", str(1 * 1024 * 1024))) +PASSTHROUGH_MAX_RESPONSE_BYTES = int(os.getenv("MINDSDB_PASSTHROUGH_MAX_RESPONSE_BYTES", str(10 * 1024 * 1024))) + +REDACTED_SENTINEL = "[REDACTED_API_KEY]" + + +def _is_private_host(hostname: str) -> bool: + """Return True if `hostname` resolves to a private/loopback/link-local IP literal. + + Only IP literals are checked; DNS resolution is intentionally not performed + (handlers may legitimately point at an internal DNS name the operator has + approved via `allowed_hosts`). The IP-literal check prevents a caller from + smuggling `http://127.0.0.1/` or `http://10.0.0.1/` through a typo'd base_url. + """ + try: + ip = ipaddress.ip_address(hostname) + except ValueError: + return False + return ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_multicast or ip.is_reserved + + +def _host_matches(host: str, allowlist: list[str]) -> bool: + if not host: + return False + host = host.lower() + return any(host == entry.lower() for entry in allowlist) + + +@runtime_checkable +class PassthroughProtocol(Protocol): + """Structural contract for handlers that expose HTTP passthrough. + + The HTTP namespace checks against this Protocol rather than the + `PassthroughMixin` class, which lets future handlers satisfy the + contract without inheriting the default implementation. + """ + + def api_passthrough(self, req: PassthroughRequest) -> PassthroughResponse: ... + + def test_passthrough(self) -> dict[str, Any]: ... + + +class PassthroughMixin: + # Required overrides + _bearer_token_arg: str = "" + + # Optional overrides + _base_url_arg: str = "base_url" + _base_url_default: str | None = None + _allowed_hosts_arg: str = "allowed_hosts" + _default_headers_arg: str = "default_headers" + + # Auth header. Defaults to bearer-compatible; handlers using a custom + # scheme (e.g. Shopify's `X-Shopify-Access-Token: `) override + # both attrs. The value from `_get_bearer_token()` is formatted into + # `{token}` — the method name is retained for backwards compat but + # now represents "the auth secret" regardless of scheme. + _auth_header_name: str = "Authorization" + _auth_header_format: str = "Bearer {token}" + + # Declarative auth mode surfaced to /capabilities. One handler instance + # has exactly one auth mode, so this is a single string; the API + # response still wraps it in a list because a future contract may + # surface handlers supporting multiple configurations. Known values: + # "bearer", "custom", "oauth_refresh". Handlers that use a non-bearer + # header scheme or a refresh-aware mixin should set this explicitly — + # don't infer it from _auth_header_format, since OAuth-refresh also + # uses "Bearer {token}" but is a distinct mode. + _auth_mode: str = "bearer" + + # Canonical sanity-check request for `test_passthrough()`. Handlers MUST + # set this if they want the /passthrough/test endpoint to do anything + # useful. `None` means "test endpoint returns 'not implemented'". + _test_request: PassthroughRequest | None = None + + # Stamped on every upstream request so the upstream can identify our + # traffic for support/debugging. See design §13 (q3). + _upstream_marker_header: str = "X-Minds-Passthrough" + + # Hook: override when URL composition is more than "take a string" + # (e.g. strapi composes from host+port). + def _build_base_url(self) -> str | None: + data = self._get_connection_data() + value = data.get(self._base_url_arg) if self._base_url_arg else None + if value: + return str(value).rstrip("/") + if self._base_url_default is not None: + return self._base_url_default.rstrip("/") + return None + + def _get_connection_data(self) -> dict[str, Any]: + """Return the handler's stored connection_data dict. + + Handlers store this differently; we check the common attribute names + so most handlers don't need to override. + """ + for attr in ("connection_data", "_connection_data"): + value = getattr(self, attr, None) + if isinstance(value, dict): + return value + return {} + + def _get_bearer_token(self) -> str: + if not self._bearer_token_arg: + raise PassthroughConfigError("handler did not declare _bearer_token_arg") + token = self._get_connection_data().get(self._bearer_token_arg) + if not token: + raise PassthroughConfigError(f"bearer token ('{self._bearer_token_arg}') is missing from connection_data") + return str(token) + + def _resolve_url(self, path: str) -> tuple[str, str]: + """Return ``(url, hostname)`` for the outgoing request. + + `path` is appended to the base URL verbatim. After joining we parse + the result and compare the hostname against the allowlist — path + injection tricks like ``@evil.com`` or ``//evil.com`` are rejected + at the hostname-comparison step, not by string matching. + """ + if not path.startswith("/"): + raise PassthroughValidationError("path must start with '/'") + base = self._build_base_url() + if not base: + raise PassthroughConfigError("base_url is not configured for this datasource") + + url = f"{base}{path}" + parsed = urlparse(url) + if parsed.scheme not in ("http", "https") or not parsed.hostname: + raise PassthroughValidationError(f"resolved URL is not valid: {url}") + return url, parsed.hostname + + def _allowed_hosts(self, default_host: str) -> list[str]: + data = self._get_connection_data() + allowed = data.get(self._allowed_hosts_arg) + if isinstance(allowed, list) and allowed: + return [str(h) for h in allowed] + return [default_host] + + def _check_host_allowed(self, hostname: str) -> None: + allowlist = self._allowed_hosts(hostname) + if allowlist == ["*"]: + return + if not _host_matches(hostname, allowlist): + raise HostNotAllowedError(f"host '{hostname}' is not in the datasource allowlist") + if _is_private_host(hostname): + raise HostNotAllowedError( + f"host '{hostname}' resolves to a private/loopback address; " + "set allowed_hosts=['*'] to bypass this check (explicit " + "listing is ignored for private IPs)" + ) + + def _build_outgoing_headers(self, caller_headers: dict[str, str], bearer: str) -> dict[str, str]: + """Merge caller headers (filtered) + default_headers + Authorization.""" + out: dict[str, str] = {} + data = self._get_connection_data() + defaults = data.get(self._default_headers_arg) or {} + if isinstance(defaults, dict): + out.update({str(k): str(v) for k, v in defaults.items()}) + for name, value in (caller_headers or {}).items(): + if name.lower() in FORBIDDEN_REQUEST_HEADERS: + continue + if name.lower().startswith("proxy-"): + continue + out[name] = value + out[self._auth_header_name] = self._auth_header_format.format(token=bearer) + out[self._upstream_marker_header] = "1" + return out + + def _secrets_for_scrub(self) -> list[str]: + """Values that must not appear in the response returned to the caller.""" + secrets: list[str] = [] + try: + secrets.append(self._get_bearer_token()) + except PassthroughConfigError: + pass + data = self._get_connection_data() + defaults = data.get(self._default_headers_arg) or {} + if isinstance(defaults, dict): + for value in defaults.values(): + s = str(value) + if len(s) >= 16: + secrets.append(s) + return secrets + + def _scrub(self, text: str, secrets: list[str]) -> str: + for s in secrets: + if s: + text = text.replace(s, REDACTED_SENTINEL) + return text + + def _scrub_bytes(self, data: bytes, secrets: list[str]) -> bytes: + """Byte-level secret scrub (spec §7.6). + + Replacing on raw bytes before decoding prevents U+FFFD substitutions + from `errors="replace"` from fragmenting a secret and letting part of + it survive the scrub. + """ + sentinel = REDACTED_SENTINEL.encode("utf-8") + for s in secrets: + if s: + data = data.replace(s.encode("utf-8"), sentinel) + return data + + def _filter_response_headers(self, headers: dict[str, str], secrets: list[str]) -> dict[str, str]: + filtered: dict[str, str] = {} + for name, value in headers.items(): + if name.lower() in HOP_BY_HOP_RESPONSE_HEADERS: + continue + filtered[name] = self._scrub(str(value), secrets) + return filtered + + def _read_capped_body(self, response: requests.Response) -> bytes: + """Read the response body in chunks, abort if it exceeds the cap.""" + chunks: list[bytes] = [] + total = 0 + try: + for chunk in response.iter_content(chunk_size=64 * 1024): + if not chunk: + continue + total += len(chunk) + if total > PASSTHROUGH_MAX_RESPONSE_BYTES: + raise PassthroughValidationError(f"response body exceeded {PASSTHROUGH_MAX_RESPONSE_BYTES} bytes") + chunks.append(chunk) + finally: + response.close() + return b"".join(chunks) + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def api_passthrough(self, req: PassthroughRequest) -> PassthroughResponse: + method = (req.method or "").upper() + if method not in ALLOWED_METHODS: + raise PassthroughValidationError(f"method '{req.method}' is not allowed") + + connection_data = self._get_connection_data() + allowed_methods_cfg = connection_data.get("allowed_methods") + if allowed_methods_cfg is not None: + if not isinstance(allowed_methods_cfg, list): + raise PassthroughConfigError("'allowed_methods' must be a list of HTTP method strings") + if not all(isinstance(m, str) for m in allowed_methods_cfg): + raise PassthroughConfigError("'allowed_methods' must be a list of HTTP method strings") + allowed_upper = {m.upper() for m in allowed_methods_cfg} + unknown = sorted(allowed_upper - ALLOWED_METHODS) + if unknown: + raise PassthroughConfigError( + f"'allowed_methods' contains unsupported verbs: {unknown}. Allowed: {sorted(ALLOWED_METHODS)}" + ) + if method not in allowed_upper: + raise PassthroughValidationError( + f"method '{method}' is not permitted by this datasource", + error_code="method_not_allowed", + http_status=405, + ) + + request_bytes = 0 + if req.body is not None: + # requests will serialize dict bodies to JSON; we cap on the + # serialized length. For raw strings / bytes we cap directly. + import json as _json + + if isinstance(req.body, (dict, list)): + body_bytes_for_size = _json.dumps(req.body).encode("utf-8") + elif isinstance(req.body, (bytes, bytearray)): + body_bytes_for_size = bytes(req.body) + else: + body_bytes_for_size = str(req.body).encode("utf-8") + if len(body_bytes_for_size) > PASSTHROUGH_MAX_REQUEST_BYTES: + raise PassthroughValidationError(f"request body exceeded {PASSTHROUGH_MAX_REQUEST_BYTES} bytes") + request_bytes = len(body_bytes_for_size) + + url, hostname = self._resolve_url(req.path) + self._check_host_allowed(hostname) + bearer = self._get_bearer_token() + headers = self._build_outgoing_headers(req.headers or {}, bearer) + + request_kwargs: dict[str, Any] = { + "headers": headers, + "params": req.query or None, + "timeout": PASSTHROUGH_TIMEOUT_SECONDS, + "stream": True, + } + if req.body is not None: + if isinstance(req.body, (dict, list)): + request_kwargs["json"] = req.body + else: + request_kwargs["data"] = req.body + + datasource_name = getattr(self, "name", None) or "?" + start = time.monotonic() + response = requests.request(method, url, **request_kwargs) + body_bytes = self._read_capped_body(response) + duration_ms = int((time.monotonic() - start) * 1000) + + secrets = self._secrets_for_scrub() + body_bytes = self._scrub_bytes(body_bytes, secrets) + content_type = response.headers.get("Content-Type", "") or "" + out_headers = self._filter_response_headers(dict(response.headers), secrets) + + body: Any + if "application/json" in content_type.lower(): + try: + text = body_bytes.decode("utf-8", errors="replace") + import json as _json + + body = _json.loads(text) if text else None + except ValueError: + body = body_bytes.decode("utf-8", errors="replace") + else: + body = body_bytes.decode("utf-8", errors="replace") + + self._log_passthrough_call( + method=method, + path=req.path, + datasource_name=datasource_name, + upstream_status_code=response.status_code, + request_bytes=request_bytes, + response_bytes=len(body_bytes), + duration_ms=duration_ms, + ) + + return PassthroughResponse( + status_code=response.status_code, + headers=out_headers, + body=body, + content_type=content_type.split(";", 1)[0].strip() or None, + ) + + def _log_passthrough_call( + self, + *, + method: str, + path: str, + datasource_name: str, + upstream_status_code: int, + request_bytes: int, + response_bytes: int, + duration_ms: int, + ) -> None: + """Emit one audit line per passthrough call (spec §7.8). + + Never logs headers or bodies. user_id / org_id are pulled from the + MindsDB request context when available; in test/dev invocations + where the context is not populated, they are omitted. + """ + fields: dict[str, Any] = { + "method": method, + "path": path, + "datasource_name": datasource_name, + "upstream_status_code": upstream_status_code, + "request_bytes": request_bytes, + "response_bytes": response_bytes, + "duration_ms": duration_ms, + } + # TODO: org_id lives in Minds; when the passthrough is called via the + # Minds gateway the org scope should be propagated and logged here. + try: + from mindsdb.utilities.context import context as _ctx + + user_id = getattr(_ctx, "user_id", None) + company_id = getattr(_ctx, "company_id", None) + if user_id is not None: + fields["user_id"] = user_id + if company_id is not None: + fields["company_id"] = company_id + except Exception: + pass + # DEBUG level per team decision: per-request audit logging at + # info level happens in Minds at the HTTP layer. This log is + # intended for mindsdb-side troubleshooting only. + logger.debug("passthrough %s", fields) + + def test_passthrough(self) -> dict[str, Any]: + """Run the handler's canonical sanity-check call (see §6.1a). + + Returns a structured dict the HTTP layer forwards to the caller: + { "ok": bool, "status_code": int?, "host": str?, "latency_ms": int?, + "error_code": str?, "message": str? } + """ + if self._test_request is None: + return { + "ok": False, + "error_code": "not_implemented", + "message": "this handler does not define a passthrough test request", + } + + start = time.monotonic() + try: + resp = self.api_passthrough(self._test_request) + except HostNotAllowedError as e: + return {"ok": False, "error_code": e.error_code, "message": str(e)} + except PassthroughConfigError as e: + return {"ok": False, "error_code": e.error_code, "message": str(e)} + except PassthroughValidationError as e: + return {"ok": False, "error_code": e.error_code, "message": str(e)} + except requests.exceptions.Timeout as e: + return {"ok": False, "error_code": "timeout", "message": str(e)} + except requests.exceptions.ConnectionError as e: + return {"ok": False, "error_code": "network", "message": str(e)} + except Exception as e: # noqa: BLE001 + logger.exception("passthrough test failed unexpectedly") + return {"ok": False, "error_code": "unknown", "message": str(e)} + + latency_ms = int((time.monotonic() - start) * 1000) + try: + _, host = self._resolve_url(self._test_request.path) + except Exception: + host = None + + if 200 <= resp.status_code < 300: + return {"ok": True, "status_code": resp.status_code, "host": host, "latency_ms": latency_ms} + if resp.status_code in (401, 403): + return { + "ok": False, + "status_code": resp.status_code, + "host": host, + "latency_ms": latency_ms, + "error_code": "auth_failed", + "message": "upstream rejected credentials; base URL and allowlist look correct", + } + return { + "ok": False, + "status_code": resp.status_code, + "host": host, + "latency_ms": latency_ms, + "error_code": "upstream_error", + "message": f"upstream returned {resp.status_code}", + } diff --git a/mindsdb/integrations/libs/passthrough_types.py b/mindsdb/integrations/libs/passthrough_types.py new file mode 100644 index 0000000000..63d1cb1452 --- /dev/null +++ b/mindsdb/integrations/libs/passthrough_types.py @@ -0,0 +1,94 @@ +""" +Request/response dataclasses and error types for the REST passthrough path. + +These are the payloads exchanged between the HTTP layer and +`PassthroughMixin`. They are intentionally framework-agnostic so the +mixin can be unit-tested without Flask. +""" + +from dataclasses import dataclass, field +from typing import Any + + +ALLOWED_METHODS = frozenset({"GET", "POST", "PUT", "PATCH", "DELETE"}) + +# Hop-by-hop and auth-related headers that must never come from the caller. +FORBIDDEN_REQUEST_HEADERS = frozenset( + h.lower() + for h in ( + "authorization", + "host", + "cookie", + "content-length", + "connection", + ) +) + +# Hop-by-hop response headers stripped before returning to the caller. +HOP_BY_HOP_RESPONSE_HEADERS = frozenset( + h.lower() + for h in ( + "connection", + "keep-alive", + "proxy-authenticate", + "proxy-authorization", + "te", + "trailers", + "transfer-encoding", + "upgrade", + "content-length", + ) +) + + +@dataclass +class PassthroughRequest: + method: str + path: str + query: dict[str, Any] = field(default_factory=dict) + headers: dict[str, str] = field(default_factory=dict) + body: Any = None + + +@dataclass +class PassthroughResponse: + status_code: int + headers: dict[str, str] + body: Any + content_type: str | None = None + + +class PassthroughError(Exception): + """Base class for passthrough failures that should not be leaked as 500s.""" + + error_code: str = "passthrough_error" + http_status: int = 400 + + def __init__(self, message: str, *, error_code: str | None = None, http_status: int | None = None): + super().__init__(message) + if error_code is not None: + self.error_code = error_code + if http_status is not None: + self.http_status = http_status + + +class PassthroughConfigError(PassthroughError): + error_code = "config_error" + http_status = 500 + + +class HostNotAllowedError(PassthroughError): + error_code = "host_not_allowed" + http_status = 400 + + +class PassthroughValidationError(PassthroughError): + error_code = "invalid_request" + http_status = 400 + + +class PassthroughNotSupportedError(PassthroughError): + """Raised when a handler does not implement the mixin.""" + + error_code = "passthrough_not_supported" + http_status = 501 diff --git a/mindsdb/integrations/utilities/handlers/auth_utilities/oauth2/__init__.py b/mindsdb/integrations/utilities/handlers/auth_utilities/oauth2/__init__.py new file mode 100644 index 0000000000..03759c1b53 --- /dev/null +++ b/mindsdb/integrations/utilities/handlers/auth_utilities/oauth2/__init__.py @@ -0,0 +1,23 @@ +from .client_credentials import ( + ALLOWED_AUTH_METHODS, + CONNECT_TIMEOUT_SECONDS, + DEFAULT_EXPIRES_IN_SECONDS, + DEFAULT_STORAGE_KEY, + DEFAULT_TOKEN_AUTH_METHOD, + EXPIRY_SKEW_SECONDS, + MAX_RESPONSE_BYTES, + OAuth2ClientCredentialsProvider, + READ_TIMEOUT_SECONDS, +) + +__all__ = [ + "OAuth2ClientCredentialsProvider", + "ALLOWED_AUTH_METHODS", + "DEFAULT_TOKEN_AUTH_METHOD", + "DEFAULT_STORAGE_KEY", + "CONNECT_TIMEOUT_SECONDS", + "READ_TIMEOUT_SECONDS", + "DEFAULT_EXPIRES_IN_SECONDS", + "EXPIRY_SKEW_SECONDS", + "MAX_RESPONSE_BYTES", +] diff --git a/mindsdb/integrations/utilities/handlers/auth_utilities/oauth2/client_credentials.py b/mindsdb/integrations/utilities/handlers/auth_utilities/oauth2/client_credentials.py new file mode 100644 index 0000000000..2c737c14e6 --- /dev/null +++ b/mindsdb/integrations/utilities/handlers/auth_utilities/oauth2/client_credentials.py @@ -0,0 +1,461 @@ +"""OAuth2 client credentials grant token provider. + +A reusable, thread-safe utility that fetches and caches OAuth2 access tokens +using the RFC 6749 client credentials grant. Suitable for server-to-server +flows with no end-user redirect. + +Public surface: + OAuth2ClientCredentialsProvider(connection_data, handler_storage=None, + storage_key=None) + .get_access_token() -> str + .clear_cached_token() -> None + .current_secrets() -> list[str] + +The constructor takes the handler's stored connection_data dict as a single +argument so consumers (e.g. the rest_api handler) can pass `self.connection_data` +verbatim. Auth resolution stays inside the provider — callers never see +client_secret values, and cached state never contains credentials or config. +""" + +from __future__ import annotations + +import base64 +import ipaddress +import socket +import threading +import time +from typing import Any, Optional +from urllib.parse import urlparse + +import requests + +from mindsdb.integrations.libs.passthrough import _host_matches, _is_private_host +from mindsdb.utilities import log + +logger = log.getLogger(__name__) + + +# Public constants — exposed so callers (and tests) can reference them +# without reaching for underscore-prefixed names. +ALLOWED_AUTH_METHODS = ("client_secret_post", "client_secret_basic") +DEFAULT_TOKEN_AUTH_METHOD = "client_secret_post" +DEFAULT_STORAGE_KEY = "oauth_client_credentials_tokens" + +CONNECT_TIMEOUT_SECONDS = 10 +READ_TIMEOUT_SECONDS = 30 +DEFAULT_EXPIRES_IN_SECONDS = 300 +EXPIRY_SKEW_SECONDS = 60 +MAX_RESPONSE_BYTES = 64 * 1024 + + +# Backward-compatible private aliases. Removing these would break any external +# code that imported the underscore names; safe to keep as thin pointers. +_ALLOWED_AUTH_METHODS = ALLOWED_AUTH_METHODS +_CONNECT_TIMEOUT_SECONDS = CONNECT_TIMEOUT_SECONDS +_READ_TIMEOUT_SECONDS = READ_TIMEOUT_SECONDS +_DEFAULT_EXPIRES_IN_SECONDS = DEFAULT_EXPIRES_IN_SECONDS +_SKEW_SECONDS = EXPIRY_SKEW_SECONDS +_MAX_RESPONSE_BYTES = MAX_RESPONSE_BYTES + + +def _is_localhost_name(host: str) -> bool: + h = host.lower().rstrip(".") + if h == "localhost": + return True + if h.endswith(".localhost"): + return True + if h in ("ip6-localhost", "ip6-loopback"): + return True + return False + + +def _is_forbidden_ip_string(addr: str) -> bool: + """True if `addr` is an IP literal in a forbidden range. + + Wraps the passthrough mixin's `_is_private_host` (which covers loopback, + private, link-local, multicast, reserved) and adds the unspecified range + (0.0.0.0, ::) — the unspecified address is meaningless as an outbound + target and a common SSRF foothold. + """ + if _is_private_host(addr): + return True + try: + return ipaddress.ip_address(addr).is_unspecified + except ValueError: + return False + + +def _validate_token_url(token_url: str, allowed_hosts: Optional[list] = None) -> None: + """Raise ValueError if token_url violates SSRF or allowlist rules. + + `allowed_hosts` mirrors the passthrough handler's `connection_data["allowed_hosts"]` + semantics: + - missing / None / empty list → no host allowlist applied (SSRF still runs) + - ["*"] → host allowlist skipped, but baseline SSRF protections still apply + - other list → token_url host must match one of the entries (case-insensitive) + + Baseline SSRF protections always run, regardless of `allowed_hosts`. A + wildcard cannot enable loopback/private/link-local destinations, because + even an operator-curated wildcard is not a license to call internal + infrastructure with the datasource's stored client_secret. + """ + if not isinstance(token_url, str) or not token_url: + raise ValueError("token_url must be a non-empty string") + + parsed = urlparse(token_url) + scheme = parsed.scheme.lower() + if scheme not in ("http", "https"): + raise ValueError(f"token_url scheme '{parsed.scheme}' is not allowed; only http and https are permitted") + + host = parsed.hostname + if not host: + raise ValueError("token_url must include a host component") + + # Host allowlist check — independent of and prior to baseline SSRF. + # A non-wildcard list confines the token endpoint to the operator's + # pre-approved hosts. ["*"] disables only this check. + if isinstance(allowed_hosts, list) and allowed_hosts and allowed_hosts != ["*"]: + normalized = [str(h) for h in allowed_hosts] + if not _host_matches(host, normalized): + raise ValueError(f"token_url host '{host}' is not in the datasource allowed_hosts allowlist") + + # Baseline SSRF — runs unconditionally, including under ["*"]. + if _is_localhost_name(host): + raise ValueError(f"token_url host '{host}' is a localhost alias and is not permitted") + + if _is_forbidden_ip_string(host): + raise ValueError(f"token_url host '{host}' is in a forbidden IP range") + + # If host is a name (not an IP literal), resolve and re-check each address + # to defeat names that point at internal IPs. + is_ip_literal = True + try: + ipaddress.ip_address(host) + except ValueError: + is_ip_literal = False + + if not is_ip_literal: + try: + addrinfo = socket.getaddrinfo(host, None) + except socket.gaierror as exc: + raise ValueError(f"token_url host '{host}' could not be resolved: {exc}") from exc + + for info in addrinfo: + addr = info[4][0] + # Strip IPv6 zone identifier if present + if "%" in addr: + addr = addr.split("%", 1)[0] + if _is_forbidden_ip_string(addr): + raise ValueError(f"token_url host '{host}' resolves to a forbidden IP range") + + if scheme == "http": + logger.warning( + "OAuth2 token_url uses http://; credentials will be transmitted over an unencrypted channel. host=%s", + host, + ) + + +class OAuth2ClientCredentialsProvider: + """Fetches and caches OAuth2 access tokens using the client credentials grant. + + Thread-safe: concurrent callers of get_access_token() during refresh trigger + exactly one HTTP request to the token endpoint via double-checked locking. + """ + + def __init__( + self, + connection_data: dict, + handler_storage: Any = None, + storage_key: Optional[str] = None, + ) -> None: + if not isinstance(connection_data, dict): + raise TypeError("connection_data must be a dict") + + token_url = connection_data.get("token_url") + client_id = connection_data.get("client_id") + client_secret = connection_data.get("client_secret") + scope = connection_data.get("scope") + audience = connection_data.get("audience") + token_auth_method = connection_data.get("token_auth_method") or DEFAULT_TOKEN_AUTH_METHOD + + if token_auth_method not in ALLOWED_AUTH_METHODS: + raise ValueError( + f"token_auth_method '{token_auth_method}' is not supported; " + f"allowed values are: {', '.join(ALLOWED_AUTH_METHODS)}" + ) + + if not token_url: + raise ValueError("connection_data['token_url'] is required") + if not client_id: + raise ValueError("connection_data['client_id'] is required") + if not client_secret: + raise ValueError("connection_data['client_secret'] is required") + + # token_url is validated against the same `allowed_hosts` list the + # passthrough mixin uses for upstream API calls. Operators who restrict + # passthrough to specific hosts must also list the IdP's token host, + # since a different host is permitted but must still be allowlisted. + _validate_token_url(token_url, allowed_hosts=connection_data.get("allowed_hosts")) + + self.token_url = token_url + self.client_id = client_id + self.client_secret = client_secret + self.scope = scope + self.audience = audience + self.token_auth_method = token_auth_method + self.handler_storage = handler_storage + self.storage_key = storage_key or DEFAULT_STORAGE_KEY + + self._lock = threading.Lock() + self._memory_cache: Optional[dict] = None + self._missing_expires_in_logged = False + + def get_access_token(self) -> str: + """Return a valid access token, refreshing if needed.""" + cached = self._read_cache() + if cached and not self._is_expired(cached): + return cached["access_token"] + + with self._lock: + # Re-read inside the lock — another thread may have refreshed while + # we were waiting on the lock. Without this second check, two + # threads that both observed an expired token would both refresh. + cached = self._read_cache() + if cached and not self._is_expired(cached): + return cached["access_token"] + + new_token = self._request_token() + self._write_cache(new_token) + return new_token["access_token"] + + def clear_cached_token(self) -> None: + """Clear the cached token from both in-memory and persistent storage.""" + with self._lock: + self._memory_cache = None + if self.handler_storage is not None: + try: + self.handler_storage.encrypted_json_set(self.storage_key, None) + except Exception as exc: + logger.debug( + "Failed to clear OAuth2 token from handler_storage; cleared in-memory only. host=%s err=%s", + self._safe_host(), + exc, + ) + + def current_secrets(self) -> list: + """Return secrets that response-scrub layers should redact. + + Safe to call per-request. For the client credentials flow this is + currently the cached access token if any. + """ + cached = self._read_cache() + if cached and not self._is_expired(cached): + token = cached.get("access_token") + if token: + return [token] + return [] + + def _request_token(self) -> dict: + body: dict = {"grant_type": "client_credentials"} + + if self.scope is not None: + if isinstance(self.scope, (list, tuple)): + scope_value = " ".join(str(s) for s in self.scope) + else: + scope_value = str(self.scope) + if scope_value: + body["scope"] = scope_value + + if self.audience is not None: + # `audience` is an Auth0-style extension; also accepted by Cognito + # and others. Not part of RFC 6749. + body["audience"] = self.audience + + headers = { + "Content-Type": "application/x-www-form-urlencoded", + "Accept": "application/json", + } + + if self.token_auth_method == "client_secret_post": + body["client_id"] = self.client_id + body["client_secret"] = self.client_secret + else: + credentials = f"{self.client_id}:{self.client_secret}".encode("utf-8") + headers["Authorization"] = "Basic " + base64.b64encode(credentials).decode("ascii") + + try: + response = requests.post( + self.token_url, + data=body, + headers=headers, + timeout=(CONNECT_TIMEOUT_SECONDS, READ_TIMEOUT_SECONDS), + allow_redirects=False, + stream=True, + ) + except requests.RequestException as exc: + # Avoid leaking request body which contains credentials + raise RuntimeError( + f"OAuth2 token request failed (transport error). host={self._safe_host()} client_id={self.client_id}" + ) from self._sanitize_exception(exc) + + try: + return self._parse_token_response(response) + finally: + try: + response.close() + except Exception: + pass + + def _parse_token_response(self, response: requests.Response) -> dict: + if response.is_redirect or 300 <= response.status_code < 400: + raise RuntimeError( + "OAuth2 token endpoint returned a redirect; redirects are disabled. " + f"status={response.status_code} host={self._safe_host()} client_id={self.client_id}" + ) + + body_bytes = self._read_capped(response) + body_text = body_bytes.decode("utf-8", errors="replace") if body_bytes else "" + + parsed_json: Optional[dict] = None + if body_text: + try: + import json as _json + + parsed_json = _json.loads(body_text) + if not isinstance(parsed_json, dict): + parsed_json = None + except ValueError: + parsed_json = None + + if not (200 <= response.status_code < 300): + err_code = parsed_json.get("error") if parsed_json else None + err_desc = parsed_json.get("error_description") if parsed_json else None + details = "" + if err_code: + details = f" error={err_code}" + if err_desc: + details += f" error_description={err_desc}" + raise RuntimeError( + f"OAuth2 token endpoint returned status {response.status_code}.{details} " + f"host={self._safe_host()} client_id={self.client_id}" + ) + + if parsed_json is None: + raise RuntimeError( + f"OAuth2 token endpoint returned non-JSON response. status={response.status_code} " + f"host={self._safe_host()} client_id={self.client_id}" + ) + + access_token = parsed_json.get("access_token") + if not access_token or not isinstance(access_token, str): + raise RuntimeError( + f"OAuth2 token response is missing 'access_token'. host={self._safe_host()} client_id={self.client_id}" + ) + + token_type = parsed_json.get("token_type", "Bearer") + if not isinstance(token_type, str) or token_type.lower() != "bearer": + raise RuntimeError( + f"OAuth2 token response token_type '{token_type}' is not supported; only Bearer is accepted. " + f"host={self._safe_host()} client_id={self.client_id}" + ) + + expires_in_raw = parsed_json.get("expires_in") + try: + expires_in = int(expires_in_raw) if expires_in_raw is not None else 0 + except (TypeError, ValueError): + expires_in = 0 + + if expires_in <= 0: + if not self._missing_expires_in_logged: + logger.warning( + "OAuth2 token response omitted or returned invalid 'expires_in'; defaulting to %ss. host=%s", + DEFAULT_EXPIRES_IN_SECONDS, + self._safe_host(), + ) + self._missing_expires_in_logged = True + expires_in = DEFAULT_EXPIRES_IN_SECONDS + + expires_at = time.time() + expires_in - EXPIRY_SKEW_SECONDS + + return { + "access_token": access_token, + "token_type": token_type, + "expires_at": expires_at, + } + + def _read_capped(self, response: requests.Response) -> bytes: + """Read response body up to MAX_RESPONSE_BYTES; abort if exceeded.""" + chunks: list = [] + total = 0 + try: + for chunk in response.iter_content(chunk_size=4096): + if not chunk: + continue + total += len(chunk) + if total > MAX_RESPONSE_BYTES: + raise RuntimeError( + f"OAuth2 token response exceeded {MAX_RESPONSE_BYTES} bytes; aborting. " + f"host={self._safe_host()} client_id={self.client_id}" + ) + chunks.append(chunk) + except requests.RequestException as exc: + raise RuntimeError( + f"OAuth2 token response read error. host={self._safe_host()} client_id={self.client_id}" + ) from self._sanitize_exception(exc) + return b"".join(chunks) + + def _read_cache(self) -> Optional[dict]: + if self.handler_storage is not None: + try: + cached = self.handler_storage.encrypted_json_get(self.storage_key) + if cached: + return cached + except Exception as exc: + logger.debug( + "OAuth2 token cache read failed; falling back to in-memory cache. host=%s err=%s", + self._safe_host(), + exc, + ) + return self._memory_cache + + def _write_cache(self, token: dict) -> None: + # Only persist the minimal token shape — never credentials or config + cache_entry = { + "access_token": token["access_token"], + "token_type": token["token_type"], + "expires_at": token["expires_at"], + } + self._memory_cache = cache_entry + if self.handler_storage is not None: + try: + self.handler_storage.encrypted_json_set(self.storage_key, cache_entry) + except Exception as exc: + logger.debug( + "OAuth2 token cache write failed; falling back to in-memory cache. host=%s err=%s", + self._safe_host(), + exc, + ) + + @staticmethod + def _is_expired(cached: dict) -> bool: + expires_at = cached.get("expires_at") + if not isinstance(expires_at, (int, float)): + return True + return time.time() >= expires_at + + def _safe_host(self) -> str: + try: + return urlparse(self.token_url).hostname or "" + except Exception: + return "" + + def _sanitize_exception(self, exc: BaseException) -> BaseException: + """Rebuild an exception with secrets redacted from its message.""" + text = str(exc) + redacted = text + for secret in (self.client_secret,): + if secret and secret in redacted: + redacted = redacted.replace(secret, "***") + if redacted == text: + return exc + return type(exc)(redacted) diff --git a/tests/unit/api/http/test_integrations_passthrough.py b/tests/unit/api/http/test_integrations_passthrough.py new file mode 100644 index 0000000000..1ea419e554 --- /dev/null +++ b/tests/unit/api/http/test_integrations_passthrough.py @@ -0,0 +1,404 @@ +"""HTTP-layer tests for the /api/integrations//passthrough routes. + +Exercises the Flask blueprint in isolation: the session's integration +controller is mocked to return handlers that satisfy +PassthroughProtocol, so these tests do not touch real handlers and do +not make network calls. +""" + +from http import HTTPStatus +from unittest.mock import MagicMock, patch + +from mindsdb.integrations.libs.passthrough import PassthroughMixin +from mindsdb.integrations.libs.passthrough_types import PassthroughResponse + + +class _StubPassthroughHandler(PassthroughMixin): + """Handler double: the HTTP layer checks the PassthroughProtocol, then + calls `api_passthrough`. We bypass all mixin internals by overriding + `api_passthrough` directly so the endpoint test does not depend on + connection_data, base_url resolution, or the requests library.""" + + def __init__(self, response: PassthroughResponse): + self._response = response + self.calls: list = [] + + def api_passthrough(self, req): # type: ignore[override] + self.calls.append(req) + return self._response + + def test_passthrough(self): + return {"ok": True, "status_code": self._response.status_code} + + +def _patch_handler(handler): + """Patch FakeMysqlProxy so the endpoint resolves `name` to `handler`.""" + proxy = MagicMock() + proxy.session.integration_controller.get_data_handler.return_value = handler + return patch( + "mindsdb.api.http.namespaces.integrations.FakeMysqlProxy", + return_value=proxy, + ) + + +def test_passthrough_happy_path_returns_200_and_serialized_body(client): + handler = _StubPassthroughHandler( + PassthroughResponse( + status_code=200, + headers={"X-Safe": "1"}, + body={"hello": "world"}, + content_type="application/json", + ) + ) + + with _patch_handler(handler): + response = client.post( + "/api/integrations/any_ds/passthrough", + json={"method": "GET", "path": "/me"}, + ) + + assert response.status_code == HTTPStatus.OK + payload = response.get_json() + assert payload == { + "status_code": 200, + "headers": {"X-Safe": "1"}, + "body": {"hello": "world"}, + "content_type": "application/json", + } + # Request actually reached the mixin with the parsed PassthroughRequest. + assert len(handler.calls) == 1 + assert handler.calls[0].method == "GET" + assert handler.calls[0].path == "/me" + + +def test_passthrough_returns_501_when_handler_does_not_support_mixin(client): + # A bare object does not satisfy PassthroughProtocol, so the endpoint + # should surface passthrough_not_supported (501) instead of a 500. + with _patch_handler(object()): + response = client.post( + "/api/integrations/mysql/passthrough", + json={"method": "GET", "path": "/anything"}, + ) + + assert response.status_code == HTTPStatus.NOT_IMPLEMENTED + payload = response.get_json() + assert payload["error_code"] == "passthrough_not_supported" + assert "mysql" in payload["message"] + + +def test_passthrough_returns_400_on_invalid_method(client): + handler = _StubPassthroughHandler(PassthroughResponse(status_code=200, headers={}, body=None, content_type=None)) + + with _patch_handler(handler): + response = client.post( + "/api/integrations/any_ds/passthrough", + json={"method": "TRACE", "path": "/me"}, + ) + + assert response.status_code == HTTPStatus.BAD_REQUEST + payload = response.get_json() + assert payload["error_code"] == "invalid_request" + # The handler must not have been invoked when validation fails up front. + assert handler.calls == [] + + +def _patch_handler_modules(modules: dict): + return patch( + "mindsdb.api.http.namespaces.integrations.integration_controller.handler_modules", + modules, + create=True, + ) + + +def test_capabilities_returns_handlers_dict_and_legacy_list(client): + # Two opted-in handlers covering both auth modes, one non-opt-in, and + # one broken module that lacks a Handler attribute. auth_modes is + # surfaced from the handler's declarative `_auth_mode` class attr — + # not inferred from header format. + class _BearerHandler(PassthroughMixin): + pass # inherits _auth_mode = "bearer" + + class _CustomHeaderHandler(PassthroughMixin): + _auth_header_name = "X-Shopify-Access-Token" + _auth_header_format = "{token}" + _auth_mode = "custom" + + class _NotOptedIn: + pass + + bearer_mod = MagicMock() + bearer_mod.Handler = _BearerHandler + custom_mod = MagicMock() + custom_mod.Handler = _CustomHeaderHandler + plain_mod = MagicMock() + plain_mod.Handler = _NotOptedIn + no_handler_mod = MagicMock(spec=[]) + + fake_modules = { + "hubspot": bearer_mod, + "shopify": custom_mod, + "mysql": plain_mod, + "broken": no_handler_mod, + } + + with _patch_handler_modules(fake_modules): + response = client.get("/api/integrations/capabilities") + + assert response.status_code == HTTPStatus.OK + payload = response.get_json() + + # New structured shape: every opted-in handler appears with auth_modes + # and operations metadata. + assert payload["handlers"] == { + "hubspot": {"auth_modes": ["bearer"], "operations": ["passthrough"]}, + "shopify": {"auth_modes": ["custom"], "operations": ["passthrough"]}, + } + + # Legacy flat list: only bearer-auth handlers (Minds migration compat). + assert payload["bearer_passthrough"] == ["hubspot"] + + +def test_capabilities_auth_mode_is_declarative_not_format_derived(client): + # Handler keeps the default "Bearer {token}" header format but flags + # itself as oauth_refresh. The old format-matching heuristic would + # have bucketed this as "bearer"; the new declarative path returns + # the explicit mode and correctly omits it from the legacy list. + class _OAuthRefreshHandler(PassthroughMixin): + _auth_mode = "oauth_refresh" + # _auth_header_format intentionally left as the default. + + oauth_mod = MagicMock() + oauth_mod.Handler = _OAuthRefreshHandler + + with _patch_handler_modules({"hubspot_oauth": oauth_mod}): + response = client.get("/api/integrations/capabilities") + + assert response.status_code == HTTPStatus.OK + payload = response.get_json() + assert payload["handlers"] == { + "hubspot_oauth": {"auth_modes": ["oauth_refresh"], "operations": ["passthrough"]}, + } + # oauth_refresh is NOT surfaced in the legacy bearer-only list even + # though the underlying header format is still "Bearer {token}". + assert payload["bearer_passthrough"] == [] + + +def test_capabilities_empty_when_no_handlers_opted_in(client): + class _NotOptedIn: + pass + + plain_mod = MagicMock() + plain_mod.Handler = _NotOptedIn + + with _patch_handler_modules({"mysql": plain_mod}): + response = client.get("/api/integrations/capabilities") + + assert response.status_code == HTTPStatus.OK + payload = response.get_json() + assert payload == {"handlers": {}, "bearer_passthrough": []} + + +# --------------------------------------------------------------------------- +# Multi-mode (`_auth_modes`) capability resolution +# --------------------------------------------------------------------------- + + +def test_capabilities_handler_with_only_auth_mode_string_unchanged(client): + # A handler that declares only the legacy `_auth_mode = "bearer"` + # continues to surface as auth_modes: ["bearer"]. This pins the + # backward-compat path of the new resolver. + class _LegacyBearerHandler(PassthroughMixin): + pass # _auth_mode defaults to "bearer" via the mixin + + mod = MagicMock() + mod.Handler = _LegacyBearerHandler + + with _patch_handler_modules({"legacy": mod}): + response = client.get("/api/integrations/capabilities") + + assert response.status_code == HTTPStatus.OK + payload = response.get_json() + assert payload["handlers"] == { + "legacy": {"auth_modes": ["bearer"], "operations": ["passthrough"]}, + } + assert payload["bearer_passthrough"] == ["legacy"] + + +def test_capabilities_handler_with_auth_modes_list_returns_all_modes(client): + # The rest_api shape: a single handler advertising both bearer and + # oauth_client_credentials. The endpoint must return both modes + # verbatim and include the handler in `bearer_passthrough` because + # "bearer" is among them. + class _MultiAuthHandler(PassthroughMixin): + _auth_modes = ["bearer", "oauth_client_credentials"] + + mod = MagicMock() + mod.Handler = _MultiAuthHandler + + with _patch_handler_modules({"rest_api": mod}): + response = client.get("/api/integrations/capabilities") + + assert response.status_code == HTTPStatus.OK + payload = response.get_json() + assert payload["handlers"] == { + "rest_api": { + "auth_modes": ["bearer", "oauth_client_credentials"], + "operations": ["passthrough"], + }, + } + assert payload["bearer_passthrough"] == ["rest_api"] + + +def test_capabilities_auth_modes_takes_precedence_over_auth_mode(client): + # When both fields are declared (as on RestApiHandler — _auth_modes for + # the new shape, _auth_mode kept as a fallback), the list wins. This is + # the contract the resolver promises. + class _DualDeclaredHandler(PassthroughMixin): + _auth_modes = ["bearer", "oauth_client_credentials"] + _auth_mode = "bearer" + + mod = MagicMock() + mod.Handler = _DualDeclaredHandler + + with _patch_handler_modules({"rest_api": mod}): + response = client.get("/api/integrations/capabilities") + + payload = response.get_json() + assert payload["handlers"]["rest_api"]["auth_modes"] == [ + "bearer", + "oauth_client_credentials", + ] + + +def test_capabilities_bearer_passthrough_membership_per_auth_modes(client): + # `bearer_passthrough` is populated based on whether "bearer" appears + # in the resolved auth_modes, not on the legacy `_auth_mode` field. + class _BearerOnly(PassthroughMixin): + _auth_modes = ["bearer"] + + class _OAuthOnly(PassthroughMixin): + _auth_modes = ["oauth_client_credentials"] + + class _Multi(PassthroughMixin): + _auth_modes = ["bearer", "oauth_client_credentials"] + + bearer_mod = MagicMock() + bearer_mod.Handler = _BearerOnly + oauth_mod = MagicMock() + oauth_mod.Handler = _OAuthOnly + multi_mod = MagicMock() + multi_mod.Handler = _Multi + + with _patch_handler_modules({"a_bearer": bearer_mod, "b_oauth": oauth_mod, "c_multi": multi_mod}): + response = client.get("/api/integrations/capabilities") + + payload = response.get_json() + assert payload["bearer_passthrough"] == ["a_bearer", "c_multi"] + assert "b_oauth" not in payload["bearer_passthrough"] + + +def test_capabilities_handler_with_empty_auth_modes_falls_back_to_auth_mode(client): + # An empty list is treated as "not declared" and the resolver falls + # back to `_auth_mode`, which itself falls back to "bearer". + class _EmptyListHandler(PassthroughMixin): + _auth_modes = [] + _auth_mode = "custom" + + mod = MagicMock() + mod.Handler = _EmptyListHandler + + with _patch_handler_modules({"odd": mod}): + response = client.get("/api/integrations/capabilities") + + payload = response.get_json() + assert payload["handlers"]["odd"]["auth_modes"] == ["custom"] + + +def test_capabilities_endpoint_stable_when_handler_module_is_broken(client): + # If reading attrs off one module raises, the endpoint must still + # return 200 with the other handlers intact. + class _Healthy(PassthroughMixin): + _auth_modes = ["bearer"] + + healthy_mod = MagicMock() + healthy_mod.Handler = _Healthy + + # An "evil" module whose Handler attribute access raises. Use a + # PropertyMock-style trick: set Handler to a property that raises. + class _ExplodingModule: + @property + def Handler(self): # noqa: N802 — matches handler-module API + raise RuntimeError("module import side-effect blew up") + + broken_mod = _ExplodingModule() + + with _patch_handler_modules({"healthy": healthy_mod, "broken": broken_mod}): + response = client.get("/api/integrations/capabilities") + + assert response.status_code == HTTPStatus.OK + payload = response.get_json() + # Healthy handler still surfaces; broken one is silently skipped. + assert payload["handlers"] == { + "healthy": {"auth_modes": ["bearer"], "operations": ["passthrough"]}, + } + assert payload["bearer_passthrough"] == ["healthy"] + + +# --------------------------------------------------------------------------- +# Direct unit tests for the resolver helper +# --------------------------------------------------------------------------- + + +def test_resolve_auth_modes_prefers_list_over_string(): + from mindsdb.api.http.namespaces.integrations import _resolve_auth_modes + + class H: + _auth_modes = ["bearer", "oauth_client_credentials"] + _auth_mode = "bearer" + + assert _resolve_auth_modes(H) == ["bearer", "oauth_client_credentials"] + + +def test_resolve_auth_modes_falls_back_to_auth_mode_string(): + from mindsdb.api.http.namespaces.integrations import _resolve_auth_modes + + class H: + _auth_mode = "custom" + + assert _resolve_auth_modes(H) == ["custom"] + + +def test_resolve_auth_modes_defaults_to_bearer_when_nothing_declared(): + from mindsdb.api.http.namespaces.integrations import _resolve_auth_modes + + class H: + pass + + assert _resolve_auth_modes(H) == ["bearer"] + + +def test_resolve_auth_modes_handles_none_handler_class(): + from mindsdb.api.http.namespaces.integrations import _resolve_auth_modes + + assert _resolve_auth_modes(None) == ["bearer"] + + +def test_resolve_auth_modes_normalizes_tuple_to_list_of_strings(): + from mindsdb.api.http.namespaces.integrations import _resolve_auth_modes + + class H: + _auth_modes = ("bearer", "oauth_client_credentials") + + result = _resolve_auth_modes(H) + assert result == ["bearer", "oauth_client_credentials"] + assert isinstance(result, list) + + +def test_resolve_auth_modes_skips_empty_list_and_falls_through(): + from mindsdb.api.http.namespaces.integrations import _resolve_auth_modes + + class H: + _auth_modes = [] + _auth_mode = "custom" + + assert _resolve_auth_modes(H) == ["custom"] diff --git a/tests/unit/handlers/test_hubspot.py b/tests/unit/handlers/test_hubspot.py index 5f5ad08e5f..edbb3732ec 100644 --- a/tests/unit/handlers/test_hubspot.py +++ b/tests/unit/handlers/test_hubspot.py @@ -1196,5 +1196,58 @@ def test_multijoin_query_handling(self): self.assertIn("not supported", response.error_message) +class TestHubspotPassthrough(unittest.TestCase): + """Exercise the PassthroughMixin retrofit (PAT path).""" + + def _mock_response(self, status_code=200): + resp = MagicMock() + resp.status_code = status_code + resp.headers = {"Content-Type": "application/json"} + resp.iter_content = MagicMock(return_value=iter([b'{"results":[]}'])) + resp.close = MagicMock() + return resp + + @patch("mindsdb.integrations.libs.passthrough.requests.request") + def test_passthrough_uses_bearer_and_hubspot_base_url(self, mock_request): + mock_request.return_value = self._mock_response() + handler = HubspotHandler( + "hubspot", + connection_data={"access_token": "pat-abc123xyz"}, + ) + from mindsdb.integrations.libs.passthrough_types import PassthroughRequest + + resp = handler.api_passthrough(PassthroughRequest("GET", "/crm/v3/owners")) + + self.assertEqual(resp.status_code, 200) + args, kwargs = mock_request.call_args + self.assertEqual(args[0], "GET") + self.assertEqual(args[1], "https://api.hubapi.com/crm/v3/owners") + self.assertEqual(kwargs["headers"]["Authorization"], "Bearer pat-abc123xyz") + + @patch("mindsdb.integrations.libs.passthrough.requests.request") + def test_test_passthrough_returns_ok_on_200(self, mock_request): + mock_request.return_value = self._mock_response(status_code=200) + handler = HubspotHandler("hubspot", connection_data={"access_token": "pat"}) + + result = handler.test_passthrough() + + self.assertTrue(result["ok"]) + self.assertEqual(result["status_code"], 200) + self.assertEqual(result["host"], "api.hubapi.com") + self.assertIsInstance(result["latency_ms"], int) + + @patch("mindsdb.integrations.libs.passthrough.requests.request") + def test_test_passthrough_returns_auth_failed_on_401(self, mock_request): + mock_request.return_value = self._mock_response(status_code=401) + handler = HubspotHandler("hubspot", connection_data={"access_token": "pat"}) + + result = handler.test_passthrough() + + self.assertFalse(result["ok"]) + self.assertEqual(result["error_code"], "auth_failed") + self.assertEqual(result["status_code"], 401) + self.assertEqual(result["host"], "api.hubapi.com") + + if __name__ == "__main__": unittest.main() diff --git a/tests/unit/handlers/test_rest_api.py b/tests/unit/handlers/test_rest_api.py new file mode 100644 index 0000000000..c220f03c77 --- /dev/null +++ b/tests/unit/handlers/test_rest_api.py @@ -0,0 +1,1433 @@ +"""Unit tests for the generic REST API passthrough handler.""" + +from unittest.mock import patch, MagicMock + +from mindsdb.integrations.handlers.rest_api_handler import ( + connection_args as exported_connection_args, +) +from mindsdb.integrations.handlers.rest_api_handler.connection_args import connection_args +from mindsdb.integrations.handlers.rest_api_handler.oauth_connection_args import ( + oauth_connection_args, +) +from mindsdb.integrations.handlers.rest_api_handler.rest_connection_args import ( + rest_connection_args, +) +from mindsdb.integrations.handlers.rest_api_handler.rest_api_handler import RestApiHandler +from mindsdb.integrations.libs.const import HANDLER_CONNECTION_ARG_TYPE as ARG_TYPE +from mindsdb.integrations.libs.passthrough import PassthroughProtocol +from mindsdb.integrations.libs.passthrough_types import PassthroughRequest, PassthroughResponse +from mindsdb.integrations.libs.response import ( + HandlerStatusResponse as StatusResponse, +) + + +VALID_DATA = { + "base_url": "https://api.example.com", + "bearer_token": "test-token-123", +} + + +def _make_handler(connection_data=None): + if connection_data is None: + connection_data = dict(VALID_DATA) + return RestApiHandler("test_rest", connection_data=connection_data) + + +# --------------------------------------------------------------------------- +# Shared stub helpers for OAuth tests +# +# Hoisted to the top so any test class can use them. The helpers monkeypatch +# the OAuth provider's DNS / token-endpoint POST, and the passthrough mixin's +# upstream `requests.request` call. +# --------------------------------------------------------------------------- + + +def _stub_oauth_dns(monkeypatch): + """Stub socket.getaddrinfo so the provider's SSRF check passes for hostnames.""" + from mindsdb.integrations.utilities.handlers.auth_utilities.oauth2 import ( + client_credentials as cc_module, + ) + + def fake_getaddrinfo(host, *args, **kwargs): + return [(2, 1, 6, "", ("1.2.3.4", 0))] + + monkeypatch.setattr(cc_module.socket, "getaddrinfo", fake_getaddrinfo) + + +def _stub_token_endpoint(monkeypatch, access_token="OAUTH-AT", expires_in=3600): + """Patch the token POST to return a synthetic access token. + + Returns a counter dict with key 'n' incremented on each call, so tests + can assert how many times the IdP was hit. + """ + from mindsdb.integrations.utilities.handlers.auth_utilities.oauth2 import ( + client_credentials as cc_module, + ) + + calls = {"n": 0} + + def fake_post(url, data=None, headers=None, **_): + calls["n"] += 1 + + class _Resp: + status_code = 200 + is_redirect = False + headers = {} + + def iter_content(self, chunk_size=4096): + import json as _json + + payload = _json.dumps({"access_token": access_token, "expires_in": expires_in}).encode("utf-8") + yield payload + + def close(self): + pass + + return _Resp() + + monkeypatch.setattr(cc_module.requests, "post", fake_post) + return calls + + +def _stub_upstream(monkeypatch): + """Patch the upstream request so we can inspect the outgoing call.""" + captured = {} + + def fake_request(method, url, **kwargs): + captured["method"] = method + captured["url"] = url + captured["headers"] = kwargs.get("headers", {}) + + resp = MagicMock() + resp.status_code = 200 + resp.headers = {} + resp.iter_content.return_value = [b""] + resp.close = MagicMock() + return resp + + from mindsdb.integrations.libs import passthrough as pt_module + + monkeypatch.setattr(pt_module.requests, "request", fake_request) + return captured + + +class TestRestApiHandlerInit: + def test_satisfies_passthrough_protocol(self): + assert issubclass(RestApiHandler, PassthroughProtocol) + + def test_stores_connection_data(self): + data = {"base_url": "https://x.com", "bearer_token": "tok"} + handler = _make_handler(data) + assert handler.connection_data == data + + def test_default_test_request_path(self): + handler = _make_handler() + assert handler._test_request.method == "GET" + assert handler._test_request.path == "/" + + def test_custom_test_path(self): + handler = _make_handler( + { + "base_url": "https://api.example.com", + "bearer_token": "tok", + "test_path": "/health", + } + ) + assert handler._test_request.path == "/health" + + def test_custom_test_path_without_slash(self): + handler = _make_handler( + { + "base_url": "https://api.example.com", + "bearer_token": "tok", + "test_path": "status", + } + ) + assert handler._test_request.path == "/status" + + +class TestCheckConnection: + def test_success(self): + handler = _make_handler() + response = handler.check_connection() + assert isinstance(response, StatusResponse) + assert response.success is True + assert not response.error_message + + def test_missing_base_url(self): + handler = _make_handler({"bearer_token": "tok"}) + response = handler.check_connection() + assert response.success is False + assert "base_url" in response.error_message + + def test_missing_bearer_token(self): + handler = _make_handler({"base_url": "https://api.example.com"}) + response = handler.check_connection() + assert response.success is False + assert "bearer_token" in response.error_message + + def test_empty_connection_data(self): + handler = _make_handler({}) + response = handler.check_connection() + assert response.success is False + + +class TestPassthroughIntegration: + """Test that the mixin methods work correctly on RestApiHandler.""" + + @patch("mindsdb.integrations.libs.passthrough.requests.request") + def test_api_passthrough_injects_bearer(self, mock_request): + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.headers = {"Content-Type": "application/json"} + mock_resp.iter_content.return_value = [b'{"ok": true}'] + mock_resp.close = MagicMock() + mock_request.return_value = mock_resp + + handler = _make_handler() + result = handler.api_passthrough(PassthroughRequest(method="GET", path="/v1/users")) + + assert isinstance(result, PassthroughResponse) + assert result.status_code == 200 + headers = mock_request.call_args.kwargs["headers"] + assert headers["Authorization"] == "Bearer test-token-123" + + @patch("mindsdb.integrations.libs.passthrough.requests.request") + def test_api_passthrough_uses_base_url(self, mock_request): + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.headers = {} + mock_resp.iter_content.return_value = [b""] + mock_resp.close = MagicMock() + mock_request.return_value = mock_resp + + handler = _make_handler() + handler.api_passthrough(PassthroughRequest(method="GET", path="/foo")) + + called_url = mock_request.call_args.args[1] + assert called_url == "https://api.example.com/foo" + + @patch("mindsdb.integrations.libs.passthrough.requests.request") + def test_api_passthrough_includes_default_headers(self, mock_request): + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.headers = {} + mock_resp.iter_content.return_value = [b""] + mock_resp.close = MagicMock() + mock_request.return_value = mock_resp + + handler = _make_handler( + { + "base_url": "https://api.example.com", + "bearer_token": "tok", + "default_headers": {"Accept": "application/json", "X-Team": "data"}, + } + ) + handler.api_passthrough(PassthroughRequest(method="GET", path="/")) + + headers = mock_request.call_args.kwargs["headers"] + assert headers["Accept"] == "application/json" + assert headers["X-Team"] == "data" + + @patch("mindsdb.integrations.libs.passthrough.requests.request") + def test_test_passthrough_success(self, mock_request): + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.headers = {"Content-Type": "application/json"} + mock_resp.iter_content.return_value = [b'{"ok": true}'] + mock_resp.close = MagicMock() + mock_request.return_value = mock_resp + + handler = _make_handler() + result = handler.test_passthrough() + + assert isinstance(result, dict) + assert result["ok"] is True + assert result["status_code"] == 200 + + def test_test_passthrough_with_no_network(self): + """test_passthrough catches connection errors gracefully.""" + handler = _make_handler() + result = handler.test_passthrough() + assert isinstance(result, dict) + assert result["ok"] is False + assert result["error_code"] in ("network", "unknown") + + +class TestConnectionArgsSchema: + """The exported connection_args is the union of REST + auth modules.""" + + def test_rest_module_only_holds_passthrough_fields(self): + assert set(rest_connection_args.keys()) == { + "base_url", + "default_headers", + "allowed_hosts", + "test_path", + } + + def test_auth_module_only_holds_auth_fields(self): + assert set(oauth_connection_args.keys()) == { + "auth_type", + "bearer_token", + "token_url", + "client_id", + "client_secret", + "scope", + "audience", + "token_auth_method", + } + + def test_bearer_token_lives_in_auth_module(self): + # bearer_token is an auth strategy, not a REST/passthrough setting. + assert "bearer_token" not in rest_connection_args + assert "bearer_token" in oauth_connection_args + + def test_aggregated_connection_args_includes_rest_fields(self): + for key in ("base_url", "default_headers", "allowed_hosts", "test_path"): + assert key in connection_args + + def test_aggregated_connection_args_includes_bearer_token(self): + assert "bearer_token" in connection_args + + def test_aggregated_connection_args_includes_oauth_fields(self): + for key in ( + "auth_type", + "token_url", + "client_id", + "client_secret", + "scope", + "audience", + "token_auth_method", + ): + assert key in connection_args + + def test_client_secret_marked_secret_and_pwd(self): + spec = connection_args["client_secret"] + assert spec["type"] == ARG_TYPE.PWD + assert spec.get("secret") is True + + def test_bearer_token_marked_secret_and_pwd(self): + spec = connection_args["bearer_token"] + assert spec["type"] == ARG_TYPE.PWD + assert spec.get("secret") is True + + def test_package_exports_aggregated_args(self): + # The handler package re-exports connection_args; make sure the + # aggregator and the package-level export are the same object. + assert exported_connection_args is connection_args + + +class TestBackwardCompatibleBearerInit: + """Existing bearer-only configs must still initialize and validate.""" + + def test_legacy_config_initializes(self): + handler = _make_handler( + { + "base_url": "https://api.example.com", + "bearer_token": "legacy-token", + } + ) + assert handler.connection_data["base_url"] == "https://api.example.com" + assert handler.connection_data["bearer_token"] == "legacy-token" + + def test_legacy_config_check_connection_succeeds(self): + handler = _make_handler( + { + "base_url": "https://api.example.com", + "bearer_token": "legacy-token", + } + ) + response = handler.check_connection() + assert response.success is True + + +class TestAuthTypeResolution: + def test_default_auth_type_is_bearer(self): + handler = _make_handler() + assert handler._get_auth_type() == "bearer" + + def test_empty_auth_type_treated_as_default(self): + handler = _make_handler( + { + "base_url": "https://api.example.com", + "bearer_token": "tok", + "auth_type": "", + } + ) + assert handler._get_auth_type() == "bearer" + + def test_explicit_auth_type_returned(self): + handler = _make_handler( + { + "base_url": "https://api.example.com", + "auth_type": "oauth_client_credentials", + "token_url": "https://auth.example.com/token", + "client_id": "cid", + "client_secret": "csecret", + } + ) + assert handler._get_auth_type() == "oauth_client_credentials" + + +class TestBearerAuthValidation: + def test_implicit_bearer_with_token_passes(self): + handler = _make_handler({"base_url": "https://api.example.com", "bearer_token": "tok"}) + handler._validate_auth_config() # should not raise + + def test_explicit_bearer_with_token_passes(self): + handler = _make_handler( + { + "base_url": "https://api.example.com", + "bearer_token": "tok", + "auth_type": "bearer", + } + ) + handler._validate_auth_config() + + def test_missing_bearer_token_fails(self): + handler = _make_handler({"base_url": "https://api.example.com", "auth_type": "bearer"}) + try: + handler._validate_auth_config() + except ValueError as e: + assert "bearer_token" in str(e) + else: + raise AssertionError("expected ValueError for missing bearer_token") + + def test_check_connection_missing_bearer_token_message(self): + handler = _make_handler({"base_url": "https://api.example.com", "auth_type": "bearer"}) + response = handler.check_connection() + assert response.success is False + assert "bearer_token" in response.error_message + + def test_bearer_rejects_token_url(self): + handler = _make_handler( + { + "base_url": "https://api.example.com", + "bearer_token": "tok", + "token_url": "https://auth.example.com/token", + } + ) + try: + handler._validate_auth_config() + except ValueError as e: + assert "token_url" in str(e) + else: + raise AssertionError("expected ValueError for token_url in bearer mode") + + def test_bearer_rejects_client_id(self): + handler = _make_handler( + { + "base_url": "https://api.example.com", + "bearer_token": "tok", + "client_id": "cid", + } + ) + try: + handler._validate_auth_config() + except ValueError as e: + assert "client_id" in str(e) + else: + raise AssertionError("expected ValueError for client_id in bearer mode") + + def test_bearer_rejects_client_secret(self): + handler = _make_handler( + { + "base_url": "https://api.example.com", + "bearer_token": "tok", + "client_secret": "secret", + } + ) + try: + handler._validate_auth_config() + except ValueError as e: + assert "client_secret" in str(e) + else: + raise AssertionError("expected ValueError for client_secret in bearer mode") + + def test_bearer_rejects_scope(self): + handler = _make_handler( + { + "base_url": "https://api.example.com", + "bearer_token": "tok", + "scope": "read:all", + } + ) + try: + handler._validate_auth_config() + except ValueError as e: + assert "scope" in str(e) + else: + raise AssertionError("expected ValueError for scope in bearer mode") + + def test_bearer_rejects_audience(self): + handler = _make_handler( + { + "base_url": "https://api.example.com", + "bearer_token": "tok", + "audience": "https://api.example.com", + } + ) + try: + handler._validate_auth_config() + except ValueError as e: + assert "audience" in str(e) + else: + raise AssertionError("expected ValueError for audience in bearer mode") + + def test_bearer_rejects_token_auth_method(self): + handler = _make_handler( + { + "base_url": "https://api.example.com", + "bearer_token": "tok", + "token_auth_method": "client_secret_post", + } + ) + try: + handler._validate_auth_config() + except ValueError as e: + assert "token_auth_method" in str(e) + else: + raise AssertionError("expected ValueError for token_auth_method in bearer mode") + + def test_bearer_ignores_empty_oauth_fields(self): + # UIs may submit empty strings for unfilled fields; treat as absent. + handler = _make_handler( + { + "base_url": "https://api.example.com", + "bearer_token": "tok", + "token_url": "", + "client_id": "", + "client_secret": "", + } + ) + handler._validate_auth_config() + + +class TestOAuthClientCredentialsValidation: + BASE = { + "base_url": "https://api.example.com", + "auth_type": "oauth_client_credentials", + "token_url": "https://auth.example.com/token", + "client_id": "cid", + "client_secret": "csecret", + } + + def test_minimal_valid_config_passes(self): + handler = _make_handler(dict(self.BASE)) + handler._validate_auth_config() + + def test_missing_token_url_fails(self): + data = dict(self.BASE) + del data["token_url"] + handler = _make_handler(data) + try: + handler._validate_auth_config() + except ValueError as e: + assert "token_url" in str(e) + else: + raise AssertionError("expected ValueError for missing token_url") + + def test_missing_client_id_fails(self): + data = dict(self.BASE) + del data["client_id"] + handler = _make_handler(data) + try: + handler._validate_auth_config() + except ValueError as e: + assert "client_id" in str(e) + else: + raise AssertionError("expected ValueError for missing client_id") + + def test_missing_client_secret_fails(self): + data = dict(self.BASE) + del data["client_secret"] + handler = _make_handler(data) + try: + handler._validate_auth_config() + except ValueError as e: + assert "client_secret" in str(e) + else: + raise AssertionError("expected ValueError for missing client_secret") + + def test_oauth_rejects_bearer_token(self): + data = dict(self.BASE) + data["bearer_token"] = "tok" + handler = _make_handler(data) + try: + handler._validate_auth_config() + except ValueError as e: + assert "bearer_token" in str(e) + else: + raise AssertionError("expected ValueError for bearer_token in OAuth mode") + + def test_oauth_default_token_auth_method_passes(self): + # token_auth_method omitted → defaults to client_secret_post. + handler = _make_handler(dict(self.BASE)) + handler._validate_auth_config() + + def test_oauth_explicit_client_secret_basic_passes(self): + data = dict(self.BASE) + data["token_auth_method"] = "client_secret_basic" + handler = _make_handler(data) + handler._validate_auth_config() + + def test_oauth_unsupported_token_auth_method_fails(self): + data = dict(self.BASE) + data["token_auth_method"] = "private_key_jwt" + handler = _make_handler(data) + try: + handler._validate_auth_config() + except ValueError as e: + assert "token_auth_method" in str(e) + assert "private_key_jwt" in str(e) + else: + raise AssertionError("expected ValueError for unsupported token_auth_method") + + def test_oauth_check_connection_succeeds(self, monkeypatch): + # OAuth check_connection now actually fetches a token and runs the + # upstream test request, so we need to stub DNS, the token endpoint, + # and the upstream HTTP call. + _stub_oauth_dns(monkeypatch) + _stub_token_endpoint(monkeypatch) + _stub_upstream(monkeypatch) + + handler = _make_handler(dict(self.BASE)) + response = handler.check_connection() + assert response.success is True + + +class TestUnsupportedAuthType: + def test_unsupported_auth_type_fails(self): + handler = _make_handler( + { + "base_url": "https://api.example.com", + "auth_type": "api_key", + "bearer_token": "tok", + } + ) + try: + handler._validate_auth_config() + except ValueError as e: + assert "auth_type" in str(e) + assert "api_key" in str(e) + else: + raise AssertionError("expected ValueError for unsupported auth_type") + + def test_unsupported_auth_type_via_check_connection(self): + handler = _make_handler( + { + "base_url": "https://api.example.com", + "auth_type": "saml", + } + ) + response = handler.check_connection() + assert response.success is False + assert "auth_type" in response.error_message + + +# --------------------------------------------------------------------------- +# OAuth client credentials integration +# --------------------------------------------------------------------------- + + +class TestOAuthIntegration: + OAUTH_CONFIG = { + "base_url": "https://api.example.com", + "auth_type": "oauth_client_credentials", + "token_url": "https://auth.example.com/token", + "client_id": "cid", + "client_secret": "csec", + } + + def test_oauth_mode_fetches_token_via_provider(self, monkeypatch): + _stub_oauth_dns(monkeypatch) + token_calls = _stub_token_endpoint(monkeypatch) + captured = _stub_upstream(monkeypatch) + + handler = _make_handler(dict(self.OAUTH_CONFIG)) + handler.api_passthrough(PassthroughRequest(method="GET", path="/v1/users")) + + assert token_calls["n"] == 1 + assert handler._oauth_provider is not None + assert captured["headers"]["Authorization"] == "Bearer OAUTH-AT" + + def test_oauth_token_cached_across_calls(self, monkeypatch): + # Two passthrough calls, one token fetch — proves we go through the + # provider's cache rather than the static-token path. + _stub_oauth_dns(monkeypatch) + token_calls = _stub_token_endpoint(monkeypatch) + _stub_upstream(monkeypatch) + + handler = _make_handler(dict(self.OAUTH_CONFIG)) + handler.api_passthrough(PassthroughRequest(method="GET", path="/v1/users")) + handler.api_passthrough(PassthroughRequest(method="GET", path="/v1/orgs")) + + assert token_calls["n"] == 1 + + def test_oauth_caller_authorization_cannot_override_generated_auth(self, monkeypatch): + _stub_oauth_dns(monkeypatch) + _stub_token_endpoint(monkeypatch) + captured = _stub_upstream(monkeypatch) + + handler = _make_handler(dict(self.OAUTH_CONFIG)) + handler.api_passthrough( + PassthroughRequest( + method="GET", + path="/v1/users", + headers={"Authorization": "Bearer attacker-token"}, + ) + ) + + assert captured["headers"]["Authorization"] == "Bearer OAUTH-AT" + assert "attacker-token" not in captured["headers"]["Authorization"] + + def test_oauth_caller_authorization_lowercase_also_rejected(self, monkeypatch): + # FORBIDDEN_REQUEST_HEADERS check is case-insensitive; verify a + # lowercase header from the caller still loses to the generated auth. + _stub_oauth_dns(monkeypatch) + _stub_token_endpoint(monkeypatch) + captured = _stub_upstream(monkeypatch) + + handler = _make_handler(dict(self.OAUTH_CONFIG)) + handler.api_passthrough( + PassthroughRequest( + method="GET", + path="/foo", + headers={"authorization": "Bearer attacker-token"}, + ) + ) + + # The mixin always writes the canonical-cased "Authorization" header. + assert captured["headers"]["Authorization"] == "Bearer OAUTH-AT" + # The caller-supplied lowercase variant should not have leaked through. + # If anything's in headers under "authorization" lowercase, it must + # not be the attacker token. + if "authorization" in captured["headers"]: + assert "attacker-token" not in captured["headers"]["authorization"] + + def test_oauth_passthrough_works_without_caller_auth_headers(self, monkeypatch): + # Anton-style call: no headers at all on the PassthroughRequest. + _stub_oauth_dns(monkeypatch) + _stub_token_endpoint(monkeypatch) + captured = _stub_upstream(monkeypatch) + + handler = _make_handler(dict(self.OAUTH_CONFIG)) + handler.api_passthrough(PassthroughRequest(method="GET", path="/v1/users")) + + assert captured["headers"]["Authorization"] == "Bearer OAUTH-AT" + + def test_bearer_mode_does_not_instantiate_oauth_provider(self, monkeypatch): + _stub_upstream(monkeypatch) + + handler = _make_handler() # bearer mode (default) + handler.api_passthrough(PassthroughRequest(method="GET", path="/foo")) + + assert handler._oauth_provider is None + + def test_implicit_bearer_mode_still_injects_static_token(self, monkeypatch): + # Config without auth_type → defaults to bearer → static token used. + captured = _stub_upstream(monkeypatch) + + handler = _make_handler({"base_url": "https://api.example.com", "bearer_token": "static-tok"}) + handler.api_passthrough(PassthroughRequest(method="GET", path="/foo")) + + assert captured["headers"]["Authorization"] == "Bearer static-tok" + assert handler._oauth_provider is None + + def test_explicit_bearer_mode_still_injects_static_token(self, monkeypatch): + captured = _stub_upstream(monkeypatch) + + handler = _make_handler( + { + "base_url": "https://api.example.com", + "auth_type": "bearer", + "bearer_token": "static-tok", + } + ) + handler.api_passthrough(PassthroughRequest(method="GET", path="/foo")) + + assert captured["headers"]["Authorization"] == "Bearer static-tok" + assert handler._oauth_provider is None + + def test_bearer_caller_authorization_cannot_override_generated_auth(self, monkeypatch): + captured = _stub_upstream(monkeypatch) + + handler = _make_handler({"base_url": "https://api.example.com", "bearer_token": "static-tok"}) + handler.api_passthrough( + PassthroughRequest( + method="GET", + path="/foo", + headers={"Authorization": "Bearer attacker-token"}, + ) + ) + + assert captured["headers"]["Authorization"] == "Bearer static-tok" + assert "attacker-token" not in captured["headers"]["Authorization"] + + def test_oauth_storage_key_includes_handler_name(self, monkeypatch): + # Two providers built for two different handler names must use + # distinct storage keys, so token caches don't collide on a shared + # handler_storage instance. + _stub_oauth_dns(monkeypatch) + _stub_token_endpoint(monkeypatch) + _stub_upstream(monkeypatch) + + h1 = RestApiHandler("ds_one", connection_data=dict(self.OAUTH_CONFIG)) + h2 = RestApiHandler("ds_two", connection_data=dict(self.OAUTH_CONFIG)) + h1.api_passthrough(PassthroughRequest(method="GET", path="/foo")) + h2.api_passthrough(PassthroughRequest(method="GET", path="/foo")) + + assert h1._oauth_provider.storage_key != h2._oauth_provider.storage_key + assert "ds_one" in h1._oauth_provider.storage_key + assert "ds_two" in h2._oauth_provider.storage_key + + def test_oauth_provider_receives_handler_storage(self, monkeypatch): + _stub_oauth_dns(monkeypatch) + _stub_token_endpoint(monkeypatch) + _stub_upstream(monkeypatch) + + sentinel_storage = MagicMock() + sentinel_storage.encrypted_json_get.return_value = None + handler = RestApiHandler( + "test_rest", + connection_data=dict(self.OAUTH_CONFIG), + handler_storage=sentinel_storage, + ) + handler.api_passthrough(PassthroughRequest(method="GET", path="/foo")) + + assert handler._oauth_provider.handler_storage is sentinel_storage + + +# --------------------------------------------------------------------------- +# Response scrubbing +# --------------------------------------------------------------------------- + + +REDACTED = "[REDACTED_API_KEY]" + + +def _stub_upstream_with_body(monkeypatch, body_bytes, content_type="text/plain"): + """Patch the upstream so its response body contains caller-controlled bytes. + + Returns the captured-headers dict (populated on each call) so tests can + confirm what was sent in addition to what came back. + """ + captured = {} + + def fake_request(method, url, **kwargs): + captured["headers"] = kwargs.get("headers", {}) + resp = MagicMock() + resp.status_code = 200 + resp.headers = {"Content-Type": content_type} + # `body_bytes` may be a callable so each request can return different + # bytes (e.g., echoing whatever token was sent). + payload = body_bytes(captured) if callable(body_bytes) else body_bytes + resp.iter_content.return_value = [payload] + resp.close = MagicMock() + return resp + + from mindsdb.integrations.libs import passthrough as pt_module + + monkeypatch.setattr(pt_module.requests, "request", fake_request) + return captured + + +class TestResponseScrubbing: + """Tokens — static or rotating — must never reach the runtime caller.""" + + OAUTH_CONFIG = { + "base_url": "https://api.example.com", + "auth_type": "oauth_client_credentials", + "token_url": "https://auth.example.com/token", + "client_id": "cid", + "client_secret": "csec", + } + + def test_static_bearer_scrubbed_when_upstream_echoes_it(self, monkeypatch): + _stub_upstream_with_body(monkeypatch, b"upstream said: test-token-123 hi") + + handler = _make_handler() + result = handler.api_passthrough(PassthroughRequest(method="GET", path="/foo")) + + assert "test-token-123" not in str(result.body) + assert REDACTED in str(result.body) + + def test_oauth_token_scrubbed_when_upstream_echoes_it(self, monkeypatch): + _stub_oauth_dns(monkeypatch) + _stub_token_endpoint(monkeypatch, access_token="OAUTH-AT") + _stub_upstream_with_body(monkeypatch, b"hi OAUTH-AT here is your data") + + handler = _make_handler(dict(self.OAUTH_CONFIG)) + result = handler.api_passthrough(PassthroughRequest(method="GET", path="/foo")) + + assert "OAUTH-AT" not in str(result.body) + assert REDACTED in str(result.body) + + def test_runtime_caller_never_receives_oauth_access_token(self, monkeypatch): + # Belt-and-suspenders for the spec's "runtime caller never receives + # OAuth access token if upstream echoes it" — exercises the full + # api_passthrough path including JSON decoding. + _stub_oauth_dns(monkeypatch) + _stub_token_endpoint(monkeypatch, access_token="SECRET-AT-9000") + _stub_upstream_with_body( + monkeypatch, + b'{"echoed_token": "SECRET-AT-9000", "ok": true}', + content_type="application/json", + ) + + handler = _make_handler(dict(self.OAUTH_CONFIG)) + result = handler.api_passthrough(PassthroughRequest(method="GET", path="/foo")) + + # Body parsed as JSON; serialize it back to verify scrub holds across + # the JSON round-trip. + import json as _json + + rendered = _json.dumps(result.body) + assert "SECRET-AT-9000" not in rendered + assert REDACTED in rendered + + def test_rotated_oauth_token_scrubbed(self, monkeypatch): + # Force the provider to issue a different token after invalidation + # (the manual rotation path). The scrub list should track the + # currently-cached token, not stale ones. + _stub_oauth_dns(monkeypatch) + + from mindsdb.integrations.utilities.handlers.auth_utilities.oauth2 import ( + client_credentials as cc_module, + ) + + token_iter = iter(["TOKEN-A", "TOKEN-B"]) + + def fake_post(*a, **kw): + class _R: + status_code = 200 + is_redirect = False + headers = {} + + def iter_content(self, chunk_size=4096): + import json as _j + + yield _j.dumps({"access_token": next(token_iter), "expires_in": 3600}).encode() + + def close(self): + pass + + return _R() + + monkeypatch.setattr(cc_module.requests, "post", fake_post) + + # Upstream echoes whichever token was sent. + def body_for(captured): + sent = captured["headers"]["Authorization"].replace("Bearer ", "") + return f"echo {sent}".encode() + + _stub_upstream_with_body(monkeypatch, body_for) + + handler = _make_handler(dict(self.OAUTH_CONFIG)) + + # First call — uses TOKEN-A. + r1 = handler.api_passthrough(PassthroughRequest(method="GET", path="/a")) + assert "TOKEN-A" not in str(r1.body) + assert REDACTED in str(r1.body) + + # Rotate. + handler._oauth_provider.clear_cached_token() + + # Second call — uses TOKEN-B; scrub list reflects the new token. + r2 = handler.api_passthrough(PassthroughRequest(method="GET", path="/b")) + assert "TOKEN-B" not in str(r2.body) + assert REDACTED in str(r2.body) + + def test_empty_secrets_dropped_from_scrub_list(self): + # bearer_token is the empty string (e.g. UI submitted blank); the + # scrub list must not contain "" because str.replace("", X) inserts + # the sentinel between every character of the response body. + handler = _make_handler({"base_url": "https://api.example.com", "bearer_token": ""}) + secrets = handler._secrets_for_scrub() + assert "" not in secrets + + def test_duplicate_secrets_deduplicated(self): + # bearer_token equals a default_headers value — both would otherwise + # be appended to the scrub list. The override dedupes. + long_value = "shared-long-value-1234567890" # ≥ 16 chars + handler = _make_handler( + { + "base_url": "https://api.example.com", + "bearer_token": long_value, + "default_headers": {"X-Token": long_value}, + } + ) + secrets = handler._secrets_for_scrub() + assert secrets.count(long_value) == 1 + + def test_oauth_current_secrets_does_not_trigger_token_fetch(self, monkeypatch): + # _secrets_for_scrub must not POST to the IdP just to populate its + # list. current_secrets() returns [] when uncached. + _stub_oauth_dns(monkeypatch) + token_calls = _stub_token_endpoint(monkeypatch) + + handler = _make_handler(dict(self.OAUTH_CONFIG)) + handler._maybe_init_oauth_provider() + + secrets = handler._secrets_for_scrub() + + assert token_calls["n"] == 0 + assert secrets == [] + + def test_bearer_mode_scrub_list_unchanged_with_default_headers(self): + # Pre-existing bearer behavior: the bearer token plus any + # default_headers values >= 16 chars. Order doesn't matter for + # correctness, but membership does. + long_value = "x" * 32 + short_value = "shortie" + handler = _make_handler( + { + "base_url": "https://api.example.com", + "bearer_token": "tok-12345", + "default_headers": {"X-Long": long_value, "X-Short": short_value}, + } + ) + secrets = handler._secrets_for_scrub() + assert "tok-12345" in secrets + assert long_value in secrets + assert short_value not in secrets # too short to be treated as secret + + def test_oauth_mode_does_not_include_static_bearer_token(self, monkeypatch): + # In OAuth mode, even if a stale bearer_token field somehow survived + # in connection_data (it shouldn't — _validate_auth_config rejects + # it), the scrub override skips reading it. Belt-and-suspenders to + # ensure the static field doesn't sneak into the scrub list. + _stub_oauth_dns(monkeypatch) + _stub_token_endpoint(monkeypatch, access_token="LIVE-AT") + _stub_upstream(monkeypatch) + + cfg = dict(self.OAUTH_CONFIG) + # Bypass validation to construct the test scenario. + handler = RestApiHandler("test_rest", connection_data=cfg) + handler.api_passthrough(PassthroughRequest(method="GET", path="/foo")) + + secrets = handler._secrets_for_scrub() + assert "LIVE-AT" in secrets + assert handler.connection_data.get("bearer_token") in (None, "") + + def test_oauth_mode_scrub_does_not_leak_client_secret(self, monkeypatch): + # client_secret must never appear in the scrub list (it shouldn't be + # in responses either, but if it ever were, we don't want a redaction + # path that accidentally implies it's tracked as a "current secret"). + _stub_oauth_dns(monkeypatch) + _stub_token_endpoint(monkeypatch, access_token="LIVE-AT") + _stub_upstream(monkeypatch) + + handler = _make_handler(dict(self.OAUTH_CONFIG)) + handler.api_passthrough(PassthroughRequest(method="GET", path="/foo")) + + secrets = handler._secrets_for_scrub() + assert "csec" not in secrets + assert "cid" not in secrets + + +# --------------------------------------------------------------------------- +# OAuth 401 retry +# --------------------------------------------------------------------------- + + +def _make_response(status_code=200, body=b"", headers=None): + """Build a minimal MagicMock standing in for requests.Response.""" + resp = MagicMock() + resp.status_code = status_code + resp.headers = headers or {} + resp.iter_content.return_value = [body] + resp.close = MagicMock() + return resp + + +def _stub_upstream_with_status_sequence(monkeypatch, statuses, body_for_status=None): + """Patch upstream to return a different status code on each successive call. + + `statuses` is a list of integers; each call pops the next one. After the + list is exhausted, the last status is repeated. + `body_for_status` (optional callable) returns bytes given (call_idx, status, + captured_headers) so a test can shape the body per call. + """ + state = {"calls": [], "idx": 0} + + def fake_request(method, url, **kwargs): + idx = state["idx"] + status = statuses[min(idx, len(statuses) - 1)] + state["idx"] += 1 + captured = { + "method": method, + "url": url, + "headers": dict(kwargs.get("headers", {})), + } + state["calls"].append(captured) + if body_for_status is not None: + body = body_for_status(idx, status, captured["headers"]) + else: + body = b"" + return _make_response(status_code=status, body=body) + + from mindsdb.integrations.libs import passthrough as pt_module + + monkeypatch.setattr(pt_module.requests, "request", fake_request) + return state + + +class TestOAuthRetryOn401: + OAUTH_CONFIG = { + "base_url": "https://api.example.com", + "auth_type": "oauth_client_credentials", + "token_url": "https://auth.example.com/token", + "client_id": "cid", + "client_secret": "csec", + } + + def test_401_clears_cached_token_then_refetches(self, monkeypatch): + _stub_oauth_dns(monkeypatch) + + # Token endpoint returns a different token on each call so we can + # observe the cache having been cleared. + from mindsdb.integrations.utilities.handlers.auth_utilities.oauth2 import ( + client_credentials as cc_module, + ) + + token_iter = iter(["TOKEN-A", "TOKEN-B"]) + token_calls = {"n": 0} + + def fake_post(*a, **kw): + token_calls["n"] += 1 + + class _R: + status_code = 200 + is_redirect = False + headers = {} + + def iter_content(self, chunk_size=4096): + import json as _j + + yield _j.dumps({"access_token": next(token_iter), "expires_in": 3600}).encode() + + def close(self): + pass + + return _R() + + monkeypatch.setattr(cc_module.requests, "post", fake_post) + upstream = _stub_upstream_with_status_sequence(monkeypatch, [401, 200]) + + handler = _make_handler(dict(self.OAUTH_CONFIG)) + result = handler.api_passthrough(PassthroughRequest(method="GET", path="/foo")) + + # Two upstream calls (401 then 200) and two token POSTs (initial fetch + # then post-clear refetch). + assert len(upstream["calls"]) == 2 + assert token_calls["n"] == 2 + assert result.status_code == 200 + + def test_retry_uses_new_token(self, monkeypatch): + _stub_oauth_dns(monkeypatch) + + from mindsdb.integrations.utilities.handlers.auth_utilities.oauth2 import ( + client_credentials as cc_module, + ) + + token_iter = iter(["TOKEN-A", "TOKEN-B"]) + + def fake_post(*a, **kw): + class _R: + status_code = 200 + is_redirect = False + headers = {} + + def iter_content(self, chunk_size=4096): + import json as _j + + yield _j.dumps({"access_token": next(token_iter), "expires_in": 3600}).encode() + + def close(self): + pass + + return _R() + + monkeypatch.setattr(cc_module.requests, "post", fake_post) + upstream = _stub_upstream_with_status_sequence(monkeypatch, [401, 200]) + + handler = _make_handler(dict(self.OAUTH_CONFIG)) + handler.api_passthrough(PassthroughRequest(method="GET", path="/foo")) + + assert upstream["calls"][0]["headers"]["Authorization"] == "Bearer TOKEN-A" + assert upstream["calls"][1]["headers"]["Authorization"] == "Bearer TOKEN-B" + + def test_retried_response_scrubs_new_token(self, monkeypatch): + _stub_oauth_dns(monkeypatch) + + from mindsdb.integrations.utilities.handlers.auth_utilities.oauth2 import ( + client_credentials as cc_module, + ) + + token_iter = iter(["TOKEN-A", "TOKEN-B"]) + + def fake_post(*a, **kw): + class _R: + status_code = 200 + is_redirect = False + headers = {} + + def iter_content(self, chunk_size=4096): + import json as _j + + yield _j.dumps({"access_token": next(token_iter), "expires_in": 3600}).encode() + + def close(self): + pass + + return _R() + + monkeypatch.setattr(cc_module.requests, "post", fake_post) + + # First call: 401 with body echoing TOKEN-A. Second call: 200 with + # body echoing TOKEN-B. Final response should redact TOKEN-B (the + # "new" token), since current_secrets() reflects the post-retry cache. + def body_for(idx, status, headers): + sent = headers["Authorization"].replace("Bearer ", "") + return f"echoed: {sent}".encode() + + upstream = _stub_upstream_with_status_sequence(monkeypatch, [401, 200], body_for_status=body_for) + + handler = _make_handler(dict(self.OAUTH_CONFIG)) + result = handler.api_passthrough(PassthroughRequest(method="GET", path="/foo")) + + assert result.status_code == 200 + assert "TOKEN-B" not in str(result.body) + assert "[REDACTED_API_KEY]" in str(result.body) + assert len(upstream["calls"]) == 2 + + def test_repeated_401_does_not_loop(self, monkeypatch): + _stub_oauth_dns(monkeypatch) + token_calls = _stub_token_endpoint(monkeypatch) + # Upstream returns 401 every time. + upstream = _stub_upstream_with_status_sequence(monkeypatch, [401]) + + handler = _make_handler(dict(self.OAUTH_CONFIG)) + result = handler.api_passthrough(PassthroughRequest(method="GET", path="/foo")) + + # Exactly one retry, so two upstream calls and two token POSTs. + # The second 401 is returned to the caller as-is. + assert len(upstream["calls"]) == 2 + assert token_calls["n"] == 2 + assert result.status_code == 401 + + def test_bearer_mode_does_not_retry_on_401(self, monkeypatch): + # Bearer mode: a 401 from the upstream is returned as-is. No second + # request, no retry logic. + upstream = _stub_upstream_with_status_sequence(monkeypatch, [401]) + + handler = _make_handler() # default bearer + result = handler.api_passthrough(PassthroughRequest(method="GET", path="/foo")) + + assert len(upstream["calls"]) == 1 + assert result.status_code == 401 + + def test_non_401_status_does_not_retry(self, monkeypatch): + _stub_oauth_dns(monkeypatch) + _stub_token_endpoint(monkeypatch) + upstream = _stub_upstream_with_status_sequence(monkeypatch, [403]) + + handler = _make_handler(dict(self.OAUTH_CONFIG)) + result = handler.api_passthrough(PassthroughRequest(method="GET", path="/foo")) + + # 403 (or any non-401) does not trigger the retry path. + assert len(upstream["calls"]) == 1 + assert result.status_code == 403 + + +# --------------------------------------------------------------------------- +# OAuth check_connection +# --------------------------------------------------------------------------- + + +class TestOAuthCheckConnection: + OAUTH_CONFIG = { + "base_url": "https://api.example.com", + "auth_type": "oauth_client_credentials", + "token_url": "https://auth.example.com/token", + "client_id": "cid", + "client_secret": "csec", + } + + def test_oauth_check_connection_fetches_token(self, monkeypatch): + _stub_oauth_dns(monkeypatch) + token_calls = _stub_token_endpoint(monkeypatch) + _stub_upstream(monkeypatch) + + handler = _make_handler(dict(self.OAUTH_CONFIG)) + response = handler.check_connection() + + assert response.success is True + assert token_calls["n"] == 1 + + def test_oauth_check_connection_calls_test_path(self, monkeypatch): + # Default test_path is "/", so the upstream sanity request goes to + # base_url + "/". + _stub_oauth_dns(monkeypatch) + _stub_token_endpoint(monkeypatch) + captured = _stub_upstream(monkeypatch) + + handler = _make_handler(dict(self.OAUTH_CONFIG)) + response = handler.check_connection() + + assert response.success is True + assert captured["url"].startswith("https://api.example.com") + assert captured["url"].endswith("/") + + def test_oauth_check_connection_calls_custom_test_path(self, monkeypatch): + _stub_oauth_dns(monkeypatch) + _stub_token_endpoint(monkeypatch) + captured = _stub_upstream(monkeypatch) + + cfg = dict(self.OAUTH_CONFIG) + cfg["test_path"] = "/healthz" + handler = _make_handler(cfg) + response = handler.check_connection() + + assert response.success is True + assert captured["url"] == "https://api.example.com/healthz" + + def test_oauth_check_connection_token_endpoint_failure_is_safe(self, monkeypatch): + _stub_oauth_dns(monkeypatch) + # Token endpoint returns 401. + from mindsdb.integrations.utilities.handlers.auth_utilities.oauth2 import ( + client_credentials as cc_module, + ) + + def fake_post(*a, **kw): + class _R: + status_code = 401 + is_redirect = False + headers = {} + + def iter_content(self, chunk_size=4096): + import json as _j + + yield _j.dumps({"error": "invalid_client"}).encode() + + def close(self): + pass + + return _R() + + monkeypatch.setattr(cc_module.requests, "post", fake_post) + + handler = _make_handler(dict(self.OAUTH_CONFIG)) + response = handler.check_connection() + + assert response.success is False + # client_secret must not leak into the error message. + assert "csec" not in (response.error_message or "") + # The upstream of the IdP error should be reflected back to the caller + # in some form (status code reference is fine). + assert response.error_message # non-empty + + def test_oauth_check_connection_upstream_test_failure_is_safe(self, monkeypatch): + _stub_oauth_dns(monkeypatch) + _stub_token_endpoint(monkeypatch, access_token="OAUTH-AT") + # Upstream returns 500 — token was fetched OK but the upstream itself + # rejected the request. + _stub_upstream_with_status_sequence(monkeypatch, [500]) + + handler = _make_handler(dict(self.OAUTH_CONFIG)) + response = handler.check_connection() + + assert response.success is False + assert "OAUTH-AT" not in (response.error_message or "") + assert "csec" not in (response.error_message or "") + + def test_oauth_check_connection_401_then_success_via_retry(self, monkeypatch): + # The upstream test goes through api_passthrough, so the 401-retry + # applies during check_connection too: a 401 + subsequent 200 makes + # check_connection succeed. + _stub_oauth_dns(monkeypatch) + _stub_token_endpoint(monkeypatch) + upstream = _stub_upstream_with_status_sequence(monkeypatch, [401, 200]) + + handler = _make_handler(dict(self.OAUTH_CONFIG)) + response = handler.check_connection() + + assert response.success is True + assert len(upstream["calls"]) == 2 + + def test_oauth_check_connection_persistent_401_fails_safely(self, monkeypatch): + _stub_oauth_dns(monkeypatch) + _stub_token_endpoint(monkeypatch, access_token="OAUTH-AT") + _stub_upstream_with_status_sequence(monkeypatch, [401]) + + handler = _make_handler(dict(self.OAUTH_CONFIG)) + response = handler.check_connection() + + assert response.success is False + # test_passthrough's 401 message is "upstream rejected credentials..." + # which is safe — it doesn't include the token or client_secret. + assert "OAUTH-AT" not in (response.error_message or "") + assert "csec" not in (response.error_message or "") + + def test_bearer_check_connection_does_not_make_http_calls(self, monkeypatch): + # Bearer mode keeps schema-only check_connection. If anything reaches + # the network, this test will catch it. + called = {"n": 0} + + def boom(*a, **kw): + called["n"] += 1 + raise AssertionError("bearer check_connection must not hit the network") + + from mindsdb.integrations.libs import passthrough as pt_module + + monkeypatch.setattr(pt_module.requests, "request", boom) + + handler = _make_handler() + response = handler.check_connection() + + assert response.success is True + assert called["n"] == 0 + + def test_bearer_check_connection_missing_token_unchanged(self): + # Pre-existing bearer error path is preserved. + handler = _make_handler({"base_url": "https://api.example.com"}) + response = handler.check_connection() + assert response.success is False + assert "bearer_token" in response.error_message + + def test_oauth_check_connection_no_secrets_in_error(self, monkeypatch): + # Provoke a connection error at the IdP transport layer and check + # that neither client_secret nor anything resembling an Authorization + # header value leaks into response.error_message. + _stub_oauth_dns(monkeypatch) + + from mindsdb.integrations.utilities.handlers.auth_utilities.oauth2 import ( + client_credentials as cc_module, + ) + import requests as _requests + + def boom(*a, **kw): + raise _requests.ConnectionError( + f"transport failure carrying client_secret={self.OAUTH_CONFIG['client_secret']}" + ) + + monkeypatch.setattr(cc_module.requests, "post", boom) + + handler = _make_handler(dict(self.OAUTH_CONFIG)) + response = handler.check_connection() + + assert response.success is False + # The provider's _sanitize_exception strips client_secret from + # transport-error chains; verify the final surface is clean. + msg = response.error_message or "" + assert self.OAUTH_CONFIG["client_secret"] not in msg + assert "Bearer " not in msg + + def test_oauth_check_connection_unsupported_auth_type_unchanged(self): + # Pre-existing path — unsupported auth_type still errors out via the + # schema validator without touching the network. + handler = _make_handler({"base_url": "https://api.example.com", "auth_type": "saml"}) + response = handler.check_connection() + assert response.success is False + assert "auth_type" in response.error_message diff --git a/tests/unit/handlers/test_salesforce.py b/tests/unit/handlers/test_salesforce.py index 62be61a18f..6df5580ba2 100644 --- a/tests/unit/handlers/test_salesforce.py +++ b/tests/unit/handlers/test_salesforce.py @@ -680,5 +680,64 @@ def test_meta_get_columns_builds_schema(self): self.assertEqual(columns[0]["data_type"], "string") +class TestSalesforcePassthrough(unittest.TestCase): + """Exercise the PassthroughMixin retrofit (per-instance base URL).""" + + CONNECTION_DATA = { + "username": "u", + "password": "p", + "client_id": "cid", + "client_secret": "csec", + "access_token": "sf_access_tok", + "instance_url": "https://my-org.my.salesforce.com", + } + + def _mock_response(self, status_code=200): + resp = MagicMock() + resp.status_code = status_code + resp.headers = {"Content-Type": "application/json"} + resp.iter_content = MagicMock(return_value=iter([b'{"sobjects":[]}'])) + resp.close = MagicMock() + return resp + + @patch("mindsdb.integrations.libs.passthrough.requests.request") + def test_passthrough_uses_bearer_and_instance_url(self, mock_request): + mock_request.return_value = self._mock_response() + handler = SalesforceHandler("salesforce", connection_data=self.CONNECTION_DATA) + from mindsdb.integrations.libs.passthrough_types import PassthroughRequest + + resp = handler.api_passthrough(PassthroughRequest("GET", "/services/data/v60.0/")) + + self.assertEqual(resp.status_code, 200) + args, kwargs = mock_request.call_args + self.assertEqual(args[0], "GET") + self.assertEqual(args[1], "https://my-org.my.salesforce.com/services/data/v60.0/") + self.assertEqual(kwargs["headers"]["Authorization"], "Bearer sf_access_tok") + + @patch("mindsdb.integrations.libs.passthrough.requests.request") + def test_test_passthrough_returns_ok_on_200(self, mock_request): + mock_request.return_value = self._mock_response(status_code=200) + handler = SalesforceHandler("salesforce", connection_data=self.CONNECTION_DATA) + + result = handler.test_passthrough() + + self.assertTrue(result["ok"]) + self.assertEqual(result["status_code"], 200) + self.assertEqual(result["host"], "my-org.my.salesforce.com") + self.assertIsInstance(result["latency_ms"], int) + + @patch("mindsdb.integrations.libs.passthrough.requests.request") + def test_test_passthrough_returns_auth_failed_on_401(self, mock_request): + mock_request.return_value = self._mock_response(status_code=401) + handler = SalesforceHandler("salesforce", connection_data=self.CONNECTION_DATA) + + result = handler.test_passthrough() + + self.assertFalse(result["ok"]) + self.assertEqual(result["error_code"], "auth_failed") + self.assertEqual(result["status_code"], 401) + self.assertEqual(result["host"], "my-org.my.salesforce.com") + + if __name__ == "__main__": unittest.main() diff --git a/tests/unit/handlers/test_shopify_handler.py b/tests/unit/handlers/test_shopify_handler.py index 277ec09a63..b1ecc6e881 100644 --- a/tests/unit/handlers/test_shopify_handler.py +++ b/tests/unit/handlers/test_shopify_handler.py @@ -817,5 +817,67 @@ def test_limit_large_than_max_page_limit(self, mock_shopify_query): self.assertEqual(len(result), 300) +class TestShopifyPassthrough(unittest.TestCase): + """Exercise the PassthroughMixin retrofit (X-Shopify-Access-Token auth).""" + + CONNECTION_DATA = { + "shop_url": "test-shop.myshopify.com", + "client_id": "cid", + "client_secret": "csec", + "access_token": "shpat_tokenvalue", + } + + def _mock_response(self, status_code=200): + resp = MagicMock() + resp.status_code = status_code + resp.headers = {"Content-Type": "application/json"} + resp.iter_content = MagicMock(return_value=iter([b'{"shop":{"id":1}}'])) + resp.close = MagicMock() + return resp + + @patch("mindsdb.integrations.libs.passthrough.requests.request") + def test_passthrough_uses_shopify_header_and_per_shop_base_url(self, mock_request): + mock_request.return_value = self._mock_response() + handler = ShopifyHandler("shopify", connection_data=self.CONNECTION_DATA) + from mindsdb.integrations.libs.passthrough_types import PassthroughRequest + + resp = handler.api_passthrough(PassthroughRequest("GET", "/admin/api/2024-01/shop.json")) + + self.assertEqual(resp.status_code, 200) + args, kwargs = mock_request.call_args + self.assertEqual(args[0], "GET") + self.assertEqual(args[1], "https://test-shop.myshopify.com/admin/api/2024-01/shop.json") + # Custom Shopify auth header; no bearer Authorization. + self.assertEqual(kwargs["headers"]["X-Shopify-Access-Token"], "shpat_tokenvalue") + self.assertNotIn("Authorization", kwargs["headers"]) + + @patch("mindsdb.integrations.libs.passthrough.requests.request") + def test_test_passthrough_returns_ok_on_200(self, mock_request): + mock_request.return_value = self._mock_response(status_code=200) + handler = ShopifyHandler("shopify", connection_data=self.CONNECTION_DATA) + + result = handler.test_passthrough() + + self.assertTrue(result["ok"]) + self.assertEqual(result["status_code"], 200) + self.assertEqual(result["host"], "test-shop.myshopify.com") + self.assertIsInstance(result["latency_ms"], int) + # The probe should hit the version-less endpoint so it survives + # Shopify's quarterly Admin API version retirements. + self.assertEqual(mock_request.call_args[0][1], "https://test-shop.myshopify.com/admin/shop.json") + + @patch("mindsdb.integrations.libs.passthrough.requests.request") + def test_test_passthrough_returns_auth_failed_on_401(self, mock_request): + mock_request.return_value = self._mock_response(status_code=401) + handler = ShopifyHandler("shopify", connection_data=self.CONNECTION_DATA) + + result = handler.test_passthrough() + + self.assertFalse(result["ok"]) + self.assertEqual(result["error_code"], "auth_failed") + self.assertEqual(result["status_code"], 401) + self.assertEqual(result["host"], "test-shop.myshopify.com") + + if __name__ == "__main__": unittest.main() diff --git a/tests/unit/test_passthrough.py b/tests/unit/test_passthrough.py new file mode 100644 index 0000000000..9b86bc9f7e --- /dev/null +++ b/tests/unit/test_passthrough.py @@ -0,0 +1,400 @@ +"""Unit tests for PassthroughMixin.""" + +import unittest +from unittest.mock import MagicMock, patch + +from mindsdb.integrations.libs.passthrough import ( + PassthroughMixin, + REDACTED_SENTINEL, +) +from mindsdb.integrations.libs.passthrough_types import ( + HostNotAllowedError, + PassthroughConfigError, + PassthroughRequest, + PassthroughValidationError, +) + + +class _FakeHandler(PassthroughMixin): + """Minimal handler stub for exercising the mixin.""" + + _bearer_token_arg = "api_key" + _base_url_default = "https://api.example.com" + _test_request = PassthroughRequest(method="GET", path="/me") + + def __init__(self, connection_data: dict): + self.name = "fake_ds" + self.connection_data = connection_data + + +def _mock_response(status_code=200, body=b'{"ok":true}', headers=None, content_type="application/json"): + """Return a mock requests.Response exposing the bits the mixin uses.""" + resp = MagicMock() + resp.status_code = status_code + resp.headers = {"Content-Type": content_type, **(headers or {})} + resp.iter_content = MagicMock(return_value=iter([body])) + resp.close = MagicMock() + return resp + + +class PassthroughHappyPathTests(unittest.TestCase): + def setUp(self): + self.handler = _FakeHandler({"api_key": "secret-token-abcdef1234567890"}) + + @patch("mindsdb.integrations.libs.passthrough.requests.request") + def test_injects_bearer_and_uses_default_base_url(self, mock_request): + mock_request.return_value = _mock_response() + resp = self.handler.api_passthrough(PassthroughRequest("GET", "/me")) + + self.assertEqual(resp.status_code, 200) + self.assertEqual(resp.body, {"ok": True}) + + args, kwargs = mock_request.call_args + self.assertEqual(args[0], "GET") + self.assertEqual(args[1], "https://api.example.com/me") + self.assertEqual(kwargs["headers"]["Authorization"], "Bearer secret-token-abcdef1234567890") + self.assertEqual(kwargs["headers"]["X-Minds-Passthrough"], "1") + + @patch("mindsdb.integrations.libs.passthrough.requests.request") + def test_user_base_url_overrides_default(self, mock_request): + self.handler.connection_data["base_url"] = "https://api.eu.example.com" + mock_request.return_value = _mock_response() + self.handler.api_passthrough(PassthroughRequest("GET", "/me")) + self.assertEqual(mock_request.call_args[0][1], "https://api.eu.example.com/me") + + @patch("mindsdb.integrations.libs.passthrough.requests.request") + def test_query_params_forwarded(self, mock_request): + mock_request.return_value = _mock_response() + self.handler.api_passthrough(PassthroughRequest("GET", "/x", query={"a": "1"})) + self.assertEqual(mock_request.call_args.kwargs["params"], {"a": "1"}) + + @patch("mindsdb.integrations.libs.passthrough.requests.request") + def test_json_body_forwarded(self, mock_request): + mock_request.return_value = _mock_response() + self.handler.api_passthrough(PassthroughRequest("POST", "/x", body={"name": "foo"})) + self.assertEqual(mock_request.call_args.kwargs["json"], {"name": "foo"}) + + @patch("mindsdb.integrations.libs.passthrough.requests.request") + def test_default_headers_merged(self, mock_request): + self.handler.connection_data["default_headers"] = {"Accept": "application/json"} + mock_request.return_value = _mock_response() + self.handler.api_passthrough(PassthroughRequest("GET", "/x")) + self.assertEqual(mock_request.call_args.kwargs["headers"]["Accept"], "application/json") + + +class PassthroughHeaderFilteringTests(unittest.TestCase): + def setUp(self): + self.handler = _FakeHandler({"api_key": "secret-token-abcdef1234567890"}) + + @patch("mindsdb.integrations.libs.passthrough.requests.request") + def test_caller_cannot_override_authorization(self, mock_request): + mock_request.return_value = _mock_response() + self.handler.api_passthrough( + PassthroughRequest("GET", "/x", headers={"Authorization": "Bearer hijack", "Cookie": "s=1"}) + ) + outgoing = mock_request.call_args.kwargs["headers"] + self.assertEqual(outgoing["Authorization"], "Bearer secret-token-abcdef1234567890") + self.assertNotIn("Cookie", outgoing) + + @patch("mindsdb.integrations.libs.passthrough.requests.request") + def test_proxy_headers_stripped(self, mock_request): + mock_request.return_value = _mock_response() + self.handler.api_passthrough(PassthroughRequest("GET", "/x", headers={"Proxy-Authorization": "hijack"})) + outgoing = mock_request.call_args.kwargs["headers"] + self.assertNotIn("Proxy-Authorization", outgoing) + + @patch("mindsdb.integrations.libs.passthrough.requests.request") + def test_hop_by_hop_response_headers_stripped(self, mock_request): + mock_request.return_value = _mock_response( + headers={"Connection": "close", "X-Safe": "1", "Transfer-Encoding": "chunked"} + ) + resp = self.handler.api_passthrough(PassthroughRequest("GET", "/x")) + self.assertNotIn("Connection", resp.headers) + self.assertNotIn("Transfer-Encoding", resp.headers) + self.assertEqual(resp.headers.get("X-Safe"), "1") + + +class PassthroughHostAllowlistTests(unittest.TestCase): + def test_rejects_host_outside_allowlist(self): + handler = _FakeHandler( + { + "api_key": "t", + "base_url": "https://api.example.com", + "allowed_hosts": ["api.example.com"], + } + ) + # Direct host check using a bad URL + with self.assertRaises(HostNotAllowedError): + handler._check_host_allowed("evil.com") + + def test_wildcard_allows_any_host(self): + handler = _FakeHandler( + { + "api_key": "t", + "base_url": "https://api.example.com", + "allowed_hosts": ["*"], + } + ) + handler._check_host_allowed("evil.com") # must not raise + + def test_private_ip_rejected_by_default(self): + handler = _FakeHandler({"api_key": "t", "base_url": "http://10.0.0.1"}) + with self.assertRaises(HostNotAllowedError): + handler._check_host_allowed("10.0.0.1") + + def test_private_ip_allowed_when_explicitly_listed(self): + handler = _FakeHandler( + { + "api_key": "t", + "base_url": "http://10.0.0.1", + "allowed_hosts": ["10.0.0.1"], + } + ) + # Explicitly allowlisted private IP should still be rejected — the + # mixin treats explicit private-IP allowlisting as a foot-gun that + # requires the "*" escape hatch. Document this behavior. + with self.assertRaises(HostNotAllowedError): + handler._check_host_allowed("10.0.0.1") + + def test_loopback_rejected_with_wildcard_when_asterisk_not_used(self): + handler = _FakeHandler( + { + "api_key": "t", + "base_url": "http://127.0.0.1", + "allowed_hosts": ["127.0.0.1"], + } + ) + with self.assertRaises(HostNotAllowedError): + handler._check_host_allowed("127.0.0.1") + + +class PassthroughValidationTests(unittest.TestCase): + def test_missing_bearer_raises(self): + handler = _FakeHandler({}) # no api_key + with self.assertRaises(PassthroughConfigError): + handler.api_passthrough(PassthroughRequest("GET", "/me")) + + def test_missing_base_url_raises(self): + class NoDefault(_FakeHandler): + _base_url_default = None + + handler = NoDefault({"api_key": "t"}) + with self.assertRaises(PassthroughConfigError): + handler.api_passthrough(PassthroughRequest("GET", "/me")) + + def test_path_must_start_with_slash(self): + handler = _FakeHandler({"api_key": "t"}) + with self.assertRaises(PassthroughValidationError): + handler.api_passthrough(PassthroughRequest("GET", "me")) + + def test_method_allowlist(self): + handler = _FakeHandler({"api_key": "t"}) + with self.assertRaises(PassthroughValidationError): + handler.api_passthrough(PassthroughRequest("TRACE", "/me")) + + +class PassthroughSecretScrubTests(unittest.TestCase): + @patch("mindsdb.integrations.libs.passthrough.requests.request") + def test_token_scrubbed_from_json_body(self, mock_request): + token = "secret-token-abcdef1234567890" + # Non-UTF-8 byte (0xFF) positioned adjacent to the token. Spec §7.6 + # mandates byte-level scrubbing: if the scrub ran after a + # errors="replace" decode, U+FFFD insertions would risk fragmenting + # a token mid-match. Byte-level scrub avoids that entirely. + body = b'{"error":"Invalid token ' + token.encode("utf-8") + b' \xff trailing"}' + handler = _FakeHandler({"api_key": token}) + # Use plain-text content-type so the non-UTF-8 body survives without + # a json.loads detour; the scrub is still invoked. + mock_request.return_value = _mock_response(body=body, content_type="text/plain") + + resp = handler.api_passthrough(PassthroughRequest("GET", "/x")) + # Token must not survive anywhere in the body. + self.assertNotIn(token, str(resp.body)) + self.assertIn(REDACTED_SENTINEL, str(resp.body)) + + @patch("mindsdb.integrations.libs.passthrough.requests.request") + def test_token_scrubbed_from_headers(self, mock_request): + token = "secret-token-abcdef1234567890" + handler = _FakeHandler({"api_key": token}) + mock_request.return_value = _mock_response( + headers={"X-Debug-Auth": f"Bearer {token}"}, + ) + resp = handler.api_passthrough(PassthroughRequest("GET", "/x")) + self.assertIn(REDACTED_SENTINEL, resp.headers["X-Debug-Auth"]) + self.assertNotIn(token, resp.headers["X-Debug-Auth"]) + + @patch("mindsdb.integrations.libs.passthrough.requests.request") + def test_long_default_header_values_scrubbed(self, mock_request): + token = "secret-token-abcdef1234567890" + long_secret = "x" * 32 + handler = _FakeHandler( + { + "api_key": token, + "default_headers": {"X-Api-Secondary": long_secret}, + } + ) + mock_request.return_value = _mock_response(body=('{"echoed":"' + long_secret + '"}').encode("utf-8")) + resp = handler.api_passthrough(PassthroughRequest("GET", "/x")) + self.assertEqual(resp.body["echoed"], REDACTED_SENTINEL) + + @patch("mindsdb.integrations.libs.passthrough.requests.request") + def test_token_scrubbed_in_nested_json_without_corrupting_structure(self, mock_request): + token = "secret-token-abcdef1234567890" + handler = _FakeHandler({"api_key": token}) + body = ('{"data": {"nested": {"token": "' + token + '"}}}').encode("utf-8") + mock_request.return_value = _mock_response(body=body) + + resp = handler.api_passthrough(PassthroughRequest("GET", "/x")) + + # Structure preserved: dict-of-dict-of-dict with the expected keys. + self.assertIsInstance(resp.body, dict) + self.assertIsInstance(resp.body["data"], dict) + self.assertIsInstance(resp.body["data"]["nested"], dict) + self.assertEqual(set(resp.body.keys()), {"data"}) + self.assertEqual(set(resp.body["data"].keys()), {"nested"}) + self.assertEqual(set(resp.body["data"]["nested"].keys()), {"token"}) + # Value redacted at the leaf; token does not survive anywhere. + self.assertEqual(resp.body["data"]["nested"]["token"], REDACTED_SENTINEL) + self.assertNotIn(token, str(resp.body)) + + +class PassthroughTestEndpointTests(unittest.TestCase): + @patch("mindsdb.integrations.libs.passthrough.requests.request") + def test_returns_ok_on_200(self, mock_request): + handler = _FakeHandler({"api_key": "t"}) + mock_request.return_value = _mock_response(status_code=200) + result = handler.test_passthrough() + self.assertTrue(result["ok"]) + self.assertEqual(result["status_code"], 200) + self.assertEqual(result["host"], "api.example.com") + + @patch("mindsdb.integrations.libs.passthrough.requests.request") + def test_returns_auth_failed_on_401(self, mock_request): + handler = _FakeHandler({"api_key": "t"}) + mock_request.return_value = _mock_response(status_code=401) + result = handler.test_passthrough() + self.assertFalse(result["ok"]) + self.assertEqual(result["error_code"], "auth_failed") + + def test_returns_not_implemented_when_no_test_request(self): + class NoTest(_FakeHandler): + _test_request = None + + handler = NoTest({"api_key": "t"}) + result = handler.test_passthrough() + self.assertFalse(result["ok"]) + self.assertEqual(result["error_code"], "not_implemented") + + +class PassthroughAllowedMethodsTests(unittest.TestCase): + @patch("mindsdb.integrations.libs.passthrough.requests.request") + def test_rejects_method_not_in_allowed_methods(self, mock_request): + handler = _FakeHandler( + { + "api_key": "t", + "allowed_methods": ["GET"], + } + ) + mock_request.return_value = _mock_response() + + with self.assertRaises(PassthroughValidationError) as cm: + handler.api_passthrough(PassthroughRequest("POST", "/x")) + + self.assertEqual(cm.exception.error_code, "method_not_allowed") + self.assertEqual(cm.exception.http_status, 405) + mock_request.assert_not_called() + + @patch("mindsdb.integrations.libs.passthrough.requests.request") + def test_all_methods_allowed_when_config_absent(self, mock_request): + handler = _FakeHandler({"api_key": "t"}) + mock_request.return_value = _mock_response() + + for method in ("GET", "POST", "PUT", "PATCH", "DELETE"): + mock_request.reset_mock() + mock_request.return_value = _mock_response() + handler.api_passthrough(PassthroughRequest(method, "/x")) + self.assertEqual(mock_request.call_args[0][0], method) + + @patch("mindsdb.integrations.libs.passthrough.requests.request") + def test_string_allowed_methods_raises_config_error(self, mock_request): + handler = _FakeHandler({"api_key": "t", "allowed_methods": "GET"}) + + with self.assertRaises(PassthroughConfigError): + handler.api_passthrough(PassthroughRequest("GET", "/x")) + + mock_request.assert_not_called() + + @patch("mindsdb.integrations.libs.passthrough.requests.request") + def test_lowercase_allowed_methods_normalized(self, mock_request): + handler = _FakeHandler({"api_key": "t", "allowed_methods": ["get"]}) + mock_request.return_value = _mock_response() + + # GET passes after uppercase normalization. + handler.api_passthrough(PassthroughRequest("GET", "/x")) + self.assertEqual(mock_request.call_args[0][0], "GET") + + mock_request.reset_mock() + # POST is rejected with method_not_allowed. + with self.assertRaises(PassthroughValidationError) as cm: + handler.api_passthrough(PassthroughRequest("POST", "/x")) + self.assertEqual(cm.exception.error_code, "method_not_allowed") + self.assertEqual(cm.exception.http_status, 405) + mock_request.assert_not_called() + + @patch("mindsdb.integrations.libs.passthrough.requests.request") + def test_unknown_verb_in_allowed_methods_raises_config_error(self, mock_request): + handler = _FakeHandler({"api_key": "t", "allowed_methods": ["GET", "TRACE"]}) + + with self.assertRaises(PassthroughConfigError) as cm: + handler.api_passthrough(PassthroughRequest("GET", "/x")) + self.assertIn("TRACE", str(cm.exception)) + mock_request.assert_not_called() + + +class PassthroughAuthHeaderOverrideTests(unittest.TestCase): + @patch("mindsdb.integrations.libs.passthrough.requests.request") + def test_custom_auth_header_name_and_format(self, mock_request): + class ShopifyLikeHandler(_FakeHandler): + _auth_header_name = "X-Shopify-Access-Token" + _auth_header_format = "{token}" + + handler = ShopifyLikeHandler({"api_key": "shpat_abc123"}) + mock_request.return_value = _mock_response() + + handler.api_passthrough(PassthroughRequest("GET", "/x")) + + outgoing = mock_request.call_args.kwargs["headers"] + # Custom header present, with raw token (no "Bearer " prefix). + self.assertEqual(outgoing["X-Shopify-Access-Token"], "shpat_abc123") + # Default Authorization header must NOT be added when the handler + # overrides the auth header name. + self.assertNotIn("Authorization", outgoing) + + +class PassthroughProtocolTests(unittest.TestCase): + def test_non_mixin_class_satisfies_protocol(self): + from mindsdb.integrations.libs.passthrough import PassthroughProtocol + from mindsdb.integrations.libs.passthrough_types import PassthroughResponse + + class ManualHandler: + def api_passthrough(self, req: PassthroughRequest) -> PassthroughResponse: + return PassthroughResponse(status_code=200, headers={}, body=None, content_type=None) + + def test_passthrough(self) -> dict: + return {"ok": True} + + self.assertIsInstance(ManualHandler(), PassthroughProtocol) + + def test_class_missing_methods_fails_protocol(self): + from mindsdb.integrations.libs.passthrough import PassthroughProtocol + + class Incomplete: + def api_passthrough(self, req): ... + + # missing test_passthrough + + self.assertNotIsInstance(Incomplete(), PassthroughProtocol) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/utilities/handlers/auth_utilities/oauth2/test_client_credentials.py b/tests/unit/utilities/handlers/auth_utilities/oauth2/test_client_credentials.py new file mode 100644 index 0000000000..edf48aa9f4 --- /dev/null +++ b/tests/unit/utilities/handlers/auth_utilities/oauth2/test_client_credentials.py @@ -0,0 +1,889 @@ +"""Tests for OAuth2ClientCredentialsProvider.""" + +from __future__ import annotations + +import base64 +import logging +import threading +import time +from typing import Optional + +import pytest +import requests + +from mindsdb.integrations.utilities.handlers.auth_utilities.oauth2 import ( + OAuth2ClientCredentialsProvider, +) +from mindsdb.integrations.utilities.handlers.auth_utilities.oauth2 import client_credentials as cc_module + + +VALID_TOKEN_URL = "https://example.com/oauth/token" +CLIENT_ID = "client-id-abc" +CLIENT_SECRET = "client-secret-xyz" + + +def _conn(**overrides): + """Build a connection_data dict with valid defaults.""" + data = { + "token_url": VALID_TOKEN_URL, + "client_id": CLIENT_ID, + "client_secret": CLIENT_SECRET, + } + data.update(overrides) + return data + + +class FakeResponse: + """Minimal stand-in for requests.Response used by token-request tests.""" + + def __init__( + self, + status_code: int = 200, + json_body: Optional[dict] = None, + raw_body: Optional[bytes] = None, + is_redirect: bool = False, + ) -> None: + self.status_code = status_code + if raw_body is not None: + self._body = raw_body + elif json_body is not None: + import json as _json + + self._body = _json.dumps(json_body).encode("utf-8") + else: + self._body = b"" + self.is_redirect = is_redirect + + def iter_content(self, chunk_size: int = 4096): + # Yield in chunks of chunk_size + for i in range(0, len(self._body), chunk_size): + yield self._body[i : i + chunk_size] + + def close(self): + pass + + +class FakeStorage: + """Minimal in-memory stand-in for handler_storage.encrypted_json_*.""" + + def __init__(self) -> None: + self._data: dict = {} + self.set_calls = 0 + self.get_calls = 0 + + def encrypted_json_get(self, key: str): + self.get_calls += 1 + return self._data.get(key) + + def encrypted_json_set(self, key: str, value) -> None: + self.set_calls += 1 + if value is None: + self._data.pop(key, None) + else: + self._data[key] = value + + +class FailingStorage: + """A handler_storage stand-in whose set always raises.""" + + def __init__(self) -> None: + self.set_calls = 0 + self.get_calls = 0 + + def encrypted_json_get(self, key: str): + self.get_calls += 1 + return None + + def encrypted_json_set(self, key: str, value) -> None: + self.set_calls += 1 + raise RuntimeError("storage offline") + + +def _bypass_dns(monkeypatch): + """Stub socket.getaddrinfo so the public-host SSRF check does not hit DNS. + + Returns a public IP (1.2.3.4) which passes all forbidden-range checks. + """ + + def fake_getaddrinfo(host, *args, **kwargs): + return [(2, 1, 6, "", ("1.2.3.4", 0))] + + monkeypatch.setattr(cc_module.socket, "getaddrinfo", fake_getaddrinfo) + + +class TestConstructionValidation: + def test_unsupported_auth_method_raises(self, monkeypatch): + _bypass_dns(monkeypatch) + with pytest.raises(ValueError) as excinfo: + OAuth2ClientCredentialsProvider(_conn(token_auth_method="client_secret_jwt")) + msg = str(excinfo.value) + assert "client_secret_post" in msg + assert "client_secret_basic" in msg + + def test_client_secret_post_accepted(self, monkeypatch): + _bypass_dns(monkeypatch) + OAuth2ClientCredentialsProvider(_conn(token_auth_method="client_secret_post")) + + def test_client_secret_basic_accepted(self, monkeypatch): + _bypass_dns(monkeypatch) + OAuth2ClientCredentialsProvider(_conn(token_auth_method="client_secret_basic")) + + def test_default_token_auth_method_is_client_secret_post(self, monkeypatch): + _bypass_dns(monkeypatch) + provider = OAuth2ClientCredentialsProvider(_conn()) + assert provider.token_auth_method == "client_secret_post" + + def test_missing_token_url_raises(self): + with pytest.raises(ValueError) as excinfo: + OAuth2ClientCredentialsProvider({"client_id": CLIENT_ID, "client_secret": CLIENT_SECRET}) + assert "token_url" in str(excinfo.value) + + def test_missing_client_id_raises(self, monkeypatch): + _bypass_dns(monkeypatch) + with pytest.raises(ValueError) as excinfo: + OAuth2ClientCredentialsProvider({"token_url": VALID_TOKEN_URL, "client_secret": CLIENT_SECRET}) + assert "client_id" in str(excinfo.value) + + def test_missing_client_secret_raises(self, monkeypatch): + _bypass_dns(monkeypatch) + with pytest.raises(ValueError) as excinfo: + OAuth2ClientCredentialsProvider({"token_url": VALID_TOKEN_URL, "client_id": CLIENT_ID}) + assert "client_secret" in str(excinfo.value) + + def test_non_dict_connection_data_raises(self): + with pytest.raises(TypeError): + OAuth2ClientCredentialsProvider("not-a-dict") + + @pytest.mark.parametrize( + "url", + [ + "http://localhost/oauth/token", + "http://127.0.0.1/oauth/token", + "http://10.0.0.1/oauth/token", + "http://169.254.169.254/oauth/token", + "http://192.168.1.1/oauth/token", + "http://[::1]/oauth/token", + "file:///etc/passwd", + "ftp://example.com/oauth", + ], + ) + def test_ssrf_rules_reject_url(self, url): + with pytest.raises(ValueError): + OAuth2ClientCredentialsProvider(_conn(token_url=url)) + + def test_http_url_accepted_with_warning(self, monkeypatch, caplog): + _bypass_dns(monkeypatch) + with caplog.at_level(logging.WARNING, logger=cc_module.logger.name): + OAuth2ClientCredentialsProvider(_conn(token_url="http://example.com/oauth")) + assert any("http://" in r.message or "unencrypted" in r.message for r in caplog.records) + + def test_https_url_accepted_no_warning(self, monkeypatch, caplog): + _bypass_dns(monkeypatch) + with caplog.at_level(logging.WARNING, logger=cc_module.logger.name): + OAuth2ClientCredentialsProvider(_conn(token_url="https://example.com/oauth")) + assert not any(r.levelno == logging.WARNING for r in caplog.records) + + +class TestTokenUrlAllowedHosts: + """allowed_hosts gating for token_url, mirroring the passthrough mixin. + + Operator semantics: + - missing/empty → no host allowlist applied (SSRF still runs) + - ["*"] → host allowlist skipped, baseline SSRF still runs + - other list → token host must match one entry (case-insensitive) + """ + + def test_no_allowed_hosts_does_not_restrict_token_host(self, monkeypatch): + _bypass_dns(monkeypatch) + # No allowed_hosts configured — provider constructs successfully even + # though the token host differs from any base_url the caller might use. + OAuth2ClientCredentialsProvider(_conn(token_url="https://idp.example.com/oauth")) + + def test_distinct_api_and_token_hosts_both_in_allowlist(self, monkeypatch): + _bypass_dns(monkeypatch) + # token_url host differs from a notional base_url host; both listed. + OAuth2ClientCredentialsProvider( + _conn( + token_url="https://idp.example.com/oauth/token", + allowed_hosts=["api.example.com", "idp.example.com"], + ) + ) + + def test_token_host_not_in_allowlist_rejected(self, monkeypatch): + _bypass_dns(monkeypatch) + with pytest.raises(ValueError) as excinfo: + OAuth2ClientCredentialsProvider( + _conn( + token_url="https://idp.example.com/oauth/token", + allowed_hosts=["api.example.com"], # IdP host omitted + ) + ) + msg = str(excinfo.value) + assert "idp.example.com" in msg + assert "allowed_hosts" in msg or "allowlist" in msg + + def test_allowlist_match_is_case_insensitive(self, monkeypatch): + _bypass_dns(monkeypatch) + OAuth2ClientCredentialsProvider( + _conn( + token_url="https://IdP.Example.Com/oauth/token", + allowed_hosts=["idp.example.com"], + ) + ) + + @pytest.mark.parametrize( + "url", + [ + "http://localhost/oauth/token", + "http://127.0.0.1/oauth/token", + "http://10.0.0.1/oauth/token", + "http://169.254.169.254/oauth/token", + "http://[::1]/oauth/token", + ], + ) + def test_wildcard_does_not_bypass_ssrf(self, url): + # Operator wrote allowed_hosts=["*"] (skip allowlist), but the token + # endpoint must still pass baseline SSRF checks. A bypass here would + # let a misconfigured datasource POST the client_secret to internal + # infrastructure (cloud metadata, in-cluster services, etc.). + with pytest.raises(ValueError): + OAuth2ClientCredentialsProvider(_conn(token_url=url, allowed_hosts=["*"])) + + def test_wildcard_skips_allowlist_for_public_token_host(self, monkeypatch): + _bypass_dns(monkeypatch) + # ["*"] disables only the allowlist — a public host still passes SSRF + # and constructs successfully without being explicitly listed. + OAuth2ClientCredentialsProvider(_conn(token_url="https://idp.example.com/oauth/token", allowed_hosts=["*"])) + + def test_empty_allowed_hosts_list_treated_as_unset(self, monkeypatch): + _bypass_dns(monkeypatch) + # An empty list is interpreted as "no allowlist configured" rather + # than "no host is allowed", to match passthrough's _allowed_hosts + # which falls back to [default_host] when allowed is empty. + OAuth2ClientCredentialsProvider(_conn(token_url="https://idp.example.com/oauth/token", allowed_hosts=[])) + + +class TestTokenEndpointTransportSafety: + """Pin the transport-layer guarantees: redirects, timeouts, response cap.""" + + def _provider(self, monkeypatch, **overrides): + _bypass_dns(monkeypatch) + return OAuth2ClientCredentialsProvider(_conn(**overrides)) + + def test_allow_redirects_disabled_on_token_post(self, monkeypatch): + provider = self._provider(monkeypatch) + captured = {} + + def fake_post(url, data=None, headers=None, allow_redirects=None, **_): + captured["allow_redirects"] = allow_redirects + return FakeResponse(json_body={"access_token": "AT", "expires_in": 3600}) + + monkeypatch.setattr(cc_module.requests, "post", fake_post) + provider.get_access_token() + assert captured["allow_redirects"] is False + + def test_connect_and_read_timeouts_match_constants(self, monkeypatch): + provider = self._provider(monkeypatch) + captured = {} + + def fake_post(url, data=None, headers=None, timeout=None, **_): + captured["timeout"] = timeout + return FakeResponse(json_body={"access_token": "AT", "expires_in": 3600}) + + monkeypatch.setattr(cc_module.requests, "post", fake_post) + provider.get_access_token() + assert captured["timeout"] == ( + cc_module.CONNECT_TIMEOUT_SECONDS, + cc_module.READ_TIMEOUT_SECONDS, + ) + # Pin the actual values so a future bump is a deliberate decision. + assert cc_module.CONNECT_TIMEOUT_SECONDS == 10 + assert cc_module.READ_TIMEOUT_SECONDS == 30 + + def test_response_body_over_64kb_rejected(self, monkeypatch): + provider = self._provider(monkeypatch) + oversize = b"x" * (cc_module.MAX_RESPONSE_BYTES + 1) + monkeypatch.setattr( + cc_module.requests, + "post", + lambda *a, **kw: FakeResponse(status_code=200, raw_body=oversize), + ) + with pytest.raises(RuntimeError) as excinfo: + provider.get_access_token() + assert "exceeded" in str(excinfo.value).lower() + assert cc_module.MAX_RESPONSE_BYTES == 64 * 1024 + + +class TestRequestShape: + def _provider(self, monkeypatch, **overrides): + _bypass_dns(monkeypatch) + return OAuth2ClientCredentialsProvider(_conn(**overrides)) + + def test_client_secret_post_places_credentials_in_body(self, monkeypatch): + provider = self._provider(monkeypatch, token_auth_method="client_secret_post") + captured = {} + + def fake_post(url, data=None, headers=None, **_): + captured["url"] = url + captured["data"] = data + captured["headers"] = headers + return FakeResponse(json_body={"access_token": "T", "expires_in": 3600}) + + monkeypatch.setattr(cc_module.requests, "post", fake_post) + provider.get_access_token() + + assert captured["data"]["client_id"] == CLIENT_ID + assert captured["data"]["client_secret"] == CLIENT_SECRET + assert "Authorization" not in captured["headers"] + + def test_client_secret_basic_uses_auth_header(self, monkeypatch): + provider = self._provider(monkeypatch, token_auth_method="client_secret_basic") + captured = {} + + def fake_post(url, data=None, headers=None, **_): + captured["data"] = data + captured["headers"] = headers + return FakeResponse(json_body={"access_token": "T", "expires_in": 3600}) + + monkeypatch.setattr(cc_module.requests, "post", fake_post) + provider.get_access_token() + + expected = "Basic " + base64.b64encode(f"{CLIENT_ID}:{CLIENT_SECRET}".encode("utf-8")).decode("ascii") + assert captured["headers"]["Authorization"] == expected + assert "client_id" not in captured["data"] + assert "client_secret" not in captured["data"] + + def test_grant_type_always_present(self, monkeypatch): + provider = self._provider(monkeypatch) + captured = {} + + def fake_post(url, data=None, headers=None, **_): + captured["data"] = data + return FakeResponse(json_body={"access_token": "T", "expires_in": 3600}) + + monkeypatch.setattr(cc_module.requests, "post", fake_post) + provider.get_access_token() + assert captured["data"]["grant_type"] == "client_credentials" + + def test_scope_string_included(self, monkeypatch): + provider = self._provider(monkeypatch, scope="read:foo write:bar") + captured = {} + + def fake_post(url, data=None, headers=None, **_): + captured["data"] = data + return FakeResponse(json_body={"access_token": "T", "expires_in": 3600}) + + monkeypatch.setattr(cc_module.requests, "post", fake_post) + provider.get_access_token() + assert captured["data"]["scope"] == "read:foo write:bar" + + def test_scope_list_joined_with_space(self, monkeypatch): + provider = self._provider(monkeypatch, scope=["read:foo", "write:bar"]) + captured = {} + + def fake_post(url, data=None, headers=None, **_): + captured["data"] = data + return FakeResponse(json_body={"access_token": "T", "expires_in": 3600}) + + monkeypatch.setattr(cc_module.requests, "post", fake_post) + provider.get_access_token() + assert captured["data"]["scope"] == "read:foo write:bar" + + def test_audience_included_when_configured(self, monkeypatch): + provider = self._provider(monkeypatch, audience="https://api.example.com") + captured = {} + + def fake_post(url, data=None, headers=None, **_): + captured["data"] = data + return FakeResponse(json_body={"access_token": "T", "expires_in": 3600}) + + monkeypatch.setattr(cc_module.requests, "post", fake_post) + provider.get_access_token() + assert captured["data"]["audience"] == "https://api.example.com" + + def test_audience_omitted_when_none(self, monkeypatch): + provider = self._provider(monkeypatch) + captured = {} + + def fake_post(url, data=None, headers=None, **_): + captured["data"] = data + return FakeResponse(json_body={"access_token": "T", "expires_in": 3600}) + + monkeypatch.setattr(cc_module.requests, "post", fake_post) + provider.get_access_token() + assert "audience" not in captured["data"] + + def test_redirects_disabled_and_timeouts_set(self, monkeypatch): + provider = self._provider(monkeypatch) + captured = {} + + def fake_post(url, data=None, headers=None, timeout=None, allow_redirects=None, **_): + captured["timeout"] = timeout + captured["allow_redirects"] = allow_redirects + return FakeResponse(json_body={"access_token": "T", "expires_in": 3600}) + + monkeypatch.setattr(cc_module.requests, "post", fake_post) + provider.get_access_token() + assert captured["allow_redirects"] is False + assert captured["timeout"] == (10, 30) + + +# --------------------------------------------------------------------------- +# Response handling +# --------------------------------------------------------------------------- + + +class TestResponseHandling: + def _provider(self, monkeypatch, **overrides): + _bypass_dns(monkeypatch) + return OAuth2ClientCredentialsProvider(_conn(**overrides)) + + def test_success_caches_with_correct_expires_at(self, monkeypatch): + provider = self._provider(monkeypatch) + monkeypatch.setattr( + cc_module.requests, + "post", + lambda *a, **kw: FakeResponse(json_body={"access_token": "AT", "expires_in": 3600}), + ) + before = time.time() + token = provider.get_access_token() + after = time.time() + assert token == "AT" + cached = provider._read_cache() + # expires_at = now + 3600 - 60 (skew) + assert before + 3600 - 60 - 1 <= cached["expires_at"] <= after + 3600 - 60 + 1 + + def test_missing_access_token_fails(self, monkeypatch): + provider = self._provider(monkeypatch) + monkeypatch.setattr( + cc_module.requests, + "post", + lambda *a, **kw: FakeResponse(json_body={"expires_in": 3600}), + ) + with pytest.raises(RuntimeError) as excinfo: + provider.get_access_token() + # Safe message: no client_secret leak + assert CLIENT_SECRET not in str(excinfo.value) + + @pytest.mark.parametrize("token_type", ["Bearer", "bearer", "BEARER"]) + def test_token_type_case_insensitive(self, monkeypatch, token_type): + provider = self._provider(monkeypatch) + monkeypatch.setattr( + cc_module.requests, + "post", + lambda *a, **kw: FakeResponse( + json_body={"access_token": "AT", "token_type": token_type, "expires_in": 3600} + ), + ) + assert provider.get_access_token() == "AT" + + def test_missing_token_type_defaults_to_bearer(self, monkeypatch): + # No token_type field at all in the response — provider must accept it + # and treat it as Bearer per RFC 6749 §5.1. + provider = self._provider(monkeypatch) + monkeypatch.setattr( + cc_module.requests, + "post", + lambda *a, **kw: FakeResponse(json_body={"access_token": "AT", "expires_in": 3600}), + ) + assert provider.get_access_token() == "AT" + + @pytest.mark.parametrize("token_type", ["MAC", "DPoP", "Token"]) + def test_unsupported_token_type_fails(self, monkeypatch, token_type): + provider = self._provider(monkeypatch) + monkeypatch.setattr( + cc_module.requests, + "post", + lambda *a, **kw: FakeResponse( + json_body={"access_token": "AT", "token_type": token_type, "expires_in": 3600} + ), + ) + with pytest.raises(RuntimeError): + provider.get_access_token() + + @pytest.mark.parametrize("expires_in", [None, 0, -100, "abc"]) + def test_invalid_expires_in_defaults_to_300(self, monkeypatch, caplog, expires_in): + provider = self._provider(monkeypatch) + body = {"access_token": "AT"} + if expires_in is not None: + body["expires_in"] = expires_in + monkeypatch.setattr( + cc_module.requests, + "post", + lambda *a, **kw: FakeResponse(json_body=body), + ) + before = time.time() + with caplog.at_level(logging.WARNING, logger=cc_module.logger.name): + provider.get_access_token() + after = time.time() + cached = provider._read_cache() + # expires_at = now + 300 - 60 = now + 240 + assert before + 240 - 1 <= cached["expires_at"] <= after + 240 + 1 + if expires_in is None: + # WARNING log emitted at least once when explicitly missing + warnings = [r for r in caplog.records if r.levelno == logging.WARNING and "expires_in" in r.message] + assert warnings, "expected a WARNING log mentioning expires_in" + + +class TestCaching: + def _provider(self, monkeypatch, **overrides): + _bypass_dns(monkeypatch) + return OAuth2ClientCredentialsProvider(_conn(**overrides)) + + def test_cached_token_reused(self, monkeypatch): + provider = self._provider(monkeypatch) + calls = {"n": 0} + + def fake_post(*a, **kw): + calls["n"] += 1 + return FakeResponse(json_body={"access_token": "AT", "expires_in": 3600}) + + monkeypatch.setattr(cc_module.requests, "post", fake_post) + provider.get_access_token() + provider.get_access_token() + provider.get_access_token() + assert calls["n"] == 1 + + def test_token_within_skew_triggers_refresh(self, monkeypatch): + provider = self._provider(monkeypatch) + # Manually inject a token that is "expired" (skew already elapsed) + provider._memory_cache = { + "access_token": "OLD", + "token_type": "Bearer", + "expires_at": time.time() - 1, + } + calls = {"n": 0} + + def fake_post(*a, **kw): + calls["n"] += 1 + return FakeResponse(json_body={"access_token": "NEW", "expires_in": 3600}) + + monkeypatch.setattr(cc_module.requests, "post", fake_post) + assert provider.get_access_token() == "NEW" + assert calls["n"] == 1 + + def test_expired_token_triggers_refresh(self, monkeypatch): + provider = self._provider(monkeypatch) + provider._memory_cache = { + "access_token": "OLD", + "token_type": "Bearer", + "expires_at": time.time() - 3600, + } + monkeypatch.setattr( + cc_module.requests, + "post", + lambda *a, **kw: FakeResponse(json_body={"access_token": "NEW", "expires_in": 3600}), + ) + assert provider.get_access_token() == "NEW" + + def test_clear_cached_token_clears_cache_and_refetches(self, monkeypatch): + provider = self._provider(monkeypatch) + calls = {"n": 0} + tokens = iter(["FIRST", "SECOND"]) + + def fake_post(*a, **kw): + calls["n"] += 1 + return FakeResponse(json_body={"access_token": next(tokens), "expires_in": 3600}) + + monkeypatch.setattr(cc_module.requests, "post", fake_post) + assert provider.get_access_token() == "FIRST" + provider.clear_cached_token() + assert provider.get_access_token() == "SECOND" + assert calls["n"] == 2 + + +class TestConcurrency: + def test_concurrent_get_token_makes_single_http_call(self, monkeypatch): + """Two threads call get_access_token() with empty cache simultaneously. + + Forces the double-checked-lock scenario: both threads must complete + their FIRST cache read (and observe empty) before either acquires the + lock. With the second read inside the lock, the second thread observes + the populated cache and skips the HTTP call. Without it, both threads + would issue the HTTP request — exactly the bug the lock prevents. + """ + _bypass_dns(monkeypatch) + provider = OAuth2ClientCredentialsProvider(_conn()) + + call_count = {"n": 0} + sync_counter = {"value": 0} + sync_lock = threading.Lock() + first_read_barrier = threading.Barrier(2) + + original_read_cache = provider._read_cache + + def synced_read_cache(): + result = original_read_cache() + with sync_lock: + sync_counter["value"] += 1 + n = sync_counter["value"] + # Only the first read from each thread (the pre-lock read) blocks + # on the barrier. Inside-lock reads (calls 3 and 4) pass through. + if n <= 2: + first_read_barrier.wait(timeout=5) + return result + + monkeypatch.setattr(provider, "_read_cache", synced_read_cache) + + def fake_post(*a, **kw): + call_count["n"] += 1 + # Hold the lock briefly so the second thread is queued behind it. + time.sleep(0.05) + return FakeResponse(json_body={"access_token": "SHARED", "expires_in": 3600}) + + monkeypatch.setattr(cc_module.requests, "post", fake_post) + + results = {} + + def worker(idx): + results[idx] = provider.get_access_token() + + t1 = threading.Thread(target=worker, args=(1,)) + t2 = threading.Thread(target=worker, args=(2,)) + t1.start() + t2.start() + t1.join(timeout=5) + t2.join(timeout=5) + + assert results[1] == "SHARED" + assert results[2] == "SHARED" + assert call_count["n"] == 1, ( + "Double-checked locking failed: both threads issued an HTTP call. " + "The second read inside the lock is missing or broken." + ) + + +class TestStorage: + def test_token_persists_across_provider_instances(self, monkeypatch): + _bypass_dns(monkeypatch) + storage = FakeStorage() + calls = {"n": 0} + + def fake_post(*a, **kw): + calls["n"] += 1 + return FakeResponse(json_body={"access_token": "PERSIST", "expires_in": 3600}) + + monkeypatch.setattr(cc_module.requests, "post", fake_post) + + p1 = OAuth2ClientCredentialsProvider( + _conn(), + handler_storage=storage, + storage_key="oauth_test", + ) + assert p1.get_access_token() == "PERSIST" + + p2 = OAuth2ClientCredentialsProvider( + _conn(), + handler_storage=storage, + storage_key="oauth_test", + ) + assert p2.get_access_token() == "PERSIST" + assert calls["n"] == 1 + + def test_in_memory_cache_when_storage_none(self, monkeypatch): + _bypass_dns(monkeypatch) + provider = OAuth2ClientCredentialsProvider(_conn()) + calls = {"n": 0} + + def fake_post(*a, **kw): + calls["n"] += 1 + return FakeResponse(json_body={"access_token": "MEM", "expires_in": 3600}) + + monkeypatch.setattr(cc_module.requests, "post", fake_post) + assert provider.get_access_token() == "MEM" + assert provider.get_access_token() == "MEM" + assert calls["n"] == 1 + + def test_storage_set_failure_falls_back_to_memory(self, monkeypatch, caplog): + _bypass_dns(monkeypatch) + storage = FailingStorage() + provider = OAuth2ClientCredentialsProvider( + _conn(), + handler_storage=storage, + ) + + def fake_post(*a, **kw): + return FakeResponse(json_body={"access_token": "FB", "expires_in": 3600}) + + monkeypatch.setattr(cc_module.requests, "post", fake_post) + with caplog.at_level(logging.DEBUG, logger=cc_module.logger.name): + assert provider.get_access_token() == "FB" + # Subsequent get_access_token uses in-memory fallback + assert provider.get_access_token() == "FB" + assert provider._memory_cache is not None + assert any("in-memory" in r.message or "fallback" in r.message.lower() for r in caplog.records) + + def test_cache_does_not_contain_credentials(self, monkeypatch): + _bypass_dns(monkeypatch) + storage = FakeStorage() + provider = OAuth2ClientCredentialsProvider( + _conn(scope="read:foo", audience="https://api.example.com"), + handler_storage=storage, + storage_key="oauth_test", + ) + monkeypatch.setattr( + cc_module.requests, + "post", + lambda *a, **kw: FakeResponse(json_body={"access_token": "AT", "expires_in": 3600}), + ) + provider.get_access_token() + cached = storage._data["oauth_test"] + as_text = repr(cached) + assert CLIENT_SECRET not in as_text + assert CLIENT_ID not in as_text + assert VALID_TOKEN_URL not in as_text + assert "read:foo" not in as_text + assert "https://api.example.com" not in as_text + + +class TestCurrentSecrets: + def _provider(self, monkeypatch): + _bypass_dns(monkeypatch) + return OAuth2ClientCredentialsProvider(_conn()) + + def test_no_cached_token_returns_empty(self, monkeypatch): + provider = self._provider(monkeypatch) + assert provider.current_secrets() == [] + + def test_cached_valid_token_returned(self, monkeypatch): + provider = self._provider(monkeypatch) + monkeypatch.setattr( + cc_module.requests, + "post", + lambda *a, **kw: FakeResponse(json_body={"access_token": "S3CR3T", "expires_in": 3600}), + ) + provider.get_access_token() + assert provider.current_secrets() == ["S3CR3T"] + + def test_after_clear_cached_token_returns_empty(self, monkeypatch): + provider = self._provider(monkeypatch) + monkeypatch.setattr( + cc_module.requests, + "post", + lambda *a, **kw: FakeResponse(json_body={"access_token": "S3CR3T", "expires_in": 3600}), + ) + provider.get_access_token() + provider.clear_cached_token() + assert provider.current_secrets() == [] + + def test_current_secrets_never_returns_client_secret(self, monkeypatch): + # Across uncached, cached, and post-clear states, the client_secret + # must never appear in current_secrets() — that list feeds the + # response-scrub layer and exposing the static credential there + # would risk redacting upstream payloads that legitimately contain + # the same string, but more importantly it would imply the secret + # is in some live cache, which it must not be. + provider = self._provider(monkeypatch) + + # Uncached + assert CLIENT_SECRET not in provider.current_secrets() + + monkeypatch.setattr( + cc_module.requests, + "post", + lambda *a, **kw: FakeResponse(json_body={"access_token": "AT", "expires_in": 3600}), + ) + provider.get_access_token() + + # Cached — only the access_token should show up + secrets = provider.current_secrets() + assert secrets == ["AT"] + assert CLIENT_SECRET not in secrets + assert CLIENT_ID not in secrets + + provider.clear_cached_token() + assert CLIENT_SECRET not in provider.current_secrets() + + +class TestErrorSanitization: + def _provider(self, monkeypatch): + _bypass_dns(monkeypatch) + return OAuth2ClientCredentialsProvider(_conn()) + + def test_failure_message_does_not_leak_client_secret(self, monkeypatch): + provider = self._provider(monkeypatch) + monkeypatch.setattr( + cc_module.requests, + "post", + lambda *a, **kw: FakeResponse(status_code=500, json_body={"error": "boom"}), + ) + with pytest.raises(RuntimeError) as excinfo: + provider.get_access_token() + assert CLIENT_SECRET not in str(excinfo.value) + + def test_failure_message_does_not_leak_access_token(self, monkeypatch): + provider = self._provider(monkeypatch) + # First, succeed and cache + monkeypatch.setattr( + cc_module.requests, + "post", + lambda *a, **kw: FakeResponse(json_body={"access_token": "very-secret-token", "expires_in": 1}), + ) + provider.get_access_token() + # Now force expiry and make next call fail + provider._memory_cache["expires_at"] = time.time() - 1 + monkeypatch.setattr( + cc_module.requests, + "post", + lambda *a, **kw: FakeResponse(status_code=401, json_body={"error": "invalid_grant"}), + ) + with pytest.raises(RuntimeError) as excinfo: + provider.get_access_token() + assert "very-secret-token" not in str(excinfo.value) + + def test_401_includes_provider_error_fields(self, monkeypatch): + provider = self._provider(monkeypatch) + monkeypatch.setattr( + cc_module.requests, + "post", + lambda *a, **kw: FakeResponse( + status_code=401, + json_body={ + "error": "invalid_client", + "error_description": "Client authentication failed", + }, + ), + ) + with pytest.raises(RuntimeError) as excinfo: + provider.get_access_token() + msg = str(excinfo.value) + assert "invalid_client" in msg + assert "Client authentication failed" in msg + assert CLIENT_SECRET not in msg + assert CLIENT_ID in msg or "client_id" in msg + + def test_redirect_response_treated_as_error(self, monkeypatch): + provider = self._provider(monkeypatch) + monkeypatch.setattr( + cc_module.requests, + "post", + lambda *a, **kw: FakeResponse(status_code=302, is_redirect=True), + ) + with pytest.raises(RuntimeError) as excinfo: + provider.get_access_token() + assert "redirect" in str(excinfo.value).lower() + + def test_response_size_cap_aborts(self, monkeypatch): + provider = self._provider(monkeypatch) + oversize = b"x" * (cc_module.MAX_RESPONSE_BYTES + 100) + monkeypatch.setattr( + cc_module.requests, + "post", + lambda *a, **kw: FakeResponse(status_code=200, raw_body=oversize), + ) + with pytest.raises(RuntimeError) as excinfo: + provider.get_access_token() + assert "exceeded" in str(excinfo.value).lower() + + def test_transport_error_message_redacts_secret(self, monkeypatch): + provider = self._provider(monkeypatch) + + def fake_post(*a, **kw): + raise requests.ConnectionError(f"oops while sending client_secret={CLIENT_SECRET}") + + monkeypatch.setattr(cc_module.requests, "post", fake_post) + with pytest.raises(RuntimeError) as excinfo: + provider.get_access_token() + chained = excinfo.value.__cause__ + assert CLIENT_SECRET not in str(excinfo.value) + assert CLIENT_SECRET not in str(chained) diff --git a/tests/unit/utilities/handlers/auth_utilities/oauth2/test_package.py b/tests/unit/utilities/handlers/auth_utilities/oauth2/test_package.py new file mode 100644 index 0000000000..0c8952dbbc --- /dev/null +++ b/tests/unit/utilities/handlers/auth_utilities/oauth2/test_package.py @@ -0,0 +1,77 @@ +"""Sanity tests for the auth_utilities.oauth2 package surface. + +These tests intentionally do not exercise behavior — they verify only that +the public import path, provider class, and exposed constants exist as the +rest of the codebase expects to import them. +""" + +from __future__ import annotations + + +def test_provider_imports_from_package_root(): + from mindsdb.integrations.utilities.handlers.auth_utilities.oauth2 import ( + OAuth2ClientCredentialsProvider, + ) + + assert isinstance(OAuth2ClientCredentialsProvider, type) + + +def test_provider_exposes_required_methods(): + from mindsdb.integrations.utilities.handlers.auth_utilities.oauth2 import ( + OAuth2ClientCredentialsProvider, + ) + + for method in ("get_access_token", "clear_cached_token", "current_secrets"): + assert callable(getattr(OAuth2ClientCredentialsProvider, method, None)), ( + f"OAuth2ClientCredentialsProvider is missing public method {method!r}" + ) + + +def test_constructor_accepts_connection_data_keyword(): + import inspect + from mindsdb.integrations.utilities.handlers.auth_utilities.oauth2 import ( + OAuth2ClientCredentialsProvider, + ) + + sig = inspect.signature(OAuth2ClientCredentialsProvider.__init__) + params = sig.parameters + assert "connection_data" in params + assert "handler_storage" in params + assert "storage_key" in params + # handler_storage and storage_key default to None so the provider can be + # constructed without persistent storage. + assert params["handler_storage"].default is None + assert params["storage_key"].default is None + + +def test_public_constants_exposed_at_package_root(): + from mindsdb.integrations.utilities.handlers.auth_utilities.oauth2 import ( + ALLOWED_AUTH_METHODS, + CONNECT_TIMEOUT_SECONDS, + DEFAULT_EXPIRES_IN_SECONDS, + DEFAULT_STORAGE_KEY, + DEFAULT_TOKEN_AUTH_METHOD, + EXPIRY_SKEW_SECONDS, + MAX_RESPONSE_BYTES, + READ_TIMEOUT_SECONDS, + ) + + assert EXPIRY_SKEW_SECONDS == 60 + assert DEFAULT_EXPIRES_IN_SECONDS == 300 + assert MAX_RESPONSE_BYTES == 64 * 1024 + assert CONNECT_TIMEOUT_SECONDS == 10 + assert READ_TIMEOUT_SECONDS == 30 + assert DEFAULT_TOKEN_AUTH_METHOD == "client_secret_post" + assert "client_secret_post" in ALLOWED_AUTH_METHODS + assert "client_secret_basic" in ALLOWED_AUTH_METHODS + assert isinstance(DEFAULT_STORAGE_KEY, str) and DEFAULT_STORAGE_KEY + + +def test_client_credentials_module_importable(): + # The submodule path itself should resolve, in case callers want + # access to internal helpers (e.g. the SSRF validator) for tests. + from mindsdb.integrations.utilities.handlers.auth_utilities.oauth2 import ( + client_credentials, + ) + + assert hasattr(client_credentials, "OAuth2ClientCredentialsProvider")