Skip to content

Commit

Permalink
- make context propagation robust to unavailability of root tracer
Browse files Browse the repository at this point in the history
  • Loading branch information
valentindreismann authored and vdreismann committed Feb 20, 2023
1 parent b85e476 commit f768bab
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 2 deletions.
18 changes: 16 additions & 2 deletions contrib/opencensus-ext-threading/opencensus/ext/threading/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import logging
import threading
from concurrent import futures
Expand Down Expand Up @@ -88,8 +89,14 @@ def wrap_apply_async(apply_async_func):
that will be called and wrap it then add the opencensus context."""

def call(self, func, args=(), kwds={}, **kwargs):
wrapped_func = wrap_task_func(func)
_tracer = execution_context.get_opencensus_tracer()

from opencensus.trace.tracers.noop_tracer import NoopTracer

if isinstance(_tracer, NoopTracer):
return apply_async_func(self, func, args=args, kwds={}, **kwargs)

wrapped_func = wrap_task_func(func)
propagator = binary_format.BinaryFormatPropagator()

wrapped_kwargs = {}
Expand All @@ -113,14 +120,21 @@ def wrap_submit(submit_func):
that will be called and wrap it then add the opencensus context."""

def call(self, func, *args, **kwargs):
wrapped_func = wrap_task_func(func)
_tracer = execution_context.get_opencensus_tracer()

from opencensus.trace.tracers.noop_tracer import NoopTracer

if isinstance(_tracer, NoopTracer):
return submit_func(self, func, *args, **kwargs)

wrapped_func = wrap_task_func(func)
propagator = binary_format.BinaryFormatPropagator()

wrapped_kwargs = {}
wrapped_kwargs["span_context_binary"] = propagator.to_header(
_tracer.span_context
)

wrapped_kwargs["kwds"] = kwargs
wrapped_kwargs["sampler"] = _tracer.sampler
wrapped_kwargs["exporter"] = _tracer.exporter
Expand Down
50 changes: 50 additions & 0 deletions contrib/opencensus-ext-threading/tests/test_noop_tracer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import unittest
from unittest.mock import patch, MagicMock

from opencensus.trace.tracers.noop_tracer import NoopTracer
from opencensus.ext.threading.trace import wrap_submit, wrap_apply_async


class TestNoopTracer(unittest.TestCase):
"""
In case no OpenCensus context is present (i.e. we have a NoopTracer), do _not_ pass down tracer in apply_async
and submit; instead invoke function directly.
"""

@patch("opencensus.ext.threading.trace.wrap_task_func")
@patch("opencensus.trace.execution_context.get_opencensus_tracer")
def test_noop_tracer_apply_async(
self, get_opencensus_tracer_mock: MagicMock, wrap_task_func_mock: MagicMock
):
mock_tracer = NoopTracer()
get_opencensus_tracer_mock.return_value = mock_tracer
submission_function_mock = MagicMock()
original_function_mock = MagicMock()

wrap_apply_async(submission_function_mock)(None, original_function_mock)

# check whether invocation of original function _has_ happened
submission_function_mock.assert_called_once_with(
None, original_function_mock, args=(), kwds={}
)

# ensure that the function has _not_ been wrapped
wrap_task_func_mock.assert_not_called()

@patch("opencensus.ext.threading.trace.wrap_task_func")
@patch("opencensus.trace.execution_context.get_opencensus_tracer")
def test_noop_tracer_wrap_submit(
self, get_opencensus_tracer_mock: MagicMock, wrap_task_func_mock: MagicMock
):
mock_tracer = NoopTracer()
get_opencensus_tracer_mock.return_value = mock_tracer
submission_function_mock = MagicMock()
original_function_mock = MagicMock()

wrap_submit(submission_function_mock)(None, original_function_mock)

# check whether invocation of original function _has_ happened
submission_function_mock.assert_called_once_with(None, original_function_mock)

# ensure that the function has _not_ been wrapped
wrap_task_func_mock.assert_not_called()
72 changes: 72 additions & 0 deletions contrib/opencensus-ext-threading/tests/test_tracer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import unittest
from unittest.mock import patch, MagicMock
from opencensus.ext.threading.trace import wrap_submit, wrap_apply_async


class TestTracer(unittest.TestCase):
"""
Ensures that sampler, exporter, propagator are passed through
in case global tracer is present.
"""

@patch("opencensus.trace.propagation.binary_format.BinaryFormatPropagator")
@patch("opencensus.ext.threading.trace.wrap_task_func")
@patch("opencensus.trace.execution_context.get_opencensus_tracer")
def test_apply_async_context_passed(
self,
get_opencensus_tracer_mock: MagicMock,
wrap_task_func_mock: MagicMock,
binary_format_propagator_mock: MagicMock,
):
mock_tracer = NoNoopTracerMock()
# ensure that unique object is generated
mock_tracer.sampler = MagicMock()
mock_tracer.exporter = MagicMock()
mock_tracer.propagator = MagicMock()

get_opencensus_tracer_mock.return_value = mock_tracer

submission_function_mock = MagicMock()
original_function_mock = MagicMock()

wrap_apply_async(submission_function_mock)(None, original_function_mock)

# check whether invocation of original function _has_ happened
call = submission_function_mock.call_args_list[0].kwargs

self.assertEqual(id(call["kwds"]["sampler"]), id(mock_tracer.sampler))
self.assertEqual(id(call["kwds"]["exporter"]), id(mock_tracer.exporter))
self.assertEqual(id(call["kwds"]["propagator"]), id(mock_tracer.propagator))

@patch("opencensus.trace.propagation.binary_format.BinaryFormatPropagator")
@patch("opencensus.ext.threading.trace.wrap_task_func")
@patch("opencensus.trace.execution_context.get_opencensus_tracer")
def test_wrap_submit_context_passed(
self,
get_opencensus_tracer_mock: MagicMock,
wrap_task_func_mock: MagicMock,
binary_format_propagator_mock: MagicMock,
):
mock_tracer = NoNoopTracerMock()
# ensure that unique object is generated
mock_tracer.sampler = MagicMock()
mock_tracer.exporter = MagicMock()
mock_tracer.propagator = MagicMock()

get_opencensus_tracer_mock.return_value = mock_tracer

submission_function_mock = MagicMock()
original_function_mock = MagicMock()

wrap_submit(submission_function_mock)(None, original_function_mock)

# check whether invocation of original function _has_ happened
call = submission_function_mock.call_args_list[0].kwargs

self.assertEqual(id(call["sampler"]), id(mock_tracer.sampler))
self.assertEqual(id(call["exporter"]), id(mock_tracer.exporter))
self.assertEqual(id(call["propagator"]), id(mock_tracer.propagator))


class NoNoopTracerMock(MagicMock):
pass

0 comments on commit f768bab

Please sign in to comment.