Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions docs/reference/adapters.rst
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,15 @@ duckdb
"SELECT * FROM 'https://example.com/data.parquet' LIMIT 10"
)

**Community Extensions**:

DuckDBConfig accepts the runtime flags DuckDB expects for community/unsigned extensions via
``pool_config`` (for example ``allow_community_extensions=True``,
``allow_unsigned_extensions=True``, ``enable_external_access=True``). SQLSpec applies those
options with ``SET`` statements immediately after establishing each connection, so even older
DuckDB builds that do not recognize the options during ``duckdb.connect()`` will still enable the
required permissions before extensions are installed.

**API Reference**:

.. autoclass:: DuckDBConfig
Expand Down
40 changes: 35 additions & 5 deletions sqlspec/adapters/duckdb/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from collections.abc import Callable, Generator

from sqlspec.core import StatementConfig

__all__ = (
"DuckDBConfig",
"DuckDBConnectionParams",
Expand All @@ -31,10 +30,21 @@
"DuckDBPoolParams",
"DuckDBSecretConfig",
)
EXTENSION_FLAG_KEYS: "tuple[str, ...]" = (
"allow_community_extensions",
"allow_unsigned_extensions",
"enable_external_access",
)


class DuckDBConnectionParams(TypedDict):
"""DuckDB connection parameters."""
"""DuckDB connection parameters.

Mirrors the keyword arguments accepted by duckdb.connect so callers can drive every DuckDB
configuration switch directly through SQLSpec. All keys are optional and forwarded verbatim
to DuckDB, either as top-level parameters or via the nested ``config`` dictionary when DuckDB
expects them there.
"""

database: NotRequired[str]
read_only: NotRequired[bool]
Expand Down Expand Up @@ -75,7 +85,8 @@ class DuckDBConnectionParams(TypedDict):
class DuckDBPoolParams(DuckDBConnectionParams):
"""Complete pool configuration for DuckDB adapter.

Combines standardized pool parameters with DuckDB-specific connection parameters.
Extends DuckDBConnectionParams with pool sizing and lifecycle settings so SQLSpec can manage
per-thread DuckDB connections safely while honoring DuckDB's thread-safety constraints.
"""

pool_min_size: NotRequired[int]
Expand Down Expand Up @@ -128,13 +139,16 @@ class DuckDBDriverFeatures(TypedDict):
enable_uuid_conversion: Enable automatic UUID string conversion.
When True (default), UUID strings are automatically converted to UUID objects.
When False, UUID strings are treated as regular strings.
extension_flags: Connection-level flags (e.g., allow_community_extensions) applied
via SET statements immediately after connection creation.
"""

extensions: NotRequired[Sequence[DuckDBExtensionConfig]]
secrets: NotRequired[Sequence[DuckDBSecretConfig]]
on_connection_create: NotRequired["Callable[[DuckDBConnection], DuckDBConnection | None]"]
json_serializer: NotRequired["Callable[[Any], str]"]
enable_uuid_conversion: NotRequired[bool]
extension_flags: NotRequired[dict[str, Any]]


class DuckDBConfig(SyncDatabaseConfig[DuckDBConnection, DuckDBConnectionPool, DuckDBDriver]):
Expand Down Expand Up @@ -222,13 +236,23 @@ def __init__(
if pool_config.get("database") in {":memory:", ""}:
pool_config["database"] = ":memory:shared_db"

processed_features = dict(driver_features) if driver_features else {}
extension_flags: dict[str, Any] = {}
for key in tuple(pool_config.keys()):
if key in EXTENSION_FLAG_KEYS:
extension_flags[key] = pool_config.pop(key) # type: ignore[misc]

processed_features: dict[str, Any] = dict(driver_features) if driver_features else {}
user_connection_hook = cast(
"Callable[[Any], None] | None", processed_features.pop("on_connection_create", None)
)
processed_features.setdefault("enable_uuid_conversion", True)
serializer = processed_features.setdefault("json_serializer", to_json)

if extension_flags:
existing_flags = cast("dict[str, Any]", processed_features.get("extension_flags", {}))
merged_flags = {**existing_flags, **extension_flags}
processed_features["extension_flags"] = merged_flags

local_observability = observability_config
if user_connection_hook is not None:

Expand Down Expand Up @@ -271,11 +295,17 @@ def _create_pool(self) -> DuckDBConnectionPool:

extensions = self.driver_features.get("extensions", None)
secrets = self.driver_features.get("secrets", None)
extension_flags = self.driver_features.get("extension_flags", None)
extensions_dicts = [dict(ext) for ext in extensions] if extensions else None
secrets_dicts = [dict(secret) for secret in secrets] if secrets else None
extension_flags_dict = dict(extension_flags) if extension_flags else None

return DuckDBConnectionPool(
connection_config=connection_config, extensions=extensions_dicts, secrets=secrets_dicts, **self.pool_config
connection_config=connection_config,
extensions=extensions_dicts,
extension_flags=extension_flags_dict,
secrets=secrets_dicts,
**self.pool_config,
)

def _close_pool(self) -> None:
Expand Down
33 changes: 33 additions & 0 deletions sqlspec/adapters/duckdb/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class DuckDBConnectionPool:
"_connection_config",
"_connection_times",
"_created_connections",
"_extension_flags",
"_extensions",
"_lock",
"_on_connection_create",
Expand All @@ -52,6 +53,7 @@ def __init__(
connection_config: "dict[str, Any]",
pool_recycle_seconds: int = POOL_RECYCLE,
extensions: "list[dict[str, Any]] | None" = None,
extension_flags: "dict[str, Any] | None" = None,
secrets: "list[dict[str, Any]] | None" = None,
on_connection_create: "Callable[[DuckDBConnection], None] | None" = None,
**kwargs: Any,
Expand All @@ -62,13 +64,15 @@ def __init__(
connection_config: DuckDB connection configuration
pool_recycle_seconds: Connection recycle time in seconds
extensions: List of extensions to install/load
extension_flags: Connection-level SET statements applied after creation
secrets: List of secrets to create
on_connection_create: Callback executed when connection is created
**kwargs: Additional parameters ignored for compatibility
"""
self._connection_config = connection_config
self._recycle = pool_recycle_seconds
self._extensions = extensions or []
self._extension_flags = extension_flags or {}
self._secrets = secrets or []
self._on_connection_create = on_connection_create
self._thread_local = threading.local()
Expand All @@ -92,6 +96,8 @@ def _create_connection(self) -> DuckDBConnection:

