Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SQLCommenter semicolon bug fix #1200

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased](https://github.com/open-telemetry/opentelemetry-python/compare/v1.12.0rc2-0.32b0...HEAD)
- Adding multiple db connections support for django-instrumentation's sqlcommenter
([#1187](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1187))
- SQLCommenter semicolon bug fix
([#1200](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1200/files))

### Added
- `opentelemetry-instrumentation-redis` add support to instrument RedisCluster clients
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,12 @@
from opentelemetry import trace as trace_api
from opentelemetry.instrumentation.dbapi.version import __version__
from opentelemetry.instrumentation.utils import (
_generate_opentelemetry_traceparent,
_generate_sql_comment,
_add_sql_comment,
_get_opentelemetry_values,
unwrap,
)
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.trace import Span, SpanKind, TracerProvider, get_tracer
from opentelemetry.trace import SpanKind, TracerProvider, get_tracer

_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -375,15 +375,6 @@ def get_statement(self, cursor, args): # pylint: disable=no-self-use
return statement.decode("utf8", "replace")
return statement

@staticmethod
def _generate_comment(span: Span) -> str:
span_context = span.get_span_context()
meta = {}
if span_context.is_valid:
meta.update(_generate_opentelemetry_traceparent(span))
# TODO(schekuri): revisit to enrich with info such as route, db_driver etc...
return _generate_sql_comment(**meta)

def traced_execution(
self,
cursor,
Expand All @@ -405,11 +396,14 @@ def traced_execution(
self._populate_span(span, cursor, *args)
if args and self._commenter_enabled:
try:
comment = self._generate_comment(span)
if isinstance(args[0], bytes):
comment = comment.encode("utf8")
args_list = list(args)
args_list[0] += comment
commenter_data = {}
commenter_data.update(_get_opentelemetry_values())
statement = _add_sql_comment(
args_list[0], **commenter_data
)

args_list[0] = statement
args = tuple(args_list)
except Exception as exc: # pylint: disable=broad-except
_logger.exception(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,12 +236,11 @@ def test_executemany_comment(self):
mock_connect, {}, {}
)
cursor = mock_connection.cursor()
cursor.executemany("Test query")
spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)
span = spans_list[0]
comment = dbapi.CursorTracer._generate_comment(span)
self.assertIn(comment, cursor.query)
cursor.executemany("Select 1;")
self.assertRegex(
cursor.query,
r"Select 1 /\*traceparent='\d{1,2}-[a-zA-Z0-9_]{32}-[a-zA-Z0-9_]{16}-\d{1,2}'\*/;",
)

def test_callproc(self):
db_integration = dbapi.DatabaseApiIntegration(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from django.db.backends.utils import CursorDebugWrapper

from opentelemetry.instrumentation.utils import (
_generate_sql_comment,
_add_sql_comment,
_get_opentelemetry_values,
)
from opentelemetry.trace.propagation.tracecontext import (
Expand Down Expand Up @@ -84,7 +84,8 @@ def __call__(self, execute: Type[T], sql, params, many, context) -> T:
db_driver = context["connection"].settings_dict.get("ENGINE", "")
resolver_match = self.request.resolver_match

sql_comment = _generate_sql_comment(
sql = _add_sql_comment(
sql,
# Information about the controller.
controller=resolver_match.view_name
if resolver_match and with_controller
Expand Down Expand Up @@ -112,7 +113,6 @@ def __call__(self, execute: Type[T], sql, params, many, context) -> T:
# See:
# * https://github.com/basecamp/marginalia/issues/61
# * https://github.com/basecamp/marginalia/pull/80
sql += sql_comment

# Add the query to the query log if debugging.
if isinstance(context["cursor"], CursorDebugWrapper):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def test_query_wrapper(self, trace_capture):
execute_mock_obj = MagicMock()
qw_instance(
execute_mock_obj,
"Select 1",
"Select 1;",
MagicMock("test"),
MagicMock("test1"),
MagicMock(),
Expand All @@ -97,7 +97,7 @@ def test_query_wrapper(self, trace_capture):
self.assertEqual(
output_sql,
"Select 1 /*app_name='app',controller='view',route='route',traceparent='%%2Atraceparent%%3D%%2700-0000000"
"00000000000000000deadbeef-000000000000beef-00'*/",
"00000000000000000deadbeef-000000000000beef-00'*/;",
)

@patch(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,10 @@
)
from opentelemetry.instrumentation.sqlalchemy.version import __version__
from opentelemetry.instrumentation.utils import (
_generate_opentelemetry_traceparent,
_generate_sql_comment,
_add_sql_comment,
_get_opentelemetry_values,
)
from opentelemetry.semconv.trace import NetTransportValues, SpanAttributes
from opentelemetry.trace import Span
from opentelemetry.trace.status import Status, StatusCode


Expand Down Expand Up @@ -141,21 +140,15 @@ def _before_cur_exec(
span.set_attribute(SpanAttributes.DB_SYSTEM, self.vendor)
for key, value in attrs.items():
span.set_attribute(key, value)
if self.enable_commenter:
commenter_data = {}
commenter_data.update(_get_opentelemetry_values())
statement = _add_sql_comment(statement, **commenter_data)

context._otel_span = span
if self.enable_commenter:
statement = statement + EngineTracer._generate_comment(span=span)

return statement, params

@staticmethod
def _generate_comment(span: Span) -> str:
span_context = span.get_span_context()
meta = {}
if span_context.is_valid:
meta.update(_generate_opentelemetry_traceparent(span))
return _generate_sql_comment(**meta)


# pylint: disable=unused-argument
def _after_cur_exec(conn, cursor, statement, params, context, executemany):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
from unittest import mock

import pytest
Expand All @@ -21,7 +20,6 @@

from opentelemetry import trace
from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor
from opentelemetry.instrumentation.sqlalchemy.engine import EngineTracer
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider, export
from opentelemetry.test.test_base import TestBase
Expand Down Expand Up @@ -217,22 +215,3 @@ async def run():
)

asyncio.get_event_loop().run_until_complete(run())

def test_generate_commenter(self):
logging.getLogger("sqlalchemy.engine").setLevel(logging.INFO)
engine = create_engine("sqlite:///:memory:")
SQLAlchemyInstrumentor().instrument(
engine=engine,
tracer_provider=self.tracer_provider,
enable_commenter=True,
)

cnx = engine.connect()
cnx.execute("SELECT 1 + 1;").fetchall()
spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 2)
span = spans[1]
self.assertIn(
EngineTracer._generate_comment(span),
self.caplog.records[-2].getMessage(),
)
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging

import pytest
from sqlalchemy import create_engine
Expand All @@ -28,6 +29,17 @@ def tearDown(self):
super().tearDown()
SQLAlchemyInstrumentor().uninstrument()

def test_sqlcommenter_disabled(self):
logging.getLogger("sqlalchemy.engine").setLevel(logging.INFO)
engine = create_engine("sqlite:///:memory:", echo=True)
SQLAlchemyInstrumentor().instrument(
engine=engine, tracer_provider=self.tracer_provider
)
cnx = engine.connect()
cnx.execute("SELECT 1;").fetchall()

self.assertEqual(self.caplog.records[-2].getMessage(), "SELECT 1;")

def test_sqlcommenter_enabled(self):
engine = create_engine("sqlite:///:memory:")
SQLAlchemyInstrumentor().instrument(
Expand All @@ -39,15 +51,5 @@ def test_sqlcommenter_enabled(self):
cnx.execute("SELECT 1;").fetchall()
self.assertRegex(
self.caplog.records[-2].getMessage(),
r"SELECT 1; /\*traceparent='\d{1,2}-[a-zA-Z0-9_]{32}-[a-zA-Z0-9_]{16}-\d{1,2}'\*/",
r"SELECT 1 /\*traceparent='\d{1,2}-[a-zA-Z0-9_]{32}-[a-zA-Z0-9_]{16}-\d{1,2}'\*/;",
)

def test_sqlcommenter_disabled(self):
engine = create_engine("sqlite:///:memory:", echo=True)
SQLAlchemyInstrumentor().instrument(
engine=engine, tracer_provider=self.tracer_provider
)
cnx = engine.connect()
cnx.execute("SELECT 1;").fetchall()

self.assertEqual(self.caplog.records[-2].getMessage(), "SELECT 1;")
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
# pylint: disable=E0611
from opentelemetry.context import _SUPPRESS_INSTRUMENTATION_KEY # noqa: F401
from opentelemetry.propagate import extract
from opentelemetry.trace import Span, StatusCode
from opentelemetry.trace import StatusCode
from opentelemetry.trace.propagation.tracecontext import (
TraceContextTextMapPropagator,
)
Expand Down Expand Up @@ -147,7 +147,7 @@ def _generate_sql_comment(**meta) -> str:
)


def _url_quote(s): # pylint: disable=invalid-name
def _url_quote(s) -> str: # pylint: disable=invalid-name
if not isinstance(s, (str, bytes)):
return s
quoted = urllib.parse.quote(s)
Expand All @@ -158,7 +158,7 @@ def _url_quote(s): # pylint: disable=invalid-name
return quoted.replace("%", "%%")


def _get_opentelemetry_values():
def _get_opentelemetry_values() -> dict:
"""
Return the OpenTelemetry Trace and Span IDs if Span ID is set in the
OpenTelemetry execution context.
Expand All @@ -169,20 +169,22 @@ def _get_opentelemetry_values():
return _headers


def _generate_opentelemetry_traceparent(span: Span) -> str:
meta = {}
_version = "00"
_span_id = trace.format_span_id(span.context.span_id)
_trace_id = trace.format_trace_id(span.context.trace_id)
_flags = str(trace.TraceFlags.SAMPLED)
_traceparent = _version + "-" + _trace_id + "-" + _span_id + "-" + _flags
meta.update({"traceparent": _traceparent})
return meta


def _python_path_without_directory(python_path, directory, path_separator):
return sub(
rf"{escape(directory)}{path_separator}(?!$)",
"",
python_path,
)


def _add_sql_comment(sql, **meta) -> str:
"""
Appends comments to the sql statement and returns it
"""
comment = _generate_sql_comment(**meta)
sql = sql.rstrip()
if sql[-1] == ";":
sql = sql[:-1] + comment + ";"
else:
sql = sql + comment
return sql
34 changes: 34 additions & 0 deletions opentelemetry-instrumentation/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from http import HTTPStatus

from opentelemetry.instrumentation.utils import (
_add_sql_comment,
_python_path_without_directory,
http_status_to_status_code,
)
Expand Down Expand Up @@ -152,3 +153,36 @@ def test_remove_current_directory_from_python_path_linux_only_path(self):
python_path, directory, path_separator
)
self.assertEqual(actual_python_path, python_path)

def test_add_sql_comments_with_semicolon(self):
sql_query_without_semicolon = "Select 1;"
comments = {"comment_1": "value 1", "comment 2": "value 3"}
commented_sql_without_semicolon = _add_sql_comment(
sql_query_without_semicolon, **comments
)

self.assertEqual(
commented_sql_without_semicolon,
"Select 1 /*comment%%202='value%%203',comment_1='value%%201'*/;",
)

def test_add_sql_comments_without_semicolon(self):
sql_query_without_semicolon = "Select 1"
comments = {"comment_1": "value 1", "comment 2": "value 3"}
commented_sql_without_semicolon = _add_sql_comment(
sql_query_without_semicolon, **comments
)

self.assertEqual(
commented_sql_without_semicolon,
"Select 1 /*comment%%202='value%%203',comment_1='value%%201'*/",
)

def test_add_sql_comments_without_comments(self):
sql_query_without_semicolon = "Select 1"
comments = {}
commented_sql_without_semicolon = _add_sql_comment(
sql_query_without_semicolon, **comments
)

self.assertEqual(commented_sql_without_semicolon, "Select 1")
Original file line number Diff line number Diff line change
Expand Up @@ -50,5 +50,5 @@ def test_commenter_enabled(self):
self._cursor.execute("SELECT 1;")
self.assertRegex(
self._cursor.query.decode("ascii"),
r"SELECT 1; /\*traceparent='\d{1,2}-[a-zA-Z0-9_]{32}-[a-zA-Z0-9_]{16}-\d{1,2}'\*/",
r"SELECT 1 /\*traceparent='\d{1,2}-[a-zA-Z0-9_]{32}-[a-zA-Z0-9_]{16}-\d{1,2}'\*/;",
)