Skip to content
Merged
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,5 @@ requirements/*
!requirements/example-feature
!requirements/README.md
!.claude/bootstrap.md
.pre-commit-cache
.gh-cache
90 changes: 12 additions & 78 deletions sqlspec/adapters/adbc/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,13 @@
from sqlspec.adapters.adbc.data_dictionary import AdbcDataDictionary
from sqlspec.adapters.adbc.type_converter import ADBCTypeConverter
from sqlspec.core.cache import get_cache_config
from sqlspec.core.parameters import ParameterStyle, ParameterStyleConfig
from sqlspec.core.parameters import (
ParameterProfile,
ParameterStyle,
ParameterStyleConfig,
ParameterValidator,
validate_parameter_alignment,
)
from sqlspec.core.result import create_arrow_result
from sqlspec.core.statement import SQL, StatementConfig
from sqlspec.driver import SyncDriverAdapterBase
Expand Down Expand Up @@ -69,87 +75,14 @@
"snowflake": (ParameterStyle.QMARK, [ParameterStyle.QMARK, ParameterStyle.NUMERIC]),
}


def _count_placeholders(expression: Any) -> int:
"""Count the number of unique parameter placeholders in a SQLGlot expression.

For PostgreSQL ($1, $2) style: counts highest numbered parameter (e.g., $1, $1, $2 = 2)
For QMARK (?) style: counts total occurrences (each ? is a separate parameter)
For named (:name) style: counts unique parameter names

Args:
expression: SQLGlot AST expression

Returns:
Number of unique parameter placeholders expected
"""
numeric_params = set() # For $1, $2 style
qmark_count = 0 # For ? style
named_params = set() # For :name style

def count_node(node: Any) -> Any:
nonlocal qmark_count
if isinstance(node, exp.Parameter):
# PostgreSQL style: $1, $2, etc.
param_str = str(node)
if param_str.startswith("$") and param_str[1:].isdigit():
numeric_params.add(int(param_str[1:]))
elif ":" in param_str:
# Named parameter: :name
named_params.add(param_str)
else:
# Other parameter formats
named_params.add(param_str)
elif isinstance(node, exp.Placeholder):
# QMARK style: ?
qmark_count += 1
return node

expression.transform(count_node)

# Return the appropriate count based on parameter style detected
if numeric_params:
# PostgreSQL style: return highest numbered parameter
return max(numeric_params)
if named_params:
# Named parameters: return count of unique names
return len(named_params)
# QMARK style: return total count
return qmark_count
_AST_PARAMETER_VALIDATOR: "ParameterValidator" = ParameterValidator()


def _is_execute_many_parameters(parameters: Any) -> bool:
"""Check if parameters are in execute_many format (list/tuple of lists/tuples)."""
return isinstance(parameters, (list, tuple)) and len(parameters) > 0 and isinstance(parameters[0], (list, tuple))


def _validate_parameter_counts(expression: Any, parameters: Any, dialect: str) -> None:
"""Validate parameter count against placeholder count in SQL."""
placeholder_count = _count_placeholders(expression)
is_execute_many = _is_execute_many_parameters(parameters)

if is_execute_many:
# For execute_many, validate each inner parameter set
for i, param_set in enumerate(parameters):
param_count = len(param_set) if isinstance(param_set, (list, tuple)) else 0
if param_count != placeholder_count:
msg = f"Parameter count mismatch in set {i}: {param_count} parameters provided but {placeholder_count} placeholders in SQL (dialect: {dialect})"
raise SQLSpecError(msg)
else:
# For single execution, validate the parameter set directly
param_count = (
len(parameters)
if isinstance(parameters, (list, tuple))
else len(parameters)
if isinstance(parameters, dict)
else 0
)

if param_count != placeholder_count:
msg = f"Parameter count mismatch: {param_count} parameters provided but {placeholder_count} placeholders in SQL (dialect: {dialect})"
raise SQLSpecError(msg)


def _find_null_positions(parameters: Any) -> set[int]:
"""Find positions of None values in parameters for single execution."""
null_positions = set()
Expand Down Expand Up @@ -187,14 +120,15 @@ def _adbc_ast_transformer(expression: Any, parameters: Any, dialect: str = "post
if not parameters:
return expression, parameters

# Validate parameter count before transformation
_validate_parameter_counts(expression, parameters, dialect)

# For execute_many operations, skip AST transformation as different parameter
# sets may have None values in different positions, making transformation complex
if _is_execute_many_parameters(parameters):
return expression, parameters

parameter_info = _AST_PARAMETER_VALIDATOR.extract_parameters(expression.sql(dialect=dialect))
parameter_profile = ParameterProfile(parameter_info)
validate_parameter_alignment(parameter_profile, parameters)

# Find positions of None values for single execution
null_positions = _find_null_positions(parameters)
if not null_positions:
Expand Down
11 changes: 10 additions & 1 deletion sqlspec/adapters/asyncmy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ def __init__(
if "port" not in processed_pool_config:
processed_pool_config["port"] = 3306

if statement_config is None:
using_default_statement_config = statement_config is None
if using_default_statement_config:
statement_config = asyncmy_statement_config

processed_driver_features: dict[str, Any] = dict(driver_features) if driver_features else {}
Expand All @@ -127,6 +128,14 @@ def __init__(
if "json_deserializer" not in processed_driver_features:
processed_driver_features["json_deserializer"] = from_json

if statement_config is None:
statement_config = asyncmy_statement_config

json_serializer = processed_driver_features.get("json_serializer")
if json_serializer is not None and using_default_statement_config:
parameter_config = statement_config.parameter_config.with_json_serializers(json_serializer)
statement_config = statement_config.replace(parameter_config=parameter_config)

super().__init__(
pool_config=processed_pool_config,
pool_instance=pool_instance,
Expand Down
72 changes: 0 additions & 72 deletions sqlspec/adapters/asyncmy/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from sqlspec.utils.serializers import to_json

if TYPE_CHECKING:
from collections.abc import Callable
from contextlib import AbstractAsyncContextManager

from sqlspec.adapters.asyncmy._types import AsyncmyConnection
Expand Down Expand Up @@ -243,82 +242,11 @@ def __init__(
dialect="mysql",
)

final_statement_config = self._apply_json_serializer_feature(final_statement_config, driver_features)

super().__init__(
connection=connection, statement_config=final_statement_config, driver_features=driver_features
)
self._data_dictionary: AsyncDataDictionaryBase | None = None

@staticmethod
def _clone_parameter_config(
parameter_config: ParameterStyleConfig, type_coercion_map: "dict[type[Any], Callable[[Any], Any]]"
) -> ParameterStyleConfig:
"""Create a copy of the parameter configuration with updated coercion map.

Args:
parameter_config: Existing parameter configuration to copy.
type_coercion_map: Updated coercion mapping for parameter serialization.

Returns:
ParameterStyleConfig with the updated type coercion map applied.
"""

supported_execution_styles = (
set(parameter_config.supported_execution_parameter_styles)
if parameter_config.supported_execution_parameter_styles is not None
else None
)

return ParameterStyleConfig(
default_parameter_style=parameter_config.default_parameter_style,
supported_parameter_styles=set(parameter_config.supported_parameter_styles),
supported_execution_parameter_styles=supported_execution_styles,
default_execution_parameter_style=parameter_config.default_execution_parameter_style,
type_coercion_map=type_coercion_map,
has_native_list_expansion=parameter_config.has_native_list_expansion,
needs_static_script_compilation=parameter_config.needs_static_script_compilation,
allow_mixed_parameter_styles=parameter_config.allow_mixed_parameter_styles,
preserve_parameter_format=parameter_config.preserve_parameter_format,
preserve_original_params_for_many=parameter_config.preserve_original_params_for_many,
output_transformer=parameter_config.output_transformer,
ast_transformer=parameter_config.ast_transformer,
)

@staticmethod
def _apply_json_serializer_feature(
statement_config: "StatementConfig", driver_features: "dict[str, Any] | None"
) -> "StatementConfig":
"""Apply driver-level JSON serializer customization to the statement config.

Args:
statement_config: Base statement configuration for the driver.
driver_features: Driver feature mapping provided via configuration.

Returns:
StatementConfig with serializer adjustments applied when configured.
"""

if not driver_features:
return statement_config

serializer = driver_features.get("json_serializer")
if serializer is None:
return statement_config

parameter_config = statement_config.parameter_config
type_coercion_map = dict(parameter_config.type_coercion_map)

def serialize_tuple(value: Any) -> Any:
return serializer(list(value))

type_coercion_map[dict] = serializer
type_coercion_map[list] = serializer
type_coercion_map[tuple] = serialize_tuple

updated_parameter_config = AsyncmyDriver._clone_parameter_config(parameter_config, type_coercion_map)
return statement_config.replace(parameter_config=updated_parameter_config)

def with_cursor(self, connection: "AsyncmyConnection") -> "AsyncmyCursor":
"""Create cursor context manager for the connection.

Expand Down
4 changes: 2 additions & 2 deletions sqlspec/adapters/bigquery/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,8 +279,8 @@ def _create_bq_parameters(
raise SQLSpecError(msg)

elif isinstance(parameters, (list, tuple)):
logger.warning("BigQuery received positional parameters instead of named parameters")
return []
msg = "BigQuery driver requires named parameters (e.g., @name); positional parameters are not supported"
raise SQLSpecError(msg)

return bq_parameters

Expand Down
61 changes: 12 additions & 49 deletions sqlspec/adapters/duckdb/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import TYPE_CHECKING, Any, Final

import duckdb
from sqlglot import exp

from sqlspec.adapters.duckdb.data_dictionary import DuckDBSyncDataDictionary
from sqlspec.adapters.duckdb.type_converter import DuckDBTypeConverter
Expand Down Expand Up @@ -225,36 +224,20 @@ def __init__(
statement_config = updated_config

if driver_features:
param_config = statement_config.parameter_config
json_serializer = driver_features.get("json_serializer")
enable_uuid_conversion = driver_features.get("enable_uuid_conversion", True)
if json_serializer:
param_config = param_config.with_json_serializers(json_serializer, tuple_strategy="tuple")

if json_serializer or not enable_uuid_conversion:
enable_uuid_conversion = driver_features.get("enable_uuid_conversion", True)
if not enable_uuid_conversion:
type_converter = DuckDBTypeConverter(enable_uuid_conversion=enable_uuid_conversion)
type_coercion_map = dict(statement_config.parameter_config.type_coercion_map)

if json_serializer:
type_coercion_map[dict] = json_serializer
type_coercion_map[list] = json_serializer

if not enable_uuid_conversion:
type_coercion_map[str] = type_converter.convert_if_detected

param_config = statement_config.parameter_config
updated_param_config = ParameterStyleConfig(
default_parameter_style=param_config.default_parameter_style,
supported_parameter_styles=param_config.supported_parameter_styles,
supported_execution_parameter_styles=param_config.supported_execution_parameter_styles,
default_execution_parameter_style=param_config.default_execution_parameter_style,
type_coercion_map=type_coercion_map,
has_native_list_expansion=param_config.has_native_list_expansion,
needs_static_script_compilation=param_config.needs_static_script_compilation,
allow_mixed_parameter_styles=param_config.allow_mixed_parameter_styles,
preserve_parameter_format=param_config.preserve_parameter_format,
preserve_original_params_for_many=param_config.preserve_original_params_for_many,
output_transformer=param_config.output_transformer,
ast_transformer=param_config.ast_transformer,
)
statement_config = statement_config.replace(parameter_config=updated_param_config)
type_coercion_map = dict(param_config.type_coercion_map)
type_coercion_map[str] = type_converter.convert_if_detected
param_config = param_config.replace(type_coercion_map=type_coercion_map)

if param_config is not statement_config.parameter_config:
statement_config = statement_config.replace(parameter_config=param_config)

super().__init__(connection=connection, statement_config=statement_config, driver_features=driver_features)
self._data_dictionary: SyncDataDictionaryBase | None = None
Expand Down Expand Up @@ -294,26 +277,6 @@ def _try_special_handling(self, cursor: Any, statement: SQL) -> "SQLResult | Non
_ = (cursor, statement)
return None

def _is_modifying_operation(self, statement: SQL) -> bool:
"""Check if the SQL statement modifies data.

Determines if a statement is an INSERT, UPDATE, or DELETE operation
using AST analysis when available, falling back to text parsing.

Args:
statement: SQL statement to analyze

Returns:
True if the operation modifies data (INSERT/UPDATE/DELETE)
"""

expression = statement.expression
if expression and isinstance(expression, (exp.Insert, exp.Update, exp.Delete)):
return True

sql_upper = statement.sql.strip().upper()
return any(sql_upper.startswith(op) for op in MODIFYING_OPERATIONS)

def _execute_script(self, cursor: Any, statement: SQL) -> "ExecutionResult":
"""Execute SQL script with statement splitting and parameter handling.

Expand Down Expand Up @@ -359,7 +322,7 @@ def _execute_many(self, cursor: Any, statement: SQL) -> "ExecutionResult":
if prepared_parameters:
cursor.executemany(sql, prepared_parameters)

if self._is_modifying_operation(statement):
if statement.is_modifying_operation():
row_count = len(prepared_parameters)
else:
try:
Expand Down
15 changes: 15 additions & 0 deletions sqlspec/core/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from mypy_extensions import mypyc_attr
from typing_extensions import TypeVar

from sqlspec.core.pipeline import get_statement_pipeline_metrics, reset_statement_pipeline_cache
from sqlspec.utils.logging import get_logger

if TYPE_CHECKING:
Expand All @@ -40,6 +41,8 @@
"get_cache",
"get_cache_config",
"get_default_cache",
"get_pipeline_metrics",
"reset_pipeline_registry",
)

T = TypeVar("T")
Expand Down Expand Up @@ -768,3 +771,15 @@ def to_canonical(self) -> "tuple[Any, ...]":
filter_objects.append(Filter(f.field_name, f.operation, f.value))

return canonicalize_filters(filter_objects)


def get_pipeline_metrics() -> "list[dict[str, Any]]":
"""Return metrics for the shared statement pipeline cache when enabled."""

return get_statement_pipeline_metrics()


def reset_pipeline_registry() -> None:
"""Clear shared statement pipeline caches and metrics."""

reset_statement_pipeline_cache()
Loading