Skip to content

Commit

Permalink
Prepare the manager-using code for Django 1.10
Browse files Browse the repository at this point in the history
This change simply requires the programmer to ensure that
_default_manager returns a tree manager. Custom manager classes
must extend TreeManager.

Refs django-mptt#469.
  • Loading branch information
matthiask committed Jul 12, 2016
1 parent 9e131f9 commit 7914c39
Show file tree
Hide file tree
Showing 11 changed files with 42 additions and 101 deletions.
11 changes: 4 additions & 7 deletions mptt/admin.py
Expand Up @@ -75,7 +75,7 @@ def delete_selected_tree(self, modeladmin, request, queryset):
# If this is True, the confirmation page has been displayed
if request.POST.get('post'):
n = 0
with queryset.model._tree_manager.delay_mptt_updates():
with queryset.model._default_manager.delay_mptt_updates():
for obj in queryset:
if self.has_delete_permission(request, obj):
obj.delete()
Expand Down Expand Up @@ -225,7 +225,7 @@ def _move_node(self, request):
return http.HttpResponse('FAIL, no permission.')

try:
self.model._tree_manager.move_node(cut_item, pasted_on, position)
self.model._default_manager.move_node(cut_item, pasted_on, position)
except InvalidMove as e:
self.message_user(request, '%s' % e)
return http.HttpResponse('FAIL, invalid move.')
Expand Down Expand Up @@ -303,10 +303,7 @@ class MyModelAdmin(admin.ModelAdmin):

