diff --git a/instrumentation/opentelemetry-instrumentation-sqlalchemy/tests/test_sqlalchemy.py b/instrumentation/opentelemetry-instrumentation-sqlalchemy/tests/test_sqlalchemy.py index 3dd538ce83..50a420d062 100644 --- a/instrumentation/opentelemetry-instrumentation-sqlalchemy/tests/test_sqlalchemy.py +++ b/instrumentation/opentelemetry-instrumentation-sqlalchemy/tests/test_sqlalchemy.py @@ -129,11 +129,12 @@ async def run(): def test_not_recording(self): mock_tracer = mock.Mock() mock_span = mock.Mock() + mock_context = mock.Mock() mock_span.is_recording.return_value = False - mock_span.__enter__ = mock.Mock(return_value=(mock.Mock(), None)) - mock_span.__exit__ = mock.Mock(return_value=None) - mock_tracer.start_span.return_value = mock_span - mock_tracer.start_as_current_span.return_value = mock_span + mock_context.__enter__ = mock.Mock(return_value=mock_span) + mock_context.__exit__ = mock.Mock(return_value=None) + mock_tracer.start_span.return_value = mock_context + mock_tracer.start_as_current_span.return_value = mock_context with mock.patch("opentelemetry.trace.get_tracer") as tracer: tracer.return_value = mock_tracer engine = create_engine("sqlite:///:memory:")