Skip to content

Commit

Permalink
Fix transient cryptography dependency in testing.test_client.client (#…
Browse files Browse the repository at this point in the history
…711)

* Fix transient cryptography dependency in testing.test_client.client.

This changes the implementations of the session-related methods in `TestClient` to use the actual middleware by constructing fake "Scope" and "ASGIConnection" objects

* Drop cryptography from application test dependencies
  • Loading branch information
provinzkraut committed Oct 30, 2022
1 parent d1b2c2f commit a06084f
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 55 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/publish.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
- name: Install App Dependencies
run: poetry install --no-interaction --no-root --only main
- name: Install Test Dependencies
run: poetry run python -m pip install pytest pytest-asyncio httpx cryptography
run: poetry run python -m pip install pytest pytest-asyncio httpx
- name: Set pythonpath
run: echo "PYTHONPATH=$PWD" >> $GITHUB_ENV
- name: Test
Expand Down
2 changes: 1 addition & 1 deletion starlite/connection/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def clear_session(self) -> None:
Returns:
None.
"""
self.scope["session"] = {}
self.scope["session"] = Empty

def url_for(self, name: str, **path_parameters: Dict[str, Any]) -> str:
"""
Expand Down
19 changes: 11 additions & 8 deletions starlite/middleware/session/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,27 +272,30 @@ async def store_in_message(
Returns:
None
"""

scope = connection.scope
headers = MutableHeaders(scope=message)
session_id = connection.cookies.get(self.config.key, self.generate_session_id())
session_id = connection.cookies.get(self.config.key)
if session_id == "null":
session_id = None
if not session_id:
session_id = self.generate_session_id()

cookie_params = self.config.dict(
exclude_none=True,
exclude={"secret", "key"} | set(self.config.__fields__) - set(BaseBackendConfig.__fields__),
)

if scope_session and scope_session is not Empty:
serialised_data = self.serlialize_data(scope_session, scope)
await self.set(session_id=session_id, data=serialised_data)

headers["Set-Cookie"] = Cookie(value=session_id, key=self.config.key, **cookie_params).to_header(header="")
else:
if scope_session is Empty:
await self.delete(session_id)
headers.append(
"Set-Cookie",
Cookie(value="null", key=self.config.key, expires=0, **cookie_params).to_header(header=""),
)
else:
serialised_data = self.serlialize_data(scope_session, scope)
await self.set(session_id=session_id, data=serialised_data)

headers["Set-Cookie"] = Cookie(value=session_id, key=self.config.key, **cookie_params).to_header(header="")

async def load_from_connection(self, connection: ASGIConnection) -> Dict[str, Any]:
"""Load session data from a connection and return it as a dictionary to
Expand Down
100 changes: 58 additions & 42 deletions starlite/testing/test_client/client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import warnings
from contextlib import ExitStack, contextmanager
from http.cookiejar import CookieJar
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -15,23 +16,19 @@
from urllib.parse import urljoin

from anyio.from_thread import BlockingPortal, start_blocking_portal
from starlette.datastructures import MutableHeaders

from starlite import HttpMethod, ImproperlyConfiguredException
from starlite import ASGIConnection, HttpMethod, ImproperlyConfiguredException
from starlite.exceptions import MissingDependencyException
from starlite.middleware.session.base import ServerSideBackend, ServerSideSessionConfig
from starlite.middleware.session.cookie_backend import (
CookieBackend,
CookieBackendConfig,
)
from starlite.testing.test_client.life_span_handler import LifeSpanHandler
from starlite.testing.test_client.transport import (
ConnectionUpgradeException,
TestClientTransport,
)
from starlite.types import AnyIOBackend, ASGIApp
from starlite.types import AnyIOBackend, ASGIApp, HTTPResponseStartEvent

try:
from httpx import USE_CLIENT_DEFAULT, Client, Response
from httpx import USE_CLIENT_DEFAULT, Client, Cookies, Request, Response
except ImportError as e:
raise MissingDependencyException(
"To use starlite.testing, install starlite with 'testing' extra, e.g. `pip install starlite[testing]`"
Expand All @@ -53,16 +50,40 @@
)

from starlite.middleware.session.base import BaseBackendConfig, BaseSessionBackend
from starlite.middleware.session.cookie_backend import CookieBackend
from starlite.testing.test_client.websocket_test_session import WebSocketTestSession


T = TypeVar("T", bound=ASGIApp)
AnySessionBackend = Union[CookieBackend, ServerSideBackend]
AnySessionConfig = Union["ServerSideSessionConfig", "CookieBackendConfig"]


def raise_for_unsupported_session_backend(backend: "BaseSessionBackend") -> None:
raise ImproperlyConfiguredException(f"Backend of type {type(backend)!r} is currently not supported")
def fake_http_send_message(headers: MutableHeaders) -> HTTPResponseStartEvent:
headers.setdefault("content-type", "application/text")
return HTTPResponseStartEvent(type="http.response.start", status=200, headers=headers.raw)


def fake_asgi_connection(app: ASGIApp, cookies: Dict[str, str]) -> ASGIConnection[Any, Any, Any]:
scope = {
"type": "http",
"path": "/",
"raw_path": b"/",
"root_path": "",
"scheme": "http",
"query_string": b"",
"client": ("testclient", 50000),
"server": ("testserver", 80),
"method": "GET",
"http_version": "1.1",
"extensions": {"http.response.template": {}},
"app": app,
"state": {},
"path_params": {},
"route_handler": None,
"_cookies": cookies,
}
return ASGIConnection[Any, Any, Any](
scope=scope, # type: ignore[arg-type]
)


class TestClient(Client, Generic[T]):
Expand Down Expand Up @@ -124,11 +145,13 @@ def session(self) -> "CookieBackend":
"To access the session backend directly, use the session_backend attribute",
PendingDeprecationWarning,
)
if not isinstance(self._session_backend, CookieBackend):
from starlite.middleware.session.cookie_backend import CookieBackend

