Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion KERNEL_REV
Original file line number Diff line number Diff line change
@@ -1 +1 @@
b4d88220cdfad8dba1cfa89892269342ae26feeb
101aa465e71991eec98102bba77aad2f7ad8faed
82 changes: 63 additions & 19 deletions src/databricks/sql/backend/kernel/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,14 @@

logger = logging.getLogger(__name__)

# Headers the kernel manages itself and that the connector must NOT
# forward via ``http_headers`` (lower-cased for case-insensitive match):
# ``authorization`` (the kernel applies the auth provider's token) and
# ``x-databricks-org-id`` (the kernel re-derives it from the ``?o=`` in
# http_path). Forwarding either is redundant and trips the kernel's
# per-request skip-and-warn.
_KERNEL_MANAGED_HEADERS = frozenset({"authorization", "x-databricks-org-id"})


# ─── Client ─────────────────────────────────────────────────────────────────

Expand Down Expand Up @@ -91,13 +99,19 @@ def __init__(
):
# ``ssl_options`` is translated to the kernel's ``tls_*``
# Session kwargs in ``open_session`` (custom CA, verify
# toggles, mTLS client cert/key). ``http_headers`` /
# ``http_client`` / ``port`` are still accept-and-ignore — the
# kernel manages its own HTTP stack.
# toggles, mTLS client cert/key). ``http_headers`` is forwarded
# to the kernel as custom request headers (it carries the
# connector's composed ``User-Agent`` + any caller headers + the
# SPOG ``x-databricks-org-id``). ``http_client`` / ``port`` are
# still accept-and-ignore — the kernel manages its own HTTP
# stack.
self._server_hostname = server_hostname
self._http_path = http_path
self._auth_provider = auth_provider
self._ssl_options = ssl_options
# Caller / connector HTTP headers (list of (name, value) pairs).
# Forwarded to the kernel Session in ``open_session``.
self._http_headers = http_headers or []
# Raw auth-relevant connect() kwargs (auth_type,
# oauth_client_id/secret, redirect port, credentials_provider).
# The kernel auth bridge needs these to build OAuth kwargs — the
Expand Down Expand Up @@ -175,19 +189,45 @@ def open_session(
session_conf: Optional[Dict[str, str]] = None
if session_configuration:
session_conf = {k: str(v) for k, v in session_configuration.items()}
# Build auth kwargs here (not in ``__init__``) so the bearer
# token has the shortest possible in-process lifetime: a
# local kwargs dict is GC-eligible the moment this method
# returns, regardless of whether the kernel ``Session()``
# call succeeded or raised.
auth_kwargs = kernel_auth_kwargs(self._auth_provider, self._auth_options)
# Translate the connector's SSLOptions into the kernel's
# ``tls_*`` Session kwargs. Empty when TLS is left at defaults.
tls_kwargs = _kernel_tls_kwargs(self._ssl_options)
# Translate the connector's ``_retry_*`` kwargs into the kernel's
# ``retry_*`` Session kwargs. Empty when retry is left at defaults.
retry_kwargs = _kernel_retry_kwargs(self._retry_options)
# The kwarg builds run INSIDE the try so the ``finally`` scrub
# below always fires — including when ``kernel_auth_kwargs``
# itself raises mid-build (e.g. an OAuth token-exchange failure
# while the M2M secret is in hand). Pre-declared empty so the
# ``finally`` can reference them unconditionally even on an early
# raise. Building here (not in ``__init__``) keeps the bearer
# token's in-process lifetime as short as possible.
auth_kwargs: Dict[str, Any] = {}
tls_kwargs: Dict[str, Any] = {}
try:
auth_kwargs = kernel_auth_kwargs(self._auth_provider, self._auth_options)
# Translate the connector's SSLOptions into the kernel's
# ``tls_*`` Session kwargs. Empty when TLS is at defaults.
tls_kwargs = _kernel_tls_kwargs(self._ssl_options)
# Translate the connector's ``_retry_*`` kwargs into the
# kernel's ``retry_*`` kwargs. Empty when at defaults.
retry_kwargs = _kernel_retry_kwargs(self._retry_options)
# Forward caller / connector HTTP headers. The kernel applies
# them on every request; a caller ``User-Agent`` is appended
# to the kernel's base UA. Only pass the kwarg when there's
# something to send.
#
# We drop ``Authorization`` and ``x-databricks-org-id`` here,
# before they reach the kernel, for two reasons: (1) the
# kernel manages both itself (auth from the provider; org-id
# re-derived from the ``?o=`` in http_path), so forwarding
# them is redundant; (2) the kernel skips-and-warns those two
# names on every request, so forwarding the SPOG org-id the
# connector always injects would spam a warning per request.
# This double-walls the kernel's own reserved-name skip.
http_headers_kwargs: Dict[str, Any] = {}
if self._http_headers:
forwarded = [
(str(k), str(v))
for k, v in self._http_headers
if str(k).lower() not in _KERNEL_MANAGED_HEADERS
]
if forwarded:
http_headers_kwargs["http_headers"] = forwarded
self._kernel_session = _kernel.Session(
host=self._server_hostname,
http_path=self._http_path,
Expand All @@ -208,6 +248,7 @@ def open_session(
**auth_kwargs,
**tls_kwargs,
**retry_kwargs,
**http_headers_kwargs,
)
except Exception as exc:
raise _wrap_kernel_exception("open_session", exc) from exc
Expand Down Expand Up @@ -304,10 +345,6 @@ def execute_command(
) -> Union["ResultSet", None]:
if self._kernel_session is None:
raise InterfaceError("Cannot execute_command without an open session.")
if query_tags:
raise NotSupportedError(
"Statement-level query_tags are not yet supported on the kernel backend."
)

try:
stmt = self._kernel_session.statement()
Expand All @@ -321,6 +358,13 @@ def execute_command(
try:
try:
stmt.set_sql(operation)
if query_tags:
# Per-statement query tags. The kernel serialises the
# dict (None value -> bare key) into the SEA
# `query_tags` statement conf. ``query_tags`` is
# already ``Dict[str, Optional[str]]`` from the
# connector, which the kernel accepts directly.
stmt.set_query_tags(query_tags)
if parameters:
bind_tspark_params(stmt, parameters)
if async_op:
Expand Down
29 changes: 29 additions & 0 deletions tests/e2e/test_kernel_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,3 +332,32 @@ def test_parameterized_query_decimal(conn):
rows = cur.fetchall()
# Server echoes back as decimal.Decimal.
assert str(rows[0][0]) == "-123.45"


def test_query_tags_round_trip(kernel_conn_params):
"""Per-statement query_tags are forwarded to the kernel and accepted
by the server. Smoke-level: a malformed query_tags conf would fail
the execute. (query.history ingestion lag makes a sync tag-readback
assertion infeasible.)"""
with sql.connect(**kernel_conn_params) as c:
with c.cursor() as cur:
cur.execute(
"SELECT 1 AS n",
query_tags={"team": "platform", "production": None},
)
assert cur.fetchall()[0][0] == 1


def test_user_agent_entry_and_http_headers_round_trip(kernel_conn_params):
"""A connection with user_agent_entry (folded into the connector's
User-Agent, then appended to the kernel base UA) and a custom HTTP
header opens and queries cleanly. Replacing the kernel base UA would
break the SEA result disposition (HTTP 400); appending preserves it
— this exercises that end-to-end."""
params = dict(kernel_conn_params)
params["user_agent_entry"] = "kernel-e2e-app"
params["http_headers"] = [("X-Kernel-E2E", "yes")]
with sql.connect(**params) as c:
with c.cursor() as cur:
cur.execute("SELECT 1 AS n")
assert cur.fetchall()[0][0] == 1
119 changes: 104 additions & 15 deletions tests/unit/test_kernel_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,26 +332,43 @@ def test_execute_command_forwards_parameters_to_bind_param():
assert stmt.execute.called


def test_execute_command_rejects_query_tags():
def test_execute_command_forwards_query_tags():
"""Statement-level query_tags are forwarded to the kernel statement
via set_query_tags (the kernel serialises them into the SEA
query_tags conf). Previously rejected with NotSupportedError; now
wired (kernel PR adding Statement.set_query_tags)."""
c = _make_client()
c._kernel_session = MagicMock()
cursor = MagicMock()
cursor.arraysize = 100
cursor.buffer_size_bytes = 1024
with pytest.raises(NotSupportedError, match="query_tags"):
c.execute_command(
operation="SELECT 1",
session_id=MagicMock(),
max_rows=1,
max_bytes=1,
lz4_compression=False,
cursor=cursor,
use_cloud_fetch=False,
parameters=[],
async_op=False,
enforce_embedded_schema_correctness=False,
query_tags={"team": "x"},
)

stmt = MagicMock()
stmt.set_sql = MagicMock()
stmt.set_query_tags = MagicMock()
stmt.execute.return_value = MagicMock(
statement_id="stmt-id",
arrow_schema=MagicMock(return_value=pa.schema([("x", pa.int64())])),
)
c._kernel_session.statement.return_value = stmt

tags = {"team": "platform", "production": None}
c.execute_command(
operation="SELECT 1",
session_id=MagicMock(),
max_rows=1,
max_bytes=1,
lz4_compression=False,
cursor=cursor,
use_cloud_fetch=False,
parameters=[],
async_op=False,
enforce_embedded_schema_correctness=False,
query_tags=tags,
)

stmt.set_query_tags.assert_called_once_with(tags)
assert stmt.execute.called


def test_get_columns_accepts_none_catalog():
Expand Down Expand Up @@ -1015,3 +1032,75 @@ def test_retry_delay_default_has_no_mapping(self):
# recognised key here — it has no kernel equivalent.
out = kernel_client._kernel_retry_kwargs({"retry_delay_default": 5.0})
assert out == {}


class TestKernelHttpHeadersForwarding:
"""http_headers (the connector's caller headers + composed
User-Agent + SPOG org-id) are forwarded to the kernel Session as the
``http_headers`` kwarg. The kernel applies them per request (its own
Authorization / org-id win; a caller User-Agent is appended to the
kernel base UA)."""

def _open_capturing(self, monkeypatch, http_headers):
captured = {}

def fake_session(**kw):
captured.update(kw)
sess = MagicMock()
sess.session_id = "sess-id"
return sess

monkeypatch.setattr(kernel_client._kernel, "Session", fake_session)
c = kernel_client.KernelDatabricksClient(
server_hostname="example.cloud.databricks.com",
http_path="/sql/1.0/warehouses/abc",
auth_provider=AccessTokenAuthProvider("dapi-test"),
ssl_options=None,
http_headers=http_headers,
)
c.open_session(session_configuration=None, catalog=None, schema=None)
return captured

def test_http_headers_forwarded_to_kernel_session(self, monkeypatch):
headers = [
("User-Agent", "PyDatabricksSqlConnector/4.0 (myentry)"),
("X-Custom", "v1"),
]
captured = self._open_capturing(monkeypatch, headers)
assert captured.get("http_headers") == [
("User-Agent", "PyDatabricksSqlConnector/4.0 (myentry)"),
("X-Custom", "v1"),
]

def test_no_http_headers_omits_kwarg(self, monkeypatch):
# Empty/none headers → the kwarg isn't passed at all (kernel
# keeps its defaults).
captured = self._open_capturing(monkeypatch, [])
assert "http_headers" not in captured

def test_authorization_and_org_id_dropped_before_forwarding(self, monkeypatch):
# The connector must NOT forward Authorization / x-databricks-org-id
# to the kernel — the kernel manages both (and warns per request
# if it sees them). Double-walls the kernel's own skip.
headers = [
("Authorization", "Bearer should-not-forward"),
("X-Databricks-Org-Id", "12345"),
("User-Agent", "PyDatabricksSqlConnector/4.0 (e)"),
("X-Keep", "yes"),
]
captured = self._open_capturing(monkeypatch, headers)
fwd = captured.get("http_headers")
names = {n.lower() for n, _ in fwd}
assert "authorization" not in names
assert "x-databricks-org-id" not in names
# Non-reserved headers (incl. User-Agent) still forwarded.
assert ("User-Agent", "PyDatabricksSqlConnector/4.0 (e)") in fwd
assert ("X-Keep", "yes") in fwd

def test_only_reserved_headers_omits_kwarg(self, monkeypatch):
# If the only headers are reserved ones, nothing is forwarded
# and the kwarg is omitted entirely.
captured = self._open_capturing(
monkeypatch, [("Authorization", "Bearer x"), ("x-databricks-org-id", "1")]
)
assert "http_headers" not in captured
49 changes: 49 additions & 0 deletions tests/unit/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,3 +476,52 @@ def test_retry_kwargs_threaded_into_kernel_client(self):
assert opts["retry_stop_after_attempts_duration"] == 600.0
finally:
conn.close()


class TestKernelUserAgentForwarding:
"""user_agent_entry must reach the kernel on the use_kernel path —
session.py folds it into the composed User-Agent and includes it in
all_headers, which is passed to the kernel client as http_headers.
Guards against a regression where session.py stops folding it under
use_kernel=True (which would silently drop partner attribution)."""

PACKAGE = "databricks.sql"

def test_user_agent_entry_reaches_kernel_client_http_headers(self):
import sys
import types

pytest.importorskip(
"pyarrow", reason="kernel client module imports pyarrow at load"
)

fake = types.ModuleType("databricks_sql_kernel")
fake.KernelError = type("KernelError", (Exception,), {})
fake.Session = MagicMock()

with patch.dict(sys.modules, {"databricks_sql_kernel": fake}), patch(
"databricks.sql.backend.kernel.client.KernelDatabricksClient"
) as mock_kernel_client, patch(
"%s.session.get_python_sql_connector_auth_provider" % self.PACKAGE
):
instance = mock_kernel_client.return_value
instance.open_session.return_value = SessionId(
BackendType.SEA, "sess-id", None
)

conn = databricks.sql.connect(
server_hostname="foo",
http_path="/sql/1.0/warehouses/abc",
use_kernel=True,
access_token="dapi-xyz",
enable_telemetry=False,
user_agent_entry="my-partner-app",
)
try:
_, kwargs = mock_kernel_client.call_args
# http_headers carries a User-Agent that embeds the entry.
headers = dict(kwargs["http_headers"])
ua = headers.get("User-Agent", "")
assert "my-partner-app" in ua, f"UA was {ua!r}"
finally:
conn.close()
Loading