Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 132 additions & 0 deletions AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2616,6 +2616,138 @@ class StarletteConfig(TypedDict):
extra_rollback_statuses: NotRequired[set[int]]
```

### Disabling Built-in Dependency Injection (disable_di Pattern)

**When to Use**: When users want to integrate SQLSpec with their own dependency injection solution (e.g., Dishka, dependency-injector) and need full control over database lifecycle management.

**Pattern**: Add a `disable_di` boolean flag to framework extension configuration that conditionally skips the built-in DI setup.

**Implementation Steps**:

1. **Add to TypedDict in `sqlspec/config.py`**:

```python
class StarletteConfig(TypedDict):
# ... existing fields ...

disable_di: NotRequired[bool]
"""Disable built-in dependency injection. Default: False.
When True, the Starlette/FastAPI extension will not add middleware for managing
database connections and sessions. Users are responsible for managing the
database lifecycle manually via their own DI solution.
"""
```

2. **Add to Configuration State Dataclass**:

```python
@dataclass
class SQLSpecConfigState:
config: "DatabaseConfigProtocol[Any, Any, Any]"
connection_key: str
pool_key: str
session_key: str
commit_mode: CommitMode
extra_commit_statuses: "set[int] | None"
extra_rollback_statuses: "set[int] | None"
disable_di: bool # Add this field
```

3. **Extract from Config and Default to False**:

```python
def _extract_starlette_settings(self, config):
starlette_config = config.extension_config.get("starlette", {})
return {
# ... existing keys ...
"disable_di": starlette_config.get("disable_di", False), # Default False
}
```

4. **Conditionally Skip DI Setup**:

**Middleware-based (Starlette/FastAPI)**:
```python
def init_app(self, app):
# ... lifespan setup ...

for config_state in self._config_states:
if not config_state.disable_di: # Only add if DI enabled
self._add_middleware(app, config_state)
```

**Provider-based (Litestar)**:
```python
def on_app_init(self, app_config):
for state in self._plugin_configs:
# ... signature namespace ...

if not state.disable_di: # Only register if DI enabled
app_config.before_send.append(state.before_send_handler)
app_config.lifespan.append(state.lifespan_handler)
app_config.dependencies.update({
state.connection_key: Provide(state.connection_provider),
state.pool_key: Provide(state.pool_provider),
state.session_key: Provide(state.session_provider),
})
```

**Hook-based (Flask)**:
```python
def init_app(self, app):
# ... pool setup ...

# Only register hooks if at least one config has DI enabled
if any(not state.disable_di for state in self._config_states):
app.before_request(self._before_request_handler)
app.after_request(self._after_request_handler)
app.teardown_appcontext(self._teardown_appcontext_handler)

def _before_request_handler(self):
for config_state in self._config_states:
if config_state.disable_di: # Skip if DI disabled
continue
# ... connection setup ...
```

**Testing Requirements**:

1. **Test with `disable_di=True`**: Verify DI mechanisms are not active
2. **Test default behavior**: Verify `disable_di=False` preserves existing functionality
3. **Integration tests**: Demonstrate manual DI setup works correctly

**Example Usage**:

```python
from sqlspec.adapters.asyncpg import AsyncpgConfig
from sqlspec.base import SQLSpec
from sqlspec.extensions.starlette import SQLSpecPlugin

sql = SQLSpec()
config = AsyncpgConfig(
pool_config={"dsn": "postgresql://localhost/db"},
extension_config={"starlette": {"disable_di": True}} # Disable built-in DI
)
sql.add_config(config)
plugin = SQLSpecPlugin(sql)

