Skip to content

Fixed #27718 -- Added QuerySet.union(), intersection(), difference(). #7727

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 14, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions django/db/backends/base/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,12 @@ class BaseDatabaseFeatures(object):
# Place FOR UPDATE right after FROM clause. Used on MSSQL.
for_update_after_from = False

# Combinatorial flags
supports_select_union = True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How did you choose the defaults? I think it would be simplest to default to True and have backends opt-out as needed (that eliminates the risk that a backend fails to opt-in) but feel free to explain your thinking.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, default on sounds more sensible to me since they're part of the SQL standard

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was looking for the lowest common denominator, but have no problem to switch that to true. I'd leave supports_slicing_ordering_in_compound as False though since most databases do not support it and afaik it is an extension to the standard.

supports_select_intersection = True
supports_select_difference = True
supports_slicing_ordering_in_compound = False

def __init__(self, connection):
self.connection = connection

Expand Down
5 changes: 5 additions & 0 deletions django/db/backends/base/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ class BaseDatabaseOperations(object):
'PositiveSmallIntegerField': (0, 32767),
'PositiveIntegerField': (0, 2147483647),
}
set_operators = {
'union': 'UNION',
'intersection': 'INTERSECT',
'difference': 'EXCEPT',
}

def __init__(self, connection):
self.connection = connection
Expand Down
3 changes: 3 additions & 0 deletions django/db/backends/mysql/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ class DatabaseFeatures(BaseDatabaseFeatures):
supports_column_check_constraints = False
can_clone_databases = True
supports_temporal_subtraction = True
supports_select_intersection = False
supports_select_difference = False
supports_slicing_ordering_in_compound = True

@cached_property
def _mysql_storage_engine(self):
Expand Down
4 changes: 4 additions & 0 deletions django/db/backends/oracle/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ class DatabaseOperations(BaseDatabaseOperations):
END;
/"""

def __init__(self, *args, **kwargs):
super(DatabaseOperations, self).__init__(*args, **kwargs)
self.set_operators['difference'] = 'MINUS'

def autoinc_sql(self, table, column):
# To simulate auto-incrementing primary keys in Oracle, we have to
# create a sequence and a trigger.
Expand Down
1 change: 1 addition & 0 deletions django/db/backends/postgresql/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
greatest_least_ignores_nulls = True
can_clone_databases = True
supports_temporal_subtraction = True
supports_slicing_ordering_in_compound = True

@cached_property
def has_select_for_update_skip_locked(self):
Expand Down
27 changes: 27 additions & 0 deletions django/db/models/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,6 +816,33 @@ def complex_filter(self, filter_obj):
else:
return self._filter_or_exclude(None, **filter_obj)

def _combinator_query(self, combinator, *other_qs, **kwargs):
# Clone the query to inherit the select list and everything
clone = self._clone()
# Clear limits and ordering so they can be reapplied
clone.query.clear_ordering(True)
clone.query.clear_limits()
clone.query.combined_queries = (self.query,) + tuple(qs.query for qs in other_qs)
clone.query.combinator = combinator
clone.query.combinator_all = kwargs.pop('all', False)
return clone

def union(self, *other_qs, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to have union_all as separate method? That way there would be no need for unbounded kwargs. If not, at least assert here that there are no other kwargs than 'all' given here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the kwargs more than union_all, starting with django 2.0 we can do kwarg only arguments, so the unbounded issue goes away…

if kwargs:
unexpected_kwarg = next((k for k in kwargs.keys() if k != 'all'), None)
if unexpected_kwarg:
raise TypeError(
"union() received an unexpected keyword argument '%s'" %
(unexpected_kwarg,)
)
return self._combinator_query('union', *other_qs, **kwargs)

def intersection(self, *other_qs):
return self._combinator_query('intersection', *other_qs)

def difference(self, *other_qs):
return self._combinator_query('difference', *other_qs)

def select_for_update(self, nowait=False, skip_locked=False):
"""
Returns a new QuerySet instance that will select objects with a
Expand Down
166 changes: 106 additions & 60 deletions django/db/models/sql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,21 @@ def get_order_by(self):
seen = set()

for expr, is_ref in order_by:
if self.query.combinator:
src = expr.get_source_expressions()[0]
# Relabel order by columns to raw numbers if this is a combined
# query; necessary since the columns can't be referenced by the
# fully qualified name and the simple column names may collide.
for idx, (sel_expr, _, col_alias) in enumerate(self.select):
if is_ref and col_alias == src.refs:
src = src.source
elif col_alias:
continue
if src == sel_expr:
expr.set_source_expressions([RawSQL('%d' % (idx + 1), ())])
break
else:
raise DatabaseError('ORDER BY term does not match any column in the result set.')
resolved = expr.resolve_expression(
self.query, allow_joins=True, reuse=None)
sql, params = self.compile(resolved)
Expand Down Expand Up @@ -360,6 +375,30 @@ def compile(self, node, select_format=False):
return node.output_field.select_format(self, sql, params)
return sql, params

