From e7575e87dc9e2d2560b87d6fd5a123b9398cbd34 Mon Sep 17 00:00:00 2001 From: Tim Schilling Date: Tue, 9 May 2023 20:28:58 -0500 Subject: [PATCH] Inherit from django.db.backends.utils.CursorWrapper This switches the Debug Toolbar cursor wrappers to inherit from the Django class django.db.backends.utils.CursorWrapper. This reduces some of the code we need. --- debug_toolbar/panels/sql/tracking.py | 46 +++++++++------------------- 1 file changed, 14 insertions(+), 32 deletions(-) diff --git a/debug_toolbar/panels/sql/tracking.py b/debug_toolbar/panels/sql/tracking.py index 4add1fce7..425e4e5cc 100644 --- a/debug_toolbar/panels/sql/tracking.py +++ b/debug_toolbar/panels/sql/tracking.py @@ -4,6 +4,7 @@ from time import time import django.test.testcases +from django.db.backends.utils import CursorWrapper from django.utils.encoding import force_str from debug_toolbar import settings as dt_settings @@ -57,54 +58,47 @@ def cursor(*args, **kwargs): wrapper = NormalCursorWrapper else: wrapper = ExceptionCursorWrapper - return wrapper(cursor, connection, logger) + return wrapper(cursor.cursor, connection, logger) def chunked_cursor(*args, **kwargs): # prevent double wrapping # solves https://github.com/jazzband/django-debug-toolbar/issues/1239 logger = connection._djdt_logger cursor = connection._djdt_chunked_cursor(*args, **kwargs) - if logger is not None and not isinstance(cursor, BaseCursorWrapper): + if logger is not None and not isinstance(cursor, DjDTCursorWrapper): if allow_sql.get(): wrapper = NormalCursorWrapper else: wrapper = ExceptionCursorWrapper - return wrapper(cursor, connection, logger) + return wrapper(cursor.cursor, connection, logger) return cursor connection.cursor = cursor connection.chunked_cursor = chunked_cursor -class BaseCursorWrapper: - pass +class DjDTCursorWrapper(CursorWrapper): + def __init__(self, cursor, db, logger): + super().__init__(cursor, db) + # logger must implement a ``record`` method + self.logger = logger -class ExceptionCursorWrapper(BaseCursorWrapper): +class ExceptionCursorWrapper(DjDTCursorWrapper): """ Wraps a cursor and raises an exception on any operation. Used in Templates panel. """ - def __init__(self, cursor, db, logger): - pass - def __getattr__(self, attr): raise SQLQueryTriggered() -class NormalCursorWrapper(BaseCursorWrapper): +class NormalCursorWrapper(DjDTCursorWrapper): """ Wraps a cursor and logs queries. """ - def __init__(self, cursor, db, logger): - self.cursor = cursor - # Instance of a BaseDatabaseWrapper subclass - self.db = db - # logger must implement a ``record`` method - self.logger = logger - def _quote_expr(self, element): if isinstance(element, str): return "'%s'" % element.replace("'", "''") @@ -246,22 +240,10 @@ def _record(self, method, sql, params): self.logger.record(**params) def callproc(self, procname, params=None): - return self._record(self.cursor.callproc, procname, params) + return self._record(super().callproc, procname, params) def execute(self, sql, params=None): - return self._record(self.cursor.execute, sql, params) + return self._record(super().execute, sql, params) def executemany(self, sql, param_list): - return self._record(self.cursor.executemany, sql, param_list) - - def __getattr__(self, attr): - return getattr(self.cursor, attr) - - def __iter__(self): - return iter(self.cursor) - - def __enter__(self): - return self - - def __exit__(self, type, value, traceback): - self.close() + return self._record(super().executemany, sql, param_list)