Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

Fixed #10362 -- An update() that only affects a parent model no longe…

…r crashes.

This includes a fairly large refactor of the update() query path (and
the initial portions of constructing the SQL for any query). The
previous code appears to have been only working more or less by accident
and was very fragile.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@9967 bcc190cf-cafb-0310-a4f2-bffc1f526a37
  • Loading branch information...
commit 0e93f60c7f6de821e31424b0e7c26586155a7a1a 1 parent 14c8e52
Malcolm Tredinnick malcolmt authored
83 django/db/models/sql/query.py
View
@@ -62,6 +62,7 @@ def __init__(self, model, connection, where=WhereNode):
self.dupe_avoidance = {}
self.used_aliases = set()
self.filter_is_sticky = False
+ self.included_inherited_models = {}
# SQL-related attributes
self.select = []
@@ -171,6 +172,7 @@ def clone(self, klass=None, **kwargs):
obj.default_cols = self.default_cols
obj.default_ordering = self.default_ordering
obj.standard_ordering = self.standard_ordering
+ obj.included_inherited_models = self.included_inherited_models.copy()
obj.ordering_aliases = []
obj.select_fields = self.select_fields[:]
obj.related_select_fields = self.related_select_fields[:]
@@ -304,6 +306,7 @@ def get_aggregation(self):
self.select = []
self.default_cols = False
self.extra_select = {}
+ self.remove_inherited_models()
query.clear_ordering(True)
query.clear_limits()
@@ -458,6 +461,7 @@ def combine(self, rhs, connector):
assert self.distinct == rhs.distinct, \
"Cannot combine a unique query with a non-unique query."
+ self.remove_inherited_models()
# Work out how to relabel the rhs aliases, if necessary.
change_map = {}
used = set()
@@ -540,6 +544,9 @@ def pre_sql_setup(self):
"""
if not self.tables:
self.join((None, self.model._meta.db_table, None, None))
+ if (not self.select and self.default_cols and not
+ self.included_inherited_models):
+ self.setup_inherited_models()
if self.select_related and not self.related_select_cols:
self.fill_related_selections()
@@ -619,7 +626,9 @@ def get_default_columns(self, with_aliases=False, col_aliases=None,
start_alias=None, opts=None, as_pairs=False):
"""
Computes the default columns for selecting every field in the base
- model.
+ model. Will sometimes be called to pull in related models (e.g. via
+ select_related), in which case "opts" and "start_alias" will be given
+ to provide a starting point for the traversal.
Returns a list of strings, quoted appropriately for use in SQL
directly, as well as a set of aliases used in the select statement (if
@@ -629,22 +638,25 @@ def get_default_columns(self, with_aliases=False, col_aliases=None,
result = []
if opts is None:
opts = self.model._meta
- if start_alias:
- table_alias = start_alias
- else:
- table_alias = self.tables[0]
- root_pk = opts.pk.column
- seen = {None: table_alias}
qn = self.quote_name_unless_alias
qn2 = self.connection.ops.quote_name
aliases = set()
+ if start_alias:
+ seen = {None: start_alias}
+ root_pk = opts.pk.column
for field, model in opts.get_fields_with_model():
- try:
- alias = seen[model]
- except KeyError:
- alias = self.join((table_alias, model._meta.db_table,
- root_pk, model._meta.pk.column))
- seen[model] = alias
+ if start_alias:
+ try:
+ alias = seen[model]
+ except KeyError:
+ alias = self.join((start_alias, model._meta.db_table,
+ root_pk, model._meta.pk.column))
+ seen[model] = alias
+ else:
+ # If we're starting from the base model of the queryset, the
+ # aliases will have already been set up in pre_sql_setup(), so
+ # we can save time here.
+ alias = self.included_inherited_models[model]
if as_pairs:
result.append((alias, field.column))
continue
@@ -996,6 +1008,9 @@ def change_aliases(self, change_map):
if alias == old_alias:
self.tables[pos] = new_alias
break
+ for key, alias in self.included_inherited_models.items():
+ if alias in change_map:
+ self.included_inherited_models[key] = change_map[alias]
# 3. Update any joins that refer to the old alias.
for alias, data in self.alias_map.iteritems():
@@ -1062,9 +1077,11 @@ def join(self, connection, always_create=False, exclusions=(),
lhs.lhs_col = table.col
If 'always_create' is True and 'reuse' is None, a new alias is always
- created, regardless of whether one already exists or not. Otherwise
- 'reuse' must be a set and a new join is created unless one of the
- aliases in `reuse` can be used.
+ created, regardless of whether one already exists or not. If
+ 'always_create' is True and 'reuse' is a set, an alias in 'reuse' that
+ matches the connection will be returned, if possible. If
+ 'always_create' is False, the first existing alias that matches the
+ 'connection' is returned, if any. Otherwise a new join is created.
If 'exclusions' is specified, it is something satisfying the container
protocol ("foo in exclusions" must work) and specifies a list of
@@ -1126,6 +1143,38 @@ def join(self, connection, always_create=False, exclusions=(),
self.rev_join_map[alias] = t_ident
return alias
+ def setup_inherited_models(self):
+ """
+ If the model that is the basis for this QuerySet inherits other models,
+ we need to ensure that those other models have their tables included in
+ the query.
+
+ We do this as a separate step so that subclasses know which
+ tables are going to be active in the query, without needing to compute
+ all the select columns (this method is called from pre_sql_setup(),
+ whereas column determination is a later part, and side-effect, of
+ as_sql()).
+ """
+ opts = self.model._meta
+ root_pk = opts.pk.column
+ root_alias = self.tables[0]
+ seen = {None: root_alias}
+ for field, model in opts.get_fields_with_model():
+ if model not in seen:
+ seen[model] = self.join((root_alias, model._meta.db_table,
+ root_pk, model._meta.pk.column))
+ self.included_inherited_models = seen
+
+ def remove_inherited_models(self):
+ """
+ Undoes the effects of setup_inherited_models(). Should be called
+ whenever select columns (self.select) are set explicitly.
+ """
+ for key, alias in self.included_inherited_models.items():
+ if key:
+ self.unref_alias(alias)
+ self.included_inherited_models = {}
+
def fill_related_selections(self, opts=None, root_alias=None, cur_depth=1,
used=None, requested=None, restricted=None, nullable=None,
dupe_set=None, avoid_set=None):
@@ -1803,6 +1852,7 @@ def add_fields(self, field_names, allow_m2m=True):
names.sort()
raise FieldError("Cannot resolve keyword %r into field. "
"Choices are: %s" % (name, ", ".join(names)))
+ self.remove_inherited_models()
def add_ordering(self, *ordering):
"""
@@ -2004,6 +2054,7 @@ def set_start(self, start):
select_alias = join_info[RHS_ALIAS]
select_col = join_info[RHS_JOIN_COL]
self.select = [(select_alias, select_col)]
+ self.remove_inherited_models()
def execute_sql(self, result_type=MULTI):
"""
18 django/db/models/sql/subqueries.py
View
@@ -179,21 +179,9 @@ def pre_sql_setup(self):
query = self.clone(klass=Query)
query.bump_prefix()
query.extra_select = {}
- first_table = query.tables[0]
- if query.alias_refcount[first_table] == 1:
- # We can remove one table from the inner query.
- query.unref_alias(first_table)
- for i in xrange(1, len(query.tables)):
- table = query.tables[i]
- if query.alias_refcount[table]:
- break
- join_info = query.alias_map[table]
- query.select = [(join_info[RHS_ALIAS], join_info[RHS_JOIN_COL])]
- must_pre_select = False
- else:
- query.select = []
- query.add_fields([query.model._meta.pk.name])
- must_pre_select = not self.connection.features.update_can_self_select
+ query.select = []
+ query.add_fields([query.model._meta.pk.name])
+ must_pre_select = count > 1 and not self.connection.features.update_can_self_select
# Now we adjust the current query: reset the where clause and get rid
# of all the tables we don't need (since they're in the sub-select).
10 tests/regressiontests/model_inheritance_regress/models.py
View
@@ -222,7 +222,7 @@ class QualityControl(Evaluation):
>>> obj = SelfRefChild.objects.create(child_data=37, parent_data=42)
>>> obj.delete()
-# Regression tests for #8076 - get_(next/previous)_by_date should
+# Regression tests for #8076 - get_(next/previous)_by_date should work.
>>> c1 = ArticleWithAuthor(headline='ArticleWithAuthor 1', author="Person 1", pub_date=datetime.datetime(2005, 8, 1, 3, 0))
>>> c1.save()
>>> c2 = ArticleWithAuthor(headline='ArticleWithAuthor 2', author="Person 2", pub_date=datetime.datetime(2005, 8, 1, 10, 0))
@@ -267,4 +267,12 @@ class QualityControl(Evaluation):
>>> fragment.find('pub_date', pos + 1) == -1
True
+# It is possible to call update() and only change a field in an ancestor model
+# (regression test for #10362).
+>>> article = ArticleWithAuthor.objects.create(author="fred", headline="Hey there!", pub_date = datetime.datetime(2009, 3, 1, 8, 0, 0))
+>>> ArticleWithAuthor.objects.filter(author="fred").update(headline="Oh, no!")
+1
+>>> ArticleWithAuthor.objects.filter(pk=article.pk).update(headline="Oh, no!")
+1
+
"""}
Please sign in to comment.
Something went wrong with that request. Please try again.