diff --git a/hotdata_runtime/client.py b/hotdata_runtime/client.py index 4665f78..66ddc0b 100644 --- a/hotdata_runtime/client.py +++ b/hotdata_runtime/client.py @@ -475,11 +475,20 @@ def _wait_result_ready( f"(last status: {getattr(last, 'status', None)})" ) - def execute_sql(self, sql: str) -> QueryResult: + def execute_sql(self, sql: str, *, database: str | None = None) -> QueryResult: + """Execute SQL and return a :class:`QueryResult`. + + Pass ``database`` to scope the query to a managed database. The name + is resolved to a database ID once before the retry loop, and the + ``X-Database-Id`` header is sent with every attempt. Inside a managed + database the built-in catalog is always ``"default"``, so table + references should use ``"default"."".""``. + """ + database_id = self.resolve_managed_database(database).id if database else None last_err: BaseException | None = None for attempt in range(3): try: - return self._execute_sql_once(sql) + return self._execute_sql_once(sql, database_id=database_id) except (ProtocolError, ConnectionResetError, Urllib3HTTPError) as e: last_err = e if attempt == 2: @@ -487,10 +496,13 @@ def execute_sql(self, sql: str) -> QueryResult: time.sleep(0.2 * (2**attempt)) raise last_err # pragma: no cover - def _execute_sql_once(self, sql: str) -> QueryResult: + def _execute_sql_once(self, sql: str, *, database_id: str | None = None) -> QueryResult: q = self._query_api() try: - raw = q.query(QueryRequest(sql=sql)) + if database_id: + raw = q.query(QueryRequest(sql=sql), x_database_id=database_id) + else: + raw = q.query(QueryRequest(sql=sql)) except ApiException as e: raise RuntimeError(e.reason or str(e)) from e diff --git a/tests/test_client.py b/tests/test_client.py index 9d3a025..d70f22b 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -200,6 +200,67 @@ def list_results(self, *, limit: int, offset: int): assert out[0].to_dict()["created_at"] == "2026-01-01T00:00:00Z" +def test_execute_sql_sends_no_database_id_by_default(): + from hotdata.models.query_response import QueryResponse as _QR + + client = HotdataClient("k", "ws", host="https://api.hotdata.dev") + + class FakeQueryApi: + def __init__(self): + self.calls: list[dict] = [] + + def query(self, request, **kwargs): + self.calls.append(kwargs) + return _QR( + columns=["n"], + rows=[[1]], + row_count=1, + nullable=[False], + result_id="res_1", + query_run_id="qrun_1", + execution_time_ms=1, + ) + + fake_q = FakeQueryApi() + with patch.object(client, "_query_api", return_value=fake_q): + client.execute_sql("SELECT 1") + + assert fake_q.calls == [{}] + + +def test_execute_sql_resolves_database_and_sends_x_database_id(): + from hotdata.models.query_response import QueryResponse as _QR + from types import SimpleNamespace + + client = HotdataClient("k", "ws", host="https://api.hotdata.dev") + + class FakeQueryApi: + def __init__(self): + self.calls: list[dict] = [] + + def query(self, request, **kwargs): + self.calls.append(kwargs) + return _QR( + columns=["n"], + rows=[[1]], + row_count=1, + nullable=[False], + result_id="res_1", + query_run_id="qrun_1", + execution_time_ms=1, + ) + + fake_q = FakeQueryApi() + fake_db = SimpleNamespace(id="db_abc") + + with patch.object(client, "_query_api", return_value=fake_q), \ + patch.object(client, "resolve_managed_database", return_value=fake_db) as resolve: + client.execute_sql('SELECT * FROM "default"."public"."orders"', database="my_db") + + resolve.assert_called_once_with("my_db") + assert fake_q.calls == [{"x_database_id": "db_abc"}] + + def test_list_run_history_returns_normalized_items(): client = HotdataClient("k", "ws", host="https://api.hotdata.dev") listing = SimpleNamespace(