From dee8dd9707b5c46fcfc377c97756eadc24fd1c49 Mon Sep 17 00:00:00 2001 From: Serge Smertin Date: Wed, 13 Mar 2024 15:59:37 +0100 Subject: [PATCH] Fixed row converter to properly handle nullable values --- src/databricks/labs/lsql/core.py | 71 ++++++++++++++++++-------------- tests/integration/test_core.py | 2 +- 2 files changed, 41 insertions(+), 32 deletions(-) diff --git a/src/databricks/labs/lsql/core.py b/src/databricks/labs/lsql/core.py index e4feca02..0c008b90 100644 --- a/src/databricks/labs/lsql/core.py +++ b/src/databricks/labs/lsql/core.py @@ -147,10 +147,10 @@ def __init__( # pylint: disable=too-many-arguments self._disable_magic = disable_magic self._byte_limit = byte_limit self._disposition = disposition - self._type_converters = { + self._type_converters: dict[ColumnInfoTypeName, Callable[[str], Any]] = { ColumnInfoTypeName.ARRAY: json.loads, ColumnInfoTypeName.BINARY: base64.b64decode, - ColumnInfoTypeName.BOOLEAN: bool, + ColumnInfoTypeName.BOOLEAN: lambda value: value.lower() == "true", ColumnInfoTypeName.CHAR: str, ColumnInfoTypeName.DATE: self._parse_date, ColumnInfoTypeName.DOUBLE: float, @@ -318,11 +318,11 @@ def fetch_all( result_data = execute_response.result if result_data is None: return - row_factory, col_conv = self._result_schema(execute_response) + converter = self._result_converter(execute_response) while True: if result_data.data_array: for data in result_data.data_array: - yield row_factory(col_conv[i](value) for i, value in enumerate(data)) + yield converter(data) next_chunk_index = result_data.next_chunk_index if result_data.external_links: for external_link in result_data.external_links: @@ -331,7 +331,7 @@ def fetch_all( response = self._http.get(external_link.external_link) response.raise_for_status() for data in response.json(): - yield row_factory(col_conv[i](value) for i, value in enumerate(data)) + yield converter(data) if not next_chunk_index: return result_data = self._ws.statement_execution.get_statement_result_chunk_n( @@ -373,6 +373,41 @@ def fetch_value(self, statement: str, **kwargs) -> Any | None: return v return None + def _result_converter(self, execute_response: ExecuteStatementResponse): + """Get the result schema from the execute response.""" + manifest = execute_response.manifest + if not manifest: + msg = f"missing manifest: {execute_response}" + raise ValueError(msg) + manifest_schema = manifest.schema + if not manifest_schema: + msg = f"missing schema: {manifest}" + raise ValueError(msg) + col_names = [] + col_conv: list[Callable[[str], Any]] = [] + columns = manifest_schema.columns + if not columns: + columns = [] + for col in columns: + assert col.name is not None + col_names.append(col.name) + type_name = col.type_name + if not type_name: + type_name = ColumnInfoTypeName.NULL + conv = self._type_converters.get(type_name, None) + if conv is None: + msg = f"{col.name} has no {type_name.value} converter" + raise ValueError(msg) + col_conv.append(conv) + row_factory = Row.factory(col_names) + + def converter(data: list[str | None]) -> Row: + # enumerate() + iterator + tuple constructor makes it more performant on larger humber of records + # for Python, even though it's far less readable code. + return row_factory(col_conv[i](value) if value else None for i, value in enumerate(data)) + + return converter + def _statement_timeouts(self, timeout) -> tuple[timedelta, str | None]: """Set server-side and client-side timeouts for statement execution.""" if timeout is None: @@ -481,29 +516,3 @@ def _add_limit(statement: str) -> str: raise ValueError(f"limit is not 1: {limit.text('expression')}") return statement_ast.limit(expression=1).sql("databricks") return statement - - def _result_schema(self, execute_response: ExecuteStatementResponse): - """Get the result schema from the execute response.""" - manifest = execute_response.manifest - if not manifest: - msg = f"missing manifest: {execute_response}" - raise ValueError(msg) - manifest_schema = manifest.schema - if not manifest_schema: - msg = f"missing schema: {manifest}" - raise ValueError(msg) - col_names = [] - col_conv = [] - columns = manifest_schema.columns - if not columns: - columns = [] - for col in columns: - assert col.name is not None - assert col.type_name is not None - col_names.append(col.name) - conv = self._type_converters.get(col.type_name, None) - if conv is None: - msg = f"{col.name} has no {col.type_name.value} converter" - raise ValueError(msg) - col_conv.append(conv) - return Row.factory(col_names), col_conv diff --git a/tests/integration/test_core.py b/tests/integration/test_core.py index 38093a33..375a03b7 100644 --- a/tests/integration/test_core.py +++ b/tests/integration/test_core.py @@ -3,7 +3,7 @@ import pytest from databricks.sdk.service.sql import Disposition -from databricks.labs.lsql.core import StatementExecutionExt, Row +from databricks.labs.lsql.core import Row, StatementExecutionExt logger = logging.getLogger(__name__)