diff --git a/debug_toolbar/panels/sql/tracking.py b/debug_toolbar/panels/sql/tracking.py index 7cdb2b1c3..315553959 100644 --- a/debug_toolbar/panels/sql/tracking.py +++ b/debug_toolbar/panels/sql/tracking.py @@ -3,6 +3,7 @@ import json from time import time +from django.db.backends.utils import CursorWrapper from django.utils.encoding import force_str from debug_toolbar import settings as dt_settings @@ -43,21 +44,16 @@ def cursor(*args, **kwargs): # See: # https://github.com/jazzband/django-debug-toolbar/pull/615 # https://github.com/jazzband/django-debug-toolbar/pull/896 - cursor = connection._djdt_cursor(*args, **kwargs) - if connection._djdt_in_record: - return cursor if allow_sql.get(): wrapper = NormalCursorWrapper else: wrapper = ExceptionCursorWrapper - return wrapper(cursor, connection, panel) + return wrapper(connection._djdt_cursor(*args, **kwargs), connection, panel) def chunked_cursor(*args, **kwargs): # prevent double wrapping # solves https://github.com/jazzband/django-debug-toolbar/issues/1239 cursor = connection._djdt_chunked_cursor(*args, **kwargs) - if connection._djdt_in_record: - return cursor if not isinstance(cursor, BaseCursorWrapper): if allow_sql.get(): wrapper = NormalCursorWrapper @@ -68,7 +64,6 @@ def chunked_cursor(*args, **kwargs): connection.cursor = cursor connection.chunked_cursor = chunked_cursor - connection._djdt_in_record = False def unwrap_cursor(connection): @@ -91,8 +86,11 @@ def unwrap_cursor(connection): del connection._djdt_chunked_cursor -class BaseCursorWrapper: - pass +class BaseCursorWrapper(CursorWrapper): + def __init__(self, cursor, db, logger): + super().__init__(cursor, db) + # logger must implement a ``record`` method + self.logger = logger class ExceptionCursorWrapper(BaseCursorWrapper): @@ -101,11 +99,21 @@ class ExceptionCursorWrapper(BaseCursorWrapper): Used in Templates panel. """ - def __init__(self, cursor, db, logger): - pass - def __getattr__(self, attr): - raise SQLQueryTriggered() + # This allows the cursor to access connection and close which + # are needed in psycopg to determine the last_executed_query via + # the mogrify function. + if attr in ( + "callproc", + "execute", + "executemany", + "fetchone", + "fetchmany", + "fetchall", + "nextset", + ): + raise SQLQueryTriggered(f"Attr: {attr} was accessed") + return super().__getattr__(attr) class NormalCursorWrapper(BaseCursorWrapper): @@ -113,13 +121,6 @@ class NormalCursorWrapper(BaseCursorWrapper): Wraps a cursor and logs queries. """ - def __init__(self, cursor, db, logger): - self.cursor = cursor - # Instance of a BaseDatabaseWrapper subclass - self.db = db - # logger must implement a ``record`` method - self.logger = logger - def _quote_expr(self, element): if isinstance(element, str): return "'%s'" % element.replace("'", "''") @@ -159,115 +160,108 @@ def _decode(self, param): except UnicodeDecodeError: return "(encoded string)" + def _get_last_executed_query(self, sql, params): + """Get the last executed query from the connection.""" + # The pyscopg3 backend uses a mogrify function which creates a new cursor. + # We need to avoid hooking into that cursor. + reset_token = allow_sql.set(False) + sql_query = self.db.ops.last_executed_query( + self.cursor, sql, self._quote_params(params) + ) + allow_sql.reset(reset_token) + return sql_query + def _record(self, method, sql, params): - self.db._djdt_in_record = True - try: - alias = self.db.alias - vendor = self.db.vendor + alias = self.db.alias + vendor = self.db.vendor - if vendor == "postgresql": - # The underlying DB connection (as opposed to Django's wrapper) - conn = self.db.connection - initial_conn_status = conn.info.transaction_status + if vendor == "postgresql": + # The underlying DB connection (as opposed to Django's wrapper) + conn = self.db.connection + initial_conn_status = conn.info.transaction_status - start_time = time() + start_time = time() + try: + return method(sql, params) + finally: + stop_time = time() + duration = (stop_time - start_time) * 1000 + _params = "" try: - return method(sql, params) - finally: - stop_time = time() - duration = (stop_time - start_time) * 1000 - _params = "" + _params = json.dumps(self._decode(params)) + except TypeError: + pass # object not JSON serializable + template_info = get_template_info() + + # Sql might be an object (such as psycopg Composed). + # For logging purposes, make sure it's str. + if vendor == "postgresql" and not isinstance(sql, str): + sql = sql.as_string(conn) + else: + sql = str(sql) + + params = { + "vendor": vendor, + "alias": alias, + "sql": self._get_last_executed_query(sql, params), + "duration": duration, + "raw_sql": sql, + "params": _params, + "raw_params": params, + "stacktrace": get_stack_trace(skip=2), + "start_time": start_time, + "stop_time": stop_time, + "is_slow": ( + duration > dt_settings.get_config()["SQL_WARNING_THRESHOLD"] + ), + "is_select": sql.lower().strip().startswith("select"), + "template_info": template_info, + } + + if vendor == "postgresql": + # If an erroneous query was ran on the connection, it might + # be in a state where checking isolation_level raises an + # exception. try: - _params = json.dumps(self._decode(params)) - except TypeError: - pass # object not JSON serializable - template_info = get_template_info() - - # Sql might be an object (such as psycopg Composed). - # For logging purposes, make sure it's str. - if vendor == "postgresql" and not isinstance(sql, str): - sql = sql.as_string(conn) - else: - sql = str(sql) - - params = { - "vendor": vendor, - "alias": alias, - "sql": self.db.ops.last_executed_query( - self.cursor, sql, self._quote_params(params) - ), - "duration": duration, - "raw_sql": sql, - "params": _params, - "raw_params": params, - "stacktrace": get_stack_trace(skip=2), - "start_time": start_time, - "stop_time": stop_time, - "is_slow": ( - duration > dt_settings.get_config()["SQL_WARNING_THRESHOLD"] - ), - "is_select": sql.lower().strip().startswith("select"), - "template_info": template_info, - } - - if vendor == "postgresql": - # If an erroneous query was ran on the connection, it might - # be in a state where checking isolation_level raises an - # exception. - try: - iso_level = conn.isolation_level - except conn.InternalError: - iso_level = "unknown" - # PostgreSQL does not expose any sort of transaction ID, so it is - # necessary to generate synthetic transaction IDs here. If the - # connection was not in a transaction when the query started, and was - # after the query finished, a new transaction definitely started, so get - # a new transaction ID from logger.new_transaction_id(). If the query - # was in a transaction both before and after executing, make the - # assumption that it is the same transaction and get the current - # transaction ID from logger.current_transaction_id(). There is an edge - # case where Django can start a transaction before the first query - # executes, so in that case logger.current_transaction_id() will - # generate a new transaction ID since one does not already exist. - final_conn_status = conn.info.transaction_status - if final_conn_status == STATUS_IN_TRANSACTION: - if initial_conn_status == STATUS_IN_TRANSACTION: - trans_id = self.logger.current_transaction_id(alias) - else: - trans_id = self.logger.new_transaction_id(alias) + iso_level = conn.isolation_level + except conn.InternalError: + iso_level = "unknown" + # PostgreSQL does not expose any sort of transaction ID, so it is + # necessary to generate synthetic transaction IDs here. If the + # connection was not in a transaction when the query started, and was + # after the query finished, a new transaction definitely started, so get + # a new transaction ID from logger.new_transaction_id(). If the query + # was in a transaction both before and after executing, make the + # assumption that it is the same transaction and get the current + # transaction ID from logger.current_transaction_id(). There is an edge + # case where Django can start a transaction before the first query + # executes, so in that case logger.current_transaction_id() will + # generate a new transaction ID since one does not already exist. + final_conn_status = conn.info.transaction_status + if final_conn_status == STATUS_IN_TRANSACTION: + if initial_conn_status == STATUS_IN_TRANSACTION: + trans_id = self.logger.current_transaction_id(alias) else: - trans_id = None - - params.update( - { - "trans_id": trans_id, - "trans_status": conn.info.transaction_status, - "iso_level": iso_level, - } - ) - - # We keep `sql` to maintain backwards compatibility - self.logger.record(**params) - finally: - self.db._djdt_in_record = False + trans_id = self.logger.new_transaction_id(alias) + else: + trans_id = None + + params.update( + { + "trans_id": trans_id, + "trans_status": conn.info.transaction_status, + "iso_level": iso_level, + } + ) + + # We keep `sql` to maintain backwards compatibility + self.logger.record(**params) def callproc(self, procname, params=None): - return self._record(self.cursor.callproc, procname, params) + return self._record(super().callproc, procname, params) def execute(self, sql, params=None): - return self._record(self.cursor.execute, sql, params) + return self._record(super().execute, sql, params) def executemany(self, sql, param_list): - return self._record(self.cursor.executemany, sql, param_list) - - def __getattr__(self, attr): - return getattr(self.cursor, attr) - - def __iter__(self): - return iter(self.cursor) - - def __enter__(self): - return self - - def __exit__(self, type, value, traceback): - self.close() + return self._record(super().executemany, sql, param_list)