Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions docs/guides/adapters/parameter-profile-registry.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ The table below summarizes the canonical `DriverParameterProfile` entries define

## Adding or Updating Profiles

1. Define the profile in `_registry.py` using lowercase key naming.
1. Define the profile in `_registry.py` using lowercase key naming. Ensure the adapter package (e.g., `sqlspec.adapters.duckdb.__init__`) imports the driver module so `register_driver_profile` executes during normal adapter imports; the registry does not perform lazy imports.
2. Pick the JSON strategy that matches driver capabilities (`helper`, `driver`, or `none`).
3. Declare extras as an immutable mapping; document each addition in this file and the relevant adapter guide.
4. Add or update regression coverage (see `specs/archive/driver-quality-review/research/testing_deliverables.md`).
Expand All @@ -28,6 +28,8 @@ Refer to [AGENTS.md](../../AGENTS.md) for the full checklist when touching the r
## Example Usage

```python
import sqlspec.adapters.duckdb # Triggers profile registration

from sqlspec.core.parameters import get_driver_profile, build_statement_config_from_profile

profile = get_driver_profile("duckdb")
Expand All @@ -36,4 +38,4 @@ config = build_statement_config_from_profile(profile)
print(config.parameter_config.default_parameter_style)
```

The snippet above retrieves the DuckDB profile, builds a `StatementConfig`, and prints the default parameter style (`?`). Use the same pattern for new adapters after defining their profiles.
The snippet above imports the DuckDB adapter package (which registers its profile), retrieves the profile, builds a `StatementConfig`, and prints the default parameter style (`?`). Use the same pattern for new adapters after defining their profiles.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,7 @@ ignore = [
"PLR6301", # method could be static or class method
"B903", # class could be a dataclass or named tuple
"PLW0603", # Using the global statement to update is discouraged
"PLW0108", # Replace lambda expression with a def
]
select = ["ALL"]

Expand Down
5 changes: 2 additions & 3 deletions sqlspec/adapters/adbc/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import contextlib
import datetime
import decimal
from functools import partial
from typing import TYPE_CHECKING, Any, cast

from sqlspec.adapters.adbc.data_dictionary import AdbcDataDictionary
Expand All @@ -16,10 +15,10 @@
from sqlspec.core.parameters import (
DriverParameterProfile,
ParameterStyle,
build_null_pruning_transform,
build_statement_config_from_profile,
get_driver_profile,
register_driver_profile,
replace_null_parameters_with_literals,
)
from sqlspec.core.result import create_arrow_result
from sqlspec.core.statement import SQL, StatementConfig
Expand Down Expand Up @@ -758,7 +757,7 @@ def get_adbc_statement_config(detected_dialect: str) -> StatementConfig:
}

if detected_dialect in {"postgres", "postgresql"}:
parameter_overrides["ast_transformer"] = partial(replace_null_parameters_with_literals, dialect=sqlglot_dialect)
parameter_overrides["ast_transformer"] = build_null_pruning_transform(dialect=sqlglot_dialect)

return build_statement_config_from_profile(
get_driver_profile("adbc"),
Expand Down
21 changes: 6 additions & 15 deletions sqlspec/adapters/aiosqlite/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
UniqueViolationError,
)
from sqlspec.utils.serializers import to_json
from sqlspec.utils.type_converters import build_decimal_converter, build_time_iso_converter

if TYPE_CHECKING:
from contextlib import AbstractAsyncContextManager
Expand All @@ -49,6 +50,8 @@
SQLITE_CANTOPEN_CODE = 14
SQLITE_IOERR_CODE = 10
SQLITE_MISMATCH_CODE = 20
_TIME_TO_ISO = build_time_iso_converter()
_DECIMAL_TO_STRING = build_decimal_converter(mode="string")


class AiosqliteCursor:
Expand Down Expand Up @@ -327,18 +330,6 @@ def _bool_to_int(value: bool) -> int:
return int(value)


def _datetime_to_iso(value: datetime) -> str:
return value.isoformat()


def _date_to_iso(value: date) -> str:
return value.isoformat()


def _decimal_to_str(value: Decimal) -> str:
return str(value)


def _build_aiosqlite_profile() -> DriverParameterProfile:
"""Create the AIOSQLite driver parameter profile."""

Expand All @@ -356,9 +347,9 @@ def _build_aiosqlite_profile() -> DriverParameterProfile:
json_serializer_strategy="helper",
custom_type_coercions={
bool: _bool_to_int,
datetime: _datetime_to_iso,
date: _date_to_iso,
Decimal: _decimal_to_str,
datetime: _TIME_TO_ISO,
date: _TIME_TO_ISO,
Decimal: _DECIMAL_TO_STRING,
},
default_dialect="sqlite",
)
Expand Down
51 changes: 26 additions & 25 deletions sqlspec/adapters/bigquery/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import datetime
import logging
from collections.abc import Callable
from decimal import Decimal
from typing import TYPE_CHECKING, Any, cast

