From dbcb2ab589368c2031c7cae2a9cc8cb4203cdc22 Mon Sep 17 00:00:00 2001 From: Matthew Brown Date: Tue, 6 Jul 2021 16:13:36 +0100 Subject: [PATCH] add _span to sqlalchemy execution context instead of maintaining mapping --- .../instrumentation/sqlalchemy/engine.py | 32 ++++--------------- .../tests/sqlalchemy_tests/mixins.py | 1 - .../tests/sqlalchemy_tests/test_sqlite.py | 2 +- 3 files changed, 7 insertions(+), 28 deletions(-) diff --git a/instrumentation/opentelemetry-instrumentation-sqlalchemy/src/opentelemetry/instrumentation/sqlalchemy/engine.py b/instrumentation/opentelemetry-instrumentation-sqlalchemy/src/opentelemetry/instrumentation/sqlalchemy/engine.py index ab8f84bd07..b1297e9bcb 100644 --- a/instrumentation/opentelemetry-instrumentation-sqlalchemy/src/opentelemetry/instrumentation/sqlalchemy/engine.py +++ b/instrumentation/opentelemetry-instrumentation-sqlalchemy/src/opentelemetry/instrumentation/sqlalchemy/engine.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from threading import local - from sqlalchemy.event import listen # pylint: disable=no-name-in-module from opentelemetry import trace @@ -59,21 +57,11 @@ def __init__(self, tracer, engine): self.tracer = tracer self.engine = engine self.vendor = _normalize_vendor(engine.name) - self.cursor_mapping = {} - self.local = local() listen(engine, "before_cursor_execute", self._before_cur_exec) listen(engine, "after_cursor_execute", self._after_cur_exec) listen(engine, "handle_error", self._handle_error) - @property - def current_thread_span(self): - return getattr(self.local, "current_span", None) - - @current_thread_span.setter - def current_thread_span(self, span): - setattr(self.local, "current_span", span) - def _operation_name(self, db_name, statement): parts = [] if isinstance(statement, str): @@ -90,7 +78,7 @@ def _operation_name(self, db_name, statement): return " ".join(parts) # pylint: disable=unused-argument - def _before_cur_exec(self, conn, cursor, statement, *args): + def _before_cur_exec(self, conn, cursor, statement, params, context, executemany): attrs, found = _get_attributes_from_url(conn.engine.url) if not found: attrs = _get_attributes_from_cursor(self.vendor, cursor, attrs) @@ -100,8 +88,6 @@ def _before_cur_exec(self, conn, cursor, statement, *args): self._operation_name(db_name, statement), kind=trace.SpanKind.CLIENT, ) - self.current_thread_span = span - self.cursor_mapping[cursor] = span with trace.use_span(span, end_on_exit=False): if span.is_recording(): span.set_attribute(SpanAttributes.DB_STATEMENT, statement) @@ -109,18 +95,18 @@ def _before_cur_exec(self, conn, cursor, statement, *args): for key, value in attrs.items(): span.set_attribute(key, value) + context._span = span + # pylint: disable=unused-argument - def _after_cur_exec(self, conn, cursor, statement, *args): - span = self.cursor_mapping.get(cursor, None) + def _after_cur_exec(self, conn, cursor, statement, params, context, executemany): + span = getattr(context, '_span', None) if span is None: return span.end() - self._cleanup(cursor) def _handle_error(self, context): - # span = self.cursor_mapping[context.cursor] - span = self.current_thread_span + span = getattr(context.execution_context, '_span', None) if span is None: return @@ -131,13 +117,7 @@ def _handle_error(self, context): ) finally: span.end() - self._cleanup(context.cursor) - def _cleanup(self, cursor): - try: - del self.cursor_mapping[cursor] - except KeyError: - pass def _get_attributes_from_url(url): diff --git a/tests/opentelemetry-docker-tests/tests/sqlalchemy_tests/mixins.py b/tests/opentelemetry-docker-tests/tests/sqlalchemy_tests/mixins.py index 04db609886..a3bd18ddf7 100644 --- a/tests/opentelemetry-docker-tests/tests/sqlalchemy_tests/mixins.py +++ b/tests/opentelemetry-docker-tests/tests/sqlalchemy_tests/mixins.py @@ -243,5 +243,4 @@ def insert_players(session): close_all_sessions() spans = self.memory_exporter.get_finished_spans() - breakpoint() self.assertEqual(len(spans), 5) diff --git a/tests/opentelemetry-docker-tests/tests/sqlalchemy_tests/test_sqlite.py b/tests/opentelemetry-docker-tests/tests/sqlalchemy_tests/test_sqlite.py index 981e82d7c2..0acba0fec2 100644 --- a/tests/opentelemetry-docker-tests/tests/sqlalchemy_tests/test_sqlite.py +++ b/tests/opentelemetry-docker-tests/tests/sqlalchemy_tests/test_sqlite.py @@ -35,7 +35,7 @@ class SQLiteTestCase(SQLAlchemyTestMixin): def test_engine_execute_errors(self): # ensures that SQL errors are reported stmt = "SELECT * FROM a_wrong_table" - with pytest.raises(Exception): + with pytest.raises(OperationalError): with self.connection() as conn: conn.execute(stmt).fetchall()