-
-
Notifications
You must be signed in to change notification settings - Fork 4.1k
/
base.py
750 lines (632 loc) · 28 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
from __future__ import annotations
import abc
import functools
import logging
import time
from collections.abc import Callable, Iterable, Mapping
from datetime import datetime, timedelta, timezone
from typing import Any
from urllib.parse import quote as urlquote
import orjson
import sentry_sdk
from django.conf import settings
from django.http import HttpResponse
from django.http.request import HttpRequest
from django.views.decorators.csrf import csrf_exempt
from rest_framework import status
from rest_framework.authentication import BaseAuthentication, SessionAuthentication
from rest_framework.exceptions import ParseError
from rest_framework.permissions import BasePermission
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.views import APIView
from sentry_sdk import Scope
from sentry import analytics, options, tsdb
from sentry.api.api_owners import ApiOwner
from sentry.api.api_publish_status import ApiPublishStatus
from sentry.api.exceptions import StaffRequired, SuperuserRequired
from sentry.apidocs.hooks import HTTP_METHOD_NAME
from sentry.auth import access
from sentry.auth.staff import has_staff_option
from sentry.models.environment import Environment
from sentry.ratelimits.config import DEFAULT_RATE_LIMIT_CONFIG, RateLimitConfig
from sentry.silo.base import SiloLimit, SiloMode
from sentry.types.ratelimit import RateLimit, RateLimitCategory
from sentry.utils.audit import create_audit_entry
from sentry.utils.cursors import Cursor
from sentry.utils.dates import to_datetime
from sentry.utils.http import (
absolute_uri,
is_using_customer_domain,
is_valid_origin,
origin_from_request,
)
from sentry.utils.sdk import capture_exception, merge_context_into_scope
from ..services.hybrid_cloud import rpcmetrics
from ..utils.pagination_factory import (
annotate_span_with_pagination_args,
clamp_pagination_per_page,
get_cursor,
get_paginator,
)
from .authentication import (
ApiKeyAuthentication,
OrgAuthTokenAuthentication,
UserAuthTokenAuthentication,
update_token_access_record,
)
from .paginator import BadPaginationError, MissingPaginationError, Paginator
from .permissions import (
NoPermission,
StaffPermission,
SuperuserOrStaffFeatureFlaggedPermission,
SuperuserPermission,
)
from .utils import generate_organization_url
__all__ = [
"Endpoint",
"EnvironmentMixin",
"StatsMixin",
"control_silo_endpoint",
"region_silo_endpoint",
]
PAGINATION_DEFAULT_PER_PAGE = 100
ONE_MINUTE = 60
ONE_HOUR = ONE_MINUTE * 60
ONE_DAY = ONE_HOUR * 24
CURSOR_LINK_HEADER = (
'<{uri}&cursor={cursor}>; rel="{name}"; results="{has_results}"; cursor="{cursor}"'
)
DEFAULT_AUTHENTICATION = (
UserAuthTokenAuthentication,
OrgAuthTokenAuthentication,
ApiKeyAuthentication,
SessionAuthentication,
)
logger = logging.getLogger(__name__)
audit_logger = logging.getLogger("sentry.audit.api")
api_access_logger = logging.getLogger("sentry.access.api")
def allow_cors_options(func):
"""
Decorator that adds automatic handling of OPTIONS requests for CORS
If the request is OPTIONS (i.e. pre flight CORS) construct a OK (200) response
in which we explicitly enable the caller and add the custom headers that we support
For other requests just add the appropriate CORS headers
:param func: the original request handler
:return: a request handler that shortcuts OPTIONS requests and just returns an OK (CORS allowed)
"""
@functools.wraps(func)
def allow_cors_options_wrapper(self, request: Request, *args, **kwargs):
if request.method == "OPTIONS":
response = HttpResponse(status=200)
response["Access-Control-Max-Age"] = "3600" # don't ask for options again for 1 hour
else:
response = func(self, request, *args, **kwargs)
return apply_cors_headers(
request=request, response=response, allowed_methods=self._allowed_methods()
)
return allow_cors_options_wrapper
def apply_cors_headers(
request: HttpRequest, response: HttpResponse, allowed_methods: list[str] | None = None
) -> HttpResponse:
if allowed_methods is None:
allowed_methods = []
allow = ", ".join(allowed_methods)
response["Allow"] = allow
response["Access-Control-Allow-Methods"] = allow
response["Access-Control-Allow-Headers"] = (
"X-Sentry-Auth, X-Requested-With, Origin, Accept, "
"Content-Type, Authentication, Authorization, Content-Encoding, "
"sentry-trace, baggage, X-CSRFToken"
)
response["Access-Control-Expose-Headers"] = (
"X-Sentry-Error, X-Sentry-Direct-Hit, X-Hits, X-Max-Hits, " "Endpoint, Retry-After, Link"
)
if request.META.get("HTTP_ORIGIN") == "null":
# if ORIGIN header is explicitly specified as 'null' leave it alone
origin: str | None = "null"
else:
origin = origin_from_request(request)
if origin is None or origin == "null":
response["Access-Control-Allow-Origin"] = "*"
else:
response["Access-Control-Allow-Origin"] = origin
# If the requesting origin is a subdomain of
# the application's base-hostname we should allow cookies
# to be sent.
basehost = options.get("system.base-hostname")
if basehost and origin:
if (
origin.endswith(("://" + basehost, "." + basehost))
or origin in settings.ALLOWED_CREDENTIAL_ORIGINS
):
response["Access-Control-Allow-Credentials"] = "true"
return response
class BaseEndpointMixin(abc.ABC):
"""
Inherit from this class when adding mixin classes that call `Endpoint` methods. This allows typing to
work correctly
"""
@abc.abstractmethod
def create_audit_entry(self, request: Request, transaction_id=None, **kwargs):
pass
@abc.abstractmethod
def respond(self, context: object | None = None, **kwargs: Any) -> Response:
pass
@abc.abstractmethod
def paginate(
self,
request,
on_results=None,
paginator=None,
paginator_cls=Paginator,
default_per_page: int | None = None,
max_per_page: int | None = None,
cursor_cls=Cursor,
response_cls=Response,
response_kwargs=None,
count_hits=None,
**paginator_kwargs,
):
pass
class Endpoint(APIView):
# Note: the available renderer and parser classes can be found in conf/server.py.
authentication_classes: tuple[type[BaseAuthentication], ...] = DEFAULT_AUTHENTICATION
permission_classes: tuple[type[BasePermission], ...] = (NoPermission,)
cursor_name = "cursor"
owner: ApiOwner = ApiOwner.UNOWNED
publish_status: dict[HTTP_METHOD_NAME, ApiPublishStatus] = {}
rate_limits: RateLimitConfig | dict[
str, dict[RateLimitCategory, RateLimit]
] = DEFAULT_RATE_LIMIT_CONFIG
enforce_rate_limit: bool = settings.SENTRY_RATELIMITER_ENABLED
def build_link_header(self, request: Request, path: str, rel: str):
# TODO(dcramer): it would be nice to expand this to support params to consolidate `build_cursor_link`
uri = request.build_absolute_uri(urlquote(path))
return f'<{uri}>; rel="{rel}">'
def build_cursor_link(self, request: Request, name: str, cursor: Cursor):
if request.GET.get("cursor") is None:
querystring = request.GET.urlencode()
else:
mutable_query_dict = request.GET.copy()
mutable_query_dict.pop("cursor")
querystring = mutable_query_dict.urlencode()
url_prefix = (
generate_organization_url(request.subdomain)
if is_using_customer_domain(request)
else None
)
base_url = absolute_uri(urlquote(request.path), url_prefix=url_prefix)
if querystring:
base_url = f"{base_url}?{querystring}"
else:
base_url = f"{base_url}?"
return CURSOR_LINK_HEADER.format(
uri=base_url,
cursor=str(cursor),
name=name,
has_results="true" if bool(cursor) else "false",
)
def convert_args(self, request: Request, *args, **kwargs):
return (args, kwargs)
def permission_denied(self, request, message=None, code=None):
"""
Raise a specific superuser exception if the user can become superuser
and the only permission class is SuperuserPermission. Otherwise, raises
the appropriate exception according to parent DRF function.
"""
permissions = self.get_permissions()
if request.user.is_authenticated and len(permissions) == 1:
permission_cls = permissions[0]
enforce_staff_permission = has_staff_option(request.user)
# TODO(schew2381): Remove SuperuserOrStaffFeatureFlaggedPermission
# from isinstance checks once feature flag is removed.
if enforce_staff_permission:
is_staff_user = request.user.is_staff
has_only_staff_permission = isinstance(
permission_cls, (StaffPermission, SuperuserOrStaffFeatureFlaggedPermission)
)
if is_staff_user and has_only_staff_permission:
raise StaffRequired()
else:
is_superuser_user = request.user.is_superuser
has_only_superuser_permission = isinstance(
permission_cls, (SuperuserPermission, SuperuserOrStaffFeatureFlaggedPermission)
)
if is_superuser_user and has_only_superuser_permission:
raise SuperuserRequired()
super().permission_denied(request, message, code)
def handle_exception( # type: ignore[override]
self,
request: Request,
exc: Exception,
handler_context: Mapping[str, Any] | None = None,
scope: Scope | None = None,
) -> Response:
"""
Handle exceptions which arise while processing incoming API requests.
:param request: The incoming request.
:param exc: The exception raised during handling.
:param handler_context: (Optional) Extra data which will be attached to the event sent to
Sentry, under the "Request Handler Data" heading.
:param scope: (Optional) A `Scope` object containing extra data which will be
attached to the event sent to Sentry.
:returns: A 500 response including the event id of the captured Sentry event.
"""
try:
# Django REST Framework's built-in exception handler. If `settings.EXCEPTION_HANDLER`
# exists and returns a response, that's used. Otherwise, `exc` is just re-raised
# and caught below.
response = super().handle_exception(exc)
except Exception as err:
import sys
import traceback
sys.stderr.write(traceback.format_exc())
scope = scope or sentry_sdk.Scope()
if handler_context:
merge_context_into_scope("Request Handler Data", handler_context, scope)
event_id = capture_exception(err, scope=scope)
response_body = {"detail": "Internal Error", "errorId": event_id}
response = Response(response_body, status=500)
response.exception = True
return response
def create_audit_entry(self, request: Request, transaction_id=None, **kwargs):
return create_audit_entry(request, transaction_id, audit_logger, **kwargs)
def load_json_body(self, request: Request):
"""
Attempts to load the request body when it's JSON.
The end result is ``request.json_body`` having a value. When it can't
load the body as JSON, for any reason, ``request.json_body`` is None.
The request flow is unaffected and no exceptions are ever raised.
"""
request.json_body = None
if not request.META.get("CONTENT_TYPE", "").startswith("application/json"):
return
if not len(request.body):
return
try:
request.json_body = orjson.loads(request.body)
except orjson.JSONDecodeError:
return
def initialize_request(self, request: HttpRequest, *args: Any, **kwargs: Any) -> Request:
# XXX: Since DRF 3.x, when the request is passed into
# `initialize_request` it's set as an internal variable on the returned
# request. Then when we call `rv.auth` it attempts to authenticate,
# fails and sets `user` and `auth` to None on the internal request. We
# keep track of these here and reassign them as needed.
orig_auth = getattr(request, "auth", None)
orig_user = getattr(request, "user", None)
rv = super().initialize_request(request, *args, **kwargs)
# If our request is being made via our internal API client, we need to
# stitch back on auth and user information
if getattr(request, "__from_api_client__", False):
if rv.auth is None:
rv.auth = orig_auth
if rv.user is None:
rv.user = orig_user
return rv
def has_pagination(self, response: Response) -> bool:
# If response is paginated, it will have a "Link" header
return response.headers.get("Link") is not None
@csrf_exempt
@allow_cors_options
def dispatch(self, request: Request, *args, **kwargs) -> Response:
"""
Identical to rest framework's dispatch except we add the ability
to convert arguments (for common URL params).
"""
with sentry_sdk.start_span(op="base.dispatch.setup", description=type(self).__name__):
self.args = args
self.kwargs = kwargs
request = self.initialize_request(request, *args, **kwargs)
self.load_json_body(request)
self.request = request
self.headers = self.default_response_headers # deprecate?
# Tags that will ultimately flow into the metrics backend at the end of
# the request (happens via middleware/stats.py).
request._metric_tags = {}
start_time = time.time()
origin = request.META.get("HTTP_ORIGIN", "null")
# A "null" value should be treated as no Origin for us.
# See RFC6454 for more information on this behavior.
if origin == "null":
origin = None
try:
with sentry_sdk.start_span(op="base.dispatch.request", description=type(self).__name__):
if origin:
if request.auth:
allowed_origins = request.auth.get_allowed_origins()
else:
allowed_origins = None
if not is_valid_origin(origin, allowed=allowed_origins):
response = Response(f"Invalid origin: {origin}", status=400)
self.response = self.finalize_response(request, response, *args, **kwargs)
return self.response
if request.auth:
update_token_access_record(request.auth)
self.initial(request, *args, **kwargs)
# Get the appropriate handler method
method = request.method.lower()
if method in self.http_method_names and hasattr(self, method):
handler = getattr(self, method)
# Only convert args when using defined handlers
(args, kwargs) = self.convert_args(request, *args, **kwargs)
self.args = args
self.kwargs = kwargs
else:
handler = self.http_method_not_allowed
if getattr(request, "access", None) is None:
# setup default access
request.access = access.from_request(request)
with sentry_sdk.start_span(
op="base.dispatch.execute",
description=".".join(
getattr(part, "__name__", None) or str(part) for part in (type(self), handler)
),
) as span:
with rpcmetrics.wrap_sdk_span(span):
response = handler(request, *args, **kwargs)
except Exception as exc:
response = self.handle_exception(request, exc)
if origin:
self.add_cors_headers(request, response)
self.response = self.finalize_response(request, response, *args, **kwargs)
if settings.SENTRY_API_RESPONSE_DELAY:
duration = time.time() - start_time
if duration < (settings.SENTRY_API_RESPONSE_DELAY / 1000.0):
with sentry_sdk.start_span(
op="base.dispatch.sleep",
description=type(self).__name__,
) as span:
span.set_data("SENTRY_API_RESPONSE_DELAY", settings.SENTRY_API_RESPONSE_DELAY)
time.sleep(settings.SENTRY_API_RESPONSE_DELAY / 1000.0 - duration)
# Only enforced in dev environment
if settings.ENFORCE_PAGINATION:
if request.method.lower() == "get":
status = getattr(self.response, "status_code", 0)
# Response can either be Response or HttpResponse, check if
# it's value is an array and that it was an OK response
if (
200 <= status < 300
and hasattr(self.response, "data")
and isinstance(self.response.data, list)
):
# if not paginated and not in settings.SENTRY_API_PAGINATION_ALLOWLIST, raise error
if (
handler.__self__.__class__.__name__
not in settings.SENTRY_API_PAGINATION_ALLOWLIST
and not self.has_pagination(self.response)
):
raise MissingPaginationError(handler.__func__.__qualname__)
return self.response
def add_cors_headers(self, request: Request, response):
response["Access-Control-Allow-Origin"] = request.META["HTTP_ORIGIN"]
response["Access-Control-Allow-Methods"] = ", ".join(self.http_method_names)
def add_cursor_headers(self, request: Request, response, cursor_result):
if cursor_result.hits is not None:
response["X-Hits"] = cursor_result.hits
if cursor_result.max_hits is not None:
response["X-Max-Hits"] = cursor_result.max_hits
response["Link"] = ", ".join(
[
self.build_cursor_link(request, "previous", cursor_result.prev),
self.build_cursor_link(request, "next", cursor_result.next),
]
)
def respond(self, context: object | None = None, **kwargs: Any) -> Response:
return Response(context, **kwargs)
def respond_with_text(self, text):
return self.respond({"text": text})
def get_per_page(
self, request: Request, default_per_page: int | None = None, max_per_page: int | None = None
):
default_per_page = default_per_page or PAGINATION_DEFAULT_PER_PAGE
max_per_page = max_per_page or 100
try:
return clamp_pagination_per_page(
request.GET.get("per_page", default_per_page),
default_per_page=default_per_page,
max_per_page=max_per_page,
)
except ValueError as e:
raise ParseError(detail=str(e))
def get_cursor_from_request(self, request: Request, cursor_cls=Cursor):
try:
return get_cursor(request.GET.get(self.cursor_name), cursor_cls)
except ValueError as e:
raise ParseError(detail=str(e))
def paginate(
self,
request,
on_results=None,
paginator=None,
paginator_cls=Paginator,
default_per_page: int | None = None,
max_per_page: int | None = None,
cursor_cls=Cursor,
response_cls=Response,
response_kwargs=None,
count_hits=None,
**paginator_kwargs,
):
try:
per_page = self.get_per_page(request, default_per_page, max_per_page)
cursor = self.get_cursor_from_request(request, cursor_cls)
with sentry_sdk.start_span(
op="base.paginate.get_result",
description=type(self).__name__,
) as span:
annotate_span_with_pagination_args(span, per_page)
paginator = get_paginator(paginator, paginator_cls, paginator_kwargs)
result_args = dict(count_hits=count_hits) if count_hits is not None else dict()
cursor_result = paginator.get_result(
limit=per_page,
cursor=cursor,
**result_args,
)
except BadPaginationError as e:
raise ParseError(detail=str(e))
if response_kwargs is None:
response_kwargs = {}
# map results based on callback
if on_results:
with sentry_sdk.start_span(
op="base.paginate.on_results",
description=type(self).__name__,
):
results = on_results(cursor_result.results)
else:
results = cursor_result.results
response = response_cls(results, **response_kwargs)
self.add_cursor_headers(request, response, cursor_result)
return response
class EnvironmentMixin:
def _get_environment_func(self, request: Request, organization_id):
"""\
Creates a function that when called returns the ``Environment``
associated with a request object, or ``None`` if no environment was
provided. If the environment doesn't exist, an ``Environment.DoesNotExist``
exception will be raised.
This returns as a callable since some objects outside of the API
endpoint need to handle the "environment was provided but does not
exist" state in addition to the two non-exceptional states (the
environment was provided and exists, or the environment was not
provided.)
"""
return functools.partial(self._get_environment_from_request, request, organization_id)
def _get_environment_id_from_request(self, request: Request, organization_id):
environment = self._get_environment_from_request(request, organization_id)
return environment and environment.id
def _get_environment_from_request(self, request: Request, organization_id):
if not hasattr(request, "_cached_environment"):
environment_param = request.GET.get("environment")
if environment_param is None:
environment = None
else:
environment = Environment.get_for_organization_id(
name=environment_param, organization_id=organization_id
)
request._cached_environment = environment
return request._cached_environment
class StatsMixin:
def _parse_args(self, request: Request, environment_id=None, restrict_rollups=True):
"""
Parse common stats parameters from the query string. This includes
`since`, `until`, and `resolution`.
:param boolean restrict_rollups: When False allows any rollup value to
be specified. Be careful using this as this allows for fine grain
rollups that may put strain on the system.
"""
try:
resolution = request.GET.get("resolution")
if resolution:
resolution = self._parse_resolution(resolution)
if restrict_rollups and resolution not in tsdb.backend.get_rollups():
raise ValueError
except ValueError:
raise ParseError(detail="Invalid resolution")
try:
end_s = request.GET.get("until")
if end_s:
end = to_datetime(float(end_s))
else:
end = datetime.now(timezone.utc)
except ValueError:
raise ParseError(detail="until must be a numeric timestamp.")
try:
start_s = request.GET.get("since")
if start_s:
start = to_datetime(float(start_s))
assert start <= end
else:
start = end - timedelta(days=1, seconds=-1)
except ValueError:
raise ParseError(detail="since must be a numeric timestamp")
except AssertionError:
raise ParseError(detail="start must be before or equal to end")
if not resolution:
resolution = tsdb.backend.get_optimal_rollup(start, end)
return {
"start": start,
"end": end,
"rollup": resolution,
"environment_ids": environment_id and [environment_id],
}
def _parse_resolution(self, value):
if value.endswith("h"):
return int(value[:-1]) * ONE_HOUR
elif value.endswith("d"):
return int(value[:-1]) * ONE_DAY
elif value.endswith("m"):
return int(value[:-1]) * ONE_MINUTE
elif value.endswith("s"):
return int(value[:-1])
else:
raise ValueError(value)
class ReleaseAnalyticsMixin:
def track_set_commits_local(self, request: Request, organization_id=None, project_ids=None):
analytics.record(
"release.set_commits_local",
user_id=request.user.id if request.user and request.user.id else None,
organization_id=organization_id,
project_ids=project_ids,
user_agent=request.META.get("HTTP_USER_AGENT", ""),
)
class EndpointSiloLimit(SiloLimit):
def modify_endpoint_class(self, decorated_class: type[Endpoint]) -> type:
dispatch_override = self.create_override(decorated_class.dispatch)
new_class = type(
decorated_class.__name__,
(decorated_class,),
{
"dispatch": dispatch_override,
"silo_limit": self,
},
)
new_class.__module__ = decorated_class.__module__
return new_class
def modify_endpoint_method(self, decorated_method: Callable[..., Any]) -> Callable[..., Any]:
return self.create_override(decorated_method)
def handle_when_unavailable(
self,
original_method: Callable[..., Any],
current_mode: SiloMode,
available_modes: Iterable[SiloMode],
) -> Callable[..., Any]:
def handle(obj: Any, request: Request, *args: Any, **kwargs: Any) -> HttpResponse:
mode_str = ", ".join(str(m) for m in available_modes)
message = (
f"Received {request.method} request at {request.path!r} to server in "
f"{current_mode} mode. This endpoint is available only in: {mode_str}"
)
if settings.FAIL_ON_UNAVAILABLE_API_CALL:
raise self.AvailabilityError(message)
else:
logger.warning(message)
return HttpResponse(status=status.HTTP_404_NOT_FOUND)
return handle
def __call__(self, decorated_obj: Any) -> Any:
if isinstance(decorated_obj, type):
if not issubclass(decorated_obj, Endpoint):
raise ValueError("`@EndpointSiloLimit` can decorate only Endpoint subclasses")
return self.modify_endpoint_class(decorated_obj)
if callable(decorated_obj):
return self.modify_endpoint_method(decorated_obj)
raise TypeError("`@EndpointSiloLimit` must decorate a class or method")
control_silo_endpoint = EndpointSiloLimit(SiloMode.CONTROL)
"""
Apply to endpoints that exist in CONTROL silo.
If a request is received and the application is not in CONTROL
mode 404s will be returned.
"""
region_silo_endpoint = EndpointSiloLimit(SiloMode.REGION)
"""
Apply to endpoints that exist in REGION silo.
If a request is received and the application is not in REGION
mode 404s will be returned.
"""
all_silo_endpoint = EndpointSiloLimit(SiloMode.CONTROL, SiloMode.REGION, SiloMode.MONOLITH)
"""
Apply to endpoints that are available in all silo modes.
This should be rarely used, but is relevant for resources like ROBOTS.txt.
"""