Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

Initial implementation of custom lookups

  • Loading branch information...
commit 4d219d4cdef21d9c14e5d6b9299d583d1975fcba 1 parent 01e8ac4
@akaariai akaariai authored
View
3  django/db/backends/__init__.py
@@ -67,6 +67,9 @@ def __init__(self, settings_dict, alias=DEFAULT_DB_ALIAS,
self.allow_thread_sharing = allow_thread_sharing
self._thread_ident = thread.get_ident()
+ # Compile implementations, used by compiler.compile(someelem)
+ self.compile_implementations = utils.get_implementations(self.vendor)
+
def __eq__(self, other):
if isinstance(other, BaseDatabaseWrapper):
return self.alias == other.alias
View
24 django/db/backends/utils.py
@@ -194,3 +194,27 @@ def format_number(value, max_digits, decimal_places):
return "{0:f}".format(value.quantize(decimal.Decimal(".1") ** decimal_places, context=context))
else:
return "%.*f" % (decimal_places, value)
+
+# Map of vendor name -> map of query element class -> implementation function
+compile_implementations = {}
+
+
+def get_implementations(vendor):
+ try:
+ implementation = compile_implementations[vendor]
+ except KeyError:
+ # TODO: do we need thread safety here? We could easily use an lock...
+ implementation = {}
+ compile_implementations[vendor] = implementation
+ return implementation
+
+
+class add_implementation(object):
+ def __init__(self, klass, vendor):
+ self.klass = klass
+ self.vendor = vendor
+
+ def __call__(self, func):
+ implementations = get_implementations(self.vendor)
+ implementations[self.klass] = func
+ return func
View
4 django/db/models/aggregates.py
@@ -17,8 +17,8 @@ def refs_aggregate(lookup_parts, aggregates):
"""
for i in range(len(lookup_parts) + 1):
if LOOKUP_SEP.join(lookup_parts[0:i]) in aggregates:
- return True
- return False
+ return aggregates[LOOKUP_SEP.join(lookup_parts[0:i])], lookup_parts[i:]
+ return False, ()
class Aggregate(object):
View
32 django/db/models/fields/__init__.py
@@ -4,6 +4,7 @@
import copy
import datetime
import decimal
+import inspect
import math
import warnings
from base64 import b64decode, b64encode
@@ -11,6 +12,7 @@
from django.db import connection
from django.db.models.loading import get_model
+from django.db.models.lookups import default_lookups
from django.db.models.query_utils import QueryWrapper
from django.conf import settings
from django import forms
@@ -101,6 +103,7 @@ class Field(object):
'unique': _('%(model_name)s with this %(field_label)s '
'already exists.'),
}
+ class_lookups = default_lookups.copy()
# Generic field type description, usually overridden by subclasses
def _description(self):
@@ -446,6 +449,30 @@ def get_cache_name(self):
def get_internal_type(self):
return self.__class__.__name__
+ def get_lookup(self, lookup_name):
+ try:
+ return self.class_lookups[lookup_name]
+ except KeyError:
+ for parent in inspect.getmro(self.__class__):
+ if not 'class_lookups' in parent.__dict__:
+ continue
+ if lookup_name in parent.class_lookups:
+ return parent.class_lookups[lookup_name]
+
+ @classmethod
+ def register_lookup(cls, lookup):
+ if not 'class_lookups' in cls.__dict__:
+ cls.class_lookups = {}
+ cls.class_lookups[lookup.lookup_name] = lookup
+
+ @classmethod
+ def _unregister_lookup(cls, lookup):
+ """
+ Removes given lookup from cls lookups. Meant to be used in
+ tests only.
+ """
+ del cls.class_lookups[lookup.lookup_name]
+
def pre_save(self, model_instance, add):
"""
Returns field's value just before saving.
@@ -504,8 +531,7 @@ def get_prep_lookup(self, lookup_type, value):
except ValueError:
raise ValueError("The __year lookup type requires an integer "
"argument")
-
- raise TypeError("Field has invalid lookup: %s" % lookup_type)
+ return self.get_prep_value(value)
def get_db_prep_lookup(self, lookup_type, value, connection,
prepared=False):
@@ -554,6 +580,8 @@ def get_db_prep_lookup(self, lookup_type, value, connection,
return connection.ops.year_lookup_bounds_for_date_field(value)
else:
return [value] # this isn't supposed to happen
+ else:
+ return [value]
def has_default(self):
"""
View
5 django/db/models/fields/related.py
@@ -934,6 +934,11 @@ def set_field_name(self):
# example custom multicolumn joins currently have no remote field).
self.field_name = None
+ def get_lookup_constraint(self, constraint_class, alias, targets, sources, lookup_type,
+ raw_value):
+ return self.field.get_lookup_constraint(constraint_class, alias, targets, sources,
+ lookup_type, raw_value)
+
class ManyToOneRel(ForeignObjectRel):
def __init__(self, field, to, field_name, related_name=None, limit_choices_to=None,
View
242 django/db/models/lookups.py
@@ -0,0 +1,242 @@
+from copy import copy
+
+from django.conf import settings
+from django.utils import timezone
+
+
+class Lookup(object):
+ def __init__(self, constraint_class, lhs, rhs):
+ self.constraint_class, self.lhs, self.rhs = constraint_class, lhs, rhs
+ self.rhs = self.get_prep_lookup()
+
+ def get_db_prep_lookup(self, value, connection):
+ return (
+ '%s', self.lhs.output_type.get_db_prep_lookup(
+ self.lookup_name, value, connection, prepared=True))
+
+ def get_prep_lookup(self):
+ return self.lhs.output_type.get_prep_lookup(self.lookup_name, self.rhs)
+
+ def process_lhs(self, qn, connection):
+ return qn.compile(self.lhs)
+
+ def process_rhs(self, qn, connection):
+ value = self.rhs
+ # Due to historical reasons there are a couple of different
+ # ways to produce sql here. get_compiler is likely a Query
+ # instance, _as_sql QuerySet and as_sql just something with
+ # as_sql. Finally the value can of course be just plain
+ # Python value.
+ if hasattr(value, 'get_compiler'):
+ value = value.get_compiler(connection=connection)
+ if hasattr(value, 'as_sql'):
+ sql, params = qn.compile(value)
+ return '(' + sql + ')', params
+ if hasattr(value, '_as_sql'):
+ sql, params = value._as_sql(connection=connection)
+ return '(' + sql + ')', params
+ else:
+ return self.get_db_prep_lookup(value, connection)
+
+ def relabeled_clone(self, relabels):
+ new = copy(self)
+ new.lhs = new.lhs.relabeled_clone(relabels)
+ if hasattr(new.rhs, 'relabeled_clone'):
+ new.rhs = new.rhs.relabeled_clone(relabels)
+ return new
+
+ def get_cols(self):
+ cols = self.lhs.get_cols()
+ if hasattr(self.rhs, 'get_cols'):
+ cols.extend(self.rhs.get_cols())
+ return cols
+
+ def as_sql(self, qn, connection):
+ raise NotImplementedError
+
+
+class DjangoLookup(Lookup):
+ def as_sql(self, qn, connection):
+ lhs_sql, params = self.process_lhs(qn, connection)
+ field_internal_type = self.lhs.output_type.get_internal_type()
+ db_type = self.lhs.output_type
+ lhs_sql = connection.ops.field_cast_sql(db_type, field_internal_type) % lhs_sql
+ lhs_sql = connection.ops.lookup_cast(self.lookup_name) % lhs_sql
+ rhs_sql, rhs_params = self.process_rhs(qn, connection)
+ params.extend(rhs_params)
+ operator_plus_rhs = self.get_rhs_op(connection, rhs_sql)
+ return '%s %s' % (lhs_sql, operator_plus_rhs), params
+
+ def get_rhs_op(self, connection, rhs):
+ return connection.operators[self.lookup_name] % rhs
+
+
+default_lookups = {}
+
+
+class Exact(DjangoLookup):
+ lookup_name = 'exact'
+default_lookups['exact'] = Exact
+
+
+class IExact(DjangoLookup):
+ lookup_name = 'iexact'
+default_lookups['iexact'] = IExact
+
+
+class Contains(DjangoLookup):
+ lookup_name = 'contains'
+default_lookups['contains'] = Contains
+
+
+class IContains(DjangoLookup):
+ lookup_name = 'icontains'
+default_lookups['icontains'] = IContains
+
+
+class GreaterThan(DjangoLookup):
+ lookup_name = 'gt'
+default_lookups['gt'] = GreaterThan
+
+
+class GreaterThanOrEqual(DjangoLookup):
+ lookup_name = 'gte'
+default_lookups['gte'] = GreaterThanOrEqual
+
+
+class LessThan(DjangoLookup):
+ lookup_name = 'lt'
+default_lookups['lt'] = LessThan
+
+
+class LessThanOrEqual(DjangoLookup):
+ lookup_name = 'lte'
+default_lookups['lte'] = LessThanOrEqual
+
+
+class In(DjangoLookup):
+ lookup_name = 'in'
+
+ def get_db_prep_lookup(self, value, connection):
+ params = self.lhs.field.get_db_prep_lookup(
+ self.lookup_name, value, connection, prepared=True)
+ if not params:
+ # TODO: check why this leads to circular import
+ from django.db.models.sql.datastructures import EmptyResultSet
+ raise EmptyResultSet
+ placeholder = '(' + ', '.join('%s' for p in params) + ')'
+ return (placeholder, params)
+
+ def get_rhs_op(self, connection, rhs):
+ return 'IN %s' % rhs
+default_lookups['in'] = In
+
+
+class StartsWith(DjangoLookup):
+ lookup_name = 'startswith'
+default_lookups['startswith'] = StartsWith
+
+
+class IStartsWith(DjangoLookup):
+ lookup_name = 'istartswith'
+default_lookups['istartswith'] = IStartsWith
+
+
+class EndsWith(DjangoLookup):
+ lookup_name = 'endswith'
+default_lookups['endswith'] = EndsWith
+
+
+class IEndsWith(DjangoLookup):
+ lookup_name = 'iendswith'
+default_lookups['iendswith'] = IEndsWith
+
+
+class Between(DjangoLookup):
+ def get_rhs_op(self, connection, rhs):
+ return "BETWEEN %s AND %s" % (rhs, rhs)
+
+
+class Year(Between):
+ lookup_name = 'year'
+default_lookups['year'] = Year
+
+
+class Range(Between):
+ lookup_name = 'range'
+default_lookups['range'] = Range
+
+
+class DateLookup(DjangoLookup):
+
+ def process_lhs(self, qn, connection):
+ lhs, params = super(DateLookup, self).process_lhs(qn, connection)
+ tzname = timezone.get_current_timezone_name() if settings.USE_TZ else None
+ sql, tz_params = connection.ops.datetime_extract_sql(self.extract_type, lhs, tzname)
+ return connection.ops.lookup_cast(self.lookup_name) % sql, tz_params
+
+ def get_rhs_op(self, connection, rhs):
+ return '= %s' % rhs
+
+
+class Month(DateLookup):
+ lookup_name = 'month'
+ extract_type = 'month'
+default_lookups['month'] = Month
+
+
+class Day(DateLookup):
+ lookup_name = 'day'
+ extract_type = 'day'
+default_lookups['day'] = Day
+
+
+class WeekDay(DateLookup):
+ lookup_name = 'week_day'
+ extract_type = 'week_day'
+default_lookups['week_day'] = WeekDay
+
+
+class Hour(DateLookup):
+ lookup_name = 'hour'
+ extract_type = 'hour'
+default_lookups['hour'] = Hour
+
+
+class Minute(DateLookup):
+ lookup_name = 'minute'
+ extract_type = 'minute'
+default_lookups['minute'] = Minute
+
+
+class Second(DateLookup):
+ lookup_name = 'second'
+ extract_type = 'second'
+default_lookups['second'] = Second
+
+
+class IsNull(DjangoLookup):
+ lookup_name = 'isnull'
+
+ def as_sql(self, qn, connection):
+ sql, params = qn.compile(self.lhs)
+ if self.rhs:
+ return "%s IS NULL" % sql, params
+ else:
+ return "%s IS NOT NULL" % sql, params
+default_lookups['isnull'] = IsNull
+
+
+class Search(DjangoLookup):
+ lookup_name = 'search'
+default_lookups['search'] = Search
+
+
+class Regex(DjangoLookup):
+ lookup_name = 'regex'
+default_lookups['regex'] = Regex
+
+
+class IRegex(DjangoLookup):
+ lookup_name = 'iregex'
+default_lookups['iregex'] = IRegex
View
7 django/db/models/sql/aggregates.py
@@ -93,6 +93,13 @@ def as_sql(self, qn, connection):
return self.sql_template % substitutions, params
+ def get_cols(self):
+ return []
+
+ @property
+ def output_type(self):
+ return self.field
+
class Avg(Aggregate):
is_computed = True
View
59 django/db/models/sql/compiler.py
@@ -45,7 +45,7 @@ def pre_sql_setup(self):
if self.query.select_related and not self.query.related_select_cols:
self.fill_related_selections()
- def quote_name_unless_alias(self, name):
+ def __call__(self, name):
"""
A wrapper around connection.ops.quote_name that doesn't quote aliases
for table names. This avoids problems with some SQL dialects that treat
@@ -61,6 +61,20 @@ def quote_name_unless_alias(self, name):
self.quote_cache[name] = r
return r
+ def quote_name_unless_alias(self, name):
+ """
+ A wrapper around connection.ops.quote_name that doesn't quote aliases
+ for table names. This avoids problems with some SQL dialects that treat
+ quoted strings specially (e.g. PostgreSQL).
+ """
+ return self(name)
+
+ def compile(self, node):
+ if node.__class__ in self.connection.compile_implementations:
+ return self.connection.compile_implementations[node.__class__](node, self)
+ else:
+ return node.as_sql(self, self.connection)
+
def as_sql(self, with_limits=True, with_col_aliases=False):
"""
Creates the SQL for this query. Returns the SQL string and list of
@@ -88,10 +102,8 @@ def as_sql(self, with_limits=True, with_col_aliases=False):
# docstring of get_from_clause() for details.
from_, f_params = self.get_from_clause()
- qn = self.quote_name_unless_alias
-
- where, w_params = self.query.where.as_sql(qn=qn, connection=self.connection)
- having, h_params = self.query.having.as_sql(qn=qn, connection=self.connection)
+ where, w_params = self.compile(self.query.where)
+ having, h_params = self.compile(self.query.having)
having_group_by = self.query.having.get_cols()
params = []
for val in six.itervalues(self.query.extra_select):
@@ -180,7 +192,7 @@ def get_columns(self, with_aliases=False):
(without the table names) are given unique aliases. This is needed in
some cases to avoid ambiguity with nested queries.
"""
- qn = self.quote_name_unless_alias
+ qn = self
qn2 = self.connection.ops.quote_name
result = ['(%s) AS %s' % (col[0], qn2(alias)) for alias, col in six.iteritems(self.query.extra_select)]
params = []
@@ -213,7 +225,7 @@ def get_columns(self, with_aliases=False):
aliases.add(r)
col_aliases.add(col[1])
else:
- col_sql, col_params = col.as_sql(qn, self.connection)
+ col_sql, col_params = self.compile(col)
result.append(col_sql)
params.extend(col_params)
@@ -229,7 +241,7 @@ def get_columns(self, with_aliases=False):
max_name_length = self.connection.ops.max_name_length()
for alias, aggregate in self.query.aggregate_select.items():
- agg_sql, agg_params = aggregate.as_sql(qn, self.connection)
+ agg_sql, agg_params = self.compile(aggregate)
if alias is None:
result.append(agg_sql)
else:
@@ -267,7 +279,7 @@ def get_default_columns(self, with_aliases=False, col_aliases=None,
result = []
if opts is None:
opts = self.query.get_meta()
- qn = self.quote_name_unless_alias
+ qn = self
qn2 = self.connection.ops.quote_name
aliases = set()
only_load = self.deferred_to_columns()
@@ -312,7 +324,7 @@ def get_distinct(self):
Note that this method can alter the tables in the query, and thus it
must be called before get_from_clause().
"""
- qn = self.quote_name_unless_alias
+ qn = self
qn2 = self.connection.ops.quote_name
result = []
opts = self.query.get_meta()
@@ -345,7 +357,7 @@ def get_ordering(self):
ordering = (self.query.order_by
or self.query.get_meta().ordering
or [])
- qn = self.quote_name_unless_alias
+ qn = self
qn2 = self.connection.ops.quote_name
distinct = self.query.distinct
select_aliases = self._select_aliases
@@ -483,7 +495,7 @@ def get_from_clause(self):
ordering and distinct must be done first.
"""
result = []
- qn = self.quote_name_unless_alias
+ qn = self
qn2 = self.connection.ops.quote_name
first = True
from_params = []
@@ -501,8 +513,7 @@ def get_from_clause(self):
extra_cond = join_field.get_extra_restriction(
self.query.where_class, alias, lhs)
if extra_cond:
- extra_sql, extra_params = extra_cond.as_sql(
- qn, self.connection)
+ extra_sql, extra_params = self.compile(extra_cond)
extra_sql = 'AND (%s)' % extra_sql
from_params.extend(extra_params)
else:
@@ -534,7 +545,7 @@ def get_grouping(self, having_group_by, ordering_group_by):
"""
Returns a tuple representing the SQL elements in the "group by" clause.
"""
- qn = self.quote_name_unless_alias
+ qn = self
result, params = [], []
if self.query.group_by is not None:
select_cols = self.query.select + self.query.related_select_cols
@@ -553,7 +564,7 @@ def get_grouping(self, having_group_by, ordering_group_by):
if isinstance(col, (list, tuple)):
sql = '%s.%s' % (qn(col[0]), qn(col[1]))
elif hasattr(col, 'as_sql'):
- sql, col_params = col.as_sql(qn, self.connection)
+ self.compile(col)
else:
sql = '(%s)' % str(col)
if sql not in seen:
@@ -776,7 +787,7 @@ def execute_sql(self, result_type=MULTI):
return result
def as_subquery_condition(self, alias, columns, qn):
- inner_qn = self.quote_name_unless_alias
+ inner_qn = self
qn2 = self.connection.ops.quote_name
if len(columns) == 1:
sql, params = self.as_sql()
@@ -887,9 +898,9 @@ def as_sql(self):
"""
assert len(self.query.tables) == 1, \
"Can only delete from one table at a time."
- qn = self.quote_name_unless_alias
+ qn = self
result = ['DELETE FROM %s' % qn(self.query.tables[0])]
- where, params = self.query.where.as_sql(qn=qn, connection=self.connection)
+ where, params = self.compile(self.query.where)
if where:
result.append('WHERE %s' % where)
return ' '.join(result), tuple(params)
@@ -905,7 +916,7 @@ def as_sql(self):
if not self.query.values:
return '', ()
table = self.query.tables[0]
- qn = self.quote_name_unless_alias
+ qn = self
result = ['UPDATE %s' % qn(table)]
result.append('SET')
values, update_params = [], []
@@ -925,7 +936,7 @@ def as_sql(self):
val = SQLEvaluator(val, self.query, allow_joins=False)
name = field.column
if hasattr(val, 'as_sql'):
- sql, params = val.as_sql(qn, self.connection)
+ sql, params = self.compile(val)
values.append('%s = %s' % (qn(name), sql))
update_params.extend(params)
elif val is not None:
@@ -936,7 +947,7 @@ def as_sql(self):
if not values:
return '', ()
result.append(', '.join(values))
- where, params = self.query.where.as_sql(qn=qn, connection=self.connection)
+ where, params = self.compile(self.query.where)
if where:
result.append('WHERE %s' % where)
return ' '.join(result), tuple(update_params + params)
@@ -1016,11 +1027,11 @@ def as_sql(self, qn=None):
parameters.
"""
if qn is None:
- qn = self.quote_name_unless_alias
+ qn = self
sql, params = [], []
for aggregate in self.query.aggregate_select.values():
- agg_sql, agg_params = aggregate.as_sql(qn, self.connection)
+ agg_sql, agg_params = self.compile(aggregate)
sql.append(agg_sql)
params.extend(agg_params)
sql = ', '.join(sql)
View
21 django/db/models/sql/datastructures.py
@@ -5,18 +5,25 @@
class Col(object):
- def __init__(self, alias, col):
- self.alias = alias
- self.col = col
+ def __init__(self, alias, target, source):
+ self.alias, self.target, self.source = alias, target, source
def as_sql(self, qn, connection):
- return '%s.%s' % (qn(self.alias), self.col), []
+ return "%s.%s" % (qn(self.alias), qn(self.target.column)), []
+
+ @property
+ def output_type(self):
+ return self.source
- def prepare(self):
- return self
+ @property
+ def field(self):
+ return self.source
def relabeled_clone(self, relabels):
- return self.__class__(relabels.get(self.alias, self.alias), self.col)
+ return self.__class__(relabels.get(self.alias, self.alias), self.target, self.source)
+
+ def get_cols(self):
+ return [(self.alias, self.target.column)]
class EmptyResultSet(Exception):
View
126 django/db/models/sql/query.py
@@ -1030,6 +1030,12 @@ def add_aggregate(self, aggregate, model, alias, is_summary):
def prepare_lookup_value(self, value, lookup_type, can_reuse):
# Interpret '__exact=None' as the sql 'is NULL'; otherwise, reject all
# uses of None as a query value.
+ if len(lookup_type) > 1:
+ raise FieldError('Nested lookups not allowed')
+ elif len(lookup_type) == 0:
+ lookup_type = 'exact'
+ else:
+ lookup_type = lookup_type[0]
if value is None:
if lookup_type != 'exact':
raise ValueError("Cannot use None as a query value")
@@ -1060,31 +1066,39 @@ def solve_lookup_type(self, lookup):
"""
Solve the lookup type from the lookup (eg: 'foobar__id__icontains')
"""
- lookup_type = 'exact' # Default lookup type
- lookup_parts = lookup.split(LOOKUP_SEP)
- num_parts = len(lookup_parts)
- if (len(lookup_parts) > 1 and lookup_parts[-1] in self.query_terms
- and (not self._aggregates or lookup not in self._aggregates)):
- # Traverse the lookup query to distinguish related fields from
- # lookup types.
- lookup_model = self.model
- for counter, field_name in enumerate(lookup_parts):
- try:
- lookup_field = lookup_model._meta.get_field(field_name)
- except FieldDoesNotExist:
- # Not a field. Bail out.
- lookup_type = lookup_parts.pop()
- break
- # Unless we're at the end of the list of lookups, let's attempt
- # to continue traversing relations.
- if (counter + 1) < num_parts:
- try:
- lookup_model = lookup_field.rel.to
- except AttributeError:
- # Not a related field. Bail out.
- lookup_type = lookup_parts.pop()
- break
- return lookup_type, lookup_parts
+ lookup_splitted = lookup.split(LOOKUP_SEP)
+ aggregate, aggregate_lookups = refs_aggregate(lookup_splitted, self.aggregates)
+ if aggregate:
+ if len(aggregate_lookups) > 1:
+ raise FieldError("Nested lookups not allowed.")
+ return aggregate_lookups, (), aggregate
+ _, field, _, lookup_parts = self.names_to_path(lookup_splitted, self.get_meta())
+ field_parts = lookup_splitted[0:len(lookup_splitted) - len(lookup_parts)]
+ if len(lookup_parts) == 0:
+ lookup_parts = ['exact']
+ elif len(lookup_parts) > 1:
+ if field_parts:
+ raise FieldError(
+ 'Only one lookup part allowed (found path "%s" from "%s").' %
+ (LOOKUP_SEP.join(field_parts), lookup))
+ else:
+ raise FieldError(
+ 'Invalid lookup "%s" for model %s".' %
+ (lookup, self.get_meta().model.__name__))
+ else:
+ if not hasattr(field, 'get_lookup_constraint'):
+ lookup_class = field.get_lookup(lookup_parts[0])
+ if lookup_class is None and lookup_parts[0] not in self.query_terms:
+ raise FieldError(
+ 'Invalid lookup name %s' % lookup_parts[0])
+ return lookup_parts, field_parts, False
+
+ def build_lookup(self, lookup_type, lhs, rhs):
+ if hasattr(lhs.output_type, 'get_lookup'):
+ lookup = lhs.output_type.get_lookup(lookup_type)
+ if lookup:
+ return lookup(self.where_class, lhs, rhs)
+ return None
def build_filter(self, filter_expr, branch_negated=False, current_negated=False,
can_reuse=None, connector=AND):
@@ -1114,9 +1128,9 @@ def build_filter(self, filter_expr, branch_negated=False, current_negated=False,
is responsible for unreffing the joins used.
"""
arg, value = filter_expr
- lookup_type, parts = self.solve_lookup_type(arg)
- if not parts:
+ if not arg:
raise FieldError("Cannot parse keyword query %r" % arg)
+ lookup_type, parts, reffed_aggregate = self.solve_lookup_type(arg)
# Work out the lookup type and remove it from the end of 'parts',
# if necessary.
@@ -1124,11 +1138,13 @@ def build_filter(self, filter_expr, branch_negated=False, current_negated=False,
used_joins = getattr(value, '_used_joins', [])
clause = self.where_class()
- if self._aggregates:
- for alias, aggregate in self.aggregates.items():
- if alias in (parts[0], LOOKUP_SEP.join(parts)):
- clause.add((aggregate, lookup_type, value), AND)
- return clause, []
+ if reffed_aggregate:
+ condition = self.build_lookup(lookup_type, reffed_aggregate, value)
+ if not condition:
+ # Backwards compat for custom lookups
+ condition = (reffed_aggregate, lookup_type, value)
+ clause.add(condition, AND)
+ return clause, []
opts = self.get_meta()
alias = self.get_initial_alias()
@@ -1150,11 +1166,18 @@ def build_filter(self, filter_expr, branch_negated=False, current_negated=False,
targets, alias, join_list = self.trim_joins(sources, join_list, path)
if hasattr(field, 'get_lookup_constraint'):
- constraint = field.get_lookup_constraint(self.where_class, alias, targets, sources,
- lookup_type, value)
+ # For now foreign keys get special treatment. This should be
+ # refactored when composite fields lands.
+ condition = field.get_lookup_constraint(self.where_class, alias, targets, sources,
+ lookup_type, value)
else:
- constraint = (Constraint(alias, targets[0].column, field), lookup_type, value)
- clause.add(constraint, AND)
+ assert(len(targets) == 1)
+ col = Col(alias, targets[0], field)
+ condition = self.build_lookup(lookup_type, col, value)
+ if not condition:
+ # Backwards compat for custom lookups
+ condition = (Constraint(alias, targets[0].column, field), lookup_type, value)
+ clause.add(condition, AND)
require_outer = lookup_type == 'isnull' and value is True and not current_negated
if current_negated and (lookup_type != 'isnull' or value is False):
@@ -1185,7 +1208,7 @@ def need_having(self, obj):
if not self._aggregates:
return False
if not isinstance(obj, Node):
- return (refs_aggregate(obj[0].split(LOOKUP_SEP), self.aggregates)
+ return (refs_aggregate(obj[0].split(LOOKUP_SEP), self.aggregates)[0]
or (hasattr(obj[1], 'contains_aggregate')
and obj[1].contains_aggregate(self.aggregates)))
return any(self.need_having(c) for c in obj.children)
@@ -1273,7 +1296,7 @@ def _add_q(self, q_object, used_aliases, branch_negated=False,
needed_inner = joinpromoter.update_join_types(self)
return target_clause, needed_inner
- def names_to_path(self, names, opts, allow_many):
+ def names_to_path(self, names, opts, allow_many=True):
"""
Walks the names path and turns them PathInfo tuples. Note that a
single name in 'names' can generate multiple PathInfos (m2m for
@@ -1293,9 +1316,10 @@ def names_to_path(self, names, opts, allow_many):
try:
field, model, direct, m2m = opts.get_field_by_name(name)
except FieldDoesNotExist:
- available = opts.get_all_field_names() + list(self.aggregate_select)
- raise FieldError("Cannot resolve keyword %r into field. "
- "Choices are: %s" % (name, ", ".join(available)))
+ # We didn't found the current field, so move position back
+ # one step.
+ pos -= 1
+ break
# Check if we need any joins for concrete inheritance cases (the
# field lives in parent, but we are currently in one of its
# children)
@@ -1330,15 +1354,9 @@ def names_to_path(self, names, opts, allow_many):
final_field = field
targets = (field,)
break
-
- if pos != len(names) - 1:
- if pos == len(names) - 2:
- raise FieldError(
- "Join on field %r not permitted. Did you misspell %r for "
- "the lookup type?" % (name, names[pos + 1]))
- else:
- raise FieldError("Join on field %r not permitted." % name)
- return path, final_field, targets
+ if pos == -1:
+ raise FieldError('Whazaa')
+ return path, final_field, targets, names[pos + 1:]
def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True):
"""
@@ -1367,8 +1385,10 @@ def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True):
"""
joins = [alias]
# First, generate the path for the names
- path, final_field, targets = self.names_to_path(
+ path, final_field, targets, rest = self.names_to_path(
names, opts, allow_many)
+ if rest:
+ raise FieldError('Invalid lookup')
# Then, add the path to the query's joins. Note that we can't trim
# joins at this stage - we will need the information about join type
# of the trimmed joins.
@@ -1383,8 +1403,6 @@ def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True):
alias = self.join(
connection, reuse=reuse, nullable=nullable, join_field=join.join_field)
joins.append(alias)
- if hasattr(final_field, 'field'):
- final_field = final_field.field
return final_field, targets, opts, joins, path
def trim_joins(self, targets, joins, path):
@@ -1455,7 +1473,7 @@ def split_exclude(self, filter_expr, prefix, can_reuse, names_with_path):
query.bump_prefix(self)
query.where.add(
(Constraint(query.select[0].col[0], pk.column, pk),
- 'exact', Col(alias, pk.column)),
+ 'exact', Col(alias, pk, pk)),
AND
)
View
10 django/db/models/sql/where.py
@@ -101,7 +101,7 @@ def as_sql(self, qn, connection):
for child in self.children:
try:
if hasattr(child, 'as_sql'):
- sql, params = child.as_sql(qn=qn, connection=connection)
+ sql, params = qn.compile(child)
else:
# A leaf node in the tree.
sql, params = self.make_atom(child, qn, connection)
@@ -193,13 +193,13 @@ def make_atom(self, child, qn, connection):
field_sql, field_params = self.sql_for_columns(lvalue, qn, connection, field_internal_type), []
else:
# A smart object with an as_sql() method.
- field_sql, field_params = lvalue.as_sql(qn, connection)
+ field_sql, field_params = qn.compile(lvalue)
is_datetime_field = value_annotation is datetime.datetime
cast_sql = connection.ops.datetime_cast_sql() if is_datetime_field else '%s'
if hasattr(params, 'as_sql'):
- extra, params = params.as_sql(qn, connection)
+ extra, params = qn.compile(params)
cast_sql = ''
else:
extra = ''
@@ -282,6 +282,8 @@ def relabel_aliases(self, change_map):
if hasattr(child, 'relabel_aliases'):
# For example another WhereNode
child.relabel_aliases(change_map)
+ elif hasattr(child, 'relabeled_clone'):
+ self.children[pos] = child.relabeled_clone(change_map)
elif isinstance(child, (list, tuple)):
# tuple starting with Constraint
child = (child[0].relabeled_clone(change_map),) + child[1:]
@@ -350,7 +352,7 @@ def __init__(self, alias, col, field):
self.alias, self.col, self.field = alias, col, field
def prepare(self, lookup_type, value):
- if self.field:
+ if self.field and not hasattr(value, 'as_sql'):
return self.field.get_prep_lookup(lookup_type, value)
return value
View
2  tests/aggregation/tests.py
@@ -443,7 +443,7 @@ def test_annotation(self):
vals = Author.objects.filter(pk=1).aggregate(Count("friends__id"))
self.assertEqual(vals, {"friends__id__count": 2})
- books = Book.objects.annotate(num_authors=Count("authors__name")).filter(num_authors__ge=2).order_by("pk")
+ books = Book.objects.annotate(num_authors=Count("authors__name")).filter(num_authors__exact=2).order_by("pk")
self.assertQuerysetEqual(
books, [
"The Definitive Guide to Django: Web Development Done Right",
View
0  tests/custom_lookups/__init__.py
No changes.
View
7 tests/custom_lookups/models.py
@@ -0,0 +1,7 @@
+from django.db import models
+
+
+class Author(models.Model):
+ name = models.CharField(max_length=20)
+ age = models.IntegerField(null=True)
+ birthdate = models.DateField(null=True)
View
136 tests/custom_lookups/tests.py
@@ -0,0 +1,136 @@
+from copy import copy
+from datetime import date
+import unittest
+
+from django.test import TestCase
+from .models import Author
+from django.db import models
+from django.db import connection
+from django.db.backends.utils import add_implementation
+
+
+class Div3Lookup(models.lookups.Lookup):
+ lookup_name = 'div3'
+
+ def as_sql(self, qn, connection):
+ lhs, params = self.process_lhs(qn, connection)
+ rhs, rhs_params = self.process_rhs(qn, connection)
+ params.extend(rhs_params)
+ return '%s %%%% 3 = %s' % (lhs, rhs), params
+
+
+class InMonth(models.lookups.Lookup):
+ """
+ InMonth matches if the column's month is contained in the value's month.
+ """
+ lookup_name = 'inmonth'
+
+ def as_sql(self, qn, connection):
+ lhs, params = self.process_lhs(qn, connection)
+ rhs, rhs_params = self.process_rhs(qn, connection)
+ # We need to be careful so that we get the params in right
+ # places.
+ full_params = params[:]
+ full_params.extend(rhs_params)
+ full_params.extend(params)
+ full_params.extend(rhs_params)
+ return ("%s >= date_trunc('month', %s) and "
+ "%s < date_trunc('month', %s) + interval '1 months'" %
+ (lhs, rhs, lhs, rhs), full_params)
+
+
+class LookupTests(TestCase):
+ def test_basic_lookup(self):
+ a1 = Author.objects.create(name='a1', age=1)
+ a2 = Author.objects.create(name='a2', age=2)
+ a3 = Author.objects.create(name='a3', age=3)
+ a4 = Author.objects.create(name='a4', age=4)
+ models.IntegerField.register_lookup(Div3Lookup)
+ try:
+ self.assertQuerysetEqual(
+ Author.objects.filter(age__div3=0),
+ [a3], lambda x: x
+ )
+ self.assertQuerysetEqual(
+ Author.objects.filter(age__div3=1).order_by('age'),
+ [a1, a4], lambda x: x
+ )
+ self.assertQuerysetEqual(
+ Author.objects.filter(age__div3=2),
+ [a2], lambda x: x
+ )
+ self.assertQuerysetEqual(
+ Author.objects.filter(age__div3=3),
+ [], lambda x: x
+ )
+ finally:
+ models.IntegerField._unregister_lookup(Div3Lookup)
+
+ @unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific SQL used")
+ def test_birthdate_month(self):
+ a1 = Author.objects.create(name='a1', birthdate=date(1981, 2, 16))
+ a2 = Author.objects.create(name='a2', birthdate=date(2012, 2, 29))
+ a3 = Author.objects.create(name='a3', birthdate=date(2012, 1, 31))
+ a4 = Author.objects.create(name='a4', birthdate=date(2012, 3, 1))
+ models.DateField.register_lookup(InMonth)
+ try:
+ self.assertQuerysetEqual(
+ Author.objects.filter(birthdate__inmonth=date(2012, 1, 15)),
+ [a3], lambda x: x
+ )
+ self.assertQuerysetEqual(
+ Author.objects.filter(birthdate__inmonth=date(2012, 2, 1)),
+ [a2], lambda x: x
+ )
+ self.assertQuerysetEqual(
+ Author.objects.filter(birthdate__inmonth=date(1981, 2, 28)),
+ [a1], lambda x: x
+ )
+ self.assertQuerysetEqual(
+ Author.objects.filter(birthdate__inmonth=date(2012, 3, 12)),
+ [a4], lambda x: x
+ )
+ self.assertQuerysetEqual(
+ Author.objects.filter(birthdate__inmonth=date(2012, 4, 1)),
+ [], lambda x: x
+ )
+ finally:
+ models.DateField._unregister_lookup(InMonth)
+
+ def test_custom_compiles(self):
+ a1 = Author.objects.create(name='a1', age=1)
+ a2 = Author.objects.create(name='a2', age=2)
+ a3 = Author.objects.create(name='a3', age=3)
+ a4 = Author.objects.create(name='a4', age=4)
+
+ class AnotherEqual(models.lookups.Exact):
+ lookup_name = 'anotherequal'
+ models.Field.register_lookup(AnotherEqual)
+ try:
+ @add_implementation(AnotherEqual, connection.vendor)
+ def custom_eq_sql(node, compiler):
+ return '1 = 1', []
+
+ self.assertIn('1 = 1', str(Author.objects.filter(name__anotherequal='asdf').query))
+ self.assertQuerysetEqual(
+ Author.objects.filter(name__anotherequal='asdf').order_by('name'),
+ [a1, a2, a3, a4], lambda x: x)
+
+ @add_implementation(AnotherEqual, connection.vendor)
+ def another_custom_eq_sql(node, compiler):
+ # If you need to override one method, it seems this is the best
+ # option.
+ node = copy(node)
+
+ class OverriddenAnotherEqual(AnotherEqual):
+ def get_rhs_op(self, connection, rhs):
+ return ' <> %s'
+ node.__class__ = OverriddenAnotherEqual
+ return node.as_sql(compiler, compiler.connection)
+ self.assertIn(' <> ', str(Author.objects.filter(name__anotherequal='a1').query))
+ self.assertQuerysetEqual(
+ Author.objects.filter(name__anotherequal='a1').order_by('name'),
+ [a2, a3, a4], lambda x: x
+ )
+ finally:
+ models.Field._unregister_lookup(AnotherEqual)
View
13 tests/queries/tests.py
@@ -2620,8 +2620,15 @@ class DummyNode(object):
def as_sql(self, qn, connection):
return 'dummy', []
+ class MockCompiler(object):
+ def compile(self, node):
+ return node.as_sql(self, connection)
+
+ def __call__(self, name):
+ return connection.ops.quote_name(name)
+
def test_empty_full_handling_conjunction(self):
- qn = connection.ops.quote_name
+ qn = WhereNodeTest.MockCompiler()
w = WhereNode(children=[EverythingNode()])
self.assertEqual(w.as_sql(qn, connection), ('', []))
w.negate()
@@ -2646,7 +2653,7 @@ def test_empty_full_handling_conjunction(self):
self.assertEqual(w.as_sql(qn, connection), ('', []))
def test_empty_full_handling_disjunction(self):
- qn = connection.ops.quote_name
+ qn = WhereNodeTest.MockCompiler()
w = WhereNode(children=[EverythingNode()], connector='OR')
self.assertEqual(w.as_sql(qn, connection), ('', []))
w.negate()
@@ -2673,7 +2680,7 @@ def test_empty_full_handling_disjunction(self):
self.assertEqual(w.as_sql(qn, connection), ('NOT (dummy)', []))
def test_empty_nodes(self):
- qn = connection.ops.quote_name
+ qn = WhereNodeTest.MockCompiler()
empty_w = WhereNode()
w = WhereNode(children=[empty_w, empty_w])
self.assertEqual(w.as_sql(qn, connection), (None, []))
Please sign in to comment.
Something went wrong with that request. Please try again.