Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

Added documentation, polished implementation

  • Loading branch information...
commit 2adf50428d59a783078b0da3d5d035106640c899 1 parent 32c0435
Anssi Kääriäinen akaariai authored
15 django/db/models/lookups.py
View
@@ -7,8 +7,8 @@
class Extract(object):
- def __init__(self, constraint_class, lhs):
- self.constraint_class, self.lhs = constraint_class, lhs
+ def __init__(self, lhs):
+ self.lhs = lhs
def get_lookup(self, lookup):
return self.output_type.get_lookup(lookup)
@@ -21,15 +21,18 @@ def output_type(self):
return self.lhs.output_type
def relabeled_clone(self, relabels):
- return self.__class__(self.constraint_class, self.lhs.relabeled_clone(relabels))
+ return self.__class__(self.lhs.relabeled_clone(relabels))
+
+ def get_cols(self):
+ return self.lhs.get_cols()
class Lookup(object):
lookup_name = None
extract_class = None
- def __init__(self, constraint_class, lhs, rhs):
- self.constraint_class, self.lhs, self.rhs = constraint_class, lhs, rhs
+ def __init__(self, lhs, rhs):
+ self.lhs, self.rhs = lhs, rhs
if rhs is None:
if not self.extract_class:
raise FieldError("Lookup '%s' doesn't support nesting." % self.lookup_name)
@@ -37,7 +40,7 @@ def __init__(self, constraint_class, lhs, rhs):
self.rhs = self.get_prep_lookup()
def get_extract(self):
- return self.extract_class(self.constraint_class, self.lhs)
+ return self.extract_class(self.lhs)
def get_prep_lookup(self):
return self.lhs.output_type.get_prep_lookup(self.lookup_name, self.rhs)
3  django/db/models/sql/compiler.py
View
@@ -71,7 +71,8 @@ def quote_name_unless_alias(self, name):
def compile(self, node):
if node.__class__ in self.connection.compile_implementations:
- return self.connection.compile_implementations[node.__class__](node, self)
+ return self.connection.compile_implementations[node.__class__](
+ node, self, self.connection)
else:
return node.as_sql(self, self.connection)
8 django/db/models/sql/query.py
View
@@ -18,6 +18,7 @@
from django.db.models.aggregates import refs_aggregate
from django.db.models.expressions import ExpressionNode
from django.db.models.fields import FieldDoesNotExist
+from django.db.models.lookups import Extract
from django.db.models.query_utils import Q
from django.db.models.related import PathInfo
from django.db.models.sql import aggregates as base_aggregates_module
@@ -1088,9 +1089,12 @@ def build_lookup(self, lookups, lhs, rhs):
if next:
if not lookups:
# This was the last lookup, so return value lookup.
- return next(self.where_class, lhs, rhs)
+ if issubclass(next, Extract):
+ lhs = next(lhs)
+ next = lhs.get_lookup('exact')
+ return next(lhs, rhs)
else:
- lhs = next(self.where_class, lhs, None).get_extract()
+ lhs = next(lhs)
# A field's get_lookup() can return None to opt for backwards
# compatibility path.
elif len(lookups) > 1:
243 docs/ref/models/lookups.txt
View
@@ -0,0 +1,243 @@
+==============
+Custom lookups
+==============
+
+.. module:: django.db.models.lookups
+ :synopsis: Custom lookups
+
+.. currentmodule:: django.db.models
+
+(This documentation is candidate for complete rewrite, but contains
+useful information of how to test the current implementation.)
+
+This documentation constains instructions of how to create custom lookups
+for model fields.
+
+Django's ORM works using lookup paths when building query filters and other
+query structures. For example in the query Book.filter(author__age__lte=30)
+the author__age__lte is the lookup path.
+
+The lookup path consist of three different part. First is the related lookups,
+above part author refers to Book's related model Author. Second part of the
+lookup path is the final field, above this is Author's field age. Finally the
+lte part is commonly called just lookup (TODO: this nomenclature is confusing,
+can we invent something better).
+
+This documentation concentrates on writing custom lookups, that is custom
+implementations for lte or any other lookup you wish to use.
+
+Django will fetch a ``Lookup`` class from the final field using the field's
+method get_lookup(lookup_name). This method can do three things:
+
+ 1. Return a Lookup class
+ 2. Raise a FieldError
+ 3. Return None
+
+Above return None is only available during backwards compatibility period and
+returning None will not be allowed in Django 1.9 or later. The interpretation
+is to use the old way of lookup hadling inside the ORM.
+
+The returned Lookup will be used to build the query.
+
+The Lookup class
+~~~~~~~~~~~~~~~~
+
+The API is as follows:
+
+.. attribute:: lookup_name
+
+A string used by Django to distinguish different lookups.
+
+.. method:: __init__(lhs, rhs)
+
+The lhs and rhs are the field reference (reference to field age in the
+author__age__lte=30 example), and rhs is the value (30 in the example).
+
+.. attribute:: Lookup.lhs
+
+The left hand side part of this lookup. You can assume it implements the
+query part interface (TODO: write interface definition...).
+
+.. method:: Lookup.as_sql(qn, connection)
+
+This method is used to produce the query string of the Lookup. A typical
+implementation is usually something like::
+
+ def as_sql(self, qn, connection):
+ lhs, params = self.process_lhs(qn, connection)
+ rhs, rhs_params = self.process_rhs(qn, connection)
+ params = lhs_params.extend(rhs_params)
+ return '%s <OPERATOR> %s', (lhs, rhs), params
+
+where the <OPERATOR> is some query operator. The qn is a callable that
+can be used to convert strings to quoted variants (that is, colname to
+"colname"). Note that the quotation is *not* safe against SQL injection.
+
+In addition the qn implements method compile() which can be used to turn
+anything with as_sql() method to query string. You should always call
+qn.compile(part) instead of part.as_sql(qn, connection) so that 3rd party
+backends have ability to customize the produced query string. More of this
+later on.
+
+The connection is the used connection.
+
+.. method:: Lookup.process_lhs(qn, connection, lhs=None)
+
+This method is used to convert the left hand side of the lookup into query
+string. The left hand side can be a field reference or a nested lookup. The
+lhs kwarg can be used to convert something else than self.lhs to query string.
+
+.. method:: Lookup.process_rhs(qn, connection, rhs=None)
+
+The process_rhs method is used to convert the right hand side into query string.
+The rhs is the value given in the filter clause. It can be a raw value to
+compare agains, a F() reference to another field or even a QuerySet.
+
+.. method:: get_extract()
+
+The get_extract method is used in nested lookups. It must return an Extract instance.
+
+.. classattribute:: Lookup.extract_class
+
+The default implementation of get_extract() will return an instance of extract_class.
+
+In addition there are some private methods - that is, implementing just the above
+mentioned attributes and methods is not enough, you must subclass Lookup instead.
+
+The Extract class
+~~~~~~~~~~~~~~~~~
+
+An Extract is something that converts a value to another value in the query string.
+For example you could have an Extract that procudes modulo 3 of the given value.
+In SQL this would be something like "author"."age" % 3.
+
+Extracts are used in nested lookups. The Extract class must implement the query
+part interface.
+
+A simple Lookup example
+~~~~~~~~~~~~~~~~~~~~~~~
+
+This is how to write a simple div3 lookup for IntegerField::
+
+ from django.db.models import Lookup, IntegerField
+ class Div3(Lookup):
+ lookup_name = 'div3'
+
+ def as_sql(self, qn, connection):
+ lhs_sql, params = self.process_lhs(qn, connection)
+ rhs_sql, rhs_params = self.process_rhs(qn, connection)
+ params.extend(rhs_params)
+ # We need doulbe-escaping for the %%%% operator.
+ return '%s %%%% %s' % (lhs_sql, rhs_sql), params
+
+ IntegerField.register_lookup(Div3)
+
+Now all IntegerFields or subclasses of IntegerField will have
+a div3 lookup. For example you could do Author.objects.filter(age__div3=2).
+This query would return every author whose age % 3 == 2.
+
+A simple nested lookup example
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Here is how to write an Extract and a Lookup for IntegerField. The example
+lookup can be used similarly as the above div3 lookup, and in addition it
+support nesting lookups::
+
+ class Div3Extract(Extract):
+ lookup_name = 'div3'
+
+ def as_sql(self, qn, connection):
+ lhs, lhs_params = qn.compile(self.lhs)
+ return '%s %%%% 3' % (lhs,), lhs_params
+
+ IntegerField.register_lookup(Div3Extract)
+
+Note that if you already added Div3 for IntegerField in the above
+example, now Div3LookupWithExtract will override that lookup.
+
+This lookup can be used like Div3 lookup, but in addition it supports
+nesting, too. The default output type for Extracts is the same type as the
+lhs' output_type. So, the Div3Extract supports all the same lookups as
+IntegerField. For example Author.objects.filter(age__div3__in=[1, 2])
+returns all authors for which age % 3 in (1, 2).
+
+A more complex nested lookup
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+We will write a Year lookup that extracts year from date field. This
+field will convert the output type of the field - the lhs (or "input")
+field is DateField, but output is of type IntegerField.::
+
+ from django.db.models import IntegerField, DateField
+ from django.db.models.lookups import Extract
+
+ class YearExtract(Extract):
+ lookup_name = 'year'
+
+ def as_sql(self, qn, connection):
+ lhs_sql, params = qn.compile(self.lhs)
+ # hmmh - this is internal API...
+ return connection.ops.date_extract_sql('year', lhs_sql), params
+
+ @property
+ def output_type(self):
+ return IntegerField()
+
+ DateField.register_lookup(YearExtract)
+
+Now you could write Author.objects.filter(birthdate__year=1981). This will
+produce SQL like 'EXTRACT('year' from "author"."birthdate") = 1981'. The
+produces SQL depends on used backend. In addtition you can use any lookup
+defined for IntegerField, even div3 if you added that. So,
+Authos.objects.filter(birthdate__year__div3=2) will return every author
+with birthdate.year % 3 == 2.
+
+We could go further and add an optimized implementation for exact lookups::
+
+ from django.db.models.lookups import Lookup
+
+ class YearExtractOptimized(YearExtract):
+ def get_lookup(self, lookup):
+ if lookup == 'exact':
+ return YearExact
+ return super(YearExtractOptimized, self).get_lookup()
+
+ class YearExact(Lookup):
+ def as_sql(self, qn, connection):
+ # We will need to skip the extract part, and instead go
+ # directly with the originating field, that is self.lhs.lhs
+ lhs_sql, lhs_params = self.process_lhs(qn, connection, self.lhs.lhs)
+ rhs_sql, rhs_params = self.process_rhs(qn, connection)
+ # Note that we must be careful so that we have params in the
+ # same order as we have the parts in the SQL.
+ params = []
+ params.extend(lhs_params)
+ params.extend(rhs_params)
+ params.extend(lhs_params)
+ params.extend(rhs_params)
+ # We use PostgreSQL specific SQL here. Note that we must do the
+ # conversions in SQL instead of in Python to support F() references.
+ return ("%(lhs)s >= (%(rhs)s || '-01-01')::date "
+ "AND %(lhs)s <= (%(rhs)s || '-12-31')::date" %
+ {'lhs': lhs_sql, 'rhs': rhs_sql}, params)
+
+Note that we used PostgreSQL specific SQL above. What if we want to support
+MySQL, too? This can be done by registering a different compiling implementation
+for MySQL::
+
+ from django.db.backends.utils import add_implementation
+ @add_implementation(YearExact, 'mysql')
+ def mysql_year_exact(node, qn, connection):
+ lhs_sql, lhs_params = node.process_lhs(qn, connection, node.lhs.lhs)
+ rhs_sql, rhs_params = node.process_rhs(qn, connection)
+ params = []
+ params.extend(lhs_params)
+ params.extend(rhs_params)
+ params.extend(lhs_params)
+ params.extend(rhs_params)
+ return ("%(lhs)s >= str_to_date(concat(%(rhs)s, '-01-01'), '%%%%Y-%%%%m-%%%%d') "
+ "AND %(lhs)s <= str_to_date(concat(%(rhs)s, '-12-31'), '%%%%Y-%%%%m-%%%%d')" %
+ {'lhs': lhs_sql, 'rhs': rhs_sql}, params)
+
+Now, on MySQL instead of calling as_sql() of the YearExact Django will use the
+above compile implementation.
76 tests/custom_lookups/tests.py
View
@@ -20,16 +20,13 @@ def as_sql(self, qn, connection):
class Div3Extract(models.lookups.Extract):
+ lookup_name = 'div3'
+
def as_sql(self, qn, connection):
lhs, lhs_params = qn.compile(self.lhs)
return '%s %%%% 3' % (lhs,), lhs_params
-class Div3LookupWithExtract(Div3Lookup):
- lookup_name = 'div3'
- extract_class = Div3Extract
-
-
class YearLte(models.lookups.LessThanOrEqual):
"""
The purpose of this lookup is to efficiently compare the year of the field.
@@ -50,6 +47,8 @@ def as_sql(self, qn, connection):
class YearExtract(models.lookups.Extract):
+ lookup_name = 'year'
+
def as_sql(self, qn, connection):
lhs_sql, params = qn.compile(self.lhs)
return connection.ops.date_extract_sql('year', lhs_sql), params
@@ -61,12 +60,44 @@ def output_type(self):
def get_lookup(self, lookup):
if lookup == 'lte':
return YearLte
+ elif lookup == 'exact':
+ return YearExact
else:
return super(YearExtract, self).get_lookup(lookup)
-class YearWithExtract(models.lookups.Year):
- extract_class = YearExtract
+class YearExact(models.lookups.Lookup):
+ def as_sql(self, qn, connection):
+ # We will need to skip the extract part, and instead go
+ # directly with the originating field, that is self.lhs.lhs
+ lhs_sql, lhs_params = self.process_lhs(qn, connection, self.lhs.lhs)
+ rhs_sql, rhs_params = self.process_rhs(qn, connection)
+ # Note that we must be careful so that we have params in the
+ # same order as we have the parts in the SQL.
+ params = []
+ params.extend(lhs_params)
+ params.extend(rhs_params)
+ params.extend(lhs_params)
+ params.extend(rhs_params)
+ # We use PostgreSQL specific SQL here. Note that we must do the
+ # conversions in SQL instead of in Python to support F() references.
+ return ("%(lhs)s >= (%(rhs)s || '-01-01')::date "
+ "AND %(lhs)s <= (%(rhs)s || '-12-31')::date" %
+ {'lhs': lhs_sql, 'rhs': rhs_sql}, params)
+
+
+@add_implementation(YearExact, 'mysql')
+def mysql_year_exact(node, qn, connection):
+ lhs_sql, lhs_params = node.process_lhs(qn, connection, node.lhs.lhs)
+ rhs_sql, rhs_params = node.process_rhs(qn, connection)
+ params = []
+ params.extend(lhs_params)
+ params.extend(rhs_params)
+ params.extend(lhs_params)
+ params.extend(rhs_params)
+ return ("%(lhs)s >= str_to_date(concat(%(rhs)s, '-01-01'), '%%%%Y-%%%%m-%%%%d') "
+ "AND %(lhs)s <= str_to_date(concat(%(rhs)s, '-12-31'), '%%%%Y-%%%%m-%%%%d')" %
+ {'lhs': lhs_sql, 'rhs': rhs_sql}, params)
class InMonth(models.lookups.Lookup):
@@ -158,7 +189,7 @@ class AnotherEqual(models.lookups.Exact):
models.Field.register_lookup(AnotherEqual)
try:
@add_implementation(AnotherEqual, connection.vendor)
- def custom_eq_sql(node, compiler):
+ def custom_eq_sql(node, qn, connection):
return '1 = 1', []
self.assertIn('1 = 1', str(Author.objects.filter(name__anotherequal='asdf').query))
@@ -167,7 +198,7 @@ def custom_eq_sql(node, compiler):
[a1, a2, a3, a4], lambda x: x)
@add_implementation(AnotherEqual, connection.vendor)
- def another_custom_eq_sql(node, compiler):
+ def another_custom_eq_sql(node, qn, connection):
# If you need to override one method, it seems this is the best
# option.
node = copy(node)
@@ -176,7 +207,7 @@ class OverriddenAnotherEqual(AnotherEqual):
def get_rhs_op(self, connection, rhs):
return ' <> %s'
node.__class__ = OverriddenAnotherEqual
- return node.as_sql(compiler, compiler.connection)
+ return node.as_sql(qn, connection)
self.assertIn(' <> ', str(Author.objects.filter(name__anotherequal='a1').query))
self.assertQuerysetEqual(
Author.objects.filter(name__anotherequal='a1').order_by('name'),
@@ -186,7 +217,7 @@ def get_rhs_op(self, connection, rhs):
models.Field._unregister_lookup(AnotherEqual)
def test_div3_extract(self):
- models.IntegerField.register_lookup(Div3LookupWithExtract)
+ models.IntegerField.register_lookup(Div3Extract)
try:
a1 = Author.objects.create(name='a1', age=1)
a2 = Author.objects.create(name='a2', age=2)
@@ -194,25 +225,28 @@ def test_div3_extract(self):
a4 = Author.objects.create(name='a4', age=4)
baseqs = Author.objects.order_by('name')
self.assertQuerysetEqual(
+ baseqs.filter(age__div3=2),
+ [a2], lambda x: x)
+ self.assertQuerysetEqual(
baseqs.filter(age__div3__lte=3),
[a1, a2, a3, a4], lambda x: x)
self.assertQuerysetEqual(
baseqs.filter(age__div3__in=[0, 2]),
[a2, a3], lambda x: x)
finally:
- models.IntegerField._unregister_lookup(Div3LookupWithExtract)
+ models.IntegerField._unregister_lookup(Div3Extract)
class YearLteTests(TestCase):
def setUp(self):
- models.DateField.register_lookup(YearWithExtract)
+ models.DateField.register_lookup(YearExtract)
self.a1 = Author.objects.create(name='a1', birthdate=date(1981, 2, 16))
self.a2 = Author.objects.create(name='a2', birthdate=date(2012, 2, 29))
self.a3 = Author.objects.create(name='a3', birthdate=date(2012, 1, 31))
self.a4 = Author.objects.create(name='a4', birthdate=date(2012, 3, 1))
def tearDown(self):
- models.DateField._unregister_lookup(YearWithExtract)
+ models.DateField._unregister_lookup(YearExtract)
@unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific SQL used")
def test_year_lte(self):
@@ -221,6 +255,11 @@ def test_year_lte(self):
baseqs.filter(birthdate__year__lte=2012),
[self.a1, self.a2, self.a3, self.a4], lambda x: x)
self.assertQuerysetEqual(
+ baseqs.filter(birthdate__year=2012),
+ [self.a2, self.a3, self.a4], lambda x: x)
+
+ self.assertNotIn('BETWEEN', str(baseqs.filter(birthdate__year=2012).query))
+ self.assertQuerysetEqual(
baseqs.filter(birthdate__year__lte=2011),
[self.a1], lambda x: x)
# The non-optimized version works, too.
@@ -253,3 +292,12 @@ def test_year_lte_sql(self):
'<= (2011 || ', str(baseqs.filter(birthdate__year__lte=2011).query))
self.assertIn(
'-12-31', str(baseqs.filter(birthdate__year__lte=2011).query))
+
+ @unittest.skipUnless(connection.vendor == 'mysql', 'MySQL specific SQL used')
+ def test_mysql_year_exact(self):
+ self.assertQuerysetEqual(
+ Author.objects.filter(birthdate__year=2012).order_by('name'),
+ [self.a2, self.a3, self.a4], lambda x: x)
+ self.assertIn(
+ 'concat(',
+ str(Author.objects.filter(birthdate__year=2012).query))
Please sign in to comment.
Something went wrong with that request. Please try again.