Skip to content

Commit

Permalink
Make redirect an exception in load_models
Browse files Browse the repository at this point in the history
  • Loading branch information
jace committed May 17, 2024
1 parent 039dfea commit c1a02f5
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 48 deletions.
49 changes: 40 additions & 9 deletions src/coaster/views/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@
abort,
jsonify,
make_response,
redirect,
render_template,
request,
url_for,
)
from flask.typing import ResponseReturnValue
from markupsafe import escape as html_escape
from werkzeug.datastructures import Headers, MIMEAccept
from werkzeug.exceptions import BadRequest
from werkzeug.exceptions import BadRequest, HTTPException
from werkzeug.wrappers import Response as WerkzeugResponse

from ..auth import add_auth_attribute, current_auth
Expand All @@ -43,6 +43,7 @@
'ReturnRenderWith',
'RequestTypeError',
'RequestValueError',
'Redirect',
'requestargs',
'requestquery',
'requestform',
Expand Down Expand Up @@ -76,6 +77,40 @@ class RequestValueError(BadRequest, ValueError):
"""Exception that combines ValueError with BadRequest."""


class Redirect(HTTPException):
"""HTTP redirect as an exception, to bypass return type constraints."""

code: int = 302

def __init__(
self, location: str, code: Literal[301, 302, 303, 307, 308] = 302
) -> None:
super().__init__()
self.location = location
self.code = code

def get_headers(
self,
*args: Any,
**kwargs: Any,
) -> list[tuple[str, str]]:
"""Add location header to response."""
headers = super().get_headers(*args, **kwargs)
headers.append(('Location', self.location))
return headers

def get_description(self, *_args: Any, **_kwargs: Any) -> str:
"""Add a HTML description."""
html_location = html_escape(self.location)
return (
"<p>You should be redirected automatically to the target URL: "
f'<a href="{html_location}">{html_location}</a>. If not, click the link.\n'
)

def __str__(self) -> str:
return f"{self.code} {self.name}: {self.location}"


