diff --git a/docs/guides/adapters/parameter-profile-registry.md b/docs/guides/adapters/parameter-profile-registry.md index 4a6f2e68..8d031d76 100644 --- a/docs/guides/adapters/parameter-profile-registry.md +++ b/docs/guides/adapters/parameter-profile-registry.md @@ -17,7 +17,7 @@ The table below summarizes the canonical `DriverParameterProfile` entries define ## Adding or Updating Profiles -1. Define the profile in `_registry.py` using lowercase key naming. +1. Define the profile in `_registry.py` using lowercase key naming. Ensure the adapter package (e.g., `sqlspec.adapters.duckdb.__init__`) imports the driver module so `register_driver_profile` executes during normal adapter imports; the registry does not perform lazy imports. 2. Pick the JSON strategy that matches driver capabilities (`helper`, `driver`, or `none`). 3. Declare extras as an immutable mapping; document each addition in this file and the relevant adapter guide. 4. Add or update regression coverage (see `specs/archive/driver-quality-review/research/testing_deliverables.md`). @@ -28,6 +28,8 @@ Refer to [AGENTS.md](../../AGENTS.md) for the full checklist when touching the r ## Example Usage ```python +import sqlspec.adapters.duckdb # Triggers profile registration + from sqlspec.core.parameters import get_driver_profile, build_statement_config_from_profile profile = get_driver_profile("duckdb") @@ -36,4 +38,4 @@ config = build_statement_config_from_profile(profile) print(config.parameter_config.default_parameter_style) ``` -The snippet above retrieves the DuckDB profile, builds a `StatementConfig`, and prints the default parameter style (`?`). Use the same pattern for new adapters after defining their profiles. +The snippet above imports the DuckDB adapter package (which registers its profile), retrieves the profile, builds a `StatementConfig`, and prints the default parameter style (`?`). Use the same pattern for new adapters after defining their profiles. diff --git a/pyproject.toml b/pyproject.toml index 4eddb588..e41c8faa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -462,6 +462,7 @@ ignore = [ "PLR6301", # method could be static or class method "B903", # class could be a dataclass or named tuple "PLW0603", # Using the global statement to update is discouraged + "PLW0108", # Replace lambda expression with a def ] select = ["ALL"] diff --git a/sqlspec/adapters/adbc/driver.py b/sqlspec/adapters/adbc/driver.py index b3ff6503..7e9b8ba4 100644 --- a/sqlspec/adapters/adbc/driver.py +++ b/sqlspec/adapters/adbc/driver.py @@ -7,7 +7,6 @@ import contextlib import datetime import decimal -from functools import partial from typing import TYPE_CHECKING, Any, cast from sqlspec.adapters.adbc.data_dictionary import AdbcDataDictionary @@ -16,10 +15,10 @@ from sqlspec.core.parameters import ( DriverParameterProfile, ParameterStyle, + build_null_pruning_transform, build_statement_config_from_profile, get_driver_profile, register_driver_profile, - replace_null_parameters_with_literals, ) from sqlspec.core.result import create_arrow_result from sqlspec.core.statement import SQL, StatementConfig @@ -758,7 +757,7 @@ def get_adbc_statement_config(detected_dialect: str) -> StatementConfig: } if detected_dialect in {"postgres", "postgresql"}: - parameter_overrides["ast_transformer"] = partial(replace_null_parameters_with_literals, dialect=sqlglot_dialect) + parameter_overrides["ast_transformer"] = build_null_pruning_transform(dialect=sqlglot_dialect) return build_statement_config_from_profile( get_driver_profile("adbc"), diff --git a/sqlspec/adapters/aiosqlite/driver.py b/sqlspec/adapters/aiosqlite/driver.py index e38c5236..8fcbc114 100644 --- a/sqlspec/adapters/aiosqlite/driver.py +++ b/sqlspec/adapters/aiosqlite/driver.py @@ -29,6 +29,7 @@ UniqueViolationError, ) from sqlspec.utils.serializers import to_json +from sqlspec.utils.type_converters import build_decimal_converter, build_time_iso_converter if TYPE_CHECKING: from contextlib import AbstractAsyncContextManager @@ -49,6 +50,8 @@ SQLITE_CANTOPEN_CODE = 14 SQLITE_IOERR_CODE = 10 SQLITE_MISMATCH_CODE = 20 +_TIME_TO_ISO = build_time_iso_converter() +_DECIMAL_TO_STRING = build_decimal_converter(mode="string") class AiosqliteCursor: @@ -327,18 +330,6 @@ def _bool_to_int(value: bool) -> int: return int(value) -def _datetime_to_iso(value: datetime) -> str: - return value.isoformat() - - -def _date_to_iso(value: date) -> str: - return value.isoformat() - - -def _decimal_to_str(value: Decimal) -> str: - return str(value) - - def _build_aiosqlite_profile() -> DriverParameterProfile: """Create the AIOSQLite driver parameter profile.""" @@ -356,9 +347,9 @@ def _build_aiosqlite_profile() -> DriverParameterProfile: json_serializer_strategy="helper", custom_type_coercions={ bool: _bool_to_int, - datetime: _datetime_to_iso, - date: _date_to_iso, - Decimal: _decimal_to_str, + datetime: _TIME_TO_ISO, + date: _TIME_TO_ISO, + Decimal: _DECIMAL_TO_STRING, }, default_dialect="sqlite", ) diff --git a/sqlspec/adapters/bigquery/driver.py b/sqlspec/adapters/bigquery/driver.py index 9d0f6314..483a90c4 100644 --- a/sqlspec/adapters/bigquery/driver.py +++ b/sqlspec/adapters/bigquery/driver.py @@ -6,6 +6,7 @@ import datetime import logging +from collections.abc import Callable from decimal import Decimal from typing import TYPE_CHECKING, Any, cast @@ -18,9 +19,9 @@ from sqlspec.core import ParameterStyle, StatementConfig, create_arrow_result, get_cache_config from sqlspec.core.parameters import ( DriverParameterProfile, + build_literal_inlining_transform, build_statement_config_from_profile, register_driver_profile, - replace_placeholders_with_literals, ) from sqlspec.driver import ExecutionResult, SyncDriverAdapterBase from sqlspec.exceptions import ( @@ -339,7 +340,13 @@ class BigQueryDriver(SyncDriverAdapterBase): type coercion, error handling, and query job management. """ - __slots__ = ("_data_dictionary", "_default_query_job_config", "_json_serializer", "_type_converter") + __slots__ = ( + "_data_dictionary", + "_default_query_job_config", + "_json_serializer", + "_literal_inliner", + "_type_converter", + ) dialect = "bigquery" def __init__( @@ -362,6 +369,7 @@ def __init__( parameter_json_serializer = features.get("json_serializer", to_json) self._json_serializer: Callable[[Any], str] = parameter_json_serializer + self._literal_inliner = build_literal_inlining_transform(json_serializer=self._json_serializer) super().__init__(connection=connection, statement_config=statement_config, driver_features=driver_features) self._default_query_job_config: QueryJobConfig | None = (driver_features or {}).get("default_query_job_config") @@ -485,30 +493,14 @@ def _try_special_handling(self, cursor: "Any", statement: "SQL") -> "SQLResult | _ = (cursor, statement) return None - def _transform_ast_with_literals(self, sql: str, parameters: Any) -> str: - """Transform SQL AST by replacing placeholders with literal values. - - Used for BigQuery script execution and execute_many operations where - parameter binding is not supported. Safely embeds values as SQL literals. - - Args: - sql: SQL string to transform. - parameters: Parameters to embed as literals. + def _inline_literals(self, expression: "sqlglot.Expression", parameters: Any) -> str: + """Inline literal values into a parsed SQLGlot expression.""" - Returns: - Transformed SQL string with literals embedded. - """ if not parameters: - return sql + return expression.sql(dialect="bigquery") - try: - ast = sqlglot.parse_one(sql, dialect="bigquery") - except sqlglot.ParseError: - return sql - - transformed_ast = replace_placeholders_with_literals(ast, parameters, json_serializer=self._json_serializer) - - return cast("str", transformed_ast.sql(dialect="bigquery")) + transformed_expression, _ = self._literal_inliner(expression, parameters) + return cast("str", transformed_expression.sql(dialect="bigquery")) def _execute_script(self, cursor: Any, statement: "SQL") -> ExecutionResult: """Execute SQL script with statement splitting and parameter handling. @@ -562,10 +554,19 @@ def _execute_many(self, cursor: Any, statement: "SQL") -> ExecutionResult: base_sql = statement.sql + try: + parsed_expression = sqlglot.parse_one(base_sql, dialect="bigquery") + except sqlglot.ParseError: + parsed_expression = None + script_statements = [] for param_set in parameters_list: - transformed_sql = self._transform_ast_with_literals(base_sql, param_set) - script_statements.append(transformed_sql) + if parsed_expression is None: + script_statements.append(base_sql) + continue + + expression_copy = parsed_expression.copy() + script_statements.append(self._inline_literals(expression_copy, param_set)) script_sql = ";\n".join(script_statements) diff --git a/sqlspec/adapters/duckdb/driver.py b/sqlspec/adapters/duckdb/driver.py index e2936dae..ab03cb88 100644 --- a/sqlspec/adapters/duckdb/driver.py +++ b/sqlspec/adapters/duckdb/driver.py @@ -27,6 +27,7 @@ ) from sqlspec.utils.logging import get_logger from sqlspec.utils.serializers import to_json +from sqlspec.utils.type_converters import build_decimal_converter, build_time_iso_converter if TYPE_CHECKING: from contextlib import AbstractContextManager @@ -48,6 +49,9 @@ logger = get_logger("adapters.duckdb") +_TIME_TO_ISO = build_time_iso_converter() +_DECIMAL_TO_STRING = build_decimal_converter(mode="string") + _type_converter = DuckDBTypeConverter() @@ -471,18 +475,6 @@ def _bool_to_int(value: bool) -> int: return int(value) -def _datetime_to_iso(value: datetime) -> str: - return value.isoformat() - - -def _date_to_iso(value: date) -> str: - return value.isoformat() - - -def _decimal_to_str(value: Decimal) -> str: - return str(value) - - def _build_duckdb_profile() -> DriverParameterProfile: """Create the DuckDB driver parameter profile.""" @@ -500,9 +492,9 @@ def _build_duckdb_profile() -> DriverParameterProfile: json_serializer_strategy="helper", custom_type_coercions={ bool: _bool_to_int, - datetime: _datetime_to_iso, - date: _date_to_iso, - Decimal: _decimal_to_str, + datetime: _TIME_TO_ISO, + date: _TIME_TO_ISO, + Decimal: _DECIMAL_TO_STRING, }, default_dialect="duckdb", ) diff --git a/sqlspec/adapters/psqlpy/driver.py b/sqlspec/adapters/psqlpy/driver.py index 3aae48cb..981ee19d 100644 --- a/sqlspec/adapters/psqlpy/driver.py +++ b/sqlspec/adapters/psqlpy/driver.py @@ -41,6 +41,7 @@ from sqlspec.typing import Empty from sqlspec.utils.logging import get_logger from sqlspec.utils.serializers import to_json +from sqlspec.utils.type_converters import build_nested_decimal_normalizer if TYPE_CHECKING: from collections.abc import Callable @@ -73,6 +74,7 @@ "TIMESTAMP WITHOUT TIME ZONE", }) _UUID_CASTS: Final[frozenset[str]] = frozenset({"UUID"}) +_DECIMAL_NORMALIZER = build_nested_decimal_normalizer(mode="float") class PsqlpyCursor: @@ -520,18 +522,6 @@ def data_dictionary(self) -> "AsyncDataDictionaryBase": return self._data_dictionary -def _convert_decimals_in_structure(obj: Any) -> Any: - """Recursively convert Decimal values to float in nested structures.""" - - if isinstance(obj, decimal.Decimal): - return float(obj) - if isinstance(obj, dict): - return {k: _convert_decimals_in_structure(v) for k, v in obj.items()} - if isinstance(obj, (list, tuple)): - return [_convert_decimals_in_structure(item) for item in obj] - return obj - - def _coerce_json_parameter(value: Any, cast_type: str, serializer: "Callable[[Any], str]") -> Any: """Serialize JSON parameters according to the detected cast type. @@ -640,16 +630,16 @@ def _coerce_parameter_for_cast(value: Any, cast_type: str, serializer: "Callable def _prepare_dict_parameter(value: "dict[str, Any]") -> dict[str, Any]: - normalized = _convert_decimals_in_structure(value) + normalized = _DECIMAL_NORMALIZER(value) return normalized if isinstance(normalized, dict) else value def _prepare_list_parameter(value: "list[Any]") -> list[Any]: - return [_convert_decimals_in_structure(item) for item in value] + return [_DECIMAL_NORMALIZER(item) for item in value] def _prepare_tuple_parameter(value: "tuple[Any, ...]") -> tuple[Any, ...]: - return tuple(_convert_decimals_in_structure(item) for item in value) + return tuple(_DECIMAL_NORMALIZER(item) for item in value) def _normalize_scalar_parameter(value: Any) -> Any: diff --git a/sqlspec/adapters/psycopg/driver.py b/sqlspec/adapters/psycopg/driver.py index f80b06b1..ae1b2ba9 100644 --- a/sqlspec/adapters/psycopg/driver.py +++ b/sqlspec/adapters/psycopg/driver.py @@ -47,6 +47,7 @@ ) from sqlspec.utils.logging import get_logger from sqlspec.utils.serializers import to_json +from sqlspec.utils.type_converters import build_json_list_converter, build_json_tuple_converter if TYPE_CHECKING: from collections.abc import Callable @@ -822,57 +823,6 @@ def _identity(value: Any) -> Any: return value -def _convert_list_to_postgres_array(value: Any) -> str: - """Convert a Python list to PostgreSQL array literal format.""" - - if not isinstance(value, list): - return str(value) - - elements: list[str] = [] - for item in value: - if isinstance(item, list): - elements.append(_convert_list_to_postgres_array(item)) - elif isinstance(item, str): - escaped = item.replace("'", "''") - elements.append(f"'{escaped}'") - elif item is None: - elements.append("NULL") - else: - elements.append(str(item)) - - return "{" + ",".join(elements) + "}" - - -def _should_serialize_list(value: "list[Any]") -> bool: - """Detect whether a list should be serialized to JSON.""" - - return any(isinstance(item, (dict, list, tuple)) for item in value) - - -def _build_list_parameter_converter(serializer: "Callable[[Any], str]") -> "Callable[[list[Any]], Any]": - """Create converter that serializes nested lists while preserving arrays.""" - - def convert(value: "list[Any]") -> Any: - if not value: - return value - if _should_serialize_list(value): - return serializer(value) - return value - - return convert - - -def _build_tuple_parameter_converter(serializer: "Callable[[Any], str]") -> "Callable[[tuple[Any, ...]], Any]": - """Create converter mirroring list handling for tuple parameters.""" - - list_converter = _build_list_parameter_converter(serializer) - - def convert(value: "tuple[Any, ...]") -> Any: - return list_converter(list(value)) - - return convert - - def _build_psycopg_custom_type_coercions() -> dict[type, "Callable[[Any], Any]"]: """Return custom type coercions for psycopg.""" @@ -919,8 +869,8 @@ def _create_psycopg_parameter_config(serializer: "Callable[[Any], str]") -> Para base_config = build_statement_config_from_profile(_PSYCOPG_PROFILE, json_serializer=serializer).parameter_config updated_type_map = dict(base_config.type_coercion_map) - updated_type_map[list] = _build_list_parameter_converter(serializer) - updated_type_map[tuple] = _build_tuple_parameter_converter(serializer) + updated_type_map[list] = build_json_list_converter(serializer) + updated_type_map[tuple] = build_json_tuple_converter(serializer) return base_config.replace(type_coercion_map=updated_type_map) diff --git a/sqlspec/adapters/sqlite/driver.py b/sqlspec/adapters/sqlite/driver.py index 76c1921b..4ba3d25b 100644 --- a/sqlspec/adapters/sqlite/driver.py +++ b/sqlspec/adapters/sqlite/driver.py @@ -27,6 +27,7 @@ UniqueViolationError, ) from sqlspec.utils.serializers import to_json +from sqlspec.utils.type_converters import build_decimal_converter, build_time_iso_converter if TYPE_CHECKING: from contextlib import AbstractContextManager @@ -47,6 +48,8 @@ SQLITE_CANTOPEN_CODE = 14 SQLITE_IOERR_CODE = 10 SQLITE_MISMATCH_CODE = 20 +_TIME_TO_ISO = build_time_iso_converter() +_DECIMAL_TO_STRING = build_decimal_converter(mode="string") class SqliteCursor: @@ -393,18 +396,6 @@ def _bool_to_int(value: bool) -> int: return int(value) -def _datetime_to_iso(value: datetime) -> str: - return value.isoformat() - - -def _date_to_iso(value: date) -> str: - return value.isoformat() - - -def _decimal_to_str(value: Decimal) -> str: - return str(value) - - def _build_sqlite_profile() -> DriverParameterProfile: """Create the SQLite driver parameter profile.""" @@ -422,9 +413,9 @@ def _build_sqlite_profile() -> DriverParameterProfile: json_serializer_strategy="helper", custom_type_coercions={ bool: _bool_to_int, - datetime: _datetime_to_iso, - date: _date_to_iso, - Decimal: _decimal_to_str, + datetime: _TIME_TO_ISO, + date: _TIME_TO_ISO, + Decimal: _DECIMAL_TO_STRING, }, default_dialect="sqlite", ) diff --git a/sqlspec/core/parameters/__init__.py b/sqlspec/core/parameters/__init__.py index 9cc581fa..df965eb6 100644 --- a/sqlspec/core/parameters/__init__.py +++ b/sqlspec/core/parameters/__init__.py @@ -16,6 +16,8 @@ register_driver_profile, ) from sqlspec.core.parameters._transformers import ( + build_literal_inlining_transform, + build_null_pruning_transform, replace_null_parameters_with_literals, replace_placeholders_with_literals, ) @@ -46,6 +48,8 @@ "ParameterStyleConfig", "ParameterValidator", "TypedParameter", + "build_literal_inlining_transform", + "build_null_pruning_transform", "build_statement_config_from_profile", "collect_null_parameter_ordinals", "get_driver_profile", diff --git a/sqlspec/core/parameters/_registry.py b/sqlspec/core/parameters/_registry.py index fda05731..79e54da2 100644 --- a/sqlspec/core/parameters/_registry.py +++ b/sqlspec/core/parameters/_registry.py @@ -38,9 +38,9 @@ def get_driver_profile(adapter_key: str) -> "DriverParameterProfile": key = adapter_key.lower() try: return DRIVER_PARAMETER_PROFILES[key] - except KeyError as error: + except KeyError as exc: msg = f"No driver parameter profile registered for adapter '{adapter_key}'." - raise sqlspec.exceptions.ImproperConfigurationError(msg) from error + raise sqlspec.exceptions.ImproperConfigurationError(msg) from exc def register_driver_profile( diff --git a/sqlspec/core/parameters/_transformers.py b/sqlspec/core/parameters/_transformers.py index eca9c9c6..01cd1bc9 100644 --- a/sqlspec/core/parameters/_transformers.py +++ b/sqlspec/core/parameters/_transformers.py @@ -13,11 +13,39 @@ from sqlspec.core.parameters._types import ParameterProfile from sqlspec.core.parameters._validator import ParameterValidator -__all__ = ("replace_null_parameters_with_literals", "replace_placeholders_with_literals") +__all__ = ( + "build_literal_inlining_transform", + "build_null_pruning_transform", + "replace_null_parameters_with_literals", + "replace_placeholders_with_literals", +) _AST_TRANSFORMER_VALIDATOR: "ParameterValidator" = ParameterValidator() +def build_null_pruning_transform( + *, dialect: str = "postgres", validator: "ParameterValidator | None" = None +) -> "Callable[[Any, Any], tuple[Any, Any]]": + """Return a callable that prunes NULL placeholders from an expression.""" + + def transform(expression: Any, parameters: Any) -> "tuple[Any, Any]": + return replace_null_parameters_with_literals(expression, parameters, dialect=dialect, validator=validator) + + return transform + + +def build_literal_inlining_transform( + *, json_serializer: "Callable[[Any], str]" +) -> "Callable[[Any, Any], tuple[Any, Any]]": + """Return a callable that replaces placeholders with SQL literals.""" + + def transform(expression: Any, parameters: Any) -> "tuple[Any, Any]": + literal_expression = replace_placeholders_with_literals(expression, parameters, json_serializer=json_serializer) + return literal_expression, parameters + + return transform + + def replace_null_parameters_with_literals( expression: Any, parameters: Any, *, dialect: str = "postgres", validator: "ParameterValidator | None" = None ) -> "tuple[Any, Any]": diff --git a/sqlspec/core/parameters/_types.py b/sqlspec/core/parameters/_types.py index 9eedca6b..6bea30a7 100644 --- a/sqlspec/core/parameters/_types.py +++ b/sqlspec/core/parameters/_types.py @@ -1,7 +1,6 @@ """Core parameter data structures and utilities.""" from collections.abc import Callable, Collection, Generator, Mapping, Sequence -from dataclasses import dataclass, field from datetime import date, datetime, time from decimal import Decimal from enum import Enum @@ -131,6 +130,7 @@ def __repr__(self) -> str: ) +@mypyc_attr(allow_interpreted_subclasses=False) class ParameterStyleConfig: """Configuration describing parameter behaviour for a statement.""" @@ -276,36 +276,71 @@ def tuple_adapter(value: Any) -> Any: ) -@dataclass(slots=True) +@mypyc_attr(allow_interpreted_subclasses=False) class DriverParameterProfile: """Immutable adapter profile describing parameter defaults.""" - name: str - default_style: "ParameterStyle" - supported_styles: "Collection[ParameterStyle]" - default_execution_style: "ParameterStyle" - supported_execution_styles: "Collection[ParameterStyle] | None" - has_native_list_expansion: bool - preserve_parameter_format: bool - needs_static_script_compilation: bool - allow_mixed_parameter_styles: bool - preserve_original_params_for_many: bool - json_serializer_strategy: "Literal['driver', 'helper', 'none']" - custom_type_coercions: "Mapping[type, Callable[[Any], Any]]" = field(default_factory=dict) - default_output_transformer: "Callable[[str, Any], tuple[str, Any]] | None" = None - default_ast_transformer: "Callable[[Any, Any], tuple[Any, Any]] | None" = None - extras: "Mapping[str, Any]" = field(default_factory=dict) - default_dialect: "str | None" = None - statement_kwargs: "Mapping[str, Any]" = field(default_factory=dict) - - def __post_init__(self) -> None: - self.supported_styles = frozenset(self.supported_styles) + __slots__ = ( + "allow_mixed_parameter_styles", + "custom_type_coercions", + "default_ast_transformer", + "default_dialect", + "default_execution_style", + "default_output_transformer", + "default_style", + "extras", + "has_native_list_expansion", + "json_serializer_strategy", + "name", + "needs_static_script_compilation", + "preserve_original_params_for_many", + "preserve_parameter_format", + "statement_kwargs", + "supported_execution_styles", + "supported_styles", + ) + + def __init__( + self, + name: str, + default_style: "ParameterStyle", + supported_styles: "Collection[ParameterStyle]", + default_execution_style: "ParameterStyle", + supported_execution_styles: "Collection[ParameterStyle] | None", + has_native_list_expansion: bool, + preserve_parameter_format: bool, + needs_static_script_compilation: bool, + allow_mixed_parameter_styles: bool, + preserve_original_params_for_many: bool, + json_serializer_strategy: "Literal['driver', 'helper', 'none']", + custom_type_coercions: "Mapping[type, Callable[[Any], Any]] | None" = None, + default_output_transformer: "Callable[[str, Any], tuple[str, Any]] | None" = None, + default_ast_transformer: "Callable[[Any, Any], tuple[Any, Any]] | None" = None, + extras: "Mapping[str, Any] | None" = None, + default_dialect: "str | None" = None, + statement_kwargs: "Mapping[str, Any] | None" = None, + ) -> None: + self.name = name + self.default_style = default_style + self.supported_styles = frozenset(supported_styles) + self.default_execution_style = default_execution_style self.supported_execution_styles = ( - frozenset(self.supported_execution_styles) if self.supported_execution_styles is not None else None + frozenset(supported_execution_styles) if supported_execution_styles is not None else None + ) + self.has_native_list_expansion = has_native_list_expansion + self.preserve_parameter_format = preserve_parameter_format + self.needs_static_script_compilation = needs_static_script_compilation + self.allow_mixed_parameter_styles = allow_mixed_parameter_styles + self.preserve_original_params_for_many = preserve_original_params_for_many + self.json_serializer_strategy = json_serializer_strategy + self.custom_type_coercions = ( + MappingProxyType(dict(custom_type_coercions)) if custom_type_coercions else MappingProxyType({}) ) - self.custom_type_coercions = MappingProxyType(dict(self.custom_type_coercions)) - self.extras = MappingProxyType(dict(self.extras)) - self.statement_kwargs = MappingProxyType(dict(self.statement_kwargs)) + self.default_output_transformer = default_output_transformer + self.default_ast_transformer = default_ast_transformer + self.extras = MappingProxyType(dict(extras)) if extras else MappingProxyType({}) + self.default_dialect = default_dialect + self.statement_kwargs = MappingProxyType(dict(statement_kwargs)) if statement_kwargs else MappingProxyType({}) @mypyc_attr(allow_interpreted_subclasses=False) diff --git a/sqlspec/utils/type_converters.py b/sqlspec/utils/type_converters.py new file mode 100644 index 00000000..8b751abf --- /dev/null +++ b/sqlspec/utils/type_converters.py @@ -0,0 +1,99 @@ +"""Reusable converter builders for parameter configuration.""" + +import decimal +from typing import TYPE_CHECKING, Any, Final + +if TYPE_CHECKING: + import datetime + from collections.abc import Callable, Sequence + +__all__ = ( + "DEFAULT_DECIMAL_MODE", + "build_decimal_converter", + "build_json_list_converter", + "build_json_tuple_converter", + "build_nested_decimal_normalizer", + "build_time_iso_converter", + "should_json_encode_sequence", +) + +JSON_NESTED_TYPES: Final[tuple[type[Any], ...]] = (dict, list, tuple) +DEFAULT_DECIMAL_MODE: Final[str] = "preserve" + + +def should_json_encode_sequence(sequence: "Sequence[Any]") -> bool: + """Return ``True`` when a sequence should be JSON serialized.""" + + return any(isinstance(item, JSON_NESTED_TYPES) for item in sequence if item is not None) + + +def build_json_list_converter( + serializer: "Callable[[Any], str]", *, preserve_arrays: bool = True +) -> "Callable[[list[Any]], Any]": + """Create a converter that serializes lists containing nested structures.""" + + def convert(value: "list[Any]") -> Any: + if not value: + return value + if preserve_arrays and not should_json_encode_sequence(value): + return value + return serializer(value) + + return convert + + +def build_json_tuple_converter( + serializer: "Callable[[Any], str]", *, preserve_arrays: bool = True +) -> "Callable[[tuple[Any, ...]], Any]": + """Create a converter that mirrors list handling for tuples.""" + + list_converter = build_json_list_converter(serializer, preserve_arrays=preserve_arrays) + + def convert(value: "tuple[Any, ...]") -> Any: + if not value: + return value + return list_converter(list(value)) + + return convert + + +def build_decimal_converter(*, mode: str = DEFAULT_DECIMAL_MODE) -> "Callable[[decimal.Decimal], Any]": + """Create a Decimal converter according to the desired mode.""" + + if mode == "preserve": + return lambda value: value + if mode == "string": + return lambda value: str(value) + if mode == "float": + return lambda value: float(value) + + msg = f"Unsupported decimal converter mode: {mode}" + raise ValueError(msg) + + +def build_nested_decimal_normalizer(*, mode: str = DEFAULT_DECIMAL_MODE) -> "Callable[[Any], Any]": + """Return a callable that coerces ``Decimal`` values within nested structures.""" + + decimal_converter = build_decimal_converter(mode=mode) + + def normalize(value: Any) -> Any: + if isinstance(value, decimal.Decimal): + return decimal_converter(value) + if isinstance(value, list): + return [normalize(item) for item in value] + if isinstance(value, tuple): + return tuple(normalize(item) for item in value) + if isinstance(value, dict): + return {key: normalize(item) for key, item in value.items()} + return value + + return normalize + + +def build_time_iso_converter() -> "Callable[[datetime.date | datetime.datetime | datetime.time], str]": + """Return a converter that formats temporal values using ISO 8601.""" + + def convert(value: "datetime.date | datetime.datetime | datetime.time") -> str: + return value.isoformat() + + return convert diff --git a/tests/unit/test_core/test_parameters.py b/tests/unit/test_core/test_parameters.py index f5bdffad..83701989 100644 --- a/tests/unit/test_core/test_parameters.py +++ b/tests/unit/test_core/test_parameters.py @@ -12,6 +12,7 @@ import math from datetime import date, datetime from decimal import Decimal +from importlib import import_module from typing import Any import pytest @@ -38,6 +39,22 @@ from sqlspec.exceptions import ImproperConfigurationError from sqlspec.utils.serializers import from_json, to_json +_ADAPTER_MODULE_NAMES: "tuple[str, ...]" = ( + "sqlspec.adapters.adbc", + "sqlspec.adapters.aiosqlite", + "sqlspec.adapters.asyncmy", + "sqlspec.adapters.asyncpg", + "sqlspec.adapters.bigquery", + "sqlspec.adapters.duckdb", + "sqlspec.adapters.oracledb", + "sqlspec.adapters.psqlpy", + "sqlspec.adapters.psycopg", + "sqlspec.adapters.sqlite", +) + +for _module_name in _ADAPTER_MODULE_NAMES: + import_module(_module_name) + pytestmark = pytest.mark.xdist_group("core")