diff --git a/AGENTS.md b/AGENTS.md index eda9023ad..e4ca5e33a 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -185,6 +185,19 @@ SQLSpec is a type-safe SQL query mapper designed for minimal abstraction between - **Single-Pass Processing**: Parse once → transform once → validate once - SQL object is single source of truth - **Abstract Methods with Concrete Implementations**: Protocol defines abstract methods, base classes provide concrete sync/async implementations +### Driver Parameter Profile Registry + +- All adapter parameter defaults live in `DriverParameterProfile` entries inside `sqlspec/core/parameters.py`. +- Use lowercase adapter keys (e.g., `"asyncpg"`, `"duckdb"`) and populate every required field: default style, supported styles, execution style, native list expansion flags, JSON strategy, and optional extras. +- JSON behaviour is controlled through `json_serializer_strategy`: + - `"helper"`: call `ParameterStyleConfig.with_json_serializers()` (dict/list/tuple auto-encode) + - `"driver"`: defer to driver codecs while surfacing serializer references for later registration + - `"none"`: skip JSON helpers entirely (reserve for adapters that must not touch JSON) +- Extras should encapsulate adapter-specific tweaks (e.g., `type_coercion_overrides`, `json_tuple_strategy`). Document new extras inline and keep them immutable. +- Always build `StatementConfig` via `build_statement_config_from_profile()` and pass adapter-specific overrides through the helper instead of instantiating configs manually in drivers. +- When introducing a new adapter, add its profile, update relevant guides, and extend unit coverage so each JSON strategy path is exercised. +- Record the canonical adapter key, JSON strategy, and extras in the corresponding adapter guide so contributors can verify behaviour without reading the registry source. + ### Protocol Abstract Methods Pattern When adding methods that need to support both sync and async configurations, use this pattern: @@ -335,6 +348,13 @@ def test_starlette_autocommit_mode() -> None: - Disabling pooling - Tests don't reflect production configuration - Running tests serially - Slows down CI significantly +### CLI Config Loader Isolation Pattern + +- When exercising CLI migration commands, generate a unique module namespace for each test (for example `cli_test_config_`). +- Place temporary config modules inside `tmp_path` and register them via `sys.modules` within the test, then delete them during teardown to prevent bleed-through. +- Always patch `Path.cwd()` or provide explicit path arguments so helper functions resolve the test-local module rather than cached global fixtures. +- Add regression tests ensuring the helper cleaning logic runs even if CLI commands raise exceptions to avoid polluting later suites. + ### Performance Optimizations - **Mypyc Compilation**: Core modules can be compiled with mypyc for performance @@ -2667,6 +2687,7 @@ def _extract_starlette_settings(self, config): 4. **Conditionally Skip DI Setup**: **Middleware-based (Starlette/FastAPI)**: + ```python def init_app(self, app): # ... lifespan setup ... @@ -2677,6 +2698,7 @@ def init_app(self, app): ``` **Provider-based (Litestar)**: + ```python def on_app_init(self, app_config): for state in self._plugin_configs: @@ -2693,6 +2715,7 @@ def on_app_init(self, app_config): ``` **Hook-based (Flask)**: + ```python def init_app(self, app): # ... pool setup ... diff --git a/docs/guides/adapters/adbc.md b/docs/guides/adapters/adbc.md index da692bb50..92ce47173 100644 --- a/docs/guides/adapters/adbc.md +++ b/docs/guides/adapters/adbc.md @@ -11,6 +11,12 @@ This guide provides specific instructions for the `adbc` adapter. - **Driver:** Arrow Database Connectivity (ADBC) drivers (e.g., `adbc_driver_postgresql`, `adbc_driver_sqlite`). - **Parameter Style:** Varies by underlying database (e.g., `numeric` for PostgreSQL, `qmark` for SQLite). +## Parameter Profile + +- **Registry Key:** `"adbc"` +- **JSON Strategy:** `helper` (shared serializers wrap dict/list/tuple values) +- **Extras:** `type_coercion_overrides` ensure Arrow arrays map to Python lists; PostgreSQL dialects attach a NULL-handling AST transformer + ## Best Practices - **Arrow-Native:** The primary benefit of ADBC is its direct integration with Apache Arrow. Use it when you need to move large amounts of data efficiently between the database and data science tools like Pandas or Polars. diff --git a/docs/guides/adapters/aiosqlite.md b/docs/guides/adapters/aiosqlite.md index b725a091c..6643e8227 100644 --- a/docs/guides/adapters/aiosqlite.md +++ b/docs/guides/adapters/aiosqlite.md @@ -11,6 +11,12 @@ This guide provides specific instructions for the `aiosqlite` adapter. - **Driver:** `aiosqlite` - **Parameter Style:** `qmark` (e.g., `?`) +## Parameter Profile + +- **Registry Key:** `"aiosqlite"` +- **JSON Strategy:** `helper` (shared serializer handles dict/list/tuple inputs) +- **Extras:** None (profile applies bool→int and ISO datetime coercions automatically) + ## Best Practices - **Async Only:** This is an asynchronous driver for SQLite. Use it in `asyncio` applications. diff --git a/docs/guides/adapters/asyncmy.md b/docs/guides/adapters/asyncmy.md index 2556e199a..1a98861f9 100644 --- a/docs/guides/adapters/asyncmy.md +++ b/docs/guides/adapters/asyncmy.md @@ -11,6 +11,12 @@ This guide covers `asyncmy`. - **Driver:** `asyncmy` - **Parameter Style:** `pyformat` (e.g., `%s`) +## Parameter Profile + +- **Registry Key:** `"asyncmy"` +- **JSON Strategy:** `helper` (uses shared JSON serializers for dict/list/tuple) +- **Extras:** None (native list expansion remains disabled) + ## Best Practices - **Character Set:** Always ensure the connection character set is `utf8mb4` to support a full range of Unicode characters, including emojis. diff --git a/docs/guides/adapters/asyncpg.md b/docs/guides/adapters/asyncpg.md index 0a9867805..ea8bc5a5b 100644 --- a/docs/guides/adapters/asyncpg.md +++ b/docs/guides/adapters/asyncpg.md @@ -11,6 +11,12 @@ This guide provides specific instructions and best practices for working with th - **Driver:** `asyncpg` - **Parameter Style:** `numeric` (e.g., `$1, $2`) +## Parameter Profile + +- **Registry Key:** `"asyncpg"` +- **JSON Strategy:** `driver` (delegates JSON binding to asyncpg codecs) +- **Extras:** None (codecs registered through config init hook) + ## Best Practices - **High-Performance:** `asyncpg` is often chosen for high-performance applications due to its speed. It's a good choice for applications with a high volume of database traffic. diff --git a/docs/guides/adapters/bigquery.md b/docs/guides/adapters/bigquery.md index 59b883493..c9d53e2be 100644 --- a/docs/guides/adapters/bigquery.md +++ b/docs/guides/adapters/bigquery.md @@ -11,6 +11,12 @@ This guide provides specific instructions for the `bigquery` adapter. - **Driver:** `google-cloud-bigquery` - **Parameter Style:** `named` (e.g., `@name`) +## Parameter Profile + +- **Registry Key:** `"bigquery"` +- **JSON Strategy:** `helper` with `json_tuple_strategy="tuple"` +- **Extras:** `type_coercion_overrides` keep list values intact while converting tuples to lists during binding + ## Best Practices - **Authentication:** BigQuery requires authentication with Google Cloud. For local development, the easiest way is to use the Google Cloud CLI and run `gcloud auth application-default login`. diff --git a/docs/guides/adapters/duckdb.md b/docs/guides/adapters/duckdb.md index b125679fb..884364899 100644 --- a/docs/guides/adapters/duckdb.md +++ b/docs/guides/adapters/duckdb.md @@ -11,6 +11,12 @@ This guide provides specific instructions for the `duckdb` adapter. - **Driver:** `duckdb` - **Parameter Style:** `qmark` (e.g., `?`) +## Parameter Profile + +- **Registry Key:** `"duckdb"` +- **JSON Strategy:** `helper` (shared serializer covers dict/list/tuple) +- **Extras:** None (profile preserves existing `allow_mixed_parameter_styles=False`) + ## Best Practices - **In-Memory vs. File:** DuckDB can run entirely in-memory (`:memory:`) or with a file-based database. In-memory is great for fast, temporary analytics. File-based is for persistence. diff --git a/docs/guides/adapters/oracledb.md b/docs/guides/adapters/oracledb.md index 7d0c225ad..9f6050de2 100644 --- a/docs/guides/adapters/oracledb.md +++ b/docs/guides/adapters/oracledb.md @@ -11,6 +11,12 @@ This guide provides specific instructions and best practices for working with th - **Driver:** `oracledb` - **Parameter Style:** `named` (e.g., `:name`) +## Parameter Profile + +- **Registry Key:** `"oracledb"` +- **JSON Strategy:** `helper` (shared JSON serializer applied through the profile) +- **Extras:** None (uses defaults with native list expansion disabled) + ## Thick vs. Thin Client The `oracledb` driver supports two modes: diff --git a/docs/guides/adapters/parameter-profile-registry.md b/docs/guides/adapters/parameter-profile-registry.md new file mode 100644 index 000000000..4a6f2e68a --- /dev/null +++ b/docs/guides/adapters/parameter-profile-registry.md @@ -0,0 +1,39 @@ +# Driver Parameter Profile Registry + +The table below summarizes the canonical `DriverParameterProfile` entries defined in `sqlspec/core/parameters/_registry.py`. Use it as a quick reference when updating adapters or adding new ones. + +| Adapter | Registry Key | JSON Strategy | Extras | Default Dialect | Notes | +| --- | --- | --- | --- | --- | --- | +| ADBC | `"adbc"` | `helper` | `type_coercion_overrides` (list/tuple array coercion) | dynamic (per detected dialect) | Shares AST transformer metadata with BigQuery dialect helpers. | +| AioSQLite | `"aiosqlite"` | `helper` | None | `sqlite` | Mirrors SQLite defaults; bools coerced to ints for driver parity. | +| AsyncMy | `"asyncmy"` | `helper` | None | `mysql` | Native list expansion currently disabled until connector parity confirmed. | +| AsyncPG | `"asyncpg"` | `driver` | None | `postgres` | Relies on asyncpg codecs; JSON serializers referenced for later registration. | +| BigQuery | `"bigquery"` | `helper` | `json_tuple_strategy="tuple"`, `type_coercion_overrides` | `bigquery` | Enforces named parameters; tuple JSON payloads preserved as tuples. | +| DuckDB | `"duckdb"` | `helper` | None | `duckdb` | Mixed-style binding disabled; aligns bool/datetime coercion with SQLite. | +| OracleDB | `"oracledb"` | `helper` | None | `oracle` | List expansion disabled; LOB handling delegated to adapter converters. | +| PSQLPy | `"psqlpy"` | `helper` | None | `postgres` | Decimal values currently downcast to float for driver compatibility. | +| Psycopg | `"psycopg"` | `helper` | None | `postgres` | Array coercion delegated to psycopg adapters; JSON handled by shared converters. | +| SQLite | `"sqlite"` | `helper` | None | `sqlite` | Shares bool/datetime handling with DuckDB and CLI defaults. | + +## Adding or Updating Profiles + +1. Define the profile in `_registry.py` using lowercase key naming. +2. Pick the JSON strategy that matches driver capabilities (`helper`, `driver`, or `none`). +3. Declare extras as an immutable mapping; document each addition in this file and the relevant adapter guide. +4. Add or update regression coverage (see `specs/archive/driver-quality-review/research/testing_deliverables.md`). +5. If behaviour changes, update changelog entries and adapter guides accordingly. + +Refer to [AGENTS.md](../../AGENTS.md) for the full checklist when touching the registry. + +## Example Usage + +```python +from sqlspec.core.parameters import get_driver_profile, build_statement_config_from_profile + +profile = get_driver_profile("duckdb") +config = build_statement_config_from_profile(profile) + +print(config.parameter_config.default_parameter_style) +``` + +The snippet above retrieves the DuckDB profile, builds a `StatementConfig`, and prints the default parameter style (`?`). Use the same pattern for new adapters after defining their profiles. diff --git a/docs/guides/adapters/psqlpy.md b/docs/guides/adapters/psqlpy.md index 39dedb661..b6965c016 100644 --- a/docs/guides/adapters/psqlpy.md +++ b/docs/guides/adapters/psqlpy.md @@ -12,6 +12,12 @@ This guide provides specific instructions for the `psqlpy` adapter for PostgreSQ - **Parameter Style:** `numeric` (e.g., `$1, $2`) - **Type System:** Rust-level type conversion (not Python-level) +## Parameter Profile + +- **Registry Key:** `"psqlpy"` +- **JSON Strategy:** `helper` (shared JSON serializer applied before Rust-side codecs) +- **Extras:** Decimal writes coerce through `_decimal_to_float` to match Rust numeric expectations + ## Architecture Psqlpy handles type conversion differently than other PostgreSQL drivers: diff --git a/docs/guides/adapters/psycopg.md b/docs/guides/adapters/psycopg.md index 82bb7c7d5..4c5ca67c0 100644 --- a/docs/guides/adapters/psycopg.md +++ b/docs/guides/adapters/psycopg.md @@ -11,6 +11,12 @@ This guide provides specific instructions and best practices for working with th - **Driver:** `psycopg` - **Parameter Style:** `pyformat` (e.g., `%s`) +## Parameter Profile + +- **Registry Key:** `"psycopg"` +- **JSON Strategy:** `helper` (shared JSON serializer wraps dict/list/tuple values before psycopg adapters run) +- **Extras:** None (adapter-specific list/tuple converters remain in driver to preserve array semantics) + ## Best Practices - **General Purpose:** `psycopg` is a robust, general-purpose PostgreSQL adapter. It has excellent type handling and is a good choice for a wide variety of applications. diff --git a/docs/guides/adapters/sqlite.md b/docs/guides/adapters/sqlite.md index e0baad2a9..6d67461c7 100644 --- a/docs/guides/adapters/sqlite.md +++ b/docs/guides/adapters/sqlite.md @@ -11,6 +11,12 @@ This guide covers `sqlite3` (sync) and `aiosqlite` (async). - **Driver:** `sqlite3` (built-in), `aiosqlite` - **Parameter Style:** `qmark` (e.g., `?`) +## Parameter Profile + +- **Registry Keys:** `"sqlite"` (sync), `"aiosqlite"` (async) +- **JSON Strategy:** `helper` for both drivers (shared serializer handles dict/list/tuple parameters) +- **Extras:** None (profiles apply ISO formatting for datetime/date and convert Decimal to string) + ## Best Practices - **Use Cases:** Ideal for testing, local development, and embedded applications. Not suitable for high-concurrency production workloads. diff --git a/sqlspec/adapters/adbc/config.py b/sqlspec/adapters/adbc/config.py index 783409e0f..d019ee3a9 100644 --- a/sqlspec/adapters/adbc/config.py +++ b/sqlspec/adapters/adbc/config.py @@ -148,6 +148,21 @@ def __init__( if "arrow_extension_types" not in driver_features: driver_features["arrow_extension_types"] = True + json_serializer = driver_features.get("json_serializer") + if json_serializer is not None: + parameter_config = statement_config.parameter_config + previous_list_converter = parameter_config.type_coercion_map.get(list) + previous_tuple_converter = parameter_config.type_coercion_map.get(tuple) + updated_parameter_config = parameter_config.with_json_serializers(json_serializer) + updated_map = dict(updated_parameter_config.type_coercion_map) + if previous_list_converter is not None: + updated_map[list] = previous_list_converter + if previous_tuple_converter is not None: + updated_map[tuple] = previous_tuple_converter + statement_config = statement_config.replace( + parameter_config=updated_parameter_config.replace(type_coercion_map=updated_map) + ) + super().__init__( connection_config=self.connection_config, migration_config=migration_config, diff --git a/sqlspec/adapters/adbc/driver.py b/sqlspec/adapters/adbc/driver.py index 985b12896..b3ff6503f 100644 --- a/sqlspec/adapters/adbc/driver.py +++ b/sqlspec/adapters/adbc/driver.py @@ -7,19 +7,19 @@ import contextlib import datetime import decimal +from functools import partial from typing import TYPE_CHECKING, Any, cast -from sqlglot import exp - from sqlspec.adapters.adbc.data_dictionary import AdbcDataDictionary from sqlspec.adapters.adbc.type_converter import ADBCTypeConverter from sqlspec.core.cache import get_cache_config from sqlspec.core.parameters import ( - ParameterProfile, + DriverParameterProfile, ParameterStyle, - ParameterStyleConfig, - ParameterValidator, - validate_parameter_alignment, + build_statement_config_from_profile, + get_driver_profile, + register_driver_profile, + replace_null_parameters_with_literals, ) from sqlspec.core.result import create_arrow_result from sqlspec.core.statement import SQL, StatementConfig @@ -39,6 +39,7 @@ from sqlspec.typing import Empty from sqlspec.utils.logging import get_logger from sqlspec.utils.module_loader import ensure_pyarrow +from sqlspec.utils.serializers import to_json if TYPE_CHECKING: from contextlib import AbstractContextManager @@ -75,198 +76,19 @@ "snowflake": (ParameterStyle.QMARK, [ParameterStyle.QMARK, ParameterStyle.NUMERIC]), } -_AST_PARAMETER_VALIDATOR: "ParameterValidator" = ParameterValidator() - - -def _is_execute_many_parameters(parameters: Any) -> bool: - """Check if parameters are in execute_many format (list/tuple of lists/tuples).""" - return isinstance(parameters, (list, tuple)) and len(parameters) > 0 and isinstance(parameters[0], (list, tuple)) - - -def _find_null_positions(parameters: Any) -> set[int]: - """Find positions of None values in parameters for single execution.""" - null_positions = set() - if isinstance(parameters, (list, tuple)): - for i, param in enumerate(parameters): - if param is None: - null_positions.add(i) - elif isinstance(parameters, dict): - for key, param in parameters.items(): - if param is None: - try: - if isinstance(key, str) and key.lstrip("$").isdigit(): - param_num = int(key.lstrip("$")) - null_positions.add(param_num - 1) - except ValueError: - pass - return null_positions - - -def _adbc_ast_transformer(expression: Any, parameters: Any, dialect: str = "postgres") -> tuple[Any, Any]: - """Transform AST to handle NULL parameters. - - Replaces NULL parameter placeholders with NULL literals in the AST - to prevent Arrow from inferring 'na' types which cause binding errors. - Validates parameter count before transformation. - - Args: - expression: SQLGlot AST expression parsed with proper dialect - parameters: Parameter values that may contain None - dialect: SQLGlot dialect used for parsing (default: "postgres") - - Returns: - Tuple of (modified_expression, cleaned_parameters) - """ - if not parameters: - return expression, parameters - - # For execute_many operations, skip AST transformation as different parameter - # sets may have None values in different positions, making transformation complex - if _is_execute_many_parameters(parameters): - return expression, parameters - - parameter_info = _AST_PARAMETER_VALIDATOR.extract_parameters(expression.sql(dialect=dialect)) - parameter_profile = ParameterProfile(parameter_info) - validate_parameter_alignment(parameter_profile, parameters) - - # Find positions of None values for single execution - null_positions = _find_null_positions(parameters) - if not null_positions: - return expression, parameters - - qmark_position = [0] - - def transform_node(node: Any) -> Any: - if isinstance(node, exp.Placeholder) and (not hasattr(node, "this") or node.this is None): - current_pos = qmark_position[0] - qmark_position[0] += 1 - - if current_pos in null_positions: - return exp.Null() - - return node - - if isinstance(node, exp.Placeholder) and hasattr(node, "this") and node.this is not None: - try: - param_str = str(node.this).lstrip("$") - param_num = int(param_str) - param_index = param_num - 1 - - if param_index in null_positions: - return exp.Null() - - nulls_before = sum(1 for idx in null_positions if idx < param_index) - new_param_num = param_num - nulls_before - return exp.Placeholder(this=f"${new_param_num}") - except (ValueError, AttributeError): - pass - - if isinstance(node, exp.Parameter) and hasattr(node, "this"): - try: - param_str = str(node.this) - param_num = int(param_str) - param_index = param_num - 1 - - if param_index in null_positions: - return exp.Null() - - nulls_before = sum(1 for idx in null_positions if idx < param_index) - new_param_num = param_num - nulls_before - return exp.Parameter(this=str(new_param_num)) - except (ValueError, AttributeError): - pass - - return node - - modified_expression = expression.transform(transform_node) - - cleaned_params: Any - if isinstance(parameters, (list, tuple)): - cleaned_params = [p for i, p in enumerate(parameters) if i not in null_positions] - elif isinstance(parameters, dict): - cleaned_params_dict = {} - new_num = 1 - for val in parameters.values(): - if val is not None: - cleaned_params_dict[str(new_num)] = val - new_num += 1 - cleaned_params = cleaned_params_dict - else: - cleaned_params = parameters - - return modified_expression, cleaned_params - - -def get_adbc_statement_config(detected_dialect: str) -> StatementConfig: - """Create statement configuration for the specified dialect.""" - default_style, supported_styles = DIALECT_PARAMETER_STYLES.get( - detected_dialect, (ParameterStyle.QMARK, [ParameterStyle.QMARK]) - ) - - type_map = get_type_coercion_map(detected_dialect) - - sqlglot_dialect = "postgres" if detected_dialect == "postgresql" else detected_dialect - - parameter_config = ParameterStyleConfig( - default_parameter_style=default_style, - supported_parameter_styles=set(supported_styles), - default_execution_parameter_style=default_style, - supported_execution_parameter_styles=set(supported_styles), - type_coercion_map=type_map, - has_native_list_expansion=True, - needs_static_script_compilation=False, - preserve_parameter_format=True, - ast_transformer=_adbc_ast_transformer if detected_dialect in {"postgres", "postgresql"} else None, - ) - return StatementConfig( - dialect=sqlglot_dialect, - parameter_config=parameter_config, - enable_parsing=True, - enable_validation=True, - enable_caching=True, - enable_parameter_type_wrapping=True, - ) +def _identity(value: Any) -> Any: + return value def _convert_array_for_postgres_adbc(value: Any) -> Any: - """Convert array values for PostgreSQL compatibility. - - Args: - value: Value to convert + """Convert array values for PostgreSQL compatibility.""" - Returns: - Converted value (tuples become lists) - """ if isinstance(value, tuple): return list(value) return value -def get_type_coercion_map(dialect: str) -> "dict[type, Any]": - """Get type coercion map for Arrow type handling with dialect-aware type conversion. - - Args: - dialect: Database dialect name - - Returns: - Mapping of Python types to conversion functions - """ - return { - datetime.datetime: lambda x: x, - datetime.date: lambda x: x, - datetime.time: lambda x: x, - decimal.Decimal: float, - bool: lambda x: x, - int: lambda x: x, - float: lambda x: x, - bytes: lambda x: x, - tuple: _convert_array_for_postgres_adbc, - list: _convert_array_for_postgres_adbc, - dict: lambda x: x, - } - - class AdbcCursor: """Context manager for cursor management.""" @@ -552,9 +374,9 @@ def _prepare_parameters_with_casts( Returns: Parameters with cast-aware type coercion applied """ - from sqlspec.utils.serializers import to_json - - json_encoder = self.driver_features.get("json_serializer", to_json) + json_encoder = statement_config.parameter_config.json_serializer or self.driver_features.get( + "json_serializer", to_json + ) if isinstance(parameters, (list, tuple)): result: list[Any] = [] @@ -816,10 +638,6 @@ def select_to_arrow( Returns: ArrowResult with native Arrow data - Raises: - MissingDependencyError: If pyarrow not installed - SQLExecutionError: If query execution fails - Example: >>> result = driver.select_to_arrow( ... "SELECT * FROM users WHERE age > $1", 18 @@ -862,3 +680,89 @@ def select_to_arrow( # Create ArrowResult return create_arrow_result(statement=prepared_statement, data=arrow_data, rows_affected=arrow_data.num_rows) + + +def get_type_coercion_map(dialect: str) -> "dict[type, Any]": + """Return dialect-aware type coercion mapping for Arrow parameter handling.""" + + return { + datetime.datetime: lambda x: x, + datetime.date: lambda x: x, + datetime.time: lambda x: x, + decimal.Decimal: float, + bool: lambda x: x, + int: lambda x: x, + float: lambda x: x, + bytes: lambda x: x, + tuple: _convert_array_for_postgres_adbc, + list: _convert_array_for_postgres_adbc, + dict: lambda x: x, + } + + +def _build_adbc_profile() -> DriverParameterProfile: + """Create the ADBC driver parameter profile.""" + + return DriverParameterProfile( + name="ADBC", + default_style=ParameterStyle.QMARK, + supported_styles={ParameterStyle.QMARK}, + default_execution_style=ParameterStyle.QMARK, + supported_execution_styles={ParameterStyle.QMARK}, + has_native_list_expansion=True, + preserve_parameter_format=True, + needs_static_script_compilation=False, + allow_mixed_parameter_styles=False, + preserve_original_params_for_many=False, + json_serializer_strategy="helper", + custom_type_coercions={ + datetime.datetime: _identity, + datetime.date: _identity, + datetime.time: _identity, + decimal.Decimal: float, + bool: _identity, + int: _identity, + float: _identity, + bytes: _identity, + tuple: _convert_array_for_postgres_adbc, + list: _convert_array_for_postgres_adbc, + dict: _identity, + }, + extras={ + "type_coercion_overrides": {list: _convert_array_for_postgres_adbc, tuple: _convert_array_for_postgres_adbc} + }, + ) + + +_ADBC_PROFILE = _build_adbc_profile() + +register_driver_profile("adbc", _ADBC_PROFILE) + + +def get_adbc_statement_config(detected_dialect: str) -> StatementConfig: + """Create statement configuration for the specified dialect.""" + default_style, supported_styles = DIALECT_PARAMETER_STYLES.get( + detected_dialect, (ParameterStyle.QMARK, [ParameterStyle.QMARK]) + ) + + type_map = get_type_coercion_map(detected_dialect) + + sqlglot_dialect = "postgres" if detected_dialect == "postgresql" else detected_dialect + + parameter_overrides: dict[str, Any] = { + "default_parameter_style": default_style, + "supported_parameter_styles": set(supported_styles), + "default_execution_parameter_style": default_style, + "supported_execution_parameter_styles": set(supported_styles), + "type_coercion_map": type_map, + } + + if detected_dialect in {"postgres", "postgresql"}: + parameter_overrides["ast_transformer"] = partial(replace_null_parameters_with_literals, dialect=sqlglot_dialect) + + return build_statement_config_from_profile( + get_driver_profile("adbc"), + parameter_overrides=parameter_overrides, + statement_overrides={"dialect": sqlglot_dialect}, + json_serializer=to_json, + ) diff --git a/sqlspec/adapters/aiosqlite/config.py b/sqlspec/adapters/aiosqlite/config.py index 351203951..7ecb1fbe8 100644 --- a/sqlspec/adapters/aiosqlite/config.py +++ b/sqlspec/adapters/aiosqlite/config.py @@ -123,11 +123,21 @@ def __init__( if "json_deserializer" not in processed_driver_features: processed_driver_features["json_deserializer"] = from_json + base_statement_config = statement_config or aiosqlite_statement_config + + json_serializer = processed_driver_features.get("json_serializer") + json_deserializer = processed_driver_features.get("json_deserializer") + if json_serializer is not None: + parameter_config = base_statement_config.parameter_config.with_json_serializers( + json_serializer, deserializer=json_deserializer + ) + base_statement_config = base_statement_config.replace(parameter_config=parameter_config) + super().__init__( pool_config=config_dict, pool_instance=pool_instance, migration_config=migration_config, - statement_config=statement_config or aiosqlite_statement_config, + statement_config=base_statement_config, driver_features=processed_driver_features, bind_key=bind_key, extension_config=extension_config, diff --git a/sqlspec/adapters/aiosqlite/driver.py b/sqlspec/adapters/aiosqlite/driver.py index f2ee6ca78..e38c5236a 100644 --- a/sqlspec/adapters/aiosqlite/driver.py +++ b/sqlspec/adapters/aiosqlite/driver.py @@ -2,15 +2,19 @@ import asyncio import contextlib -import datetime +from datetime import date, datetime from decimal import Decimal from typing import TYPE_CHECKING, Any import aiosqlite from sqlspec.core.cache import get_cache_config -from sqlspec.core.parameters import ParameterStyle, ParameterStyleConfig -from sqlspec.core.statement import StatementConfig +from sqlspec.core.parameters import ( + DriverParameterProfile, + ParameterStyle, + build_statement_config_from_profile, + register_driver_profile, +) from sqlspec.driver import AsyncDriverAdapterBase from sqlspec.exceptions import ( CheckViolationError, @@ -31,7 +35,7 @@ from sqlspec.adapters.aiosqlite._types import AiosqliteConnection from sqlspec.core.result import SQLResult - from sqlspec.core.statement import SQL + from sqlspec.core.statement import SQL, StatementConfig from sqlspec.driver import ExecutionResult from sqlspec.driver._async import AsyncDataDictionaryBase @@ -47,33 +51,6 @@ SQLITE_MISMATCH_CODE = 20 -aiosqlite_statement_config = StatementConfig( - dialect="sqlite", - parameter_config=ParameterStyleConfig( - default_parameter_style=ParameterStyle.QMARK, - supported_parameter_styles={ParameterStyle.QMARK}, - default_execution_parameter_style=ParameterStyle.QMARK, - supported_execution_parameter_styles={ParameterStyle.QMARK}, - type_coercion_map={ - bool: int, - datetime.datetime: lambda v: v.isoformat(), - datetime.date: lambda v: v.isoformat(), - Decimal: str, - dict: to_json, - list: to_json, - tuple: lambda v: to_json(list(v)), - }, - has_native_list_expansion=False, - needs_static_script_compilation=False, - preserve_parameter_format=True, - ), - enable_parsing=True, - enable_validation=True, - enable_caching=True, - enable_parameter_type_wrapping=True, -) - - class AiosqliteCursor: """Async context manager for AIOSQLite cursors.""" @@ -344,3 +321,53 @@ def data_dictionary(self) -> "AsyncDataDictionaryBase": self._data_dictionary = AiosqliteAsyncDataDictionary() return self._data_dictionary + + +def _bool_to_int(value: bool) -> int: + return int(value) + + +def _datetime_to_iso(value: datetime) -> str: + return value.isoformat() + + +def _date_to_iso(value: date) -> str: + return value.isoformat() + + +def _decimal_to_str(value: Decimal) -> str: + return str(value) + + +def _build_aiosqlite_profile() -> DriverParameterProfile: + """Create the AIOSQLite driver parameter profile.""" + + return DriverParameterProfile( + name="AIOSQLite", + default_style=ParameterStyle.QMARK, + supported_styles={ParameterStyle.QMARK}, + default_execution_style=ParameterStyle.QMARK, + supported_execution_styles={ParameterStyle.QMARK}, + has_native_list_expansion=False, + preserve_parameter_format=True, + needs_static_script_compilation=False, + allow_mixed_parameter_styles=False, + preserve_original_params_for_many=False, + json_serializer_strategy="helper", + custom_type_coercions={ + bool: _bool_to_int, + datetime: _datetime_to_iso, + date: _date_to_iso, + Decimal: _decimal_to_str, + }, + default_dialect="sqlite", + ) + + +_AIOSQLITE_PROFILE = _build_aiosqlite_profile() + +register_driver_profile("aiosqlite", _AIOSQLITE_PROFILE) + +aiosqlite_statement_config = build_statement_config_from_profile( + _AIOSQLITE_PROFILE, statement_overrides={"dialect": "sqlite"}, json_serializer=to_json +) diff --git a/sqlspec/adapters/asyncmy/config.py b/sqlspec/adapters/asyncmy/config.py index ec5254bab..4e9ca109f 100644 --- a/sqlspec/adapters/asyncmy/config.py +++ b/sqlspec/adapters/asyncmy/config.py @@ -11,7 +11,12 @@ from typing_extensions import NotRequired from sqlspec.adapters.asyncmy._types import AsyncmyConnection -from sqlspec.adapters.asyncmy.driver import AsyncmyCursor, AsyncmyDriver, asyncmy_statement_config +from sqlspec.adapters.asyncmy.driver import ( + AsyncmyCursor, + AsyncmyDriver, + asyncmy_statement_config, + build_asyncmy_statement_config, +) from sqlspec.config import ADKConfig, AsyncDatabaseConfig, FastAPIConfig, FlaskConfig, LitestarConfig, StarletteConfig from sqlspec.utils.serializers import from_json, to_json @@ -117,30 +122,19 @@ def __init__( if "port" not in processed_pool_config: processed_pool_config["port"] = 3306 - using_default_statement_config = statement_config is None - if using_default_statement_config: - statement_config = asyncmy_statement_config - processed_driver_features: dict[str, Any] = dict(driver_features) if driver_features else {} + serializer = processed_driver_features.setdefault("json_serializer", to_json) + deserializer = processed_driver_features.setdefault("json_deserializer", from_json) - if "json_serializer" not in processed_driver_features: - processed_driver_features["json_serializer"] = to_json - if "json_deserializer" not in processed_driver_features: - processed_driver_features["json_deserializer"] = from_json - - if statement_config is None: - statement_config = asyncmy_statement_config - - json_serializer = processed_driver_features.get("json_serializer") - if json_serializer is not None and using_default_statement_config: - parameter_config = statement_config.parameter_config.with_json_serializers(json_serializer) - statement_config = statement_config.replace(parameter_config=parameter_config) + base_statement_config = statement_config or build_asyncmy_statement_config( + json_serializer=serializer, json_deserializer=deserializer + ) super().__init__( pool_config=processed_pool_config, pool_instance=pool_instance, migration_config=migration_config, - statement_config=statement_config, + statement_config=base_statement_config, driver_features=processed_driver_features, bind_key=bind_key, extension_config=extension_config, diff --git a/sqlspec/adapters/asyncmy/driver.py b/sqlspec/adapters/asyncmy/driver.py index 1053531c8..d06260a57 100644 --- a/sqlspec/adapters/asyncmy/driver.py +++ b/sqlspec/adapters/asyncmy/driver.py @@ -12,8 +12,12 @@ from asyncmy.cursors import Cursor, DictCursor # pyright: ignore from sqlspec.core.cache import get_cache_config -from sqlspec.core.parameters import ParameterStyle, ParameterStyleConfig -from sqlspec.core.statement import StatementConfig +from sqlspec.core.parameters import ( + DriverParameterProfile, + ParameterStyle, + build_statement_config_from_profile, + register_driver_profile, +) from sqlspec.driver import AsyncDriverAdapterBase from sqlspec.exceptions import ( CheckViolationError, @@ -27,17 +31,24 @@ TransactionError, UniqueViolationError, ) -from sqlspec.utils.serializers import to_json +from sqlspec.utils.serializers import from_json, to_json if TYPE_CHECKING: + from collections.abc import Callable from contextlib import AbstractAsyncContextManager from sqlspec.adapters.asyncmy._types import AsyncmyConnection from sqlspec.core.result import SQLResult - from sqlspec.core.statement import SQL + from sqlspec.core.statement import SQL, StatementConfig from sqlspec.driver import ExecutionResult from sqlspec.driver._async import AsyncDataDictionaryBase -__all__ = ("AsyncmyCursor", "AsyncmyDriver", "AsyncmyExceptionHandler", "asyncmy_statement_config") +__all__ = ( + "AsyncmyCursor", + "AsyncmyDriver", + "AsyncmyExceptionHandler", + "asyncmy_statement_config", + "build_asyncmy_statement_config", +) logger = logging.getLogger(__name__) @@ -49,24 +60,6 @@ MYSQL_ER_NO_DEFAULT_FOR_FIELD = 1364 MYSQL_ER_CHECK_CONSTRAINT_VIOLATED = 3819 -asyncmy_statement_config = StatementConfig( - dialect="mysql", - parameter_config=ParameterStyleConfig( - default_parameter_style=ParameterStyle.QMARK, - supported_parameter_styles={ParameterStyle.QMARK, ParameterStyle.POSITIONAL_PYFORMAT}, - default_execution_parameter_style=ParameterStyle.POSITIONAL_PYFORMAT, - supported_execution_parameter_styles={ParameterStyle.POSITIONAL_PYFORMAT}, - type_coercion_map={dict: to_json, list: to_json, tuple: lambda v: to_json(list(v)), bool: int}, - has_native_list_expansion=False, - needs_static_script_compilation=True, - preserve_parameter_format=True, - ), - enable_parsing=True, - enable_validation=True, - enable_caching=True, - enable_parameter_type_wrapping=True, -) - class AsyncmyCursor: """Context manager for AsyncMy cursor operations. @@ -491,3 +484,50 @@ def data_dictionary(self) -> "AsyncDataDictionaryBase": self._data_dictionary = MySQLAsyncDataDictionary() return self._data_dictionary + + +def _bool_to_int(value: bool) -> int: + return int(value) + + +def _build_asyncmy_profile() -> DriverParameterProfile: + """Create the AsyncMy driver parameter profile.""" + + return DriverParameterProfile( + name="AsyncMy", + default_style=ParameterStyle.QMARK, + supported_styles={ParameterStyle.QMARK, ParameterStyle.POSITIONAL_PYFORMAT}, + default_execution_style=ParameterStyle.POSITIONAL_PYFORMAT, + supported_execution_styles={ParameterStyle.POSITIONAL_PYFORMAT}, + has_native_list_expansion=False, + preserve_parameter_format=True, + needs_static_script_compilation=True, + allow_mixed_parameter_styles=False, + preserve_original_params_for_many=False, + json_serializer_strategy="helper", + custom_type_coercions={bool: _bool_to_int}, + default_dialect="mysql", + ) + + +_ASYNCMY_PROFILE = _build_asyncmy_profile() + +register_driver_profile("asyncmy", _ASYNCMY_PROFILE) + + +def build_asyncmy_statement_config( + *, json_serializer: "Callable[[Any], str] | None" = None, json_deserializer: "Callable[[str], Any] | None" = None +) -> "StatementConfig": + """Construct the AsyncMy statement configuration with optional JSON codecs.""" + + serializer = json_serializer or to_json + deserializer = json_deserializer or from_json + return build_statement_config_from_profile( + _ASYNCMY_PROFILE, + statement_overrides={"dialect": "mysql"}, + json_serializer=serializer, + json_deserializer=deserializer, + ) + + +asyncmy_statement_config = build_asyncmy_statement_config() diff --git a/sqlspec/adapters/asyncpg/config.py b/sqlspec/adapters/asyncpg/config.py index fa78a8966..5aad19e4f 100644 --- a/sqlspec/adapters/asyncpg/config.py +++ b/sqlspec/adapters/asyncpg/config.py @@ -12,7 +12,12 @@ from typing_extensions import NotRequired from sqlspec.adapters.asyncpg._types import AsyncpgConnection -from sqlspec.adapters.asyncpg.driver import AsyncpgCursor, AsyncpgDriver, asyncpg_statement_config +from sqlspec.adapters.asyncpg.driver import ( + AsyncpgCursor, + AsyncpgDriver, + asyncpg_statement_config, + build_asyncpg_statement_config, +) from sqlspec.config import ADKConfig, AsyncDatabaseConfig, FastAPIConfig, FlaskConfig, LitestarConfig, StarletteConfig from sqlspec.typing import PGVECTOR_INSTALLED from sqlspec.utils.serializers import from_json, to_json @@ -122,20 +127,20 @@ def __init__( """ features_dict: dict[str, Any] = dict(driver_features) if driver_features else {} - if "json_serializer" not in features_dict: - features_dict["json_serializer"] = to_json - if "json_deserializer" not in features_dict: - features_dict["json_deserializer"] = from_json - if "enable_json_codecs" not in features_dict: - features_dict["enable_json_codecs"] = True - if "enable_pgvector" not in features_dict: - features_dict["enable_pgvector"] = PGVECTOR_INSTALLED + serializer = features_dict.setdefault("json_serializer", to_json) + deserializer = features_dict.setdefault("json_deserializer", from_json) + features_dict.setdefault("enable_json_codecs", True) + features_dict.setdefault("enable_pgvector", PGVECTOR_INSTALLED) + + base_statement_config = statement_config or build_asyncpg_statement_config( + json_serializer=serializer, json_deserializer=deserializer + ) super().__init__( pool_config=dict(pool_config) if pool_config else {}, pool_instance=pool_instance, migration_config=migration_config, - statement_config=statement_config or asyncpg_statement_config, + statement_config=base_statement_config, driver_features=features_dict, bind_key=bind_key, extension_config=extension_config, diff --git a/sqlspec/adapters/asyncpg/driver.py b/sqlspec/adapters/asyncpg/driver.py index fa54b64c5..897068142 100644 --- a/sqlspec/adapters/asyncpg/driver.py +++ b/sqlspec/adapters/asyncpg/driver.py @@ -1,8 +1,4 @@ -"""AsyncPG PostgreSQL driver implementation for async PostgreSQL operations. - -Provides async PostgreSQL connectivity with parameter processing, resource management, -PostgreSQL COPY operation support, and transaction management. -""" +"""AsyncPG PostgreSQL driver implementation for async PostgreSQL operations.""" import datetime import re @@ -10,9 +6,13 @@ import asyncpg -from sqlspec.core.cache import get_cache_config -from sqlspec.core.parameters import ParameterStyle, ParameterStyleConfig -from sqlspec.core.statement import StatementConfig +from sqlspec.core import ( + DriverParameterProfile, + ParameterStyle, + build_statement_config_from_profile, + get_cache_config, + register_driver_profile, +) from sqlspec.driver import AsyncDriverAdapterBase from sqlspec.exceptions import ( CheckViolationError, @@ -28,86 +28,28 @@ UniqueViolationError, ) from sqlspec.utils.logging import get_logger +from sqlspec.utils.serializers import from_json, to_json if TYPE_CHECKING: + from collections.abc import Callable from contextlib import AbstractAsyncContextManager from sqlspec.adapters.asyncpg._types import AsyncpgConnection - from sqlspec.core.result import SQLResult - from sqlspec.core.statement import SQL - from sqlspec.driver import ExecutionResult - from sqlspec.driver._async import AsyncDataDictionaryBase - -__all__ = ("AsyncpgCursor", "AsyncpgDriver", "AsyncpgExceptionHandler", "asyncpg_statement_config") + from sqlspec.core import SQL, ParameterStyleConfig, SQLResult, StatementConfig + from sqlspec.driver import AsyncDataDictionaryBase, ExecutionResult + +__all__ = ( + "AsyncpgCursor", + "AsyncpgDriver", + "AsyncpgExceptionHandler", + "_configure_asyncpg_parameter_serializers", + "asyncpg_statement_config", + "build_asyncpg_statement_config", +) logger = get_logger("adapters.asyncpg") -def _convert_datetime_param(value: Any) -> Any: - """Convert datetime parameter, handling ISO strings. - - Args: - value: datetime object or ISO format string - - Returns: - datetime object for asyncpg - """ - if isinstance(value, str): - return datetime.datetime.fromisoformat(value) - return value - - -def _convert_date_param(value: Any) -> Any: - """Convert date parameter, handling ISO strings. - - Args: - value: date object or ISO format string - - Returns: - date object for asyncpg - """ - if isinstance(value, str): - return datetime.date.fromisoformat(value) - return value - - -def _convert_time_param(value: Any) -> Any: - """Convert time parameter, handling ISO strings. - - Args: - value: time object or ISO format string - - Returns: - time object for asyncpg - """ - if isinstance(value, str): - return datetime.time.fromisoformat(value) - return value - - -asyncpg_statement_config = StatementConfig( - dialect="postgres", - parameter_config=ParameterStyleConfig( - default_parameter_style=ParameterStyle.NUMERIC, - supported_parameter_styles={ParameterStyle.NUMERIC, ParameterStyle.POSITIONAL_PYFORMAT}, - default_execution_parameter_style=ParameterStyle.NUMERIC, - supported_execution_parameter_styles={ParameterStyle.NUMERIC}, - type_coercion_map={ - datetime.datetime: _convert_datetime_param, - datetime.date: _convert_date_param, - datetime.time: _convert_time_param, - }, - has_native_list_expansion=True, - needs_static_script_compilation=False, - preserve_parameter_format=True, - ), - enable_parsing=True, - enable_validation=True, - enable_caching=True, - enable_parameter_type_wrapping=True, -) - - ASYNC_PG_STATUS_REGEX: Final[re.Pattern[str]] = re.compile(r"^([A-Z]+)(?:\s+(\d+))?\s+(\d+)$", re.IGNORECASE) EXPECTED_REGEX_GROUPS: Final[int] = 3 @@ -456,3 +398,99 @@ def data_dictionary(self) -> "AsyncDataDictionaryBase": self._data_dictionary = PostgresAsyncDataDictionary() return self._data_dictionary + + +def _convert_datetime_param(value: Any) -> Any: + """Convert datetime parameter, handling ISO strings.""" + + if isinstance(value, str): + return datetime.datetime.fromisoformat(value) + return value + + +def _convert_date_param(value: Any) -> Any: + """Convert date parameter, handling ISO strings.""" + + if isinstance(value, str): + return datetime.date.fromisoformat(value) + return value + + +def _convert_time_param(value: Any) -> Any: + """Convert time parameter, handling ISO strings.""" + + if isinstance(value, str): + return datetime.time.fromisoformat(value) + return value + + +def _build_asyncpg_custom_type_coercions() -> dict[type, "Callable[[Any], Any]"]: + """Return custom type coercions for AsyncPG.""" + + return { + datetime.datetime: _convert_datetime_param, + datetime.date: _convert_date_param, + datetime.time: _convert_time_param, + } + + +def _build_asyncpg_profile() -> DriverParameterProfile: + """Create the AsyncPG driver parameter profile.""" + + return DriverParameterProfile( + name="AsyncPG", + default_style=ParameterStyle.NUMERIC, + supported_styles={ParameterStyle.NUMERIC, ParameterStyle.POSITIONAL_PYFORMAT}, + default_execution_style=ParameterStyle.NUMERIC, + supported_execution_styles={ParameterStyle.NUMERIC}, + has_native_list_expansion=True, + preserve_parameter_format=True, + needs_static_script_compilation=False, + allow_mixed_parameter_styles=False, + preserve_original_params_for_many=False, + json_serializer_strategy="driver", + custom_type_coercions=_build_asyncpg_custom_type_coercions(), + default_dialect="postgres", + ) + + +_ASYNC_PG_PROFILE = _build_asyncpg_profile() + +register_driver_profile("asyncpg", _ASYNC_PG_PROFILE) + + +def _configure_asyncpg_parameter_serializers( + parameter_config: "ParameterStyleConfig", + serializer: "Callable[[Any], str]", + *, + deserializer: "Callable[[str], Any] | None" = None, +) -> "ParameterStyleConfig": + """Return a parameter configuration updated with AsyncPG JSON codecs.""" + + effective_deserializer = deserializer or parameter_config.json_deserializer or from_json + return parameter_config.replace(json_serializer=serializer, json_deserializer=effective_deserializer) + + +def build_asyncpg_statement_config( + *, json_serializer: "Callable[[Any], str] | None" = None, json_deserializer: "Callable[[str], Any] | None" = None +) -> "StatementConfig": + """Construct the AsyncPG statement configuration with optional JSON codecs.""" + + effective_serializer = json_serializer or to_json + effective_deserializer = json_deserializer or from_json + + base_config = build_statement_config_from_profile( + _ASYNC_PG_PROFILE, + statement_overrides={"dialect": "postgres"}, + json_serializer=effective_serializer, + json_deserializer=effective_deserializer, + ) + + parameter_config = _configure_asyncpg_parameter_serializers( + base_config.parameter_config, effective_serializer, deserializer=effective_deserializer + ) + + return base_config.replace(parameter_config=parameter_config) + + +asyncpg_statement_config = build_asyncpg_statement_config() diff --git a/sqlspec/adapters/bigquery/config.py b/sqlspec/adapters/bigquery/config.py index f59d2437b..965bb886e 100644 --- a/sqlspec/adapters/bigquery/config.py +++ b/sqlspec/adapters/bigquery/config.py @@ -8,10 +8,11 @@ from typing_extensions import NotRequired from sqlspec.adapters.bigquery._types import BigQueryConnection -from sqlspec.adapters.bigquery.driver import BigQueryCursor, BigQueryDriver, bigquery_statement_config +from sqlspec.adapters.bigquery.driver import BigQueryCursor, BigQueryDriver, build_bigquery_statement_config from sqlspec.config import ADKConfig, FastAPIConfig, FlaskConfig, LitestarConfig, NoPoolSyncConfig, StarletteConfig from sqlspec.exceptions import ImproperConfigurationError from sqlspec.typing import Empty +from sqlspec.utils.serializers import to_json if TYPE_CHECKING: from collections.abc import Callable, Generator @@ -132,23 +133,19 @@ def __init__( if "enable_uuid_conversion" not in self.driver_features: self.driver_features["enable_uuid_conversion"] = True - if "json_serializer" not in self.driver_features: - from sqlspec.utils.serializers import to_json - - self.driver_features["json_serializer"] = to_json + serializer = self.driver_features.setdefault("json_serializer", to_json) self._connection_instance: BigQueryConnection | None = self.driver_features.get("connection_instance") if "default_query_job_config" not in self.connection_config: self._setup_default_job_config() - if statement_config is None: - statement_config = bigquery_statement_config + base_statement_config = statement_config or build_bigquery_statement_config(json_serializer=serializer) super().__init__( connection_config=self.connection_config, migration_config=migration_config, - statement_config=statement_config, + statement_config=base_statement_config, driver_features=self.driver_features, bind_key=bind_key, extension_config=extension_config, diff --git a/sqlspec/adapters/bigquery/driver.py b/sqlspec/adapters/bigquery/driver.py index 4c2110523..9d0f63147 100644 --- a/sqlspec/adapters/bigquery/driver.py +++ b/sqlspec/adapters/bigquery/driver.py @@ -7,16 +7,21 @@ import datetime import logging from decimal import Decimal -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast import sqlglot -import sqlglot.expressions as exp from google.cloud.bigquery import ArrayQueryParameter, QueryJob, QueryJobConfig, ScalarQueryParameter from google.cloud.exceptions import GoogleCloudError from sqlspec.adapters.bigquery._types import BigQueryConnection from sqlspec.adapters.bigquery.type_converter import BigQueryTypeConverter -from sqlspec.core import ParameterStyle, ParameterStyleConfig, StatementConfig, create_arrow_result, get_cache_config +from sqlspec.core import ParameterStyle, StatementConfig, create_arrow_result, get_cache_config +from sqlspec.core.parameters import ( + DriverParameterProfile, + build_statement_config_from_profile, + register_driver_profile, + replace_placeholders_with_literals, +) from sqlspec.driver import ExecutionResult, SyncDriverAdapterBase from sqlspec.exceptions import ( DatabaseConnectionError, @@ -40,7 +45,13 @@ logger = logging.getLogger(__name__) -__all__ = ("BigQueryCursor", "BigQueryDriver", "BigQueryExceptionHandler", "bigquery_statement_config") +__all__ = ( + "BigQueryCursor", + "BigQueryDriver", + "BigQueryExceptionHandler", + "bigquery_statement_config", + "build_bigquery_statement_config", +) HTTP_CONFLICT = 409 HTTP_NOT_FOUND = 404 @@ -49,7 +60,12 @@ HTTP_SERVER_ERROR = 500 -_default_type_converter = BigQueryTypeConverter() +def _identity(value: Any) -> Any: + return value + + +def _tuple_to_list(value: tuple[Any, ...]) -> list[Any]: + return list(value) _BQ_TYPE_MAP: dict[type, tuple[str, str | None]] = { @@ -107,91 +123,6 @@ def _create_scalar_parameter(name: str, value: Any, param_type: str) -> ScalarQu return ScalarQueryParameter(name, param_type, value) -def _create_literal_node(value: Any, json_serializer: "Callable[[Any], str]") -> "exp.Expression": - """Create a SQLGlot literal expression from a Python value. - - Args: - value: Python value to convert to SQLGlot literal. - json_serializer: Function to serialize dict/list to JSON string. - - Returns: - SQLGlot expression representing the literal value. - """ - if value is None: - return exp.Null() - if isinstance(value, bool): - return exp.Boolean(this=value) - if isinstance(value, (int, float)): - return exp.Literal.number(str(value)) - if isinstance(value, str): - return exp.Literal.string(value) - if isinstance(value, (list, tuple)): - items = [_create_literal_node(item, json_serializer) for item in value] - return exp.Array(expressions=items) - if isinstance(value, dict): - json_str = json_serializer(value) - return exp.Literal.string(json_str) - - return exp.Literal.string(str(value)) - - -def _replace_placeholder_node( - node: "exp.Expression", - parameters: Any, - placeholder_counter: dict[str, int], - json_serializer: "Callable[[Any], str]", -) -> "exp.Expression": - """Replace placeholder or parameter nodes with literal values. - - Handles both positional placeholders (?) and named parameters (@name, :name). - Converts values to SQLGlot literal expressions for safe embedding in SQL. - - Args: - node: SQLGlot expression node to check and potentially replace. - parameters: Parameter values (dict, list, or tuple). - placeholder_counter: Mutable counter dict for positional placeholders. - json_serializer: Function to serialize dict/list to JSON string. - - Returns: - Literal expression if replacement made, otherwise original node. - """ - if isinstance(node, exp.Placeholder): - if isinstance(parameters, (list, tuple)): - current_index = placeholder_counter["index"] - placeholder_counter["index"] += 1 - if current_index < len(parameters): - return _create_literal_node(parameters[current_index], json_serializer) - return node - - if isinstance(node, exp.Parameter): - param_name = str(node.this) if hasattr(node.this, "__str__") else node.this - - if isinstance(parameters, dict): - possible_names = [param_name, f"@{param_name}", f":{param_name}", f"param_{param_name}"] - for name in possible_names: - if name in parameters: - actual_value = getattr(parameters[name], "value", parameters[name]) - return _create_literal_node(actual_value, json_serializer) - return node - - if isinstance(parameters, (list, tuple)): - try: - if param_name.startswith("param_"): - param_index = int(param_name[6:]) - if param_index < len(parameters): - return _create_literal_node(parameters[param_index], json_serializer) - - if param_name.isdigit(): - param_index = int(param_name) - if param_index < len(parameters): - return _create_literal_node(parameters[param_index], json_serializer) - except (ValueError, IndexError, AttributeError): - pass - return node - - return node - - def _get_bq_param_type(value: Any) -> tuple[str | None, str | None]: """Determine BigQuery parameter type from Python value. @@ -285,53 +216,6 @@ def _create_bq_parameters( return bq_parameters -def _get_bigquery_type_coercion_map(type_converter: BigQueryTypeConverter) -> dict[type, Any]: - """Get BigQuery type coercion map with configurable type converter. - - Args: - type_converter: BigQuery type converter instance - - Returns: - Type coercion map for BigQuery - """ - return { - tuple: list, - bool: lambda x: x, - int: lambda x: x, - float: lambda x: x, - bytes: lambda x: x, - datetime.datetime: lambda x: x, - datetime.date: lambda x: x, - datetime.time: lambda x: x, - Decimal: lambda x: x, - dict: lambda x: x, - list: lambda x: x, - type(None): lambda _: None, - } - - -bigquery_type_coercion_map = _get_bigquery_type_coercion_map(_default_type_converter) - - -bigquery_statement_config = StatementConfig( - dialect="bigquery", - parameter_config=ParameterStyleConfig( - default_parameter_style=ParameterStyle.NAMED_AT, - supported_parameter_styles={ParameterStyle.NAMED_AT, ParameterStyle.QMARK}, - default_execution_parameter_style=ParameterStyle.NAMED_AT, - supported_execution_parameter_styles={ParameterStyle.NAMED_AT}, - type_coercion_map=bigquery_type_coercion_map, - has_native_list_expansion=True, - needs_static_script_compilation=False, - preserve_original_params_for_many=True, - ), - enable_parsing=True, - enable_validation=True, - enable_caching=True, - enable_parameter_type_wrapping=True, -) - - class BigQueryCursor: """BigQuery cursor with resource management.""" @@ -466,12 +350,6 @@ def __init__( ) -> None: features = driver_features or {} - json_serializer = features.get("json_serializer") - if json_serializer is None: - json_serializer = to_json - - self._json_serializer: Callable[[Any], str] = json_serializer - enable_uuid_conversion = features.get("enable_uuid_conversion", True) self._type_converter = BigQueryTypeConverter(enable_uuid_conversion=enable_uuid_conversion) @@ -479,6 +357,12 @@ def __init__( cache_config = get_cache_config() statement_config = bigquery_statement_config.replace(cache_config=cache_config) + parameter_json_serializer = statement_config.parameter_config.json_serializer + if parameter_json_serializer is None: + parameter_json_serializer = features.get("json_serializer", to_json) + + self._json_serializer: Callable[[Any], str] = parameter_json_serializer + super().__init__(connection=connection, statement_config=statement_config, driver_features=driver_features) self._default_query_job_config: QueryJobConfig | None = (driver_features or {}).get("default_query_job_config") self._data_dictionary: SyncDataDictionaryBase | None = None @@ -622,13 +506,9 @@ def _transform_ast_with_literals(self, sql: str, parameters: Any) -> str: except sqlglot.ParseError: return sql - placeholder_counter = {"index": 0} + transformed_ast = replace_placeholders_with_literals(ast, parameters, json_serializer=self._json_serializer) - transformed_ast = ast.transform( - lambda node: _replace_placeholder_node(node, parameters, placeholder_counter, self._json_serializer) - ) - - return transformed_ast.sql(dialect="bigquery") + return cast("str", transformed_ast.sql(dialect="bigquery")) def _execute_script(self, cursor: Any, statement: "SQL") -> ExecutionResult: """Execute SQL script with statement splitting and parameter handling. @@ -870,3 +750,52 @@ def select_to_arrow( # Create ArrowResult return create_arrow_result(statement=prepared_statement, data=arrow_data, rows_affected=arrow_data.num_rows) + + +def _build_bigquery_profile() -> DriverParameterProfile: + """Create the BigQuery driver parameter profile.""" + + return DriverParameterProfile( + name="BigQuery", + default_style=ParameterStyle.NAMED_AT, + supported_styles={ParameterStyle.NAMED_AT, ParameterStyle.QMARK}, + default_execution_style=ParameterStyle.NAMED_AT, + supported_execution_styles={ParameterStyle.NAMED_AT}, + has_native_list_expansion=True, + preserve_parameter_format=True, + needs_static_script_compilation=False, + allow_mixed_parameter_styles=False, + preserve_original_params_for_many=True, + json_serializer_strategy="helper", + custom_type_coercions={ + int: _identity, + float: _identity, + bytes: _identity, + datetime.datetime: _identity, + datetime.date: _identity, + datetime.time: _identity, + Decimal: _identity, + dict: _identity, + list: _identity, + type(None): lambda _: None, + }, + extras={"json_tuple_strategy": "tuple", "type_coercion_overrides": {list: _identity, tuple: _tuple_to_list}}, + default_dialect="bigquery", + ) + + +_BIGQUERY_PROFILE = _build_bigquery_profile() + +register_driver_profile("bigquery", _BIGQUERY_PROFILE) + + +def build_bigquery_statement_config(*, json_serializer: "Callable[[Any], str] | None" = None) -> StatementConfig: + """Construct the BigQuery statement configuration with optional JSON serializer.""" + + serializer = json_serializer or to_json + return build_statement_config_from_profile( + _BIGQUERY_PROFILE, statement_overrides={"dialect": "bigquery"}, json_serializer=serializer + ) + + +bigquery_statement_config = build_bigquery_statement_config() diff --git a/sqlspec/adapters/duckdb/config.py b/sqlspec/adapters/duckdb/config.py index 8831ef9c1..600a844f4 100644 --- a/sqlspec/adapters/duckdb/config.py +++ b/sqlspec/adapters/duckdb/config.py @@ -2,14 +2,15 @@ from collections.abc import Sequence from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, ClassVar, TypedDict +from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast from typing_extensions import NotRequired from sqlspec.adapters.duckdb._types import DuckDBConnection -from sqlspec.adapters.duckdb.driver import DuckDBCursor, DuckDBDriver, duckdb_statement_config +from sqlspec.adapters.duckdb.driver import DuckDBCursor, DuckDBDriver, build_duckdb_statement_config from sqlspec.adapters.duckdb.pool import DuckDBConnectionPool from sqlspec.config import ADKConfig, FastAPIConfig, FlaskConfig, LitestarConfig, StarletteConfig, SyncDatabaseConfig +from sqlspec.utils.serializers import to_json if TYPE_CHECKING: from collections.abc import Callable, Generator @@ -210,15 +211,19 @@ def __init__( pool_config["database"] = ":memory:shared_db" processed_features = dict(driver_features) if driver_features else {} - if "enable_uuid_conversion" not in processed_features: - processed_features["enable_uuid_conversion"] = True + processed_features.setdefault("enable_uuid_conversion", True) + serializer = processed_features.setdefault("json_serializer", to_json) + + base_statement_config = statement_config or build_duckdb_statement_config( + json_serializer=cast("Callable[[Any], str]", serializer) + ) super().__init__( bind_key=bind_key, pool_config=dict(pool_config), pool_instance=pool_instance, migration_config=migration_config, - statement_config=statement_config or duckdb_statement_config, + statement_config=base_statement_config, driver_features=processed_features, extension_config=extension_config, ) diff --git a/sqlspec/adapters/duckdb/driver.py b/sqlspec/adapters/duckdb/driver.py index 6ccd2b358..e2936dae9 100644 --- a/sqlspec/adapters/duckdb/driver.py +++ b/sqlspec/adapters/duckdb/driver.py @@ -1,6 +1,7 @@ """DuckDB driver implementation.""" -import datetime +import typing +from datetime import date, datetime from decimal import Decimal from typing import TYPE_CHECKING, Any, Final @@ -8,7 +9,8 @@ from sqlspec.adapters.duckdb.data_dictionary import DuckDBSyncDataDictionary from sqlspec.adapters.duckdb.type_converter import DuckDBTypeConverter -from sqlspec.core import SQL, ParameterStyle, ParameterStyleConfig, StatementConfig, get_cache_config +from sqlspec.core import SQL, ParameterStyle, StatementConfig, get_cache_config +from sqlspec.core.parameters import DriverParameterProfile, build_statement_config_from_profile, register_driver_profile from sqlspec.driver import SyncDriverAdapterBase from sqlspec.exceptions import ( CheckViolationError, @@ -36,43 +38,19 @@ from sqlspec.driver._sync import SyncDataDictionaryBase from sqlspec.typing import ArrowReturnFormat, StatementParameters -__all__ = ("DuckDBCursor", "DuckDBDriver", "DuckDBExceptionHandler", "duckdb_statement_config") +__all__ = ( + "DuckDBCursor", + "DuckDBDriver", + "DuckDBExceptionHandler", + "build_duckdb_statement_config", + "duckdb_statement_config", +) logger = get_logger("adapters.duckdb") _type_converter = DuckDBTypeConverter() -duckdb_statement_config = StatementConfig( - dialect="duckdb", - parameter_config=ParameterStyleConfig( - default_parameter_style=ParameterStyle.QMARK, - supported_parameter_styles={ParameterStyle.QMARK, ParameterStyle.NUMERIC, ParameterStyle.NAMED_DOLLAR}, - default_execution_parameter_style=ParameterStyle.QMARK, - supported_execution_parameter_styles={ParameterStyle.QMARK, ParameterStyle.NUMERIC}, - type_coercion_map={ - bool: int, - datetime.datetime: lambda v: v.isoformat(), - datetime.date: lambda v: v.isoformat(), - Decimal: str, - dict: to_json, - list: to_json, - }, - has_native_list_expansion=True, - needs_static_script_compilation=False, - preserve_parameter_format=True, - allow_mixed_parameter_styles=False, - ), - enable_parsing=True, - enable_validation=True, - enable_caching=True, - enable_parameter_type_wrapping=True, -) - - -MODIFYING_OPERATIONS: Final[tuple[str, ...]] = ("INSERT", "UPDATE", "DELETE") - - class DuckDBCursor: """Context manager for DuckDB cursor management.""" @@ -440,11 +418,6 @@ def select_to_arrow( Returns: ArrowResult with native Arrow data - - Raises: - MissingDependencyError: If pyarrow not installed - SQLExecutionError: If query execution fails - Example: >>> result = driver.select_to_arrow( ... "SELECT * FROM users WHERE age > ?", 18 @@ -492,3 +465,64 @@ def select_to_arrow( # Create ArrowResult return create_arrow_result(statement=prepared_statement, data=arrow_data, rows_affected=arrow_data.num_rows) + + +def _bool_to_int(value: bool) -> int: + return int(value) + + +def _datetime_to_iso(value: datetime) -> str: + return value.isoformat() + + +def _date_to_iso(value: date) -> str: + return value.isoformat() + + +def _decimal_to_str(value: Decimal) -> str: + return str(value) + + +def _build_duckdb_profile() -> DriverParameterProfile: + """Create the DuckDB driver parameter profile.""" + + return DriverParameterProfile( + name="DuckDB", + default_style=ParameterStyle.QMARK, + supported_styles={ParameterStyle.QMARK, ParameterStyle.NUMERIC, ParameterStyle.NAMED_DOLLAR}, + default_execution_style=ParameterStyle.QMARK, + supported_execution_styles={ParameterStyle.QMARK}, + has_native_list_expansion=True, + preserve_parameter_format=True, + needs_static_script_compilation=False, + allow_mixed_parameter_styles=False, + preserve_original_params_for_many=False, + json_serializer_strategy="helper", + custom_type_coercions={ + bool: _bool_to_int, + datetime: _datetime_to_iso, + date: _date_to_iso, + Decimal: _decimal_to_str, + }, + default_dialect="duckdb", + ) + + +_DUCKDB_PROFILE = _build_duckdb_profile() + +register_driver_profile("duckdb", _DUCKDB_PROFILE) + + +def build_duckdb_statement_config(*, json_serializer: "typing.Callable[[Any], str] | None" = None) -> StatementConfig: + """Construct the DuckDB statement configuration with optional JSON serializer.""" + + serializer = json_serializer or to_json + return build_statement_config_from_profile( + _DUCKDB_PROFILE, statement_overrides={"dialect": "duckdb"}, json_serializer=serializer + ) + + +duckdb_statement_config = build_duckdb_statement_config() + + +MODIFYING_OPERATIONS: Final[tuple[str, ...]] = ("INSERT", "UPDATE", "DELETE") diff --git a/sqlspec/adapters/oracledb/driver.py b/sqlspec/adapters/oracledb/driver.py index 49426a192..6e7c2010c 100644 --- a/sqlspec/adapters/oracledb/driver.py +++ b/sqlspec/adapters/oracledb/driver.py @@ -11,14 +11,8 @@ from sqlspec.adapters.oracledb._types import OracleAsyncConnection, OracleSyncConnection from sqlspec.adapters.oracledb.data_dictionary import OracleAsyncDataDictionary, OracleSyncDataDictionary from sqlspec.adapters.oracledb.type_converter import OracleTypeConverter -from sqlspec.core import ( - SQL, - ParameterStyle, - ParameterStyleConfig, - StatementConfig, - create_arrow_result, - get_cache_config, -) +from sqlspec.core import SQL, ParameterStyle, StatementConfig, create_arrow_result, get_cache_config +from sqlspec.core.parameters import DriverParameterProfile, build_statement_config_from_profile, register_driver_profile from sqlspec.driver import ( AsyncDataDictionaryBase, AsyncDriverAdapterBase, @@ -145,25 +139,6 @@ async def _coerce_async_row_values(row: "tuple[Any, ...]") -> "list[Any]": ORA_TABLESPACE_FULL = 1652 -oracledb_statement_config = StatementConfig( - dialect="oracle", - parameter_config=ParameterStyleConfig( - default_parameter_style=ParameterStyle.POSITIONAL_COLON, - supported_parameter_styles={ParameterStyle.NAMED_COLON, ParameterStyle.POSITIONAL_COLON, ParameterStyle.QMARK}, - default_execution_parameter_style=ParameterStyle.NAMED_COLON, - supported_execution_parameter_styles={ParameterStyle.NAMED_COLON, ParameterStyle.POSITIONAL_COLON}, - type_coercion_map={dict: to_json, list: to_json}, - has_native_list_expansion=False, - needs_static_script_compilation=False, - preserve_parameter_format=True, - ), - enable_parsing=True, - enable_validation=True, - enable_caching=True, - enable_parameter_type_wrapping=True, -) - - class OracleSyncCursor: """Sync context manager for Oracle cursor management.""" @@ -904,3 +879,31 @@ def data_dictionary(self) -> "AsyncDataDictionaryBase": if self._data_dictionary is None: self._data_dictionary = OracleAsyncDataDictionary() return self._data_dictionary + + +def _build_oracledb_profile() -> DriverParameterProfile: + """Create the OracleDB driver parameter profile.""" + + return DriverParameterProfile( + name="OracleDB", + default_style=ParameterStyle.POSITIONAL_COLON, + supported_styles={ParameterStyle.NAMED_COLON, ParameterStyle.POSITIONAL_COLON, ParameterStyle.QMARK}, + default_execution_style=ParameterStyle.NAMED_COLON, + supported_execution_styles={ParameterStyle.NAMED_COLON, ParameterStyle.POSITIONAL_COLON}, + has_native_list_expansion=False, + preserve_parameter_format=True, + needs_static_script_compilation=False, + allow_mixed_parameter_styles=False, + preserve_original_params_for_many=False, + json_serializer_strategy="helper", + default_dialect="oracle", + ) + + +_ORACLE_PROFILE = _build_oracledb_profile() + +register_driver_profile("oracledb", _ORACLE_PROFILE) + +oracledb_statement_config = build_statement_config_from_profile( + _ORACLE_PROFILE, statement_overrides={"dialect": "oracle"}, json_serializer=to_json +) diff --git a/sqlspec/adapters/psqlpy/config.py b/sqlspec/adapters/psqlpy/config.py index 8dee4bfd9..e4858a280 100644 --- a/sqlspec/adapters/psqlpy/config.py +++ b/sqlspec/adapters/psqlpy/config.py @@ -3,16 +3,17 @@ import logging from collections.abc import AsyncGenerator from contextlib import asynccontextmanager -from typing import TYPE_CHECKING, Any, ClassVar, TypedDict +from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast from psqlpy import ConnectionPool from typing_extensions import NotRequired from sqlspec.adapters.psqlpy._types import PsqlpyConnection -from sqlspec.adapters.psqlpy.driver import PsqlpyCursor, PsqlpyDriver, psqlpy_statement_config +from sqlspec.adapters.psqlpy.driver import PsqlpyCursor, PsqlpyDriver, build_psqlpy_statement_config from sqlspec.config import ADKConfig, AsyncDatabaseConfig, FastAPIConfig, FlaskConfig, LitestarConfig, StarletteConfig from sqlspec.core.statement import StatementConfig from sqlspec.typing import PGVECTOR_INSTALLED +from sqlspec.utils.serializers import to_json if TYPE_CHECKING: from collections.abc import Callable @@ -81,9 +82,13 @@ class PsqlpyDriverFeatures(TypedDict): Requires pgvector-python package installed. Defaults to True when pgvector is installed. Provides automatic conversion between NumPy arrays and PostgreSQL vector types. + json_serializer: Custom JSON serializer applied to the statement configuration. + json_deserializer: Custom JSON deserializer retained alongside the serializer for parity with asyncpg. """ enable_pgvector: NotRequired[bool] + json_serializer: NotRequired["Callable[[Any], str]"] + json_deserializer: NotRequired["Callable[[str], Any]"] __all__ = ("PsqlpyConfig", "PsqlpyConnectionParams", "PsqlpyCursor", "PsqlpyDriverFeatures", "PsqlpyPoolParams") @@ -124,14 +129,16 @@ def __init__( processed_pool_config.update(extras) processed_driver_features: dict[str, Any] = dict(driver_features) if driver_features else {} - if "enable_pgvector" not in processed_driver_features: - processed_driver_features["enable_pgvector"] = PGVECTOR_INSTALLED + serializer = processed_driver_features.get("json_serializer") + serializer_callable = to_json if serializer is None else cast("Callable[[Any], str]", serializer) + processed_driver_features.setdefault("json_serializer", serializer_callable) + processed_driver_features.setdefault("enable_pgvector", PGVECTOR_INSTALLED) super().__init__( pool_config=processed_pool_config, pool_instance=pool_instance, migration_config=migration_config, - statement_config=statement_config or psqlpy_statement_config, + statement_config=statement_config or build_psqlpy_statement_config(json_serializer=serializer_callable), driver_features=processed_driver_features, bind_key=bind_key, extension_config=extension_config, diff --git a/sqlspec/adapters/psqlpy/driver.py b/sqlspec/adapters/psqlpy/driver.py index e97176a60..3aae48cbc 100644 --- a/sqlspec/adapters/psqlpy/driver.py +++ b/sqlspec/adapters/psqlpy/driver.py @@ -4,8 +4,10 @@ and transaction management. """ +import datetime import decimal import re +import uuid from typing import TYPE_CHECKING, Any, Final import psqlpy.exceptions @@ -14,7 +16,13 @@ from sqlspec.adapters.psqlpy.data_dictionary import PsqlpyAsyncDataDictionary from sqlspec.adapters.psqlpy.type_converter import PostgreSQLTypeConverter from sqlspec.core.cache import get_cache_config -from sqlspec.core.parameters import ParameterStyle, ParameterStyleConfig +from sqlspec.core.parameters import ( + DriverParameterProfile, + ParameterStyle, + ParameterStyleConfig, + build_statement_config_from_profile, + register_driver_profile, +) from sqlspec.core.statement import SQL, StatementConfig from sqlspec.driver import AsyncDriverAdapterBase from sqlspec.exceptions import ( @@ -32,8 +40,10 @@ ) from sqlspec.typing import Empty from sqlspec.utils.logging import get_logger +from sqlspec.utils.serializers import to_json if TYPE_CHECKING: + from collections.abc import Callable from contextlib import AbstractAsyncContextManager from sqlspec.adapters.psqlpy._types import PsqlpyConnection @@ -41,100 +51,29 @@ from sqlspec.driver import ExecutionResult from sqlspec.driver._async import AsyncDataDictionaryBase -__all__ = ("PsqlpyCursor", "PsqlpyDriver", "PsqlpyExceptionHandler", "psqlpy_statement_config") +__all__ = ( + "PsqlpyCursor", + "PsqlpyDriver", + "PsqlpyExceptionHandler", + "build_psqlpy_statement_config", + "psqlpy_statement_config", +) logger = get_logger("adapters.psqlpy") _type_converter = PostgreSQLTypeConverter() - -def _convert_decimals_in_structure(obj: Any) -> Any: - """Recursively convert Decimal values to float in nested structures. - - Psqlpy's Rust layer expects native Python dict/list for JSONB parameters - (when using CAST(... AS JSONB) in SQL), but cannot handle Decimal objects. - This function walks through dict/list structures and converts any Decimal - values to float while preserving the native Python structure. - - Args: - obj: Object to process (dict, list, or scalar value). - - Returns: - Object with all Decimal values converted to float, preserving structure. - """ - if isinstance(obj, decimal.Decimal): - return float(obj) - if isinstance(obj, dict): - return {k: _convert_decimals_in_structure(v) for k, v in obj.items()} - if isinstance(obj, (list, tuple)): - return [_convert_decimals_in_structure(item) for item in obj] - return obj - - -def _prepare_dict_parameter(value: "dict[str, Any]") -> dict[str, Any]: - """Normalize dict parameters while preserving native structures.""" - normalized = _convert_decimals_in_structure(value) - return normalized if isinstance(normalized, dict) else value - - -def _prepare_list_parameter(value: "list[Any]") -> list[Any]: - """Normalize list parameters while preserving native list semantics.""" - return [_convert_decimals_in_structure(item) for item in value] - - -def _prepare_tuple_parameter(value: "tuple[Any, ...]") -> tuple[Any, ...]: - """Normalize tuple parameters by deferring to list handling.""" - return tuple(_convert_decimals_in_structure(item) for item in value) - - -def _normalize_scalar_parameter(value: Any) -> Any: - """Return scalar value without additional coercion.""" - return value - - -def _coerce_numeric_for_write(value: Any) -> Any: - """Convert write parameters to driver-compatible numeric types.""" - if isinstance(value, float): - return decimal.Decimal(str(value)) - if isinstance(value, decimal.Decimal): - return value - if isinstance(value, list): - return [_coerce_numeric_for_write(item) for item in value] - if isinstance(value, tuple): - coerced = [_coerce_numeric_for_write(item) for item in value] - return tuple(coerced) - if isinstance(value, dict): - return {key: _coerce_numeric_for_write(item) for key, item in value.items()} - return value - - -psqlpy_statement_config = StatementConfig( - dialect="postgres", - parameter_config=ParameterStyleConfig( - default_parameter_style=ParameterStyle.NUMERIC, - supported_parameter_styles={ParameterStyle.NUMERIC, ParameterStyle.NAMED_DOLLAR, ParameterStyle.QMARK}, - default_execution_parameter_style=ParameterStyle.NUMERIC, - supported_execution_parameter_styles={ParameterStyle.NUMERIC}, - type_coercion_map={ - dict: _prepare_dict_parameter, - list: _prepare_list_parameter, - tuple: _prepare_tuple_parameter, - decimal.Decimal: float, - str: _type_converter.convert_if_detected, - }, - has_native_list_expansion=False, - needs_static_script_compilation=False, - allow_mixed_parameter_styles=False, - preserve_parameter_format=True, - ), - enable_parsing=True, - enable_validation=True, - enable_caching=True, - enable_parameter_type_wrapping=True, -) - PSQLPY_STATUS_REGEX: Final[re.Pattern[str]] = re.compile(r"^([A-Z]+)(?:\s+(\d+))?\s+(\d+)$", re.IGNORECASE) +_JSON_CASTS: Final[frozenset[str]] = frozenset({"JSON", "JSONB"}) +_TIMESTAMP_CASTS: Final[frozenset[str]] = frozenset({ + "TIMESTAMP", + "TIMESTAMPTZ", + "TIMESTAMP WITH TIME ZONE", + "TIMESTAMP WITHOUT TIME ZONE", +}) +_UUID_CASTS: Final[frozenset[str]] = frozenset({"UUID"}) + class PsqlpyCursor: """Context manager for psqlpy cursor management.""" @@ -355,17 +294,19 @@ def _prepare_parameters_with_casts( """ if isinstance(parameters, (list, tuple)): result: list[Any] = [] + serializer = statement_config.parameter_config.json_serializer or to_json + type_map = statement_config.parameter_config.type_coercion_map for idx, param in enumerate(parameters, start=1): - cast_type = parameter_casts.get(idx, "").upper() - if cast_type in {"JSON", "JSONB"} and isinstance(param, list): - result.append(JSONB(param)) - else: - if statement_config.parameter_config.type_coercion_map: - for type_check, converter in statement_config.parameter_config.type_coercion_map.items(): - if isinstance(param, type_check): - param = converter(param) - break - result.append(param) + cast_type = parameter_casts.get(idx, "") + prepared_value = param + if type_map: + for type_check, converter in type_map.items(): + if isinstance(prepared_value, type_check): + prepared_value = converter(prepared_value) + break + if cast_type: + prepared_value = _coerce_parameter_for_cast(prepared_value, cast_type, serializer) + result.append(prepared_value) return tuple(result) if isinstance(parameters, tuple) else result return parameters @@ -577,3 +518,205 @@ def data_dictionary(self) -> "AsyncDataDictionaryBase": if self._data_dictionary is None: self._data_dictionary = PsqlpyAsyncDataDictionary() return self._data_dictionary + + +def _convert_decimals_in_structure(obj: Any) -> Any: + """Recursively convert Decimal values to float in nested structures.""" + + if isinstance(obj, decimal.Decimal): + return float(obj) + if isinstance(obj, dict): + return {k: _convert_decimals_in_structure(v) for k, v in obj.items()} + if isinstance(obj, (list, tuple)): + return [_convert_decimals_in_structure(item) for item in obj] + return obj + + +def _coerce_json_parameter(value: Any, cast_type: str, serializer: "Callable[[Any], str]") -> Any: + """Serialize JSON parameters according to the detected cast type. + + Args: + value: Parameter value supplied by the caller. + cast_type: Uppercase cast identifier detected in SQL. + serializer: JSON serialization callable from statement config. + + Returns: + Serialized parameter suitable for driver execution. + + Raises: + SQLSpecError: If serialization fails for JSON payloads. + """ + + if value is None: + return None + if cast_type == "JSONB": + if isinstance(value, JSONB): + return value + if isinstance(value, dict): + return JSONB(value) + if isinstance(value, (list, tuple)): + return JSONB(list(value)) + if isinstance(value, tuple): + return list(value) + if isinstance(value, (dict, list, str, JSONB)): + return value + try: + serialized_value = serializer(value) + except Exception as error: + msg = "Failed to serialize JSON parameter for psqlpy." + raise SQLSpecError(msg) from error + return serialized_value + + +def _coerce_uuid_parameter(value: Any) -> Any: + """Convert UUID-compatible parameters to ``uuid.UUID`` instances. + + Args: + value: Parameter value supplied by the caller. + + Returns: + ``uuid.UUID`` instance when input is coercible, otherwise original value. + + Raises: + SQLSpecError: If the value cannot be converted to ``uuid.UUID``. + """ + + if isinstance(value, uuid.UUID): + return value + if isinstance(value, str): + try: + return uuid.UUID(value) + except ValueError as error: + msg = "Invalid UUID parameter for psqlpy." + raise SQLSpecError(msg) from error + return value + + +def _coerce_timestamp_parameter(value: Any) -> Any: + """Convert ISO-formatted timestamp strings to ``datetime.datetime``. + + Args: + value: Parameter value supplied by the caller. + + Returns: + ``datetime.datetime`` instance when conversion succeeds, otherwise original value. + + Raises: + SQLSpecError: If the value cannot be parsed as an ISO timestamp. + """ + + if isinstance(value, datetime.datetime): + return value + if isinstance(value, str): + normalized_value = value[:-1] + "+00:00" if value.endswith("Z") else value + try: + return datetime.datetime.fromisoformat(normalized_value) + except ValueError as error: + msg = "Invalid ISO timestamp parameter for psqlpy." + raise SQLSpecError(msg) from error + return value + + +def _coerce_parameter_for_cast(value: Any, cast_type: str, serializer: "Callable[[Any], str]") -> Any: + """Apply cast-aware coercion for psqlpy parameters. + + Args: + value: Parameter value supplied by the caller. + cast_type: Uppercase cast identifier detected in SQL. + serializer: JSON serialization callable from statement config. + + Returns: + Coerced value appropriate for the specified cast, or the original value. + """ + + upper_cast = cast_type.upper() + if upper_cast in _JSON_CASTS: + return _coerce_json_parameter(value, upper_cast, serializer) + if upper_cast in _UUID_CASTS: + return _coerce_uuid_parameter(value) + if upper_cast in _TIMESTAMP_CASTS: + return _coerce_timestamp_parameter(value) + return value + + +def _prepare_dict_parameter(value: "dict[str, Any]") -> dict[str, Any]: + normalized = _convert_decimals_in_structure(value) + return normalized if isinstance(normalized, dict) else value + + +def _prepare_list_parameter(value: "list[Any]") -> list[Any]: + return [_convert_decimals_in_structure(item) for item in value] + + +def _prepare_tuple_parameter(value: "tuple[Any, ...]") -> tuple[Any, ...]: + return tuple(_convert_decimals_in_structure(item) for item in value) + + +def _normalize_scalar_parameter(value: Any) -> Any: + return value + + +def _coerce_numeric_for_write(value: Any) -> Any: + if isinstance(value, float): + return decimal.Decimal(str(value)) + if isinstance(value, decimal.Decimal): + return value + if isinstance(value, list): + return [_coerce_numeric_for_write(item) for item in value] + if isinstance(value, tuple): + coerced = [_coerce_numeric_for_write(item) for item in value] + return tuple(coerced) + if isinstance(value, dict): + return {key: _coerce_numeric_for_write(item) for key, item in value.items()} + return value + + +def _build_psqlpy_profile() -> DriverParameterProfile: + """Create the psqlpy driver parameter profile.""" + + return DriverParameterProfile( + name="Psqlpy", + default_style=ParameterStyle.NUMERIC, + supported_styles={ParameterStyle.NUMERIC, ParameterStyle.NAMED_DOLLAR, ParameterStyle.QMARK}, + default_execution_style=ParameterStyle.NUMERIC, + supported_execution_styles={ParameterStyle.NUMERIC}, + has_native_list_expansion=False, + preserve_parameter_format=True, + needs_static_script_compilation=False, + allow_mixed_parameter_styles=False, + preserve_original_params_for_many=False, + json_serializer_strategy="helper", + custom_type_coercions={decimal.Decimal: float}, + default_dialect="postgres", + ) + + +_PSQLPY_PROFILE = _build_psqlpy_profile() + +register_driver_profile("psqlpy", _PSQLPY_PROFILE) + + +def _create_psqlpy_parameter_config(serializer: "Callable[[Any], str]") -> ParameterStyleConfig: + base_config = build_statement_config_from_profile(_PSQLPY_PROFILE, json_serializer=serializer).parameter_config + + updated_type_map = dict(base_config.type_coercion_map) + updated_type_map[dict] = _prepare_dict_parameter + updated_type_map[list] = _prepare_list_parameter + updated_type_map[tuple] = _prepare_tuple_parameter + + return base_config.replace(type_coercion_map=updated_type_map) + + +def build_psqlpy_statement_config(*, json_serializer: "Callable[[Any], str]" = to_json) -> StatementConfig: + parameter_config = _create_psqlpy_parameter_config(json_serializer) + return StatementConfig( + dialect="postgres", + parameter_config=parameter_config, + enable_parsing=True, + enable_validation=True, + enable_caching=True, + enable_parameter_type_wrapping=True, + ) + + +psqlpy_statement_config = build_psqlpy_statement_config() diff --git a/sqlspec/adapters/psycopg/config.py b/sqlspec/adapters/psycopg/config.py index b33a23067..69f6b68e2 100644 --- a/sqlspec/adapters/psycopg/config.py +++ b/sqlspec/adapters/psycopg/config.py @@ -15,6 +15,7 @@ PsycopgAsyncDriver, PsycopgSyncCursor, PsycopgSyncDriver, + build_psycopg_statement_config, psycopg_statement_config, ) from sqlspec.config import ( @@ -27,6 +28,7 @@ SyncDatabaseConfig, ) from sqlspec.typing import PGVECTOR_INSTALLED +from sqlspec.utils.serializers import to_json if TYPE_CHECKING: from collections.abc import AsyncGenerator, Callable, Generator @@ -82,9 +84,13 @@ class PsycopgDriverFeatures(TypedDict): Provides automatic conversion between Python objects and PostgreSQL vector types. Enables vector similarity operations and index support. Set to False to disable pgvector support even when package is available. + json_serializer: Custom JSON serializer for StatementConfig parameter handling. + json_deserializer: Custom JSON deserializer reference stored alongside the serializer for parity with asyncpg. """ enable_pgvector: NotRequired[bool] + json_serializer: NotRequired["Callable[[Any], str]"] + json_deserializer: NotRequired["Callable[[str], Any]"] __all__ = ( @@ -132,17 +138,17 @@ def __init__( extras = processed_pool_config.pop("extra") processed_pool_config.update(extras) - if driver_features is None: - driver_features = {} - if "enable_pgvector" not in driver_features: - driver_features["enable_pgvector"] = PGVECTOR_INSTALLED + processed_driver_features: dict[str, Any] = dict(driver_features) if driver_features else {} + serializer = cast("Callable[[Any], str]", processed_driver_features.get("json_serializer", to_json)) + processed_driver_features.setdefault("json_serializer", serializer) + processed_driver_features.setdefault("enable_pgvector", PGVECTOR_INSTALLED) super().__init__( pool_config=processed_pool_config, pool_instance=pool_instance, migration_config=migration_config, - statement_config=statement_config or psycopg_statement_config, - driver_features=driver_features, + statement_config=statement_config or build_psycopg_statement_config(json_serializer=serializer), + driver_features=processed_driver_features, bind_key=bind_key, extension_config=extension_config, ) @@ -323,17 +329,17 @@ def __init__( extras = processed_pool_config.pop("extra") processed_pool_config.update(extras) - if driver_features is None: - driver_features = {} - if "enable_pgvector" not in driver_features: - driver_features["enable_pgvector"] = PGVECTOR_INSTALLED + processed_driver_features: dict[str, Any] = dict(driver_features) if driver_features else {} + serializer = cast("Callable[[Any], str]", processed_driver_features.get("json_serializer", to_json)) + processed_driver_features.setdefault("json_serializer", serializer) + processed_driver_features.setdefault("enable_pgvector", PGVECTOR_INSTALLED) super().__init__( pool_config=processed_pool_config, pool_instance=pool_instance, migration_config=migration_config, - statement_config=statement_config or psycopg_statement_config, - driver_features=driver_features, + statement_config=statement_config or build_psycopg_statement_config(json_serializer=serializer), + driver_features=processed_driver_features, bind_key=bind_key, extension_config=extension_config, ) diff --git a/sqlspec/adapters/psycopg/driver.py b/sqlspec/adapters/psycopg/driver.py index fa4195072..f80b06b1b 100644 --- a/sqlspec/adapters/psycopg/driver.py +++ b/sqlspec/adapters/psycopg/driver.py @@ -22,7 +22,13 @@ from sqlspec.adapters.psycopg._types import PsycopgAsyncConnection, PsycopgSyncConnection from sqlspec.core.cache import get_cache_config -from sqlspec.core.parameters import ParameterStyle, ParameterStyleConfig +from sqlspec.core.parameters import ( + DriverParameterProfile, + ParameterStyle, + ParameterStyleConfig, + build_statement_config_from_profile, + register_driver_profile, +) from sqlspec.core.result import SQLResult from sqlspec.core.statement import SQL, StatementConfig from sqlspec.driver import AsyncDriverAdapterBase, SyncDriverAdapterBase @@ -43,105 +49,13 @@ from sqlspec.utils.serializers import to_json if TYPE_CHECKING: + from collections.abc import Callable from contextlib import AbstractAsyncContextManager, AbstractContextManager from sqlspec.driver._async import AsyncDataDictionaryBase from sqlspec.driver._common import ExecutionResult from sqlspec.driver._sync import SyncDataDictionaryBase -logger = get_logger("adapters.psycopg") - - -TRANSACTION_STATUS_IDLE = 0 -TRANSACTION_STATUS_ACTIVE = 1 -TRANSACTION_STATUS_INTRANS = 2 -TRANSACTION_STATUS_INERROR = 3 -TRANSACTION_STATUS_UNKNOWN = 4 - - -def _convert_list_to_postgres_array(value: Any) -> str: - """Convert Python list to PostgreSQL array literal format. - - Args: - value: Python list to convert - - Returns: - PostgreSQL array literal string - """ - if not isinstance(value, list): - return str(value) - - elements = [] - for item in value: - if isinstance(item, list): - elements.append(_convert_list_to_postgres_array(item)) - elif isinstance(item, str): - escaped = item.replace("'", "''") - elements.append(f"'{escaped}'") - elif item is None: - elements.append("NULL") - else: - elements.append(str(item)) - - return f"{{{','.join(elements)}}}" - - -def _should_serialize_list(value: "list[Any]") -> bool: - """Detect whether list should be serialized to JSON.""" - return any(isinstance(item, (dict, list, tuple)) for item in value) - - -def _prepare_list_parameter(value: "list[Any]") -> Any: - """Convert complex lists to JSON strings while keeping primitive arrays.""" - if not value: - return value - if _should_serialize_list(value): - return to_json(value) - return value - - -def _prepare_tuple_parameter(value: "tuple[Any, ...]") -> Any: - """Normalize tuple parameters for psycopg binding.""" - return _prepare_list_parameter(list(value)) - - -psycopg_statement_config = StatementConfig( - dialect="postgres", - pre_process_steps=None, - post_process_steps=None, - enable_parsing=True, - enable_transformations=True, - enable_validation=True, - enable_caching=True, - enable_parameter_type_wrapping=True, - parameter_config=ParameterStyleConfig( - default_parameter_style=ParameterStyle.POSITIONAL_PYFORMAT, - supported_parameter_styles={ - ParameterStyle.POSITIONAL_PYFORMAT, - ParameterStyle.NAMED_PYFORMAT, - ParameterStyle.NUMERIC, - ParameterStyle.QMARK, - }, - default_execution_parameter_style=ParameterStyle.POSITIONAL_PYFORMAT, - supported_execution_parameter_styles={ - ParameterStyle.POSITIONAL_PYFORMAT, - ParameterStyle.NAMED_PYFORMAT, - ParameterStyle.NUMERIC, - }, - type_coercion_map={ - dict: to_json, - list: _prepare_list_parameter, - tuple: _prepare_tuple_parameter, - datetime.datetime: lambda x: x, - datetime.date: lambda x: x, - datetime.time: lambda x: x, - }, - has_native_list_expansion=True, - needs_static_script_compilation=False, - preserve_parameter_format=True, - ), -) - __all__ = ( "PsycopgAsyncCursor", "PsycopgAsyncDriver", @@ -149,9 +63,19 @@ def _prepare_tuple_parameter(value: "tuple[Any, ...]") -> Any: "PsycopgSyncCursor", "PsycopgSyncDriver", "PsycopgSyncExceptionHandler", + "build_psycopg_statement_config", "psycopg_statement_config", ) +logger = get_logger("adapters.psycopg") + + +TRANSACTION_STATUS_IDLE = 0 +TRANSACTION_STATUS_ACTIVE = 1 +TRANSACTION_STATUS_INTRANS = 2 +TRANSACTION_STATUS_INERROR = 3 +TRANSACTION_STATUS_UNKNOWN = 4 + class PsycopgSyncCursor: """Context manager for PostgreSQL psycopg cursor management.""" @@ -892,3 +816,130 @@ def data_dictionary(self) -> "AsyncDataDictionaryBase": self._data_dictionary = PostgresAsyncDataDictionary() return self._data_dictionary + + +def _identity(value: Any) -> Any: + return value + + +def _convert_list_to_postgres_array(value: Any) -> str: + """Convert a Python list to PostgreSQL array literal format.""" + + if not isinstance(value, list): + return str(value) + + elements: list[str] = [] + for item in value: + if isinstance(item, list): + elements.append(_convert_list_to_postgres_array(item)) + elif isinstance(item, str): + escaped = item.replace("'", "''") + elements.append(f"'{escaped}'") + elif item is None: + elements.append("NULL") + else: + elements.append(str(item)) + + return "{" + ",".join(elements) + "}" + + +def _should_serialize_list(value: "list[Any]") -> bool: + """Detect whether a list should be serialized to JSON.""" + + return any(isinstance(item, (dict, list, tuple)) for item in value) + + +def _build_list_parameter_converter(serializer: "Callable[[Any], str]") -> "Callable[[list[Any]], Any]": + """Create converter that serializes nested lists while preserving arrays.""" + + def convert(value: "list[Any]") -> Any: + if not value: + return value + if _should_serialize_list(value): + return serializer(value) + return value + + return convert + + +def _build_tuple_parameter_converter(serializer: "Callable[[Any], str]") -> "Callable[[tuple[Any, ...]], Any]": + """Create converter mirroring list handling for tuple parameters.""" + + list_converter = _build_list_parameter_converter(serializer) + + def convert(value: "tuple[Any, ...]") -> Any: + return list_converter(list(value)) + + return convert + + +def _build_psycopg_custom_type_coercions() -> dict[type, "Callable[[Any], Any]"]: + """Return custom type coercions for psycopg.""" + + return {datetime.datetime: _identity, datetime.date: _identity, datetime.time: _identity} + + +def _build_psycopg_profile() -> DriverParameterProfile: + """Create the psycopg driver parameter profile.""" + + return DriverParameterProfile( + name="Psycopg", + default_style=ParameterStyle.POSITIONAL_PYFORMAT, + supported_styles={ + ParameterStyle.POSITIONAL_PYFORMAT, + ParameterStyle.NAMED_PYFORMAT, + ParameterStyle.NUMERIC, + ParameterStyle.QMARK, + }, + default_execution_style=ParameterStyle.POSITIONAL_PYFORMAT, + supported_execution_styles={ + ParameterStyle.POSITIONAL_PYFORMAT, + ParameterStyle.NAMED_PYFORMAT, + ParameterStyle.NUMERIC, + }, + has_native_list_expansion=True, + preserve_parameter_format=True, + needs_static_script_compilation=False, + allow_mixed_parameter_styles=False, + preserve_original_params_for_many=False, + json_serializer_strategy="helper", + custom_type_coercions=_build_psycopg_custom_type_coercions(), + default_dialect="postgres", + ) + + +_PSYCOPG_PROFILE = _build_psycopg_profile() + +register_driver_profile("psycopg", _PSYCOPG_PROFILE) + + +def _create_psycopg_parameter_config(serializer: "Callable[[Any], str]") -> ParameterStyleConfig: + """Construct parameter configuration with shared JSON serializer support.""" + + base_config = build_statement_config_from_profile(_PSYCOPG_PROFILE, json_serializer=serializer).parameter_config + + updated_type_map = dict(base_config.type_coercion_map) + updated_type_map[list] = _build_list_parameter_converter(serializer) + updated_type_map[tuple] = _build_tuple_parameter_converter(serializer) + + return base_config.replace(type_coercion_map=updated_type_map) + + +def build_psycopg_statement_config(*, json_serializer: "Callable[[Any], str]" = to_json) -> StatementConfig: + """Construct the psycopg statement configuration with optional JSON codecs.""" + + parameter_config = _create_psycopg_parameter_config(json_serializer) + return StatementConfig( + dialect="postgres", + pre_process_steps=None, + post_process_steps=None, + enable_parsing=True, + enable_transformations=True, + enable_validation=True, + enable_caching=True, + enable_parameter_type_wrapping=True, + parameter_config=parameter_config, + ) + + +psycopg_statement_config = build_psycopg_statement_config() diff --git a/sqlspec/adapters/sqlite/config.py b/sqlspec/adapters/sqlite/config.py index abc80c7fd..b2b3f389d 100644 --- a/sqlspec/adapters/sqlite/config.py +++ b/sqlspec/adapters/sqlite/config.py @@ -112,12 +112,22 @@ def __init__( if "json_deserializer" not in processed_driver_features: processed_driver_features["json_deserializer"] = from_json + base_statement_config = statement_config or sqlite_statement_config + + json_serializer = processed_driver_features.get("json_serializer") + json_deserializer = processed_driver_features.get("json_deserializer") + if json_serializer is not None: + parameter_config = base_statement_config.parameter_config.with_json_serializers( + json_serializer, deserializer=json_deserializer + ) + base_statement_config = base_statement_config.replace(parameter_config=parameter_config) + super().__init__( bind_key=bind_key, pool_instance=pool_instance, pool_config=cast("dict[str, Any]", pool_config), migration_config=migration_config, - statement_config=statement_config or sqlite_statement_config, + statement_config=base_statement_config, driver_features=processed_driver_features, extension_config=extension_config, ) diff --git a/sqlspec/adapters/sqlite/driver.py b/sqlspec/adapters/sqlite/driver.py index f7dcb70f0..76c1921be 100644 --- a/sqlspec/adapters/sqlite/driver.py +++ b/sqlspec/adapters/sqlite/driver.py @@ -1,14 +1,18 @@ """SQLite driver implementation.""" import contextlib -import datetime import sqlite3 +from datetime import date, datetime from decimal import Decimal from typing import TYPE_CHECKING, Any from sqlspec.core.cache import get_cache_config -from sqlspec.core.parameters import ParameterStyle, ParameterStyleConfig -from sqlspec.core.statement import StatementConfig +from sqlspec.core.parameters import ( + DriverParameterProfile, + ParameterStyle, + build_statement_config_from_profile, + register_driver_profile, +) from sqlspec.driver import SyncDriverAdapterBase from sqlspec.exceptions import ( CheckViolationError, @@ -29,7 +33,7 @@ from sqlspec.adapters.sqlite._types import SqliteConnection from sqlspec.core.result import SQLResult - from sqlspec.core.statement import SQL + from sqlspec.core.statement import SQL, StatementConfig from sqlspec.driver import ExecutionResult from sqlspec.driver._sync import SyncDataDictionaryBase @@ -44,31 +48,6 @@ SQLITE_IOERR_CODE = 10 SQLITE_MISMATCH_CODE = 20 -sqlite_statement_config = StatementConfig( - dialect="sqlite", - parameter_config=ParameterStyleConfig( - default_parameter_style=ParameterStyle.QMARK, - supported_parameter_styles={ParameterStyle.QMARK, ParameterStyle.NAMED_COLON}, - default_execution_parameter_style=ParameterStyle.QMARK, - supported_execution_parameter_styles={ParameterStyle.QMARK, ParameterStyle.NAMED_COLON}, - type_coercion_map={ - bool: int, - datetime.datetime: lambda v: v.isoformat(), - datetime.date: lambda v: v.isoformat(), - Decimal: str, - dict: to_json, - list: to_json, - }, - has_native_list_expansion=False, - needs_static_script_compilation=False, - preserve_parameter_format=True, - ), - enable_parsing=True, - enable_validation=True, - enable_caching=True, - enable_parameter_type_wrapping=True, -) - class SqliteCursor: """Context manager for SQLite cursor management. @@ -408,3 +387,53 @@ def data_dictionary(self) -> "SyncDataDictionaryBase": self._data_dictionary = SqliteSyncDataDictionary() return self._data_dictionary + + +def _bool_to_int(value: bool) -> int: + return int(value) + + +def _datetime_to_iso(value: datetime) -> str: + return value.isoformat() + + +def _date_to_iso(value: date) -> str: + return value.isoformat() + + +def _decimal_to_str(value: Decimal) -> str: + return str(value) + + +def _build_sqlite_profile() -> DriverParameterProfile: + """Create the SQLite driver parameter profile.""" + + return DriverParameterProfile( + name="SQLite", + default_style=ParameterStyle.QMARK, + supported_styles={ParameterStyle.QMARK, ParameterStyle.NAMED_COLON}, + default_execution_style=ParameterStyle.QMARK, + supported_execution_styles={ParameterStyle.QMARK, ParameterStyle.NAMED_COLON}, + has_native_list_expansion=False, + preserve_parameter_format=True, + needs_static_script_compilation=False, + allow_mixed_parameter_styles=False, + preserve_original_params_for_many=False, + json_serializer_strategy="helper", + custom_type_coercions={ + bool: _bool_to_int, + datetime: _datetime_to_iso, + date: _date_to_iso, + Decimal: _decimal_to_str, + }, + default_dialect="sqlite", + ) + + +_SQLITE_PROFILE = _build_sqlite_profile() + +register_driver_profile("sqlite", _SQLITE_PROFILE) + +sqlite_statement_config = build_statement_config_from_profile( + _SQLITE_PROFILE, statement_overrides={"dialect": "sqlite"}, json_serializer=to_json +) diff --git a/sqlspec/core/__init__.py b/sqlspec/core/__init__.py index 7f1eb5704..e0d0b32c7 100644 --- a/sqlspec/core/__init__.py +++ b/sqlspec/core/__init__.py @@ -101,11 +101,14 @@ hash_sql_statement, ) from sqlspec.core.parameters import ( + DriverParameterProfile, ParameterConverter, ParameterProcessor, ParameterStyle, ParameterStyleConfig, TypedParameter, + build_statement_config_from_profile, + register_driver_profile, ) from sqlspec.core.result import ArrowResult, SQLResult, StatementResult, create_arrow_result, create_sql_result from sqlspec.core.statement import SQL, Statement, StatementConfig @@ -115,6 +118,7 @@ "ArrowResult", "CacheConfig", "CacheStats", + "DriverParameterProfile", "MultiLevelCache", "OperationType", "ParameterConverter", @@ -129,6 +133,7 @@ "StatementResult", "TypedParameter", "UnifiedCache", + "build_statement_config_from_profile", "create_arrow_result", "create_sql_result", "filters", @@ -139,4 +144,5 @@ "hash_optimized_expression", "hash_parameters", "hash_sql_statement", + "register_driver_profile", ) diff --git a/sqlspec/core/parameters.py b/sqlspec/core/parameters.py deleted file mode 100644 index ed4f15d5a..000000000 --- a/sqlspec/core/parameters.py +++ /dev/null @@ -1,1636 +0,0 @@ -"""Parameter processing system for SQL statements. - -This module implements parameter processing including type conversion, -style conversion, and validation for SQL statements. - -Components: -- ParameterStyle enum: Supported parameter styles -- TypedParameter: Preserves type information through processing -- ParameterInfo: Tracks parameter metadata -- ParameterValidator: Extracts and validates parameters -- ParameterConverter: Handles parameter style conversions -- ParameterProcessor: Parameter processing coordinator -- ParameterStyleConfig: Configuration for parameter processing - -Processing: -- Two-phase processing: compatibility and execution format -- Type-specific parameter wrapping -- Parameter style conversions -- Support for multiple parameter styles and database adapters -""" - -import hashlib -import re -from collections import OrderedDict -from collections.abc import Callable, Generator, Mapping, Sequence -from datetime import date, datetime -from decimal import Decimal -from enum import Enum -from functools import singledispatch -from typing import Any, Final, Literal, cast - -from mypy_extensions import mypyc_attr - -import sqlspec.exceptions - -__all__ = ( - "ParameterConverter", - "ParameterInfo", - "ParameterProcessingResult", - "ParameterProcessor", - "ParameterProfile", - "ParameterStyle", - "ParameterStyleConfig", - "ParameterValidator", - "TypedParameter", - "is_iterable_parameters", - "validate_parameter_alignment", - "wrap_with_type", -) - - -_PARAMETER_REGEX = re.compile( - r""" - (?P"(?:[^"\\]|\\.)*") | - (?P'(?:[^'\\]|\\.)*') | - (?P\$(?P\w*)?\$[\s\S]*?\$\4\$) | - (?P--[^\r\n]*) | - (?P/\*(?:[^*]|\*(?!/))*\*/) | - (?P\?\?|\?\||\?&) | - (?P::(?P\w+)) | - (?P%\((?P\w+)\)s) | - (?P%s) | - (?P:(?P\d+)) | - (?P:(?P\w+)) | - (?P@(?P\w+)) | - (?P\$(?P\d+)) | - (?P\$(?P\w+)) | - (?P\?) - """, - re.VERBOSE | re.IGNORECASE | re.MULTILINE | re.DOTALL, -) - - -class ParameterStyle(str, Enum): - """Parameter style enumeration. - - Supported parameter styles: - - QMARK: ? placeholders - - NUMERIC: $1, $2 placeholders - - POSITIONAL_PYFORMAT: %s placeholders - - NAMED_PYFORMAT: %(name)s placeholders - - NAMED_COLON: :name placeholders - - NAMED_AT: @name placeholders - - NAMED_DOLLAR: $name placeholders - - POSITIONAL_COLON: :1, :2 placeholders - - STATIC: Direct embedding of values in SQL - - NONE: No parameters supported - """ - - NONE = "none" - STATIC = "static" - QMARK = "qmark" - NUMERIC = "numeric" - NAMED_COLON = "named_colon" - POSITIONAL_COLON = "positional_colon" - NAMED_AT = "named_at" - NAMED_DOLLAR = "named_dollar" - NAMED_PYFORMAT = "pyformat_named" - POSITIONAL_PYFORMAT = "pyformat_positional" - - -@mypyc_attr(allow_interpreted_subclasses=False) -class TypedParameter: - """Parameter wrapper that preserves type information. - - Maintains type information through parsing and execution - format conversion. - - Attributes: - value: The parameter value - original_type: The original Python type of the value - semantic_name: Optional name for debugging purposes - """ - - __slots__ = ("_hash", "original_type", "semantic_name", "value") - - def __init__(self, value: Any, original_type: type | None = None, semantic_name: str | None = None) -> None: - """Initialize typed parameter wrapper. - - Args: - value: The parameter value - original_type: Original type (defaults to type(value)) - semantic_name: Optional semantic name for debugging - """ - self.value = value - self.original_type = original_type or type(value) - self.semantic_name = semantic_name - self._hash: int | None = None - - def __hash__(self) -> int: - """Cached hash value.""" - if self._hash is None: - value_id = id(self.value) - self._hash = hash((value_id, self.original_type, self.semantic_name)) - return self._hash - - def __eq__(self, other: object) -> bool: - """Equality comparison for TypedParameter instances.""" - if not isinstance(other, TypedParameter): - return False - return ( - self.value == other.value - and self.original_type == other.original_type - and self.semantic_name == other.semantic_name - ) - - def __repr__(self) -> str: - """String representation for debugging.""" - name_part = f", semantic_name='{self.semantic_name}'" if self.semantic_name else "" - return f"TypedParameter({self.value!r}, original_type={self.original_type.__name__}{name_part})" - - -@singledispatch -def _wrap_parameter_by_type(value: Any, semantic_name: str | None = None) -> Any: - """Type-specific parameter wrapping using singledispatch. - - Args: - value: Parameter value to potentially wrap - semantic_name: Optional semantic name for debugging - - Returns: - Either the original value or TypedParameter wrapper - """ - return value - - -@_wrap_parameter_by_type.register -def _(value: bool, semantic_name: str | None = None) -> TypedParameter: - """Wrap boolean values to prevent SQLGlot parsing issues.""" - return TypedParameter(value, bool, semantic_name) - - -@_wrap_parameter_by_type.register -def _(value: Decimal, semantic_name: str | None = None) -> TypedParameter: - """Wrap Decimal values to preserve precision.""" - return TypedParameter(value, Decimal, semantic_name) - - -@_wrap_parameter_by_type.register -def _(value: datetime, semantic_name: str | None = None) -> TypedParameter: - """Wrap datetime values for database-specific formatting.""" - return TypedParameter(value, datetime, semantic_name) - - -@_wrap_parameter_by_type.register -def _(value: date, semantic_name: str | None = None) -> TypedParameter: - """Wrap date values for database-specific formatting.""" - return TypedParameter(value, date, semantic_name) - - -@_wrap_parameter_by_type.register -def _(value: bytes, semantic_name: str | None = None) -> TypedParameter: - """Wrap bytes values to prevent string conversion issues in ADBC/Arrow.""" - return TypedParameter(value, bytes, semantic_name) - - -@mypyc_attr(allow_interpreted_subclasses=False) -class ParameterInfo: - """Information about a detected parameter in SQL. - - Tracks parameter metadata for conversion operations. - - Attributes: - name: Parameter name (for named styles) - style: Parameter style - position: Character position in SQL string - ordinal: Order of appearance (0-indexed) - placeholder_text: Original text in SQL - """ - - __slots__ = ("name", "ordinal", "placeholder_text", "position", "style") - - def __init__( - self, name: str | None, style: ParameterStyle, position: int, ordinal: int, placeholder_text: str - ) -> None: - """Initialize parameter information. - - Args: - name: Parameter name (None for positional styles) - style: Parameter style enum - position: Character position in SQL - ordinal: Order of appearance (0-indexed) - placeholder_text: Original placeholder text - """ - self.name = name - self.style = style - self.position = position - self.ordinal = ordinal - self.placeholder_text = placeholder_text - - def __repr__(self) -> str: - """String representation for debugging.""" - return ( - f"ParameterInfo(name={self.name!r}, style={self.style!r}, " - f"position={self.position}, ordinal={self.ordinal}, " - f"placeholder_text={self.placeholder_text!r})" - ) - - -@mypyc_attr(allow_interpreted_subclasses=False) -class ParameterStyleConfig: - """Configuration for parameter style processing. - - Provides configuration for parameter processing operations including - style conversion, type coercion, and parameter format preservation. - """ - - __slots__ = ( - "allow_mixed_parameter_styles", - "ast_transformer", - "default_execution_parameter_style", - "default_parameter_style", - "has_native_list_expansion", - "json_deserializer", - "json_serializer", - "needs_static_script_compilation", - "output_transformer", - "preserve_original_params_for_many", - "preserve_parameter_format", - "supported_execution_parameter_styles", - "supported_parameter_styles", - "type_coercion_map", - ) - - def __init__( - self, - default_parameter_style: ParameterStyle, - supported_parameter_styles: set[ParameterStyle] | None = None, - supported_execution_parameter_styles: set[ParameterStyle] | None = None, - default_execution_parameter_style: ParameterStyle | None = None, - type_coercion_map: dict[type, Callable[[Any], Any]] | None = None, - has_native_list_expansion: bool = False, - needs_static_script_compilation: bool = False, - allow_mixed_parameter_styles: bool = False, - preserve_parameter_format: bool = True, - preserve_original_params_for_many: bool = False, - output_transformer: Callable[[str, Any], tuple[str, Any]] | None = None, - ast_transformer: Callable[[Any, Any], tuple[Any, Any]] | None = None, - json_serializer: Callable[[Any], str] | None = None, - json_deserializer: Callable[[str], Any] | None = None, - ) -> None: - """Initialize parameter style configuration. - - Args: - default_parameter_style: Primary parameter style for parsing - supported_parameter_styles: All input styles this config supports - supported_execution_parameter_styles: Styles driver can execute - default_execution_parameter_style: Target format for execution - type_coercion_map: Driver-specific type conversions - has_native_list_expansion: Driver supports native array parameters - output_transformer: Final transformation hook - needs_static_script_compilation: Embed parameters directly in SQL - allow_mixed_parameter_styles: Support mixed styles in single query - preserve_parameter_format: Maintain original parameter structure - preserve_original_params_for_many: Return original list of tuples for execute_many - ast_transformer: AST-based transformation hook for SQL/parameter manipulation - json_serializer: Optional JSON serializer to apply to dict/list/tuple parameters - json_deserializer: Optional JSON deserializer retained for driver use - """ - self.default_parameter_style = default_parameter_style - self.supported_parameter_styles = ( - supported_parameter_styles if supported_parameter_styles is not None else {default_parameter_style} - ) - self.supported_execution_parameter_styles = supported_execution_parameter_styles - self.default_execution_parameter_style = default_execution_parameter_style or default_parameter_style - self.type_coercion_map = type_coercion_map or {} - self.has_native_list_expansion = has_native_list_expansion - self.output_transformer = output_transformer - self.ast_transformer = ast_transformer - self.needs_static_script_compilation = needs_static_script_compilation - self.allow_mixed_parameter_styles = allow_mixed_parameter_styles - self.preserve_parameter_format = preserve_parameter_format - self.preserve_original_params_for_many = preserve_original_params_for_many - self.json_serializer = json_serializer - self.json_deserializer = json_deserializer - - def hash(self) -> int: - """Generate hash for cache key generation. - - Returns: - Hash value for cache key generation - """ - hash_components = ( - self.default_parameter_style.value, - frozenset(s.value for s in self.supported_parameter_styles), - ( - frozenset(s.value for s in self.supported_execution_parameter_styles) - if self.supported_execution_parameter_styles - else None - ), - self.default_execution_parameter_style.value, - tuple(sorted(self.type_coercion_map.keys(), key=str)) if self.type_coercion_map else None, - self.has_native_list_expansion, - self.preserve_original_params_for_many, - bool(self.output_transformer), - self.needs_static_script_compilation, - self.allow_mixed_parameter_styles, - self.preserve_parameter_format, - bool(self.ast_transformer), - self.json_serializer, - self.json_deserializer, - ) - return hash(hash_components) - - def replace(self, **overrides: Any) -> "ParameterStyleConfig": - data: dict[str, Any] = { - "default_parameter_style": self.default_parameter_style, - "supported_parameter_styles": set(self.supported_parameter_styles), - "supported_execution_parameter_styles": ( - set(self.supported_execution_parameter_styles) - if self.supported_execution_parameter_styles is not None - else None - ), - "default_execution_parameter_style": self.default_execution_parameter_style, - "type_coercion_map": dict(self.type_coercion_map), - "has_native_list_expansion": self.has_native_list_expansion, - "needs_static_script_compilation": self.needs_static_script_compilation, - "allow_mixed_parameter_styles": self.allow_mixed_parameter_styles, - "preserve_parameter_format": self.preserve_parameter_format, - "preserve_original_params_for_many": self.preserve_original_params_for_many, - "output_transformer": self.output_transformer, - "ast_transformer": self.ast_transformer, - "json_serializer": self.json_serializer, - "json_deserializer": self.json_deserializer, - } - data.update(overrides) - return ParameterStyleConfig(**data) - - def with_json_serializers( - self, - serializer: "Callable[[Any], str]", - *, - tuple_strategy: Literal["list", "tuple"] = "list", - deserializer: "Callable[[str], Any] | None" = None, - ) -> "ParameterStyleConfig": - """Return a copy configured to serialize dict/list/tuple parameters with a custom JSON encoder.""" - - if tuple_strategy == "list": - - def tuple_adapter(value: Any) -> Any: - return serializer(list(value)) - - elif tuple_strategy == "tuple": - - def tuple_adapter(value: Any) -> Any: - return serializer(value) - - else: - msg = f"Unsupported tuple_strategy: {tuple_strategy}" - raise ValueError(msg) - - updated_type_map = dict(self.type_coercion_map) - updated_type_map[dict] = serializer - updated_type_map[list] = serializer - updated_type_map[tuple] = tuple_adapter - - return self.replace( - type_coercion_map=updated_type_map, - json_serializer=serializer, - json_deserializer=deserializer or self.json_deserializer, - ) - - -@mypyc_attr(allow_interpreted_subclasses=False) -class ParameterValidator: - """Parameter validation and extraction. - - Extracts parameter information from SQL strings and determines - compatibility with target dialects. - """ - - __slots__ = ("_cache_max_size", "_parameter_cache") - - def __init__(self, cache_max_size: int = 5000) -> None: - """Initialize validator with bounded LRU cache. - - Args: - cache_max_size: Maximum number of SQL strings to cache (default: 5000) - """ - self._parameter_cache: OrderedDict[str, list[ParameterInfo]] = OrderedDict() - self._cache_max_size = cache_max_size - - def _extract_parameter_style(self, match: "re.Match[str]") -> "tuple[ParameterStyle | None, str | None]": - """Extract parameter style and name from regex match. - - Args: - match: Regex match object - - Returns: - Tuple of (style, name) or (None, None) if not a parameter - """ - - if match.group("qmark"): - return ParameterStyle.QMARK, None - - if match.group("named_colon"): - return ParameterStyle.NAMED_COLON, match.group("colon_name") - - if match.group("numeric"): - return ParameterStyle.NUMERIC, match.group("numeric_num") - - if match.group("named_at"): - return ParameterStyle.NAMED_AT, match.group("at_name") - - if match.group("pyformat_named"): - return ParameterStyle.NAMED_PYFORMAT, match.group("pyformat_name") - - if match.group("pyformat_pos"): - return ParameterStyle.POSITIONAL_PYFORMAT, None - - if match.group("positional_colon"): - return ParameterStyle.POSITIONAL_COLON, match.group("colon_num") - - if match.group("named_dollar_param"): - return ParameterStyle.NAMED_DOLLAR, match.group("dollar_param_name") - - return None, None - - def extract_parameters(self, sql: str) -> "list[ParameterInfo]": - """Extract all parameters from SQL. - - Args: - sql: SQL string to analyze - - Returns: - List of ParameterInfo objects for each detected parameter - """ - cached_result = self._parameter_cache.get(sql) - if cached_result is not None: - self._parameter_cache.move_to_end(sql) - return cached_result - - if not any(c in sql for c in ("?", "%", ":", "@", "$")): - if len(self._parameter_cache) >= self._cache_max_size: - self._parameter_cache.popitem(last=False) - self._parameter_cache[sql] = [] - return [] - - parameters: list[ParameterInfo] = [] - ordinal = 0 - - skip_groups = ( - "dquote", - "squote", - "dollar_quoted_string", - "line_comment", - "block_comment", - "pg_q_operator", - "pg_cast", - ) - - for match in _PARAMETER_REGEX.finditer(sql): - if any(match.group(g) for g in skip_groups): - continue - - style, name = self._extract_parameter_style(match) - - if style is ParameterStyle.QMARK: - tail = sql[match.end() :] - next_non_space = tail.lstrip() - if next_non_space.startswith(("'", '"')): - continue - - if style is not None: - parameters.append( - ParameterInfo( - name=name, style=style, position=match.start(), ordinal=ordinal, placeholder_text=match.group(0) - ) - ) - ordinal += 1 - - if len(self._parameter_cache) >= self._cache_max_size: - self._parameter_cache.popitem(last=False) - - self._parameter_cache[sql] = parameters - return parameters - - def get_sqlglot_incompatible_styles(self, dialect: str | None = None) -> "set[ParameterStyle]": - """Get parameter styles incompatible with SQLGlot for dialect. - - Args: - dialect: SQL dialect for compatibility checking - - Returns: - Set of parameter styles incompatible with SQLGlot - """ - base_incompatible = { - ParameterStyle.POSITIONAL_PYFORMAT, - ParameterStyle.NAMED_PYFORMAT, - ParameterStyle.POSITIONAL_COLON, - } - - if dialect and dialect.lower() in {"mysql", "mariadb"}: - return base_incompatible - if dialect and dialect.lower() in {"postgres", "postgresql"}: - return {ParameterStyle.POSITIONAL_COLON} - if dialect and dialect.lower() == "sqlite": - return {ParameterStyle.POSITIONAL_COLON} - if dialect and dialect.lower() in {"oracle", "bigquery"}: - return base_incompatible - return base_incompatible - - -@mypyc_attr(allow_interpreted_subclasses=False) -class ParameterConverter: - """Parameter style conversion. - - Handles two-phase parameter processing: - - Phase 1: Compatibility normalization - - Phase 2: Execution format conversion - """ - - __slots__ = ("_format_converters", "_placeholder_generators", "validator") - - def __init__(self) -> None: - """Initialize converter with lookup tables.""" - self.validator = ParameterValidator() - - self._format_converters = { - ParameterStyle.POSITIONAL_COLON: self._convert_to_positional_colon_format, - ParameterStyle.NAMED_COLON: self._convert_to_named_colon_format, - ParameterStyle.NAMED_PYFORMAT: self._convert_to_named_pyformat_format, - ParameterStyle.QMARK: self._convert_to_positional_format, - ParameterStyle.NUMERIC: self._convert_to_positional_format, - ParameterStyle.POSITIONAL_PYFORMAT: self._convert_to_positional_format, - ParameterStyle.NAMED_AT: self._convert_to_named_colon_format, - ParameterStyle.NAMED_DOLLAR: self._convert_to_named_colon_format, - } - - self._placeholder_generators: dict[ParameterStyle, Callable[[Any], str]] = { - ParameterStyle.QMARK: lambda _: "?", - ParameterStyle.NUMERIC: lambda i: f"${int(i) + 1}", - ParameterStyle.NAMED_COLON: lambda name: f":{name}", - ParameterStyle.POSITIONAL_COLON: lambda i: f":{int(i) + 1}", - ParameterStyle.NAMED_AT: lambda name: f"@{name}", - ParameterStyle.NAMED_DOLLAR: lambda name: f"${name}", - ParameterStyle.NAMED_PYFORMAT: lambda name: f"%({name})s", - ParameterStyle.POSITIONAL_PYFORMAT: lambda _: "%s", - } - - def normalize_sql_for_parsing(self, sql: str, dialect: str | None = None) -> "tuple[str, list[ParameterInfo]]": - """Convert SQL to parsable format. - - Takes raw SQL with potentially incompatible parameter styles and converts - them to a canonical format for parsing. - - Args: - sql: Raw SQL string with any parameter style - dialect: Target SQL dialect for compatibility checking - - Returns: - Tuple of (parsable_sql, original_parameter_info) - """ - param_info = self.validator.extract_parameters(sql) - - incompatible_styles = self.validator.get_sqlglot_incompatible_styles(dialect) - needs_conversion = any(p.style in incompatible_styles for p in param_info) - - if not needs_conversion: - return sql, param_info - - converted_sql = self._convert_to_sqlglot_compatible(sql, param_info, incompatible_styles) - return converted_sql, param_info - - def _convert_to_sqlglot_compatible( - self, sql: str, param_info: "list[ParameterInfo]", incompatible_styles: "set[ParameterStyle]" - ) -> str: - """Convert SQL to SQLGlot-compatible format.""" - converted_sql = sql - for param in reversed(param_info): - if param.style in incompatible_styles: - canonical_placeholder = f":param_{param.ordinal}" - converted_sql = ( - converted_sql[: param.position] - + canonical_placeholder - + converted_sql[param.position + len(param.placeholder_text) :] - ) - - return converted_sql - - def convert_placeholder_style( - self, sql: str, parameters: Any, target_style: ParameterStyle, is_many: bool = False - ) -> "tuple[str, Any]": - """Convert SQL and parameters to execution format. - - Args: - sql: SQL string (possibly from Phase 1 normalization) - parameters: Parameter values in any format - target_style: Target parameter style for execution - is_many: Whether this is for executemany() operation - - Returns: - Tuple of (final_sql, execution_parameters) - """ - param_info = self.validator.extract_parameters(sql) - - if target_style == ParameterStyle.STATIC: - return self._embed_static_parameters(sql, parameters, param_info) - - current_styles = {p.style for p in param_info} - if len(current_styles) == 1 and target_style in current_styles: - converted_parameters = self._convert_parameter_format( - parameters, param_info, target_style, parameters, preserve_parameter_format=True - ) - return sql, converted_parameters - - converted_sql = self._convert_placeholders_to_style(sql, param_info, target_style) - converted_parameters = self._convert_parameter_format( - parameters, param_info, target_style, parameters, preserve_parameter_format=True - ) - - return converted_sql, converted_parameters - - def _convert_placeholders_to_style( - self, sql: str, param_info: "list[ParameterInfo]", target_style: ParameterStyle - ) -> str: - """Convert SQL placeholders to target style.""" - generator = self._placeholder_generators.get(target_style) - if not generator: - msg = f"Unsupported target parameter style: {target_style}" - raise ValueError(msg) - - param_styles = {p.style for p in param_info} - use_sequential_for_qmark = ( - len(param_styles) == 1 and ParameterStyle.QMARK in param_styles and target_style == ParameterStyle.NUMERIC - ) - - unique_params: dict[str, int] = {} - for param in param_info: - param_key = ( - f"{param.placeholder_text}_{param.ordinal}" - if use_sequential_for_qmark and param.style == ParameterStyle.QMARK - else param.placeholder_text - ) - - if param_key not in unique_params: - unique_params[param_key] = len(unique_params) - - converted_sql = sql - placeholder_text_len_cache: dict[str, int] = {} - - for param in reversed(param_info): - if param.placeholder_text not in placeholder_text_len_cache: - placeholder_text_len_cache[param.placeholder_text] = len(param.placeholder_text) - text_len = placeholder_text_len_cache[param.placeholder_text] - - if target_style in { - ParameterStyle.QMARK, - ParameterStyle.NUMERIC, - ParameterStyle.POSITIONAL_PYFORMAT, - ParameterStyle.POSITIONAL_COLON, - }: - param_key = ( - f"{param.placeholder_text}_{param.ordinal}" - if use_sequential_for_qmark and param.style == ParameterStyle.QMARK - else param.placeholder_text - ) - new_placeholder = generator(unique_params[param_key]) - else: - param_name = param.name or f"param_{param.ordinal}" - new_placeholder = generator(param_name) - - converted_sql = ( - converted_sql[: param.position] + new_placeholder + converted_sql[param.position + text_len :] - ) - - return converted_sql - - def _convert_sequence_to_dict( - self, parameters: "Sequence[Any]", param_info: "list[ParameterInfo]" - ) -> "dict[str, Any]": - """Convert sequence parameters to dictionary for named styles. - - Args: - parameters: Sequence of parameter values - param_info: Parameter information from SQL - - Returns: - Dictionary mapping parameter names to values - """ - param_dict = {} - for i, param in enumerate(param_info): - if i < len(parameters): - name = param.name or f"param_{param.ordinal}" - param_dict[name] = parameters[i] - return param_dict - - def _extract_param_value_mixed_styles( - self, param: ParameterInfo, parameters: "Mapping[str, Any]", param_keys: "list[str]" - ) -> "tuple[Any, bool]": - """Extract parameter value for mixed style parameters. - - Args: - param: Parameter information - parameters: Parameter mapping - param_keys: List of parameter keys - - Returns: - Tuple of (value, found_flag) - """ - if param.name and param.name in parameters: - return parameters[param.name], True - - if ( - param.style == ParameterStyle.NUMERIC - and param.name - and param.name.isdigit() - and param.ordinal < len(param_keys) - ): - key_to_use = param_keys[param.ordinal] - return parameters[key_to_use], True - - if f"param_{param.ordinal}" in parameters: - return parameters[f"param_{param.ordinal}"], True - - ordinal_key = str(param.ordinal + 1) - if ordinal_key in parameters: - return parameters[ordinal_key], True - - # Fallback: rely on insertion order when placeholders were normalized to positional names - if isinstance(parameters, Mapping): - try: - ordered_keys = list(parameters.keys()) - except AttributeError: - ordered_keys = [] - if ordered_keys and param.ordinal < len(ordered_keys): - key = ordered_keys[param.ordinal] - return parameters[key], True - - return None, False - - def _extract_param_value_single_style( - self, param: ParameterInfo, parameters: "Mapping[str, Any]" - ) -> "tuple[Any, bool]": - """Extract parameter value for single style parameters. - - Args: - param: Parameter information - parameters: Parameter mapping - - Returns: - Tuple of (value, found_flag) where found_flag indicates if parameter was found - """ - if param.name and param.name in parameters: - return parameters[param.name], True - if f"param_{param.ordinal}" in parameters: - return parameters[f"param_{param.ordinal}"], True - - ordinal_key = str(param.ordinal + 1) - if ordinal_key in parameters: - return parameters[ordinal_key], True - - try: - ordered_keys = list(parameters.keys()) - except AttributeError: - ordered_keys = [] - if ordered_keys and param.ordinal < len(ordered_keys): - key = ordered_keys[param.ordinal] - if key in parameters: - return parameters[key], True - - return None, False - - def _preserve_original_format(self, param_values: "list[Any]", original_parameters: Any) -> Any: - """Preserve the original parameter container format. - - Args: - param_values: List of parameter values - original_parameters: Original parameter container - - Returns: - Parameters in original format - """ - if isinstance(original_parameters, tuple): - return tuple(param_values) - if isinstance(original_parameters, list): - return param_values - if isinstance(original_parameters, Mapping): - return tuple(param_values) - - if hasattr(original_parameters, "__class__") and callable(original_parameters.__class__): - try: - return original_parameters.__class__(param_values) - except (TypeError, ValueError): - return tuple(param_values) - - return param_values - - def _convert_parameter_format( - self, - parameters: Any, - param_info: "list[ParameterInfo]", - target_style: ParameterStyle, - original_parameters: Any = None, - preserve_parameter_format: bool = False, - ) -> Any: - """Convert parameter format to match target style requirements. - - Args: - parameters: Current parameter values - param_info: Parameter information extracted from SQL - target_style: Target parameter style for conversion - original_parameters: Original parameter container for type preservation - preserve_parameter_format: Whether to preserve the original parameter format - """ - if not parameters or not param_info: - return parameters - - is_named_style = target_style in { - ParameterStyle.NAMED_COLON, - ParameterStyle.NAMED_AT, - ParameterStyle.NAMED_DOLLAR, - ParameterStyle.NAMED_PYFORMAT, - } - - if is_named_style: - if isinstance(parameters, Mapping): - return parameters - if isinstance(parameters, Sequence) and not isinstance(parameters, (str, bytes)): - return self._convert_sequence_to_dict(parameters, param_info) - - elif isinstance(parameters, Sequence) and not isinstance(parameters, (str, bytes)): - return parameters - - elif isinstance(parameters, Mapping): - param_values = [] - parameter_styles = {p.style for p in param_info} - has_mixed_styles = len(parameter_styles) > 1 - - # Build unique parameter mapping to avoid duplicates when same parameter appears multiple times - unique_params: dict[str, Any] = {} - param_order: list[str] = [] - - if has_mixed_styles: - param_keys = list(parameters.keys()) - for param in param_info: - param_key = param.placeholder_text - if param_key not in unique_params: - value, found = self._extract_param_value_mixed_styles(param, parameters, param_keys) - if found: - unique_params[param_key] = value - param_order.append(param_key) - else: - for param in param_info: - param_key = param.placeholder_text - if param_key not in unique_params: - value, found = self._extract_param_value_single_style(param, parameters) - if found: - unique_params[param_key] = value - param_order.append(param_key) - - # Build parameter values list from unique parameters in order - param_values = [unique_params[param_key] for param_key in param_order] - - if preserve_parameter_format and original_parameters is not None: - return self._preserve_original_format(param_values, original_parameters) - - return param_values - - return parameters - - def _embed_static_parameters( - self, sql: str, parameters: Any, param_info: "list[ParameterInfo]" - ) -> "tuple[str, Any]": - """Embed parameters directly into SQL for STATIC style.""" - if not param_info: - return sql, None - - unique_params: dict[str, int] = {} - for param in param_info: - if param.style in {ParameterStyle.QMARK, ParameterStyle.POSITIONAL_PYFORMAT}: - param_key = f"{param.placeholder_text}_{param.ordinal}" - elif (param.style == ParameterStyle.NUMERIC and param.name) or param.name: - param_key = param.placeholder_text - else: - param_key = f"{param.placeholder_text}_{param.ordinal}" - - if param_key not in unique_params: - unique_params[param_key] = len(unique_params) - - static_sql = sql - for param in reversed(param_info): - param_value = self._get_parameter_value_with_reuse(parameters, param, unique_params) - - if param_value is None: - literal = "NULL" - elif isinstance(param_value, str): - escaped = param_value.replace("'", "''") - literal = f"'{escaped}'" - elif isinstance(param_value, bool): - literal = "TRUE" if param_value else "FALSE" - elif isinstance(param_value, (int, float)): - literal = str(param_value) - else: - literal = f"'{param_value!s}'" - - static_sql = ( - static_sql[: param.position] + literal + static_sql[param.position + len(param.placeholder_text) :] - ) - - return static_sql, None - - def _get_parameter_value(self, parameters: Any, param: ParameterInfo) -> Any: - """Extract parameter value based on parameter info and format.""" - if isinstance(parameters, Mapping): - if param.name and param.name in parameters: - return parameters[param.name] - if f"param_{param.ordinal}" in parameters: - return parameters[f"param_{param.ordinal}"] - if str(param.ordinal + 1) in parameters: - return parameters[str(param.ordinal + 1)] - elif isinstance(parameters, Sequence) and not isinstance(parameters, (str, bytes)): - if param.ordinal < len(parameters): - return parameters[param.ordinal] - - return None - - def _get_parameter_value_with_reuse( - self, parameters: Any, param: ParameterInfo, unique_params: "dict[str, int]" - ) -> Any: - """Extract parameter value handling parameter reuse correctly. - - Args: - parameters: Parameter values in any format - param: Parameter information - unique_params: Mapping of unique placeholders to their ordinal positions - - Returns: - Parameter value, correctly handling reused parameters - """ - - if param.style in {ParameterStyle.QMARK, ParameterStyle.POSITIONAL_PYFORMAT}: - param_key = f"{param.placeholder_text}_{param.ordinal}" - elif (param.style == ParameterStyle.NUMERIC and param.name) or param.name: - param_key = param.placeholder_text - else: - param_key = f"{param.placeholder_text}_{param.ordinal}" - - unique_ordinal = unique_params.get(param_key) - if unique_ordinal is None: - return None - - if isinstance(parameters, Mapping): - if param.name and param.name in parameters: - return parameters[param.name] - if f"param_{unique_ordinal}" in parameters: - return parameters[f"param_{unique_ordinal}"] - if str(unique_ordinal + 1) in parameters: - return parameters[str(unique_ordinal + 1)] - elif isinstance(parameters, Sequence) and not isinstance(parameters, (str, bytes)): - if unique_ordinal < len(parameters): - return parameters[unique_ordinal] - - return None - - def _convert_to_positional_format(self, parameters: Any, param_info: "list[ParameterInfo]") -> Any: - """Convert parameters to positional format (list/tuple).""" - return self._convert_parameter_format( - parameters, param_info, ParameterStyle.QMARK, parameters, preserve_parameter_format=False - ) - - def _convert_to_named_colon_format(self, parameters: Any, param_info: "list[ParameterInfo]") -> Any: - """Convert parameters to named colon format (dict).""" - return self._convert_parameter_format( - parameters, param_info, ParameterStyle.NAMED_COLON, parameters, preserve_parameter_format=False - ) - - def _convert_to_positional_colon_format(self, parameters: Any, param_info: "list[ParameterInfo]") -> Any: - """Convert parameters to positional colon format with 1-based keys.""" - if isinstance(parameters, Mapping): - return parameters - - param_dict = {} - if isinstance(parameters, Sequence) and not isinstance(parameters, (str, bytes)): - for i, value in enumerate(parameters): - param_dict[str(i + 1)] = value - - return param_dict - - def _convert_to_named_pyformat_format(self, parameters: Any, param_info: "list[ParameterInfo]") -> Any: - """Convert parameters to named pyformat format (dict).""" - return self._convert_parameter_format( - parameters, param_info, ParameterStyle.NAMED_PYFORMAT, parameters, preserve_parameter_format=False - ) - - -@mypyc_attr(allow_interpreted_subclasses=False) -class ParameterProfile: - """Aggregate metadata describing detected parameters.""" - - __slots__ = ("_parameters", "_placeholder_counts", "named_parameters", "reused_ordinals", "styles") - - def __init__(self, parameters: "Sequence[ParameterInfo] | None" = None) -> None: - param_tuple: tuple[ParameterInfo, ...] = tuple(parameters) if parameters else () - self._parameters = param_tuple - self.styles = tuple(sorted({param.style.value for param in param_tuple})) if param_tuple else () - placeholder_counts: dict[str, int] = {} - reused_ordinals: list[int] = [] - named_parameters: list[str] = [] - - for param in param_tuple: - placeholder = param.placeholder_text - current_count = placeholder_counts.get(placeholder, 0) - placeholder_counts[placeholder] = current_count + 1 - if current_count: - reused_ordinals.append(param.ordinal) - if param.name is not None: - named_parameters.append(param.name) - - self._placeholder_counts = placeholder_counts - self.reused_ordinals = tuple(reused_ordinals) - self.named_parameters = tuple(named_parameters) - - @classmethod - def empty(cls) -> "ParameterProfile": - return cls(()) - - @property - def parameters(self) -> "tuple[ParameterInfo, ...]": - return self._parameters - - @property - def total_count(self) -> int: - return len(self._parameters) - - def placeholder_count(self, placeholder: str) -> int: - return self._placeholder_counts.get(placeholder, 0) - - def is_empty(self) -> bool: - return not self._parameters - - -@mypyc_attr(allow_interpreted_subclasses=False) -class ParameterProcessingResult: - """Return container for parameter processing output.""" - - __slots__ = ("parameter_profile", "parameters", "sql") - - def __init__(self, sql: str, parameters: Any, parameter_profile: "ParameterProfile") -> None: - self.sql = sql - self.parameters = parameters - self.parameter_profile = parameter_profile - - def __iter__(self) -> "Generator[str | Any, Any, None]": - yield self.sql - yield self.parameters - - def __len__(self) -> int: - return 2 - - def __getitem__(self, index: int) -> Any: - if index == 0: - return self.sql - if index == 1: - return self.parameters - msg = "ParameterProcessingResult exposes exactly two positional items" - raise IndexError(msg) - - -EXECUTE_MANY_MIN_ROWS: Final[int] = 2 - - -def _is_sequence_like(value: Any) -> bool: - return isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)) - - -def _looks_like_execute_many(parameters: Any) -> bool: - if not _is_sequence_like(parameters) or len(parameters) < EXECUTE_MANY_MIN_ROWS: - return False - return all(_is_sequence_like(entry) or isinstance(entry, Mapping) for entry in parameters) - - -def _normalize_parameter_key(key: Any) -> "tuple[str, int | str]": - if isinstance(key, str): - stripped_numeric = key.lstrip("$") - if stripped_numeric.isdigit(): - return ("index", int(stripped_numeric) - 1) - if key.isdigit(): - return ("index", int(key) - 1) - return ("named", key) - if isinstance(key, int): - if key > 0: - return ("index", key - 1) - return ("index", key) - return ("named", str(key)) - - -def _collect_expected_identifiers(parameter_profile: "ParameterProfile") -> "set[tuple[str, int | str]]": - identifiers: set[tuple[str, int | str]] = set() - for parameter in parameter_profile.parameters: - style = parameter.style - name = parameter.name - if style in { - ParameterStyle.NAMED_COLON, - ParameterStyle.NAMED_AT, - ParameterStyle.NAMED_DOLLAR, - ParameterStyle.NAMED_PYFORMAT, - }: - identifiers.add(("named", name or f"param_{parameter.ordinal}")) - elif style in {ParameterStyle.NUMERIC, ParameterStyle.POSITIONAL_COLON}: - if name and name.isdigit(): - identifiers.add(("index", int(name) - 1)) - else: - identifiers.add(("index", parameter.ordinal)) - else: - identifiers.add(("index", parameter.ordinal)) - return identifiers - - -def _collect_actual_identifiers(parameters: Any) -> "tuple[set[tuple[str, int | str]], int]": - if parameters is None: - return set(), 0 - if isinstance(parameters, Mapping): - mapping_identifiers = {_normalize_parameter_key(key) for key in parameters} - return mapping_identifiers, len(parameters) - if _looks_like_execute_many(parameters): - aggregated_identifiers: set[tuple[str, int | str]] = set() - for entry in parameters: - entry_identifiers, _ = _collect_actual_identifiers(entry) - aggregated_identifiers.update(entry_identifiers) - return aggregated_identifiers, len(aggregated_identifiers) - if _is_sequence_like(parameters): - identifiers = {("index", cast("int | str", index)) for index in range(len(parameters))} - return identifiers, len(parameters) - identifiers = {("index", cast("int | str", 0))} - return identifiers, 1 - - -def _format_identifiers(identifiers: "set[tuple[str, int | str]]") -> str: - if not identifiers: - return "[]" - formatted: list[str] = [] - for identifier in sorted(identifiers, key=lambda item: (item[0], str(item[1]))): - kind, value = identifier - if kind == "named": - formatted.append(str(value)) - elif isinstance(value, int): - formatted.append(str(value + 1)) - else: - formatted.append(str(value)) - return "[" + ", ".join(formatted) + "]" - - -def _validate_single_parameter_set( - parameter_profile: "ParameterProfile", parameters: Any, batch_index: "int | None" = None -) -> None: - expected_identifiers = _collect_expected_identifiers(parameter_profile) - actual_identifiers, actual_count = _collect_actual_identifiers(parameters) - expected_count = len(expected_identifiers) - - if expected_count == 0 and actual_count == 0: - return - - prefix = "Parameter count mismatch" - if batch_index is not None: - prefix = f"{prefix} in batch {batch_index}" - - if expected_count == 0 and actual_count > 0: - msg = f"{prefix}: statement does not accept parameters." - raise sqlspec.exceptions.SQLSpecError(msg) - - if expected_count > 0 and actual_count == 0: - msg = f"{prefix}: expected {expected_count} parameters, received 0." - raise sqlspec.exceptions.SQLSpecError(msg) - - if expected_count != actual_count: - msg = f"{prefix}: {actual_count} parameters provided but {expected_count} placeholders detected." - raise sqlspec.exceptions.SQLSpecError(msg) - - if expected_identifiers != actual_identifiers: - msg = ( - f"{prefix}: expected identifiers {_format_identifiers(expected_identifiers)}, " - f"received {_format_identifiers(actual_identifiers)}." - ) - raise sqlspec.exceptions.SQLSpecError(msg) - - -def validate_parameter_alignment( - parameter_profile: "ParameterProfile | None", parameters: Any, *, is_many: bool = False -) -> None: - """Validate that provided parameters align with detected placeholders. - - Args: - parameter_profile: Placeholder metadata produced by parameter processing. - parameters: Parameters that will be bound for execution. - is_many: Whether parameters represent execute_many payload. - - Raises: - SQLSpecError: If parameter counts or identifiers do not align. - """ - profile = parameter_profile or ParameterProfile.empty() - if profile.total_count == 0: - return - - effective_is_many = is_many or _looks_like_execute_many(parameters) - - if effective_is_many: - if parameters is None: - if profile.total_count == 0: - return - msg = "Parameter count mismatch: expected parameter sets for execute_many." - raise sqlspec.exceptions.SQLSpecError(msg) - if not _is_sequence_like(parameters): - msg = "Parameter count mismatch: expected sequence of parameter sets for execute_many." - raise sqlspec.exceptions.SQLSpecError(msg) - if len(parameters) == 0: - return - for index, param_set in enumerate(parameters): - _validate_single_parameter_set(profile, param_set, batch_index=index) - return - - _validate_single_parameter_set(profile, parameters) - - -@mypyc_attr(allow_interpreted_subclasses=False) -class ParameterProcessor: - """Parameter processing engine. - - Main entry point for the parameter processing system that coordinates - Phase 1 (compatibility) and Phase 2 (execution format). - - Processing Pipeline: - 1. Type wrapping for compatibility (TypedParameter) - 2. Driver-specific type coercions (type_coercion_map) - 3. Phase 1: Normalization if needed - 4. Phase 2: Execution format conversion if needed - 5. Final output transformation (output_transformer) - """ - - __slots__ = ("_cache", "_cache_size", "_converter", "_validator") - - DEFAULT_CACHE_SIZE = 1000 - - def __init__(self) -> None: - """Initialize processor with component coordination.""" - self._cache: dict[str, ParameterProcessingResult] = {} - self._cache_size = 0 - self._validator = ParameterValidator() - self._converter = ParameterConverter() - - def _handle_static_embedding( - self, sql: str, parameters: Any, config: ParameterStyleConfig, is_many: bool, cache_key: str - ) -> "ParameterProcessingResult": - """Handle static parameter embedding for script compilation. - - Args: - sql: SQL string - parameters: Parameter values - config: Parameter configuration - is_many: Whether this is for execute_many - cache_key: Cache key for result - - Returns: - Tuple of (static_sql, static_params) - """ - coerced_params = parameters - if config.type_coercion_map and parameters: - coerced_params = self._apply_type_coercions(parameters, config.type_coercion_map, is_many) - - static_sql, static_params = self._converter.convert_placeholder_style( - sql, coerced_params, ParameterStyle.STATIC, is_many - ) - result = ParameterProcessingResult(static_sql, static_params, ParameterProfile.empty()) - if self._cache_size < self.DEFAULT_CACHE_SIZE: - self._cache[cache_key] = result - self._cache_size += 1 - return result - - def _process_parameters_conversion( - self, - sql: str, - parameters: Any, - config: ParameterStyleConfig, - original_styles: "set[ParameterStyle]", - needs_execution_conversion: bool, - needs_sqlglot_normalization: bool, - is_many: bool, - ) -> "tuple[str, Any]": - """Process parameter conversion phase. - - Args: - sql: Processed SQL string - parameters: Processed parameters - config: Parameter configuration - original_styles: Original parameter styles detected - needs_execution_conversion: Whether execution conversion is needed - needs_sqlglot_normalization: Whether SQLGlot normalization is needed - is_many: Whether this is for execute_many - - Returns: - Tuple of (processed_sql, processed_parameters) - """ - if not (needs_execution_conversion or needs_sqlglot_normalization): - return sql, parameters - - if is_many and config.preserve_original_params_for_many and isinstance(parameters, (list, tuple)): - target_style = self._determine_target_execution_style(original_styles, config) - processed_sql, _ = self._converter.convert_placeholder_style(sql, parameters, target_style, is_many) - return processed_sql, parameters - - target_style = self._determine_target_execution_style(original_styles, config) - return self._converter.convert_placeholder_style(sql, parameters, target_style, is_many) - - def _fingerprint_parameters(self, parameters: Any) -> str: - """Create deterministic fingerprint for parameter values.""" - if parameters is None: - return "none" - - if isinstance(parameters, Mapping): - try: - items = sorted(parameters.items(), key=lambda item: repr(item[0])) - except Exception: - items = list(parameters.items()) - data = repr(tuple(items)) - elif isinstance(parameters, Sequence) and not isinstance(parameters, (str, bytes, bytearray)): - data = repr(tuple(parameters)) - else: - data = repr(parameters) - - digest = hashlib.sha256(data.encode("utf-8")).hexdigest() - return f"{type(parameters).__name__}:{digest}" - - def _generate_processor_cache_key( - self, sql: str, parameters: Any, config: ParameterStyleConfig, is_many: bool, dialect: "str | None" - ) -> str: - """Generate cache key for parameter processing.""" - param_fingerprint = self._fingerprint_parameters(parameters) - dialect_marker = dialect or "default" - default_style = config.default_parameter_style.value if config.default_parameter_style else "unknown" - return f"{sql}:{param_fingerprint}:{default_style}:{is_many}:{dialect_marker}" - - def process( - self, sql: str, parameters: Any, config: ParameterStyleConfig, dialect: str | None = None, is_many: bool = False - ) -> "ParameterProcessingResult": - """Process parameters through the complete pipeline. - - Coordinates the entire parameter processing workflow: - 1. Type wrapping for compatibility - 2. Phase 1: Normalization if needed - 3. Phase 2: Execution format conversion - 4. Driver-specific type coercions - 5. Final output transformation - - Args: - sql: Raw SQL string - parameters: Parameter values in any format - config: Parameter style configuration - dialect: SQL dialect for compatibility - is_many: Whether this is for execute_many operation - - Returns: - Tuple of (final_sql, execution_parameters) - """ - cache_key = self._generate_processor_cache_key(sql, parameters, config, is_many, dialect) - cached_result = self._cache.get(cache_key) - if cached_result is not None: - return cached_result - - param_info = self._validator.extract_parameters(sql) - original_styles = {p.style for p in param_info} if param_info else set() - needs_sqlglot_normalization = self._needs_sqlglot_normalization(param_info, dialect) - needs_execution_conversion = self._needs_execution_conversion(param_info, config) - - needs_static_embedding = config.needs_static_script_compilation and param_info and parameters and not is_many - - if needs_static_embedding: - return self._handle_static_embedding(sql, parameters, config, is_many, cache_key) - - if ( - not needs_sqlglot_normalization - and not needs_execution_conversion - and not config.type_coercion_map - and not config.output_transformer - ): - result = ParameterProcessingResult(sql, parameters, ParameterProfile(param_info)) - if self._cache_size < self.DEFAULT_CACHE_SIZE: - self._cache[cache_key] = result - self._cache_size += 1 - return result - - processed_sql, processed_parameters = sql, parameters - - if processed_parameters: - processed_parameters = self._apply_type_wrapping(processed_parameters) - - if needs_sqlglot_normalization: - processed_sql, _ = self._converter.normalize_sql_for_parsing(processed_sql, dialect) - - if config.type_coercion_map and processed_parameters: - processed_parameters = self._apply_type_coercions(processed_parameters, config.type_coercion_map, is_many) - - processed_sql, processed_parameters = self._process_parameters_conversion( - processed_sql, - processed_parameters, - config, - original_styles, - needs_execution_conversion, - needs_sqlglot_normalization, - is_many, - ) - - if config.output_transformer: - processed_sql, processed_parameters = config.output_transformer(processed_sql, processed_parameters) - - final_param_info = self._validator.extract_parameters(processed_sql) - final_profile = ParameterProfile(final_param_info) - result = ParameterProcessingResult(processed_sql, processed_parameters, final_profile) - - if self._cache_size < self.DEFAULT_CACHE_SIZE: - self._cache[cache_key] = result - self._cache_size += 1 - - return result - - def get_sqlglot_compatible_sql( - self, sql: str, parameters: Any, config: ParameterStyleConfig, dialect: str | None = None - ) -> "tuple[str, Any]": - """Get SQL normalized for parsing only (Phase 1 only). - - Performs only Phase 1 normalization to make SQL compatible - with parsing, without converting to execution format. - - Args: - sql: Raw SQL string - parameters: Parameter values - config: Parameter style configuration - dialect: SQL dialect for compatibility - - Returns: - Tuple of (compatible_sql, parameters) - """ - - param_info = self._validator.extract_parameters(sql) - - if self._needs_sqlglot_normalization(param_info, dialect): - normalized_sql, _ = self._converter.normalize_sql_for_parsing(sql, dialect) - return normalized_sql, parameters - - return sql, parameters - - def _needs_execution_conversion(self, param_info: "list[ParameterInfo]", config: ParameterStyleConfig) -> bool: - """Determine if execution format conversion is needed. - - Preserves the original parameter style if it's supported by the execution - environment, otherwise converts to the default execution style. - """ - if not param_info: - return False - - current_styles = {p.style for p in param_info} - - if ( - config.allow_mixed_parameter_styles - and len(current_styles) > 1 - and config.supported_execution_parameter_styles is not None - and len(config.supported_execution_parameter_styles) > 1 - and all(style in config.supported_execution_parameter_styles for style in current_styles) - ): - return False - - if len(current_styles) > 1: - return True - - if len(current_styles) == 1: - current_style = next(iter(current_styles)) - supported_styles = config.supported_execution_parameter_styles - if supported_styles is None: - return True - return current_style not in supported_styles - - return True - - def _needs_sqlglot_normalization(self, param_info: "list[ParameterInfo]", dialect: str | None = None) -> bool: - """Check if SQLGlot normalization is needed for this SQL.""" - incompatible_styles = self._validator.get_sqlglot_incompatible_styles(dialect) - return any(p.style in incompatible_styles for p in param_info) - - def _determine_target_execution_style( - self, original_styles: "set[ParameterStyle]", config: ParameterStyleConfig - ) -> ParameterStyle: - """Determine the target execution style based on original styles and config. - - Logic: - 1. If there's a single original style and it's in supported execution styles, use it - 2. Otherwise, use the default execution style - 3. If no default execution style, use the default parameter style - - Preserves the original parameter style when possible, only converting - when necessary for execution compatibility. - """ - - if len(original_styles) == 1 and config.supported_execution_parameter_styles is not None: - original_style = next(iter(original_styles)) - if original_style in config.supported_execution_parameter_styles: - return original_style - - return config.default_execution_parameter_style or config.default_parameter_style - - def _apply_type_wrapping(self, parameters: Any) -> Any: - """Apply type wrapping using singledispatch for performance.""" - if isinstance(parameters, Sequence) and not isinstance(parameters, (str, bytes)): - return [_wrap_parameter_by_type(p) for p in parameters] - if isinstance(parameters, Mapping): - wrapped_dict = {} - for k, v in parameters.items(): - wrapped_dict[k] = _wrap_parameter_by_type(v) - return wrapped_dict - return _wrap_parameter_by_type(parameters) - - def _apply_type_coercions( - self, parameters: Any, type_coercion_map: "dict[type, Callable[[Any], Any]]", is_many: bool = False - ) -> Any: - """Apply database-specific type coercions. - - Args: - parameters: Parameter values to coerce - type_coercion_map: Type coercion mappings - is_many: If True, parameters is a list of parameter sets for execute_many - """ - - def coerce_value(value: Any) -> Any: - # Skip coercion for None values to preserve NULL semantics - if value is None: - return value - - if isinstance(value, TypedParameter): - wrapped_value: Any = value.value - # Skip coercion for None values even when wrapped - if wrapped_value is None: - return wrapped_value - - original_type = value.original_type - if original_type in type_coercion_map: - coerced = type_coercion_map[original_type](wrapped_value) - - if isinstance(coerced, (list, tuple)) and not isinstance(coerced, (str, bytes)): - coerced = [coerce_value(item) for item in coerced] - elif isinstance(coerced, dict): - coerced = {k: coerce_value(v) for k, v in coerced.items()} - return coerced - return wrapped_value - - value_type = type(value) - if value_type in type_coercion_map: - coerced = type_coercion_map[value_type](value) - - if isinstance(coerced, (list, tuple)) and not isinstance(coerced, (str, bytes)): - coerced = [coerce_value(item) for item in coerced] - elif isinstance(coerced, dict): - coerced = {k: coerce_value(v) for k, v in coerced.items()} - return coerced - return value - - def coerce_parameter_set(param_set: Any) -> Any: - """Coerce a single parameter set (dict, list, tuple, or scalar).""" - if isinstance(param_set, Sequence) and not isinstance(param_set, (str, bytes)): - return [coerce_value(p) for p in param_set] - if isinstance(param_set, Mapping): - return {k: coerce_value(v) for k, v in param_set.items()} - return coerce_value(param_set) - - if is_many and isinstance(parameters, Sequence) and not isinstance(parameters, (str, bytes)): - return [coerce_parameter_set(param_set) for param_set in parameters] - - if isinstance(parameters, Sequence) and not isinstance(parameters, (str, bytes)): - return [coerce_value(p) for p in parameters] - if isinstance(parameters, Mapping): - return {k: coerce_value(v) for k, v in parameters.items()} - return coerce_value(parameters) - - -def is_iterable_parameters(obj: Any) -> bool: - """Check if object is iterable parameters (not string/bytes). - - Args: - obj: Object to check - - Returns: - True if object is iterable parameters - """ - return isinstance(obj, (list, tuple, set)) or ( - hasattr(obj, "__iter__") and not isinstance(obj, (str, bytes, Mapping)) - ) - - -def wrap_with_type(value: Any, semantic_name: str | None = None) -> Any: - """Public API for type wrapping. - - Args: - value: Value to potentially wrap - semantic_name: Optional semantic name - - Returns: - Original value or TypedParameter wrapper - """ - return _wrap_parameter_by_type(value, semantic_name) diff --git a/sqlspec/core/parameters/__init__.py b/sqlspec/core/parameters/__init__.py new file mode 100644 index 000000000..9cc581fa0 --- /dev/null +++ b/sqlspec/core/parameters/__init__.py @@ -0,0 +1,60 @@ +"""Parameter processing public API.""" + +from sqlspec.core.parameters._alignment import ( + EXECUTE_MANY_MIN_ROWS, + collect_null_parameter_ordinals, + looks_like_execute_many, + normalize_parameter_key, + validate_parameter_alignment, +) +from sqlspec.core.parameters._converter import ParameterConverter +from sqlspec.core.parameters._processor import ParameterProcessor +from sqlspec.core.parameters._registry import ( + DRIVER_PARAMETER_PROFILES, + build_statement_config_from_profile, + get_driver_profile, + register_driver_profile, +) +from sqlspec.core.parameters._transformers import ( + replace_null_parameters_with_literals, + replace_placeholders_with_literals, +) +from sqlspec.core.parameters._types import ( + DriverParameterProfile, + ParameterInfo, + ParameterProcessingResult, + ParameterProfile, + ParameterStyle, + ParameterStyleConfig, + TypedParameter, + is_iterable_parameters, + wrap_with_type, +) +from sqlspec.core.parameters._validator import PARAMETER_REGEX, ParameterValidator + +__all__ = ( + "DRIVER_PARAMETER_PROFILES", + "EXECUTE_MANY_MIN_ROWS", + "PARAMETER_REGEX", + "DriverParameterProfile", + "ParameterConverter", + "ParameterInfo", + "ParameterProcessingResult", + "ParameterProcessor", + "ParameterProfile", + "ParameterStyle", + "ParameterStyleConfig", + "ParameterValidator", + "TypedParameter", + "build_statement_config_from_profile", + "collect_null_parameter_ordinals", + "get_driver_profile", + "is_iterable_parameters", + "looks_like_execute_many", + "normalize_parameter_key", + "register_driver_profile", + "replace_null_parameters_with_literals", + "replace_placeholders_with_literals", + "validate_parameter_alignment", + "wrap_with_type", +) diff --git a/sqlspec/core/parameters/_alignment.py b/sqlspec/core/parameters/_alignment.py new file mode 100644 index 000000000..e4ca41647 --- /dev/null +++ b/sqlspec/core/parameters/_alignment.py @@ -0,0 +1,231 @@ +"""Parameter alignment and validation helpers.""" + +from collections.abc import Mapping, Sequence +from typing import Any, cast + +import sqlspec.exceptions +from sqlspec.core.parameters._types import ParameterProfile, ParameterStyle + +__all__ = ( + "EXECUTE_MANY_MIN_ROWS", + "collect_null_parameter_ordinals", + "looks_like_execute_many", + "normalize_parameter_key", + "validate_parameter_alignment", +) + +EXECUTE_MANY_MIN_ROWS: int = 2 + + +def _is_sequence_like(value: Any) -> bool: + return isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)) + + +def normalize_parameter_key(key: Any) -> "tuple[str, int | str]": + """Normalize a parameter key into an ``(kind, value)`` tuple. + + Args: + key: Key supplied by the caller (index, name, or adapter-specific token). + + Returns: + Tuple identifying the key type and canonical value for alignment checks. + """ + if isinstance(key, str): + stripped_numeric = key.lstrip("$") + if stripped_numeric.isdigit(): + return ("index", int(stripped_numeric) - 1) + if key.isdigit(): + return ("index", int(key) - 1) + return ("named", key) + if isinstance(key, int): + if key > 0: + return ("index", key - 1) + return ("index", key) + return ("named", str(key)) + + +def looks_like_execute_many(parameters: Any) -> bool: + """Return ``True`` when the payload resembles an ``execute_many`` structure. + + Args: + parameters: Potential parameter payload to inspect. + + Returns: + ``True`` if the payload appears to be a sequence of parameter sets. + """ + if not _is_sequence_like(parameters) or len(parameters) < EXECUTE_MANY_MIN_ROWS: + return False + return all(_is_sequence_like(entry) or isinstance(entry, Mapping) for entry in parameters) + + +def collect_null_parameter_ordinals(parameters: Any, profile: "ParameterProfile") -> "set[int]": + """Identify placeholder ordinals whose provided values are ``None``. + + Args: + parameters: Parameter payload supplied by the caller. + profile: Metadata describing detected placeholders. + + Returns: + Set of ordinal indices corresponding to ``None`` values. + """ + if parameters is None: + return set() + + null_positions: set[int] = set() + + if isinstance(parameters, Mapping): + name_lookup: dict[str, int] = {} + for parameter in profile.parameters: + if parameter.name: + name_lookup[parameter.name] = parameter.ordinal + stripped_name = parameter.name.lstrip("@") + name_lookup.setdefault(stripped_name, parameter.ordinal) + name_lookup.setdefault(f"@{stripped_name}", parameter.ordinal) + + for key, value in parameters.items(): + if value is not None: + continue + key_kind, normalized_key = normalize_parameter_key(key) + if key_kind == "index" and isinstance(normalized_key, int): + null_positions.add(normalized_key) + continue + if key_kind == "named": + ordinal = name_lookup.get(str(normalized_key)) + if ordinal is not None: + null_positions.add(ordinal) + return null_positions + + if isinstance(parameters, Sequence) and not isinstance(parameters, (str, bytes, bytearray)): + for index, value in enumerate(parameters): + if value is None: + null_positions.add(index) + return null_positions + + return null_positions + + +def _collect_expected_identifiers(parameter_profile: "ParameterProfile") -> "set[tuple[str, int | str]]": + identifiers: set[tuple[str, int | str]] = set() + for parameter in parameter_profile.parameters: + style = parameter.style + name = parameter.name + if style in { + ParameterStyle.NAMED_COLON, + ParameterStyle.NAMED_AT, + ParameterStyle.NAMED_DOLLAR, + ParameterStyle.NAMED_PYFORMAT, + }: + identifiers.add(("named", name or f"param_{parameter.ordinal}")) + elif style in {ParameterStyle.NUMERIC, ParameterStyle.POSITIONAL_COLON}: + if name and name.isdigit(): + identifiers.add(("index", int(name) - 1)) + else: + identifiers.add(("index", parameter.ordinal)) + else: + identifiers.add(("index", parameter.ordinal)) + return identifiers + + +def _collect_actual_identifiers(parameters: Any) -> "tuple[set[tuple[str, int | str]], int]": + if parameters is None: + return set(), 0 + if isinstance(parameters, Mapping): + mapping_identifiers = {normalize_parameter_key(key) for key in parameters} + return mapping_identifiers, len(parameters) + if looks_like_execute_many(parameters): + aggregated_identifiers: set[tuple[str, int | str]] = set() + for entry in parameters: + entry_identifiers, _ = _collect_actual_identifiers(entry) + aggregated_identifiers.update(entry_identifiers) + return aggregated_identifiers, len(aggregated_identifiers) + if _is_sequence_like(parameters): + identifiers = {("index", cast("int | str", index)) for index in range(len(parameters))} + return identifiers, len(parameters) + identifiers = {("index", cast("int | str", 0))} + return identifiers, 1 + + +def _format_identifiers(identifiers: "set[tuple[str, int | str]]") -> str: + if not identifiers: + return "[]" + formatted: list[str] = [] + for identifier in sorted(identifiers, key=lambda item: (item[0], str(item[1]))): + kind, value = identifier + if kind == "named": + formatted.append(str(value)) + elif isinstance(value, int): + formatted.append(str(value + 1)) + else: + formatted.append(str(value)) + return "[" + ", ".join(formatted) + "]" + + +def _validate_single_parameter_set( + parameter_profile: "ParameterProfile", parameters: Any, batch_index: "int | None" = None +) -> None: + expected_identifiers = _collect_expected_identifiers(parameter_profile) + actual_identifiers, actual_count = _collect_actual_identifiers(parameters) + expected_count = len(expected_identifiers) + + if expected_count == 0 and actual_count == 0: + return + + prefix = "Parameter count mismatch" + if batch_index is not None: + prefix = f"{prefix} in batch {batch_index}" + + if expected_count == 0 and actual_count > 0: + msg = f"{prefix}: statement does not accept parameters." + raise sqlspec.exceptions.SQLSpecError(msg) + + if expected_count > 0 and actual_count == 0: + msg = f"{prefix}: expected {expected_count} parameters, received 0." + raise sqlspec.exceptions.SQLSpecError(msg) + + if expected_count != actual_count: + msg = f"{prefix}: {actual_count} parameters provided but {expected_count} placeholders detected." + raise sqlspec.exceptions.SQLSpecError(msg) + + if expected_identifiers != actual_identifiers: + msg = ( + f"{prefix}: expected identifiers {_format_identifiers(expected_identifiers)}, " + f"received {_format_identifiers(actual_identifiers)}." + ) + raise sqlspec.exceptions.SQLSpecError(msg) + + +def validate_parameter_alignment( + parameter_profile: "ParameterProfile | None", parameters: Any, *, is_many: bool = False +) -> None: + """Ensure provided parameters align with detected placeholders. + + Args: + parameter_profile: Placeholder metadata extracted from the statement. + parameters: Parameter payload the adapter will execute with. + is_many: Whether the call explicitly targets ``execute_many``. + + Raises: + SQLSpecError: If counts or identifiers differ between placeholders and payload. + """ + profile = parameter_profile or ParameterProfile.empty() + if profile.total_count == 0: + return + + effective_is_many = is_many or looks_like_execute_many(parameters) + + if effective_is_many: + if parameters is None: + if profile.total_count == 0: + return + msg = "Parameter count mismatch: expected parameter sets for execute_many." + raise sqlspec.exceptions.SQLSpecError(msg) + if not _is_sequence_like(parameters): + msg = "Parameter count mismatch: expected sequence of parameter sets for execute_many." + raise sqlspec.exceptions.SQLSpecError(msg) + if len(parameters) == 0: + return + for index, param_set in enumerate(parameters): + _validate_single_parameter_set(profile, param_set, batch_index=index) + return + + _validate_single_parameter_set(profile, parameters) diff --git a/sqlspec/core/parameters/_converter.py b/sqlspec/core/parameters/_converter.py new file mode 100644 index 000000000..e9e66b8ec --- /dev/null +++ b/sqlspec/core/parameters/_converter.py @@ -0,0 +1,393 @@ +"""Parameter style conversion utilities.""" + +from collections.abc import Callable, Mapping, Sequence +from typing import Any + +from mypy_extensions import mypyc_attr + +from sqlspec.core.parameters._types import ParameterInfo, ParameterStyle +from sqlspec.core.parameters._validator import ParameterValidator + +__all__ = ("ParameterConverter",) + + +@mypyc_attr(allow_interpreted_subclasses=False) +class ParameterConverter: + """Parameter style conversion helper.""" + + __slots__ = ("_format_converters", "_placeholder_generators", "validator") + + def __init__(self) -> None: + self.validator = ParameterValidator() + + self._format_converters = { + ParameterStyle.POSITIONAL_COLON: self._convert_to_positional_colon_format, + ParameterStyle.NAMED_COLON: self._convert_to_named_colon_format, + ParameterStyle.NAMED_PYFORMAT: self._convert_to_named_pyformat_format, + ParameterStyle.QMARK: self._convert_to_positional_format, + ParameterStyle.NUMERIC: self._convert_to_positional_format, + ParameterStyle.POSITIONAL_PYFORMAT: self._convert_to_positional_format, + ParameterStyle.NAMED_AT: self._convert_to_named_colon_format, + ParameterStyle.NAMED_DOLLAR: self._convert_to_named_colon_format, + } + + self._placeholder_generators: dict[ParameterStyle, Callable[[Any], str]] = { + ParameterStyle.QMARK: lambda _: "?", + ParameterStyle.NUMERIC: lambda i: f"${int(i) + 1}", + ParameterStyle.NAMED_COLON: lambda name: f":{name}", + ParameterStyle.POSITIONAL_COLON: lambda i: f":{int(i) + 1}", + ParameterStyle.NAMED_AT: lambda name: f"@{name}", + ParameterStyle.NAMED_DOLLAR: lambda name: f"${name}", + ParameterStyle.NAMED_PYFORMAT: lambda name: f"%({name})s", + ParameterStyle.POSITIONAL_PYFORMAT: lambda _: "%s", + } + + def normalize_sql_for_parsing(self, sql: str, dialect: str | None = None) -> "tuple[str, list[ParameterInfo]]": + param_info = self.validator.extract_parameters(sql) + + incompatible_styles = self.validator.get_sqlglot_incompatible_styles(dialect) + needs_conversion = any(p.style in incompatible_styles for p in param_info) + + if not needs_conversion: + return sql, param_info + + converted_sql = self._convert_to_sqlglot_compatible(sql, param_info, incompatible_styles) + return converted_sql, param_info + + def _convert_to_sqlglot_compatible( + self, sql: str, param_info: "list[ParameterInfo]", incompatible_styles: "set[ParameterStyle]" + ) -> str: + converted_sql = sql + for param in reversed(param_info): + if param.style in incompatible_styles: + canonical_placeholder = f":param_{param.ordinal}" + converted_sql = ( + converted_sql[: param.position] + + canonical_placeholder + + converted_sql[param.position + len(param.placeholder_text) :] + ) + return converted_sql + + def convert_placeholder_style( + self, sql: str, parameters: Any, target_style: ParameterStyle, is_many: bool = False + ) -> tuple[str, Any]: + param_info = self.validator.extract_parameters(sql) + + if target_style == ParameterStyle.STATIC: + return self._embed_static_parameters(sql, parameters, param_info) + + current_styles = {p.style for p in param_info} + if len(current_styles) == 1 and target_style in current_styles: + converted_parameters = self._convert_parameter_format( + parameters, param_info, target_style, parameters, preserve_parameter_format=True + ) + return sql, converted_parameters + + converted_sql = self._convert_placeholders_to_style(sql, param_info, target_style) + converted_parameters = self._convert_parameter_format( + parameters, param_info, target_style, parameters, preserve_parameter_format=True + ) + return converted_sql, converted_parameters + + def _convert_placeholders_to_style( + self, sql: str, param_info: "list[ParameterInfo]", target_style: "ParameterStyle" + ) -> str: + generator = self._placeholder_generators.get(target_style) + if generator is None: + msg = f"Unsupported target parameter style: {target_style}" + raise ValueError(msg) + + param_styles = {p.style for p in param_info} + use_sequential_for_qmark = ( + len(param_styles) == 1 and ParameterStyle.QMARK in param_styles and target_style == ParameterStyle.NUMERIC + ) + + unique_params: dict[str, int] = {} + for param in param_info: + param_key = ( + f"{param.placeholder_text}_{param.ordinal}" + if use_sequential_for_qmark and param.style == ParameterStyle.QMARK + else param.placeholder_text + ) + if param_key not in unique_params: + unique_params[param_key] = len(unique_params) + + converted_sql = sql + placeholder_text_len_cache: dict[str, int] = {} + for param in reversed(param_info): + if param.placeholder_text not in placeholder_text_len_cache: + placeholder_text_len_cache[param.placeholder_text] = len(param.placeholder_text) + text_len = placeholder_text_len_cache[param.placeholder_text] + + if target_style in { + ParameterStyle.QMARK, + ParameterStyle.NUMERIC, + ParameterStyle.POSITIONAL_PYFORMAT, + ParameterStyle.POSITIONAL_COLON, + }: + param_key = ( + f"{param.placeholder_text}_{param.ordinal}" + if use_sequential_for_qmark and param.style == ParameterStyle.QMARK + else param.placeholder_text + ) + new_placeholder = generator(unique_params[param_key]) + else: + param_name = param.name or f"param_{param.ordinal}" + new_placeholder = generator(param_name) + + converted_sql = ( + converted_sql[: param.position] + new_placeholder + converted_sql[param.position + text_len :] + ) + + return converted_sql + + def _convert_sequence_to_dict( + self, parameters: Sequence[Any], param_info: "list[ParameterInfo]" + ) -> "dict[str, Any]": + param_dict: dict[str, Any] = {} + for i, param in enumerate(param_info): + if i < len(parameters): + name = param.name or f"param_{param.ordinal}" + param_dict[name] = parameters[i] + return param_dict + + def _extract_param_value_mixed_styles( + self, param: "ParameterInfo", parameters: Mapping[str, Any], param_keys: "list[str]" + ) -> "tuple[Any, bool]": + if param.name and param.name in parameters: + return parameters[param.name], True + + if ( + param.style == ParameterStyle.NUMERIC + and param.name + and param.name.isdigit() + and param.ordinal < len(param_keys) + ): + key_to_use = param_keys[param.ordinal] + return parameters[key_to_use], True + + if f"param_{param.ordinal}" in parameters: + return parameters[f"param_{param.ordinal}"], True + + ordinal_key = str(param.ordinal + 1) + if ordinal_key in parameters: + return parameters[ordinal_key], True + + try: + ordered_keys = list(parameters.keys()) + except AttributeError: + ordered_keys = [] + if ordered_keys and param.ordinal < len(ordered_keys): + key = ordered_keys[param.ordinal] + if key in parameters: + return parameters[key], True + + return None, False + + def _extract_param_value_single_style( + self, param: "ParameterInfo", parameters: Mapping[str, Any] + ) -> "tuple[Any, bool]": + if param.name and param.name in parameters: + return parameters[param.name], True + if f"param_{param.ordinal}" in parameters: + return parameters[f"param_{param.ordinal}"], True + + ordinal_key = str(param.ordinal + 1) + if ordinal_key in parameters: + return parameters[ordinal_key], True + + try: + ordered_keys = list(parameters.keys()) + except AttributeError: + ordered_keys = [] + if ordered_keys and param.ordinal < len(ordered_keys): + key = ordered_keys[param.ordinal] + if key in parameters: + return parameters[key], True + + return None, False + + def _preserve_original_format(self, param_values: list[Any], original_parameters: Any) -> Any: + if isinstance(original_parameters, tuple): + return tuple(param_values) + if isinstance(original_parameters, list): + return param_values + if isinstance(original_parameters, Mapping): + return tuple(param_values) + + if hasattr(original_parameters, "__class__") and callable(original_parameters.__class__): + try: + return original_parameters.__class__(param_values) + except (TypeError, ValueError): + return tuple(param_values) + + return param_values + + def _convert_parameter_format( + self, + parameters: Any, + param_info: "list[ParameterInfo]", + target_style: "ParameterStyle", + original_parameters: Any = None, + preserve_parameter_format: bool = False, + ) -> Any: + if not parameters or not param_info: + return parameters + + is_named_style = target_style in { + ParameterStyle.NAMED_COLON, + ParameterStyle.NAMED_AT, + ParameterStyle.NAMED_DOLLAR, + ParameterStyle.NAMED_PYFORMAT, + } + + if is_named_style: + if isinstance(parameters, Mapping): + return parameters + if isinstance(parameters, Sequence) and not isinstance(parameters, (str, bytes)): + return self._convert_sequence_to_dict(parameters, param_info) + + elif isinstance(parameters, Sequence) and not isinstance(parameters, (str, bytes)): + return parameters + + elif isinstance(parameters, Mapping): + param_values: list[Any] = [] + parameter_styles = {p.style for p in param_info} + has_mixed_styles = len(parameter_styles) > 1 + + unique_params: dict[str, Any] = {} + param_order: list[str] = [] + + if has_mixed_styles: + param_keys = list(parameters.keys()) + for param in param_info: + param_key = param.placeholder_text + if param_key not in unique_params: + value, found = self._extract_param_value_mixed_styles(param, parameters, param_keys) + if found: + unique_params[param_key] = value + param_order.append(param_key) + else: + for param in param_info: + param_key = param.placeholder_text + if param_key not in unique_params: + value, found = self._extract_param_value_single_style(param, parameters) + if found: + unique_params[param_key] = value + param_order.append(param_key) + + param_values = [unique_params[param_key] for param_key in param_order] + + if preserve_parameter_format and original_parameters is not None: + return self._preserve_original_format(param_values, original_parameters) + + return param_values + + return parameters + + def _embed_static_parameters( + self, sql: str, parameters: Any, param_info: "list[ParameterInfo]" + ) -> "tuple[str, Any]": + if not param_info: + return sql, None + + unique_params: dict[str, int] = {} + for param in param_info: + if param.style in {ParameterStyle.QMARK, ParameterStyle.POSITIONAL_PYFORMAT}: + param_key = f"{param.placeholder_text}_{param.ordinal}" + elif (param.style == ParameterStyle.NUMERIC and param.name) or param.name: + param_key = param.placeholder_text + else: + param_key = f"{param.placeholder_text}_{param.ordinal}" + + if param_key not in unique_params: + unique_params[param_key] = len(unique_params) + + static_sql = sql + for param in reversed(param_info): + param_value = self._get_parameter_value_with_reuse(parameters, param, unique_params) + + if param_value is None: + literal = "NULL" + elif isinstance(param_value, str): + escaped = param_value.replace("'", "''") + literal = f"'{escaped}'" + elif isinstance(param_value, bool): + literal = "TRUE" if param_value else "FALSE" + elif isinstance(param_value, (int, float)): + literal = str(param_value) + else: + literal = f"'{param_value!s}'" + + static_sql = ( + static_sql[: param.position] + literal + static_sql[param.position + len(param.placeholder_text) :] + ) + + return static_sql, None + + def _get_parameter_value(self, parameters: Any, param: "ParameterInfo") -> Any: + if isinstance(parameters, Mapping): + if param.name and param.name in parameters: + return parameters[param.name] + if f"param_{param.ordinal}" in parameters: + return parameters[f"param_{param.ordinal}"] + if str(param.ordinal + 1) in parameters: + return parameters[str(param.ordinal + 1)] + elif isinstance(parameters, Sequence) and not isinstance(parameters, (str, bytes)): + if param.ordinal < len(parameters): + return parameters[param.ordinal] + + return None + + def _get_parameter_value_with_reuse( + self, parameters: Any, param: "ParameterInfo", unique_params: "dict[str, int]" + ) -> Any: + if param.style in {ParameterStyle.QMARK, ParameterStyle.POSITIONAL_PYFORMAT}: + param_key = f"{param.placeholder_text}_{param.ordinal}" + elif (param.style == ParameterStyle.NUMERIC and param.name) or param.name: + param_key = param.placeholder_text + else: + param_key = f"{param.placeholder_text}_{param.ordinal}" + + unique_ordinal = unique_params.get(param_key) + if unique_ordinal is None: + return None + + if isinstance(parameters, Mapping): + if param.name and param.name in parameters: + return parameters[param.name] + if f"param_{unique_ordinal}" in parameters: + return parameters[f"param_{unique_ordinal}"] + if str(unique_ordinal + 1) in parameters: + return parameters[str(unique_ordinal + 1)] + elif isinstance(parameters, Sequence) and not isinstance(parameters, (str, bytes)): + if unique_ordinal < len(parameters): + return parameters[unique_ordinal] + + return None + + def _convert_to_positional_format(self, parameters: Any, param_info: "list[ParameterInfo]") -> Any: + return self._convert_parameter_format( + parameters, param_info, ParameterStyle.QMARK, parameters, preserve_parameter_format=False + ) + + def _convert_to_named_colon_format(self, parameters: Any, param_info: "list[ParameterInfo]") -> Any: + return self._convert_parameter_format( + parameters, param_info, ParameterStyle.NAMED_COLON, parameters, preserve_parameter_format=False + ) + + def _convert_to_positional_colon_format(self, parameters: Any, param_info: "list[ParameterInfo]") -> Any: + if isinstance(parameters, Mapping): + return parameters + + param_dict: dict[str, Any] = {} + if isinstance(parameters, Sequence) and not isinstance(parameters, (str, bytes)): + for index, value in enumerate(parameters): + param_dict[str(index + 1)] = value + + return param_dict + + def _convert_to_named_pyformat_format(self, parameters: Any, param_info: "list[ParameterInfo]") -> Any: + return self._convert_parameter_format( + parameters, param_info, ParameterStyle.NAMED_PYFORMAT, parameters, preserve_parameter_format=False + ) diff --git a/sqlspec/core/parameters/_processor.py b/sqlspec/core/parameters/_processor.py new file mode 100644 index 000000000..c12e475be --- /dev/null +++ b/sqlspec/core/parameters/_processor.py @@ -0,0 +1,304 @@ +"""Parameter processing pipeline orchestrator.""" + +import hashlib +from collections.abc import Callable, Mapping, Sequence +from typing import Any + +from mypy_extensions import mypyc_attr + +from sqlspec.core.parameters._converter import ParameterConverter +from sqlspec.core.parameters._types import ( + ParameterInfo, + ParameterProcessingResult, + ParameterProfile, + ParameterStyle, + ParameterStyleConfig, + TypedParameter, + wrap_with_type, +) +from sqlspec.core.parameters._validator import ParameterValidator + +__all__ = ("ParameterProcessor",) + + +def _fingerprint_parameters(parameters: Any) -> str: + """Return a stable fingerprint for caching parameter payloads. + + Args: + parameters: Original parameter payload supplied by the caller. + + Returns: + Deterministic fingerprint string derived from the parameter payload. + """ + if parameters is None: + return "none" + + if isinstance(parameters, Mapping): + try: + items = sorted(parameters.items(), key=lambda item: repr(item[0])) + except Exception: + items = list(parameters.items()) + data = repr(tuple(items)) + elif isinstance(parameters, Sequence) and not isinstance(parameters, (str, bytes, bytearray)): + data = repr(tuple(parameters)) + else: + data = repr(parameters) + + digest = hashlib.sha256(data.encode("utf-8")).hexdigest() + return f"{type(parameters).__name__}:{digest}" + + +@mypyc_attr(allow_interpreted_subclasses=False) +class ParameterProcessor: + """Parameter processing engine coordinating conversion phases.""" + + __slots__ = ("_cache", "_cache_size", "_converter", "_validator") + + DEFAULT_CACHE_SIZE = 1000 + + def __init__(self) -> None: + self._cache: dict[str, ParameterProcessingResult] = {} + self._cache_size = 0 + self._validator = ParameterValidator() + self._converter = ParameterConverter() + + def _handle_static_embedding( + self, sql: str, parameters: Any, config: "ParameterStyleConfig", is_many: bool, cache_key: str + ) -> "ParameterProcessingResult": + coerced_params = parameters + if config.type_coercion_map and parameters: + coerced_params = self._apply_type_coercions(parameters, config.type_coercion_map, is_many) + + static_sql, static_params = self._converter.convert_placeholder_style( + sql, coerced_params, ParameterStyle.STATIC, is_many + ) + result = ParameterProcessingResult(static_sql, static_params, ParameterProfile.empty()) + if self._cache_size < self.DEFAULT_CACHE_SIZE: + self._cache[cache_key] = result + self._cache_size += 1 + return result + + def _determine_target_execution_style( + self, original_styles: "set[ParameterStyle]", config: "ParameterStyleConfig" + ) -> "ParameterStyle": + if len(original_styles) == 1 and config.supported_execution_parameter_styles is not None: + original_style = next(iter(original_styles)) + if original_style in config.supported_execution_parameter_styles: + return original_style + return config.default_execution_parameter_style or config.default_parameter_style + + def _apply_type_wrapping(self, parameters: Any) -> Any: + if isinstance(parameters, Sequence) and not isinstance(parameters, (str, bytes)): + return [wrap_with_type(p) for p in parameters] + if isinstance(parameters, Mapping): + return {k: wrap_with_type(v) for k, v in parameters.items()} + return wrap_with_type(parameters) + + def _apply_type_coercions( + self, parameters: Any, type_coercion_map: "dict[type, Callable[[Any], Any]]", is_many: bool = False + ) -> Any: + def coerce_value(value: Any) -> Any: + if value is None: + return value + + if isinstance(value, TypedParameter): + wrapped_value: Any = value.value + if wrapped_value is None: + return wrapped_value + original_type = value.original_type + if original_type in type_coercion_map: + coerced = type_coercion_map[original_type](wrapped_value) + if isinstance(coerced, (list, tuple)) and not isinstance(coerced, (str, bytes)): + coerced = [coerce_value(item) for item in coerced] + elif isinstance(coerced, dict): + coerced = {k: coerce_value(v) for k, v in coerced.items()} + return coerced + return wrapped_value + + value_type = type(value) + if value_type in type_coercion_map: + coerced = type_coercion_map[value_type](value) + if isinstance(coerced, (list, tuple)) and not isinstance(coerced, (str, bytes)): + coerced = [coerce_value(item) for item in coerced] + elif isinstance(coerced, dict): + coerced = {k: coerce_value(v) for k, v in coerced.items()} + return coerced + return value + + def coerce_parameter_set(param_set: Any) -> Any: + if isinstance(param_set, Sequence) and not isinstance(param_set, (str, bytes)): + return [coerce_value(p) for p in param_set] + if isinstance(param_set, Mapping): + return {k: coerce_value(v) for k, v in param_set.items()} + return coerce_value(param_set) + + if is_many and isinstance(parameters, Sequence) and not isinstance(parameters, (str, bytes)): + return [coerce_parameter_set(param_set) for param_set in parameters] + + if isinstance(parameters, Sequence) and not isinstance(parameters, (str, bytes)): + return [coerce_value(p) for p in parameters] + if isinstance(parameters, Mapping): + return {k: coerce_value(v) for k, v in parameters.items()} + return coerce_value(parameters) + + def _generate_processor_cache_key( + self, sql: str, parameters: Any, config: "ParameterStyleConfig", is_many: bool, dialect: str | None + ) -> str: + param_fingerprint = _fingerprint_parameters(parameters) + dialect_marker = dialect or "default" + default_style = config.default_parameter_style.value if config.default_parameter_style else "unknown" + return f"{sql}:{param_fingerprint}:{default_style}:{is_many}:{dialect_marker}" + + def process( + self, + sql: str, + parameters: Any, + config: "ParameterStyleConfig", + dialect: str | None = None, + is_many: bool = False, + ) -> "ParameterProcessingResult": + cache_key = self._generate_processor_cache_key(sql, parameters, config, is_many, dialect) + cached_result = self._cache.get(cache_key) + if cached_result is not None: + return cached_result + + param_info = self._validator.extract_parameters(sql) + original_styles = {p.style for p in param_info} if param_info else set() + needs_sqlglot_normalization = self._needs_sqlglot_normalization(param_info, dialect) + needs_execution_conversion = self._needs_execution_conversion(param_info, config) + + needs_static_embedding = config.needs_static_script_compilation and param_info and parameters and not is_many + + if needs_static_embedding: + return self._handle_static_embedding(sql, parameters, config, is_many, cache_key) + + if ( + not needs_sqlglot_normalization + and not needs_execution_conversion + and not config.type_coercion_map + and not config.output_transformer + ): + result = ParameterProcessingResult(sql, parameters, ParameterProfile(param_info)) + if self._cache_size < self.DEFAULT_CACHE_SIZE: + self._cache[cache_key] = result + self._cache_size += 1 + return result + + processed_sql, processed_parameters = sql, parameters + + if processed_parameters: + processed_parameters = self._apply_type_wrapping(processed_parameters) + + if needs_sqlglot_normalization: + processed_sql, _ = self._converter.normalize_sql_for_parsing(processed_sql, dialect) + + if config.type_coercion_map and processed_parameters: + processed_parameters = self._apply_type_coercions(processed_parameters, config.type_coercion_map, is_many) + + processed_sql, processed_parameters = self._process_parameters_conversion( + processed_sql, + processed_parameters, + config, + original_styles, + needs_execution_conversion, + needs_sqlglot_normalization, + is_many, + ) + + if config.output_transformer: + processed_sql, processed_parameters = config.output_transformer(processed_sql, processed_parameters) + + final_param_info = self._validator.extract_parameters(processed_sql) + final_profile = ParameterProfile(final_param_info) + result = ParameterProcessingResult(processed_sql, processed_parameters, final_profile) + + if self._cache_size < self.DEFAULT_CACHE_SIZE: + self._cache[cache_key] = result + self._cache_size += 1 + return result + + def get_sqlglot_compatible_sql( + self, sql: str, parameters: Any, config: "ParameterStyleConfig", dialect: str | None = None + ) -> "tuple[str, Any]": + """Normalize SQL for parsing without altering execution format. + + Args: + sql: Raw SQL text. + parameters: Parameter payload supplied by the caller. + config: Parameter style configuration. + dialect: Optional SQL dialect for compatibility checks. + + Returns: + Tuple of normalized SQL and the original parameter payload. + """ + + param_info = self._validator.extract_parameters(sql) + + if self._needs_sqlglot_normalization(param_info, dialect): + normalized_sql, _ = self._converter.normalize_sql_for_parsing(sql, dialect) + return normalized_sql, parameters + + return sql, parameters + + def _needs_execution_conversion(self, param_info: "list[ParameterInfo]", config: "ParameterStyleConfig") -> bool: + """Determine whether execution placeholder conversion is required.""" + if config.needs_static_script_compilation: + return True + + if not param_info: + return False + + current_styles = {param.style for param in param_info} + + if ( + config.allow_mixed_parameter_styles + and len(current_styles) > 1 + and config.supported_execution_parameter_styles is not None + and len(config.supported_execution_parameter_styles) > 1 + and all(style in config.supported_execution_parameter_styles for style in current_styles) + ): + return False + + if ( + config.supported_execution_parameter_styles is not None + and len(config.supported_execution_parameter_styles) > 1 + and all(style in config.supported_execution_parameter_styles for style in current_styles) + ): + return False + + if len(current_styles) > 1: + return True + + if len(current_styles) == 1: + current_style = next(iter(current_styles)) + supported_styles = config.supported_execution_parameter_styles + if supported_styles is None: + return True + return current_style not in supported_styles + + return True + + def _needs_sqlglot_normalization(self, param_info: "list[ParameterInfo]", dialect: str | None = None) -> bool: + incompatible_styles = self._validator.get_sqlglot_incompatible_styles(dialect) + return any(p.style in incompatible_styles for p in param_info) + + def _process_parameters_conversion( + self, + sql: str, + parameters: Any, + config: "ParameterStyleConfig", + original_styles: "set[ParameterStyle]", + needs_execution_conversion: bool, + needs_sqlglot_normalization: bool, + is_many: bool, + ) -> "tuple[str, Any]": + if not (needs_execution_conversion or needs_sqlglot_normalization): + return sql, parameters + + if is_many and config.preserve_original_params_for_many and isinstance(parameters, (list, tuple)): + target_style = self._determine_target_execution_style(original_styles, config) + processed_sql, _ = self._converter.convert_placeholder_style(sql, parameters, target_style, is_many) + return processed_sql, parameters + + target_style = self._determine_target_execution_style(original_styles, config) + return self._converter.convert_placeholder_style(sql, parameters, target_style, is_many) diff --git a/sqlspec/core/parameters/_registry.py b/sqlspec/core/parameters/_registry.py new file mode 100644 index 000000000..fda05731a --- /dev/null +++ b/sqlspec/core/parameters/_registry.py @@ -0,0 +1,201 @@ +"""Driver parameter profile registry and StatementConfig factory.""" + +from collections.abc import Callable +from typing import TYPE_CHECKING, Any, Literal, cast + +import sqlspec.exceptions +from sqlspec.core.parameters._types import DriverParameterProfile, ParameterStyleConfig +from sqlspec.utils.serializers import from_json, to_json + +if TYPE_CHECKING: + from sqlspec.core.statement import StatementConfig + +__all__ = ( + "DRIVER_PARAMETER_PROFILES", + "build_statement_config_from_profile", + "get_driver_profile", + "register_driver_profile", +) + +_DEFAULT_JSON_SERIALIZER: Callable[[Any], str] = to_json +_DEFAULT_JSON_DESERIALIZER: Callable[[str], Any] = from_json + +DRIVER_PARAMETER_PROFILES: dict[str, DriverParameterProfile] = {} + + +def get_driver_profile(adapter_key: str) -> "DriverParameterProfile": + """Return the registered parameter profile for the specified adapter. + + Args: + adapter_key: Adapter identifier (case-insensitive). + + Returns: + Registered :class:`DriverParameterProfile` instance. + + Raises: + ImproperConfigurationError: If the adapter does not have a profile. + """ + key = adapter_key.lower() + try: + return DRIVER_PARAMETER_PROFILES[key] + except KeyError as error: + msg = f"No driver parameter profile registered for adapter '{adapter_key}'." + raise sqlspec.exceptions.ImproperConfigurationError(msg) from error + + +def register_driver_profile( + adapter_key: str, profile: "DriverParameterProfile", *, allow_override: bool = False +) -> None: + """Register a driver profile under the canonical adapter key. + + Args: + adapter_key: Adapter identifier (case-insensitive). + profile: Profile describing parameter behaviour. + allow_override: Whether to replace an existing entry. + + Raises: + ImproperConfigurationError: If attempting to register a duplicate profile. + """ + + key = adapter_key.lower() + if not allow_override and key in DRIVER_PARAMETER_PROFILES: + msg = f"Profile already registered for adapter '{adapter_key}'." + raise sqlspec.exceptions.ImproperConfigurationError(msg) + DRIVER_PARAMETER_PROFILES[key] = profile + + +def _build_parameter_style_config_from_profile( + profile: "DriverParameterProfile", + parameter_overrides: "dict[str, Any] | None", + json_serializer: "Callable[[Any], str] | None", + json_deserializer: "Callable[[str], Any] | None", +) -> "ParameterStyleConfig": + """Build a :class:`ParameterStyleConfig` instance from a driver profile. + + Args: + profile: Source driver profile. + parameter_overrides: Optional overrides applied before instantiation. + json_serializer: Adapter-provided JSON serializer. + json_deserializer: Adapter-provided JSON deserializer. + + Returns: + Configured :class:`ParameterStyleConfig` ready for statement construction. + """ + overrides = dict(parameter_overrides or {}) + supported_styles_override = overrides.pop("supported_parameter_styles", None) + execution_styles_override = overrides.pop("supported_execution_parameter_styles", None) + type_coercion_override = overrides.pop("type_coercion_map", None) + json_serializer_override = overrides.pop("json_serializer", None) + json_deserializer_override = overrides.pop("json_deserializer", None) + tuple_strategy_override = overrides.pop("json_tuple_strategy", None) + + supported_styles = ( + set(supported_styles_override) if supported_styles_override is not None else set(profile.supported_styles) + ) + if execution_styles_override is None: + execution_supported = ( + set(profile.supported_execution_styles) if profile.supported_execution_styles is not None else None + ) + else: + execution_supported = set(execution_styles_override) if execution_styles_override is not None else None + + type_map = ( + dict(type_coercion_override) if type_coercion_override is not None else dict(profile.custom_type_coercions) + ) + + parameter_kwargs: dict[str, Any] = { + "default_parameter_style": overrides.pop("default_parameter_style", profile.default_style), + "supported_parameter_styles": supported_styles, + "supported_execution_parameter_styles": execution_supported, + "default_execution_parameter_style": overrides.pop( + "default_execution_parameter_style", profile.default_execution_style + ), + "type_coercion_map": type_map, + "has_native_list_expansion": overrides.pop("has_native_list_expansion", profile.has_native_list_expansion), + "needs_static_script_compilation": overrides.pop( + "needs_static_script_compilation", profile.needs_static_script_compilation + ), + "allow_mixed_parameter_styles": overrides.pop( + "allow_mixed_parameter_styles", profile.allow_mixed_parameter_styles + ), + "preserve_parameter_format": overrides.pop("preserve_parameter_format", profile.preserve_parameter_format), + "preserve_original_params_for_many": overrides.pop( + "preserve_original_params_for_many", profile.preserve_original_params_for_many + ), + "output_transformer": overrides.pop("output_transformer", profile.default_output_transformer), + "ast_transformer": overrides.pop("ast_transformer", profile.default_ast_transformer), + } + + parameter_kwargs = {k: v for k, v in parameter_kwargs.items() if v is not None} + + strategy = profile.json_serializer_strategy + serializer_value = json_serializer_override or json_serializer + deserializer_value = json_deserializer_override or json_deserializer + + if serializer_value is None: + serializer_value = profile.extras.get("default_json_serializer", _DEFAULT_JSON_SERIALIZER) + if deserializer_value is None: + deserializer_value = profile.extras.get("default_json_deserializer", _DEFAULT_JSON_DESERIALIZER) + + if strategy == "driver": + parameter_kwargs["json_serializer"] = serializer_value + parameter_kwargs["json_deserializer"] = deserializer_value + + parameter_kwargs.update(overrides) + parameter_config = ParameterStyleConfig(**parameter_kwargs) + + if strategy == "helper": + tuple_strategy = tuple_strategy_override or profile.extras.get("json_tuple_strategy", "list") + tuple_strategy_literal = cast("Literal['list', 'tuple']", tuple_strategy) + parameter_config = parameter_config.with_json_serializers( + serializer_value, tuple_strategy=tuple_strategy_literal, deserializer=deserializer_value + ) + elif strategy == "driver": + parameter_config = parameter_config.replace( + json_serializer=serializer_value, json_deserializer=deserializer_value + ) + + type_overrides = profile.extras.get("type_coercion_overrides") + if type_overrides: + updated_map = {**parameter_config.type_coercion_map, **dict(type_overrides)} + parameter_config = parameter_config.replace(type_coercion_map=updated_map) + + return parameter_config + + +def build_statement_config_from_profile( + profile: "DriverParameterProfile", + *, + parameter_overrides: "dict[str, Any] | None" = None, + statement_overrides: "dict[str, Any] | None" = None, + json_serializer: "Callable[[Any], str] | None" = None, + json_deserializer: "Callable[[str], Any] | None" = None, +) -> "StatementConfig": + """Construct a :class:`StatementConfig` seeded from a driver profile. + + Args: + profile: Driver profile providing default parameter behaviour. + parameter_overrides: Optional overrides for parameter config fields. + statement_overrides: Optional overrides for resulting statement config. + json_serializer: Optional JSON serializer supplied by the adapter. + json_deserializer: Optional JSON deserializer supplied by the adapter. + + Returns: + New :class:`StatementConfig` instance with merged configuration. + """ + parameter_config = _build_parameter_style_config_from_profile( + profile, parameter_overrides, json_serializer, json_deserializer + ) + + from sqlspec.core.statement import StatementConfig as _StatementConfig + + statement_kwargs: dict[str, Any] = {} + if profile.default_dialect is not None: + statement_kwargs["dialect"] = profile.default_dialect + if profile.statement_kwargs: + statement_kwargs.update(profile.statement_kwargs) + if statement_overrides: + statement_kwargs.update(statement_overrides) + + filtered_statement_kwargs = {k: v for k, v in statement_kwargs.items() if v is not None} + return _StatementConfig(parameter_config=parameter_config, **filtered_statement_kwargs) diff --git a/sqlspec/core/parameters/_transformers.py b/sqlspec/core/parameters/_transformers.py new file mode 100644 index 000000000..eca9c9c67 --- /dev/null +++ b/sqlspec/core/parameters/_transformers.py @@ -0,0 +1,198 @@ +"""AST transformer helpers for parameter processing.""" + +import bisect +from collections.abc import Callable, Mapping, Sequence +from typing import Any + +from sqlspec.core.parameters._alignment import ( + collect_null_parameter_ordinals, + looks_like_execute_many, + normalize_parameter_key, + validate_parameter_alignment, +) +from sqlspec.core.parameters._types import ParameterProfile +from sqlspec.core.parameters._validator import ParameterValidator + +__all__ = ("replace_null_parameters_with_literals", "replace_placeholders_with_literals") + +_AST_TRANSFORMER_VALIDATOR: "ParameterValidator" = ParameterValidator() + + +def replace_null_parameters_with_literals( + expression: Any, parameters: Any, *, dialect: str = "postgres", validator: "ParameterValidator | None" = None +) -> "tuple[Any, Any]": + """Rewrite placeholders representing ``NULL`` values and prune parameters. + + Args: + expression: SQLGlot expression tree to transform. + parameters: Parameter payload provided by the caller. + dialect: SQLGlot dialect for serializing the expression. + validator: Optional validator instance for parameter extraction. + + Returns: + Tuple containing the transformed expression and updated parameters. + """ + if not parameters: + return expression, parameters + + if looks_like_execute_many(parameters): + return expression, parameters + + validator_instance = validator or _AST_TRANSFORMER_VALIDATOR + parameter_info = validator_instance.extract_parameters(expression.sql(dialect=dialect)) + parameter_profile = ParameterProfile(parameter_info) + validate_parameter_alignment(parameter_profile, parameters) + + null_positions = collect_null_parameter_ordinals(parameters, parameter_profile) + if not null_positions: + return expression, parameters + + sorted_null_positions = sorted(null_positions) + + from sqlglot import exp as _exp # Imported lazily to avoid module-level dependency + + qmark_position = 0 + + def transform_node(node: Any) -> Any: + nonlocal qmark_position + + if isinstance(node, _exp.Placeholder) and getattr(node, "this", None) is None: + current_position = qmark_position + qmark_position += 1 + if current_position in null_positions: + return _exp.Null() + return node + + if isinstance(node, _exp.Placeholder) and getattr(node, "this", None) is not None: + placeholder_text = str(node.this) + normalized_text = placeholder_text.lstrip("$") + if normalized_text.isdigit(): + param_index = int(normalized_text) - 1 + if param_index in null_positions: + return _exp.Null() + shift = bisect.bisect_left(sorted_null_positions, param_index) + new_param_num = param_index - shift + 1 + return _exp.Placeholder(this=f"${new_param_num}") + return node + + if isinstance(node, _exp.Parameter) and getattr(node, "this", None) is not None: + parameter_text = str(node.this) + if parameter_text.isdigit(): + param_index = int(parameter_text) - 1 + if param_index in null_positions: + return _exp.Null() + shift = bisect.bisect_left(sorted_null_positions, param_index) + new_param_num = param_index - shift + 1 + return _exp.Parameter(this=str(new_param_num)) + return node + + return node + + transformed_expression = expression.transform(transform_node) + + if isinstance(parameters, Sequence) and not isinstance(parameters, (str, bytes, bytearray)): + cleaned_parameters = [value for index, value in enumerate(parameters) if index not in null_positions] + elif isinstance(parameters, Mapping): + cleaned_dict: dict[str, Any] = {} + next_numeric_index = 1 + + for key, value in parameters.items(): + if value is None: + continue + key_kind, normalized_key = normalize_parameter_key(key) + if key_kind == "index" and isinstance(normalized_key, int): + cleaned_dict[str(next_numeric_index)] = value + next_numeric_index += 1 + else: + cleaned_dict[str(normalized_key)] = value + cleaned_parameters = cleaned_dict # type: ignore[assignment] + else: + cleaned_parameters = parameters + + return transformed_expression, cleaned_parameters + + +def _create_literal_expression(value: Any, json_serializer: "Callable[[Any], str]") -> Any: + """Create a SQLGlot literal expression for the given value.""" + from sqlglot import exp as _exp + + if value is None: + return _exp.Null() + if isinstance(value, bool): + return _exp.Boolean(this=value) + if isinstance(value, (int, float)): + return _exp.Literal.number(str(value)) + if isinstance(value, str): + return _exp.Literal.string(value) + if isinstance(value, (list, tuple)): + items = [_create_literal_expression(item, json_serializer) for item in value] + return _exp.Array(expressions=items) + if isinstance(value, dict): + json_value = json_serializer(value) + return _exp.Literal.string(json_value) + return _exp.Literal.string(str(value)) + + +def replace_placeholders_with_literals( + expression: Any, parameters: Any, *, json_serializer: "Callable[[Any], str]" +) -> Any: + """Replace placeholders in an expression tree with literal values.""" + if not parameters: + return expression + + from sqlglot import exp as _exp + + placeholder_counter = {"index": 0} + + def resolve_mapping_value(param_name: str, payload: Mapping[str, Any]) -> Any | None: + candidate_names = (param_name, f"@{param_name}", f":{param_name}", f"${param_name}", f"param_{param_name}") + for candidate in candidate_names: + if candidate in payload: + return getattr(payload[candidate], "value", payload[candidate]) + normalized = param_name.lstrip("@:$") + if normalized in payload: + return getattr(payload[normalized], "value", payload[normalized]) + return None + + def transform(node: Any) -> Any: + if ( + isinstance(node, _exp.Placeholder) + and isinstance(parameters, Sequence) + and not isinstance(parameters, (str, bytes, bytearray)) + ): + current_index = placeholder_counter["index"] + placeholder_counter["index"] += 1 + if current_index < len(parameters): + literal_value = getattr(parameters[current_index], "value", parameters[current_index]) + return _create_literal_expression(literal_value, json_serializer) + return node + + if isinstance(node, _exp.Parameter): + param_name = str(node.this) if getattr(node, "this", None) is not None else "" + + if isinstance(parameters, Mapping): + resolved_value = resolve_mapping_value(param_name, parameters) + if resolved_value is not None: + return _create_literal_expression(resolved_value, json_serializer) + return node + + if isinstance(parameters, Sequence) and not isinstance(parameters, (str, bytes, bytearray)): + name = param_name + try: + if name.startswith("param_"): + index_value = int(name[6:]) + if 0 <= index_value < len(parameters): + literal_value = getattr(parameters[index_value], "value", parameters[index_value]) + return _create_literal_expression(literal_value, json_serializer) + if name.isdigit(): + index_value = int(name) + if 0 <= index_value < len(parameters): + literal_value = getattr(parameters[index_value], "value", parameters[index_value]) + return _create_literal_expression(literal_value, json_serializer) + except (ValueError, AttributeError): + return node + return node + + return node + + return expression.transform(transform) diff --git a/sqlspec/core/parameters/_types.py b/sqlspec/core/parameters/_types.py new file mode 100644 index 000000000..9eedca6ba --- /dev/null +++ b/sqlspec/core/parameters/_types.py @@ -0,0 +1,395 @@ +"""Core parameter data structures and utilities.""" + +from collections.abc import Callable, Collection, Generator, Mapping, Sequence +from dataclasses import dataclass, field +from datetime import date, datetime, time +from decimal import Decimal +from enum import Enum +from functools import singledispatch +from types import MappingProxyType +from typing import Any, Literal + +from mypy_extensions import mypyc_attr + +__all__ = ( + "DriverParameterProfile", + "ParameterInfo", + "ParameterProcessingResult", + "ParameterProfile", + "ParameterStyle", + "ParameterStyleConfig", + "TypedParameter", + "is_iterable_parameters", + "wrap_with_type", +) + + +@mypyc_attr(allow_interpreted_subclasses=False) +class ParameterStyle(str, Enum): + """Enumeration of supported SQL parameter placeholder styles.""" + + NONE = "none" + STATIC = "static" + QMARK = "qmark" + NUMERIC = "numeric" + NAMED_COLON = "named_colon" + POSITIONAL_COLON = "positional_colon" + NAMED_AT = "named_at" + NAMED_DOLLAR = "named_dollar" + NAMED_PYFORMAT = "pyformat_named" + POSITIONAL_PYFORMAT = "pyformat_positional" + + +@mypyc_attr(allow_interpreted_subclasses=False) +class TypedParameter: + """Wrapper that preserves original parameter type information.""" + + __slots__ = ("_hash", "original_type", "semantic_name", "value") + + def __init__(self, value: Any, original_type: "type | None" = None, semantic_name: "str | None" = None) -> None: + self.value = value + self.original_type = original_type or type(value) + self.semantic_name = semantic_name + self._hash: int | None = None + + def __hash__(self) -> int: + if self._hash is None: + value_id = id(self.value) + self._hash = hash((value_id, self.original_type, self.semantic_name)) + return self._hash + + def __eq__(self, other: object) -> bool: + if not isinstance(other, TypedParameter): + return False + return ( + self.value == other.value + and self.original_type == other.original_type + and self.semantic_name == other.semantic_name + ) + + def __repr__(self) -> str: + name_part = f", semantic_name='{self.semantic_name}'" if self.semantic_name else "" + return f"TypedParameter({self.value!r}, original_type={self.original_type.__name__}{name_part})" + + +@singledispatch +def _wrap_parameter_by_type(value: Any, semantic_name: "str | None" = None) -> Any: + return value + + +@_wrap_parameter_by_type.register +def _(value: bool, semantic_name: "str | None" = None) -> "TypedParameter": + return TypedParameter(value, bool, semantic_name) + + +@_wrap_parameter_by_type.register +def _(value: Decimal, semantic_name: "str | None" = None) -> "TypedParameter": + return TypedParameter(value, Decimal, semantic_name) + + +@_wrap_parameter_by_type.register +def _(value: datetime, semantic_name: "str | None" = None) -> "TypedParameter": + return TypedParameter(value, datetime, semantic_name) + + +@_wrap_parameter_by_type.register +def _(value: date, semantic_name: "str | None" = None) -> "TypedParameter": + return TypedParameter(value, date, semantic_name) + + +@_wrap_parameter_by_type.register +def _(value: time, semantic_name: "str | None" = None) -> "TypedParameter": + return TypedParameter(value, time, semantic_name) + + +@_wrap_parameter_by_type.register +def _(value: bytes, semantic_name: "str | None" = None) -> "TypedParameter": + return TypedParameter(value, bytes, semantic_name) + + +@mypyc_attr(allow_interpreted_subclasses=False) +class ParameterInfo: + """Metadata describing a single detected SQL parameter.""" + + __slots__ = ("name", "ordinal", "placeholder_text", "position", "style") + + def __init__( + self, name: "str | None", style: "ParameterStyle", position: int, ordinal: int, placeholder_text: str + ) -> None: + self.name = name + self.style = style + self.position = position + self.ordinal = ordinal + self.placeholder_text = placeholder_text + + def __repr__(self) -> str: + return ( + "ParameterInfo(" + f"name={self.name!r}, style={self.style!r}, position={self.position}, " + f"ordinal={self.ordinal}, placeholder_text={self.placeholder_text!r}" + ")" + ) + + +class ParameterStyleConfig: + """Configuration describing parameter behaviour for a statement.""" + + __slots__ = ( + "allow_mixed_parameter_styles", + "ast_transformer", + "default_execution_parameter_style", + "default_parameter_style", + "has_native_list_expansion", + "json_deserializer", + "json_serializer", + "needs_static_script_compilation", + "output_transformer", + "preserve_original_params_for_many", + "preserve_parameter_format", + "supported_execution_parameter_styles", + "supported_parameter_styles", + "type_coercion_map", + ) + + def __init__( + self, + default_parameter_style: "ParameterStyle", + supported_parameter_styles: "Collection[ParameterStyle] | None" = None, + supported_execution_parameter_styles: "Collection[ParameterStyle] | None" = None, + default_execution_parameter_style: "ParameterStyle | None" = None, + type_coercion_map: "Mapping[type, Callable[[Any], Any]] | None" = None, + has_native_list_expansion: bool = False, + needs_static_script_compilation: bool = False, + allow_mixed_parameter_styles: bool = False, + preserve_parameter_format: bool = True, + preserve_original_params_for_many: bool = False, + output_transformer: "Callable[[str, Any], tuple[str, Any]] | None" = None, + ast_transformer: "Callable[[Any, Any], tuple[Any, Any]] | None" = None, + json_serializer: "Callable[[Any], str] | None" = None, + json_deserializer: "Callable[[str], Any] | None" = None, + ) -> None: + self.default_parameter_style = default_parameter_style + self.supported_parameter_styles = frozenset(supported_parameter_styles or (default_parameter_style,)) + self.supported_execution_parameter_styles = ( + frozenset(supported_execution_parameter_styles) if supported_execution_parameter_styles else None + ) + self.default_execution_parameter_style = default_execution_parameter_style or default_parameter_style + self.type_coercion_map = dict(type_coercion_map or {}) + self.has_native_list_expansion = has_native_list_expansion + self.output_transformer = output_transformer + self.ast_transformer = ast_transformer + self.needs_static_script_compilation = needs_static_script_compilation + self.allow_mixed_parameter_styles = allow_mixed_parameter_styles + self.preserve_parameter_format = preserve_parameter_format + self.preserve_original_params_for_many = preserve_original_params_for_many + self.json_serializer = json_serializer + self.json_deserializer = json_deserializer + + def __hash__(self) -> int: + hash_components = ( + self.default_parameter_style.value, + frozenset(style.value for style in self.supported_parameter_styles), + ( + frozenset(style.value for style in self.supported_execution_parameter_styles) + if self.supported_execution_parameter_styles is not None + else None + ), + self.default_execution_parameter_style.value, + tuple(sorted(self.type_coercion_map.keys(), key=str)) if self.type_coercion_map else None, + self.has_native_list_expansion, + self.preserve_original_params_for_many, + bool(self.output_transformer), + self.needs_static_script_compilation, + self.allow_mixed_parameter_styles, + self.preserve_parameter_format, + bool(self.ast_transformer), + self.json_serializer, + self.json_deserializer, + ) + return hash(hash_components) + + def hash(self) -> int: + """Return the hash value for caching compatibility. + + Returns: + Hash value matching :func:`hash` output for this config. + """ + + return hash(self) + + def replace(self, **overrides: Any) -> "ParameterStyleConfig": + data: dict[str, Any] = { + "default_parameter_style": self.default_parameter_style, + "supported_parameter_styles": set(self.supported_parameter_styles), + "supported_execution_parameter_styles": ( + set(self.supported_execution_parameter_styles) + if self.supported_execution_parameter_styles is not None + else None + ), + "default_execution_parameter_style": self.default_execution_parameter_style, + "type_coercion_map": dict(self.type_coercion_map), + "has_native_list_expansion": self.has_native_list_expansion, + "needs_static_script_compilation": self.needs_static_script_compilation, + "allow_mixed_parameter_styles": self.allow_mixed_parameter_styles, + "preserve_parameter_format": self.preserve_parameter_format, + "preserve_original_params_for_many": self.preserve_original_params_for_many, + "output_transformer": self.output_transformer, + "ast_transformer": self.ast_transformer, + "json_serializer": self.json_serializer, + "json_deserializer": self.json_deserializer, + } + data.update(overrides) + return ParameterStyleConfig(**data) + + def with_json_serializers( + self, + serializer: "Callable[[Any], str]", + *, + tuple_strategy: "Literal['list', 'tuple']" = "list", + deserializer: "Callable[[str], Any] | None" = None, + ) -> "ParameterStyleConfig": + """Return a copy configured with JSON serializers for complex parameters.""" + + if tuple_strategy == "list": + + def tuple_adapter(value: Any) -> Any: + return serializer(list(value)) + + elif tuple_strategy == "tuple": + + def tuple_adapter(value: Any) -> Any: + return serializer(value) + + else: + msg = f"Unsupported tuple_strategy: {tuple_strategy}" + raise ValueError(msg) + + updated_type_map = dict(self.type_coercion_map) + updated_type_map[dict] = serializer + updated_type_map[list] = serializer + updated_type_map[tuple] = tuple_adapter + + return self.replace( + type_coercion_map=updated_type_map, + json_serializer=serializer, + json_deserializer=deserializer or self.json_deserializer, + ) + + +@dataclass(slots=True) +class DriverParameterProfile: + """Immutable adapter profile describing parameter defaults.""" + + name: str + default_style: "ParameterStyle" + supported_styles: "Collection[ParameterStyle]" + default_execution_style: "ParameterStyle" + supported_execution_styles: "Collection[ParameterStyle] | None" + has_native_list_expansion: bool + preserve_parameter_format: bool + needs_static_script_compilation: bool + allow_mixed_parameter_styles: bool + preserve_original_params_for_many: bool + json_serializer_strategy: "Literal['driver', 'helper', 'none']" + custom_type_coercions: "Mapping[type, Callable[[Any], Any]]" = field(default_factory=dict) + default_output_transformer: "Callable[[str, Any], tuple[str, Any]] | None" = None + default_ast_transformer: "Callable[[Any, Any], tuple[Any, Any]] | None" = None + extras: "Mapping[str, Any]" = field(default_factory=dict) + default_dialect: "str | None" = None + statement_kwargs: "Mapping[str, Any]" = field(default_factory=dict) + + def __post_init__(self) -> None: + self.supported_styles = frozenset(self.supported_styles) + self.supported_execution_styles = ( + frozenset(self.supported_execution_styles) if self.supported_execution_styles is not None else None + ) + self.custom_type_coercions = MappingProxyType(dict(self.custom_type_coercions)) + self.extras = MappingProxyType(dict(self.extras)) + self.statement_kwargs = MappingProxyType(dict(self.statement_kwargs)) + + +@mypyc_attr(allow_interpreted_subclasses=False) +class ParameterProfile: + """Aggregate metadata describing detected parameters.""" + + __slots__ = ("_parameters", "_placeholder_counts", "named_parameters", "reused_ordinals", "styles") + + def __init__(self, parameters: "Sequence[ParameterInfo] | None" = None) -> None: + param_tuple: tuple[ParameterInfo, ...] = tuple(parameters) if parameters else () + self._parameters = param_tuple + self.styles = tuple(sorted({param.style.value for param in param_tuple})) if param_tuple else () + placeholder_counts: dict[str, int] = {} + reused_ordinals: list[int] = [] + named_parameters: list[str] = [] + + for param in param_tuple: + placeholder = param.placeholder_text + current_count = placeholder_counts.get(placeholder, 0) + placeholder_counts[placeholder] = current_count + 1 + if current_count: + reused_ordinals.append(param.ordinal) + if param.name is not None: + named_parameters.append(param.name) + + self._placeholder_counts = placeholder_counts + self.reused_ordinals = tuple(reused_ordinals) + self.named_parameters = tuple(named_parameters) + + @classmethod + def empty(cls) -> "ParameterProfile": + return cls(()) + + @property + def parameters(self) -> "tuple[ParameterInfo, ...]": + return self._parameters + + @property + def total_count(self) -> int: + return len(self._parameters) + + def placeholder_count(self, placeholder: str) -> int: + return self._placeholder_counts.get(placeholder, 0) + + def is_empty(self) -> bool: + return not self._parameters + + +@mypyc_attr(allow_interpreted_subclasses=False) +class ParameterProcessingResult: + """Return container for parameter processing output.""" + + __slots__ = ("parameter_profile", "parameters", "sql") + + def __init__(self, sql: str, parameters: Any, parameter_profile: "ParameterProfile") -> None: + self.sql = sql + self.parameters = parameters + self.parameter_profile = parameter_profile + + def __iter__(self) -> "Generator[str | Any, Any, None]": + yield self.sql + yield self.parameters + + def __len__(self) -> int: + return 2 + + def __getitem__(self, index: int) -> Any: + if index == 0: + return self.sql + if index == 1: + return self.parameters + msg = "ParameterProcessingResult exposes exactly two positional items" + raise IndexError(msg) + + +def is_iterable_parameters(obj: Any) -> bool: + """Return True when the object behaves like an iterable parameter payload.""" + + return isinstance(obj, (list, tuple, set)) or ( + hasattr(obj, "__iter__") and not isinstance(obj, (str, bytes, Mapping)) + ) + + +def wrap_with_type(value: Any, semantic_name: "str | None" = None) -> Any: + """Wrap value with :class:`TypedParameter` if it benefits downstream processing.""" + + return _wrap_parameter_by_type(value, semantic_name) diff --git a/sqlspec/core/parameters/_validator.py b/sqlspec/core/parameters/_validator.py new file mode 100644 index 000000000..af4c069e6 --- /dev/null +++ b/sqlspec/core/parameters/_validator.py @@ -0,0 +1,121 @@ +"""Parameter extraction utilities.""" + +import re +from collections import OrderedDict + +from mypy_extensions import mypyc_attr + +from sqlspec.core.parameters._types import ParameterInfo, ParameterStyle + +__all__ = ("PARAMETER_REGEX", "ParameterValidator") + +PARAMETER_REGEX = re.compile( + r""" + (?P"(?:[^"\\]|\\.)*") | + (?P'(?:[^'\\]|\\.)*') | + (?P\$(?P\w*)?\$[\s\S]*?\$\4\$) | + (?P--[^\r\n]*) | + (?P/\*(?:[^*]|\*(?!/))*\*/) | + (?P\?\?|\?\||\?&) | + (?P::(?P\w+)) | + (?P%\((?P\w+)\)s) | + (?P%s) | + (?P:(?P\d+)) | + (?P:(?P\w+)) | + (?P@(?P\w+)) | + (?P\$(?P\d+)) | + (?P\$(?P\w+)) | + (?P\?) + """, + re.VERBOSE | re.IGNORECASE | re.MULTILINE | re.DOTALL, +) + + +@mypyc_attr(allow_interpreted_subclasses=False) +class ParameterValidator: + """Extracts placeholder metadata and dialect compatibility information.""" + + __slots__ = ("_cache_max_size", "_parameter_cache") + + def __init__(self, cache_max_size: int = 5000) -> None: + self._parameter_cache: OrderedDict[str, list[ParameterInfo]] = OrderedDict() + self._cache_max_size = cache_max_size + + def _extract_parameter_style(self, match: re.Match[str]) -> "tuple[ParameterStyle | None, str | None]": + """Map a regex match to a placeholder style and optional name.""" + if match.group("qmark"): + return ParameterStyle.QMARK, None + if match.group("named_colon"): + return ParameterStyle.NAMED_COLON, match.group("colon_name") + if match.group("numeric"): + return ParameterStyle.NUMERIC, match.group("numeric_num") + if match.group("named_at"): + return ParameterStyle.NAMED_AT, match.group("at_name") + if match.group("pyformat_named"): + return ParameterStyle.NAMED_PYFORMAT, match.group("pyformat_name") + if match.group("pyformat_pos"): + return ParameterStyle.POSITIONAL_PYFORMAT, None + if match.group("positional_colon"): + return ParameterStyle.POSITIONAL_COLON, match.group("colon_num") + if match.group("named_dollar_param"): + return ParameterStyle.NAMED_DOLLAR, match.group("dollar_param_name") + return None, None + + def extract_parameters(self, sql: str) -> "list[ParameterInfo]": + """Extract ordered parameter metadata from SQL text.""" + cached_result = self._parameter_cache.get(sql) + if cached_result is not None: + self._parameter_cache.move_to_end(sql) + return cached_result + + if not any(c in sql for c in ("?", "%", ":", "@", "$")): + if len(self._parameter_cache) >= self._cache_max_size: + self._parameter_cache.popitem(last=False) + self._parameter_cache[sql] = [] + return [] + + parameters: list[ParameterInfo] = [] + ordinal = 0 + + skip_groups = ( + "dquote", + "squote", + "dollar_quoted_string", + "line_comment", + "block_comment", + "pg_q_operator", + "pg_cast", + ) + + for match in PARAMETER_REGEX.finditer(sql): + if any(match.group(group) for group in skip_groups): + continue + style, name = self._extract_parameter_style(match) + if style is None: + continue + placeholder_text = match.group(0) + parameters.append(ParameterInfo(name, style, match.start(), ordinal, placeholder_text)) + ordinal += 1 + + if len(self._parameter_cache) >= self._cache_max_size: + self._parameter_cache.popitem(last=False) + self._parameter_cache[sql] = parameters + return parameters + + def get_sqlglot_incompatible_styles(self, dialect: str | None = None) -> "set[ParameterStyle]": + """Return placeholder styles incompatible with SQLGlot for the dialect.""" + base_incompatible = { + ParameterStyle.NAMED_PYFORMAT, + ParameterStyle.POSITIONAL_PYFORMAT, + ParameterStyle.POSITIONAL_COLON, + } + + if dialect and dialect.lower() in {"mysql", "mariadb"}: + return base_incompatible + if dialect and dialect.lower() in {"postgres", "postgresql"}: + return {ParameterStyle.POSITIONAL_COLON} + if dialect and dialect.lower() == "sqlite": + return {ParameterStyle.POSITIONAL_COLON} + if dialect and dialect.lower() in {"oracle", "bigquery"}: + return base_incompatible + return base_incompatible diff --git a/tests/integration/test_adapters/test_adbc/test_parameter_styles.py b/tests/integration/test_adapters/test_adbc/test_parameter_styles.py index 494a9d615..c1509052e 100644 --- a/tests/integration/test_adapters/test_adbc/test_parameter_styles.py +++ b/tests/integration/test_adapters/test_adbc/test_parameter_styles.py @@ -579,7 +579,7 @@ def test_adbc_ast_transformer_validation_fixed(adbc_postgresql_session: AdbcDriv """ from sqlglot import parse_one - from sqlspec.adapters.adbc.driver import _adbc_ast_transformer + from sqlspec.core.parameters import replace_null_parameters_with_literals # Create a test case with parameter count mismatch original_sql = "INSERT INTO bug_test (id, col1) VALUES ($1, $2)" @@ -590,7 +590,7 @@ def test_adbc_ast_transformer_validation_fixed(adbc_postgresql_session: AdbcDriv # FIXED: AST transformer now validates parameter count and rejects mismatches with pytest.raises(SQLSpecError) as exc_info: - _adbc_ast_transformer(parsed, original_params) + replace_null_parameters_with_literals(parsed, original_params, dialect="postgres") # Verify we get the correct error message error_msg = str(exc_info.value).lower() @@ -599,7 +599,7 @@ def test_adbc_ast_transformer_validation_fixed(adbc_postgresql_session: AdbcDriv # Verify that correct parameter count works fine correct_params = (200, None) # 2 params for 2 placeholders - modified_ast, cleaned_params = _adbc_ast_transformer(parsed, correct_params) + modified_ast, cleaned_params = replace_null_parameters_with_literals(parsed, correct_params, dialect="postgres") # Convert back to SQL to see the transformation transformed_sql = modified_ast.sql(dialect="postgres") diff --git a/tests/unit/test_adapters/test_asyncmy/test_config.py b/tests/unit/test_adapters/test_asyncmy/test_config.py new file mode 100644 index 000000000..aba0b3a85 --- /dev/null +++ b/tests/unit/test_adapters/test_asyncmy/test_config.py @@ -0,0 +1,36 @@ +"""Asyncmy configuration tests covering statement config builders.""" + +from sqlspec.adapters.asyncmy.config import AsyncmyConfig +from sqlspec.adapters.asyncmy.driver import build_asyncmy_statement_config + + +def test_build_asyncmy_statement_config_custom_serializers() -> None: + """Custom serializers should propagate into the parameter configuration.""" + + def serializer(_: object) -> str: + return "serialized" + + def deserializer(_: str) -> object: + return {"value": "deserialized"} + + statement_config = build_asyncmy_statement_config(json_serializer=serializer, json_deserializer=deserializer) + + parameter_config = statement_config.parameter_config + assert parameter_config.json_serializer is serializer + assert parameter_config.json_deserializer is deserializer + + +def test_asyncmy_config_applies_driver_feature_serializers() -> None: + """Driver features should mutate the Asyncmy statement configuration.""" + + def serializer(_: object) -> str: + return "feature" + + def deserializer(_: str) -> object: + return {"feature": True} + + config = AsyncmyConfig(driver_features={"json_serializer": serializer, "json_deserializer": deserializer}) + + parameter_config = config.statement_config.parameter_config + assert parameter_config.json_serializer is serializer + assert parameter_config.json_deserializer is deserializer diff --git a/tests/unit/test_adapters/test_asyncpg/test_config.py b/tests/unit/test_adapters/test_asyncpg/test_config.py index e69de29bb..4505c6903 100644 --- a/tests/unit/test_adapters/test_asyncpg/test_config.py +++ b/tests/unit/test_adapters/test_asyncpg/test_config.py @@ -0,0 +1,36 @@ +"""AsyncPG configuration tests covering statement config builders.""" + +from sqlspec.adapters.asyncpg.config import AsyncpgConfig +from sqlspec.adapters.asyncpg.driver import build_asyncpg_statement_config + + +def test_build_asyncpg_statement_config_custom_serializers() -> None: + """Custom serializers should propagate into the parameter configuration.""" + + def serializer(_: object) -> str: + return "serialized" + + def deserializer(_: str) -> object: + return {"value": "deserialized"} + + statement_config = build_asyncpg_statement_config(json_serializer=serializer, json_deserializer=deserializer) + + parameter_config = statement_config.parameter_config + assert parameter_config.json_serializer is serializer + assert parameter_config.json_deserializer is deserializer + + +def test_asyncpg_config_applies_driver_feature_serializers() -> None: + """Driver features should mutate the AsyncPG statement configuration.""" + + def serializer(_: object) -> str: + return "feature" + + def deserializer(_: str) -> object: + return {"feature": True} + + config = AsyncpgConfig(driver_features={"json_serializer": serializer, "json_deserializer": deserializer}) + + parameter_config = config.statement_config.parameter_config + assert parameter_config.json_serializer is serializer + assert parameter_config.json_deserializer is deserializer diff --git a/tests/unit/test_adapters/test_bigquery/test_config.py b/tests/unit/test_adapters/test_bigquery/test_config.py new file mode 100644 index 000000000..b73c93909 --- /dev/null +++ b/tests/unit/test_adapters/test_bigquery/test_config.py @@ -0,0 +1,28 @@ +"""BigQuery configuration tests covering statement config builders.""" + +from sqlspec.adapters.bigquery.config import BigQueryConfig +from sqlspec.adapters.bigquery.driver import build_bigquery_statement_config + + +def test_build_bigquery_statement_config_custom_serializer() -> None: + """Custom serializer should propagate into the parameter configuration.""" + + def serializer(_: object) -> str: + return "serialized" + + statement_config = build_bigquery_statement_config(json_serializer=serializer) + + parameter_config = statement_config.parameter_config + assert parameter_config.json_serializer is serializer + + +def test_bigquery_config_applies_driver_feature_serializer() -> None: + """Driver features should mutate the BigQuery statement configuration.""" + + def serializer(_: object) -> str: + return "feature" + + config = BigQueryConfig(driver_features={"json_serializer": serializer}) + + parameter_config = config.statement_config.parameter_config + assert parameter_config.json_serializer is serializer diff --git a/tests/unit/test_adapters/test_duckdb/test_config.py b/tests/unit/test_adapters/test_duckdb/test_config.py new file mode 100644 index 000000000..7a36b762f --- /dev/null +++ b/tests/unit/test_adapters/test_duckdb/test_config.py @@ -0,0 +1,28 @@ +"""DuckDB configuration tests covering statement config builders.""" + +from sqlspec.adapters.duckdb.config import DuckDBConfig +from sqlspec.adapters.duckdb.driver import build_duckdb_statement_config + + +def test_build_duckdb_statement_config_custom_serializer() -> None: + """Custom serializer should propagate into the parameter configuration.""" + + def serializer(_: object) -> str: + return "serialized" + + statement_config = build_duckdb_statement_config(json_serializer=serializer) + + parameter_config = statement_config.parameter_config + assert parameter_config.json_serializer is serializer + + +def test_duckdb_config_applies_driver_feature_serializer() -> None: + """Driver features should mutate the DuckDB statement configuration.""" + + def serializer(_: object) -> str: + return "feature" + + config = DuckDBConfig(driver_features={"json_serializer": serializer}) + + parameter_config = config.statement_config.parameter_config + assert parameter_config.json_serializer is serializer diff --git a/tests/unit/test_adapters/test_psqlpy/__init__.py b/tests/unit/test_adapters/test_psqlpy/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/test_adapters/test_psqlpy/test_config.py b/tests/unit/test_adapters/test_psqlpy/test_config.py new file mode 100644 index 000000000..0ecde7253 --- /dev/null +++ b/tests/unit/test_adapters/test_psqlpy/test_config.py @@ -0,0 +1,28 @@ +"""Psqlpy configuration tests covering statement config builders.""" + +from sqlspec.adapters.psqlpy.config import PsqlpyConfig +from sqlspec.adapters.psqlpy.driver import build_psqlpy_statement_config + + +def test_build_psqlpy_statement_config_custom_serializer() -> None: + """Custom serializer should propagate into the parameter configuration.""" + + def serializer(_: object) -> str: + return "serialized" + + statement_config = build_psqlpy_statement_config(json_serializer=serializer) + + parameter_config = statement_config.parameter_config + assert parameter_config.json_serializer is serializer + + +def test_psqlpy_config_applies_driver_feature_serializer() -> None: + """Driver features should mutate the Psqlpy statement configuration.""" + + def serializer(_: object) -> str: + return "feature" + + config = PsqlpyConfig(driver_features={"json_serializer": serializer}) + + parameter_config = config.statement_config.parameter_config + assert parameter_config.json_serializer is serializer diff --git a/tests/unit/test_adapters/test_psycopg/test_config.py b/tests/unit/test_adapters/test_psycopg/test_config.py new file mode 100644 index 000000000..8ee13ec78 --- /dev/null +++ b/tests/unit/test_adapters/test_psycopg/test_config.py @@ -0,0 +1,40 @@ +"""Psycopg configuration tests covering statement config builders.""" + +from sqlspec.adapters.psycopg.config import PsycopgAsyncConfig, PsycopgSyncConfig +from sqlspec.adapters.psycopg.driver import build_psycopg_statement_config + + +def test_build_psycopg_statement_config_custom_serializer() -> None: + """Custom serializer should propagate into the parameter configuration.""" + + def serializer(_: object) -> str: + return "serialized" + + statement_config = build_psycopg_statement_config(json_serializer=serializer) + + parameter_config = statement_config.parameter_config + assert parameter_config.json_serializer is serializer + + +def test_psycopg_sync_config_applies_driver_feature_serializer() -> None: + """Driver features should mutate the sync Psycopg statement configuration.""" + + def serializer(_: object) -> str: + return "sync" + + config = PsycopgSyncConfig(driver_features={"json_serializer": serializer}) + + parameter_config = config.statement_config.parameter_config + assert parameter_config.json_serializer is serializer + + +def test_psycopg_async_config_applies_driver_feature_serializer() -> None: + """Driver features should mutate the async Psycopg statement configuration.""" + + def serializer(_: object) -> str: + return "async" + + config = PsycopgAsyncConfig(driver_features={"json_serializer": serializer}) + + parameter_config = config.statement_config.parameter_config + assert parameter_config.json_serializer is serializer diff --git a/tests/unit/test_cli/test_config_loading.py b/tests/unit/test_cli/test_config_loading.py index f5942a6be..d440a3047 100644 --- a/tests/unit/test_cli/test_config_loading.py +++ b/tests/unit/test_cli/test_config_loading.py @@ -3,6 +3,7 @@ import os import sys import tempfile +import uuid from collections.abc import Iterator from pathlib import Path @@ -11,6 +12,8 @@ from sqlspec.cli import add_migration_commands +MODULE_PREFIX = "cli_test_config_" + @pytest.fixture def cleanup_test_modules() -> Iterator[None]: @@ -19,12 +22,18 @@ def cleanup_test_modules() -> Iterator[None]: yield # Remove any test modules that were imported during the test modules_after = set(sys.modules.keys()) - test_modules = {m for m in modules_after - modules_before if m.startswith("test_config")} + test_modules = {m for m in modules_after - modules_before if m.startswith(MODULE_PREFIX)} for module in test_modules: if module in sys.modules: del sys.modules[module] +def _create_module(path: "Path", content: str) -> str: + module_name = f"{MODULE_PREFIX}{uuid.uuid4().hex}" + (path / f"{module_name}.py").write_text(content) + return module_name + + def test_direct_config_instance_loading(cleanup_test_modules: None) -> None: """Test loading a direct config instance through CLI.""" runner = CliRunner() @@ -41,13 +50,15 @@ def test_direct_config_instance_loading(cleanup_test_modules: None) -> None: ) database_config = config """ - (Path(temp_dir) / "test_config.py").write_text(config_module) + module_name = _create_module(Path(temp_dir), config_module) # Change to the temp directory original_cwd = os.getcwd() try: os.chdir(temp_dir) - result = runner.invoke(add_migration_commands(), ["--config", "test_config.database_config", "show-config"]) + result = runner.invoke( + add_migration_commands(), ["--config", f"{module_name}.database_config", "show-config"] + ) finally: os.chdir(original_cwd) @@ -73,14 +84,14 @@ def get_database_config(): ) return config """ - (Path(temp_dir) / "test_config.py").write_text(config_module) + module_name = _create_module(Path(temp_dir), config_module) # Change to the temp directory original_cwd = os.getcwd() try: os.chdir(temp_dir) result = runner.invoke( - add_migration_commands(), ["--config", "test_config.get_database_config", "show-config"] + add_migration_commands(), ["--config", f"{module_name}.get_database_config", "show-config"] ) finally: os.chdir(original_cwd) @@ -110,14 +121,14 @@ async def get_database_config(): ) return config """ - (Path(temp_dir) / "test_config.py").write_text(config_module) + module_name = _create_module(Path(temp_dir), config_module) # Change to the temp directory original_cwd = os.getcwd() try: os.chdir(temp_dir) result = runner.invoke( - add_migration_commands(), ["--config", "test_config.get_database_config", "show-config"] + add_migration_commands(), ["--config", f"{module_name}.get_database_config", "show-config"] ) finally: os.chdir(original_cwd) @@ -146,13 +157,15 @@ def test_show_config_with_path_object(cleanup_test_modules: None) -> None: ) database_config = config """ - (Path(temp_dir) / "test_config.py").write_text(config_module) + module_name = _create_module(Path(temp_dir), config_module) # Change to the temp directory original_cwd = os.getcwd() try: os.chdir(temp_dir) - result = runner.invoke(add_migration_commands(), ["--config", "test_config.database_config", "show-config"]) + result = runner.invoke( + add_migration_commands(), ["--config", f"{module_name}.database_config", "show-config"] + ) finally: os.chdir(original_cwd) diff --git a/tests/unit/test_cli/test_migration_commands.py b/tests/unit/test_cli/test_migration_commands.py index 36bfd6463..c66095d09 100644 --- a/tests/unit/test_cli/test_migration_commands.py +++ b/tests/unit/test_cli/test_migration_commands.py @@ -3,6 +3,7 @@ import os import sys import tempfile +import uuid from collections.abc import Iterator from pathlib import Path from typing import TYPE_CHECKING @@ -13,6 +14,8 @@ from sqlspec.cli import add_migration_commands +MODULE_PREFIX = "cli_test_config_" + if TYPE_CHECKING: from unittest.mock import Mock @@ -24,12 +27,18 @@ def cleanup_test_modules() -> Iterator[None]: yield # Remove any test modules that were imported during the test modules_after = set(sys.modules.keys()) - test_modules = {m for m in modules_after - modules_before if m.startswith("test_config")} + test_modules = {m for m in modules_after - modules_before if m.startswith(MODULE_PREFIX)} for module in test_modules: if module in sys.modules: del sys.modules[module] +def _create_module(content: str, directory: "Path") -> str: + module_name = f"{MODULE_PREFIX}{uuid.uuid4().hex}" + (directory / f"{module_name}.py").write_text(content) + return module_name + + def test_show_config_command(cleanup_test_modules: None) -> None: """Test show-config command displays migration configurations.""" runner = CliRunner() @@ -52,9 +61,9 @@ def get_config(): ) return config """ - Path("test_config.py").write_text(config_module) + module_name = _create_module(config_module, Path(temp_dir)) - result = runner.invoke(add_migration_commands(), ["--config", "test_config.get_config", "show-config"]) + result = runner.invoke(add_migration_commands(), ["--config", f"{module_name}.get_config", "show-config"]) finally: os.chdir(original_dir) @@ -91,9 +100,9 @@ def get_configs(): return [sqlite_config, duckdb_config] """ - Path("test_config.py").write_text(config_module) + module_name = _create_module(config_module, Path(temp_dir)) - result = runner.invoke(add_migration_commands(), ["--config", "test_config.get_configs", "show-config"]) + result = runner.invoke(add_migration_commands(), ["--config", f"{module_name}.get_configs", "show-config"]) finally: os.chdir(original_dir) @@ -123,9 +132,9 @@ def get_config(): ) return config """ - Path("test_config.py").write_text(config_module) + module_name = _create_module(config_module, Path(temp_dir)) - result = runner.invoke(add_migration_commands(), ["--config", "test_config.get_config", "show-config"]) + result = runner.invoke(add_migration_commands(), ["--config", f"{module_name}.get_config", "show-config"]) finally: os.chdir(original_dir) @@ -161,10 +170,10 @@ def get_config(): ) return config """ - Path("test_config.py").write_text(config_module) + module_name = _create_module(config_module, Path(temp_dir)) result = runner.invoke( - add_migration_commands(), ["--config", "test_config.get_config", "show-current-revision"] + add_migration_commands(), ["--config", f"{module_name}.get_config", "show-current-revision"] ) finally: @@ -198,10 +207,11 @@ def get_config(): ) return config """ - Path("test_config.py").write_text(config_module) + module_name = _create_module(config_module, Path(temp_dir)) result = runner.invoke( - add_migration_commands(), ["--config", "test_config.get_config", "show-current-revision", "--verbose"] + add_migration_commands(), + ["--config", f"{module_name}.get_config", "show-current-revision", "--verbose"], ) finally: @@ -233,10 +243,10 @@ def get_config(): config.migration_config = {"script_location": "test_migrations"} return config """ - Path("test_config.py").write_text(config_module) + module_name = _create_module(config_module, Path(temp_dir)) result = runner.invoke( - add_migration_commands(), ["--config", "test_config.get_config", "init", "--no-prompt"] + add_migration_commands(), ["--config", f"{module_name}.get_config", "init", "--no-prompt"] ) finally: @@ -268,11 +278,11 @@ def get_config(): config.migration_config = {"script_location": "migrations"} return config """ - Path("test_config.py").write_text(config_module) + module_name = _create_module(config_module, Path(temp_dir)) result = runner.invoke( add_migration_commands(), - ["--config", "test_config.get_config", "init", "custom_migrations", "--no-prompt"], + ["--config", f"{module_name}.get_config", "init", "custom_migrations", "--no-prompt"], ) finally: @@ -304,11 +314,11 @@ def get_config(): config.migration_config = {"enabled": True} return config """ - Path("test_config.py").write_text(config_module) + module_name = _create_module(config_module, Path(temp_dir)) result = runner.invoke( add_migration_commands(), - ["--config", "test_config.get_config", "create-migration", "-m", "test migration", "--no-prompt"], + ["--config", f"{module_name}.get_config", "create-migration", "-m", "test migration", "--no-prompt"], ) finally: @@ -340,11 +350,11 @@ def get_config(): config.migration_config = {"enabled": True} return config """ - Path("test_config.py").write_text(config_module) + module_name = _create_module(config_module, Path(temp_dir)) result = runner.invoke( add_migration_commands(), - ["--config", "test_config.get_config", "make-migration", "-m", "test migration", "--no-prompt"], + ["--config", f"{module_name}.get_config", "make-migration", "-m", "test migration", "--no-prompt"], ) finally: @@ -376,10 +386,10 @@ def get_config(): config.migration_config = {"enabled": True} return config """ - Path("test_config.py").write_text(config_module) + module_name = _create_module(config_module, Path(temp_dir)) result = runner.invoke( - add_migration_commands(), ["--config", "test_config.get_config", "upgrade", "--no-prompt"] + add_migration_commands(), ["--config", f"{module_name}.get_config", "upgrade", "--no-prompt"] ) finally: @@ -411,10 +421,10 @@ def get_config(): config.migration_config = {"enabled": True} return config """ - Path("test_config.py").write_text(config_module) + module_name = _create_module(config_module, Path(temp_dir)) result = runner.invoke( - add_migration_commands(), ["--config", "test_config.get_config", "upgrade", "abc123", "--no-prompt"] + add_migration_commands(), ["--config", f"{module_name}.get_config", "upgrade", "abc123", "--no-prompt"] ) finally: @@ -446,10 +456,10 @@ def get_config(): config.migration_config = {"enabled": True} return config """ - Path("test_config.py").write_text(config_module) + module_name = _create_module(config_module, Path(temp_dir)) result = runner.invoke( - add_migration_commands(), ["--config", "test_config.get_config", "downgrade", "--no-prompt"] + add_migration_commands(), ["--config", f"{module_name}.get_config", "downgrade", "--no-prompt"] ) finally: @@ -481,9 +491,11 @@ def get_config(): config.migration_config = {"enabled": True} return config """ - Path("test_config.py").write_text(config_module) + module_name = _create_module(config_module, Path(temp_dir)) - result = runner.invoke(add_migration_commands(), ["--config", "test_config.get_config", "stamp", "abc123"]) + result = runner.invoke( + add_migration_commands(), ["--config", f"{module_name}.get_config", "stamp", "abc123"] + ) finally: os.chdir(original_dir) @@ -520,11 +532,11 @@ def get_configs(): return [sqlite_config, duckdb_config] """ - Path("test_config.py").write_text(config_module) + module_name = _create_module(config_module, Path(temp_dir)) result = runner.invoke( add_migration_commands(), - ["--config", "test_config.get_configs", "show-current-revision", "--include", "sqlite_multi"], + ["--config", f"{module_name}.get_configs", "show-current-revision", "--include", "sqlite_multi"], ) finally: @@ -562,10 +574,10 @@ def get_configs(): return [config1, config2] """ - Path("test_config.py").write_text(config_module) + module_name = _create_module(config_module, Path(temp_dir)) result = runner.invoke( - add_migration_commands(), ["--config", "test_config.get_configs", "upgrade", "--dry-run"] + add_migration_commands(), ["--config", f"{module_name}.get_configs", "upgrade", "--dry-run"] ) finally: @@ -595,7 +607,7 @@ def get_config(): config.migration_config = {"enabled": True} return config """ - Path("test_config.py").write_text(config_module) + module_name = _create_module(config_module, Path(temp_dir)) with patch("sqlspec.migrations.commands.create_migration_commands") as mock_create: mock_commands = Mock() @@ -604,7 +616,7 @@ def get_config(): result = runner.invoke( add_migration_commands(), - ["--config", "test_config.get_config", "upgrade", "--execution-mode", "sync", "--no-prompt"], + ["--config", f"{module_name}.get_config", "upgrade", "--execution-mode", "sync", "--no-prompt"], ) finally: @@ -632,11 +644,11 @@ def get_config(): migration_config={"enabled": True, "script_location": "migrations"} ) """ - Path("test_config.py").write_text(config_module) + module_name = _create_module(config_module, Path(temp_dir)) result = runner.invoke( add_migration_commands(), - ["--config", "test_config.get_config", "show-config", "--bind-key", "target_config"], + ["--config", f"{module_name}.get_config", "show-config", "--bind-key", "target_config"], ) finally: @@ -679,12 +691,12 @@ def get_configs(): return [sqlite_config, duckdb_config, postgres_config] """ - Path("test_config.py").write_text(config_module) + module_name = _create_module(config_module, Path(temp_dir)) # Test filtering for sqlite_db only result = runner.invoke( add_migration_commands(), - ["--config", "test_config.get_configs", "show-config", "--bind-key", "sqlite_db"], + ["--config", f"{module_name}.get_configs", "show-config", "--bind-key", "sqlite_db"], ) finally: @@ -718,11 +730,11 @@ def get_configs(): ) ] """ - Path("test_config.py").write_text(config_module) + module_name = _create_module(config_module, Path(temp_dir)) result = runner.invoke( add_migration_commands(), - ["--config", "test_config.get_configs", "show-config", "--bind-key", "nonexistent"], + ["--config", f"{module_name}.get_configs", "show-config", "--bind-key", "nonexistent"], ) finally: @@ -763,11 +775,18 @@ def get_multi_configs(): ) ] """ - Path("test_config.py").write_text(config_module) + module_name = _create_module(config_module, Path(temp_dir)) result = runner.invoke( add_migration_commands(), - ["--config", "test_config.get_multi_configs", "upgrade", "--bind-key", "analytics_db", "--no-prompt"], + [ + "--config", + f"{module_name}.get_multi_configs", + "upgrade", + "--bind-key", + "analytics_db", + "--no-prompt", + ], ) finally: diff --git a/tests/unit/test_core/test_parameters.py b/tests/unit/test_core/test_parameters.py index a85a5f696..f5bdffad2 100644 --- a/tests/unit/test_core/test_parameters.py +++ b/tests/unit/test_core/test_parameters.py @@ -8,14 +8,18 @@ - Performance and edge cases """ +import json import math from datetime import date, datetime from decimal import Decimal from typing import Any import pytest +import sqlglot from sqlspec.core.parameters import ( + DRIVER_PARAMETER_PROFILES, + DriverParameterProfile, ParameterConverter, ParameterInfo, ParameterProcessor, @@ -23,9 +27,16 @@ ParameterStyleConfig, ParameterValidator, TypedParameter, + build_statement_config_from_profile, + get_driver_profile, is_iterable_parameters, + register_driver_profile, + replace_null_parameters_with_literals, + replace_placeholders_with_literals, wrap_with_type, ) +from sqlspec.exceptions import ImproperConfigurationError +from sqlspec.utils.serializers import from_json, to_json pytestmark = pytest.mark.xdist_group("core") @@ -198,7 +209,224 @@ def test_mixed_named_and_numeric_parameters() -> None: assert converted_sql == "SELECT $1::text as name, $2::int as age" assert converted_params == ("Mixed", 25) - assert len(converted_params) == 2 + + +def test_build_statement_config_helper_strategy_applies_serializer() -> None: + """Helper strategy profiles should inject JSON serializers and tuple adapters.""" + + calls: list[Any] = [] + + def custom_serializer(value: Any) -> str: + calls.append(value) + return f"encoded:{value}" + + profile = get_driver_profile("sqlite") + config = build_statement_config_from_profile(profile, json_serializer=custom_serializer) + + parameter_config = config.parameter_config + assert parameter_config.json_serializer is custom_serializer + + dict_encoder = parameter_config.type_coercion_map[dict] + encoded_dict = dict_encoder({"a": 1}) + assert encoded_dict == "encoded:{'a': 1}" + + tuple_encoder = parameter_config.type_coercion_map[tuple] + encoded_tuple = tuple_encoder((1, 2)) + assert encoded_tuple == "encoded:[1, 2]" + assert calls == [{"a": 1}, [1, 2]] + + +def test_build_statement_config_helper_strategy_defaults_to_json() -> None: + """Helper strategy should fall back to module JSON helpers when none provided.""" + + profile = get_driver_profile("sqlite") + config = build_statement_config_from_profile(profile) + + parameter_config = config.parameter_config + assert parameter_config.json_serializer is to_json + + dict_encoder = parameter_config.type_coercion_map[dict] + encoded_dict = dict_encoder({"a": 1}) + assert isinstance(encoded_dict, str) + assert json.loads(encoded_dict) == {"a": 1} + + tuple_encoder = parameter_config.type_coercion_map[tuple] + encoded_tuple = tuple_encoder((1, 2)) + assert isinstance(encoded_tuple, str) + assert json.loads(encoded_tuple) == [1, 2] + + +def test_build_statement_config_driver_strategy_preserves_type_map() -> None: + """Driver strategy should leave type coercion map unmodified except JSON slots.""" + + def dummy_serializer(value: Any) -> str: + return f"json:{value}" + + profile = get_driver_profile("asyncpg") + config = build_statement_config_from_profile(profile, json_serializer=dummy_serializer) + + parameter_config = config.parameter_config + assert parameter_config.json_serializer is dummy_serializer + assert dict not in parameter_config.type_coercion_map + assert tuple not in parameter_config.type_coercion_map + + +def test_build_statement_config_driver_strategy_defaults_to_json() -> None: + """Driver strategy should wire default JSON helpers when overrides absent.""" + + profile = get_driver_profile("asyncpg") + config = build_statement_config_from_profile(profile) + + parameter_config = config.parameter_config + assert parameter_config.json_serializer is to_json + assert parameter_config.json_deserializer is from_json + assert parameter_config.type_coercion_map == dict(profile.custom_type_coercions) + + +def test_build_statement_config_helper_tuple_strategy_override() -> None: + """Overriding tuple strategy to tuple should preserve tuple payload.""" + + captured: list[Any] = [] + + def recorder(value: Any) -> str: + captured.append(value) + return f"encoded:{value}" + + profile = get_driver_profile("sqlite") + config = build_statement_config_from_profile( + profile, parameter_overrides={"json_tuple_strategy": "tuple"}, json_serializer=recorder + ) + + tuple_encoder = config.parameter_config.type_coercion_map[tuple] + encoded_value = tuple_encoder((1, 2)) + + assert encoded_value == "encoded:(1, 2)" + assert captured[-1] == (1, 2) + + +def test_replace_null_parameters_with_literals_numeric_dialect() -> None: + """Null parameters should render as literals and shrink parameter list.""" + + expression = sqlglot.parse_one("INSERT INTO test VALUES ($1, $2)", dialect="postgres") + modified_expression, cleaned_params = replace_null_parameters_with_literals( + expression, (42, None), dialect="postgres" + ) + + assert modified_expression.sql(dialect="postgres") == "INSERT INTO test VALUES ($1, NULL)" + assert cleaned_params == [42] + + +def test_replace_placeholders_with_literals_basic_sequence() -> None: + """Placeholders are replaced by literals when provided with positional parameters.""" + + expression = sqlglot.parse_one("SELECT ? AS value", dialect="bigquery") + transformed = replace_placeholders_with_literals(expression, [123], json_serializer=json.dumps) + + assert transformed.sql(dialect="bigquery") == "SELECT 123 AS value" + + +def test_replace_placeholders_with_literals_named_mapping() -> None: + """Named parameters in mappings are embedded as string literals.""" + + expression = sqlglot.parse_one("SELECT @name AS user", dialect="bigquery") + transformed = replace_placeholders_with_literals(expression, {"@name": "bob"}, json_serializer=json.dumps) + + assert transformed.sql(dialect="bigquery") == "SELECT 'bob' AS user" + + +def test_build_statement_config_applies_overrides_and_extras() -> None: + """Overrides and extras should both be reflected in the resulting config.""" + + def uppercase(value: str) -> str: + return value.upper() + + overrides = {"type_coercion_map": {str: uppercase}, "supported_parameter_styles": {ParameterStyle.NAMED_AT}} + + profile = get_driver_profile("bigquery") + config = build_statement_config_from_profile(profile, parameter_overrides=overrides) + + parameter_config = config.parameter_config + assert parameter_config.supported_parameter_styles == {ParameterStyle.NAMED_AT} + assert parameter_config.type_coercion_map[str]("value") == "VALUE" + assert parameter_config.type_coercion_map[tuple]([1, 2]) == [1, 2] + + +def test_get_driver_profile_missing_raises() -> None: + """Unknown adapter keys should raise ImproperConfigurationError.""" + + with pytest.raises(ImproperConfigurationError): + get_driver_profile("does-not-exist") + + +def test_register_driver_profile_duplicate_guard() -> None: + """Registering the same adapter twice without override should fail.""" + + profile = DriverParameterProfile( + name="TestAdapter", + default_style=ParameterStyle.QMARK, + supported_styles={ParameterStyle.QMARK}, + default_execution_style=ParameterStyle.QMARK, + supported_execution_styles={ParameterStyle.QMARK}, + has_native_list_expansion=False, + preserve_parameter_format=True, + needs_static_script_compilation=False, + allow_mixed_parameter_styles=False, + preserve_original_params_for_many=False, + json_serializer_strategy="helper", + ) + + key = "test-duplicate" + DRIVER_PARAMETER_PROFILES.pop(key, None) + register_driver_profile(key, profile) + try: + with pytest.raises(ImproperConfigurationError): + register_driver_profile(key, profile) + finally: + DRIVER_PARAMETER_PROFILES.pop(key, None) + + +def test_register_driver_profile_allows_override() -> None: + """allow_override should replace an existing driver profile.""" + + base_profile = DriverParameterProfile( + name="TestAdapter", + default_style=ParameterStyle.QMARK, + supported_styles={ParameterStyle.QMARK}, + default_execution_style=ParameterStyle.QMARK, + supported_execution_styles={ParameterStyle.QMARK}, + has_native_list_expansion=False, + preserve_parameter_format=True, + needs_static_script_compilation=False, + allow_mixed_parameter_styles=False, + preserve_original_params_for_many=False, + json_serializer_strategy="helper", + ) + replacement_profile = DriverParameterProfile( + name="TestAdapterReplacement", + default_style=ParameterStyle.NUMERIC, + supported_styles={ParameterStyle.NUMERIC}, + default_execution_style=ParameterStyle.NUMERIC, + supported_execution_styles={ParameterStyle.NUMERIC}, + has_native_list_expansion=True, + preserve_parameter_format=False, + needs_static_script_compilation=True, + allow_mixed_parameter_styles=True, + preserve_original_params_for_many=True, + json_serializer_strategy="driver", + ) + + key = "test-override" + DRIVER_PARAMETER_PROFILES.pop(key, None) + register_driver_profile(key, base_profile) + try: + register_driver_profile(key, replacement_profile, allow_override=True) + resolved_profile = get_driver_profile(key) + + assert resolved_profile is replacement_profile + assert resolved_profile.default_style is ParameterStyle.NUMERIC + assert resolved_profile.has_native_list_expansion is True + finally: + DRIVER_PARAMETER_PROFILES.pop(key, None) def test_mixed_parameter_style_with_processor() -> None: @@ -354,7 +582,7 @@ def output_transformer(sql: str, params: Any) -> tuple[str, Any]: assert config.default_parameter_style == ParameterStyle.NAMED_COLON assert config.supported_parameter_styles == {ParameterStyle.NAMED_COLON, ParameterStyle.QMARK} - assert config.supported_execution_parameter_styles == {ParameterStyle.QMARK} + assert config.supported_execution_parameter_styles == {ParameterStyle.QMARK} # type: ignore[comparison-overlap] assert config.default_execution_parameter_style == ParameterStyle.QMARK assert config.type_coercion_map == coercion_map assert config.has_native_list_expansion is True diff --git a/uv.lock b/uv.lock index b12e3a116..6759942b6 100644 --- a/uv.lock +++ b/uv.lock @@ -3969,28 +3969,28 @@ wheels = [ [[package]] name = "psutil" -version = "7.1.2" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/cd/ec/7b8e6b9b1d22708138630ef34c53ab2b61032c04f16adfdbb96791c8c70c/psutil-7.1.2.tar.gz", hash = "sha256:aa225cdde1335ff9684708ee8c72650f6598d5ed2114b9a7c5802030b1785018", size = 487424, upload-time = "2025-10-25T10:46:34.931Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/b8/d9/b56cc9f883140ac10021a8c9b0f4e16eed1ba675c22513cdcbce3ba64014/psutil-7.1.2-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:0cc5c6889b9871f231ed5455a9a02149e388fffcb30b607fb7a8896a6d95f22e", size = 238575, upload-time = "2025-10-25T10:46:38.728Z" }, - { url = "https://files.pythonhosted.org/packages/36/eb/28d22de383888deb252c818622196e709da98816e296ef95afda33f1c0a2/psutil-7.1.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:8e9e77a977208d84aa363a4a12e0f72189d58bbf4e46b49aae29a2c6e93ef206", size = 239297, upload-time = "2025-10-25T10:46:41.347Z" }, - { url = "https://files.pythonhosted.org/packages/89/5d/220039e2f28cc129626e54d63892ab05c0d56a29818bfe7268dcb5008932/psutil-7.1.2-cp313-cp313t-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7d9623a5e4164d2220ecceb071f4b333b3c78866141e8887c072129185f41278", size = 280420, upload-time = "2025-10-25T10:46:44.122Z" }, - { url = "https://files.pythonhosted.org/packages/ba/7a/286f0e1c167445b2ef4a6cbdfc8c59fdb45a5a493788950cf8467201dc73/psutil-7.1.2-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:364b1c10fe4ed59c89ec49e5f1a70da353b27986fa8233b4b999df4742a5ee2f", size = 283049, upload-time = "2025-10-25T10:46:47.095Z" }, - { url = "https://files.pythonhosted.org/packages/aa/cc/7eb93260794a42e39b976f3a4dde89725800b9f573b014fac142002a5c98/psutil-7.1.2-cp313-cp313t-win_amd64.whl", hash = "sha256:f101ef84de7e05d41310e3ccbdd65a6dd1d9eed85e8aaf0758405d022308e204", size = 248713, upload-time = "2025-10-25T10:46:49.573Z" }, - { url = "https://files.pythonhosted.org/packages/ab/1a/0681a92b53366e01f0a099f5237d0c8a2f79d322ac589cccde5e30c8a4e2/psutil-7.1.2-cp313-cp313t-win_arm64.whl", hash = "sha256:20c00824048a95de67f00afedc7b08b282aa08638585b0206a9fb51f28f1a165", size = 244644, upload-time = "2025-10-25T10:46:51.924Z" }, - { url = "https://files.pythonhosted.org/packages/56/9e/f1c5c746b4ed5320952acd3002d3962fe36f30524c00ea79fdf954cc6779/psutil-7.1.2-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:e09cfe92aa8e22b1ec5e2d394820cf86c5dff6367ac3242366485dfa874d43bc", size = 238640, upload-time = "2025-10-25T10:46:54.089Z" }, - { url = "https://files.pythonhosted.org/packages/32/ee/fd26216a735395cc25c3899634e34aeb41fb1f3dbb44acc67d9e594be562/psutil-7.1.2-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:fa6342cf859c48b19df3e4aa170e4cfb64aadc50b11e06bb569c6c777b089c9e", size = 239303, upload-time = "2025-10-25T10:46:56.932Z" }, - { url = "https://files.pythonhosted.org/packages/3c/cd/7d96eaec4ef7742b845a9ce2759a2769ecce4ab7a99133da24abacbc9e41/psutil-7.1.2-cp314-cp314t-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:625977443498ee7d6c1e63e93bacca893fd759a66c5f635d05e05811d23fb5ee", size = 281717, upload-time = "2025-10-25T10:46:59.116Z" }, - { url = "https://files.pythonhosted.org/packages/bc/1a/7f0b84bdb067d35fe7fade5fff888408688caf989806ce2d6dae08c72dd5/psutil-7.1.2-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4a24bcd7b7f2918d934af0fb91859f621b873d6aa81267575e3655cd387572a7", size = 284575, upload-time = "2025-10-25T10:47:00.944Z" }, - { url = "https://files.pythonhosted.org/packages/de/05/7820ef8f7b275268917e0c750eada5834581206d9024ca88edce93c4b762/psutil-7.1.2-cp314-cp314t-win_amd64.whl", hash = "sha256:329f05610da6380982e6078b9d0881d9ab1e9a7eb7c02d833bfb7340aa634e31", size = 249491, upload-time = "2025-10-25T10:47:03.174Z" }, - { url = "https://files.pythonhosted.org/packages/db/9a/58de399c7cb58489f08498459ff096cd76b3f1ddc4f224ec2c5ef729c7d0/psutil-7.1.2-cp314-cp314t-win_arm64.whl", hash = "sha256:7b04c29e3c0c888e83ed4762b70f31e65c42673ea956cefa8ced0e31e185f582", size = 244880, upload-time = "2025-10-25T10:47:05.228Z" }, - { url = "https://files.pythonhosted.org/packages/ae/89/b9f8d47ddbc52d7301fc868e8224e5f44ed3c7f55e6d0f54ecaf5dd9ff5e/psutil-7.1.2-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:c9ba5c19f2d46203ee8c152c7b01df6eec87d883cfd8ee1af2ef2727f6b0f814", size = 237244, upload-time = "2025-10-25T10:47:07.086Z" }, - { url = "https://files.pythonhosted.org/packages/c8/7a/8628c2f6b240680a67d73d8742bb9ff39b1820a693740e43096d5dcb01e5/psutil-7.1.2-cp36-abi3-macosx_11_0_arm64.whl", hash = "sha256:2a486030d2fe81bec023f703d3d155f4823a10a47c36784c84f1cc7f8d39bedb", size = 238101, upload-time = "2025-10-25T10:47:09.523Z" }, - { url = "https://files.pythonhosted.org/packages/30/28/5e27f4d5a0e347f8e3cc16cd7d35533dbce086c95807f1f0e9cd77e26c10/psutil-7.1.2-cp36-abi3-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3efd8fc791492e7808a51cb2b94889db7578bfaea22df931424f874468e389e3", size = 258675, upload-time = "2025-10-25T10:47:11.082Z" }, - { url = "https://files.pythonhosted.org/packages/e5/5c/79cf60c9acf36d087f0db0f82066fca4a780e97e5b3a2e4c38209c03d170/psutil-7.1.2-cp36-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e2aeb9b64f481b8eabfc633bd39e0016d4d8bbcd590d984af764d80bf0851b8a", size = 260203, upload-time = "2025-10-25T10:47:13.226Z" }, - { url = "https://files.pythonhosted.org/packages/f7/03/0a464404c51685dcb9329fdd660b1721e076ccd7b3d97dee066bcc9ffb15/psutil-7.1.2-cp37-abi3-win_amd64.whl", hash = "sha256:8e17852114c4e7996fe9da4745c2bdef001ebbf2f260dec406290e66628bdb91", size = 246714, upload-time = "2025-10-25T10:47:15.093Z" }, - { url = "https://files.pythonhosted.org/packages/6a/32/97ca2090f2f1b45b01b6aa7ae161cfe50671de097311975ca6eea3e7aabc/psutil-7.1.2-cp37-abi3-win_arm64.whl", hash = "sha256:3e988455e61c240cc879cb62a008c2699231bf3e3d061d7fce4234463fd2abb4", size = 243742, upload-time = "2025-10-25T10:47:17.302Z" }, +version = "7.1.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e1/88/bdd0a41e5857d5d703287598cbf08dad90aed56774ea52ae071bae9071b6/psutil-7.1.3.tar.gz", hash = "sha256:6c86281738d77335af7aec228328e944b30930899ea760ecf33a4dba66be5e74", size = 489059, upload-time = "2025-11-02T12:25:54.619Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bd/93/0c49e776b8734fef56ec9c5c57f923922f2cf0497d62e0f419465f28f3d0/psutil-7.1.3-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:0005da714eee687b4b8decd3d6cc7c6db36215c9e74e5ad2264b90c3df7d92dc", size = 239751, upload-time = "2025-11-02T12:25:58.161Z" }, + { url = "https://files.pythonhosted.org/packages/6f/8d/b31e39c769e70780f007969815195a55c81a63efebdd4dbe9e7a113adb2f/psutil-7.1.3-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:19644c85dcb987e35eeeaefdc3915d059dac7bd1167cdcdbf27e0ce2df0c08c0", size = 240368, upload-time = "2025-11-02T12:26:00.491Z" }, + { url = "https://files.pythonhosted.org/packages/62/61/23fd4acc3c9eebbf6b6c78bcd89e5d020cfde4acf0a9233e9d4e3fa698b4/psutil-7.1.3-cp313-cp313t-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:95ef04cf2e5ba0ab9eaafc4a11eaae91b44f4ef5541acd2ee91d9108d00d59a7", size = 287134, upload-time = "2025-11-02T12:26:02.613Z" }, + { url = "https://files.pythonhosted.org/packages/30/1c/f921a009ea9ceb51aa355cb0cc118f68d354db36eae18174bab63affb3e6/psutil-7.1.3-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1068c303be3a72f8e18e412c5b2a8f6d31750fb152f9cb106b54090296c9d251", size = 289904, upload-time = "2025-11-02T12:26:05.207Z" }, + { url = "https://files.pythonhosted.org/packages/a6/82/62d68066e13e46a5116df187d319d1724b3f437ddd0f958756fc052677f4/psutil-7.1.3-cp313-cp313t-win_amd64.whl", hash = "sha256:18349c5c24b06ac5612c0428ec2a0331c26443d259e2a0144a9b24b4395b58fa", size = 249642, upload-time = "2025-11-02T12:26:07.447Z" }, + { url = "https://files.pythonhosted.org/packages/df/ad/c1cd5fe965c14a0392112f68362cfceb5230819dbb5b1888950d18a11d9f/psutil-7.1.3-cp313-cp313t-win_arm64.whl", hash = "sha256:c525ffa774fe4496282fb0b1187725793de3e7c6b29e41562733cae9ada151ee", size = 245518, upload-time = "2025-11-02T12:26:09.719Z" }, + { url = "https://files.pythonhosted.org/packages/2e/bb/6670bded3e3236eb4287c7bcdc167e9fae6e1e9286e437f7111caed2f909/psutil-7.1.3-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:b403da1df4d6d43973dc004d19cee3b848e998ae3154cc8097d139b77156c353", size = 239843, upload-time = "2025-11-02T12:26:11.968Z" }, + { url = "https://files.pythonhosted.org/packages/b8/66/853d50e75a38c9a7370ddbeefabdd3d3116b9c31ef94dc92c6729bc36bec/psutil-7.1.3-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:ad81425efc5e75da3f39b3e636293360ad8d0b49bed7df824c79764fb4ba9b8b", size = 240369, upload-time = "2025-11-02T12:26:14.358Z" }, + { url = "https://files.pythonhosted.org/packages/41/bd/313aba97cb5bfb26916dc29cf0646cbe4dd6a89ca69e8c6edce654876d39/psutil-7.1.3-cp314-cp314t-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8f33a3702e167783a9213db10ad29650ebf383946e91bc77f28a5eb083496bc9", size = 288210, upload-time = "2025-11-02T12:26:16.699Z" }, + { url = "https://files.pythonhosted.org/packages/c2/fa/76e3c06e760927a0cfb5705eb38164254de34e9bd86db656d4dbaa228b04/psutil-7.1.3-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:fac9cd332c67f4422504297889da5ab7e05fd11e3c4392140f7370f4208ded1f", size = 291182, upload-time = "2025-11-02T12:26:18.848Z" }, + { url = "https://files.pythonhosted.org/packages/0f/1d/5774a91607035ee5078b8fd747686ebec28a962f178712de100d00b78a32/psutil-7.1.3-cp314-cp314t-win_amd64.whl", hash = "sha256:3792983e23b69843aea49c8f5b8f115572c5ab64c153bada5270086a2123c7e7", size = 250466, upload-time = "2025-11-02T12:26:21.183Z" }, + { url = "https://files.pythonhosted.org/packages/00/ca/e426584bacb43a5cb1ac91fae1937f478cd8fbe5e4ff96574e698a2c77cd/psutil-7.1.3-cp314-cp314t-win_arm64.whl", hash = "sha256:31d77fcedb7529f27bb3a0472bea9334349f9a04160e8e6e5020f22c59893264", size = 245756, upload-time = "2025-11-02T12:26:23.148Z" }, + { url = "https://files.pythonhosted.org/packages/ef/94/46b9154a800253e7ecff5aaacdf8ebf43db99de4a2dfa18575b02548654e/psutil-7.1.3-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:2bdbcd0e58ca14996a42adf3621a6244f1bb2e2e528886959c72cf1e326677ab", size = 238359, upload-time = "2025-11-02T12:26:25.284Z" }, + { url = "https://files.pythonhosted.org/packages/68/3a/9f93cff5c025029a36d9a92fef47220ab4692ee7f2be0fba9f92813d0cb8/psutil-7.1.3-cp36-abi3-macosx_11_0_arm64.whl", hash = "sha256:bc31fa00f1fbc3c3802141eede66f3a2d51d89716a194bf2cd6fc68310a19880", size = 239171, upload-time = "2025-11-02T12:26:27.23Z" }, + { url = "https://files.pythonhosted.org/packages/ce/b1/5f49af514f76431ba4eea935b8ad3725cdeb397e9245ab919dbc1d1dc20f/psutil-7.1.3-cp36-abi3-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3bb428f9f05c1225a558f53e30ccbad9930b11c3fc206836242de1091d3e7dd3", size = 263261, upload-time = "2025-11-02T12:26:29.48Z" }, + { url = "https://files.pythonhosted.org/packages/e0/95/992c8816a74016eb095e73585d747e0a8ea21a061ed3689474fabb29a395/psutil-7.1.3-cp36-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:56d974e02ca2c8eb4812c3f76c30e28836fffc311d55d979f1465c1feeb2b68b", size = 264635, upload-time = "2025-11-02T12:26:31.74Z" }, + { url = "https://files.pythonhosted.org/packages/55/4c/c3ed1a622b6ae2fd3c945a366e64eb35247a31e4db16cf5095e269e8eb3c/psutil-7.1.3-cp37-abi3-win_amd64.whl", hash = "sha256:f39c2c19fe824b47484b96f9692932248a54c43799a84282cfe58d05a6449efd", size = 247633, upload-time = "2025-11-02T12:26:33.887Z" }, + { url = "https://files.pythonhosted.org/packages/c9/ad/33b2ccec09bf96c2b2ef3f9a6f66baac8253d7565d8839e024a6b905d45d/psutil-7.1.3-cp37-abi3-win_arm64.whl", hash = "sha256:bd0d69cee829226a761e92f28140bec9a5ee9d5b4fb4b0cc589068dbfff559b1", size = 244608, upload-time = "2025-11-02T12:26:36.136Z" }, ] [[package]] @@ -6265,15 +6265,15 @@ wheels = [ [[package]] name = "starlette" -version = "0.49.1" +version = "0.49.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, { name = "typing-extensions", marker = "python_full_version < '3.13'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/1b/3f/507c21db33b66fb027a332f2cb3abbbe924cc3a79ced12f01ed8645955c9/starlette-0.49.1.tar.gz", hash = "sha256:481a43b71e24ed8c43b11ea02f5353d77840e01480881b8cb5a26b8cae64a8cb", size = 2654703, upload-time = "2025-10-28T17:34:10.928Z" } +sdist = { url = "https://files.pythonhosted.org/packages/de/1a/608df0b10b53b0beb96a37854ee05864d182ddd4b1156a22f1ad3860425a/starlette-0.49.3.tar.gz", hash = "sha256:1c14546f299b5901a1ea0e34410575bc33bbd741377a10484a54445588d00284", size = 2655031, upload-time = "2025-11-01T15:12:26.13Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/51/da/545b75d420bb23b5d494b0517757b351963e974e79933f01e05c929f20a6/starlette-0.49.1-py3-none-any.whl", hash = "sha256:d92ce9f07e4a3caa3ac13a79523bd18e3bc0042bb8ff2d759a8e7dd0e1859875", size = 74175, upload-time = "2025-10-28T17:34:09.13Z" }, + { url = "https://files.pythonhosted.org/packages/a3/e0/021c772d6a662f43b63044ab481dc6ac7592447605b5b35a957785363122/starlette-0.49.3-py3-none-any.whl", hash = "sha256:b579b99715fdc2980cf88c8ec96d3bf1ce16f5a8051a7c2b84ef9b1cdecaea2f", size = 74340, upload-time = "2025-11-01T15:12:24.387Z" }, ] [[package]]