def __init__(self, field, request, params, model, model_admin, field_path):
self.other_model = get_model_from_relation(field)
if hasattr(field, 'rel'):
self.rel_name = field.rel.get_related_field().name
else:
self.rel_name = self.other_model._meta.pk.name
self.rel_name = self.other_model._meta.pk.name
self.changed_lookup_kwarg = '%s__%s__inhierarchy' % (field_path, self.rel_name)
super(TreeRelatedFieldListFilter, self).__init__(field, request, params,
model, model_admin, field_path)
Expand Down Expand Up @@ -376,7 +373,7 @@ def choices(self, cl):
}
if (isinstance(self.field, ForeignObjectRel) and
(self.field.field.null or isinstance(self.field.field, ManyToManyField)) or
hasattr(self.field, 'rel') and
hasattr(self.field, 'remote_field') and
(self.field.null or isinstance(self.field, ManyToManyField))):
yield {
'selected': bool(self.lookup_val_isnull),
Expand Down
2 changes: 1 addition & 1 deletion mptt/forms.py
Expand Up @@ -121,7 +121,7 @@ def __init__(self, node, *args, **kwargs):
super(MoveNodeForm, self).__init__(*args, **kwargs)
opts = node._mptt_meta
if valid_targets is None:
valid_targets = node._tree_manager.exclude(**{
valid_targets = node.__class__._default_manager.exclude(**{
opts.tree_id_attr: getattr(node, opts.tree_id_attr),
opts.left_attr + '__gte': getattr(node, opts.left_attr),
opts.right_attr + '__lte': getattr(node, opts.right_attr),
Expand Down
18 changes: 6 additions & 12 deletions mptt/managers.py
Expand Up @@ -66,6 +66,7 @@ def delegate_manager(method):
"""
@functools.wraps(method)
def wrapped(self, *args, **kwargs):
return method(self, *args, **kwargs) # FIXME what should this really do?
if self._base_manager:
return getattr(self._base_manager, method.__name__)(*args, **kwargs)
return method(self, *args, **kwargs)
Expand All @@ -78,16 +79,9 @@ class TreeManager(models.Manager.from_queryset(TreeQuerySet)):
A manager for working with trees of objects.
"""

def contribute_to_class(self, model, name):
super(TreeManager, self).contribute_to_class(model, name)

if not model._meta.abstract:
self.tree_model = _get_tree_model(model)

self._base_manager = None
if self.tree_model is not model:
# _base_manager is the treemanager on tree_model
self._base_manager = self.tree_model._tree_manager
@property
def tree_model(self):
return _get_tree_model(self.model)

def get_queryset(self, *args, **kwargs):
"""
Expand Down Expand Up @@ -482,7 +476,7 @@ def add_related_count(self, queryset, rel_model, rel_field, count_attr,
'rel_table': qn(rel_model._meta.db_table),
'mptt_fk': qn(rel_model._meta.get_field(rel_field).column),
'mptt_table': qn(self.tree_model._meta.db_table),
'mptt_rel_to': qn(mptt_field.rel.field_name),
'mptt_rel_to': qn(mptt_field.remote_field.field_name),
'tree_id': qn(meta.get_field(self.tree_id_attr).column),
'left': qn(meta.get_field(self.left_attr).column),
'right': qn(meta.get_field(self.right_attr).column),
Expand All @@ -492,7 +486,7 @@ def add_related_count(self, queryset, rel_model, rel_field, count_attr,
'rel_table': qn(rel_model._meta.db_table),
'mptt_fk': qn(rel_model._meta.get_field(rel_field).column),
'mptt_table': qn(self.tree_model._meta.db_table),
'mptt_rel_to': qn(mptt_field.rel.field_name),
'mptt_rel_to': qn(mptt_field.remote_field.field_name),
}
return queryset.extra(select={count_attr: subquery})

Expand Down
52 changes: 10 additions & 42 deletions mptt/models.py
Expand Up @@ -212,7 +212,7 @@ def get_ordered_insertion_target(self, node, parent):
# Fall back on tree id ordering if multiple root nodes have
# the same values.
order_by.append(opts.tree_id_attr)
queryset = node.__class__._tree_manager.db_manager(node._state.db).filter(filters).order_by(*order_by)
queryset = node._tree_manager.db_manager(node._state.db).filter(filters).order_by(*order_by)
if node.pk:
queryset = queryset.exclude(pk=node.pk)
try:
Expand Down Expand Up @@ -271,7 +271,7 @@ class MPTTMeta:
bases = [base for base in cls.mro() if issubclass(base, MPTTModel)]
for base in bases:
if (not (base._meta.abstract or base._meta.proxy) and
base._tree_manager.tree_model is base):
getattr(base._default_manager, 'tree_model', None) is base):
cls._mptt_tracking_base = base
break
if cls is cls._mptt_tracking_base:
Expand All @@ -294,8 +294,6 @@ def register(meta, cls, **kwargs):
if not hasattr(cls, '_mptt_meta'):
cls._mptt_meta = MPTTOptions(**kwargs)

abstract = getattr(cls._meta, 'abstract', False)

try:
MPTTModel
except NameError:
Expand Down Expand Up @@ -333,40 +331,6 @@ def register(meta, cls, **kwargs):
field = models.PositiveIntegerField(db_index=True, editable=False)
field.contribute_to_class(cls, field_name)

# Add a tree manager, if there isn't one already
if not abstract:
manager = getattr(cls, 'objects', None)
if manager is None:
manager = cls._default_manager._copy_to_model(cls)
manager.contribute_to_class(cls, 'objects')
elif manager.model != cls:
# manager was inherited
manager = manager._copy_to_model(cls)
manager.contribute_to_class(cls, 'objects')

# make sure we have a tree manager somewhere
tree_manager = None
if hasattr(cls._meta, 'concrete_managers'): # Django < 1.10
cls_managers = cls._meta.concrete_managers + cls._meta.abstract_managers
cls_managers = [r[2] for r in cls_managers]
else:
cls_managers = cls._meta.managers

for cls_manager in cls_managers:
if isinstance(cls_manager, TreeManager):
# prefer any locally defined manager (i.e. keep going if not local)
if cls_manager.model is cls:
tree_manager = cls_manager
break

if tree_manager and tree_manager.model is not cls:
tree_manager = tree_manager._copy_to_model(cls)
elif tree_manager is None:
tree_manager = TreeManager()
tree_manager.contribute_to_class(cls, '_tree_manager')

# avoid using ManagerDescriptor, so instances can refer to self._tree_manager
setattr(cls, '_tree_manager', tree_manager)
return cls


Expand All @@ -386,15 +350,19 @@ class MPTTModel(six.with_metaclass(MPTTModelBase, models.Model)):
"""
Base class for tree models.
"""
_default_manager = TreeManager()

objects = TreeManager()

class Meta:
abstract = True

def __init__(self, *args, **kwargs):
super(MPTTModel, self).__init__(*args, **kwargs)
self._mptt_meta.update_mptt_cached_fields(self)
self._tree_manager = self._tree_manager.db_manager(self._state.db)

@property
def _tree_manager(self):
return _get_tree_model(self.__class__)._default_manager

def _mpttfield(self, fieldname):
translated_fieldname = getattr(self._mptt_meta, fieldname + '_attr')
Expand Down Expand Up @@ -789,7 +757,7 @@ def _is_saved(self, using=None):
if not self.pk or self._mpttfield('tree_id') is None:
return False
opts = self._meta
if opts.pk.rel is None:
if opts.pk.remote_field is None:
return True
else:
if not hasattr(self, '_mptt_saved'):
Expand Down Expand Up @@ -1027,7 +995,7 @@ def delete(self, *args, **kwargs):
def _mptt_refresh(self):
if not self.pk:
return
manager = type(self)._tree_manager
manager = type(self)._default_manager
opts = self._mptt_meta
values = manager.filter(pk=self.pk).values(
opts.left_attr,
Expand Down
4 changes: 2 additions & 2 deletions mptt/querysets.py
Expand Up @@ -8,14 +8,14 @@ def get_descendants(self, *args, **kwargs):
"""
Alias to `mptt.managers.TreeManager.get_queryset_descendants`.
"""
return self.model._tree_manager.get_queryset_descendants(self, *args, **kwargs)
return self.model._default_manager.get_queryset_descendants(self, *args, **kwargs)
get_descendants.queryset_only = True

def get_ancestors(self, *args, **kwargs):
"""
Alias to `mptt.managers.TreeManager.get_queryset_ancestors`.
"""
return self.model._tree_manager.get_queryset_ancestors(self, *args, **kwargs)
return self.model._default_manager.get_queryset_ancestors(self, *args, **kwargs)
get_ancestors.queryset_only = True

def get_cached_trees(self):
Expand Down
2 changes: 1 addition & 1 deletion mptt/templatetags/mptt_tags.py
Expand Up @@ -30,7 +30,7 @@ def render(self, context):
raise template.TemplateSyntaxError(
_('full_tree_for_model tag was given an invalid model: %s') % self.model
)
context[self.context_var] = cls._tree_manager.all()
context[self.context_var] = cls._default_manager.all()
return ''


Expand Down
2 changes: 1 addition & 1 deletion mptt/utils.py
Expand Up @@ -144,7 +144,7 @@ def drilldown_tree_for_node(node, rel_cls=None, rel_field=None, count_attr=None,
descendants, otherwise it will be for each child itself.
"""
if rel_cls and rel_field and count_attr:
children = node._tree_manager.add_related_count(
children = node.__class__._default_manager.add_related_count(
node.get_children(), rel_cls, rel_field, count_attr, cumulative)
else:
children = node.get_children()
Expand Down
20 changes: 10 additions & 10 deletions tests/myapp/doctests.txt
Expand Up @@ -208,23 +208,23 @@

# TreeManager Methods #########################################################
# check that tree manager is the explicitly defined tree manager for Person
>>> Person._tree_manager == Person.objects
>>> Person._default_manager == Person.objects
True

# managers of non-abstract bases don't get inherited, so:
>>> Student._tree_manager == Student.objects
False
# managers of non-abstract bases get inherited, so:
>>> Student._default_manager == Student.objects
True

>>> Student._tree_manager == Person._tree_manager
False
>>> Student._default_manager == Person._default_manager
True

>>> Student._tree_manager.model
>>> Student._default_manager.model
<class 'myapp.models.Student'>
>>> Student._tree_manager.tree_model
>>> Student._default_manager.tree_model
<class 'myapp.models.Person'>
>>> Person._tree_manager.model
>>> Person._default_manager.model
<class 'myapp.models.Person'>
>>> Person._tree_manager.tree_model
>>> Person._default_manager.tree_model
<class 'myapp.models.Person'>

>>> Genre.objects.root_node(action.tree_id)
Expand Down
3 changes: 0 additions & 3 deletions tests/myapp/models.py
Expand Up @@ -164,9 +164,6 @@ class Person(MPTTModel):
# just testing it's actually possible to override the tree manager
objects = CustomTreeManager()

# This line is set because of https://github.com/django-mptt/django-mptt/issues/369
_default_manager = objects

def __str__(self):
return self.name

Expand Down
25 changes: 5 additions & 20 deletions tests/myapp/tests.py
Expand Up @@ -9,7 +9,7 @@

from django import forms
from django.contrib.auth.models import Group, User
from django.db.models import Q
from django.db.models import Q, Manager
from django.db.models.query_utils import DeferredAttribute
from django.apps import apps
from django.forms.models import modelform_factory
Expand Down Expand Up @@ -1152,7 +1152,7 @@ def test_all_managers_are_different(self):
for model in apps.get_models():
if not issubclass(model, MPTTModel):
continue
tm = model._tree_manager
tm = model._default_manager
if id(tm) in seen:
self.fail(
"Tree managers for %s and %s are the same manager"
Expand All @@ -1164,27 +1164,12 @@ def test_all_managers_have_correct_model(self):
for model in apps.get_models():
if not issubclass(model, MPTTModel):
continue
self.assertEqual(model._tree_manager.model, model)

def test_base_manager_infinite_recursion(self):
# repeatedly calling _base_manager should eventually return None
for model in apps.get_models():
if not issubclass(model, MPTTModel):
continue
manager = model._tree_manager
for i in range(20):
manager = manager._base_manager
if manager is None:
break
else:
self.fail("Detected infinite recursion in %s._tree_manager._base_manager" % model)
self.assertEqual(model._default_manager.model, model)

def test_proxy_custom_manager(self):
self.assertIsInstance(SingleProxyModel._tree_manager, CustomTreeManager)
self.assertIsInstance(SingleProxyModel._tree_manager._base_manager, TreeManager)

self.assertIsInstance(SingleProxyModel._default_manager, CustomTreeManager)
self.assertIsInstance(SingleProxyModel.objects, CustomTreeManager)
self.assertIsInstance(SingleProxyModel.objects._base_manager, TreeManager)
self.assertIsInstance(SingleProxyModel._base_manager, Manager)

def test_get_queryset_descendants(self):
def get_desc_names(qs, include_self=False):
Expand Down
4 changes: 2 additions & 2 deletions tests/myapp/urls.py
@@ -1,4 +1,4 @@
from django.conf.urls import include, url
from django.conf.urls import url

from django.contrib import admin

Expand All @@ -7,5 +7,5 @@


urlpatterns = [
url(r'^admin/', include(admin.site.urls)),
url(r'^admin/', admin.site.urls),
]

0 comments on commit 7914c39

Please sign in to comment.