Skip to content

Commit

Permalink
Support setting sessions explicit to Empty.
Browse files Browse the repository at this point in the history
This was supported in the previous SessionMiddleware and was erroneously removed in #630. Test have been added.
  • Loading branch information
provinzkraut committed Oct 28, 2022
1 parent 3237206 commit be69985
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 10 deletions.
3 changes: 2 additions & 1 deletion starlite/connection/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Generic,
List,
Optional,
Type,
TypeVar,
Union,
cast,
Expand Down Expand Up @@ -260,7 +261,7 @@ def cache(self) -> "Cache":
"""
return self.app.cache

def set_session(self, value: Union[Dict[str, Any], "BaseModel"]) -> None:
def set_session(self, value: Union[Dict[str, Any], "BaseModel", Type["Empty"]]) -> None:
"""Helper method to set the session in scope.
If the [Starlite SessionMiddleware][starlite.middleware.session.SessionMiddleware] is
Expand Down
10 changes: 6 additions & 4 deletions starlite/middleware/session/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from starlite import ASGIConnection, Cookie, DefineMiddleware
from starlite.middleware.base import MiddlewareProtocol
from starlite.middleware.util import should_bypass_middleware
from starlite.types import Empty
from starlite.utils import default_serializer, get_serializer_from_scope

if TYPE_CHECKING:
Expand Down Expand Up @@ -276,17 +277,18 @@ async def store_in_message(
headers = MutableHeaders(scope=message)
session_id = connection.cookies.get(self.config.key, self.generate_session_id())

serialised_data = self.serlialize_data(scope_session, scope)
await self.set(session_id=session_id, data=serialised_data)

cookie_params = self.config.dict(
exclude_none=True,
exclude={"secret", "key"} | set(self.config.__fields__) - set(BaseBackendConfig.__fields__),
)

if scope_session:
if scope_session is not Empty:
serialised_data = self.serlialize_data(scope_session, scope)
await self.set(session_id=session_id, data=serialised_data)

headers["Set-Cookie"] = Cookie(value=session_id, key=self.config.key, **cookie_params).to_header(header="")
else:
await self.delete(session_id)
headers.append(
"Set-Cookie",
Cookie(value="null", key=self.config.key, expires=0, **cookie_params).to_header(header=""),
Expand Down
3 changes: 2 additions & 1 deletion starlite/middleware/session/cookie_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from starlite.datastructures.cookie import Cookie
from starlite.exceptions import MissingDependencyException
from starlite.types import Empty

from .base import BaseBackendConfig, BaseSessionBackend

Expand Down Expand Up @@ -130,7 +131,7 @@ async def store_in_message(
headers = MutableHeaders(scope=message)
cookie_keys = self.get_cookie_keys(connection)

if scope_session:
if scope_session and scope_session is not Empty:
data = self.dump_data(scope_session, scope=scope)
cookie_params = self.config.dict(exclude_none=True, exclude={"secret", "key"})
for cookie in self._create_session_cookies(data, cookie_params):
Expand Down
4 changes: 2 additions & 2 deletions starlite/testing/create_test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
TemplateConfig,
WebSocket,
)
from starlite.middleware.session import SessionCookieConfig
from starlite.middleware.session.base import BaseBackendConfig
from starlite.types import (
AfterExceptionHookHandler,
AfterRequestHookHandler,
Expand Down Expand Up @@ -73,7 +73,7 @@ def create_test_client(
request_class: Optional[Type["Request"]] = None,
response_class: Optional["ResponseType"] = None,
root_path: str = "",
session_config: Optional["SessionCookieConfig"] = None,
session_config: Optional["BaseBackendConfig"] = None,
static_files_config: Optional[Union["StaticFilesConfig", List["StaticFilesConfig"]]] = None,
template_config: Optional["TemplateConfig"] = None,
websocket_class: Optional[Type["WebSocket"]] = None,
Expand Down
4 changes: 2 additions & 2 deletions starlite/testing/test_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,6 @@ def _create_session_cookies(backend: CookieBackend, data: Dict[str, Any]) -> Dic

async def _set_session_data_async(self, data: Dict[str, Any]) -> None:
# TODO: Expose this in the async client

if isinstance(self.session_backend, ServerSideBackend):
serialized_data = self.session_backend.serlialize_data(data)
session_id = self.cookies.setdefault(
Expand All @@ -681,7 +680,8 @@ async def _get_session_data_async(self) -> Dict[str, Any]:
session_id = self.cookies.get(self.session_backend.config.key)
if session_id:
data = await self.session_backend.get(session_id)
return self.session_backend.deserialize_data(data)
if data:
return self.session_backend.deserialize_data(data)
elif isinstance(self.session_backend, CookieBackend):
raw_data = [
self.cookies[key].encode("utf-8") for key in self.cookies if self.session_backend.config.key in key
Expand Down
20 changes: 20 additions & 0 deletions tests/middleware/session/test_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
websocket,
)
from starlite.testing import create_test_client
from starlite.types import Empty

if TYPE_CHECKING:
from starlite.middleware.session.base import BaseBackendConfig
Expand Down Expand Up @@ -57,6 +58,25 @@ def session_handler(request: Request) -> Optional[Dict[str, bool]]:
assert response.json() == {"has_session": False}


def test_set_empty(session_backend_config: "BaseBackendConfig") -> None:
@post("/create-session")
def create_session_handler(request: Request) -> None:
request.set_session({"foo": "bar"})

@post("/empty-session")
def empty_session_handler(request: Request) -> None:
request.set_session(Empty)

with create_test_client(
route_handlers=[create_session_handler, empty_session_handler],
middleware=[session_backend_config.middleware],
session_config=session_backend_config,
) as client:
client.post("/create-session")
client.post("/empty-session")
assert not client.get_session_data()


def test_use_of_custom_response_serializer_with_http_handler(session_backend_config: "BaseBackendConfig") -> None:
class Obj:
inner: str
Expand Down

0 comments on commit be69985

Please sign in to comment.