From 1929a84229fc175706f70b50b065bb84a7b7422b Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Mon, 10 Nov 2025 17:58:36 +0000 Subject: [PATCH 1/3] feat(duckdb): add support for community extension flags in DuckDB configuration --- docs/reference/adapters.rst | 9 +++++ sqlspec/adapters/duckdb/config.py | 40 ++++++++++++++++++--- sqlspec/adapters/duckdb/pool.py | 33 +++++++++++++++++ tests/unit/adapters/test_duckdb_config.py | 44 +++++++++++++++++++++++ 4 files changed, 121 insertions(+), 5 deletions(-) create mode 100644 tests/unit/adapters/test_duckdb_config.py diff --git a/docs/reference/adapters.rst b/docs/reference/adapters.rst index 9148017ce..2328dad84 100644 --- a/docs/reference/adapters.rst +++ b/docs/reference/adapters.rst @@ -570,6 +570,15 @@ duckdb "SELECT * FROM 'https://example.com/data.parquet' LIMIT 10" ) +**Community Extensions**: + +DuckDBConfig accepts the runtime flags DuckDB expects for community/unsigned extensions via +``pool_config`` (for example ``allow_community_extensions=True``, +``allow_unsigned_extensions=True``, ``enable_external_access=True``). SQLSpec applies those +options with ``SET`` statements immediately after establishing each connection, so even older +DuckDB builds that do not recognize the options during ``duckdb.connect()`` will still enable the +required permissions before extensions are installed. + **API Reference**: .. autoclass:: DuckDBConfig diff --git a/sqlspec/adapters/duckdb/config.py b/sqlspec/adapters/duckdb/config.py index d5472ddef..c9c682b9a 100644 --- a/sqlspec/adapters/duckdb/config.py +++ b/sqlspec/adapters/duckdb/config.py @@ -22,7 +22,6 @@ from collections.abc import Callable, Generator from sqlspec.core import StatementConfig - __all__ = ( "DuckDBConfig", "DuckDBConnectionParams", @@ -31,10 +30,21 @@ "DuckDBPoolParams", "DuckDBSecretConfig", ) +EXTENSION_FLAG_KEYS: "tuple[str, ...]" = ( + "allow_community_extensions", + "allow_unsigned_extensions", + "enable_external_access", +) class DuckDBConnectionParams(TypedDict): - """DuckDB connection parameters.""" + """DuckDB connection parameters. + + Mirrors the keyword arguments accepted by duckdb.connect so callers can drive every DuckDB + configuration switch directly through SQLSpec. All keys are optional and forwarded verbatim + to DuckDB, either as top-level parameters or via the nested ``config`` dictionary when DuckDB + expects them there. + """ database: NotRequired[str] read_only: NotRequired[bool] @@ -75,7 +85,8 @@ class DuckDBConnectionParams(TypedDict): class DuckDBPoolParams(DuckDBConnectionParams): """Complete pool configuration for DuckDB adapter. - Combines standardized pool parameters with DuckDB-specific connection parameters. + Extends DuckDBConnectionParams with pool sizing and lifecycle settings so SQLSpec can manage + per-thread DuckDB connections safely while honoring DuckDB's thread-safety constraints. """ pool_min_size: NotRequired[int] @@ -128,6 +139,8 @@ class DuckDBDriverFeatures(TypedDict): enable_uuid_conversion: Enable automatic UUID string conversion. When True (default), UUID strings are automatically converted to UUID objects. When False, UUID strings are treated as regular strings. + extension_flags: Connection-level flags (e.g., allow_community_extensions) applied + via SET statements immediately after connection creation. """ extensions: NotRequired[Sequence[DuckDBExtensionConfig]] @@ -135,6 +148,7 @@ class DuckDBDriverFeatures(TypedDict): on_connection_create: NotRequired["Callable[[DuckDBConnection], DuckDBConnection | None]"] json_serializer: NotRequired["Callable[[Any], str]"] enable_uuid_conversion: NotRequired[bool] + extension_flags: NotRequired[dict[str, Any]] class DuckDBConfig(SyncDatabaseConfig[DuckDBConnection, DuckDBConnectionPool, DuckDBDriver]): @@ -222,13 +236,23 @@ def __init__( if pool_config.get("database") in {":memory:", ""}: pool_config["database"] = ":memory:shared_db" - processed_features = dict(driver_features) if driver_features else {} + extension_flags: dict[str, Any] = {} + for key in tuple(pool_config.keys()): + if key in EXTENSION_FLAG_KEYS: + extension_flags[key] = pool_config.pop(key) + + processed_features: dict[str, Any] = dict(driver_features) if driver_features else {} user_connection_hook = cast( "Callable[[Any], None] | None", processed_features.pop("on_connection_create", None) ) processed_features.setdefault("enable_uuid_conversion", True) serializer = processed_features.setdefault("json_serializer", to_json) + if extension_flags: + existing_flags = cast("dict[str, Any]", processed_features.get("extension_flags", {})) + merged_flags = {**existing_flags, **extension_flags} + processed_features["extension_flags"] = merged_flags + local_observability = observability_config if user_connection_hook is not None: @@ -271,11 +295,17 @@ def _create_pool(self) -> DuckDBConnectionPool: extensions = self.driver_features.get("extensions", None) secrets = self.driver_features.get("secrets", None) + extension_flags = self.driver_features.get("extension_flags", None) extensions_dicts = [dict(ext) for ext in extensions] if extensions else None secrets_dicts = [dict(secret) for secret in secrets] if secrets else None + extension_flags_dict = dict(extension_flags) if extension_flags else None return DuckDBConnectionPool( - connection_config=connection_config, extensions=extensions_dicts, secrets=secrets_dicts, **self.pool_config + connection_config=connection_config, + extensions=extensions_dicts, + extension_flags=extension_flags_dict, + secrets=secrets_dicts, + **self.pool_config, ) def _close_pool(self) -> None: diff --git a/sqlspec/adapters/duckdb/pool.py b/sqlspec/adapters/duckdb/pool.py index 6926c38a7..5a40c874e 100644 --- a/sqlspec/adapters/duckdb/pool.py +++ b/sqlspec/adapters/duckdb/pool.py @@ -39,6 +39,7 @@ class DuckDBConnectionPool: "_connection_config", "_connection_times", "_created_connections", + "_extension_flags", "_extensions", "_lock", "_on_connection_create", @@ -52,6 +53,7 @@ def __init__( connection_config: "dict[str, Any]", pool_recycle_seconds: int = POOL_RECYCLE, extensions: "list[dict[str, Any]] | None" = None, + extension_flags: "dict[str, Any] | None" = None, secrets: "list[dict[str, Any]] | None" = None, on_connection_create: "Callable[[DuckDBConnection], None] | None" = None, **kwargs: Any, @@ -62,6 +64,7 @@ def __init__( connection_config: DuckDB connection configuration pool_recycle_seconds: Connection recycle time in seconds extensions: List of extensions to install/load + extension_flags: Connection-level SET statements applied after creation secrets: List of secrets to create on_connection_create: Callback executed when connection is created **kwargs: Additional parameters ignored for compatibility @@ -69,6 +72,7 @@ def __init__( self._connection_config = connection_config self._recycle = pool_recycle_seconds self._extensions = extensions or [] + self._extension_flags = extension_flags or {} self._secrets = secrets or [] self._on_connection_create = on_connection_create self._thread_local = threading.local() @@ -92,6 +96,8 @@ def _create_connection(self) -> DuckDBConnection: connection = duckdb.connect(**connect_parameters) + self._apply_extension_flags(connection) + for ext_config in self._extensions: ext_name = ext_config.get("name") if not ext_name: @@ -149,6 +155,33 @@ def _create_connection(self) -> DuckDBConnection: return connection + def _apply_extension_flags(self, connection: DuckDBConnection) -> None: + """Apply connection-level extension flags via SET statements.""" + + if not self._extension_flags: + return + + for key, value in self._extension_flags.items(): + if not key or not key.replace("_", "").isalnum(): + continue + + normalized = self._normalize_flag_value(value) + try: + connection.execute(f"SET {key} = {normalized}") + except Exception as exc: # pragma: no cover - best-effort guard + logger.debug("Failed to set DuckDB flag %s: %s", key, exc) + + @staticmethod + def _normalize_flag_value(value: Any) -> str: + """Convert Python value to DuckDB SET literal.""" + + if isinstance(value, bool): + return "TRUE" if value else "FALSE" + if isinstance(value, (int, float)): + return str(value) + escaped = str(value).replace("'", "''") + return f"'{escaped}'" + def _get_thread_connection(self) -> DuckDBConnection: """Get or create a connection for the current thread. diff --git a/tests/unit/adapters/test_duckdb_config.py b/tests/unit/adapters/test_duckdb_config.py new file mode 100644 index 000000000..1ecf06ac6 --- /dev/null +++ b/tests/unit/adapters/test_duckdb_config.py @@ -0,0 +1,44 @@ +import pytest + +pytest.importorskip("duckdb", reason="DuckDB adapter requires duckdb package") + +from sqlspec.adapters.duckdb import DuckDBConfig + + +def test_duckdb_config_promotes_security_flags() -> None: + config = DuckDBConfig( + pool_config={ + "database": ":memory:", + "allow_community_extensions": True, + "allow_unsigned_extensions": False, + "enable_external_access": True, + } + ) + + flags = config.driver_features.get("extension_flags") + assert flags == { + "allow_community_extensions": True, + "allow_unsigned_extensions": False, + "enable_external_access": True, + } + assert "allow_community_extensions" not in config.pool_config + assert "allow_unsigned_extensions" not in config.pool_config + assert "enable_external_access" not in config.pool_config + + +def test_duckdb_config_merges_existing_extension_flags() -> None: + config = DuckDBConfig( + pool_config={ + "database": ":memory:", + "allow_community_extensions": True, + }, + driver_features={ + "extension_flags": {"custom": "value"}, + }, + ) + + flags = config.driver_features.get("extension_flags") + assert flags == { + "custom": "value", + "allow_community_extensions": True, + } From da7b7158b2b371bd9d18a039e6404a62077ed765 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Mon, 10 Nov 2025 17:59:54 +0000 Subject: [PATCH 2/3] fix: suppress type checking error for extension flag key pop operation --- sqlspec/adapters/duckdb/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlspec/adapters/duckdb/config.py b/sqlspec/adapters/duckdb/config.py index c9c682b9a..59c1ad492 100644 --- a/sqlspec/adapters/duckdb/config.py +++ b/sqlspec/adapters/duckdb/config.py @@ -239,7 +239,7 @@ def __init__( extension_flags: dict[str, Any] = {} for key in tuple(pool_config.keys()): if key in EXTENSION_FLAG_KEYS: - extension_flags[key] = pool_config.pop(key) + extension_flags[key] = pool_config.pop(key) # type: ignore[misc] processed_features: dict[str, Any] = dict(driver_features) if driver_features else {} user_connection_hook = cast( From ac7c850502214bedbea1eeed07da1e91232126d0 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Mon, 10 Nov 2025 18:05:47 +0000 Subject: [PATCH 3/3] refactor(tests): simplify test_duckdb_config_merges_existing_extension_flags assertions --- tests/unit/adapters/test_duckdb_config.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/tests/unit/adapters/test_duckdb_config.py b/tests/unit/adapters/test_duckdb_config.py index 1ecf06ac6..5ebc9c34c 100644 --- a/tests/unit/adapters/test_duckdb_config.py +++ b/tests/unit/adapters/test_duckdb_config.py @@ -28,17 +28,9 @@ def test_duckdb_config_promotes_security_flags() -> None: def test_duckdb_config_merges_existing_extension_flags() -> None: config = DuckDBConfig( - pool_config={ - "database": ":memory:", - "allow_community_extensions": True, - }, - driver_features={ - "extension_flags": {"custom": "value"}, - }, + pool_config={"database": ":memory:", "allow_community_extensions": True}, + driver_features={"extension_flags": {"custom": "value"}}, ) flags = config.driver_features.get("extension_flags") - assert flags == { - "custom": "value", - "allow_community_extensions": True, - } + assert flags == {"custom": "value", "allow_community_extensions": True}