diff --git a/elasticapm/contrib/starlette/__init__.py b/elasticapm/contrib/starlette/__init__.py index 62925c42d..83369d83f 100644 --- a/elasticapm/contrib/starlette/__init__.py +++ b/elasticapm/contrib/starlette/__init__.py @@ -35,6 +35,7 @@ from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.requests import Request from starlette.responses import Response +from starlette.routing import Match from starlette.types import ASGIApp import elasticapm @@ -178,7 +179,8 @@ async def _request_started(self, request: Request): self.client.begin_transaction("request", trace_parent=trace_parent) await set_context(lambda: get_data_from_request(request, self.client.config, constants.TRANSACTION), "request") - elasticapm.set_transaction_name("{} {}".format(request.method, request.url.path), override=False) + transaction_name = self.get_route_name(request) or request.url.path + elasticapm.set_transaction_name("{} {}".format(request.method, transaction_name), override=False) async def _request_finished(self, response: Response): """Captures the end of the request processing to APM. @@ -192,3 +194,15 @@ async def _request_finished(self, response: Response): result = "HTTP {}xx".format(response.status_code // 100) elasticapm.set_transaction_result(result, override=False) + + def get_route_name(self, request: Request) -> str: + path = None + routes = request.scope["app"].routes + for route in routes: + match, _ = route.matches(request.scope) + if match == Match.FULL: + path = route.path + break + elif match == Match.PARTIAL and path is None: + path = route.path + return path diff --git a/tests/contrib/asyncio/starlette_tests.py b/tests/contrib/asyncio/starlette_tests.py index 6e6907a6c..160b93674 100644 --- a/tests/contrib/asyncio/starlette_tests.py +++ b/tests/contrib/asyncio/starlette_tests.py @@ -55,6 +55,11 @@ async def hi(request): pass return PlainTextResponse("ok") + @app.route("/hi/{name}", methods=["GET"]) + async def hi_name(request): + name = request.path_params["name"] + return PlainTextResponse("Hi {}".format(name)) + @app.route("/raise-exception", methods=["GET", "POST"]) async def raise_exception(request): raise ValueError() @@ -210,3 +215,23 @@ def test_capture_headers_body_is_dynamic(app, elasticapm_client): assert elasticapm_client.events[constants.TRANSACTION][2]["context"]["request"]["body"] == "[REDACTED]" assert "headers" not in elasticapm_client.events[constants.ERROR][1]["context"]["request"] assert elasticapm_client.events[constants.ERROR][1]["context"]["request"]["body"] == "[REDACTED]" + + +def test_transaction_name_is_route(app, elasticapm_client): + client = TestClient(app) + + response = client.get( + "/hi/shay", + headers={ + constants.TRACEPARENT_HEADER_NAME: "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-03", + constants.TRACESTATE_HEADER_NAME: "foo=bar,bar=baz", + "REMOTE_ADDR": "127.0.0.1", + }, + ) + + assert response.status_code == 200 + + assert len(elasticapm_client.events[constants.TRANSACTION]) == 1 + transaction = elasticapm_client.events[constants.TRANSACTION][0] + assert transaction["name"] == "GET /hi/{name}" + assert transaction["context"]["request"]["url"]["pathname"] == "/hi/shay"