def get_combinator_sql(self, combinator, all):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small nit: I would suggest using a name that is not a built-in. Perhaps all_ or _all.

features = self.connection.features
compilers = [
query.get_compiler(self.using, self.connection)
for query in self.query.combined_queries
]
if not features.supports_slicing_ordering_in_compound:
for query, compiler in zip(self.query.combined_queries, compilers):
if query.low_mark or query.high_mark:
raise DatabaseError('LIMIT/OFFSET not allowed in subqueries of compound statements.')
if compiler.get_order_by():
raise DatabaseError('ORDER BY not allowed in subqueries of compound statements.')
parts = (compiler.as_sql() for compiler in compilers)
combinator_sql = self.connection.ops.set_operators[combinator]
if all and combinator == 'union':
combinator_sql += ' ALL'
braces = '({})' if features.supports_slicing_ordering_in_compound else '{}'
sql_parts, args_parts = zip(*((braces.format(sql), args) for sql, args in parts))
result = [' {} '.format(combinator_sql).join(sql_parts)]
params = []
for part in args_parts:
params.extend(part)
return result, params

def as_sql(self, with_limits=True, with_col_aliases=False):
"""
Creates the SQL for this query. Returns the SQL string and list of
Expand All @@ -377,69 +416,76 @@ def as_sql(self, with_limits=True, with_col_aliases=False):
# docstring of get_from_clause() for details.
from_, f_params = self.get_from_clause()

for_update_part = None
where, w_params = self.compile(self.where) if self.where is not None else ("", [])
having, h_params = self.compile(self.having) if self.having is not None else ("", [])
params = []
result = ['SELECT']

if self.query.distinct:
result.append(self.connection.ops.distinct_sql(distinct_fields))

out_cols = []
col_idx = 1
for _, (s_sql, s_params), alias in self.select + extra_select:
if alias:
s_sql = '%s AS %s' % (s_sql, self.connection.ops.quote_name(alias))
elif with_col_aliases:
s_sql = '%s AS %s' % (s_sql, 'Col%d' % col_idx)
col_idx += 1
params.extend(s_params)
out_cols.append(s_sql)

result.append(', '.join(out_cols))

result.append('FROM')
result.extend(from_)
params.extend(f_params)

for_update_part = None
if self.query.select_for_update and self.connection.features.has_select_for_update:
if self.connection.get_autocommit():
raise TransactionManagementError("select_for_update cannot be used outside of a transaction.")

nowait = self.query.select_for_update_nowait
skip_locked = self.query.select_for_update_skip_locked
# If it's a NOWAIT/SKIP LOCKED query but the backend doesn't
# support it, raise a DatabaseError to prevent a possible
# deadlock.
if nowait and not self.connection.features.has_select_for_update_nowait:
raise DatabaseError('NOWAIT is not supported on this database backend.')
elif skip_locked and not self.connection.features.has_select_for_update_skip_locked:
raise DatabaseError('SKIP LOCKED is not supported on this database backend.')
for_update_part = self.connection.ops.for_update_sql(nowait=nowait, skip_locked=skip_locked)

if for_update_part and self.connection.features.for_update_after_from:
result.append(for_update_part)

if where:
result.append('WHERE %s' % where)
params.extend(w_params)

grouping = []
for g_sql, g_params in group_by:
grouping.append(g_sql)
params.extend(g_params)
if grouping:
if distinct_fields:
raise NotImplementedError(
"annotate() + distinct(fields) is not implemented.")
if not order_by:
order_by = self.connection.ops.force_no_ordering()
result.append('GROUP BY %s' % ', '.join(grouping))

if having:
result.append('HAVING %s' % having)
params.extend(h_params)
combinator = self.query.combinator
features = self.connection.features
if combinator:
if not getattr(features, 'supports_select_{}'.format(combinator)):
raise DatabaseError('{} not supported on this database backend.'.format(combinator))
result, params = self.get_combinator_sql(combinator, self.query.combinator_all)
else:
result = ['SELECT']
params = []

if self.query.distinct:
result.append(self.connection.ops.distinct_sql(distinct_fields))

out_cols = []
col_idx = 1
for _, (s_sql, s_params), alias in self.select + extra_select:
if alias:
s_sql = '%s AS %s' % (s_sql, self.connection.ops.quote_name(alias))
elif with_col_aliases:
s_sql = '%s AS %s' % (s_sql, 'Col%d' % col_idx)
col_idx += 1
params.extend(s_params)
out_cols.append(s_sql)

result.append(', '.join(out_cols))

result.append('FROM')
result.extend(from_)
params.extend(f_params)

if self.query.select_for_update and self.connection.features.has_select_for_update:
if self.connection.get_autocommit():
raise TransactionManagementError('select_for_update cannot be used outside of a transaction.')

nowait = self.query.select_for_update_nowait
skip_locked = self.query.select_for_update_skip_locked
# If it's a NOWAIT/SKIP LOCKED query but the backend
# doesn't support it, raise a DatabaseError to prevent a
# possible deadlock.
if nowait and not self.connection.features.has_select_for_update_nowait:
raise DatabaseError('NOWAIT is not supported on this database backend.')
elif skip_locked and not self.connection.features.has_select_for_update_skip_locked:
raise DatabaseError('SKIP LOCKED is not supported on this database backend.')
for_update_part = self.connection.ops.for_update_sql(nowait=nowait, skip_locked=skip_locked)

if for_update_part and self.connection.features.for_update_after_from:
result.append(for_update_part)

if where:
result.append('WHERE %s' % where)
params.extend(w_params)

grouping = []
for g_sql, g_params in group_by:
grouping.append(g_sql)
params.extend(g_params)
if grouping:
if distinct_fields:
raise NotImplementedError('annotate() + distinct(fields) is not implemented.')
if not order_by:
order_by = self.connection.ops.force_no_ordering()
result.append('GROUP BY %s' % ', '.join(grouping))

if having:
result.append('HAVING %s' % having)
params.extend(h_params)

if order_by:
ordering = []
Expand Down
8 changes: 8 additions & 0 deletions django/db/models/sql/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,11 @@ def __init__(self, model, where=WhereNode):
self.annotation_select_mask = None
self._annotation_select_cache = None

# Set combination attributes
self.combinator = None
self.combinator_all = False
self.combined_queries = ()

# These are for extensions. The contents are more or less appended
# verbatim to the appropriate clause.
# The _extra attribute is an OrderedDict, lazily created similarly to
Expand Down Expand Up @@ -303,6 +308,9 @@ def clone(self, klass=None, memo=None, **kwargs):
# used.
obj._annotation_select_cache = None
obj.max_depth = self.max_depth
obj.combinator = self.combinator
obj.combinator_all = self.combinator_all
obj.combined_queries = self.combined_queries
obj._extra = self._extra.copy() if self._extra is not None else None
if self.extra_select_mask is None:
obj.extra_select_mask = None
Expand Down
55 changes: 55 additions & 0 deletions docs/ref/models/querysets.txt
Original file line number Diff line number Diff line change
Expand Up @@ -801,6 +801,61 @@ typically caches its results. If the data in the database might have changed
since a ``QuerySet`` was evaluated, you can get updated results for the same
query by calling ``all()`` on a previously evaluated ``QuerySet``.

``union()``
~~~~~~~~~~~

.. method:: union(*other_qs, all=False)

.. versionadded:: 1.11

Uses SQL's ``UNION`` operator to combine the results of two or more
``QuerySet``\s. For example:

>>> qs1.union(qs2, qs3)

The ``UNION`` operator selects only distinct values by default. To allow
duplicate values, use the ``all=True`` argument.

``union()``, ``intersection()``, and ``difference()`` return model instances
of the type of the first ``QuerySet`` even if the arguments are ``QuerySet``\s
of other models. Passing different models works as long as the ``SELECT`` list
is the same in all ``QuerySet``\s (at least the types, the names don't matter
as long as the types in the same order).

In addition, only ``LIMIT``, ``OFFSET``, and ``ORDER BY`` (i.e. slicing and
:meth:`order_by`) are allowed on the resulting ``QuerySet``. Further, databases
place restrictions on what operations are allowed in the combined queries. For
example, most databases don't allow ``LIMIT`` or ``OFFSET`` in the combined
queries.

``intersection()``
~~~~~~~~~~~~~~~~~~

.. method:: intersection(*other_qs)

.. versionadded:: 1.11

Uses SQL's ``INTERSECT`` operator to return the shared elements of two or more
``QuerySet``\s. For example:

>>> qs1.itersect(qs2, qs3)

See :meth:`union` for some restrictions.

``difference()``
~~~~~~~~~~~~~~~~

.. method:: difference(*other_qs)

.. versionadded:: 1.11

Uses SQL's ``EXCEPT`` operator to keep only elements present in the
``QuerySet`` but not in some other ``QuerySet``\s. For example::

>>> qs1.difference(qs2, qs3)

See :meth:`union` for some restrictions.

``select_related()``
~~~~~~~~~~~~~~~~~~~~

Expand Down
3 changes: 3 additions & 0 deletions docs/releases/1.11.txt
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,9 @@ Models
* The new ``F`` expression ``bitleftshift()`` and ``bitrightshift()`` methods
allow :ref:`bitwise shift operations <using-f-expressions-in-filters>`.

* Added :meth:`.QuerySet.union`, :meth:`~.QuerySet.intersection`, and
:meth:`~.QuerySet.difference`.

Requests and Responses
~~~~~~~~~~~~~~~~~~~~~~

Expand Down
3 changes: 3 additions & 0 deletions tests/basic/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,9 @@ class ManagerTest(SimpleTestCase):
'_insert',
'_update',
'raw',
'union',
'intersection',
'difference',
]

def test_manager_methods(self):
Expand Down
Loading