connection = duckdb.connect(**connect_parameters)

self._apply_extension_flags(connection)

for ext_config in self._extensions:
ext_name = ext_config.get("name")
if not ext_name:
Expand Down Expand Up @@ -149,6 +155,33 @@ def _create_connection(self) -> DuckDBConnection:

return connection

def _apply_extension_flags(self, connection: DuckDBConnection) -> None:
"""Apply connection-level extension flags via SET statements."""

if not self._extension_flags:
return

for key, value in self._extension_flags.items():
if not key or not key.replace("_", "").isalnum():
continue

normalized = self._normalize_flag_value(value)
try:
connection.execute(f"SET {key} = {normalized}")
except Exception as exc: # pragma: no cover - best-effort guard
logger.debug("Failed to set DuckDB flag %s: %s", key, exc)

@staticmethod
def _normalize_flag_value(value: Any) -> str:
"""Convert Python value to DuckDB SET literal."""

if isinstance(value, bool):
return "TRUE" if value else "FALSE"
if isinstance(value, (int, float)):
return str(value)
escaped = str(value).replace("'", "''")
return f"'{escaped}'"

def _get_thread_connection(self) -> DuckDBConnection:
"""Get or create a connection for the current thread.

Expand Down
36 changes: 36 additions & 0 deletions tests/unit/adapters/test_duckdb_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import pytest

pytest.importorskip("duckdb", reason="DuckDB adapter requires duckdb package")

from sqlspec.adapters.duckdb import DuckDBConfig


def test_duckdb_config_promotes_security_flags() -> None:
config = DuckDBConfig(
pool_config={
"database": ":memory:",
"allow_community_extensions": True,
"allow_unsigned_extensions": False,
"enable_external_access": True,
}
)

flags = config.driver_features.get("extension_flags")
assert flags == {
"allow_community_extensions": True,
"allow_unsigned_extensions": False,
"enable_external_access": True,
}
assert "allow_community_extensions" not in config.pool_config
assert "allow_unsigned_extensions" not in config.pool_config
assert "enable_external_access" not in config.pool_config


def test_duckdb_config_merges_existing_extension_flags() -> None:
config = DuckDBConfig(
pool_config={"database": ":memory:", "allow_community_extensions": True},
driver_features={"extension_flags": {"custom": "value"}},
)

flags = config.driver_features.get("extension_flags")
assert flags == {"custom": "value", "allow_community_extensions": True}