From ef9c2ab4b5a037a9484229935861296ceb8debc3 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sun, 2 Nov 2025 01:37:30 +0000 Subject: [PATCH 1/3] feat: Adds `disable_di` boolean flag to Litestar, Starlette/FastAPI, and Flask extensions. When enabled, disables built-in dependency injection to allow users to manage database lifecycle with their own DI solutions (e.g., Dishka). --- sqlspec/config.py | 21 +++ sqlspec/extensions/flask/_state.py | 1 + sqlspec/extensions/flask/extension.py | 19 ++- sqlspec/extensions/litestar/plugin.py | 21 ++- sqlspec/extensions/starlette/_state.py | 1 + sqlspec/extensions/starlette/extension.py | 5 +- .../test_external_di_provider.py | 137 ++++++++++++++++++ .../test_extensions/test_flask/test_state.py | 8 + .../test_starlette/test_config_state.py | 3 + .../test_starlette/test_utils.py | 5 + 10 files changed, 209 insertions(+), 12 deletions(-) create mode 100644 tests/integration/test_extensions/test_external_di_provider.py diff --git a/sqlspec/config.py b/sqlspec/config.py index 4085d8a5..8039f300 100644 --- a/sqlspec/config.py +++ b/sqlspec/config.py @@ -142,6 +142,13 @@ class FlaskConfig(TypedDict): extra_rollback_statuses: NotRequired[set[int]] """Additional HTTP status codes that trigger rollback. Default: None.""" + disable_di: NotRequired[bool] + """Disable built-in dependency injection. Default: False. + When True, the Flask extension will not register request hooks for managing + database connections and sessions. Users are responsible for managing the + database lifecycle manually via their own DI solution. + """ + class LitestarConfig(TypedDict): """Configuration options for Litestar SQLSpec plugin. @@ -170,6 +177,13 @@ class LitestarConfig(TypedDict): extra_rollback_statuses: NotRequired[set[int]] """Additional HTTP status codes that trigger rollback. Default: set()""" + disable_di: NotRequired[bool] + """Disable built-in dependency injection. Default: False. + When True, the Litestar plugin will not register dependency providers for managing + database connections, pools, and sessions. Users are responsible for managing the + database lifecycle manually via their own DI solution. + """ + class StarletteConfig(TypedDict): """Configuration options for Starlette and FastAPI extensions. @@ -225,6 +239,13 @@ class StarletteConfig(TypedDict): extra_rollback_statuses={409} """ + disable_di: NotRequired[bool] + """Disable built-in dependency injection. Default: False. + When True, the Starlette/FastAPI extension will not add middleware for managing + database connections and sessions. Users are responsible for managing the + database lifecycle manually via their own DI solution. + """ + class FastAPIConfig(StarletteConfig): """Configuration options for FastAPI SQLSpec extension. diff --git a/sqlspec/extensions/flask/_state.py b/sqlspec/extensions/flask/_state.py index 84f62d76..ec61c2b9 100644 --- a/sqlspec/extensions/flask/_state.py +++ b/sqlspec/extensions/flask/_state.py @@ -27,6 +27,7 @@ class FlaskConfigState: extra_commit_statuses: "set[int] | None" extra_rollback_statuses: "set[int] | None" is_async: bool + disable_di: bool def should_commit(self, status_code: int) -> bool: """Determine if HTTP status code should trigger commit. diff --git a/sqlspec/extensions/flask/extension.py b/sqlspec/extensions/flask/extension.py index 7ffe7d8a..a8fafa4e 100644 --- a/sqlspec/extensions/flask/extension.py +++ b/sqlspec/extensions/flask/extension.py @@ -96,6 +96,7 @@ def _create_config_state(self, config: Any) -> FlaskConfigState: commit_mode = flask_config.get("commit_mode", DEFAULT_COMMIT_MODE) extra_commit_statuses = flask_config.get("extra_commit_statuses") extra_rollback_statuses = flask_config.get("extra_rollback_statuses") + disable_di = flask_config.get("disable_di", False) is_async = isinstance(config, (AsyncDatabaseConfig, NoPoolAsyncConfig)) @@ -107,6 +108,7 @@ def _create_config_state(self, config: Any) -> FlaskConfigState: extra_commit_statuses=extra_commit_statuses, extra_rollback_statuses=extra_rollback_statuses, is_async=is_async, + disable_di=disable_di, ) def init_app(self, app: "Flask") -> None: @@ -143,9 +145,11 @@ def init_app(self, app: "Flask") -> None: app.extensions["sqlspec"] = {"plugin": self, "pools": pools} - app.before_request(self._before_request_handler) - app.after_request(self._after_request_handler) - app.teardown_appcontext(self._teardown_appcontext_handler) + if any(not state.disable_di for state in self._config_states): + app.before_request(self._before_request_handler) + app.after_request(self._after_request_handler) + app.teardown_appcontext(self._teardown_appcontext_handler) + self._register_shutdown_hook() logger.debug("SQLSpec Flask extension initialized") @@ -186,6 +190,9 @@ def _before_request_handler(self) -> None: from flask import current_app, g for config_state in self._config_states: + if config_state.disable_di: + continue + if config_state.config.supports_connection_pooling: pool = current_app.extensions["sqlspec"]["pools"][config_state.session_key] conn_ctx = config_state.config.provide_connection(pool) @@ -215,6 +222,9 @@ def _after_request_handler(self, response: "Response") -> "Response": from flask import g for config_state in self._config_states: + if config_state.disable_di: + continue + if config_state.commit_mode == "manual": continue @@ -242,6 +252,9 @@ def _teardown_appcontext_handler(self, _exc: "BaseException | None" = None) -> N from flask import g for config_state in self._config_states: + if config_state.disable_di: + continue + connection = getattr(g, config_state.connection_key, None) ctx_key = f"{config_state.connection_key}_ctx" conn_ctx = getattr(g, ctx_key, None) diff --git a/sqlspec/extensions/litestar/plugin.py b/sqlspec/extensions/litestar/plugin.py index cea8c05b..c78db0e2 100644 --- a/sqlspec/extensions/litestar/plugin.py +++ b/sqlspec/extensions/litestar/plugin.py @@ -72,6 +72,7 @@ class _PluginConfigState: extra_commit_statuses: "set[int] | None" extra_rollback_statuses: "set[int] | None" enable_correlation_middleware: bool + disable_di: bool connection_provider: "Callable[[State, Scope], AsyncGenerator[Any, None]]" = field(init=False) pool_provider: "Callable[[State, Scope], Any]" = field(init=False) session_provider: "Callable[..., AsyncGenerator[Any, None]]" = field(init=False) @@ -157,6 +158,7 @@ def _extract_litestar_settings( "extra_commit_statuses": litestar_config.get("extra_commit_statuses"), "extra_rollback_statuses": litestar_config.get("extra_rollback_statuses"), "enable_correlation_middleware": litestar_config.get("enable_correlation_middleware", True), + "disable_di": litestar_config.get("disable_di", False), } def _create_config_state( @@ -174,9 +176,11 @@ def _create_config_state( extra_commit_statuses=settings.get("extra_commit_statuses"), extra_rollback_statuses=settings.get("extra_rollback_statuses"), enable_correlation_middleware=settings["enable_correlation_middleware"], + disable_di=settings["disable_di"], ) - self._setup_handlers(state) + if not state.disable_di: + self._setup_handlers(state) return state def _setup_handlers(self, state: _PluginConfigState) -> None: @@ -256,13 +260,14 @@ def store_sqlspec_in_state() -> None: signature_namespace.update(state.config.get_signature_namespace()) # type: ignore[arg-type] - app_config.before_send.append(state.before_send_handler) - app_config.lifespan.append(state.lifespan_handler) - app_config.dependencies.update({ - state.connection_key: Provide(state.connection_provider), - state.pool_key: Provide(state.pool_provider), - state.session_key: Provide(state.session_provider), - }) + if not state.disable_di: + app_config.before_send.append(state.before_send_handler) + app_config.lifespan.append(state.lifespan_handler) + app_config.dependencies.update({ + state.connection_key: Provide(state.connection_provider), + state.pool_key: Provide(state.pool_provider), + state.session_key: Provide(state.session_provider), + }) if signature_namespace: app_config.signature_namespace.update(signature_namespace) diff --git a/sqlspec/extensions/starlette/_state.py b/sqlspec/extensions/starlette/_state.py index 58b2a64d..0a878004 100644 --- a/sqlspec/extensions/starlette/_state.py +++ b/sqlspec/extensions/starlette/_state.py @@ -23,3 +23,4 @@ class SQLSpecConfigState: commit_mode: CommitMode extra_commit_statuses: "set[int] | None" extra_rollback_statuses: "set[int] | None" + disable_di: bool diff --git a/sqlspec/extensions/starlette/extension.py b/sqlspec/extensions/starlette/extension.py index 4a2ad5df..8344e6ba 100644 --- a/sqlspec/extensions/starlette/extension.py +++ b/sqlspec/extensions/starlette/extension.py @@ -104,6 +104,7 @@ def _extract_starlette_settings(self, config: Any) -> "dict[str, Any]": "commit_mode": commit_mode, "extra_commit_statuses": starlette_config.get("extra_commit_statuses"), "extra_rollback_statuses": starlette_config.get("extra_rollback_statuses"), + "disable_di": starlette_config.get("disable_di", False), } def _create_config_state(self, config: Any, settings: "dict[str, Any]") -> SQLSpecConfigState: @@ -124,6 +125,7 @@ def _create_config_state(self, config: Any, settings: "dict[str, Any]") -> SQLSp commit_mode=settings["commit_mode"], extra_commit_statuses=settings["extra_commit_statuses"], extra_rollback_statuses=settings["extra_rollback_statuses"], + disable_di=settings["disable_di"], ) def init_app(self, app: "Starlette") -> None: @@ -146,7 +148,8 @@ async def combined_lifespan(app: "Starlette") -> "AsyncGenerator[None, None]": app.router.lifespan_context = combined_lifespan for config_state in self._config_states: - self._add_middleware(app, config_state) + if not config_state.disable_di: + self._add_middleware(app, config_state) def _validate_unique_keys(self) -> None: """Validate that all state keys are unique across configs. diff --git a/tests/integration/test_extensions/test_external_di_provider.py b/tests/integration/test_extensions/test_external_di_provider.py new file mode 100644 index 00000000..506ea414 --- /dev/null +++ b/tests/integration/test_extensions/test_external_di_provider.py @@ -0,0 +1,137 @@ +"""Integration tests for disable_di flag across all framework extensions.""" + +import tempfile + +import pytest +from flask import Flask, g +from starlette.applications import Starlette +from starlette.requests import Request +from starlette.responses import JSONResponse, Response +from starlette.routing import Route +from starlette.testclient import TestClient + +from sqlspec.adapters.aiosqlite import AiosqliteConfig +from sqlspec.adapters.sqlite import SqliteConfig +from sqlspec.base import SQLSpec +from sqlspec.extensions.flask import SQLSpecPlugin as FlaskPlugin +from sqlspec.extensions.starlette import SQLSpecPlugin as StarlettePlugin + +pytestmark = pytest.mark.xdist_group("sqlite") + + +def test_starlette_disable_di_disables_middleware() -> None: + """Test that disable_di disables middleware in Starlette extension.""" + with tempfile.NamedTemporaryFile(suffix=".db", delete=True) as tmp: + sql = SQLSpec() + config = AiosqliteConfig( + pool_config={"database": tmp.name}, extension_config={"starlette": {"disable_di": True}} + ) + sql.add_config(config) + db_ext = StarlettePlugin(sql) + + async def test_route(request: Request) -> Response: + pool = await config.create_pool() + async with config.provide_connection(pool) as connection: + session = config.driver_type(connection=connection, statement_config=config.statement_config) + result = await session.execute("SELECT 1 as value") + data = result.get_first() + assert data is not None + await config.close_pool() + return JSONResponse({"value": data["value"]}) + + app = Starlette(routes=[Route("/", test_route)]) + db_ext.init_app(app) + + with TestClient(app) as client: + response = client.get("/") + assert response.status_code == 200 + assert response.json() == {"value": 1} + + +def test_flask_disable_di_disables_hooks() -> None: + """Test that disable_di disables request hooks in Flask extension.""" + with tempfile.NamedTemporaryFile(suffix=".db", delete=True) as tmp: + sql = SQLSpec() + config = SqliteConfig(pool_config={"database": tmp.name}, extension_config={"flask": {"disable_di": True}}) + sql.add_config(config) + + app = Flask(__name__) + FlaskPlugin(sql, app) + + @app.route("/test") + def test_route(): + pool = config.create_pool() + with config.provide_connection(pool) as connection: + session = config.driver_type(connection=connection, statement_config=config.statement_config) + result = session.execute("SELECT 1 as value") + data = result.get_first() + assert data is not None + config.close_pool() + return {"value": data["value"]} + + @app.route("/check_g") + def check_g(): + return {"has_connection": hasattr(g, "sqlspec_connection_db_session")} + + with app.test_client() as client: + response = client.get("/test") + assert response.status_code == 200 + response_json = response.json + assert response_json is not None + assert response_json == {"value": 1} + + response = client.get("/check_g") + assert response.status_code == 200 + response_json = response.json + assert response_json is not None + assert response_json == {"has_connection": False} + + +def test_starlette_default_di_provider_enabled() -> None: + """Test that default behavior has disable_di=False.""" + with tempfile.NamedTemporaryFile(suffix=".db", delete=True) as tmp: + sql = SQLSpec() + config = AiosqliteConfig( + pool_config={"database": tmp.name}, extension_config={"starlette": {"session_key": "db"}} + ) + sql.add_config(config) + db_ext = StarlettePlugin(sql) + + async def test_route(request: Request) -> Response: + session = db_ext.get_session(request, "db") + result = await session.execute("SELECT 1 as value") + data = result.get_first() + return JSONResponse({"value": data["value"]}) + + app = Starlette(routes=[Route("/", test_route)]) + db_ext.init_app(app) + + with TestClient(app) as client: + response = client.get("/") + assert response.status_code == 200 + assert response.json() == {"value": 1} + + +def test_flask_default_di_provider_enabled() -> None: + """Test that default behavior has disable_di=False.""" + with tempfile.NamedTemporaryFile(suffix=".db", delete=True) as tmp: + sql = SQLSpec() + config = SqliteConfig(pool_config={"database": tmp.name}, extension_config={"flask": {"session_key": "db"}}) + sql.add_config(config) + + app = Flask(__name__) + plugin = FlaskPlugin(sql, app) + + @app.route("/test") + def test_route(): + session = plugin.get_session("db") + result = session.execute("SELECT 1 as value") + data = result.get_first() + return {"value": data["value"]} + + with app.test_client() as client: + response = client.get("/test") + assert response.status_code == 200 + response_json = response.json + assert response_json is not None + assert response_json == {"value": 1} diff --git a/tests/unit/test_extensions/test_flask/test_state.py b/tests/unit/test_extensions/test_flask/test_state.py index 770cbaaf..b9b347a2 100644 --- a/tests/unit/test_extensions/test_flask/test_state.py +++ b/tests/unit/test_extensions/test_flask/test_state.py @@ -13,6 +13,7 @@ def test_should_commit_manual_mode() -> None: extra_commit_statuses=None, extra_rollback_statuses=None, is_async=False, + disable_di=False, ) assert not state.should_commit(200) @@ -32,6 +33,7 @@ def test_should_commit_autocommit_mode() -> None: extra_commit_statuses=None, extra_rollback_statuses=None, is_async=False, + disable_di=False, ) assert state.should_commit(200) @@ -55,6 +57,7 @@ def test_should_commit_autocommit_include_redirect_mode() -> None: extra_commit_statuses=None, extra_rollback_statuses=None, is_async=False, + disable_di=False, ) assert state.should_commit(200) @@ -81,6 +84,7 @@ def test_should_commit_extra_commit_statuses() -> None: extra_commit_statuses={404, 500}, extra_rollback_statuses=None, is_async=False, + disable_di=False, ) assert state.should_commit(200) @@ -98,6 +102,7 @@ def test_should_commit_extra_rollback_statuses() -> None: extra_commit_statuses=None, extra_rollback_statuses={201}, is_async=False, + disable_di=False, ) assert state.should_commit(200) @@ -114,6 +119,7 @@ def test_should_rollback_manual_mode() -> None: extra_commit_statuses=None, extra_rollback_statuses=None, is_async=False, + disable_di=False, ) assert not state.should_rollback(200) @@ -131,6 +137,7 @@ def test_should_rollback_autocommit_mode() -> None: extra_commit_statuses=None, extra_rollback_statuses=None, is_async=False, + disable_di=False, ) assert not state.should_rollback(200) @@ -152,6 +159,7 @@ def test_should_rollback_autocommit_include_redirect_mode() -> None: extra_commit_statuses=None, extra_rollback_statuses=None, is_async=False, + disable_di=False, ) assert not state.should_rollback(200) diff --git a/tests/unit/test_extensions/test_starlette/test_config_state.py b/tests/unit/test_extensions/test_starlette/test_config_state.py index a7c20f6a..70f43efa 100644 --- a/tests/unit/test_extensions/test_starlette/test_config_state.py +++ b/tests/unit/test_extensions/test_starlette/test_config_state.py @@ -17,6 +17,7 @@ def test_config_state_creation() -> None: commit_mode="manual", extra_commit_statuses=None, extra_rollback_statuses=None, + disable_di=False, ) assert state.config is mock_config @@ -42,6 +43,7 @@ def test_config_state_with_extra_statuses() -> None: commit_mode="autocommit", extra_commit_statuses=extra_commit, extra_rollback_statuses=extra_rollback, + disable_di=False, ) assert state.extra_commit_statuses == extra_commit @@ -61,5 +63,6 @@ def test_config_state_commit_modes() -> None: commit_mode=mode, # type: ignore[arg-type] extra_commit_statuses=None, extra_rollback_statuses=None, + disable_di=False, ) assert state.commit_mode == mode diff --git a/tests/unit/test_extensions/test_starlette/test_utils.py b/tests/unit/test_extensions/test_starlette/test_utils.py index 57a60a9b..28d8eed1 100644 --- a/tests/unit/test_extensions/test_starlette/test_utils.py +++ b/tests/unit/test_extensions/test_starlette/test_utils.py @@ -22,6 +22,7 @@ def test_get_connection_from_request() -> None: commit_mode="manual", extra_commit_statuses=None, extra_rollback_statuses=None, + disable_di=False, ) setattr(mock_request.state, "db_connection", mock_connection) @@ -45,6 +46,7 @@ def test_get_connection_from_request_raises_when_missing() -> None: commit_mode="manual", extra_commit_statuses=None, extra_rollback_statuses=None, + disable_di=False, ) with pytest.raises(AttributeError): @@ -70,6 +72,7 @@ def test_get_or_create_session_creates_new_session() -> None: commit_mode="manual", extra_commit_statuses=None, extra_rollback_statuses=None, + disable_di=False, ) setattr(mock_request.state, "db_connection", mock_connection) @@ -103,6 +106,7 @@ def test_get_or_create_session_returns_cached_session() -> None: commit_mode="manual", extra_commit_statuses=None, extra_rollback_statuses=None, + disable_di=False, ) setattr(mock_request.state, "db_connection", mock_connection) @@ -133,6 +137,7 @@ def test_get_or_create_session_uses_unique_cache_key() -> None: commit_mode="manual", extra_commit_statuses=None, extra_rollback_statuses=None, + disable_di=False, ) setattr(mock_request.state, "db_connection", mock_connection) From a1ad861e6bf653aba8b6c268c157f6a3af5050a7 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sun, 2 Nov 2025 01:41:45 +0000 Subject: [PATCH 2/3] feat: add disable_di flag to all framework extensions Adds `disable_di` boolean flag to Litestar, Starlette/FastAPI, and Flask extensions. When enabled, disables built-in dependency injection to allow users to manage database lifecycle with their own DI solutions (e.g., Dishka). ## Changes ### Configuration - Added `disable_di: NotRequired[bool]` to FlaskConfig, LitestarConfig, and StarletteConfig TypedDicts in sqlspec/config.py - Default: False (backward compatible) ### Framework Extensions - **Litestar**: Conditionally skip DI provider registration when disable_di=True - **Starlette/FastAPI**: Conditionally skip middleware when disable_di=True - **Flask**: Conditionally skip request hooks when disable_di=True ### Test Organization - Reorganized extension integration tests into subdirectories: - tests/integration/test_extensions/test_flask/ - tests/integration/test_extensions/test_starlette/ - tests/integration/test_extensions/test_fastapi/ - tests/integration/test_extensions/test_litestar/ - Added comprehensive disable_di tests for all three frameworks - Verified backward compatibility with default behavior tests ## Testing - All extension unit tests pass (104 passed, 1 skipped) - All extension integration tests pass (37 passed, 1 skipped) - Type checking clean (mypy + pyright: 0 errors) --- .../test_external_di_provider.py | 137 ------------------ .../test_fastapi_filters_integration.py | 0 .../test_fastapi_integration.py | 0 .../test_flask/test_flask_disable_di.py | 76 ++++++++++ .../test_flask_integration.py | 0 .../test_litestar/test_litestar_disable_di.py | 69 +++++++++ .../test_starlette_disable_di.py | 70 +++++++++ .../test_starlette_integration.py | 0 8 files changed, 215 insertions(+), 137 deletions(-) delete mode 100644 tests/integration/test_extensions/test_external_di_provider.py rename tests/integration/test_extensions/{ => test_fastapi}/test_fastapi_filters_integration.py (100%) rename tests/integration/test_extensions/{ => test_fastapi}/test_fastapi_integration.py (100%) create mode 100644 tests/integration/test_extensions/test_flask/test_flask_disable_di.py rename tests/integration/test_extensions/{ => test_flask}/test_flask_integration.py (100%) create mode 100644 tests/integration/test_extensions/test_litestar/test_litestar_disable_di.py create mode 100644 tests/integration/test_extensions/test_starlette/test_starlette_disable_di.py rename tests/integration/test_extensions/{ => test_starlette}/test_starlette_integration.py (100%) diff --git a/tests/integration/test_extensions/test_external_di_provider.py b/tests/integration/test_extensions/test_external_di_provider.py deleted file mode 100644 index 506ea414..00000000 --- a/tests/integration/test_extensions/test_external_di_provider.py +++ /dev/null @@ -1,137 +0,0 @@ -"""Integration tests for disable_di flag across all framework extensions.""" - -import tempfile - -import pytest -from flask import Flask, g -from starlette.applications import Starlette -from starlette.requests import Request -from starlette.responses import JSONResponse, Response -from starlette.routing import Route -from starlette.testclient import TestClient - -from sqlspec.adapters.aiosqlite import AiosqliteConfig -from sqlspec.adapters.sqlite import SqliteConfig -from sqlspec.base import SQLSpec -from sqlspec.extensions.flask import SQLSpecPlugin as FlaskPlugin -from sqlspec.extensions.starlette import SQLSpecPlugin as StarlettePlugin - -pytestmark = pytest.mark.xdist_group("sqlite") - - -def test_starlette_disable_di_disables_middleware() -> None: - """Test that disable_di disables middleware in Starlette extension.""" - with tempfile.NamedTemporaryFile(suffix=".db", delete=True) as tmp: - sql = SQLSpec() - config = AiosqliteConfig( - pool_config={"database": tmp.name}, extension_config={"starlette": {"disable_di": True}} - ) - sql.add_config(config) - db_ext = StarlettePlugin(sql) - - async def test_route(request: Request) -> Response: - pool = await config.create_pool() - async with config.provide_connection(pool) as connection: - session = config.driver_type(connection=connection, statement_config=config.statement_config) - result = await session.execute("SELECT 1 as value") - data = result.get_first() - assert data is not None - await config.close_pool() - return JSONResponse({"value": data["value"]}) - - app = Starlette(routes=[Route("/", test_route)]) - db_ext.init_app(app) - - with TestClient(app) as client: - response = client.get("/") - assert response.status_code == 200 - assert response.json() == {"value": 1} - - -def test_flask_disable_di_disables_hooks() -> None: - """Test that disable_di disables request hooks in Flask extension.""" - with tempfile.NamedTemporaryFile(suffix=".db", delete=True) as tmp: - sql = SQLSpec() - config = SqliteConfig(pool_config={"database": tmp.name}, extension_config={"flask": {"disable_di": True}}) - sql.add_config(config) - - app = Flask(__name__) - FlaskPlugin(sql, app) - - @app.route("/test") - def test_route(): - pool = config.create_pool() - with config.provide_connection(pool) as connection: - session = config.driver_type(connection=connection, statement_config=config.statement_config) - result = session.execute("SELECT 1 as value") - data = result.get_first() - assert data is not None - config.close_pool() - return {"value": data["value"]} - - @app.route("/check_g") - def check_g(): - return {"has_connection": hasattr(g, "sqlspec_connection_db_session")} - - with app.test_client() as client: - response = client.get("/test") - assert response.status_code == 200 - response_json = response.json - assert response_json is not None - assert response_json == {"value": 1} - - response = client.get("/check_g") - assert response.status_code == 200 - response_json = response.json - assert response_json is not None - assert response_json == {"has_connection": False} - - -def test_starlette_default_di_provider_enabled() -> None: - """Test that default behavior has disable_di=False.""" - with tempfile.NamedTemporaryFile(suffix=".db", delete=True) as tmp: - sql = SQLSpec() - config = AiosqliteConfig( - pool_config={"database": tmp.name}, extension_config={"starlette": {"session_key": "db"}} - ) - sql.add_config(config) - db_ext = StarlettePlugin(sql) - - async def test_route(request: Request) -> Response: - session = db_ext.get_session(request, "db") - result = await session.execute("SELECT 1 as value") - data = result.get_first() - return JSONResponse({"value": data["value"]}) - - app = Starlette(routes=[Route("/", test_route)]) - db_ext.init_app(app) - - with TestClient(app) as client: - response = client.get("/") - assert response.status_code == 200 - assert response.json() == {"value": 1} - - -def test_flask_default_di_provider_enabled() -> None: - """Test that default behavior has disable_di=False.""" - with tempfile.NamedTemporaryFile(suffix=".db", delete=True) as tmp: - sql = SQLSpec() - config = SqliteConfig(pool_config={"database": tmp.name}, extension_config={"flask": {"session_key": "db"}}) - sql.add_config(config) - - app = Flask(__name__) - plugin = FlaskPlugin(sql, app) - - @app.route("/test") - def test_route(): - session = plugin.get_session("db") - result = session.execute("SELECT 1 as value") - data = result.get_first() - return {"value": data["value"]} - - with app.test_client() as client: - response = client.get("/test") - assert response.status_code == 200 - response_json = response.json - assert response_json is not None - assert response_json == {"value": 1} diff --git a/tests/integration/test_extensions/test_fastapi_filters_integration.py b/tests/integration/test_extensions/test_fastapi/test_fastapi_filters_integration.py similarity index 100% rename from tests/integration/test_extensions/test_fastapi_filters_integration.py rename to tests/integration/test_extensions/test_fastapi/test_fastapi_filters_integration.py diff --git a/tests/integration/test_extensions/test_fastapi_integration.py b/tests/integration/test_extensions/test_fastapi/test_fastapi_integration.py similarity index 100% rename from tests/integration/test_extensions/test_fastapi_integration.py rename to tests/integration/test_extensions/test_fastapi/test_fastapi_integration.py diff --git a/tests/integration/test_extensions/test_flask/test_flask_disable_di.py b/tests/integration/test_extensions/test_flask/test_flask_disable_di.py new file mode 100644 index 00000000..829c699d --- /dev/null +++ b/tests/integration/test_extensions/test_flask/test_flask_disable_di.py @@ -0,0 +1,76 @@ +"""Integration tests for disable_di flag in Flask extension.""" + +import tempfile + +import pytest +from flask import Flask, g + +from sqlspec.adapters.sqlite import SqliteConfig +from sqlspec.base import SQLSpec +from sqlspec.extensions.flask import SQLSpecPlugin + +pytestmark = pytest.mark.xdist_group("sqlite") + + +def test_flask_disable_di_disables_hooks() -> None: + """Test that disable_di disables request hooks in Flask extension.""" + with tempfile.NamedTemporaryFile(suffix=".db", delete=True) as tmp: + sql = SQLSpec() + config = SqliteConfig(pool_config={"database": tmp.name}, extension_config={"flask": {"disable_di": True}}) + sql.add_config(config) + + app = Flask(__name__) + SQLSpecPlugin(sql, app) + + @app.route("/test") + def test_route(): + pool = config.create_pool() + with config.provide_connection(pool) as connection: + session = config.driver_type(connection=connection, statement_config=config.statement_config) + result = session.execute("SELECT 1 as value") + data = result.get_first() + assert data is not None + config.close_pool() + return {"value": data["value"]} + + @app.route("/check_g") + def check_g(): + return {"has_connection": hasattr(g, "sqlspec_connection_db_session")} + + with app.test_client() as client: + response = client.get("/test") + assert response.status_code == 200 + response_json = response.json + assert response_json is not None + assert response_json == {"value": 1} + + response = client.get("/check_g") + assert response.status_code == 200 + response_json = response.json + assert response_json is not None + assert response_json == {"has_connection": False} + + +def test_flask_default_di_enabled() -> None: + """Test that default behavior has disable_di=False.""" + with tempfile.NamedTemporaryFile(suffix=".db", delete=True) as tmp: + sql = SQLSpec() + config = SqliteConfig(pool_config={"database": tmp.name}, extension_config={"flask": {"session_key": "db"}}) + sql.add_config(config) + + app = Flask(__name__) + plugin = SQLSpecPlugin(sql, app) + + @app.route("/test") + def test_route(): + session = plugin.get_session("db") + result = session.execute("SELECT 1 as value") + data = result.get_first() + return {"value": data["value"]} + + with app.test_client() as client: + response = client.get("/test") + assert response.status_code == 200 + response_json = response.json + assert response_json is not None + assert response_json == {"value": 1} diff --git a/tests/integration/test_extensions/test_flask_integration.py b/tests/integration/test_extensions/test_flask/test_flask_integration.py similarity index 100% rename from tests/integration/test_extensions/test_flask_integration.py rename to tests/integration/test_extensions/test_flask/test_flask_integration.py diff --git a/tests/integration/test_extensions/test_litestar/test_litestar_disable_di.py b/tests/integration/test_extensions/test_litestar/test_litestar_disable_di.py new file mode 100644 index 00000000..c8d2968e --- /dev/null +++ b/tests/integration/test_extensions/test_litestar/test_litestar_disable_di.py @@ -0,0 +1,69 @@ +"""Integration tests for disable_di flag in Litestar extension.""" + +import tempfile + +import pytest +from litestar import Litestar, Request, get +from litestar.testing import TestClient + +from sqlspec.adapters.aiosqlite import AiosqliteConfig +from sqlspec.base import SQLSpec +from sqlspec.extensions.litestar import SQLSpecPlugin + +pytestmark = pytest.mark.xdist_group("sqlite") + + +def test_litestar_disable_di_disables_providers() -> None: + """Test that disable_di disables dependency providers in Litestar extension.""" + with tempfile.NamedTemporaryFile(suffix=".db", delete=True) as tmp: + sql = SQLSpec() + config = AiosqliteConfig( + pool_config={"database": tmp.name}, extension_config={"litestar": {"disable_di": True}} + ) + sql.add_config(config) + plugin = SQLSpecPlugin(sqlspec=sql) + + @get("/test") + async def test_route(request: Request) -> dict: + pool = await config.create_pool() + async with config.provide_connection(pool) as connection: + session = config.driver_type(connection=connection, statement_config=config.statement_config) + result = await session.execute("SELECT 1 as value") + data = result.get_first() + assert data is not None + await config.close_pool() + return {"value": data["value"]} + + app = Litestar(route_handlers=[test_route], plugins=[plugin]) + + with TestClient(app=app) as client: + response = client.get("/test") + assert response.status_code == 200 + assert response.json() == {"value": 1} + + +def test_litestar_default_di_enabled() -> None: + """Test that default behavior has disable_di=False.""" + from sqlspec.driver import AsyncDriverAdapterBase + + with tempfile.NamedTemporaryFile(suffix=".db", delete=True) as tmp: + sql = SQLSpec() + config = AiosqliteConfig( + pool_config={"database": tmp.name}, extension_config={"litestar": {"session_key": "db"}} + ) + sql.add_config(config) + plugin = SQLSpecPlugin(sqlspec=sql) + + @get("/test") + async def test_route(db: AsyncDriverAdapterBase) -> dict: + result = await db.execute("SELECT 1 as value") + data = result.get_first() + assert data is not None + return {"value": data["value"]} + + app = Litestar(route_handlers=[test_route], plugins=[plugin]) + + with TestClient(app=app) as client: + response = client.get("/test") + assert response.status_code == 200 + assert response.json() == {"value": 1} diff --git a/tests/integration/test_extensions/test_starlette/test_starlette_disable_di.py b/tests/integration/test_extensions/test_starlette/test_starlette_disable_di.py new file mode 100644 index 00000000..fe21bfab --- /dev/null +++ b/tests/integration/test_extensions/test_starlette/test_starlette_disable_di.py @@ -0,0 +1,70 @@ +"""Integration tests for disable_di flag in Starlette extension.""" + +import tempfile + +import pytest +from starlette.applications import Starlette +from starlette.requests import Request +from starlette.responses import JSONResponse, Response +from starlette.routing import Route +from starlette.testclient import TestClient + +from sqlspec.adapters.aiosqlite import AiosqliteConfig +from sqlspec.base import SQLSpec +from sqlspec.extensions.starlette import SQLSpecPlugin + +pytestmark = pytest.mark.xdist_group("sqlite") + + +def test_starlette_disable_di_disables_middleware() -> None: + """Test that disable_di disables middleware in Starlette extension.""" + with tempfile.NamedTemporaryFile(suffix=".db", delete=True) as tmp: + sql = SQLSpec() + config = AiosqliteConfig( + pool_config={"database": tmp.name}, extension_config={"starlette": {"disable_di": True}} + ) + sql.add_config(config) + db_ext = SQLSpecPlugin(sql) + + async def test_route(request: Request) -> Response: + pool = await config.create_pool() + async with config.provide_connection(pool) as connection: + session = config.driver_type(connection=connection, statement_config=config.statement_config) + result = await session.execute("SELECT 1 as value") + data = result.get_first() + assert data is not None + await config.close_pool() + return JSONResponse({"value": data["value"]}) + + app = Starlette(routes=[Route("/", test_route)]) + db_ext.init_app(app) + + with TestClient(app) as client: + response = client.get("/") + assert response.status_code == 200 + assert response.json() == {"value": 1} + + +def test_starlette_default_di_enabled() -> None: + """Test that default behavior has disable_di=False.""" + with tempfile.NamedTemporaryFile(suffix=".db", delete=True) as tmp: + sql = SQLSpec() + config = AiosqliteConfig( + pool_config={"database": tmp.name}, extension_config={"starlette": {"session_key": "db"}} + ) + sql.add_config(config) + db_ext = SQLSpecPlugin(sql) + + async def test_route(request: Request) -> Response: + session = db_ext.get_session(request, "db") + result = await session.execute("SELECT 1 as value") + data = result.get_first() + return JSONResponse({"value": data["value"]}) + + app = Starlette(routes=[Route("/", test_route)]) + db_ext.init_app(app) + + with TestClient(app) as client: + response = client.get("/") + assert response.status_code == 200 + assert response.json() == {"value": 1} diff --git a/tests/integration/test_extensions/test_starlette_integration.py b/tests/integration/test_extensions/test_starlette/test_starlette_integration.py similarity index 100% rename from tests/integration/test_extensions/test_starlette_integration.py rename to tests/integration/test_extensions/test_starlette/test_starlette_integration.py From 6d307c03f62e3b5f3ff65e0c762d395a7558aaec Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sun, 2 Nov 2025 01:47:47 +0000 Subject: [PATCH 3/3] feat: add disable_di flag for custom dependency injection management --- AGENTS.md | 132 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 132 insertions(+) diff --git a/AGENTS.md b/AGENTS.md index a1888235..eda9023a 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -2616,6 +2616,138 @@ class StarletteConfig(TypedDict): extra_rollback_statuses: NotRequired[set[int]] ``` +### Disabling Built-in Dependency Injection (disable_di Pattern) + +**When to Use**: When users want to integrate SQLSpec with their own dependency injection solution (e.g., Dishka, dependency-injector) and need full control over database lifecycle management. + +**Pattern**: Add a `disable_di` boolean flag to framework extension configuration that conditionally skips the built-in DI setup. + +**Implementation Steps**: + +1. **Add to TypedDict in `sqlspec/config.py`**: + +```python +class StarletteConfig(TypedDict): + # ... existing fields ... + + disable_di: NotRequired[bool] + """Disable built-in dependency injection. Default: False. + When True, the Starlette/FastAPI extension will not add middleware for managing + database connections and sessions. Users are responsible for managing the + database lifecycle manually via their own DI solution. + """ +``` + +2. **Add to Configuration State Dataclass**: + +```python +@dataclass +class SQLSpecConfigState: + config: "DatabaseConfigProtocol[Any, Any, Any]" + connection_key: str + pool_key: str + session_key: str + commit_mode: CommitMode + extra_commit_statuses: "set[int] | None" + extra_rollback_statuses: "set[int] | None" + disable_di: bool # Add this field +``` + +3. **Extract from Config and Default to False**: + +```python +def _extract_starlette_settings(self, config): + starlette_config = config.extension_config.get("starlette", {}) + return { + # ... existing keys ... + "disable_di": starlette_config.get("disable_di", False), # Default False + } +``` + +4. **Conditionally Skip DI Setup**: + +**Middleware-based (Starlette/FastAPI)**: +```python +def init_app(self, app): + # ... lifespan setup ... + + for config_state in self._config_states: + if not config_state.disable_di: # Only add if DI enabled + self._add_middleware(app, config_state) +``` + +**Provider-based (Litestar)**: +```python +def on_app_init(self, app_config): + for state in self._plugin_configs: + # ... signature namespace ... + + if not state.disable_di: # Only register if DI enabled + app_config.before_send.append(state.before_send_handler) + app_config.lifespan.append(state.lifespan_handler) + app_config.dependencies.update({ + state.connection_key: Provide(state.connection_provider), + state.pool_key: Provide(state.pool_provider), + state.session_key: Provide(state.session_provider), + }) +``` + +**Hook-based (Flask)**: +```python +def init_app(self, app): + # ... pool setup ... + + # Only register hooks if at least one config has DI enabled + if any(not state.disable_di for state in self._config_states): + app.before_request(self._before_request_handler) + app.after_request(self._after_request_handler) + app.teardown_appcontext(self._teardown_appcontext_handler) + +def _before_request_handler(self): + for config_state in self._config_states: + if config_state.disable_di: # Skip if DI disabled + continue + # ... connection setup ... +``` + +**Testing Requirements**: + +1. **Test with `disable_di=True`**: Verify DI mechanisms are not active +2. **Test default behavior**: Verify `disable_di=False` preserves existing functionality +3. **Integration tests**: Demonstrate manual DI setup works correctly + +**Example Usage**: + +```python +from sqlspec.adapters.asyncpg import AsyncpgConfig +from sqlspec.base import SQLSpec +from sqlspec.extensions.starlette import SQLSpecPlugin + +sql = SQLSpec() +config = AsyncpgConfig( + pool_config={"dsn": "postgresql://localhost/db"}, + extension_config={"starlette": {"disable_di": True}} # Disable built-in DI +) +sql.add_config(config) +plugin = SQLSpecPlugin(sql) + +# User is now responsible for manual lifecycle management +async def my_route(request): + pool = await config.create_pool() + async with config.provide_connection(pool) as connection: + session = config.driver_type(connection=connection, statement_config=config.statement_config) + result = await session.execute("SELECT 1") + await config.close_pool() + return result +``` + +**Key Principles**: + +- **Backward Compatible**: Default `False` preserves existing behavior +- **Consistent Naming**: Use `disable_di` across all frameworks +- **Clear Documentation**: Warn users they are responsible for lifecycle management +- **Complete Control**: When disabled, extension does zero automatic DI + ### Multi-Database Support **Key validation ensures unique state keys**: