diff --git a/src/sentry/web/frontend/base.py b/src/sentry/web/frontend/base.py index 499e2f2655dd7e..2e4a7c35f950d4 100644 --- a/src/sentry/web/frontend/base.py +++ b/src/sentry/web/frontend/base.py @@ -22,7 +22,6 @@ from django.utils.decorators import method_decorator from django.views.decorators.csrf import csrf_exempt from django.views.generic import View -from rest_framework.request import Request from sentry import options from sentry.api.exceptions import DataSecrecyError @@ -80,10 +79,11 @@ def handle_when_unavailable( current_mode: SiloMode, available_modes: Iterable[SiloMode], ) -> Callable[..., Any]: - def handle(obj: Any, request: Request, *args: Any, **kwargs: Any) -> HttpResponse: + def handle(*args: Any, **kwargs: Any) -> HttpResponse: + method, path = self._request_attrs(args, kwargs) mode_str = ", ".join(str(m) for m in available_modes) message = ( - f"Received {request.method} request at {request.path!r} to server in " + f"Received {method} request at {path!r} to server in " f"{current_mode} mode. This view is available only in: {mode_str}" ) if settings.FAIL_ON_UNAVAILABLE_API_CALL: @@ -94,6 +94,15 @@ def handle(obj: Any, request: Request, *args: Any, **kwargs: Any) -> HttpRespons return handle + def _request_attrs(self, args: Iterable[Any], kwargs: Mapping[str, Any]) -> tuple[str, str]: + for arg in args: + if isinstance(arg, HttpRequest): + return (arg.method or "unknown", arg.path) + for value in kwargs.values(): + if isinstance(value, HttpRequest): + return (value.method or "unknown", value.path) + return ("unknown", "unknown") + def __call__(self, decorated_obj: Any) -> Any: if isinstance(decorated_obj, type): if not issubclass(decorated_obj, View): diff --git a/tests/sentry/web/frontend/test_base.py b/tests/sentry/web/frontend/test_base.py new file mode 100644 index 00000000000000..0e7e8dcd904bdb --- /dev/null +++ b/tests/sentry/web/frontend/test_base.py @@ -0,0 +1,45 @@ +import pytest +from rest_framework.response import Response + +from sentry.silo.base import SiloMode +from sentry.testutils.cases import APITestCase +from sentry.testutils.silo import assume_test_silo_mode +from sentry.web.frontend.base import BaseView, ViewSiloLimit + + +class ViewSiloLimitTest(APITestCase): + def _test_active_on(self, endpoint_mode, active_mode, expect_to_be_active): + @ViewSiloLimit(endpoint_mode) + def view_func(request): + pass + + @ViewSiloLimit(endpoint_mode) + class DummyView(BaseView): + def get(self, request): + return Response("dummy-view", 200) + + view_class_func = DummyView.as_view() + with assume_test_silo_mode(active_mode): + request = self.make_request(method="GET", path="/dummy/") + setattr(request, "subdomain", "acme") + + if expect_to_be_active: + view_func(request) + view_class_func(request) + else: + with pytest.raises(ViewSiloLimit.AvailabilityError): + view_func(request) + with pytest.raises(ViewSiloLimit.AvailabilityError): + view_class_func(request) + + def test_with_active_mode(self) -> None: + self._test_active_on(SiloMode.REGION, SiloMode.REGION, True) + self._test_active_on(SiloMode.CONTROL, SiloMode.CONTROL, True) + + def test_with_inactive_mode(self) -> None: + self._test_active_on(SiloMode.REGION, SiloMode.CONTROL, False) + self._test_active_on(SiloMode.CONTROL, SiloMode.REGION, False) + + def test_with_monolith_mode(self) -> None: + self._test_active_on(SiloMode.REGION, SiloMode.MONOLITH, True) + self._test_active_on(SiloMode.CONTROL, SiloMode.MONOLITH, True)