Skip to content

Commit

Permalink
feat: configure header extraction for ASGI middleware via constructor…
Browse files Browse the repository at this point in the history
… params
  • Loading branch information
adriangb committed Oct 31, 2023
1 parent 3478831 commit 87a619e
Show file tree
Hide file tree
Showing 8 changed files with 162 additions and 96 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- `opentelemetry-instrumentation-system-metrics` Add support for collecting process metrics
([#1948](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1948))
- Add support for configuring ASGI middleware header extraction via runtime constructor parameters
([#2026](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2026))

### Fixed

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,11 +189,13 @@ def client_response_hook(span: Span, message: dict):
---
"""

from __future__ import annotations

import typing
import urllib
from functools import wraps
from timeit import default_timer
from typing import Tuple
from typing import Any, Awaitable, Callable, Tuple

from asgiref.compatibility import guarantee_single_callable

Expand Down Expand Up @@ -332,55 +334,23 @@ def collect_request_attributes(scope):
return result


def collect_custom_request_headers_attributes(scope):
"""returns custom HTTP request headers to be added into SERVER span as span attributes
Refer specification https://github.com/open-telemetry/opentelemetry-specification/blob/main/specification/trace/semantic_conventions/http.md#http-request-and-response-headers
def collect_custom_headers_attributes(scope_or_response_message: dict[str, Any], sanitize: SanitizeValue, header_regexes: list[str], normalize_names: Callable[[str], str]) -> dict[str, str]:
"""
Returns custom HTTP request or response headers to be added into SERVER span as span attributes.
sanitize = SanitizeValue(
get_custom_headers(
OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SANITIZE_FIELDS
)
)

# Decode headers before processing.
headers = {
_key.decode("utf8"): _value.decode("utf8")
for (_key, _value) in scope.get("headers")
}

return sanitize.sanitize_header_values(
headers,
get_custom_headers(
OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SERVER_REQUEST
),
normalise_request_header_name,
)


def collect_custom_response_headers_attributes(message):
"""returns custom HTTP response headers to be added into SERVER span as span attributes
Refer specification https://github.com/open-telemetry/opentelemetry-specification/blob/main/specification/trace/semantic_conventions/http.md#http-request-and-response-headers
Refer specifications:
- https://github.com/open-telemetry/opentelemetry-specification/blob/main/specification/trace/semantic_conventions/http.md#http-request-and-response-headers
- https://github.com/open-telemetry/opentelemetry-specification/blob/main/specification/trace/semantic_conventions/http.md#http-request-and-response-headers
"""

sanitize = SanitizeValue(
get_custom_headers(
OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SANITIZE_FIELDS
)
)

# Decode headers before processing.
headers = {
headers: dict[str, str] = {
_key.decode("utf8"): _value.decode("utf8")
for (_key, _value) in message.get("headers")
for (_key, _value) in scope_or_response_message.get("headers") or {}
}

return sanitize.sanitize_header_values(
headers,
get_custom_headers(
OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SERVER_RESPONSE
),
normalise_response_header_name,
header_regexes,
normalize_names,
)


Expand Down Expand Up @@ -493,6 +463,9 @@ def __init__(
tracer_provider=None,
meter_provider=None,
meter=None,
http_capture_headers_server_request : list[str] | None = None,
http_capture_headers_server_response: list[str] | None = None,
http_capture_headers_sanitize_fields: list[str] | None = None
):
self.app = guarantee_single_callable(app)
self.tracer = trace.get_tracer(__name__, __version__, tracer_provider)
Expand Down Expand Up @@ -530,7 +503,20 @@ def __init__(
self.client_response_hook = client_response_hook
self.content_length_header = None

async def __call__(self, scope, receive, send):
# Environment variables as constructor parameters
self.http_capture_headers_server_request = http_capture_headers_server_request or (
get_custom_headers(OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SERVER_REQUEST)
) or None
self.http_capture_headers_server_response = http_capture_headers_server_response or (
get_custom_headers(OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SERVER_RESPONSE)
) or None
self.http_capture_headers_sanitize_fields = SanitizeValue(
http_capture_headers_sanitize_fields or (
get_custom_headers(OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SANITIZE_FIELDS)
) or []
)

async def __call__(self, scope: dict[str, Any], receive: Callable[[], Awaitable[dict[str, Any]]], send: Callable[[dict[str, Any]], Awaitable[None]]) -> None:
"""The ASGI application
Args:
Expand Down Expand Up @@ -573,7 +559,14 @@ async def __call__(self, scope, receive, send):

if current_span.kind == trace.SpanKind.SERVER:
custom_attributes = (
collect_custom_request_headers_attributes(scope)
collect_custom_headers_attributes(
scope,
self.http_capture_headers_sanitize_fields,
self.http_capture_headers_server_request,
normalise_request_header_name,
)
if self.http_capture_headers_server_request
else {}
)
if len(custom_attributes) > 0:
current_span.set_attributes(custom_attributes)
Expand Down Expand Up @@ -644,7 +637,7 @@ def _get_otel_send(
self, server_span, server_span_name, scope, send, duration_attrs
):
@wraps(send)
async def otel_send(message):
async def otel_send(message: dict[str, Any]) -> None:
with self.tracer.start_as_current_span(
" ".join((server_span_name, scope["type"], "send"))
) as send_span:
Expand All @@ -668,7 +661,14 @@ async def otel_send(message):
and "headers" in message
):
custom_response_attributes = (
collect_custom_response_headers_attributes(message)
collect_custom_headers_attributes(
message,
self.http_capture_headers_sanitize_fields,
self.http_capture_headers_server_response,
normalise_response_header_name,
)
if self.http_capture_headers_server_response
else {}
)
if len(custom_response_attributes) > 0:
server_span.set_attributes(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from unittest import mock

import opentelemetry.instrumentation.asgi as otel_asgi
Expand Down Expand Up @@ -72,21 +73,20 @@ async def websocket_app_with_custom_headers(scope, receive, send):
break


@mock.patch.dict(
"os.environ",
{
OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SANITIZE_FIELDS: ".*my-secret.*",
OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SERVER_REQUEST: "Custom-Test-Header-1,Custom-Test-Header-2,Custom-Test-Header-3,Regex-Test-Header-.*,Regex-Invalid-Test-Header-.*,.*my-secret.*",
OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SERVER_RESPONSE: "Custom-Test-Header-1,Custom-Test-Header-2,Custom-Test-Header-3,my-custom-regex-header-.*,invalid-regex-header-.*,.*my-secret.*",
},
)
class TestCustomHeaders(AsgiTestBase, TestBase):
constructor_params = {}
__test__ = False

def __init_subclass__(cls) -> None:
if cls is not TestCustomHeaders:
cls.__test__ = True

def setUp(self):
super().setUp()
self.tracer_provider, self.exporter = TestBase.create_tracer_provider()
self.tracer = self.tracer_provider.get_tracer(__name__)
self.app = otel_asgi.OpenTelemetryMiddleware(
simple_asgi, tracer_provider=self.tracer_provider
simple_asgi, tracer_provider=self.tracer_provider, **self.constructor_params,
)

def test_http_custom_request_headers_in_span_attributes(self):
Expand Down Expand Up @@ -148,7 +148,7 @@ def test_http_custom_request_headers_not_in_span_attributes(self):

def test_http_custom_response_headers_in_span_attributes(self):
self.app = otel_asgi.OpenTelemetryMiddleware(
http_app_with_custom_headers, tracer_provider=self.tracer_provider
http_app_with_custom_headers, tracer_provider=self.tracer_provider, **self.constructor_params,
)
self.seed_app(self.app)
self.send_default_request()
Expand All @@ -175,7 +175,7 @@ def test_http_custom_response_headers_in_span_attributes(self):

def test_http_custom_response_headers_not_in_span_attributes(self):
self.app = otel_asgi.OpenTelemetryMiddleware(
http_app_with_custom_headers, tracer_provider=self.tracer_provider
http_app_with_custom_headers, tracer_provider=self.tracer_provider, **self.constructor_params,
)
self.seed_app(self.app)
self.send_default_request()
Expand Down Expand Up @@ -277,6 +277,7 @@ def test_websocket_custom_response_headers_in_span_attributes(self):
self.app = otel_asgi.OpenTelemetryMiddleware(
websocket_app_with_custom_headers,
tracer_provider=self.tracer_provider,
**self.constructor_params,
)
self.seed_app(self.app)
self.send_input({"type": "websocket.connect"})
Expand Down Expand Up @@ -317,6 +318,7 @@ def test_websocket_custom_response_headers_not_in_span_attributes(self):
self.app = otel_asgi.OpenTelemetryMiddleware(
websocket_app_with_custom_headers,
tracer_provider=self.tracer_provider,
**self.constructor_params,
)
self.seed_app(self.app)
self.send_input({"type": "websocket.connect"})
Expand All @@ -333,3 +335,34 @@ def test_websocket_custom_response_headers_not_in_span_attributes(self):
if span.kind == SpanKind.SERVER:
for key, _ in not_expected.items():
self.assertNotIn(key, span.attributes)



SANITIZE_FIELDS_TEST_VALUE = ".*my-secret.*"
SERVER_REQUEST_TEST_VALUE = "Custom-Test-Header-1,Custom-Test-Header-2,Custom-Test-Header-3,Regex-Test-Header-.*,Regex-Invalid-Test-Header-.*,.*my-secret.*"
SERVER_RESPONSE_TEST_VALUE = "Custom-Test-Header-1,Custom-Test-Header-2,Custom-Test-Header-3,my-custom-regex-header-.*,invalid-regex-header-.*,.*my-secret.*"

class TestCustomHeadersEnv(TestCustomHeaders):
def setUp(self):
os.environ.update(
{
OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SANITIZE_FIELDS: SANITIZE_FIELDS_TEST_VALUE,
OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SERVER_REQUEST: SERVER_REQUEST_TEST_VALUE,
OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SERVER_RESPONSE: SERVER_RESPONSE_TEST_VALUE,
}
)
super().setUp()

def tearDown(self):
os.environ.pop(OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SANITIZE_FIELDS, None)
os.environ.pop(OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SERVER_REQUEST, None)
os.environ.pop(OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SERVER_RESPONSE, None)
super().tearDown()


class TestCustomHeadersConstructor(TestCustomHeaders):
constructor_params = {
"http_capture_headers_sanitize_fields": SANITIZE_FIELDS_TEST_VALUE.split(","),
"http_capture_headers_server_request": SERVER_REQUEST_TEST_VALUE.split(","),
"http_capture_headers_server_response": SERVER_RESPONSE_TEST_VALUE.split(","),
}
Original file line number Diff line number Diff line change
Expand Up @@ -801,18 +801,15 @@ class TestAsgiApplicationRaisingError(AsgiTestBase):
def tearDown(self):
pass

@mock.patch(
"opentelemetry.instrumentation.asgi.collect_custom_request_headers_attributes",
side_effect=ValueError("whatever"),
)
def test_asgi_issue_1883(
self, mock_collect_custom_request_headers_attributes
):
def test_asgi_issue_1883(self):
"""
Test that exception UnboundLocalError local variable 'start' referenced before assignment is not raised
See https://github.com/open-telemetry/opentelemetry-python-contrib/issues/1883
"""
app = otel_asgi.OpenTelemetryMiddleware(simple_asgi)
async def bad_app(_scope, _receive, _send):
raise ValueError("whatever")

app = otel_asgi.OpenTelemetryMiddleware(bad_app)
self.seed_app(app)
self.send_default_request()
try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,16 @@
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.trace import Span, SpanKind, use_span
from opentelemetry.util.http import (
OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SANITIZE_FIELDS,
OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SERVER_REQUEST,
OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SERVER_RESPONSE,
SanitizeValue,
_parse_active_request_count_attrs,
_parse_duration_attrs,
get_custom_headers,
get_excluded_urls,
get_traced_request_attrs,
normalise_request_header_name,
)

try:
Expand Down Expand Up @@ -91,10 +97,7 @@ def __call__(self, request):
try:
from opentelemetry.instrumentation.asgi import asgi_getter, asgi_setter
from opentelemetry.instrumentation.asgi import (
collect_custom_request_headers_attributes as asgi_collect_custom_request_attributes,
)
from opentelemetry.instrumentation.asgi import (
collect_custom_response_headers_attributes as asgi_collect_custom_response_attributes,
collect_custom_headers_attributes as asgi_collect_custom_headers_attributes,
)
from opentelemetry.instrumentation.asgi import (
collect_request_attributes as asgi_collect_request_attributes,
Expand Down Expand Up @@ -249,7 +252,18 @@ def process_request(self, request):
)
if span.is_recording() and span.kind == SpanKind.SERVER:
attributes.update(
asgi_collect_custom_request_attributes(carrier)
asgi_collect_custom_headers_attributes(
carrier,
SanitizeValue(
get_custom_headers(
OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SANITIZE_FIELDS
)
),
get_custom_headers(
OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SERVER_REQUEST
),
normalise_request_header_name,
)
)
else:
if span.is_recording() and span.kind == SpanKind.SERVER:
Expand Down Expand Up @@ -337,7 +351,18 @@ def process_response(self, request, response):
asgi_setter.set(custom_headers, key, value)

custom_res_attributes = (
asgi_collect_custom_response_attributes(custom_headers)
asgi_collect_custom_headers_attributes(
custom_headers,
SanitizeValue(
get_custom_headers(
OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SANITIZE_FIELDS
)
),
get_custom_headers(
OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SERVER_RESPONSE
),
normalise_request_header_name,
)
)
for key, value in custom_res_attributes.items():
span.set_attribute(key, value)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,8 @@ async def test_http_custom_response_headers_in_span_attributes(self):
),
"http.response.header.my_secret_header": ("[REDACTED]",),
}
await self.async_client.get("/traced_custom_header/")
resp = await self.async_client.get("/traced_custom_header/")
assert resp.status_code == 200
spans = self.exporter.get_finished_spans()
self.assertEqual(len(spans), 1)

Expand Down
Loading

0 comments on commit 87a619e

Please sign in to comment.