Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for db cursors and connections in context managers #1028

Merged
merged 11 commits into from
Sep 2, 2020
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,13 @@ def cursor(self, *args, **kwargs):
self.__wrapped__.cursor(*args, **kwargs), db_api_integration
)

def __enter__(self):
self.__wrapped__.__enter__()
return self

def __exit__(self, *args, **kwargs):
self.__wrapped__.__exit__(*args, **kwargs)

return TracedConnectionProxy(connection, *args, **kwargs)


Expand Down Expand Up @@ -366,4 +373,11 @@ def callproc(self, *args, **kwargs):
self.__wrapped__.callproc, *args, **kwargs
)

def __enter__(self):
self.__wrapped__.__enter__()
return self

def __exit__(self, *args, **kwargs):
self.__wrapped__.__exit__(*args, **kwargs)

return TracedCursorProxy(cursor, *args, **kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
POSTGRES_USER = os.getenv("POSTGRESQL_HOST ", "testuser")


def _await(coro):
def async_call(coro):
aabmass marked this conversation as resolved.
Show resolved Hide resolved
loop = asyncio.get_event_loop()
return loop.run_until_complete(coro)

Expand All @@ -27,7 +27,7 @@ def setUpClass(cls):
cls._cursor = None
cls._tracer = cls.tracer_provider.get_tracer(__name__)
AsyncPGInstrumentor().instrument(tracer_provider=cls.tracer_provider)
cls._connection = _await(
cls._connection = async_call(
asyncpg.connect(
database=POSTGRES_DB_NAME,
user=POSTGRES_USER,
Expand All @@ -42,7 +42,7 @@ def tearDownClass(cls):
AsyncPGInstrumentor().uninstrument()

def test_instrumented_execute_method_without_arguments(self, *_, **__):
_await(self._connection.execute("SELECT 42;"))
async_call(self._connection.execute("SELECT 42;"))
spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 1)
self.assertEqual(
Expand All @@ -59,7 +59,7 @@ def test_instrumented_execute_method_without_arguments(self, *_, **__):
)

def test_instrumented_fetch_method_without_arguments(self, *_, **__):
_await(self._connection.fetch("SELECT 42;"))
async_call(self._connection.fetch("SELECT 42;"))
spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 1)
self.assertEqual(
Expand All @@ -77,7 +77,7 @@ async def _transaction_execute():
async with self._connection.transaction():
await self._connection.execute("SELECT 42;")

_await(_transaction_execute())
async_call(_transaction_execute())

spans = self.memory_exporter.get_finished_spans()
self.assertEqual(3, len(spans))
Expand Down Expand Up @@ -124,7 +124,7 @@ async def _transaction_execute():
await self._connection.execute("SELECT 42::uuid;")

with self.assertRaises(asyncpg.CannotCoerceError):
_await(_transaction_execute())
async_call(_transaction_execute())

spans = self.memory_exporter.get_finished_spans()
self.assertEqual(3, len(spans))
Expand Down Expand Up @@ -167,7 +167,7 @@ async def _transaction_execute():
)

def test_instrumented_method_doesnt_capture_parameters(self, *_, **__):
_await(self._connection.execute("SELECT $1;", "1"))
async_call(self._connection.execute("SELECT $1;", "1"))
spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 1)
self.assertEqual(
Expand Down Expand Up @@ -198,7 +198,7 @@ def setUpClass(cls):
AsyncPGInstrumentor(capture_parameters=True).instrument(
tracer_provider=cls.tracer_provider
)
cls._connection = _await(
cls._connection = async_call(
asyncpg.connect(
database=POSTGRES_DB_NAME,
user=POSTGRES_USER,
Expand All @@ -213,7 +213,7 @@ def tearDownClass(cls):
AsyncPGInstrumentor().uninstrument()

def test_instrumented_execute_method_with_arguments(self, *_, **__):
_await(self._connection.execute("SELECT $1;", "1"))
async_call(self._connection.execute("SELECT $1;", "1"))
spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 1)
self.assertEqual(
Expand All @@ -231,7 +231,7 @@ def test_instrumented_execute_method_with_arguments(self, *_, **__):
)

def test_instrumented_fetch_method_with_arguments(self, *_, **__):
_await(self._connection.fetch("SELECT $1;", "1"))
async_call(self._connection.fetch("SELECT $1;", "1"))
spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 1)
self.assertEqual(
Expand All @@ -246,7 +246,7 @@ def test_instrumented_fetch_method_with_arguments(self, *_, **__):
)

def test_instrumented_executemany_method_with_arguments(self, *_, **__):
_await(self._connection.executemany("SELECT $1;", [["1"], ["2"]]))
async_call(self._connection.executemany("SELECT $1;", [["1"], ["2"]]))
spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 1)
self.assertEqual(
Expand All @@ -262,7 +262,7 @@ def test_instrumented_executemany_method_with_arguments(self, *_, **__):

def test_instrumented_execute_interface_error_method(self, *_, **__):
with self.assertRaises(asyncpg.InterfaceError):
_await(self._connection.execute("SELECT 42;", 1, 2, 3))
async_call(self._connection.execute("SELECT 42;", 1, 2, 3))
spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 1)
self.assertEqual(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,8 @@
# limitations under the License.


import time

import celery
import pytest
from celery import signals
from celery.exceptions import Retry

import opentelemetry.instrumentation.celery
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

import os
import time

import mysql.connector

Expand All @@ -36,21 +35,24 @@ def setUpClass(cls):
cls._cursor = None
cls._tracer = cls.tracer_provider.get_tracer(__name__)
MySQLInstrumentor().instrument()
cls._connection = mysql.connector.connect(
user=MYSQL_USER,
password=MYSQL_PASSWORD,
host=MYSQL_HOST,
port=MYSQL_PORT,
database=MYSQL_DB_NAME,
)
cls._cursor = cls._connection.cursor()

@classmethod
def tearDownClass(cls):
if cls._connection:
cls._connection.close()
MySQLInstrumentor().uninstrument()

def setUp(self):
super().setUp()
self._connection = mysql.connector.connect(
user=MYSQL_USER,
password=MYSQL_PASSWORD,
host=MYSQL_HOST,
port=MYSQL_PORT,
database=MYSQL_DB_NAME,
)
self._cursor = self._connection.cursor()

def validate_spans(self):
spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 2)
Expand Down Expand Up @@ -79,6 +81,23 @@ def test_execute(self):
self._cursor.execute("CREATE TABLE IF NOT EXISTS test (id INT)")
self.validate_spans()

def test_execute_with_connection_context_manager(self):
"""Should create a child span for execute with connection context
"""
with self._tracer.start_as_current_span("rootSpan"):
with self._connection as conn:
cursor = conn.cursor()
cursor.execute("CREATE TABLE IF NOT EXISTS test (id INT)")
self.validate_spans()

def test_execute_with_cursor_context_manager(self):
"""Should create a child span for execute with cursor context
"""
with self._tracer.start_as_current_span("rootSpan"):
with self._connection.cursor() as cursor:
cursor.execute("CREATE TABLE IF NOT EXISTS test (id INT)")
self.validate_spans()

def test_executemany(self):
"""Should create a child span for executemany
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.
import asyncio
import os
import time

import aiopg
import psycopg2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

import os
import time

import psycopg2

Expand Down Expand Up @@ -86,6 +85,24 @@ def test_execute(self):
)
self.validate_spans()

def test_execute_with_connection_context_manager(self):
"""Should create a child span for execute with connection context
"""
with self._tracer.start_as_current_span("rootSpan"):
with self._connection as conn:
cursor = conn.cursor()
cursor.execute("CREATE TABLE IF NOT EXISTS test (id INT)")
self.validate_spans()

def test_execute_with_cursor_context_manager(self):
"""Should create a child span for execute with cursor context
"""
with self._tracer.start_as_current_span("rootSpan"):
with self._connection.cursor() as cursor:
cursor.execute("CREATE TABLE IF NOT EXISTS test (id INT)")
self.validate_spans()
self.assertTrue(cursor.closed)

def test_executemany(self):
"""Should create a child span for executemany
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,14 @@ def test_execute(self):
self._cursor.execute("CREATE TABLE IF NOT EXISTS test (id INT)")
self.validate_spans()

def test_execute_with_cursor_context_manager(self):
"""Should create a child span for execute with cursor context
"""
with self._tracer.start_as_current_span("rootSpan"):
with self._connection.cursor() as cursor:
cursor.execute("CREATE TABLE IF NOT EXISTS test (id INT)")
self.validate_spans()

def test_executemany(self):
"""Should create a child span for executemany
"""
Expand Down
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -430,4 +430,4 @@ commands =
pytest {posargs}

commands_post =
docker-compose down
docker-compose down -v