From 8f3007a5893ca933773aa11c2acd52db61068aac Mon Sep 17 00:00:00 2001 From: Eugene Morozov Date: Fri, 13 Jan 2023 14:55:13 +0200 Subject: [PATCH] Fixes exception when there are signals connected to abstract models. --- .../management/commands/list_signals.py | 21 ++++++++++++++++--- .../management/commands/test_list_signals.py | 17 ++++++++++++--- 2 files changed, 32 insertions(+), 6 deletions(-) diff --git a/django_extensions/management/commands/list_signals.py b/django_extensions/management/commands/list_signals.py index c41aa6f14..ba6193b79 100644 --- a/django_extensions/management/commands/list_signals.py +++ b/django_extensions/management/commands/list_signals.py @@ -7,8 +7,8 @@ import weakref from collections import defaultdict -from django.apps import apps from django.core.management.base import BaseCommand +from django.db.models import Model from django.db.models.signals import ( ModelSignal, pre_init, post_init, pre_save, post_save, pre_delete, post_delete, m2m_changed, pre_migrate, post_migrate @@ -31,12 +31,27 @@ } +def get_all_models() -> set[Model]: + """ + Returns set of all models defined in all apps. + + This implementation is required because apps.get_models() is an internal API and + doesn't return abstract models. + """ + result = set() + generation = {Model} + while generation: + generation = {sc for c in generation for sc in c.__subclasses__()} + result.update(generation) + + return result + + class Command(BaseCommand): help = 'List all signals by model and signal type' def handle(self, *args, **options): - all_models = apps.get_models(include_auto_created=True, include_swapped=True) - model_lookup = {id(m): m for m in all_models} + model_lookup = {id(m): m for m in get_all_models()} signals = [obj for obj in gc.get_objects() if isinstance(obj, ModelSignal)] models = defaultdict(lambda: defaultdict(list)) diff --git a/tests/management/commands/test_list_signals.py b/tests/management/commands/test_list_signals.py index 82ac9c4cd..a616855a4 100644 --- a/tests/management/commands/test_list_signals.py +++ b/tests/management/commands/test_list_signals.py @@ -1,10 +1,16 @@ -# -*- coding: utf-8 -*- import re from io import StringIO +from django.db.models.signals import post_delete from django.test import TestCase from django.core.management import call_command +from tests.testapp.models import AbstractInheritanceTestModelParent + + +def delete_dummy_handler(sender, instance, **kwargs): + pass + class ListSignalsTests(TestCase): """Tests for list_signals command.""" @@ -12,16 +18,21 @@ class ListSignalsTests(TestCase): def setUp(self): self.out = StringIO() + post_delete.connect(delete_dummy_handler, sender=AbstractInheritanceTestModelParent) + def test_should_print_all_signals(self): - expected_result = '''django.contrib.sites.models.Site (site) + expected_result = """django.contrib.sites.models.Site (site) pre_delete django.contrib.sites.models.clear_site_cache # pre_save django.contrib.sites.models.clear_site_cache # +tests.testapp.models.AbstractInheritanceTestModelParent (abstract inheritance test model parent) + post_delete + tests.management.commands.test_list_signals.delete_dummy_handler # tests.testapp.models.HasOwnerModel (has owner model) pre_save tests.testapp.models.dummy_handler # -''' +""" call_command('list_signals', stdout=self.out)