Expand All @@ -18,9 +19,9 @@
from sqlspec.core import ParameterStyle, StatementConfig, create_arrow_result, get_cache_config
from sqlspec.core.parameters import (
DriverParameterProfile,
build_literal_inlining_transform,
build_statement_config_from_profile,
register_driver_profile,
replace_placeholders_with_literals,
)
from sqlspec.driver import ExecutionResult, SyncDriverAdapterBase
from sqlspec.exceptions import (
Expand Down Expand Up @@ -339,7 +340,13 @@ class BigQueryDriver(SyncDriverAdapterBase):
type coercion, error handling, and query job management.
"""

__slots__ = ("_data_dictionary", "_default_query_job_config", "_json_serializer", "_type_converter")
__slots__ = (
"_data_dictionary",
"_default_query_job_config",
"_json_serializer",
"_literal_inliner",
"_type_converter",
)
dialect = "bigquery"

def __init__(
Expand All @@ -362,6 +369,7 @@ def __init__(
parameter_json_serializer = features.get("json_serializer", to_json)

self._json_serializer: Callable[[Any], str] = parameter_json_serializer
self._literal_inliner = build_literal_inlining_transform(json_serializer=self._json_serializer)

super().__init__(connection=connection, statement_config=statement_config, driver_features=driver_features)
self._default_query_job_config: QueryJobConfig | None = (driver_features or {}).get("default_query_job_config")
Expand Down Expand Up @@ -485,30 +493,14 @@ def _try_special_handling(self, cursor: "Any", statement: "SQL") -> "SQLResult |
_ = (cursor, statement)
return None

def _transform_ast_with_literals(self, sql: str, parameters: Any) -> str:
"""Transform SQL AST by replacing placeholders with literal values.

Used for BigQuery script execution and execute_many operations where
parameter binding is not supported. Safely embeds values as SQL literals.

Args:
sql: SQL string to transform.
parameters: Parameters to embed as literals.
def _inline_literals(self, expression: "sqlglot.Expression", parameters: Any) -> str:
"""Inline literal values into a parsed SQLGlot expression."""

Returns:
Transformed SQL string with literals embedded.
"""
if not parameters:
return sql
return expression.sql(dialect="bigquery")

try:
ast = sqlglot.parse_one(sql, dialect="bigquery")
except sqlglot.ParseError:
return sql

transformed_ast = replace_placeholders_with_literals(ast, parameters, json_serializer=self._json_serializer)

return cast("str", transformed_ast.sql(dialect="bigquery"))
transformed_expression, _ = self._literal_inliner(expression, parameters)
return cast("str", transformed_expression.sql(dialect="bigquery"))

def _execute_script(self, cursor: Any, statement: "SQL") -> ExecutionResult:
"""Execute SQL script with statement splitting and parameter handling.
Expand Down Expand Up @@ -562,10 +554,19 @@ def _execute_many(self, cursor: Any, statement: "SQL") -> ExecutionResult:

base_sql = statement.sql

try:
parsed_expression = sqlglot.parse_one(base_sql, dialect="bigquery")
except sqlglot.ParseError:
parsed_expression = None

script_statements = []
for param_set in parameters_list:
transformed_sql = self._transform_ast_with_literals(base_sql, param_set)
script_statements.append(transformed_sql)
if parsed_expression is None:
script_statements.append(base_sql)
continue

expression_copy = parsed_expression.copy()
script_statements.append(self._inline_literals(expression_copy, param_set))

script_sql = ";\n".join(script_statements)

Expand Down
22 changes: 7 additions & 15 deletions sqlspec/adapters/duckdb/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
)
from sqlspec.utils.logging import get_logger
from sqlspec.utils.serializers import to_json
from sqlspec.utils.type_converters import build_decimal_converter, build_time_iso_converter

if TYPE_CHECKING:
from contextlib import AbstractContextManager
Expand All @@ -48,6 +49,9 @@

logger = get_logger("adapters.duckdb")

_TIME_TO_ISO = build_time_iso_converter()
_DECIMAL_TO_STRING = build_decimal_converter(mode="string")

_type_converter = DuckDBTypeConverter()


Expand Down Expand Up @@ -471,18 +475,6 @@ def _bool_to_int(value: bool) -> int:
return int(value)


def _datetime_to_iso(value: datetime) -> str:
return value.isoformat()


def _date_to_iso(value: date) -> str:
return value.isoformat()


def _decimal_to_str(value: Decimal) -> str:
return str(value)


def _build_duckdb_profile() -> DriverParameterProfile:
"""Create the DuckDB driver parameter profile."""

Expand All @@ -500,9 +492,9 @@ def _build_duckdb_profile() -> DriverParameterProfile:
json_serializer_strategy="helper",
custom_type_coercions={
bool: _bool_to_int,
datetime: _datetime_to_iso,
date: _date_to_iso,
Decimal: _decimal_to_str,
datetime: _TIME_TO_ISO,
date: _TIME_TO_ISO,
Decimal: _DECIMAL_TO_STRING,
},
default_dialect="duckdb",
)
Expand Down
20 changes: 5 additions & 15 deletions sqlspec/adapters/psqlpy/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from sqlspec.typing import Empty
from sqlspec.utils.logging import get_logger
from sqlspec.utils.serializers import to_json
from sqlspec.utils.type_converters import build_nested_decimal_normalizer

