diff --git a/src/sentry/hybridcloud/apigateway/proxy.py b/src/sentry/hybridcloud/apigateway/proxy.py index a33c2fe8aafbec..d780ab9355a8de 100644 --- a/src/sentry/hybridcloud/apigateway/proxy.py +++ b/src/sentry/hybridcloud/apigateway/proxy.py @@ -5,7 +5,7 @@ from __future__ import annotations import logging -from collections.abc import Generator, Iterator +from collections.abc import Generator from urllib.parse import urljoin, urlparse from wsgiref.util import is_hop_by_hop @@ -19,6 +19,7 @@ from sentry import options from sentry.api.exceptions import RequestTimeout from sentry.models.organizationmapping import OrganizationMapping +from sentry.objectstore.endpoints.organization import ChunkedEncodingDecoder, get_raw_body from sentry.sentry_apps.models.sentry_app import SentryApp from sentry.sentry_apps.models.sentry_app_installation import SentryAppInstallation from sentry.silo.util import ( @@ -34,6 +35,7 @@ get_region_for_organization, ) from sentry.utils import metrics +from sentry.utils.http import BodyWithLength logger = logging.getLogger(__name__) @@ -77,22 +79,6 @@ def stream_response() -> Generator[bytes]: return streamed_response -class _body_with_length: - """Wraps an HttpRequest with a __len__ so that the request library does not assume length=0 in all cases""" - - def __init__(self, request: HttpRequest): - self.request = request - - def __iter__(self) -> Iterator[bytes]: - return iter(self.request) - - def __len__(self) -> int: - return int(self.request.headers.get("Content-Length", "0")) - - def read(self, size: int | None = None) -> bytes: - return self.request.read(size) - - def proxy_request(request: HttpRequest, org_id_or_slug: str, url_name: str) -> HttpResponseBase: """Take a django request object and proxy it to a remote location given an org_id_or_slug""" @@ -205,6 +191,12 @@ def proxy_region_request( if settings.APIGATEWAY_PROXY_SKIP_RELAY and request.path.startswith("/api/0/relays/"): return StreamingHttpResponse(streaming_content="relay proxy skipped", status=404) + data: Generator[bytes] | ChunkedEncodingDecoder | BodyWithLength | None = None + if url_name == "sentry-api-0-organization-objectstore": + data = get_raw_body(request) + else: + data = BodyWithLength(request) + try: with metrics.timer("apigateway.proxy_request.duration", tags=metric_tags): resp = external_request( @@ -212,7 +204,7 @@ def proxy_region_request( url=target_url, headers=header_dict, params=dict(query_params) if query_params is not None else None, - data=_body_with_length(request), + data=data, stream=True, timeout=timeout, # By default, external_request will resolve any redirects for any verb except for HEAD. diff --git a/src/sentry/objectstore/endpoints/organization.py b/src/sentry/objectstore/endpoints/organization.py index da15522dee97ab..3d78bce0a62094 100644 --- a/src/sentry/objectstore/endpoints/organization.py +++ b/src/sentry/objectstore/endpoints/organization.py @@ -1,15 +1,18 @@ -import logging +from __future__ import annotations + from collections.abc import Callable, Generator from typing import Any from urllib.parse import urlparse from wsgiref.util import is_hop_by_hop import requests -from django.http import StreamingHttpResponse +from django.http import HttpRequest, StreamingHttpResponse from requests import Response as ExternalResponse from rest_framework.request import Request from rest_framework.response import Response +from sentry.utils.http import BodyWithLength + # TODO(granian): Remove this and related code paths when we fully switch from uwsgi to granian uwsgi: Any = None try: @@ -24,8 +27,6 @@ from sentry.api.bases import OrganizationEndpoint from sentry.models.organization import Organization -logger = logging.getLogger(__name__) - @region_silo_endpoint class OrganizationObjectstoreEndpoint(OrganizationEndpoint): @@ -85,60 +86,15 @@ def _proxy( target_url = get_target_url(path) headers = dict(request.headers) - headers.pop("Host", None) headers.pop("Content-Length", None) - transfer_encoding = headers.pop("Transfer-Encoding", "") - - stream: Generator[bytes] | ChunkedEncodingDecoder | None = None - wsgi_input = request.META.get("wsgi.input") - - logger.info( - "objectstore proxy request", - extra={ - "method": request.method, - "path": path, - "request_id": request.META.get("HTTP_X_REQUEST_ID"), - "content_type": request.META.get("CONTENT_TYPE"), - "content_length": request.META.get("CONTENT_LENGTH"), - "transfer_encoding": transfer_encoding, - "server_software": request.META.get("SERVER_SOFTWARE"), - "has_wsgi_input": wsgi_input is not None, - "x_forwarded_for": request.META.get("HTTP_X_FORWARDED_FOR"), - "x_forwarded_proto": request.META.get("HTTP_X_FORWARDED_PROTO"), - "x_forwarded_host": request.META.get("HTTP_X_FORWARDED_HOST"), - }, - ) - - if "granian" in request.META.get("SERVER_SOFTWARE", "").lower(): - stream = wsgi_input - # uwsgi and wsgiref will respectively raise an exception and hang when attempting to read wsgi.input while there's no body. - # For now, support bodies only on PUT and POST requests. - elif request.method in ("PUT", "POST"): - if uwsgi: - if transfer_encoding.lower() == "chunked": - - def stream_generator(): - while True: - chunk = uwsgi.chunked_read() - if not chunk: - break - yield chunk - - stream = stream_generator() - else: - stream = wsgi_input - - else: - # This code path assumes wsgiref, used in dev/test mode. - # Note that we don't handle non-chunked transfer encoding here as our client (which we use for tests) always uses chunked encoding. - stream = ChunkedEncodingDecoder(wsgi_input._read) # type: ignore[union-attr] + headers.pop("Transfer-Encoding", None) response = requests.request( request.method, url=target_url, headers=headers, - data=stream, + data=get_raw_body(request._request), params=dict(request.GET) if request.GET else None, stream=True, allow_redirects=False, @@ -146,6 +102,43 @@ def stream_generator(): return stream_response(response) +def get_raw_body( + request: HttpRequest, +) -> Generator[bytes] | ChunkedEncodingDecoder | BodyWithLength | None: + wsgi_input = request.META.get("wsgi.input") + if "granian" in request.META.get("SERVER_SOFTWARE", "").lower(): + return wsgi_input + + # uwsgi and wsgiref will respectively raise an exception and hang when attempting to read wsgi.input while there's no body. + # For now, support bodies only on PUT and POST requests when not using Granian. + if request.method not in ("PUT", "POST"): + return None + + if uwsgi: + if request.headers.get("Transfer-Encoding", "").lower() == "chunked": + + def stream_generator(): + while True: + chunk = uwsgi.chunked_read() + if not chunk: + break + yield chunk + + return stream_generator() + + return wsgi_input + + # wsgiref (dev/test server) + if ( + hasattr(wsgi_input, "_read") + and request.headers.get("Transfer-Encoding", "").lower() == "chunked" + ): + return ChunkedEncodingDecoder(wsgi_input._read) # type: ignore[union-attr] + + # wsgiref and the request has been already proxied through control silo + return BodyWithLength(request) + + def get_target_url(path: str) -> str: base = options.get("objectstore.config")["base_url"].rstrip("/") # `path` should be a relative path, only grab that part diff --git a/src/sentry/utils/http.py b/src/sentry/utils/http.py index 3c3f425fe85641..8cf3f48dca9159 100644 --- a/src/sentry/utils/http.py +++ b/src/sentry/utils/http.py @@ -1,6 +1,6 @@ from __future__ import annotations -from collections.abc import Collection +from collections.abc import Collection, Iterator from typing import TYPE_CHECKING, NamedTuple, TypeGuard, overload from urllib.parse import quote, urljoin, urlparse @@ -224,3 +224,19 @@ class _HttpRequestWithSubdomain(HttpRequest): def is_using_customer_domain(request: HttpRequest) -> TypeGuard[_HttpRequestWithSubdomain]: return bool(hasattr(request, "subdomain") and request.subdomain) + + +class BodyWithLength: + """Wraps an HttpRequest with a __len__ so that the requests library does not assume length=0 in all cases""" + + def __init__(self, request: HttpRequest): + self.request = request + + def __iter__(self) -> Iterator[bytes]: + return iter(self.request) + + def __len__(self) -> int: + return int(self.request.headers.get("Content-Length", "0")) + + def read(self, size: int | None = None) -> bytes: + return self.request.read(size) diff --git a/tests/sentry/objectstore/endpoints/test_organization.py b/tests/sentry/objectstore/endpoints/test_organization.py index 403b5dfb8813b8..bdbdb237e6db5d 100644 --- a/tests/sentry/objectstore/endpoints/test_organization.py +++ b/tests/sentry/objectstore/endpoints/test_organization.py @@ -1,13 +1,21 @@ +from dataclasses import asdict + import pytest import requests +from django.db import connections from django.urls import reverse from objectstore_client import Client, RequestError, Session, Usecase from pytest_django.live_server_helper import LiveServer +from sentry.silo.base import SiloMode, SingleProcessSiloModeState +from sentry.testutils.asserts import assert_status_code from sentry.testutils.cases import TransactionTestCase from sentry.testutils.helpers.features import with_feature -from sentry.testutils.silo import region_silo_test +from sentry.testutils.region import override_regions +from sentry.testutils.silo import create_test_regions, region_silo_test from sentry.testutils.skips import requires_objectstore +from sentry.types.region import Region +from sentry.utils import json @pytest.fixture(scope="function") @@ -101,3 +109,120 @@ def test_large_payload(self): retrieved = session.get(object_key) assert retrieved.payload.read() == data + + +test_region = create_test_regions("us")[0] + + +@region_silo_test(regions=(test_region,)) +@requires_objectstore +@with_feature("organizations:objectstore-endpoint") +@pytest.mark.usefixtures("local_live_server") +class OrganizationObjectstoreEndpointWithControlSiloTest(TransactionTestCase): + endpoint = "sentry-api-0-organization-objectstore" + live_server: LiveServer + + def setUp(self) -> None: + super().setUp() + self.login_as(user=self.user) + self.organization = self.create_organization(owner=self.user) + self.api_key = self.create_api_key( + organization=self.organization, + scope_list=["org:admin"], + ) + + def tearDown(self) -> None: + for conn in connections.all(): + conn.close() + super().tearDown() + + def get_endpoint_url(self) -> str: + path = reverse( + self.endpoint, + kwargs={ + "organization_id_or_slug": self.organization.id, + "path": "", + }, + ) + return path + + def test_health(self): + config = asdict(test_region) + config["address"] = self.live_server.url + with override_regions([Region(**config)]): + with SingleProcessSiloModeState.enter(SiloMode.CONTROL): + response = self.client.get( + self.get_endpoint_url() + "health", + follow=True, + ) + assert response.status_code == 200 + # consume body to close connection + b"".join(response.streaming_content) # type: ignore[attr-defined] + + def test_full_cycle(self): + + config = asdict(test_region) + config["address"] = self.live_server.url + auth_header = self.create_basic_auth_header(self.api_key.key).decode() + + with override_regions([Region(**config)]): + with SingleProcessSiloModeState.enter(SiloMode.CONTROL): + base_url = f"{self.get_endpoint_url()}v1/objects/test/org={self.organization.id}/" + + response = self.client.post( + base_url, + data=b"test data", + HTTP_AUTHORIZATION=auth_header, + content_type="application/octet-stream", + follow=True, + ) + assert_status_code(response, 201) + object_key = json.loads(b"".join(response.streaming_content))["key"] # type: ignore[attr-defined] + assert object_key is not None + + response = self.client.get( + f"{base_url}{object_key}", + HTTP_AUTHORIZATION=auth_header, + follow=True, + ) + assert_status_code(response, 200) + retrieved_data = b"".join(response.streaming_content) # type: ignore[attr-defined] + assert retrieved_data == b"test data" + + response = self.client.put( + f"{base_url}{object_key}", + data=b"new data", + content_type="application/octet-stream", + HTTP_AUTHORIZATION=auth_header, + follow=True, + ) + assert_status_code(response, 200) + new_key = json.loads(b"".join(response.streaming_content))["key"] # type: ignore[attr-defined] + assert new_key == object_key + + response = self.client.get( + f"{base_url}{object_key}", + HTTP_AUTHORIZATION=auth_header, + follow=True, + ) + assert_status_code(response, 200) + retrieved = b"".join(response.streaming_content) # type: ignore[attr-defined] + assert retrieved == b"new data" + + response = self.client.delete( + f"{base_url}{object_key}", + HTTP_AUTHORIZATION=auth_header, + follow=True, + ) + assert_status_code(response, 204) + # consume body to close connection + b"".join(response.streaming_content) # type: ignore[attr-defined] + + response = self.client.get( + f"{base_url}{object_key}", + HTTP_AUTHORIZATION=auth_header, + follow=True, + ) + assert_status_code(response, 404) + # consume body to close connection + b"".join(response.streaming_content) # type: ignore[attr-defined]