Permalink
Browse files

Added support for modifying the effect of ``DISTINCT`` clauses so they

only consider some fields (PostgreSQL only).

For this, the ``distinct()`` QuerySet method now accepts an optional
list of model fields names and generates ``DISTINCT ON`` clauses on
these cases. Thanks Jeffrey Gelens and Anssi Kääriäinen for their work.

Fixes #6422.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@17244 bcc190cf-cafb-0310-a4f2-bffc1f526a37
  • Loading branch information...
1 parent 03eb290 commit 287565779d3ae4d3229ecbb2ff356c79b920e7d0 @ramiro ramiro committed Dec 22, 2011
View
@@ -203,6 +203,7 @@ answer newbie questions, and generally made Django that much better:
Marc Garcia <marc.garcia@accopensys.com>
Andy Gayton <andy-django@thecablelounge.com>
geber@datacollect.com
+ Jeffrey Gelens <jeffrey@gelens.org>
Baishampayan Ghose
Joshua Ginsberg <jag@flowtheory.net>
Dimitris Glezos <dimitris@glezos.com>
@@ -269,6 +270,7 @@ answer newbie questions, and generally made Django that much better:
jpellerin@gmail.com
junzhang.jn@gmail.com
Xia Kai <http://blog.xiaket.org/>
+ Anssi Kääriäinen
Antti Kaihola <http://djangopeople.net/akaihola/>
Peter van Kampen
Bahadır Kandemir <bahadir@pardus.org.tr>
@@ -406,6 +406,9 @@ class BaseDatabaseFeatures(object):
supports_stddev = None
can_introspect_foreign_keys = None
+ # Support for the DISTINCT ON clause
+ can_distinct_on_fields = False
+
def __init__(self, connection):
self.connection = connection
@@ -559,6 +562,17 @@ def fulltext_search_sql(self, field_name):
"""
raise NotImplementedError('Full-text search is not implemented for this database backend')
+ def distinct_sql(self, fields):
+ """
+ Returns an SQL DISTINCT clause which removes duplicate rows from the
+ result set. If any fields are given, only the given fields are being
+ checked for duplicates.
+ """
+ if fields:
+ raise NotImplementedError('DISTINCT ON fields is not supported by this database backend')
+ else:
+ return 'DISTINCT'
+
def last_executed_query(self, cursor, sql, params):
"""
Returns a string of the query last executed by the given cursor, with
@@ -82,6 +82,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
has_select_for_update_nowait = True
has_bulk_insert = True
supports_tablespaces = True
+ can_distinct_on_fields = True
class DatabaseWrapper(BaseDatabaseWrapper):
vendor = 'postgresql'
@@ -179,6 +179,12 @@ def max_name_length(self):
return 63
+ def distinct_sql(self, fields):
+ if fields:
+ return 'DISTINCT ON (%s)' % ', '.join(fields)
+ else:
+ return 'DISTINCT'
+
def last_executed_query(self, cursor, sql, params):
# http://initd.org/psycopg/docs/cursor.html#cursor.query
# The query attribute is a Psycopg extension to the DB API 2.0.
@@ -323,6 +323,8 @@ def aggregate(self, *args, **kwargs):
If args is present the expression is passed as a kwarg using
the Aggregate object's default alias.
"""
+ if self.query.distinct_fields:
+ raise NotImplementedError("aggregate() + distinct(fields) not implemented.")
for arg in args:
kwargs[arg.default_alias] = arg
@@ -751,12 +753,14 @@ def order_by(self, *field_names):
obj.query.add_ordering(*field_names)
return obj
- def distinct(self, true_or_false=True):
+ def distinct(self, *field_names):
"""
Returns a new QuerySet instance that will select only distinct results.
"""
+ assert self.query.can_filter(), \
+ "Cannot create distinct fields once a slice has been taken."
obj = self._clone()
- obj.query.distinct = true_or_false
+ obj.query.add_distinct_fields(*field_names)
return obj
def extra(self, select=None, where=None, params=None, tables=None,
@@ -1179,7 +1183,7 @@ def order_by(self, *field_names):
"""
return self
- def distinct(self, true_or_false=True):
+ def distinct(self, fields=None):
"""
Always returns EmptyQuerySet.
"""
@@ -23,6 +23,8 @@ def pre_sql_setup(self):
Does any necessary class setup immediately prior to producing SQL. This
is for things that can't necessarily be done in __init__ because we
might not have all the pieces in place at that time.
+ # TODO: after the query has been executed, the altered state should be
+ # cleaned. We are not using a clone() of the query here.
"""
if not self.query.tables:
self.query.join((None, self.query.model._meta.db_table, None, None))
@@ -60,11 +62,19 @@ def as_sql(self, with_limits=True, with_col_aliases=False):
return '', ()
self.pre_sql_setup()
+ # After executing the query, we must get rid of any joins the query
+ # setup created. So, take note of alias counts before the query ran.
+ # However we do not want to get rid of stuff done in pre_sql_setup(),
+ # as the pre_sql_setup will modify query state in a way that forbids
+ # another run of it.
+ self.refcounts_before = self.query.alias_refcount.copy()
out_cols = self.get_columns(with_col_aliases)
ordering, ordering_group_by = self.get_ordering()
- # This must come after 'select' and 'ordering' -- see docstring of
- # get_from_clause() for details.
+ distinct_fields = self.get_distinct()
+
+ # This must come after 'select', 'ordering' and 'distinct' -- see
+ # docstring of get_from_clause() for details.
from_, f_params = self.get_from_clause()
qn = self.quote_name_unless_alias
@@ -76,8 +86,10 @@ def as_sql(self, with_limits=True, with_col_aliases=False):
params.extend(val[1])
result = ['SELECT']
+
if self.query.distinct:
- result.append('DISTINCT')
+ result.append(self.connection.ops.distinct_sql(distinct_fields))
+
result.append(', '.join(out_cols + self.query.ordering_aliases))
result.append('FROM')
@@ -90,6 +102,9 @@ def as_sql(self, with_limits=True, with_col_aliases=False):
grouping, gb_params = self.get_grouping()
if grouping:
+ if distinct_fields:
+ raise NotImplementedError(
+ "annotate() + distinct(fields) not implemented.")
if ordering:
# If the backend can't group by PK (i.e., any database
# other than MySQL), then any fields mentioned in the
@@ -129,6 +144,9 @@ def as_sql(self, with_limits=True, with_col_aliases=False):
raise DatabaseError('NOWAIT is not supported on this database backend.')
result.append(self.connection.ops.for_update_sql(nowait=nowait))
+ # Finally do cleanup - get rid of the joins we created above.
+ self.query.reset_refcounts(self.refcounts_before)
+
return ' '.join(result), tuple(params)
def as_nested_sql(self):
@@ -292,6 +310,26 @@ def get_default_columns(self, with_aliases=False, col_aliases=None,
col_aliases.add(field.column)
return result, aliases
+ def get_distinct(self):
+ """
+ Returns a quoted list of fields to use in DISTINCT ON part of the query.
+
+ Note that this method can alter the tables in the query, and thus it
+ must be called before get_from_clause().
+ """
+ qn = self.quote_name_unless_alias
+ qn2 = self.connection.ops.quote_name
+ result = []
+ opts = self.query.model._meta
+
+ for name in self.query.distinct_fields:
+ parts = name.split(LOOKUP_SEP)
+ field, col, alias, _, _ = self._setup_joins(parts, opts, None)
+ col, alias = self._final_join_removal(col, alias)
+ result.append("%s.%s" % (qn(alias), qn2(col)))
+ return result
+
+
def get_ordering(self):
"""
Returns a tuple containing a list representing the SQL elements in the
@@ -384,21 +422,7 @@ def find_ordering_name(self, name, opts, alias=None, default_order='ASC',
"""
name, order = get_order_dir(name, default_order)
pieces = name.split(LOOKUP_SEP)
- if not alias:
- alias = self.query.get_initial_alias()
- field, target, opts, joins, last, extra = self.query.setup_joins(pieces,
- opts, alias, False)
- alias = joins[-1]
- col = target.column
- if not field.rel:
- # To avoid inadvertent trimming of a necessary alias, use the
- # refcount to show that we are referencing a non-relation field on
- # the model.
- self.query.ref_alias(alias)
-
- # Must use left outer joins for nullable fields and their relations.
- self.query.promote_alias_chain(joins,
- self.query.alias_map[joins[0]][JOIN_TYPE] == self.query.LOUTER)
+ field, col, alias, joins, opts = self._setup_joins(pieces, opts, alias)
# If we get to this point and the field is a relation to another model,
# append the default ordering for that model.
@@ -416,19 +440,55 @@ def find_ordering_name(self, name, opts, alias=None, default_order='ASC',
results.extend(self.find_ordering_name(item, opts, alias,
order, already_seen))
return results
+ col, alias = self._final_join_removal(col, alias)
+ return [(alias, col, order)]
+
+ def _setup_joins(self, pieces, opts, alias):
+ """
+ A helper method for get_ordering and get_distinct. This method will
+ call query.setup_joins, handle refcounts and then promote the joins.
+
+ Note that get_ordering and get_distinct must produce same target
+ columns on same input, as the prefixes of get_ordering and get_distinct
+ must match. Executing SQL where this is not true is an error.
+ """
+ if not alias:
+ alias = self.query.get_initial_alias()
+ field, target, opts, joins, _, _ = self.query.setup_joins(pieces,
+ opts, alias, False)
+ alias = joins[-1]
+ col = target.column
+ if not field.rel:
+ # To avoid inadvertent trimming of a necessary alias, use the
+ # refcount to show that we are referencing a non-relation field on
+ # the model.
+ self.query.ref_alias(alias)
+ # Must use left outer joins for nullable fields and their relations.
+ # Ordering or distinct must not affect the returned set, and INNER
+ # JOINS for nullable fields could do this.
+ self.query.promote_alias_chain(joins,
+ self.query.alias_map[joins[0]][JOIN_TYPE] == self.query.LOUTER)
+ return field, col, alias, joins, opts
+
+ def _final_join_removal(self, col, alias):
+ """
+ A helper method for get_distinct and get_ordering. This method will
+ trim extra not-needed joins from the tail of the join chain.
+
+ This is very similar to what is done in trim_joins, but we will
+ trim LEFT JOINS here. It would be a good idea to consolidate this
+ method and query.trim_joins().
+ """
if alias:
- # We have to do the same "final join" optimisation as in
- # add_filter, since the final column might not otherwise be part of
- # the select set (so we can't order on it).
while 1:
join = self.query.alias_map[alias]
if col != join[RHS_JOIN_COL]:
break
self.query.unref_alias(alias)
alias = join[LHS_ALIAS]
col = join[LHS_JOIN_COL]
- return [(alias, col, order)]
+ return col, alias
def get_from_clause(self):
"""
@@ -438,8 +498,8 @@ def get_from_clause(self):
from-clause via a "select".
This should only be called after any SQL construction methods that
- might change the tables we need. This means the select columns and
- ordering must be done first.
+ might change the tables we need. This means the select columns,
+ ordering and distinct must be done first.
"""
result = []
qn = self.quote_name_unless_alias
@@ -984,6 +1044,7 @@ def as_sql(self, qn=None):
"""
if qn is None:
qn = self.quote_name_unless_alias
+
sql = ('SELECT %s FROM (%s) subquery' % (
', '.join([
aggregate.as_sql(qn, self.connection)
@@ -127,6 +127,7 @@ def __init__(self, model, where=WhereNode):
self.order_by = []
self.low_mark, self.high_mark = 0, None # Used for offset/limit
self.distinct = False
+ self.distinct_fields = []
self.select_for_update = False
self.select_for_update_nowait = False
self.select_related = False
@@ -265,6 +266,7 @@ def clone(self, klass=None, memo=None, **kwargs):
obj.order_by = self.order_by[:]
obj.low_mark, obj.high_mark = self.low_mark, self.high_mark
obj.distinct = self.distinct
+ obj.distinct_fields = self.distinct_fields[:]
obj.select_for_update = self.select_for_update
obj.select_for_update_nowait = self.select_for_update_nowait
obj.select_related = self.select_related
@@ -298,6 +300,7 @@ def clone(self, klass=None, memo=None, **kwargs):
else:
obj.used_aliases = set()
obj.filter_is_sticky = False
+
obj.__dict__.update(kwargs)
if hasattr(obj, '_setup_query'):
obj._setup_query()
@@ -393,7 +396,7 @@ def get_count(self, using):
Performs a COUNT() query using the current filter constraints.
"""
obj = self.clone()
- if len(self.select) > 1 or self.aggregate_select:
+ if len(self.select) > 1 or self.aggregate_select or (self.distinct and self.distinct_fields):
# If a select clause exists, then the query has already started to
# specify the columns that are to be returned.
# In this case, we need to use a subquery to evaluate the count.
@@ -452,6 +455,8 @@ def combine(self, rhs, connector):
"Cannot combine queries once a slice has been taken."
assert self.distinct == rhs.distinct, \
"Cannot combine a unique query with a non-unique query."
+ assert self.distinct_fields == rhs.distinct_fields, \
+ "Cannot combine queries with different distinct fields."
self.remove_inherited_models()
# Work out how to relabel the rhs aliases, if necessary.
@@ -674,9 +679,9 @@ def ref_alias(self, alias):
""" Increases the reference count for this alias. """
self.alias_refcount[alias] += 1
- def unref_alias(self, alias):
+ def unref_alias(self, alias, amount=1):
""" Decreases the reference count for this alias. """
- self.alias_refcount[alias] -= 1
+ self.alias_refcount[alias] -= amount
def promote_alias(self, alias, unconditional=False):
"""
@@ -705,6 +710,15 @@ def promote_alias_chain(self, chain, must_promote=False):
if self.promote_alias(alias, must_promote):
must_promote = True
+ def reset_refcounts(self, to_counts):
+ """
+ This method will reset reference counts for aliases so that they match
+ the value passed in :param to_counts:.
+ """
+ for alias, cur_refcount in self.alias_refcount.copy().items():
+ unref_amount = cur_refcount - to_counts.get(alias, 0)
+ self.unref_alias(alias, unref_amount)
+
def promote_unused_aliases(self, initial_refcounts, used_aliases):
"""
Given a "before" copy of the alias_refcounts dictionary (as
@@ -832,7 +846,8 @@ def get_initial_alias(self):
def count_active_tables(self):
"""
Returns the number of tables in this query with a non-zero reference
- count.
+ count. Note that after execution, the reference counts are zeroed, so
+ tables added in compiler will not be seen by this method.
"""
return len([1 for count in self.alias_refcount.itervalues() if count])
@@ -1596,6 +1611,13 @@ def clear_select_fields(self):
self.select = []
self.select_fields = []
+ def add_distinct_fields(self, *field_names):
+ """
+ Adds and resolves the given fields to the query's "distinct on" clause.
+ """
+ self.distinct_fields = field_names
+ self.distinct = True
+
def add_fields(self, field_names, allow_m2m=True):
"""
Adds the given (model) fields to the select set. The field names are
Oops, something went wrong.

0 comments on commit 2875657

Please sign in to comment.