Browse files

Merge 2e31c56 into cd2e429

  • Loading branch information...
2 parents cd2e429 + 2e31c56 commit dc2d529a48f2296bf7ae7f09b4c31eb071338cb2 @rpkilby rpkilby committed on GitHub Jul 27, 2016
Showing with 87 additions and 34 deletions.
  1. +17 −20 django/db/models/fields/related.py
  2. +35 −13 django/db/models/query_utils.py
  3. +4 −0 tests/custom_lookups/models.py
  4. +31 −1 tests/custom_lookups/tests.py
View
37 django/db/models/fields/related.py
@@ -1,5 +1,6 @@
from __future__ import unicode_literals
+import inspect
import warnings
from functools import partial
@@ -11,7 +12,7 @@
from django.db.models import Q
from django.db.models.constants import LOOKUP_SEP
from django.db.models.deletion import CASCADE, SET_DEFAULT, SET_NULL
-from django.db.models.query_utils import PathInfo
+from django.db.models.query_utils import PathInfo, merge_dicts
from django.db.models.utils import make_model_tuple
from django.utils import six
from django.utils.deprecation import RemovedInDjango20Warning
@@ -731,26 +732,14 @@ def get_reverse_path_info(self):
pathinfos = [PathInfo(from_opts, opts, (opts.pk,), self.remote_field, not self.unique, False)]
return pathinfos
- def get_lookup(self, lookup_name):
- if lookup_name == 'in':
- return RelatedIn
- elif lookup_name == 'exact':
- return RelatedExact
- elif lookup_name == 'gt':
- return RelatedGreaterThan
- elif lookup_name == 'gte':
- return RelatedGreaterThanOrEqual
- elif lookup_name == 'lt':
- return RelatedLessThan
- elif lookup_name == 'lte':
- return RelatedLessThanOrEqual
- elif lookup_name == 'isnull':
- return RelatedIsNull
- else:
- raise TypeError('Related Field got invalid lookup: %s' % lookup_name)
+ def get_lookups(self):
+ if 'cached_lookups' not in self.__class__.__dict__:
+ bases = inspect.getmro(self.__class__)
+ bases = bases[:bases.index(ForeignObject) + 1]
+ class_lookups = [parent.__dict__.get('class_lookups', {}) for parent in bases]
+ self.__class__.cached_lookups = merge_dicts(class_lookups)
- def get_transform(self, *args, **kwargs):
- raise NotImplementedError('Relational fields do not support transforms.')
+ return self.__class__.cached_lookups
def contribute_to_class(self, cls, name, private_only=False, **kwargs):
super(ForeignObject, self).contribute_to_class(cls, name, private_only=private_only, **kwargs)
@@ -767,6 +756,14 @@ def contribute_to_related_class(self, cls, related):
if self.remote_field.limit_choices_to:
cls._meta.related_fkey_lookups.append(self.remote_field.limit_choices_to)
+ForeignObject.register_lookup(RelatedIn)
+ForeignObject.register_lookup(RelatedExact)
+ForeignObject.register_lookup(RelatedLessThan)
+ForeignObject.register_lookup(RelatedGreaterThan)
+ForeignObject.register_lookup(RelatedGreaterThanOrEqual)
+ForeignObject.register_lookup(RelatedLessThanOrEqual)
+ForeignObject.register_lookup(RelatedIsNull)
+
class ForeignKey(ForeignObject):
"""
View
48 django/db/models/query_utils.py
@@ -27,6 +27,23 @@ class InvalidQuery(Exception):
pass
+def merge_dicts(dicts):
+ merged = {}
+
+ # merge in reverse to preference order
+ for d in reversed(dicts):
+ merged.update(d)
+
+ return merged
+
+
+def get_subclasses(cls):
+ for subclass in cls.__subclasses__():
+ for subsubclass in get_subclasses(subclass):
+ yield subsubclass
+ yield subclass
+
+
class QueryWrapper(object):
"""
A type that indicates the contents are an SQL fragment and the associate
@@ -133,19 +150,14 @@ def _check_parent_chain(self, instance, name):
class RegisterLookupMixin(object):
def _get_lookup(self, lookup_name):
- try:
- return self.class_lookups[lookup_name]
- except KeyError:
- # To allow for inheritance, check parent class' class_lookups.
- for parent in inspect.getmro(self.__class__):
- if 'class_lookups' not in parent.__dict__:
- continue
- if lookup_name in parent.class_lookups:
- return parent.class_lookups[lookup_name]
- except AttributeError:
- # This class didn't have any class_lookups
- pass
- return None
+ return self.get_lookups().get(lookup_name, None)
+
+ def get_lookups(self):
+ if 'cached_lookups' not in self.__class__.__dict__:
+ class_lookups = [parent.__dict__.get('class_lookups', {}) for parent in inspect.getmro(self.__class__)]
+ self.__class__.cached_lookups = merge_dicts(class_lookups)
+
+ return self.__class__.cached_lookups
def get_lookup(self, lookup_name):
from django.db.models.lookups import Lookup
@@ -166,12 +178,22 @@ def get_transform(self, lookup_name):
return found
@classmethod
+ def _bust_cached_lookups(cls):
+ for subclass in get_subclasses(cls):
+ if 'cached_lookups' in subclass.__dict__:
+ del subclass.cached_lookups
+
+ if 'cached_lookups' in cls.__dict__:
+ del cls.cached_lookups
+
+ @classmethod
def register_lookup(cls, lookup, lookup_name=None):
if lookup_name is None:
lookup_name = lookup.lookup_name
if 'class_lookups' not in cls.__dict__:
cls.class_lookups = {}
cls.class_lookups[lookup_name] = lookup
+ cls._bust_cached_lookups()
return lookup
@classmethod
View
4 tests/custom_lookups/models.py
@@ -13,6 +13,10 @@ def __str__(self):
return self.name
+class Article(models.Model):
+ author = models.ForeignKey(Author, on_delete=models.CASCADE)
+
+
@python_2_unicode_compatible
class MySQLUnixTimestamp(models.Model):
timestamp = models.PositiveIntegerField()
View
32 tests/custom_lookups/tests.py
@@ -10,7 +10,7 @@
from django.test import TestCase, override_settings
from django.utils import timezone
-from .models import Author, MySQLUnixTimestamp
+from .models import Article, Author, MySQLUnixTimestamp
@contextlib.contextmanager
@@ -319,6 +319,36 @@ def test_div3_extract(self):
baseqs.filter(age__div3__range=(1, 2)),
[a1, a2, a4], lambda x: x)
+ def test_foreignobject_lookup_registration(self):
+ field = Article._meta.get_field('author')
+
+ with register_lookup(models.ForeignObject, Exactly):
+ self.assertIs(field.get_lookup('exactly'), Exactly)
+
+ # ForeignObject should ignore regular Field lookups
+ with register_lookup(models.Field, Exactly):
+ self.assertIsNone(field.get_lookup('exactly'))
+
+ def test_lookups_caching(self):
+ field = Article._meta.get_field('author')
+
+ # clear cache
+ del field.__class__.cached_lookups
+ self.assertNotIn('cached_lookups', field.__class__.__dict__)
+
+ # get_lookups() should cache for reuse.
+ field.get_lookups()
+ self.assertIn('cached_lookups', field.__class__.__dict__)
+
+ with register_lookup(models.ForeignObject, Exactly):
+ # registration should bust/remove the cache
+ self.assertNotIn('cached_lookups', field.__class__.__dict__)
+
+ # getting the lookups again should re-cache
+ field.get_lookups()
+ self.assertIn('cached_lookups', field.__class__.__dict__)
+ self.assertIn('exactly', field.__class__.__dict__['cached_lookups'])
+
class BilateralTransformTests(TestCase):

0 comments on commit dc2d529

Please sign in to comment.