Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions hotdata_runtime/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,22 +475,34 @@ def _wait_result_ready(
f"(last status: {getattr(last, 'status', None)})"
)

def execute_sql(self, sql: str) -> QueryResult:
def execute_sql(self, sql: str, *, database: str | None = None) -> QueryResult:
"""Execute SQL and return a :class:`QueryResult`.

Pass ``database`` to scope the query to a managed database. The name
is resolved to a database ID once before the retry loop, and the
``X-Database-Id`` header is sent with every attempt. Inside a managed
database the built-in catalog is always ``"default"``, so table
references should use ``"default"."<schema>"."<table>"``.
"""
database_id = self.resolve_managed_database(database).id if database else None
last_err: BaseException | None = None
for attempt in range(3):
try:
return self._execute_sql_once(sql)
return self._execute_sql_once(sql, database_id=database_id)
except (ProtocolError, ConnectionResetError, Urllib3HTTPError) as e:
last_err = e
if attempt == 2:
raise
time.sleep(0.2 * (2**attempt))
raise last_err # pragma: no cover

def _execute_sql_once(self, sql: str) -> QueryResult:
def _execute_sql_once(self, sql: str, *, database_id: str | None = None) -> QueryResult:
q = self._query_api()
try:
raw = q.query(QueryRequest(sql=sql))
if database_id:
raw = q.query(QueryRequest(sql=sql), x_database_id=database_id)
else:
raw = q.query(QueryRequest(sql=sql))
except ApiException as e:
raise RuntimeError(e.reason or str(e)) from e

Expand Down
61 changes: 61 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,67 @@ def list_results(self, *, limit: int, offset: int):
assert out[0].to_dict()["created_at"] == "2026-01-01T00:00:00Z"


def test_execute_sql_sends_no_database_id_by_default():
from hotdata.models.query_response import QueryResponse as _QR

client = HotdataClient("k", "ws", host="https://api.hotdata.dev")

class FakeQueryApi:
def __init__(self):
self.calls: list[dict] = []

def query(self, request, **kwargs):
self.calls.append(kwargs)
return _QR(
columns=["n"],
rows=[[1]],
row_count=1,
nullable=[False],
result_id="res_1",
query_run_id="qrun_1",
execution_time_ms=1,
)

fake_q = FakeQueryApi()
with patch.object(client, "_query_api", return_value=fake_q):
client.execute_sql("SELECT 1")

assert fake_q.calls == [{}]


def test_execute_sql_resolves_database_and_sends_x_database_id():
from hotdata.models.query_response import QueryResponse as _QR
from types import SimpleNamespace

client = HotdataClient("k", "ws", host="https://api.hotdata.dev")

class FakeQueryApi:
def __init__(self):
self.calls: list[dict] = []

def query(self, request, **kwargs):
self.calls.append(kwargs)
return _QR(
columns=["n"],
rows=[[1]],
row_count=1,
nullable=[False],
result_id="res_1",
query_run_id="qrun_1",
execution_time_ms=1,
)

fake_q = FakeQueryApi()
fake_db = SimpleNamespace(id="db_abc")

with patch.object(client, "_query_api", return_value=fake_q), \
patch.object(client, "resolve_managed_database", return_value=fake_db) as resolve:
client.execute_sql('SELECT * FROM "default"."public"."orders"', database="my_db")

resolve.assert_called_once_with("my_db")
assert fake_q.calls == [{"x_database_id": "db_abc"}]


def test_list_run_history_returns_normalized_items():
client = HotdataClient("k", "ws", host="https://api.hotdata.dev")
listing = SimpleNamespace(
Expand Down
Loading