diff --git a/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py b/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py index 1d61e8cfd3..14b2ce7d2f 100644 --- a/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py @@ -104,7 +104,10 @@ def response_hook(span, instance, response): _format_command_args, ) from opentelemetry.instrumentation.redis.version import __version__ -from opentelemetry.instrumentation.utils import unwrap +from opentelemetry.instrumentation.utils import ( + is_instrumentation_enabled, + unwrap, +) from opentelemetry.semconv.trace import SpanAttributes from opentelemetry.trace import Span @@ -179,9 +182,12 @@ def _instrument( response_hook: _ResponseHookT = None, ): def _traced_execute_command(func, instance, args, kwargs): + + if not is_instrumentation_enabled(): + return func(*args, **kwargs) + query = _format_command_args(args) name = _build_span_name(instance, args) - with tracer.start_as_current_span( name, kind=trace.SpanKind.CLIENT ) as span: @@ -197,6 +203,10 @@ def _traced_execute_command(func, instance, args, kwargs): return response def _traced_execute_pipeline(func, instance, args, kwargs): + + if not is_instrumentation_enabled(): + return func(*args, **kwargs) + ( command_stack, resource, @@ -248,6 +258,10 @@ def _traced_execute_pipeline(func, instance, args, kwargs): ) async def _async_traced_execute_command(func, instance, args, kwargs): + + if not is_instrumentation_enabled(): + return await func(*args, **kwargs) + query = _format_command_args(args) name = _build_span_name(instance, args) @@ -266,6 +280,10 @@ async def _async_traced_execute_command(func, instance, args, kwargs): return response async def _async_traced_execute_pipeline(func, instance, args, kwargs): + + if not is_instrumentation_enabled(): + return await func(*args, **kwargs) + ( command_stack, resource, diff --git a/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py b/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py index 4a2fce5026..e2578277cf 100644 --- a/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py +++ b/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py @@ -20,6 +20,7 @@ from opentelemetry import trace from opentelemetry.instrumentation.redis import RedisInstrumentor +from opentelemetry.instrumentation.utils import suppress_instrumentation from opentelemetry.semconv.trace import ( DbSystemValues, NetTransportValues, @@ -66,6 +67,40 @@ def test_not_recording(self): self.assertFalse(mock_span.set_attribute.called) self.assertFalse(mock_span.set_status.called) + def test_suppress_instrumentation_no_span(self): + redis_client = redis.Redis() + + with mock.patch.object(redis_client, "connection"): + redis_client.get("key") + spans = self.memory_exporter.get_finished_spans() + + self.assertEqual(len(spans), 1) + self.memory_exporter.clear() + + with suppress_instrumentation(): + with mock.patch.object(redis_client, "connection"): + redis_client.get("key") + spans = self.memory_exporter.get_finished_spans() + + self.assertEqual(len(spans), 0) + + def test_suppress_async_instrumentation_no_span(self): + redis_client = redis.asyncio.Redis() + + with mock.patch.object(redis_client, "connection", AsyncMock()): + asyncio.run(redis_client.get("key")) + spans = self.memory_exporter.get_finished_spans() + + self.assertEqual(len(spans), 1) + self.memory_exporter.clear() + + with suppress_instrumentation(): + with mock.patch.object(redis_client, "connection", AsyncMock()): + asyncio.run(redis_client.get("key")) + spans = self.memory_exporter.get_finished_spans() + + self.assertEqual(len(spans), 0) + def test_instrument_uninstrument(self): redis_client = redis.Redis()