if TYPE_CHECKING:
from collections.abc import Callable
Expand Down Expand Up @@ -73,6 +74,7 @@
"TIMESTAMP WITHOUT TIME ZONE",
})
_UUID_CASTS: Final[frozenset[str]] = frozenset({"UUID"})
_DECIMAL_NORMALIZER = build_nested_decimal_normalizer(mode="float")


class PsqlpyCursor:
Expand Down Expand Up @@ -520,18 +522,6 @@ def data_dictionary(self) -> "AsyncDataDictionaryBase":
return self._data_dictionary


def _convert_decimals_in_structure(obj: Any) -> Any:
"""Recursively convert Decimal values to float in nested structures."""

if isinstance(obj, decimal.Decimal):
return float(obj)
if isinstance(obj, dict):
return {k: _convert_decimals_in_structure(v) for k, v in obj.items()}
if isinstance(obj, (list, tuple)):
return [_convert_decimals_in_structure(item) for item in obj]
return obj


def _coerce_json_parameter(value: Any, cast_type: str, serializer: "Callable[[Any], str]") -> Any:
"""Serialize JSON parameters according to the detected cast type.

Expand Down Expand Up @@ -640,16 +630,16 @@ def _coerce_parameter_for_cast(value: Any, cast_type: str, serializer: "Callable


def _prepare_dict_parameter(value: "dict[str, Any]") -> dict[str, Any]:
normalized = _convert_decimals_in_structure(value)
normalized = _DECIMAL_NORMALIZER(value)
return normalized if isinstance(normalized, dict) else value


def _prepare_list_parameter(value: "list[Any]") -> list[Any]:
return [_convert_decimals_in_structure(item) for item in value]
return [_DECIMAL_NORMALIZER(item) for item in value]


def _prepare_tuple_parameter(value: "tuple[Any, ...]") -> tuple[Any, ...]:
return tuple(_convert_decimals_in_structure(item) for item in value)
return tuple(_DECIMAL_NORMALIZER(item) for item in value)


def _normalize_scalar_parameter(value: Any) -> Any:
Expand Down
56 changes: 3 additions & 53 deletions sqlspec/adapters/psycopg/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
)
from sqlspec.utils.logging import get_logger
from sqlspec.utils.serializers import to_json
from sqlspec.utils.type_converters import build_json_list_converter, build_json_tuple_converter

if TYPE_CHECKING:
from collections.abc import Callable
Expand Down Expand Up @@ -822,57 +823,6 @@ def _identity(value: Any) -> Any:
return value


def _convert_list_to_postgres_array(value: Any) -> str:
"""Convert a Python list to PostgreSQL array literal format."""

if not isinstance(value, list):
return str(value)

elements: list[str] = []
for item in value:
if isinstance(item, list):
elements.append(_convert_list_to_postgres_array(item))
elif isinstance(item, str):
escaped = item.replace("'", "''")
elements.append(f"'{escaped}'")
elif item is None:
elements.append("NULL")
else:
elements.append(str(item))

return "{" + ",".join(elements) + "}"


def _should_serialize_list(value: "list[Any]") -> bool:
"""Detect whether a list should be serialized to JSON."""

return any(isinstance(item, (dict, list, tuple)) for item in value)


def _build_list_parameter_converter(serializer: "Callable[[Any], str]") -> "Callable[[list[Any]], Any]":
"""Create converter that serializes nested lists while preserving arrays."""

def convert(value: "list[Any]") -> Any:
if not value:
return value
if _should_serialize_list(value):
return serializer(value)
return value

return convert


def _build_tuple_parameter_converter(serializer: "Callable[[Any], str]") -> "Callable[[tuple[Any, ...]], Any]":
"""Create converter mirroring list handling for tuple parameters."""

list_converter = _build_list_parameter_converter(serializer)

def convert(value: "tuple[Any, ...]") -> Any:
return list_converter(list(value))

return convert


def _build_psycopg_custom_type_coercions() -> dict[type, "Callable[[Any], Any]"]:
"""Return custom type coercions for psycopg."""

Expand Down Expand Up @@ -919,8 +869,8 @@ def _create_psycopg_parameter_config(serializer: "Callable[[Any], str]") -> Para
base_config = build_statement_config_from_profile(_PSYCOPG_PROFILE, json_serializer=serializer).parameter_config

updated_type_map = dict(base_config.type_coercion_map)
updated_type_map[list] = _build_list_parameter_converter(serializer)
updated_type_map[tuple] = _build_tuple_parameter_converter(serializer)
updated_type_map[list] = build_json_list_converter(serializer)
updated_type_map[tuple] = build_json_tuple_converter(serializer)

return base_config.replace(type_coercion_map=updated_type_map)

Expand Down
Loading