Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

Implemented nested lookups

But there is no support of using lookups outside filtering yet.
  • Loading branch information...
commit 7c8b3a32cc17b4dbca160921d48125f1631e0df4 1 parent 4d219d4
@akaariai akaariai authored
View
5 django/db/models/fields/related.py
@@ -1136,11 +1136,14 @@ def get_reverse_path_info(self):
pathinfos = [PathInfo(from_opts, opts, (opts.pk,), self.rel, not self.unique, False)]
return pathinfos
- def get_lookup_constraint(self, constraint_class, alias, targets, sources, lookup_type,
+ def get_lookup_constraint(self, constraint_class, alias, targets, sources, lookups,
raw_value):
from django.db.models.sql.where import SubqueryConstraint, Constraint, AND, OR
root_constraint = constraint_class()
assert len(targets) == len(sources)
+ if len(lookups) > 1:
+ raise exceptions.FieldError('Relation fields do not support nested lookups')
+ lookup_type = lookups[0]
def get_normalized_value(value):
View
49 django/db/models/lookups.py
@@ -1,27 +1,58 @@
from copy import copy
+from django.core.exceptions import FieldError
from django.conf import settings
from django.utils import timezone
+from django.utils.functional import cached_property
+
+
+class Extract(object):
+ def __init__(self, constraint_class, lhs):
+ self.constraint_class, self.lhs = constraint_class, lhs
+
+ def get_lookup(self, lookup):
+ return self.output_type.get_lookup(lookup)
+
+ def as_sql(self, qn, connection):
+ raise NotImplementedError
+
+ @cached_property
+ def output_type(self):
+ return self.lhs.output_type
+
+ def relabeled_clone(self, relabels):
+ return self.__class__(self.constraint_class, self.lhs.relabeled_clone(relabels))
class Lookup(object):
+ lookup_name = None
+ extract_class = None
+
def __init__(self, constraint_class, lhs, rhs):
self.constraint_class, self.lhs, self.rhs = constraint_class, lhs, rhs
- self.rhs = self.get_prep_lookup()
+ if rhs is None:
+ if not self.extract_class:
+ raise FieldError("Lookup '%s' doesn't support nesting." % self.lookup_name)
+ else:
+ self.rhs = self.get_prep_lookup()
+
+ def get_extract(self):
+ return self.extract_class(self.constraint_class, self.lhs)
+
+ def get_prep_lookup(self):
+ return self.lhs.output_type.get_prep_lookup(self.lookup_name, self.rhs)
def get_db_prep_lookup(self, value, connection):
return (
'%s', self.lhs.output_type.get_db_prep_lookup(
self.lookup_name, value, connection, prepared=True))
- def get_prep_lookup(self):
- return self.lhs.output_type.get_prep_lookup(self.lookup_name, self.rhs)
-
- def process_lhs(self, qn, connection):
- return qn.compile(self.lhs)
+ def process_lhs(self, qn, connection, lhs=None):
+ lhs = lhs or self.lhs
+ return qn.compile(lhs)
- def process_rhs(self, qn, connection):
- value = self.rhs
+ def process_rhs(self, qn, connection, rhs=None):
+ value = rhs or self.rhs
# Due to historical reasons there are a couple of different
# ways to produce sql here. get_compiler is likely a Query
# instance, _as_sql QuerySet and as_sql just something with
@@ -118,7 +149,7 @@ class In(DjangoLookup):
lookup_name = 'in'
def get_db_prep_lookup(self, value, connection):
- params = self.lhs.field.get_db_prep_lookup(
+ params = self.lhs.output_type.get_db_prep_lookup(
self.lookup_name, value, connection, prepared=True)
if not params:
# TODO: check why this leads to circular import
View
3  django/db/models/sql/aggregates.py
@@ -100,6 +100,9 @@ def get_cols(self):
def output_type(self):
return self.field
+ def get_lookup(self, lookup):
+ return self.output_type.get_lookup(lookup)
+
class Avg(Aggregate):
is_computed = True
View
3  django/db/models/sql/datastructures.py
@@ -25,6 +25,9 @@ def relabeled_clone(self, relabels):
def get_cols(self):
return [(self.alias, self.target.column)]
+ def get_lookup(self, name):
+ return self.output_type.get_lookup(name)
+
class EmptyResultSet(Exception):
pass
View
104 django/db/models/sql/query.py
@@ -1027,19 +1027,16 @@ def add_aggregate(self, aggregate, model, alias, is_summary):
# Add the aggregate to the query
aggregate.add_to_query(self, alias, col=col, source=source, is_summary=is_summary)
- def prepare_lookup_value(self, value, lookup_type, can_reuse):
+ def prepare_lookup_value(self, value, lookups, can_reuse):
+ # Default lookup if none given is exact.
+ if len(lookups) == 0:
+ lookups = ['exact']
# Interpret '__exact=None' as the sql 'is NULL'; otherwise, reject all
# uses of None as a query value.
- if len(lookup_type) > 1:
- raise FieldError('Nested lookups not allowed')
- elif len(lookup_type) == 0:
- lookup_type = 'exact'
- else:
- lookup_type = lookup_type[0]
if value is None:
- if lookup_type != 'exact':
+ if lookups[-1] != 'exact':
raise ValueError("Cannot use None as a query value")
- lookup_type = 'isnull'
+ lookups[-1] = 'isnull'
value = True
elif callable(value):
value = value()
@@ -1057,10 +1054,10 @@ def prepare_lookup_value(self, value, lookup_type, can_reuse):
# stage. Using DEFAULT_DB_ALIAS isn't nice, but it is the best we
# can do here. Similar thing is done in is_nullable(), too.
if (connections[DEFAULT_DB_ALIAS].features.interprets_empty_strings_as_nulls and
- lookup_type == 'exact' and value == ''):
+ lookups[-1] == 'exact' and value == ''):
value = True
- lookup_type = 'isnull'
- return value, lookup_type
+ lookups[-1] = ['isnull']
+ return value, lookups
def solve_lookup_type(self, lookup):
"""
@@ -1069,36 +1066,37 @@ def solve_lookup_type(self, lookup):
lookup_splitted = lookup.split(LOOKUP_SEP)
aggregate, aggregate_lookups = refs_aggregate(lookup_splitted, self.aggregates)
if aggregate:
- if len(aggregate_lookups) > 1:
- raise FieldError("Nested lookups not allowed.")
return aggregate_lookups, (), aggregate
_, field, _, lookup_parts = self.names_to_path(lookup_splitted, self.get_meta())
field_parts = lookup_splitted[0:len(lookup_splitted) - len(lookup_parts)]
if len(lookup_parts) == 0:
lookup_parts = ['exact']
elif len(lookup_parts) > 1:
- if field_parts:
- raise FieldError(
- 'Only one lookup part allowed (found path "%s" from "%s").' %
- (LOOKUP_SEP.join(field_parts), lookup))
- else:
+ if not field_parts:
raise FieldError(
'Invalid lookup "%s" for model %s".' %
(lookup, self.get_meta().model.__name__))
- else:
- if not hasattr(field, 'get_lookup_constraint'):
- lookup_class = field.get_lookup(lookup_parts[0])
- if lookup_class is None and lookup_parts[0] not in self.query_terms:
- raise FieldError(
- 'Invalid lookup name %s' % lookup_parts[0])
return lookup_parts, field_parts, False
- def build_lookup(self, lookup_type, lhs, rhs):
- if hasattr(lhs.output_type, 'get_lookup'):
- lookup = lhs.output_type.get_lookup(lookup_type)
- if lookup:
- return lookup(self.where_class, lhs, rhs)
- return None
+ def build_lookup(self, lookups, lhs, rhs):
+ lookups = lookups[:]
+ lookups.reverse()
+ while lookups:
+ lookup = lookups.pop()
+ next = lhs.get_lookup(lookup)
+ if next:
+ if not lookups:
+ # This was the last lookup, so return value lookup.
+ return next(self.where_class, lhs, rhs)
+ else:
+ lhs = next(self.where_class, lhs, None).get_extract()
+ # A field's get_lookup() can return None to opt for backwards
+ # compatibility path.
+ elif len(lookups) > 1:
+ raise FieldError(
+ "Unsupported lookup for field '%s'" % lhs.output_type.name)
+ else:
+ return None
def build_filter(self, filter_expr, branch_negated=False, current_negated=False,
can_reuse=None, connector=AND):
@@ -1130,19 +1128,20 @@ def build_filter(self, filter_expr, branch_negated=False, current_negated=False,
arg, value = filter_expr
if not arg:
raise FieldError("Cannot parse keyword query %r" % arg)
- lookup_type, parts, reffed_aggregate = self.solve_lookup_type(arg)
+ lookups, parts, reffed_aggregate = self.solve_lookup_type(arg)
# Work out the lookup type and remove it from the end of 'parts',
# if necessary.
- value, lookup_type = self.prepare_lookup_value(value, lookup_type, can_reuse)
+ value, lookups = self.prepare_lookup_value(value, lookups, can_reuse)
used_joins = getattr(value, '_used_joins', [])
clause = self.where_class()
if reffed_aggregate:
- condition = self.build_lookup(lookup_type, reffed_aggregate, value)
+ condition = self.build_lookup(lookups, reffed_aggregate, value)
if not condition:
# Backwards compat for custom lookups
- condition = (reffed_aggregate, lookup_type, value)
+ assert len(lookups) == 1
+ condition = (reffed_aggregate, lookups[0], value)
clause.add(condition, AND)
return clause, []
@@ -1169,14 +1168,27 @@ def build_filter(self, filter_expr, branch_negated=False, current_negated=False,
# For now foreign keys get special treatment. This should be
# refactored when composite fields lands.
condition = field.get_lookup_constraint(self.where_class, alias, targets, sources,
- lookup_type, value)
+ lookups, value)
+ lookup_type = lookups[-1]
else:
assert(len(targets) == 1)
col = Col(alias, targets[0], field)
- condition = self.build_lookup(lookup_type, col, value)
+ condition = self.build_lookup(lookups, col, value)
if not condition:
# Backwards compat for custom lookups
- condition = (Constraint(alias, targets[0].column, field), lookup_type, value)
+ if lookups[0] not in self.query_terms:
+ raise FieldError(
+ "Join on field '%s' not permitted. Did you "
+ "misspell '%s' for the lookup type?" %
+ (col.output_type.name, lookups[0]))
+ if len(lookups) > 1:
+ raise FieldError("Nested lookup '%s' not supported." %
+ LOOKUP_SEP.join(lookups))
+ condition = (Constraint(alias, targets[0].column, field), lookups[0], value)
+ lookup_type = lookups[-1]
+ else:
+ lookup_type = condition.lookup_name
+
clause.add(condition, AND)
require_outer = lookup_type == 'isnull' and value is True and not current_negated
@@ -1296,7 +1308,7 @@ def _add_q(self, q_object, used_aliases, branch_negated=False,
needed_inner = joinpromoter.update_join_types(self)
return target_clause, needed_inner
- def names_to_path(self, names, opts, allow_many=True):
+ def names_to_path(self, names, opts, allow_many=True, fail_on_missing=False):
"""
Walks the names path and turns them PathInfo tuples. Note that a
single name in 'names' can generate multiple PathInfos (m2m for
@@ -1354,10 +1366,15 @@ def names_to_path(self, names, opts, allow_many=True):
final_field = field
targets = (field,)
break
- if pos == -1:
- raise FieldError('Whazaa')
+ if pos == -1 or (fail_on_missing and pos + 1 != len(names)):
+ self.raise_field_error(opts, name)
return path, final_field, targets, names[pos + 1:]
+ def raise_field_error(self, opts, name):
+ available = opts.get_all_field_names() + list(self.aggregate_select)
+ raise FieldError("Cannot resolve keyword %r into field. "
+ "Choices are: %s" % (name, ", ".join(available)))
+
def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True):
"""
Compute the necessary table joins for the passage through the fields
@@ -1386,9 +1403,8 @@ def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True):
joins = [alias]
# First, generate the path for the names
path, final_field, targets, rest = self.names_to_path(
- names, opts, allow_many)
- if rest:
- raise FieldError('Invalid lookup')
+ names, opts, allow_many, fail_on_missing=True)
+
# Then, add the path to the query's joins. Note that we can't trim
# joins at this stage - we will need the information about join type
# of the trimmed joins.
View
6 tests/custom_lookups/models.py
@@ -1,7 +1,13 @@
from django.db import models
+from django.utils.encoding import python_2_unicode_compatible
+@python_2_unicode_compatible
class Author(models.Model):
name = models.CharField(max_length=20)
age = models.IntegerField(null=True)
birthdate = models.DateField(null=True)
+ average_rating = models.FloatField(null=True)
+
+ def __str__(self):
+ return self.name
View
119 tests/custom_lookups/tests.py
@@ -19,6 +19,56 @@ def as_sql(self, qn, connection):
return '%s %%%% 3 = %s' % (lhs, rhs), params
+class Div3Extract(models.lookups.Extract):
+ def as_sql(self, qn, connection):
+ lhs, lhs_params = qn.compile(self.lhs)
+ return '%s %%%% 3' % (lhs,), lhs_params
+
+
+class Div3LookupWithExtract(Div3Lookup):
+ lookup_name = 'div3'
+ extract_class = Div3Extract
+
+
+class YearLte(models.lookups.LessThanOrEqual):
+ """
+ The purpose of this lookup is to efficiently compare the year of the field.
+ """
+
+ def as_sql(self, qn, connection):
+ # Skip the YearExtract above us (no possibility for efficient
+ # lookup otherwise).
+ real_lhs = self.lhs.lhs
+ lhs_sql, params = self.process_lhs(qn, connection, real_lhs)
+ rhs_sql, rhs_params = self.process_rhs(qn, connection)
+ params.extend(rhs_params)
+ # Build SQL where the integer year is concatenated with last month
+ # and day, then convert that to date. (We try to have SQL like:
+ # WHERE somecol <= '2013-12-31')
+ # but also make it work if the rhs_sql is field reference.
+ return "%s <= (%s || '-12-31')::date" % (lhs_sql, rhs_sql), params
+
+
+class YearExtract(models.lookups.Extract):
+ def as_sql(self, qn, connection):
+ lhs_sql, params = qn.compile(self.lhs)
+ return connection.ops.date_extract_sql('year', lhs_sql), params
+
+ @property
+ def output_type(self):
+ return models.IntegerField()
+
+ def get_lookup(self, lookup):
+ if lookup == 'lte':
+ return YearLte
+ else:
+ return super(YearExtract, self).get_lookup(lookup)
+
+
+class YearWithExtract(models.lookups.Year):
+ extract_class = YearExtract
+
+
class InMonth(models.lookups.Lookup):
"""
InMonth matches if the column's month is contained in the value's month.
@@ -134,3 +184,72 @@ def get_rhs_op(self, connection, rhs):
)
finally:
models.Field._unregister_lookup(AnotherEqual)
+
+ def test_div3_extract(self):
+ models.IntegerField.register_lookup(Div3LookupWithExtract)
+ try:
+ a1 = Author.objects.create(name='a1', age=1)
+ a2 = Author.objects.create(name='a2', age=2)
+ a3 = Author.objects.create(name='a3', age=3)
+ a4 = Author.objects.create(name='a4', age=4)
+ baseqs = Author.objects.order_by('name')
+ self.assertQuerysetEqual(
+ baseqs.filter(age__div3__lte=3),
+ [a1, a2, a3, a4], lambda x: x)
+ self.assertQuerysetEqual(
+ baseqs.filter(age__div3__in=[0, 2]),
+ [a2, a3], lambda x: x)
+ finally:
+ models.IntegerField._unregister_lookup(Div3LookupWithExtract)
+
+
+class YearLteTests(TestCase):
+ def setUp(self):
+ models.DateField.register_lookup(YearWithExtract)
+ self.a1 = Author.objects.create(name='a1', birthdate=date(1981, 2, 16))
+ self.a2 = Author.objects.create(name='a2', birthdate=date(2012, 2, 29))
+ self.a3 = Author.objects.create(name='a3', birthdate=date(2012, 1, 31))
+ self.a4 = Author.objects.create(name='a4', birthdate=date(2012, 3, 1))
+
+ def tearDown(self):
+ models.DateField._unregister_lookup(YearWithExtract)
+
+ @unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific SQL used")
+ def test_year_lte(self):
+ baseqs = Author.objects.order_by('name')
+ self.assertQuerysetEqual(
+ baseqs.filter(birthdate__year__lte=2012),
+ [self.a1, self.a2, self.a3, self.a4], lambda x: x)
+ self.assertQuerysetEqual(
+ baseqs.filter(birthdate__year__lte=2011),
+ [self.a1], lambda x: x)
+ # The non-optimized version works, too.
+ self.assertQuerysetEqual(
+ baseqs.filter(birthdate__year__lt=2012),
+ [self.a1], lambda x: x)
+
+ @unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific SQL used")
+ def test_year_lte_fexpr(self):
+ self.a2.age = 2011
+ self.a2.save()
+ self.a3.age = 2012
+ self.a3.save()
+ self.a4.age = 2013
+ self.a4.save()
+ baseqs = Author.objects.order_by('name')
+ self.assertQuerysetEqual(
+ baseqs.filter(birthdate__year__lte=models.F('age')),
+ [self.a3, self.a4], lambda x: x)
+ self.assertQuerysetEqual(
+ baseqs.filter(birthdate__year__lt=models.F('age')),
+ [self.a4], lambda x: x)
+
+ def test_year_lte_sql(self):
+ # This test will just check the generated SQL for __lte. This
+ # doesn't require running on PostgreSQL and spots the most likely
+ # error - not running YearLte SQL at all.
+ baseqs = Author.objects.order_by('name')
+ self.assertIn(
+ '<= (2011 || ', str(baseqs.filter(birthdate__year__lte=2011).query))
+ self.assertIn(
+ '-12-31', str(baseqs.filter(birthdate__year__lte=2011).query))
View
3  tests/null_queries/tests.py
@@ -41,9 +41,6 @@ def test_none_as_null(self):
# Can't use None on anything other than __exact
self.assertRaises(ValueError, Choice.objects.filter, id__gt=None)
- # Can't use None on anything other than __exact
- self.assertRaises(ValueError, Choice.objects.filter, foo__gt=None)
-
# Related managers use __exact=None implicitly if the object hasn't been saved.
p2 = Poll(question="How?")
self.assertEqual(repr(p2.choice_set.all()), '[]')
Please sign in to comment.
Something went wrong with that request. Please try again.