Skip to content

Commit

Permalink
add _span to sqlalchemy execution context instead of maintaining mapping
Browse files Browse the repository at this point in the history
  • Loading branch information
mnbbrown committed Jul 6, 2021
1 parent 67b6a42 commit dbcb2ab
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from threading import local

from sqlalchemy.event import listen # pylint: disable=no-name-in-module

from opentelemetry import trace
Expand Down Expand Up @@ -59,21 +57,11 @@ def __init__(self, tracer, engine):
self.tracer = tracer
self.engine = engine
self.vendor = _normalize_vendor(engine.name)
self.cursor_mapping = {}
self.local = local()

listen(engine, "before_cursor_execute", self._before_cur_exec)
listen(engine, "after_cursor_execute", self._after_cur_exec)
listen(engine, "handle_error", self._handle_error)

@property
def current_thread_span(self):
return getattr(self.local, "current_span", None)

@current_thread_span.setter
def current_thread_span(self, span):
setattr(self.local, "current_span", span)

def _operation_name(self, db_name, statement):
parts = []
if isinstance(statement, str):
Expand All @@ -90,7 +78,7 @@ def _operation_name(self, db_name, statement):
return " ".join(parts)

# pylint: disable=unused-argument
def _before_cur_exec(self, conn, cursor, statement, *args):
def _before_cur_exec(self, conn, cursor, statement, params, context, executemany):
attrs, found = _get_attributes_from_url(conn.engine.url)
if not found:
attrs = _get_attributes_from_cursor(self.vendor, cursor, attrs)
Expand All @@ -100,27 +88,25 @@ def _before_cur_exec(self, conn, cursor, statement, *args):
self._operation_name(db_name, statement),
kind=trace.SpanKind.CLIENT,
)
self.current_thread_span = span
self.cursor_mapping[cursor] = span
with trace.use_span(span, end_on_exit=False):
if span.is_recording():
span.set_attribute(SpanAttributes.DB_STATEMENT, statement)
span.set_attribute(SpanAttributes.DB_SYSTEM, self.vendor)
for key, value in attrs.items():
span.set_attribute(key, value)

context._span = span

# pylint: disable=unused-argument
def _after_cur_exec(self, conn, cursor, statement, *args):
span = self.cursor_mapping.get(cursor, None)
def _after_cur_exec(self, conn, cursor, statement, params, context, executemany):
span = getattr(context, '_span', None)
if span is None:
return

span.end()
self._cleanup(cursor)

def _handle_error(self, context):
# span = self.cursor_mapping[context.cursor]
span = self.current_thread_span
span = getattr(context.execution_context, '_span', None)
if span is None:
return

Expand All @@ -131,13 +117,7 @@ def _handle_error(self, context):
)
finally:
span.end()
self._cleanup(context.cursor)

def _cleanup(self, cursor):
try:
del self.cursor_mapping[cursor]
except KeyError:
pass


def _get_attributes_from_url(url):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,5 +243,4 @@ def insert_players(session):
close_all_sessions()

spans = self.memory_exporter.get_finished_spans()
breakpoint()
self.assertEqual(len(spans), 5)
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class SQLiteTestCase(SQLAlchemyTestMixin):
def test_engine_execute_errors(self):
# ensures that SQL errors are reported
stmt = "SELECT * FROM a_wrong_table"
with pytest.raises(Exception):
with pytest.raises(OperationalError):
with self.connection() as conn:
conn.execute(stmt).fetchall()

Expand Down

0 comments on commit dbcb2ab

Please sign in to comment.