Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 40 additions & 31 deletions src/databricks/labs/lsql/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,10 @@ def __init__( # pylint: disable=too-many-arguments
self._disable_magic = disable_magic
self._byte_limit = byte_limit
self._disposition = disposition
self._type_converters = {
self._type_converters: dict[ColumnInfoTypeName, Callable[[str], Any]] = {
ColumnInfoTypeName.ARRAY: json.loads,
ColumnInfoTypeName.BINARY: base64.b64decode,
ColumnInfoTypeName.BOOLEAN: bool,
ColumnInfoTypeName.BOOLEAN: lambda value: value.lower() == "true",
ColumnInfoTypeName.CHAR: str,
ColumnInfoTypeName.DATE: self._parse_date,
ColumnInfoTypeName.DOUBLE: float,
Expand Down Expand Up @@ -318,11 +318,11 @@ def fetch_all(
result_data = execute_response.result
if result_data is None:
return
row_factory, col_conv = self._result_schema(execute_response)
converter = self._result_converter(execute_response)
while True:
if result_data.data_array:
for data in result_data.data_array:
yield row_factory(col_conv[i](value) for i, value in enumerate(data))
yield converter(data)
next_chunk_index = result_data.next_chunk_index
if result_data.external_links:
for external_link in result_data.external_links:
Expand All @@ -331,7 +331,7 @@ def fetch_all(
response = self._http.get(external_link.external_link)
response.raise_for_status()
for data in response.json():
yield row_factory(col_conv[i](value) for i, value in enumerate(data))
yield converter(data)
if not next_chunk_index:
return
result_data = self._ws.statement_execution.get_statement_result_chunk_n(
Expand Down Expand Up @@ -373,6 +373,41 @@ def fetch_value(self, statement: str, **kwargs) -> Any | None:
return v
return None

def _result_converter(self, execute_response: ExecuteStatementResponse):
"""Get the result schema from the execute response."""
manifest = execute_response.manifest
if not manifest:
msg = f"missing manifest: {execute_response}"
raise ValueError(msg)
manifest_schema = manifest.schema
if not manifest_schema:
msg = f"missing schema: {manifest}"
raise ValueError(msg)
col_names = []
col_conv: list[Callable[[str], Any]] = []
columns = manifest_schema.columns
if not columns:
columns = []
for col in columns:
assert col.name is not None
col_names.append(col.name)
type_name = col.type_name
if not type_name:
type_name = ColumnInfoTypeName.NULL
conv = self._type_converters.get(type_name, None)
if conv is None:
msg = f"{col.name} has no {type_name.value} converter"
raise ValueError(msg)
col_conv.append(conv)
row_factory = Row.factory(col_names)

def converter(data: list[str | None]) -> Row:
# enumerate() + iterator + tuple constructor makes it more performant on larger humber of records
# for Python, even though it's far less readable code.
return row_factory(col_conv[i](value) if value else None for i, value in enumerate(data))

return converter

def _statement_timeouts(self, timeout) -> tuple[timedelta, str | None]:
"""Set server-side and client-side timeouts for statement execution."""
if timeout is None:
Expand Down Expand Up @@ -481,29 +516,3 @@ def _add_limit(statement: str) -> str:
raise ValueError(f"limit is not 1: {limit.text('expression')}")
return statement_ast.limit(expression=1).sql("databricks")
return statement

def _result_schema(self, execute_response: ExecuteStatementResponse):
"""Get the result schema from the execute response."""
manifest = execute_response.manifest
if not manifest:
msg = f"missing manifest: {execute_response}"
raise ValueError(msg)
manifest_schema = manifest.schema
if not manifest_schema:
msg = f"missing schema: {manifest}"
raise ValueError(msg)
col_names = []
col_conv = []
columns = manifest_schema.columns
if not columns:
columns = []
for col in columns:
assert col.name is not None
assert col.type_name is not None
col_names.append(col.name)
conv = self._type_converters.get(col.type_name, None)
if conv is None:
msg = f"{col.name} has no {col.type_name.value} converter"
raise ValueError(msg)
col_conv.append(conv)
return Row.factory(col_names), col_conv
2 changes: 1 addition & 1 deletion tests/integration/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
from databricks.sdk.service.sql import Disposition

from databricks.labs.lsql.core import StatementExecutionExt, Row
from databricks.labs.lsql.core import Row, StatementExecutionExt

logger = logging.getLogger(__name__)

Expand Down