if not isinstance(self.session_backend, CookieBackend):
raise ImproperlyConfiguredException(
f"Invalid session backend: {type(self._session_backend)!r}. Expected 'CookieBackend'"
)
return self._session_backend
return self.session_backend

@property
def session_backend(self) -> "BaseSessionBackend":
Expand Down Expand Up @@ -653,42 +676,35 @@ def test_something(self, test_client: TestClient) -> None:
return self.get_session_data()

@staticmethod
def _create_session_cookies(backend: CookieBackend, data: Dict[str, Any]) -> Dict[str, str]:
def _create_session_cookies(backend: "CookieBackend", data: Dict[str, Any]) -> Dict[str, str]:
encoded_data = backend.dump_data(data=data)
return {cookie.key: cast("str", cookie.value) for cookie in backend._create_session_cookies(encoded_data)}

async def _set_session_data_async(self, data: Dict[str, Any]) -> None:
# TODO: Expose this in the async client
if isinstance(self.session_backend, ServerSideBackend):
serialized_data = self.session_backend.serlialize_data(data)
session_id = self.cookies.setdefault(
self.session_backend.config.key, self.session_backend.generate_session_id()
)
await self.session_backend.set(session_id, serialized_data)
elif isinstance(self.session_backend, CookieBackend):
for key, value in self._create_session_cookies(self.session_backend, data).items():
self.cookies.set(key, value)
else:
raise_for_unsupported_session_backend(self.session_backend)
mutable_headers = MutableHeaders({})
await self.session_backend.store_in_message(
scope_session=data,
message=fake_http_send_message(mutable_headers),
connection=fake_asgi_connection(
app=self.app,
cookies=dict(self.cookies),
),
)
response = Response(200, request=Request("GET", self.base_url), headers=mutable_headers.raw)

cookies = Cookies(CookieJar())
cookies.extract_cookies(response)
self.cookies.update(cookies)

async def _get_session_data_async(self) -> Dict[str, Any]:
# TODO: Expose this in the async client

if isinstance(self.session_backend, ServerSideBackend):
session_id = self.cookies.get(self.session_backend.config.key)
if session_id:
data = await self.session_backend.get(session_id)
if data:
return self.session_backend.deserialize_data(data)
elif isinstance(self.session_backend, CookieBackend):
raw_data = [
self.cookies[key].encode("utf-8") for key in self.cookies if self.session_backend.config.key in key
]
if raw_data:
return self.session_backend.load_data(data=raw_data)
else:
raise_for_unsupported_session_backend(self.session_backend)
return {}
return await self.session_backend.load_from_connection(
connection=fake_asgi_connection(
app=self.app,
cookies=dict(self.cookies),
),
)

def set_session_data(self, data: Dict[str, Any]) -> None:
"""Set session data.
Expand Down
7 changes: 6 additions & 1 deletion tests/middleware/session/test_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def session_handler(request: Request) -> Optional[Dict[str, bool]]:
if request.method == HttpMethod.DELETE:
request.clear_session()
else:
request.set_session({"username": "moishezuchmir"})
request.session["username"] = "moishezuchmir"
return None

with create_test_client(route_handlers=[session_handler], middleware=[session_backend_config.middleware]) as client:
Expand All @@ -57,6 +57,11 @@ def session_handler(request: Request) -> Optional[Dict[str, bool]]:
response = client.get("/session")
assert response.json() == {"has_session": False}

client.post("/session")

response = client.get("/session")
assert response.json() == {"has_session": True}


def test_set_empty(session_backend_config_async_safe: "BaseBackendConfig") -> None:
@post("/create-session")
Expand Down
6 changes: 4 additions & 2 deletions tests/testing/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,15 +263,17 @@ def test_test_client_set_session_data(
session_data = {"foo": "bar"}

if with_domain:
session_config.domain = "testserver"
session_config.domain = "testserver.local"

@get(path="/test")
def get_session_data(request: Request) -> Dict[str, Any]:
return request.session

app = Starlite(route_handlers=[get_session_data], middleware=[session_config.middleware])

with TestClient(app=app, session_config=session_config, backend=test_client_backend) as client:
with TestClient(
app=app, session_config=session_config, backend=test_client_backend, base_url="http://testserver.local"
) as client:
client.set_session_data(session_data)
assert session_data == client.get("/test").json()

Expand Down

0 comments on commit a06084f

Please sign in to comment.