-
-
Notifications
You must be signed in to change notification settings - Fork 32.8k
Fixed #27718 -- Added QuerySet.union(), intersection(), difference(). #7727
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -816,6 +816,33 @@ def complex_filter(self, filter_obj): | |
else: | ||
return self._filter_or_exclude(None, **filter_obj) | ||
|
||
def _combinator_query(self, combinator, *other_qs, **kwargs): | ||
# Clone the query to inherit the select list and everything | ||
clone = self._clone() | ||
# Clear limits and ordering so they can be reapplied | ||
clone.query.clear_ordering(True) | ||
clone.query.clear_limits() | ||
clone.query.combined_queries = (self.query,) + tuple(qs.query for qs in other_qs) | ||
clone.query.combinator = combinator | ||
clone.query.combinator_all = kwargs.pop('all', False) | ||
return clone | ||
|
||
def union(self, *other_qs, **kwargs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it make sense to have union_all as separate method? That way there would be no need for unbounded kwargs. If not, at least assert here that there are no other kwargs than 'all' given here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like the kwargs more than |
||
if kwargs: | ||
unexpected_kwarg = next((k for k in kwargs.keys() if k != 'all'), None) | ||
if unexpected_kwarg: | ||
raise TypeError( | ||
"union() received an unexpected keyword argument '%s'" % | ||
(unexpected_kwarg,) | ||
) | ||
return self._combinator_query('union', *other_qs, **kwargs) | ||
|
||
def intersection(self, *other_qs): | ||
return self._combinator_query('intersection', *other_qs) | ||
|
||
def difference(self, *other_qs): | ||
return self._combinator_query('difference', *other_qs) | ||
|
||
def select_for_update(self, nowait=False, skip_locked=False): | ||
""" | ||
Returns a new QuerySet instance that will select objects with a | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -309,6 +309,21 @@ def get_order_by(self): | |
seen = set() | ||
|
||
for expr, is_ref in order_by: | ||
if self.query.combinator: | ||
src = expr.get_source_expressions()[0] | ||
# Relabel order by columns to raw numbers if this is a combined | ||
# query; necessary since the columns can't be referenced by the | ||
# fully qualified name and the simple column names may collide. | ||
for idx, (sel_expr, _, col_alias) in enumerate(self.select): | ||
if is_ref and col_alias == src.refs: | ||
src = src.source | ||
elif col_alias: | ||
continue | ||
if src == sel_expr: | ||
expr.set_source_expressions([RawSQL('%d' % (idx + 1), ())]) | ||
break | ||
else: | ||
raise DatabaseError('ORDER BY term does not match any column in the result set.') | ||
resolved = expr.resolve_expression( | ||
self.query, allow_joins=True, reuse=None) | ||
sql, params = self.compile(resolved) | ||
|
@@ -360,6 +375,30 @@ def compile(self, node, select_format=False): | |
return node.output_field.select_format(self, sql, params) | ||
return sql, params | ||
|
||
def get_combinator_sql(self, combinator, all): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Small nit: I would suggest using a name that is not a built-in. Perhaps |
||
features = self.connection.features | ||
compilers = [ | ||
query.get_compiler(self.using, self.connection) | ||
for query in self.query.combined_queries | ||
] | ||
if not features.supports_slicing_ordering_in_compound: | ||
for query, compiler in zip(self.query.combined_queries, compilers): | ||
if query.low_mark or query.high_mark: | ||
raise DatabaseError('LIMIT/OFFSET not allowed in subqueries of compound statements.') | ||
if compiler.get_order_by(): | ||
raise DatabaseError('ORDER BY not allowed in subqueries of compound statements.') | ||
parts = (compiler.as_sql() for compiler in compilers) | ||
combinator_sql = self.connection.ops.set_operators[combinator] | ||
if all and combinator == 'union': | ||
combinator_sql += ' ALL' | ||
braces = '({})' if features.supports_slicing_ordering_in_compound else '{}' | ||
sql_parts, args_parts = zip(*((braces.format(sql), args) for sql, args in parts)) | ||
result = [' {} '.format(combinator_sql).join(sql_parts)] | ||
params = [] | ||
for part in args_parts: | ||
params.extend(part) | ||
return result, params | ||
|
||
def as_sql(self, with_limits=True, with_col_aliases=False): | ||
""" | ||
Creates the SQL for this query. Returns the SQL string and list of | ||
|
@@ -377,69 +416,76 @@ def as_sql(self, with_limits=True, with_col_aliases=False): | |
# docstring of get_from_clause() for details. | ||
from_, f_params = self.get_from_clause() | ||
|
||
for_update_part = None | ||
where, w_params = self.compile(self.where) if self.where is not None else ("", []) | ||
having, h_params = self.compile(self.having) if self.having is not None else ("", []) | ||
params = [] | ||
result = ['SELECT'] | ||
|
||
if self.query.distinct: | ||
result.append(self.connection.ops.distinct_sql(distinct_fields)) | ||
|
||
out_cols = [] | ||
col_idx = 1 | ||
for _, (s_sql, s_params), alias in self.select + extra_select: | ||
if alias: | ||
s_sql = '%s AS %s' % (s_sql, self.connection.ops.quote_name(alias)) | ||
elif with_col_aliases: | ||
s_sql = '%s AS %s' % (s_sql, 'Col%d' % col_idx) | ||
col_idx += 1 | ||
params.extend(s_params) | ||
out_cols.append(s_sql) | ||
|
||
result.append(', '.join(out_cols)) | ||
|
||
result.append('FROM') | ||
result.extend(from_) | ||
params.extend(f_params) | ||
|
||
for_update_part = None | ||
if self.query.select_for_update and self.connection.features.has_select_for_update: | ||
if self.connection.get_autocommit(): | ||
raise TransactionManagementError("select_for_update cannot be used outside of a transaction.") | ||
|
||
nowait = self.query.select_for_update_nowait | ||
skip_locked = self.query.select_for_update_skip_locked | ||
# If it's a NOWAIT/SKIP LOCKED query but the backend doesn't | ||
# support it, raise a DatabaseError to prevent a possible | ||
# deadlock. | ||
if nowait and not self.connection.features.has_select_for_update_nowait: | ||
raise DatabaseError('NOWAIT is not supported on this database backend.') | ||
elif skip_locked and not self.connection.features.has_select_for_update_skip_locked: | ||
raise DatabaseError('SKIP LOCKED is not supported on this database backend.') | ||
for_update_part = self.connection.ops.for_update_sql(nowait=nowait, skip_locked=skip_locked) | ||
|
||
if for_update_part and self.connection.features.for_update_after_from: | ||
result.append(for_update_part) | ||
|
||
if where: | ||
result.append('WHERE %s' % where) | ||
params.extend(w_params) | ||
|
||
grouping = [] | ||
for g_sql, g_params in group_by: | ||
grouping.append(g_sql) | ||
params.extend(g_params) | ||
if grouping: | ||
if distinct_fields: | ||
raise NotImplementedError( | ||
"annotate() + distinct(fields) is not implemented.") | ||
if not order_by: | ||
order_by = self.connection.ops.force_no_ordering() | ||
result.append('GROUP BY %s' % ', '.join(grouping)) | ||
|
||
if having: | ||
result.append('HAVING %s' % having) | ||
params.extend(h_params) | ||
combinator = self.query.combinator | ||
features = self.connection.features | ||
if combinator: | ||
if not getattr(features, 'supports_select_{}'.format(combinator)): | ||
raise DatabaseError('{} not supported on this database backend.'.format(combinator)) | ||
result, params = self.get_combinator_sql(combinator, self.query.combinator_all) | ||
else: | ||
result = ['SELECT'] | ||
params = [] | ||
|
||
if self.query.distinct: | ||
result.append(self.connection.ops.distinct_sql(distinct_fields)) | ||
|
||
out_cols = [] | ||
col_idx = 1 | ||
for _, (s_sql, s_params), alias in self.select + extra_select: | ||
if alias: | ||
s_sql = '%s AS %s' % (s_sql, self.connection.ops.quote_name(alias)) | ||
elif with_col_aliases: | ||
s_sql = '%s AS %s' % (s_sql, 'Col%d' % col_idx) | ||
col_idx += 1 | ||
params.extend(s_params) | ||
out_cols.append(s_sql) | ||
|
||
result.append(', '.join(out_cols)) | ||
|
||
result.append('FROM') | ||
result.extend(from_) | ||
params.extend(f_params) | ||
|
||
if self.query.select_for_update and self.connection.features.has_select_for_update: | ||
if self.connection.get_autocommit(): | ||
raise TransactionManagementError('select_for_update cannot be used outside of a transaction.') | ||
|
||
nowait = self.query.select_for_update_nowait | ||
skip_locked = self.query.select_for_update_skip_locked | ||
# If it's a NOWAIT/SKIP LOCKED query but the backend | ||
# doesn't support it, raise a DatabaseError to prevent a | ||
# possible deadlock. | ||
if nowait and not self.connection.features.has_select_for_update_nowait: | ||
raise DatabaseError('NOWAIT is not supported on this database backend.') | ||
elif skip_locked and not self.connection.features.has_select_for_update_skip_locked: | ||
raise DatabaseError('SKIP LOCKED is not supported on this database backend.') | ||
for_update_part = self.connection.ops.for_update_sql(nowait=nowait, skip_locked=skip_locked) | ||
|
||
if for_update_part and self.connection.features.for_update_after_from: | ||
result.append(for_update_part) | ||
|
||
if where: | ||
result.append('WHERE %s' % where) | ||
params.extend(w_params) | ||
|
||
grouping = [] | ||
for g_sql, g_params in group_by: | ||
grouping.append(g_sql) | ||
params.extend(g_params) | ||
if grouping: | ||
if distinct_fields: | ||
raise NotImplementedError('annotate() + distinct(fields) is not implemented.') | ||
if not order_by: | ||
order_by = self.connection.ops.force_no_ordering() | ||
result.append('GROUP BY %s' % ', '.join(grouping)) | ||
|
||
if having: | ||
result.append('HAVING %s' % having) | ||
params.extend(h_params) | ||
|
||
if order_by: | ||
ordering = [] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How did you choose the defaults? I think it would be simplest to default to True and have backends opt-out as needed (that eliminates the risk that a backend fails to opt-in) but feel free to explain your thinking.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, default on sounds more sensible to me since they're part of the SQL standard
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was looking for the lowest common denominator, but have no problem to switch that to true. I'd leave
supports_slicing_ordering_in_compound
as False though since most databases do not support it and afaik it is an extension to the standard.