Skip to content

Commit

Permalink
Patch CursorWrapper dynamically to allow multiple base classes. (#1820)
Browse files Browse the repository at this point in the history
* Patch CursorWrapper dynamically to allow multiple base classes.

When the debug SQL logs are enabled, the wrapper class is CursorDebugWrapper
not CursorWrapper. Since we have inspections based on that specific class
they are removing the CursorDebugWrapper causing the SQL logs to not appear.

This attempts to dynamically patch the CursorWrapper or CursorDebugWrapper
with the functionality we need.

This doesn't do a full regression test, but it may be possible to get
it to work with:

    TEST_ARGS='--debug-sql' make test

Which causes the current set of tests to fail since they are keyed to
CursorWrapper.

* Allow mixin as a valid word in our docs.

* Support tests with --debug-sql
  • Loading branch information
tim-schilling committed Aug 11, 2023
1 parent 6e55663 commit 7677183
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 30 deletions.
31 changes: 19 additions & 12 deletions debug_toolbar/panels/sql/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from time import perf_counter

import django.test.testcases
from django.db.backends.utils import CursorWrapper
from django.utils.encoding import force_str

from debug_toolbar.utils import get_stack_trace, get_template_info
Expand Down Expand Up @@ -60,34 +59,42 @@ def cursor(*args, **kwargs):
cursor = connection._djdt_cursor(*args, **kwargs)
if logger is None:
return cursor
wrapper = NormalCursorWrapper if allow_sql.get() else ExceptionCursorWrapper
return wrapper(cursor.cursor, connection, logger)
mixin = NormalCursorMixin if allow_sql.get() else ExceptionCursorMixin
return patch_cursor_wrapper_with_mixin(cursor.__class__, mixin)(
cursor.cursor, connection, logger
)

def chunked_cursor(*args, **kwargs):
# prevent double wrapping
# solves https://github.com/jazzband/django-debug-toolbar/issues/1239
logger = connection._djdt_logger
cursor = connection._djdt_chunked_cursor(*args, **kwargs)
if logger is not None and not isinstance(cursor, DjDTCursorWrapper):
if allow_sql.get():
wrapper = NormalCursorWrapper
else:
wrapper = ExceptionCursorWrapper
return wrapper(cursor.cursor, connection, logger)
if logger is not None and not isinstance(cursor, DjDTCursorWrapperMixin):
mixin = NormalCursorMixin if allow_sql.get() else ExceptionCursorMixin
return patch_cursor_wrapper_with_mixin(cursor.__class__, mixin)(
cursor.cursor, connection, logger
)
return cursor

connection.cursor = cursor
connection.chunked_cursor = chunked_cursor


class DjDTCursorWrapper(CursorWrapper):
def patch_cursor_wrapper_with_mixin(base_wrapper, mixin):
class DjDTCursorWrapper(mixin, base_wrapper):
pass

return DjDTCursorWrapper


class DjDTCursorWrapperMixin:
def __init__(self, cursor, db, logger):
super().__init__(cursor, db)
# logger must implement a ``record`` method
self.logger = logger


class ExceptionCursorWrapper(DjDTCursorWrapper):
class ExceptionCursorMixin(DjDTCursorWrapperMixin):
"""
Wraps a cursor and raises an exception on any operation.
Used in Templates panel.
Expand All @@ -97,7 +104,7 @@ def __getattr__(self, attr):
raise SQLQueryTriggered()


class NormalCursorWrapper(DjDTCursorWrapper):
class NormalCursorMixin(DjDTCursorWrapperMixin):
"""
Wraps a cursor and logs queries.
"""
Expand Down
2 changes: 2 additions & 0 deletions docs/changes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ Pending
resolving to the wrong content type.
* Fixed SQL statement recording under PostgreSQL for queries encoded as byte
strings.
* Patch the ``CursorWrapper`` class with a mixin class to support multiple
base wrapper classes.

4.1.0 (2023-05-15)
------------------
Expand Down
1 change: 1 addition & 0 deletions docs/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ memcache
memcached
middleware
middlewares
mixin
mousedown
mouseup
multi
Expand Down
71 changes: 53 additions & 18 deletions tests/panels/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
import datetime
import os
import unittest
from unittest.mock import patch
from unittest.mock import call, patch

import django
from asgiref.sync import sync_to_async
from django.contrib.auth.models import User
from django.db import connection, transaction
from django.db.backends.utils import CursorDebugWrapper, CursorWrapper
from django.db.models import Count
from django.db.utils import DatabaseError
from django.shortcuts import render
Expand Down Expand Up @@ -68,39 +69,59 @@ def test_recording_chunked_cursor(self):
self.assertEqual(len(self.panel._queries), 1)

@patch(
"debug_toolbar.panels.sql.tracking.NormalCursorWrapper",
wraps=sql_tracking.NormalCursorWrapper,
"debug_toolbar.panels.sql.tracking.patch_cursor_wrapper_with_mixin",
wraps=sql_tracking.patch_cursor_wrapper_with_mixin,
)
def test_cursor_wrapper_singleton(self, mock_wrapper):
def test_cursor_wrapper_singleton(self, mock_patch_cursor_wrapper):
sql_call()

# ensure that cursor wrapping is applied only once
self.assertEqual(mock_wrapper.call_count, 1)
self.assertIn(
mock_patch_cursor_wrapper.mock_calls,
[
[call(CursorWrapper, sql_tracking.NormalCursorMixin)],
# CursorDebugWrapper is used if the test is called with `--debug-sql`
[call(CursorDebugWrapper, sql_tracking.NormalCursorMixin)],
],
)

@patch(
"debug_toolbar.panels.sql.tracking.NormalCursorWrapper",
wraps=sql_tracking.NormalCursorWrapper,
"debug_toolbar.panels.sql.tracking.patch_cursor_wrapper_with_mixin",
wraps=sql_tracking.patch_cursor_wrapper_with_mixin,
)
def test_chunked_cursor_wrapper_singleton(self, mock_wrapper):
def test_chunked_cursor_wrapper_singleton(self, mock_patch_cursor_wrapper):
sql_call(use_iterator=True)

# ensure that cursor wrapping is applied only once
self.assertEqual(mock_wrapper.call_count, 1)
self.assertIn(
mock_patch_cursor_wrapper.mock_calls,
[
[call(CursorWrapper, sql_tracking.NormalCursorMixin)],
# CursorDebugWrapper is used if the test is called with `--debug-sql`
[call(CursorDebugWrapper, sql_tracking.NormalCursorMixin)],
],
)

@patch(
"debug_toolbar.panels.sql.tracking.NormalCursorWrapper",
wraps=sql_tracking.NormalCursorWrapper,
"debug_toolbar.panels.sql.tracking.patch_cursor_wrapper_with_mixin",
wraps=sql_tracking.patch_cursor_wrapper_with_mixin,
)
async def test_cursor_wrapper_async(self, mock_wrapper):
async def test_cursor_wrapper_async(self, mock_patch_cursor_wrapper):
await sync_to_async(sql_call)()

self.assertEqual(mock_wrapper.call_count, 1)
self.assertIn(
mock_patch_cursor_wrapper.mock_calls,
[
[call(CursorWrapper, sql_tracking.NormalCursorMixin)],
# CursorDebugWrapper is used if the test is called with `--debug-sql`
[call(CursorDebugWrapper, sql_tracking.NormalCursorMixin)],
],
)

@patch(
"debug_toolbar.panels.sql.tracking.NormalCursorWrapper",
wraps=sql_tracking.NormalCursorWrapper,
"debug_toolbar.panels.sql.tracking.patch_cursor_wrapper_with_mixin",
wraps=sql_tracking.patch_cursor_wrapper_with_mixin,
)
async def test_cursor_wrapper_asyncio_ctx(self, mock_wrapper):
async def test_cursor_wrapper_asyncio_ctx(self, mock_patch_cursor_wrapper):
self.assertTrue(sql_tracking.allow_sql.get())
await sync_to_async(sql_call)()

Expand All @@ -116,7 +137,21 @@ async def task():
await asyncio.create_task(task())
# Because it was called in another context, it should not have affected ours
self.assertTrue(sql_tracking.allow_sql.get())
self.assertEqual(mock_wrapper.call_count, 1)

self.assertIn(
mock_patch_cursor_wrapper.mock_calls,
[
[
call(CursorWrapper, sql_tracking.NormalCursorMixin),
call(CursorWrapper, sql_tracking.ExceptionCursorMixin),
],
# CursorDebugWrapper is used if the test is called with `--debug-sql`
[
call(CursorDebugWrapper, sql_tracking.NormalCursorMixin),
call(CursorDebugWrapper, sql_tracking.ExceptionCursorMixin),
],
],
)

def test_generate_server_timing(self):
self.assertEqual(len(self.panel._queries), 0)
Expand Down

0 comments on commit 7677183

Please sign in to comment.