diff --git a/docs/third-party.rst b/docs/third-party.rst index 7b274f50..33c94a6d 100644 --- a/docs/third-party.rst +++ b/docs/third-party.rst @@ -97,7 +97,16 @@ This doesn't work, since it needs to look for revisions of the child model. Usin the view of the actual child model is used, similar to the way the regular change and delete views are redirected. +django-guardian support +----------------------- + +You can enable the content type of the base model to be used for the object levels permissions by setting the +django-guardian_ option `GUARDIAN_GET_CONTENT_TYPE` to `polymorphic.contrib.get_polymorphic_base_content_type`. Read +more about this option in the `django-guardian documentation `_. + + .. _django-reversion: https://github.com/etianen/django-reversion .. _django-reversion-compare: https://github.com/jedie/django-reversion-compare .. _django-mptt: https://github.com/django-mptt/django-mptt .. _django-polymorphic-tree: https://github.com/django-polymorphic/django-polymorphic-tree +.. _django-guardian: https://github.com/django-guardian/django-guardian diff --git a/polymorphic/contrib/__init__.py b/polymorphic/contrib/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/polymorphic/contrib/guardian.py b/polymorphic/contrib/guardian.py new file mode 100644 index 00000000..55b16012 --- /dev/null +++ b/polymorphic/contrib/guardian.py @@ -0,0 +1,35 @@ +from django.contrib.contenttypes.models import ContentType + + +def get_polymorphic_base_content_type(obj): + """ + Helper function to return the base polymorphic content type id. This should used with django-guardian and the + GUARDIAN_GET_CONTENT_TYPE option. + + See the django-guardian documentation for more information: + + https://django-guardian.readthedocs.io/en/latest/configuration.html#guardian-get-content-type + """ + if hasattr(obj, 'polymorphic_model_marker'): + try: + superclasses = list(obj.__class__.mro()) + except TypeError: + # obj is an object so mro() need to be called with the obj. + superclasses = list(obj.__class__.mro(obj)) + + polymorphic_superclasses = list() + for sclass in superclasses: + if hasattr(sclass, 'polymorphic_model_marker'): + polymorphic_superclasses.append(sclass) + + # PolymorphicMPTT adds an additional class between polymorphic and base class. + if hasattr(obj, 'can_have_children'): + root_polymorphic_class = polymorphic_superclasses[-3] + else: + root_polymorphic_class = polymorphic_superclasses[-2] + ctype = ContentType.objects.get_for_model(root_polymorphic_class) + + else: + ctype = ContentType.objects.get_for_model(obj) + + return ctype diff --git a/polymorphic/tests.py b/polymorphic/tests.py index d0483139..9c2cb3d4 100644 --- a/polymorphic/tests.py +++ b/polymorphic/tests.py @@ -21,6 +21,7 @@ from django.contrib.contenttypes.models import ContentType from django.utils import six +from polymorphic.contrib.guardian import get_polymorphic_base_content_type from polymorphic.models import PolymorphicModel from polymorphic.managers import PolymorphicManager from polymorphic.query import PolymorphicQuerySet @@ -195,6 +196,7 @@ class ModelWithMyManagerNoDefault(ShowFieldTypeAndContent, Model2A): my_objects = MyManager() field4 = models.CharField(max_length=10) + class ModelWithMyManagerDefault(ShowFieldTypeAndContent, Model2A): my_objects = MyManager() objects = PolymorphicManager() @@ -1194,6 +1196,24 @@ def test_polymorphic__expressions(self): result = Model2B.objects.annotate(val=Concat('field1', 'field2')) self.assertEqual(list(result), []) + def test_contrib_guardian(self): + # Regular Django inheritance should return the child model content type. + obj = PlainC() + ctype = get_polymorphic_base_content_type(obj) + self.assertEqual(ctype.name, 'plain c') + + ctype = get_polymorphic_base_content_type(PlainC) + self.assertEqual(ctype.name, 'plain c') + + # Polymorphic inheritance should return the parent model content type. + obj = Model2D() + ctype = get_polymorphic_base_content_type(obj) + self.assertEqual(ctype.name, 'model2a') + + ctype = get_polymorphic_base_content_type(Model2D) + self.assertEqual(ctype.name, 'model2a') + + class RegressionTests(TestCase): def test_for_query_result_incomplete_with_inheritance(self): @@ -1215,6 +1235,7 @@ def test_for_query_result_incomplete_with_inheritance(self): expected_queryset = [bottom] self.assertQuerysetEqual(Bottom.objects.all(), [repr(r) for r in expected_queryset]) + class MultipleDatabasesTests(TestCase): multi_db = True