# User is now responsible for manual lifecycle management
async def my_route(request):
pool = await config.create_pool()
async with config.provide_connection(pool) as connection:
session = config.driver_type(connection=connection, statement_config=config.statement_config)
result = await session.execute("SELECT 1")
await config.close_pool()
return result
```

**Key Principles**:

- **Backward Compatible**: Default `False` preserves existing behavior
- **Consistent Naming**: Use `disable_di` across all frameworks
- **Clear Documentation**: Warn users they are responsible for lifecycle management
- **Complete Control**: When disabled, extension does zero automatic DI

### Multi-Database Support

**Key validation ensures unique state keys**:
Expand Down
21 changes: 21 additions & 0 deletions sqlspec/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,13 @@ class FlaskConfig(TypedDict):
extra_rollback_statuses: NotRequired[set[int]]
"""Additional HTTP status codes that trigger rollback. Default: None."""

disable_di: NotRequired[bool]
"""Disable built-in dependency injection. Default: False.
When True, the Flask extension will not register request hooks for managing
database connections and sessions. Users are responsible for managing the
database lifecycle manually via their own DI solution.
"""


class LitestarConfig(TypedDict):
"""Configuration options for Litestar SQLSpec plugin.
Expand Down Expand Up @@ -170,6 +177,13 @@ class LitestarConfig(TypedDict):
extra_rollback_statuses: NotRequired[set[int]]
"""Additional HTTP status codes that trigger rollback. Default: set()"""

disable_di: NotRequired[bool]
"""Disable built-in dependency injection. Default: False.
When True, the Litestar plugin will not register dependency providers for managing
database connections, pools, and sessions. Users are responsible for managing the
database lifecycle manually via their own DI solution.
"""


class StarletteConfig(TypedDict):
"""Configuration options for Starlette and FastAPI extensions.
Expand Down Expand Up @@ -225,6 +239,13 @@ class StarletteConfig(TypedDict):
extra_rollback_statuses={409}
"""

disable_di: NotRequired[bool]
"""Disable built-in dependency injection. Default: False.
When True, the Starlette/FastAPI extension will not add middleware for managing
database connections and sessions. Users are responsible for managing the
database lifecycle manually via their own DI solution.
"""


class FastAPIConfig(StarletteConfig):
"""Configuration options for FastAPI SQLSpec extension.
Expand Down
1 change: 1 addition & 0 deletions sqlspec/extensions/flask/_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class FlaskConfigState:
extra_commit_statuses: "set[int] | None"
extra_rollback_statuses: "set[int] | None"
is_async: bool
disable_di: bool

def should_commit(self, status_code: int) -> bool:
"""Determine if HTTP status code should trigger commit.
Expand Down
19 changes: 16 additions & 3 deletions sqlspec/extensions/flask/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def _create_config_state(self, config: Any) -> FlaskConfigState:
commit_mode = flask_config.get("commit_mode", DEFAULT_COMMIT_MODE)
extra_commit_statuses = flask_config.get("extra_commit_statuses")
extra_rollback_statuses = flask_config.get("extra_rollback_statuses")
disable_di = flask_config.get("disable_di", False)

is_async = isinstance(config, (AsyncDatabaseConfig, NoPoolAsyncConfig))

Expand All @@ -107,6 +108,7 @@ def _create_config_state(self, config: Any) -> FlaskConfigState:
extra_commit_statuses=extra_commit_statuses,
extra_rollback_statuses=extra_rollback_statuses,
is_async=is_async,
disable_di=disable_di,
)

def init_app(self, app: "Flask") -> None:
Expand Down Expand Up @@ -143,9 +145,11 @@ def init_app(self, app: "Flask") -> None:

app.extensions["sqlspec"] = {"plugin": self, "pools": pools}

app.before_request(self._before_request_handler)
app.after_request(self._after_request_handler)
app.teardown_appcontext(self._teardown_appcontext_handler)
if any(not state.disable_di for state in self._config_states):
app.before_request(self._before_request_handler)
app.after_request(self._after_request_handler)
app.teardown_appcontext(self._teardown_appcontext_handler)

self._register_shutdown_hook()

logger.debug("SQLSpec Flask extension initialized")
Expand Down Expand Up @@ -186,6 +190,9 @@ def _before_request_handler(self) -> None:
from flask import current_app, g

for config_state in self._config_states:
if config_state.disable_di:
continue

if config_state.config.supports_connection_pooling:
pool = current_app.extensions["sqlspec"]["pools"][config_state.session_key]
conn_ctx = config_state.config.provide_connection(pool)
Expand Down Expand Up @@ -215,6 +222,9 @@ def _after_request_handler(self, response: "Response") -> "Response":
from flask import g

for config_state in self._config_states:
if config_state.disable_di:
continue

if config_state.commit_mode == "manual":
continue

Expand Down Expand Up @@ -242,6 +252,9 @@ def _teardown_appcontext_handler(self, _exc: "BaseException | None" = None) -> N
from flask import g

for config_state in self._config_states:
if config_state.disable_di:
continue

connection = getattr(g, config_state.connection_key, None)
ctx_key = f"{config_state.connection_key}_ctx"
conn_ctx = getattr(g, ctx_key, None)
Expand Down
21 changes: 13 additions & 8 deletions sqlspec/extensions/litestar/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class _PluginConfigState:
extra_commit_statuses: "set[int] | None"
extra_rollback_statuses: "set[int] | None"
enable_correlation_middleware: bool
disable_di: bool
connection_provider: "Callable[[State, Scope], AsyncGenerator[Any, None]]" = field(init=False)
pool_provider: "Callable[[State, Scope], Any]" = field(init=False)
session_provider: "Callable[..., AsyncGenerator[Any, None]]" = field(init=False)
Expand Down Expand Up @@ -157,6 +158,7 @@ def _extract_litestar_settings(
"extra_commit_statuses": litestar_config.get("extra_commit_statuses"),
"extra_rollback_statuses": litestar_config.get("extra_rollback_statuses"),
"enable_correlation_middleware": litestar_config.get("enable_correlation_middleware", True),
"disable_di": litestar_config.get("disable_di", False),
}

def _create_config_state(
Expand All @@ -174,9 +176,11 @@ def _create_config_state(
extra_commit_statuses=settings.get("extra_commit_statuses"),
extra_rollback_statuses=settings.get("extra_rollback_statuses"),
enable_correlation_middleware=settings["enable_correlation_middleware"],
disable_di=settings["disable_di"],
)

self._setup_handlers(state)
if not state.disable_di:
self._setup_handlers(state)
return state

def _setup_handlers(self, state: _PluginConfigState) -> None:
Expand Down Expand Up @@ -256,13 +260,14 @@ def store_sqlspec_in_state() -> None:

signature_namespace.update(state.config.get_signature_namespace()) # type: ignore[arg-type]

app_config.before_send.append(state.before_send_handler)
app_config.lifespan.append(state.lifespan_handler)
app_config.dependencies.update({
state.connection_key: Provide(state.connection_provider),
state.pool_key: Provide(state.pool_provider),
state.session_key: Provide(state.session_provider),
})
if not state.disable_di:
app_config.before_send.append(state.before_send_handler)
app_config.lifespan.append(state.lifespan_handler)
app_config.dependencies.update({
state.connection_key: Provide(state.connection_provider),
state.pool_key: Provide(state.pool_provider),
state.session_key: Provide(state.session_provider),
})

if signature_namespace:
app_config.signature_namespace.update(signature_namespace)
Expand Down
1 change: 1 addition & 0 deletions sqlspec/extensions/starlette/_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@ class SQLSpecConfigState:
commit_mode: CommitMode
extra_commit_statuses: "set[int] | None"
extra_rollback_statuses: "set[int] | None"
disable_di: bool
5 changes: 4 additions & 1 deletion sqlspec/extensions/starlette/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def _extract_starlette_settings(self, config: Any) -> "dict[str, Any]":
"commit_mode": commit_mode,
"extra_commit_statuses": starlette_config.get("extra_commit_statuses"),
"extra_rollback_statuses": starlette_config.get("extra_rollback_statuses"),
"disable_di": starlette_config.get("disable_di", False),
}

def _create_config_state(self, config: Any, settings: "dict[str, Any]") -> SQLSpecConfigState:
Expand All @@ -124,6 +125,7 @@ def _create_config_state(self, config: Any, settings: "dict[str, Any]") -> SQLSp
commit_mode=settings["commit_mode"],
extra_commit_statuses=settings["extra_commit_statuses"],
extra_rollback_statuses=settings["extra_rollback_statuses"],
disable_di=settings["disable_di"],
)

def init_app(self, app: "Starlette") -> None:
Expand All @@ -146,7 +148,8 @@ async def combined_lifespan(app: "Starlette") -> "AsyncGenerator[None, None]":
app.router.lifespan_context = combined_lifespan

for config_state in self._config_states:
self._add_middleware(app, config_state)
if not config_state.disable_di:
self._add_middleware(app, config_state)

def _validate_unique_keys(self) -> None:
"""Validate that all state keys are unique across configs.
Expand Down
Loading