Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for SQLAlchemy 1.4 #568

Merged
merged 21 commits into from
Aug 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
([#563](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/563))
- `opentelemetry-exporter-datadog` Datadog exporter should not use `unknown_service` as fallback resource service name.
([#570](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/570))
- Add support for the async extension of SQLAlchemy (>= 1.4)
([#568](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/568))

### Added
- `opentelemetry-instrumentation-httpx` Add `httpx` instrumentation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,32 @@
engine=engine,
)

# of the async variant of SQLAlchemy

from sqlalchemy.ext.asyncio import create_async_engine

from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor
import sqlalchemy

engine = create_async_engine("sqlite:///:memory:")
SQLAlchemyInstrumentor().instrument(
engine=engine.sync_engine
)

API
---
"""
from typing import Collection

import sqlalchemy
from packaging.version import parse as parse_version
from wrapt import wrap_function_wrapper as _w

from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
from opentelemetry.instrumentation.sqlalchemy.engine import (
EngineTracer,
_get_tracer,
_wrap_create_async_engine,
_wrap_create_engine,
)
from opentelemetry.instrumentation.sqlalchemy.package import _instruments
Expand Down Expand Up @@ -76,6 +90,13 @@ def _instrument(self, **kwargs):
"""
_w("sqlalchemy", "create_engine", _wrap_create_engine)
_w("sqlalchemy.engine", "create_engine", _wrap_create_engine)
if parse_version(sqlalchemy.__version__).release >= (1, 4):
_w(
"sqlalchemy.ext.asyncio",
"create_async_engine",
_wrap_create_async_engine,
)

if kwargs.get("engine") is not None:
return EngineTracer(
_get_tracer(
Expand All @@ -88,3 +109,5 @@ def _instrument(self, **kwargs):
def _uninstrument(self, **kwargs):
unwrap(sqlalchemy, "create_engine")
unwrap(sqlalchemy.engine, "create_engine")
if parse_version(sqlalchemy.__version__).release >= (1, 4):
unwrap(sqlalchemy.ext.asyncio, "create_async_engine")
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from threading import local

from sqlalchemy.event import listen # pylint: disable=no-name-in-module

from opentelemetry import trace
Expand Down Expand Up @@ -44,6 +42,16 @@ def _get_tracer(engine, tracer_provider=None):
)


# pylint: disable=unused-argument
def _wrap_create_async_engine(func, module, args, kwargs):
"""Trace the SQLAlchemy engine, creating an `EngineTracer`
object that will listen to SQLAlchemy events.
"""
engine = func(*args, **kwargs)
EngineTracer(_get_tracer(engine), engine.sync_engine)
return engine


# pylint: disable=unused-argument
def _wrap_create_engine(func, module, args, kwargs):
"""Trace the SQLAlchemy engine, creating an `EngineTracer`
Expand All @@ -59,20 +67,10 @@ def __init__(self, tracer, engine):
self.tracer = tracer
self.engine = engine
self.vendor = _normalize_vendor(engine.name)
self.cursor_mapping = {}
self.local = local()

listen(engine, "before_cursor_execute", self._before_cur_exec)
listen(engine, "after_cursor_execute", self._after_cur_exec)
listen(engine, "handle_error", self._handle_error)

@property
def current_thread_span(self):
return getattr(self.local, "current_span", None)

@current_thread_span.setter
def current_thread_span(self, span):
setattr(self.local, "current_span", span)
listen(engine, "after_cursor_execute", _after_cur_exec)
listen(engine, "handle_error", _handle_error)

def _operation_name(self, db_name, statement):
parts = []
Expand All @@ -90,7 +88,9 @@ def _operation_name(self, db_name, statement):
return " ".join(parts)

# pylint: disable=unused-argument
def _before_cur_exec(self, conn, cursor, statement, *args):
def _before_cur_exec(
self, conn, cursor, statement, params, context, executemany
):
attrs, found = _get_attributes_from_url(conn.engine.url)
if not found:
attrs = _get_attributes_from_cursor(self.vendor, cursor, attrs)
Expand All @@ -100,42 +100,35 @@ def _before_cur_exec(self, conn, cursor, statement, *args):
self._operation_name(db_name, statement),
kind=trace.SpanKind.CLIENT,
)
self.current_thread_span = self.cursor_mapping[cursor] = span
with trace.use_span(span, end_on_exit=False):
if span.is_recording():
span.set_attribute(SpanAttributes.DB_STATEMENT, statement)
span.set_attribute(SpanAttributes.DB_SYSTEM, self.vendor)
for key, value in attrs.items():
span.set_attribute(key, value)

# pylint: disable=unused-argument
def _after_cur_exec(self, conn, cursor, statement, *args):
span = self.cursor_mapping.get(cursor, None)
if span is None:
return
context._otel_span = span

span.end()
self._cleanup(cursor)

def _handle_error(self, context):
span = self.current_thread_span
if span is None:
return
# pylint: disable=unused-argument
def _after_cur_exec(conn, cursor, statement, params, context, executemany):
span = getattr(context, "_otel_span", None)
if span is None:
return

try:
if span.is_recording():
span.set_status(
Status(StatusCode.ERROR, str(context.original_exception),)
)
finally:
span.end()
self._cleanup(context.cursor)

def _cleanup(self, cursor):
try:
del self.cursor_mapping[cursor]
except KeyError:
pass
span.end()


def _handle_error(context):
span = getattr(context.execution_context, "_otel_span", None)
if span is None:
return

if span.is_recording():
span.set_status(
Status(StatusCode.ERROR, str(context.original_exception),)
)
span.end()


def _get_attributes_from_url(url):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
from unittest import mock

import pytest
import sqlalchemy
from sqlalchemy import create_engine

from opentelemetry import trace
Expand All @@ -38,6 +41,29 @@ def test_trace_integration(self):
self.assertEqual(spans[0].name, "SELECT :memory:")
self.assertEqual(spans[0].kind, trace.SpanKind.CLIENT)

@pytest.mark.skipif(
not sqlalchemy.__version__.startswith("1.4"),
reason="only run async tests for 1.4",
)
def test_async_trace_integration(self):
async def run():
from sqlalchemy.ext.asyncio import ( # pylint: disable-all
create_async_engine,
)

engine = create_async_engine("sqlite+aiosqlite:///:memory:")
SQLAlchemyInstrumentor().instrument(
engine=engine.sync_engine, tracer_provider=self.tracer_provider
)
async with engine.connect() as cnx:
await cnx.execute(sqlalchemy.text("SELECT 1 + 1;"))
spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 1)
self.assertEqual(spans[0].name, "SELECT :memory:")
self.assertEqual(spans[0].kind, trace.SpanKind.CLIENT)

asyncio.get_event_loop().run_until_complete(run())

def test_not_recording(self):
mock_tracer = mock.Mock()
mock_span = mock.Mock()
Expand Down Expand Up @@ -68,3 +94,24 @@ def test_create_engine_wrapper(self):
self.assertEqual(len(spans), 1)
self.assertEqual(spans[0].name, "SELECT :memory:")
self.assertEqual(spans[0].kind, trace.SpanKind.CLIENT)

@pytest.mark.skipif(
not sqlalchemy.__version__.startswith("1.4"),
reason="only run async tests for 1.4",
)
def test_create_async_engine_wrapper(self):
async def run():
SQLAlchemyInstrumentor().instrument()
from sqlalchemy.ext.asyncio import ( # pylint: disable-all
create_async_engine,
)

engine = create_async_engine("sqlite+aiosqlite:///:memory:")
async with engine.connect() as cnx:
await cnx.execute(sqlalchemy.text("SELECT 1 + 1;"))
spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 1)
self.assertEqual(spans[0].name, "SELECT :memory:")
self.assertEqual(spans[0].kind, trace.SpanKind.CLIENT)

asyncio.get_event_loop().run_until_complete(run())
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import contextlib
import logging
import threading
import unittest

from sqlalchemy import Column, Integer, String, create_engine, insert
from sqlalchemy.ext.declarative import declarative_base
Expand Down Expand Up @@ -242,4 +243,10 @@ def insert_players(session):
close_all_sessions()

spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 5)

# SQLAlchemy 1.4 uses the `execute_values` extension of the psycopg2 dialect to
# batch inserts together which means `insert_players` only generates one span.
# See https://docs.sqlalchemy.org/en/14/changelog/migration_14.html#orm-batch-inserts-with-psycopg2-now-batch-statements-with-returning-in-most-cases
self.assertEqual(
len(spans), 5 if self.VENDOR not in ["postgresql"] else 3
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why did this change in this PR? Can we document the difference between vendors as a comment?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

)
15 changes: 9 additions & 6 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,8 @@ envlist =
py3{6,7,8,9}-test-instrumentation-grpc

; opentelemetry-instrumentation-sqlalchemy
py3{6,7,8,9}-test-instrumentation-sqlalchemy
pypy3-test-instrumentation-sqlalchemy
py3{6,7,8,9}-test-instrumentation-sqlalchemy{11,14}
pypy3-test-instrumentation-sqlalchemy{11,14}

; opentelemetry-instrumentation-redis
py3{6,7,8,9}-test-instrumentation-redis
Expand Down Expand Up @@ -173,6 +173,9 @@ deps =
elasticsearch6: elasticsearch>=6.0,<7.0
elasticsearch7: elasticsearch-dsl>=7.0,<8.0
elasticsearch7: elasticsearch>=7.0,<8.0
sqlalchemy11: sqlalchemy>=1.1,<1.2
sqlalchemy14: aiosqlite
sqlalchemy14: sqlalchemy~=1.4

; FIXME: add coverage testing
; FIXME: add mypy testing
Expand Down Expand Up @@ -205,7 +208,7 @@ changedir =
test-instrumentation-redis: instrumentation/opentelemetry-instrumentation-redis/tests
test-instrumentation-requests: instrumentation/opentelemetry-instrumentation-requests/tests
test-instrumentation-sklearn: instrumentation/opentelemetry-instrumentation-sklearn/tests
test-instrumentation-sqlalchemy: instrumentation/opentelemetry-instrumentation-sqlalchemy/tests
test-instrumentation-sqlalchemy{11,14}: instrumentation/opentelemetry-instrumentation-sqlalchemy/tests
test-instrumentation-sqlite3: instrumentation/opentelemetry-instrumentation-sqlite3/tests
test-instrumentation-starlette: instrumentation/opentelemetry-instrumentation-starlette/tests
test-instrumentation-tornado: instrumentation/opentelemetry-instrumentation-tornado/tests
Expand Down Expand Up @@ -290,7 +293,7 @@ commands_pre =

sklearn: pip install {toxinidir}/instrumentation/opentelemetry-instrumentation-sklearn[test]

sqlalchemy: pip install {toxinidir}/instrumentation/opentelemetry-instrumentation-sqlalchemy[test]
sqlalchemy{11,14}: pip install {toxinidir}/instrumentation/opentelemetry-instrumentation-sqlalchemy[test]

elasticsearch{2,5,6,7}: pip install {toxinidir}/opentelemetry-python-core/opentelemetry-instrumentation {toxinidir}/instrumentation/opentelemetry-instrumentation-elasticsearch[test]

Expand Down Expand Up @@ -329,7 +332,7 @@ commands =

[testenv:lint]
basepython: python3.9
recreate = False
recreate = False
deps =
-c dev-requirements.txt
flaky
Expand Down Expand Up @@ -399,7 +402,7 @@ deps =
PyMySQL ~= 0.10.1
psycopg2 ~= 2.8.4
aiopg >= 0.13.0, < 1.3.0
sqlalchemy ~= 1.3.16
sqlalchemy ~= 1.4
redis ~= 3.3.11
celery[pytest] >= 4.0, < 6.0
protobuf>=3.13.0
Expand Down