def requestargs(
*args: Union[str, tuple[str, Callable[[str], Any]]],
source: Literal['values', 'form', 'query', 'body'] = 'values',
Expand Down Expand Up @@ -343,7 +378,7 @@ def show_page(folder: Folder, page: Page) -> ResponseReturnValue:
"""

def decorator(f: Callable[..., _VR]) -> Callable[..., _VR]:
def loader(kwargs: dict[str, Any]) -> Union[dict[str, Any], BaseResponse]:
def loader(kwargs: dict[str, Any]) -> dict[str, Any]:
view_args: Optional[dict[str, Any]]
request_endpoint: str = request.endpoint # type: ignore[assignment]
permissions: Optional[set[str]] = None
Expand Down Expand Up @@ -394,7 +429,7 @@ def loader(kwargs: dict[str, Any]) -> Union[dict[str, Any], BaseResponse]:
location = url_for(request_endpoint, **view_args)
if request.query_string:
location = location + '?' + request.query_string.decode()
return redirect(location, code=307)
raise Redirect(location, code=307)

if permission_required:
permissions = item.permissions(
Expand Down Expand Up @@ -429,7 +464,7 @@ def loader(kwargs: dict[str, Any]) -> Union[dict[str, Any], BaseResponse]:
location = url_for(request_endpoint, **view_args)
if request.query_string:
location = location + '?' + request.query_string.decode()
return redirect(location, code=302)
raise Redirect(location, code=302)
if parameter.startswith('g.') and g:
parameter = parameter[2:]
setattr(g, parameter, item)
Expand All @@ -445,8 +480,6 @@ def loader(kwargs: dict[str, Any]) -> Union[dict[str, Any], BaseResponse]:
@wraps(f)
async def async_wrapper(*args, **kwargs) -> Any:
result = loader(kwargs)
if isinstance(result, BaseResponse):
return result
if config.get('kwargs'):
return await f(*args, kwargs=kwargs, **result)
return await f(*args, **result)
Expand All @@ -458,8 +491,6 @@ async def async_wrapper(*args, **kwargs) -> Any:
@wraps(f)
def wrapper(*args, **kwargs) -> _VR:
result = loader(kwargs)
if isinstance(result, BaseResponse):
return result # type: ignore[return-value]
if config.get('kwargs'):
return f(*args, kwargs=kwargs, **result)
return f(*args, **result)
Expand Down
71 changes: 32 additions & 39 deletions tests/coaster_tests/views_loadmodels_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

# pylint: disable=redefined-outer-name,no-value-for-parameter

from __future__ import annotations

from typing import Optional

import pytest
Expand All @@ -10,9 +12,8 @@
from sqlalchemy.orm import Mapped
from werkzeug.exceptions import Forbidden, NotFound

from coaster.compat import BaseResponse
from coaster.sqlalchemy import BaseMixin, BaseNameMixin, BaseScopedIdMixin, relationship
from coaster.views import load_model, load_models
from coaster.views import Redirect, load_model, load_models

from .auth_test import LoginManager
from .conftest import AppTestCase, Model
Expand All @@ -31,13 +32,12 @@

class MiddleContainer(BaseMixin, Model):
__tablename__ = 'middle_container'
children: Mapped[list[ChildDocument]] = relationship(back_populates='parent')


class ParentDocument(BaseNameMixin, Model):
__tablename__ = 'parent_document'
middle_id: Mapped[int] = sa.orm.mapped_column(
sa.ForeignKey('middle_container.id'), nullable=False
)
middle_id: Mapped[int] = sa.orm.mapped_column(sa.ForeignKey('middle_container.id'))
middle: Mapped[MiddleContainer] = relationship(MiddleContainer)

def __init__(self, **kwargs) -> None:
Expand All @@ -57,10 +57,8 @@ def permissions(

class ChildDocument(BaseScopedIdMixin, Model):
__tablename__ = 'child_document'
parent_id: Mapped[int] = sa.orm.mapped_column(
sa.ForeignKey('middle_container.id'), nullable=False
)
parent: Mapped[MiddleContainer] = relationship(MiddleContainer, backref='children')
parent_id: Mapped[int] = sa.orm.mapped_column(sa.ForeignKey('middle_container.id'))
parent: Mapped[MiddleContainer] = relationship(back_populates='children')

def permissions(
self, actor: User, inherited: Optional[set[str]] = None
Expand All @@ -73,15 +71,10 @@ def permissions(

class RedirectDocument(BaseNameMixin, Model):
__tablename__ = 'redirect_document'
container_id: Mapped[int] = sa.orm.mapped_column(
sa.ForeignKey('container.id'), nullable=False
)
container: Mapped[Container] = relationship(Container)

target_id: Mapped[int] = sa.orm.mapped_column(
sa.ForeignKey('named_document.id'), nullable=False
)
target: Mapped[NamedDocument] = relationship(NamedDocument)
container_id: Mapped[int] = sa.orm.mapped_column(sa.ForeignKey('container.id'))
container: Mapped[Container] = relationship()
target_id: Mapped[int] = sa.orm.mapped_column(sa.ForeignKey('named_document.id'))
target: Mapped[NamedDocument] = relationship()

def redirect_view_args(self) -> dict[str, str]:
return {'document': self.target.name}
Expand Down Expand Up @@ -358,15 +351,15 @@ def test_redirect_document(self) -> None:
== self.nd2
)
with self.app.test_request_context('/c/redirect-document'):
response = t_redirect_document(container='c', document='redirect-document')
assert isinstance(response, BaseResponse)
assert response.status_code == 307
assert response.headers['Location'] == '/c/named-document'
with pytest.raises(Redirect) as exc_info:
t_redirect_document(container='c', document='redirect-document')
assert exc_info.value.code == 307
assert exc_info.value.location == '/c/named-document'
with self.app.test_request_context('/c/redirect-document?preserve=this'):
response = t_redirect_document(container='c', document='redirect-document')
assert isinstance(response, BaseResponse)
assert response.status_code == 307
assert response.headers['Location'] == '/c/named-document?preserve=this'
with pytest.raises(Redirect) as exc_info:
t_redirect_document(container='c', document='redirect-document')
assert exc_info.value.code == 307
assert exc_info.value.location == '/c/named-document?preserve=this'

def test_scoped_named_document(self) -> None:
assert (
Expand All @@ -391,15 +384,15 @@ def test_id_named_document(self) -> None:
== self.ind2
)
with self.app.test_request_context('/c/1-wrong-name'):
r = t_id_named_document(container='c', document='1-wrong-name')
assert isinstance(r, BaseResponse)
assert r.status_code == 302
assert r.location == '/c/1-id-named-document'
with pytest.raises(Redirect) as exc_info:
t_id_named_document(container='c', document='1-wrong-name')
assert exc_info.value.code == 302
assert exc_info.value.location == '/c/1-id-named-document'
with self.app.test_request_context('/c/1-wrong-name?preserve=this'):
r = t_id_named_document(container='c', document='1-wrong-name')
assert isinstance(r, BaseResponse)
assert r.status_code == 302
assert r.location == '/c/1-id-named-document?preserve=this'
with pytest.raises(Redirect) as exc_info:
t_id_named_document(container='c', document='1-wrong-name')
assert exc_info.value.code == 302
assert exc_info.value.location == '/c/1-id-named-document?preserve=this'
with pytest.raises(NotFound):
t_id_named_document(container='c', document='random-non-integer')

Expand All @@ -424,11 +417,11 @@ def test_scoped_id_named_document(self) -> None:
== self.sind2
)
with self.app.test_request_context('/c/1-wrong-name'):
r = t_scoped_id_named_document(container='c', document='1-wrong-name')
assert isinstance(r, BaseResponse)
assert r.status_code == 302
assert r.location == '/c/1-scoped-id-named-document'
with pytest.raises(NotFound):
with pytest.raises(Redirect) as exc_info:
t_scoped_id_named_document(container='c', document='1-wrong-name')
assert exc_info.value.code == 302
assert exc_info.value.location == '/c/1-scoped-id-named-document'
with pytest.raises(NotFound): # type: ignore[unreachable]
t_scoped_id_named_document(container='c', document='random-non-integer')

def test_callable_document(self) -> None:
Expand Down

0 comments on commit c1a02f5

Please sign in to comment.