From be699855e81a770b4f148c7f9843e917d0eeeff5 Mon Sep 17 00:00:00 2001 From: provinzkraut <25355197+provinzkraut@users.noreply.github.com> Date: Fri, 28 Oct 2022 12:39:39 +0200 Subject: [PATCH] Support setting sessions explicit to `Empty`. This was supported in the previous SessionMiddleware and was erroneously removed in #630. Test have been added. --- starlite/connection/base.py | 3 ++- starlite/middleware/session/base.py | 10 ++++++---- starlite/middleware/session/cookie_backend.py | 3 ++- starlite/testing/create_test_client.py | 4 ++-- starlite/testing/test_client/client.py | 4 ++-- tests/middleware/session/test_middleware.py | 20 +++++++++++++++++++ 6 files changed, 34 insertions(+), 10 deletions(-) diff --git a/starlite/connection/base.py b/starlite/connection/base.py index e7523a2dd5..f32c72bf50 100644 --- a/starlite/connection/base.py +++ b/starlite/connection/base.py @@ -5,6 +5,7 @@ Generic, List, Optional, + Type, TypeVar, Union, cast, @@ -260,7 +261,7 @@ def cache(self) -> "Cache": """ return self.app.cache - def set_session(self, value: Union[Dict[str, Any], "BaseModel"]) -> None: + def set_session(self, value: Union[Dict[str, Any], "BaseModel", Type["Empty"]]) -> None: """Helper method to set the session in scope. If the [Starlite SessionMiddleware][starlite.middleware.session.SessionMiddleware] is diff --git a/starlite/middleware/session/base.py b/starlite/middleware/session/base.py index 4d440ab016..dd18264a35 100644 --- a/starlite/middleware/session/base.py +++ b/starlite/middleware/session/base.py @@ -25,6 +25,7 @@ from starlite import ASGIConnection, Cookie, DefineMiddleware from starlite.middleware.base import MiddlewareProtocol from starlite.middleware.util import should_bypass_middleware +from starlite.types import Empty from starlite.utils import default_serializer, get_serializer_from_scope if TYPE_CHECKING: @@ -276,17 +277,18 @@ async def store_in_message( headers = MutableHeaders(scope=message) session_id = connection.cookies.get(self.config.key, self.generate_session_id()) - serialised_data = self.serlialize_data(scope_session, scope) - await self.set(session_id=session_id, data=serialised_data) - cookie_params = self.config.dict( exclude_none=True, exclude={"secret", "key"} | set(self.config.__fields__) - set(BaseBackendConfig.__fields__), ) - if scope_session: + if scope_session is not Empty: + serialised_data = self.serlialize_data(scope_session, scope) + await self.set(session_id=session_id, data=serialised_data) + headers["Set-Cookie"] = Cookie(value=session_id, key=self.config.key, **cookie_params).to_header(header="") else: + await self.delete(session_id) headers.append( "Set-Cookie", Cookie(value="null", key=self.config.key, expires=0, **cookie_params).to_header(header=""), diff --git a/starlite/middleware/session/cookie_backend.py b/starlite/middleware/session/cookie_backend.py index f4e9abf8e6..342c3e5b29 100644 --- a/starlite/middleware/session/cookie_backend.py +++ b/starlite/middleware/session/cookie_backend.py @@ -12,6 +12,7 @@ from starlite.datastructures.cookie import Cookie from starlite.exceptions import MissingDependencyException +from starlite.types import Empty from .base import BaseBackendConfig, BaseSessionBackend @@ -130,7 +131,7 @@ async def store_in_message( headers = MutableHeaders(scope=message) cookie_keys = self.get_cookie_keys(connection) - if scope_session: + if scope_session and scope_session is not Empty: data = self.dump_data(scope_session, scope=scope) cookie_params = self.config.dict(exclude_none=True, exclude={"secret", "key"}) for cookie in self._create_session_cookies(data, cookie_params): diff --git a/starlite/testing/create_test_client.py b/starlite/testing/create_test_client.py index 24dd7c66c0..ab24c42e94 100644 --- a/starlite/testing/create_test_client.py +++ b/starlite/testing/create_test_client.py @@ -19,7 +19,7 @@ TemplateConfig, WebSocket, ) - from starlite.middleware.session import SessionCookieConfig + from starlite.middleware.session.base import BaseBackendConfig from starlite.types import ( AfterExceptionHookHandler, AfterRequestHookHandler, @@ -73,7 +73,7 @@ def create_test_client( request_class: Optional[Type["Request"]] = None, response_class: Optional["ResponseType"] = None, root_path: str = "", - session_config: Optional["SessionCookieConfig"] = None, + session_config: Optional["BaseBackendConfig"] = None, static_files_config: Optional[Union["StaticFilesConfig", List["StaticFilesConfig"]]] = None, template_config: Optional["TemplateConfig"] = None, websocket_class: Optional[Type["WebSocket"]] = None, diff --git a/starlite/testing/test_client/client.py b/starlite/testing/test_client/client.py index ef632605b6..3c67730148 100644 --- a/starlite/testing/test_client/client.py +++ b/starlite/testing/test_client/client.py @@ -661,7 +661,6 @@ def _create_session_cookies(backend: CookieBackend, data: Dict[str, Any]) -> Dic async def _set_session_data_async(self, data: Dict[str, Any]) -> None: # TODO: Expose this in the async client - if isinstance(self.session_backend, ServerSideBackend): serialized_data = self.session_backend.serlialize_data(data) session_id = self.cookies.setdefault( @@ -681,7 +680,8 @@ async def _get_session_data_async(self) -> Dict[str, Any]: session_id = self.cookies.get(self.session_backend.config.key) if session_id: data = await self.session_backend.get(session_id) - return self.session_backend.deserialize_data(data) + if data: + return self.session_backend.deserialize_data(data) elif isinstance(self.session_backend, CookieBackend): raw_data = [ self.cookies[key].encode("utf-8") for key in self.cookies if self.session_backend.config.key in key diff --git a/tests/middleware/session/test_middleware.py b/tests/middleware/session/test_middleware.py index 71d554c0a2..1a52e45384 100644 --- a/tests/middleware/session/test_middleware.py +++ b/tests/middleware/session/test_middleware.py @@ -13,6 +13,7 @@ websocket, ) from starlite.testing import create_test_client +from starlite.types import Empty if TYPE_CHECKING: from starlite.middleware.session.base import BaseBackendConfig @@ -57,6 +58,25 @@ def session_handler(request: Request) -> Optional[Dict[str, bool]]: assert response.json() == {"has_session": False} +def test_set_empty(session_backend_config: "BaseBackendConfig") -> None: + @post("/create-session") + def create_session_handler(request: Request) -> None: + request.set_session({"foo": "bar"}) + + @post("/empty-session") + def empty_session_handler(request: Request) -> None: + request.set_session(Empty) + + with create_test_client( + route_handlers=[create_session_handler, empty_session_handler], + middleware=[session_backend_config.middleware], + session_config=session_backend_config, + ) as client: + client.post("/create-session") + client.post("/empty-session") + assert not client.get_session_data() + + def test_use_of_custom_response_serializer_with_http_handler(session_backend_config: "BaseBackendConfig") -> None: class Obj: inner: str