diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index 2910576f8..33e02a231 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -1,4 +1,5 @@ import logging +import re from typing import Dict, Tuple, List, Optional, Any, Type from databricks.sql.thrift_api.TCLIService import ttypes @@ -170,37 +171,70 @@ def _create_backend( } return databricks_client_class(**common_args) + # All-purpose-compute Thrift http_path: + # [/]sql/protocolv1/o//[/...][?...] + _CLUSTER_PATH_ORG_ID_RE = re.compile(r"(?:^|/)sql/protocolv1/o/(\d+)/[^/?]+") + @staticmethod def _extract_spog_headers(http_path, existing_headers): - """Extract ?o= from http_path and return as a header dict for SPOG routing.""" - if not http_path or "?" not in http_path: + """Extract the workspace ID from http_path for SPOG routing and return it + as an ``x-databricks-org-id`` header dict. + + Two sources are inspected, in priority order: + 1. ``?o=`` query parameter in http_path (warehouse paths + typically encode the workspace this way on SPOG). + 2. ``/sql/protocolv1/o//`` path segment + (all-purpose compute paths embed the workspace in the path itself). + + An explicit ``x-databricks-org-id`` already set by the caller wins over + both. Returns an empty dict when no workspace ID can be determined. + + On SPOG (Custom URL) hosts this header is required for non-Thrift + endpoints — telemetry, feature flags, SEA — to be routed to the right + workspace. Without it, PoPP falls back to default routing and + workspace-scoped requests are redirected to ``/login``. + """ + if not http_path: return {} - from urllib.parse import parse_qs - - query_string = http_path.split("?", 1)[1] - params = parse_qs(query_string) - org_id = params.get("o", [None])[0] - if not org_id: + # Caller already set the header — never override. + if any(k == "x-databricks-org-id" for k, _ in existing_headers): logger.debug( - "SPOG header extraction: http_path has query string but no ?o= param, " - "skipping x-databricks-org-id injection" + "SPOG header extraction: x-databricks-org-id already set by caller, " + "not extracting from http_path" ) return {} - # Don't override if explicitly set - if any(k == "x-databricks-org-id" for k, _ in existing_headers): + org_id = None + source = None + + if "?" in http_path: + from urllib.parse import parse_qs + + query_string = http_path.split("?", 1)[1] + params = parse_qs(query_string) + value = params.get("o", [None])[0] + if value: + org_id = value + source = "?o= in http_path" + + if org_id is None: + cluster_match = Session._CLUSTER_PATH_ORG_ID_RE.search(http_path) + if cluster_match: + org_id = cluster_match.group(1) + source = "cluster path segment" + + if org_id is None: logger.debug( - "SPOG header extraction: x-databricks-org-id already set by caller, " - "not overriding with ?o=%s from http_path", - org_id, + "SPOG header extraction: no workspace ID found in http_path, " + "skipping x-databricks-org-id injection" ) return {} logger.debug( - "SPOG header extraction: injecting x-databricks-org-id=%s " - "(extracted from ?o= in http_path)", + "SPOG header extraction: injecting x-databricks-org-id=%s (extracted from %s)", org_id, + source, ) return {"x-databricks-org-id": org_id} diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 5d37cd9a5..f31c785db 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -273,3 +273,41 @@ def test_multiple_query_params(self): "/sql/1.0/warehouses/abc123?o=12345&extra=val", [] ) assert result == {"x-databricks-org-id": "12345"} + + def test_extracts_org_id_from_cluster_path_segment(self): + # All-purpose-compute path embeds workspace ID in /o//. + # Without ?o=, the driver must still set x-databricks-org-id so that + # telemetry and other non-Thrift requests route to the right workspace + # on SPOG hosts. + result = Session._extract_spog_headers( + "sql/protocolv1/o/6051921418418893/0528-220959-uzmcn1qt", [] + ) + assert result == {"x-databricks-org-id": "6051921418418893"} + + def test_extracts_org_id_from_cluster_path_with_leading_slash(self): + result = Session._extract_spog_headers( + "/sql/protocolv1/o/6051921418418893/0528-220959-uzmcn1qt", [] + ) + assert result == {"x-databricks-org-id": "6051921418418893"} + + def test_query_param_wins_over_cluster_path_segment(self): + # When both forms are present, ?o= takes precedence. + result = Session._extract_spog_headers( + "sql/protocolv1/o/111/0528-220959-uzmcn1qt?o=222", [] + ) + assert result == {"x-databricks-org-id": "222"} + + def test_explicit_header_wins_over_cluster_path_segment(self): + existing = [("x-databricks-org-id", "from-caller")] + result = Session._extract_spog_headers( + "sql/protocolv1/o/111/0528-220959-uzmcn1qt", existing + ) + assert result == {} + + def test_warehouse_path_without_query_param_returns_empty(self): + # Regression guard: the new cluster-path regex must not accidentally + # match warehouse paths (which never embed the workspace ID). + result = Session._extract_spog_headers( + "/sql/1.0/warehouses/abc123", [] + ) + assert result == {}