From d958512b14676397052cb52579adccce4f3e0659 Mon Sep 17 00:00:00 2001 From: Christoph Brand Date: Thu, 4 Feb 2021 23:20:43 +0100 Subject: [PATCH] sqlalchemy: make sqlalchemy thread safe --- CHANGELOG.md | 1 + .../instrumentation/sqlalchemy/engine.py | 42 +++++++++++++------ 2 files changed, 30 insertions(+), 13 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 38e65c98d6..fe872fdd2b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Remove `component` span attribute in instrumentations. `opentelemetry-instrumentation-aiopg`, `opentelemetry-instrumentation-dbapi` Remove unused `database_type` parameter from `trace_integration` function. ([#301](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/301)) +- `opentelemetry-instrumentation-sqlalchemy` Fix multithreading issues in recording spans from SQLAlchemy ([#315](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/315)) ## [0.17b0](https://github.com/open-telemetry/opentelemetry-python-contrib/releases/tag/v0.17b0) - 2021-01-20 diff --git a/instrumentation/opentelemetry-instrumentation-sqlalchemy/src/opentelemetry/instrumentation/sqlalchemy/engine.py b/instrumentation/opentelemetry-instrumentation-sqlalchemy/src/opentelemetry/instrumentation/sqlalchemy/engine.py index e32d41718f..faa53cee21 100644 --- a/instrumentation/opentelemetry-instrumentation-sqlalchemy/src/opentelemetry/instrumentation/sqlalchemy/engine.py +++ b/instrumentation/opentelemetry-instrumentation-sqlalchemy/src/opentelemetry/instrumentation/sqlalchemy/engine.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from threading import local +from weakref import WeakKeyDictionary + from sqlalchemy.event import listen # pylint: disable=no-name-in-module from opentelemetry import trace @@ -66,12 +69,21 @@ def __init__(self, tracer, engine): self.tracer = tracer self.engine = engine self.vendor = _normalize_vendor(engine.name) - self.current_span = None + self.cursor_mapping = WeakKeyDictionary() + self.local = local() listen(engine, "before_cursor_execute", self._before_cur_exec) listen(engine, "after_cursor_execute", self._after_cur_exec) listen(engine, "handle_error", self._handle_error) + @property + def current_thread_span(self): + return getattr(self.local, "current_span", None) + + @current_thread_span.setter + def current_thread_span(self, span): + setattr(self.local, "current_span", span) + def _operation_name(self, db_name, statement): parts = [] if isinstance(statement, str): @@ -94,34 +106,38 @@ def _before_cur_exec(self, conn, cursor, statement, *args): attrs = _get_attributes_from_cursor(self.vendor, cursor, attrs) db_name = attrs.get(_DB, "") - self.current_span = self.tracer.start_span( + span = self.tracer.start_span( self._operation_name(db_name, statement), kind=trace.SpanKind.CLIENT, ) - with self.tracer.use_span(self.current_span, end_on_exit=False): - if self.current_span.is_recording(): - self.current_span.set_attribute(_STMT, statement) - self.current_span.set_attribute("db.system", self.vendor) + self.current_thread_span = self.cursor_mapping[cursor] = span + with self.tracer.use_span(span, end_on_exit=False): + if span.is_recording(): + span.set_attribute(_STMT, statement) + span.set_attribute("db.system", self.vendor) for key, value in attrs.items(): - self.current_span.set_attribute(key, value) + span.set_attribute(key, value) # pylint: disable=unused-argument def _after_cur_exec(self, conn, cursor, statement, *args): - if self.current_span is None: + span = self.cursor_mapping.get(cursor, None) + if span is None: return - self.current_span.end() + + span.end() def _handle_error(self, context): - if self.current_span is None: + span = self.current_thread_span + if span is None: return try: - if self.current_span.is_recording(): - self.current_span.set_status( + if span.is_recording(): + span.set_status( Status(StatusCode.ERROR, str(context.original_exception),) ) finally: - self.current_span.end() + span.end() def _get_attributes_from_url(url):