Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP

Comparing changes

Choose two branches to see what's changed or to start a new pull request. If you need to, you can also compare across forks.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also compare across forks.
base fork: django/django
...
head fork: akaariai/django
compare: ticket_19385
Checking mergeability… Don't worry, you can still create the pull request.
  • 13 commits
  • 13 files changed
  • 0 commit comments
  • 1 contributor
View
24 django/contrib/contenttypes/generic.py
@@ -11,6 +11,7 @@
from django.db.models import signals
from django.db import models, router, DEFAULT_DB_ALIAS
from django.db.models.fields.related import RelatedField, Field, ManyToManyRel
+from django.db.models.related import PathInfo
from django.forms import ModelForm
from django.forms.models import BaseModelFormSet, modelformset_factory, save_instance
from django.contrib.admin.options import InlineModelAdmin, flatten_fieldsets
@@ -160,6 +161,21 @@ def __init__(self, to, **kwargs):
kwargs['serialize'] = False
Field.__init__(self, **kwargs)
+ def get_path_info(self, direct=False):
+ # This join is going to be one-directional...
+ assert direct is False
+ # Gotcha! This is just a fake m2m field - a generic relation
+ # field).
+ from_field = self.model._meta.pk
+ opts = self.rel.to._meta
+ target = opts.get_field_by_name(self.object_id_field_name)[0]
+ # Note that we are using different field for the join_field
+ # than from_field or to_field. This is a hack, but we need the
+ # GenericRelation to generate the extra SQL.
+ # TODO: check the m2m and direct flags (the last two Falses there)
+ return ([PathInfo(from_field, target, self.model._meta, opts, self, True, False)],
+ opts, target, self)
+
def get_choices_default(self):
return Field.get_choices(self, include_blank=False)
@@ -211,10 +227,14 @@ def get_content_type(self):
"""
return ContentType.objects.get_for_model(self.model)
- def get_extra_join_sql(self, connection, qn, lhs_alias, rhs_alias):
+ def get_join_sql(self, connection, qn, lhs_alias, rhs_alias, direct):
+ rhs_col = self.rel.to._meta.get_field_by_name(self.object_id_field_name)[0].column
+ lhs_col = self.model._meta.pk.column
+ clause1 = connection.ops.join_sql(qn, lhs_alias, rhs_alias, lhs_col, rhs_col)
extra_col = self.rel.to._meta.get_field_by_name(self.content_type_field_name)[0].column
contenttype = self.get_content_type().pk
- return " AND %s.%s = %%s" % (qn(rhs_alias), qn(extra_col)), [contenttype]
+ clause2 = '%s.%s = %%s' % (qn(rhs_alias), qn(extra_col))
+ return "%s AND %s" % (clause1, clause2), [contenttype]
def bulk_related_objects(self, objs, using=DEFAULT_DB_ALIAS):
"""
View
3  django/db/backends/__init__.py
@@ -468,6 +468,9 @@ def __init__(self, connection):
self.connection = connection
self._cache = None
+ def join_sql(self, qn, lhs_alias, rhs_alias, lhs_col, rhs_col):
+ return '%s.%s = %s.%s' % (qn(lhs_alias), qn(lhs_col), qn(rhs_alias), qn(rhs_col))
+
def autoinc_sql(self, table, column):
"""
Returns any SQL needed to support auto-incrementing primary keys, or
View
40 django/db/models/fields/__init__.py
@@ -9,7 +9,7 @@
from itertools import tee
from django.db import connection
-from django.db.models.query_utils import QueryWrapper
+from django.db.models.query_utils import QueryWrapper, InvalidQuery
from django.conf import settings
from django import forms
from django.core import exceptions, validators
@@ -269,6 +269,9 @@ def get_attname_column(self):
def get_cache_name(self):
return '_%s_cache' % self.name
+ def get_related_cache_name(self):
+ return self.related.get_cache_name()
+
def get_internal_type(self):
return self.__class__.__name__
@@ -513,6 +516,41 @@ def __repr__(self):
return '<%s: %s>' % (path, name)
return '<%s>' % path
+ def select_related_descend(self, restricted, requested, load_fields, reverse=False):
+ """
+ Returns True if this field should be used to descend deeper for
+ select_related() purposes. Used by both the query construction code
+ (sql.query.fill_related_selections()) and the model instance creation code
+ (query.get_klass_info()).
+
+ Arguments:
+ * restricted - a boolean field, indicating if the field list has been
+ manually restricted using a requested clause)
+ * requested - The select_related() dictionary.
+ * load_fields - the set of fields to be loaded on this model
+ * reverse - boolean, True if we are checking a reverse select related
+ """
+ if not self.rel:
+ return False
+ if self.rel.parent_link and not reverse:
+ return False
+ if restricted:
+ if reverse and self.related_query_name() not in requested:
+ return False
+ if not reverse and self.name not in requested:
+ return False
+ if not restricted and self.null:
+ return False
+ if load_fields:
+ if self.name not in load_fields:
+ if restricted and self.name in requested:
+ raise InvalidQuery("Field %s.%s cannot be both deferred"
+ " and traversed using select_related"
+ " at the same time." %
+ (self.model._meta.object_name, self.name))
+ return False
+ return True
+
class AutoField(Field):
description = _("Integer")
View
45 django/db/models/fields/related.py
@@ -5,7 +5,7 @@
from django.db.models import signals, get_model
from django.db.models.fields import (AutoField, Field, IntegerField,
PositiveIntegerField, PositiveSmallIntegerField, FieldDoesNotExist)
-from django.db.models.related import RelatedObject
+from django.db.models.related import RelatedObject, PathInfo
from django.db.models.query import QuerySet
from django.db.models.query_utils import QueryWrapper
from django.db.models.deletion import CASCADE
@@ -16,7 +16,6 @@
from django.core import exceptions
from django import forms
-
RECURSIVE_RELATIONSHIP_CONSTANT = 'self'
pending_lookups = {}
@@ -1004,6 +1003,33 @@ def __init__(self, to, to_field=None, rel_class=ManyToOneRel, **kwargs):
)
Field.__init__(self, **kwargs)
+ def get_join_sql(self, connection, qn, lhs_alias, rhs_alias, direct):
+ lhs_col = self.column
+ rhs_col = self.rel.get_related_field().column
+ if not direct:
+ lhs_col, rhs_col = rhs_col, lhs_col
+ return connection.ops.join_sql(qn, lhs_alias, rhs_alias, lhs_col, rhs_col), []
+
+ def get_path_info(self, direct=True):
+ if direct:
+ opts = self.rel.to._meta
+ target = self.rel.get_related_field()
+ from_opts = self.model._meta
+ return [PathInfo(self, target, from_opts, opts, self, False, True)], opts, target, self
+ else:
+ opts = self.model._meta
+ from_field = self.rel.get_related_field()
+ from_opts = from_field.model._meta
+ pathinfos = [PathInfo(from_field, self, from_opts, opts, self, not self.unique, False)]
+ if from_field.model is self.model:
+ # Recursive foreign key to self.
+ target = opts.get_field_by_name(
+ self.rel.field_name)[0]
+ else:
+ target = opts.pk
+ return pathinfos, opts, target, self
+
+
def validate(self, value, model_instance):
if self.rel.parent_link:
return
@@ -1198,6 +1224,21 @@ def __init__(self, to, **kwargs):
msg = _('Hold down "Control", or "Command" on a Mac, to select more than one.')
self.help_text = string_concat(self.help_text, ' ', msg)
+ def get_path_info(self, direct=True):
+ pathinfos = []
+ int_model = self.rel.through
+ linkfield1 = int_model._meta.get_field_by_name(self.m2m_field_name())[0]
+ linkfield2 = int_model._meta.get_field_by_name(self.m2m_reverse_field_name())[0]
+ if direct:
+ join1infos, _, _, _ = linkfield1.get_path_info(direct=False)
+ join2infos, opts, target, final_field = linkfield2.get_path_info()
+ else:
+ join1infos, _, _, _ = linkfield2.get_path_info(direct=False)
+ join2infos, opts, target, final_field = linkfield1.get_path_info()
+ pathinfos.extend(join1infos)
+ pathinfos.extend(join2infos)
+ return pathinfos, opts, target, final_field
+
def get_choices_default(self):
return Field.get_choices(self, include_blank=False)
View
22 django/db/models/query.py
@@ -11,8 +11,7 @@
from django.db import connections, router, transaction, IntegrityError
from django.db.models.constants import LOOKUP_SEP
from django.db.models.fields import AutoField
-from django.db.models.query_utils import (Q, select_related_descend,
- deferred_class_factory, InvalidQuery)
+from django.db.models.query_utils import (Q, deferred_class_factory, InvalidQuery)
from django.db.models.deletion import Collector
from django.db.models import sql
from django.utils.functional import partition
@@ -1386,21 +1385,22 @@ def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
restricted = requested is not None
related_fields = []
- for f in klass._meta.fields:
- if select_related_descend(f, restricted, requested, load_fields):
+ for f, _ in klass._meta.get_fields_with_model():
+ if f.select_related_descend(restricted, requested, load_fields):
+ _, next_opts, _, _ = f.get_path_info()
if restricted:
next = requested[f.name]
else:
next = None
- klass_info = get_klass_info(f.rel.to, max_depth=max_depth, cur_depth=cur_depth+1,
+ klass_info = get_klass_info(next_opts.model, max_depth=max_depth, cur_depth=cur_depth+1,
requested=next, only_load=only_load)
related_fields.append((f, klass_info))
reverse_related_fields = []
if restricted:
for o in klass._meta.get_all_related_objects():
- if o.field.unique and select_related_descend(o.field, restricted, requested,
- only_load.get(o.model), reverse=True):
+ if o.field.unique and o.field.select_related_descend(
+ restricted, requested, only_load.get(o.model), reverse=True):
next = requested[o.field.related_query_name()]
parent = klass if issubclass(o.model, klass) else None
klass_info = get_klass_info(o.model, max_depth=max_depth, cur_depth=cur_depth+1,
@@ -1414,7 +1414,7 @@ def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
return klass, field_names, field_count, related_fields, reverse_related_fields, pk_idx
-def get_cached_row(row, index_start, using, klass_info, offset=0,
+def get_cached_row(row, index_start, using, klass_info, offset=0,
parent_data=()):
"""
Helper function that recursively returns an object with the specified
@@ -1476,7 +1476,7 @@ def get_cached_row(row, index_start, using, klass_info, offset=0,
if f.unique and rel_obj is not None:
# If the field is unique, populate the
# reverse descriptor cache on the related object
- setattr(rel_obj, f.related.get_cache_name(), obj)
+ setattr(rel_obj, f.get_related_cache_name(), obj)
# Now do the same, but for reverse related objects.
# Only handle the restricted case - i.e., don't do a depth
@@ -1489,14 +1489,14 @@ def get_cached_row(row, index_start, using, klass_info, offset=0,
parent_data.append((rel_field, getattr(obj, rel_field.attname)))
# Recursively retrieve the data for the related object
cached_row = get_cached_row(row, index_end, using, klass_info,
- parent_data=parent_data)
+ parent_data=parent_data)
# If the recursive descent found an object, populate the
# descriptor caches relevant to the object
if cached_row:
rel_obj, index_end = cached_row
if obj is not None:
# populate the reverse descriptor cache
- setattr(obj, f.related.get_cache_name(), rel_obj)
+ setattr(obj, f.get_related_cache_name(), rel_obj)
if rel_obj is not None:
# If the related object exists, populate
# the descriptor cache.
View
36 django/db/models/query_utils.py
@@ -127,42 +127,6 @@ def _check_parent_chain(self, instance, name):
return None
-def select_related_descend(field, restricted, requested, load_fields, reverse=False):
- """
- Returns True if this field should be used to descend deeper for
- select_related() purposes. Used by both the query construction code
- (sql.query.fill_related_selections()) and the model instance creation code
- (query.get_klass_info()).
-
- Arguments:
- * field - the field to be checked
- * restricted - a boolean field, indicating if the field list has been
- manually restricted using a requested clause)
- * requested - The select_related() dictionary.
- * load_fields - the set of fields to be loaded on this model
- * reverse - boolean, True if we are checking a reverse select related
- """
- if not field.rel:
- return False
- if field.rel.parent_link and not reverse:
- return False
- if restricted:
- if reverse and field.related_query_name() not in requested:
- return False
- if not reverse and field.name not in requested:
- return False
- if not restricted and field.null:
- return False
- if load_fields:
- if field.name not in load_fields:
- if restricted and field.name in requested:
- raise InvalidQuery("Field %s.%s cannot be both deferred"
- " and traversed using select_related"
- " at the same time." %
- (field.model._meta.object_name, field.name))
- return False
- return True
-
# This function is needed because data descriptors must be defined on a class
# object, not an instance, to have any effect.
View
12 django/db/models/related.py
@@ -1,6 +1,15 @@
+from collections import namedtuple
+
from django.utils.encoding import smart_text
from django.db.models.fields import BLANK_CHOICE_DASH
+# PathInfo is used when converting lookups (fk__somecol). The contents
+# describe the relation in Model terms (model Options and Fields for both
+# sides of the relation. The join_field is the field backing the relation.
+PathInfo = namedtuple('PathInfo',
+ 'from_field to_field from_opts to_opts join_field '
+ 'm2m direct')
+
class BoundRelatedObject(object):
def __init__(self, related_object, field_mapping, original):
self.relation = related_object
@@ -67,3 +76,6 @@ def get_accessor_name(self):
def get_cache_name(self):
return "_%s_cache" % self.get_accessor_name()
+
+ def get_path_info(self):
+ return self.field.get_path_info(direct=False)
View
113 django/db/models/sql/compiler.py
@@ -4,7 +4,6 @@
from django.db import transaction
from django.db.backends.util import truncate_name
from django.db.models.constants import LOOKUP_SEP
-from django.db.models.query_utils import select_related_descend
from django.db.models.sql.constants import (SINGLE, MULTI, ORDER_DIR,
GET_ITERATOR_CHUNK_SIZE, REUSE_ALL, SelectInfo)
from django.db.models.sql.datastructures import EmptyResultSet
@@ -261,27 +260,17 @@ def get_default_columns(self, with_aliases=False, col_aliases=None,
qn2 = self.connection.ops.quote_name
aliases = set()
only_load = self.deferred_to_columns()
-
+ seen = self.query.included_inherited_models.copy()
if start_alias:
- seen = {None: start_alias}
+ seen[None] = start_alias
for field, model in opts.get_fields_with_model():
+ if not hasattr(field, 'column'):
+ # Some sort of virtual field
+ continue
if from_parent and model is not None and issubclass(from_parent, model):
# Avoid loading data for already loaded parents.
continue
- if start_alias:
- try:
- alias = seen[model]
- except KeyError:
- link_field = opts.get_ancestor_link(model)
- alias = self.query.join((start_alias, model._meta.db_table,
- link_field.column, model._meta.pk.column),
- join_field=link_field)
- seen[model] = alias
- else:
- # If we're starting from the base model of the queryset, the
- # aliases will have already been set up in pre_sql_setup(), so
- # we can save time here.
- alias = self.query.included_inherited_models[model]
+ alias = self.query.join_parent_model(opts, model, start_alias, seen)
table = self.query.alias_map[alias].table_name
if table in only_load and field.column not in only_load[table]:
continue
@@ -480,11 +469,13 @@ def _final_join_removal(self, col, alias):
if alias:
while 1:
join = self.query.alias_map[alias]
- if col != join.rhs_join_col:
+ if not join.direct:
+ break
+ elif col != join.join_field.column:
break
self.query.unref_alias(alias)
alias = join.lhs_alias
- col = join.lhs_join_col
+ col = join.join_field.rel.get_related_field().column
return col, alias
def get_from_clause(self):
@@ -507,22 +498,18 @@ def get_from_clause(self):
if not self.query.alias_refcount[alias]:
continue
try:
- name, alias, join_type, lhs, lhs_col, col, _, join_field = self.query.alias_map[alias]
+ name, alias, join_type, lhs, _, join_field, direct = self.query.alias_map[alias]
except KeyError:
# Extra tables can end up in self.tables, but not in the
# alias_map if they aren't in a join. That's OK. We skip them.
continue
alias_str = (alias != name and ' %s' % alias or '')
if join_type and not first:
- if join_field and hasattr(join_field, 'get_extra_join_sql'):
- extra_cond, extra_params = join_field.get_extra_join_sql(
- self.connection, qn, lhs, alias)
- from_params.extend(extra_params)
- else:
- extra_cond = ""
- result.append('%s %s%s ON (%s.%s = %s.%s%s)' %
- (join_type, qn(name), alias_str, qn(lhs),
- qn2(lhs_col), qn(alias), qn2(col), extra_cond))
+ join_sql, join_params = join_field.get_join_sql(
+ self.connection, qn, lhs, alias, direct)
+ result.append('%s %s%s ON (%s)' %
+ (join_type, qn(name), alias_str, join_sql))
+ from_params.extend(join_params)
else:
connector = not first and ', ' or ''
result.append('%s%s%s' % (connector, qn(name), alias_str))
@@ -618,45 +605,25 @@ def fill_related_selections(self, opts=None, root_alias=None, cur_depth=1,
# in the field's local model. So, for those fields we want to use
# the f.model - that is the field's local model.
field_model = model or f.model
- if not select_related_descend(f, restricted, requested,
- only_load.get(field_model)):
+ if not f.select_related_descend(restricted, requested,
+ only_load.get(field_model)):
continue
- table = f.rel.to._meta.db_table
promote = nullable or f.null
- if model:
- int_opts = opts
- alias = root_alias
- alias_chain = []
- for int_model in opts.get_base_chain(model):
- # Proxy model have elements in base chain
- # with no parents, assign the new options
- # object and skip to the next base in that
- # case
- if not int_opts.parents[int_model]:
- int_opts = int_model._meta
- continue
- lhs_col = int_opts.parents[int_model].column
- int_opts = int_model._meta
- alias = self.query.join((alias, int_opts.db_table, lhs_col,
- int_opts.pk.column),
- promote=promote)
- alias_chain.append(alias)
- else:
- alias = root_alias
+ alias = self.query.join_parent_model(opts, model, root_alias, {})
- alias = self.query.join((alias, table, f.column,
- f.rel.get_related_field().column),
- promote=promote, join_field=f)
+ pathinfos, f_opts, _, _ = f.get_path_info()
+ table = f_opts.db_table
+ alias = self.query.join((alias, table, f, pathinfos[0].direct), promote=promote)
columns, aliases = self.get_default_columns(start_alias=alias,
- opts=f.rel.to._meta, as_pairs=True)
+ opts=f_opts, as_pairs=True)
self.query.related_select_cols.extend(
- SelectInfo(col, field) for col, field in zip(columns, f.rel.to._meta.fields))
+ SelectInfo(col, field) for col, field in zip(columns, f_opts.fields))
if restricted:
next = requested.get(f.name, {})
else:
next = False
new_nullable = f.null or promote
- self.fill_related_selections(f.rel.to._meta, alias, cur_depth + 1,
+ self.fill_related_selections(f_opts, alias, cur_depth + 1,
next, restricted, new_nullable)
if restricted:
@@ -666,35 +633,15 @@ def fill_related_selections(self, opts=None, root_alias=None, cur_depth=1,
if o.field.unique
]
for f, model in related_fields:
- if not select_related_descend(f, restricted, requested,
- only_load.get(model), reverse=True):
+ if not f.select_related_descend(restricted, requested,
+ only_load.get(model), reverse=True):
continue
+ alias = self.query.join_parent_model(opts, f.rel.to, root_alias, {})
table = model._meta.db_table
- int_opts = opts
- alias = root_alias
- alias_chain = []
- chain = opts.get_base_chain(f.rel.to)
- if chain is not None:
- for int_model in chain:
- # Proxy model have elements in base chain
- # with no parents, assign the new options
- # object and skip to the next base in that
- # case
- if not int_opts.parents[int_model]:
- int_opts = int_model._meta
- continue
- lhs_col = int_opts.parents[int_model].column
- int_opts = int_model._meta
- alias = self.query.join(
- (alias, int_opts.db_table, lhs_col, int_opts.pk.column),
- promote=True,
- )
- alias_chain.append(alias)
alias = self.query.join(
- (alias, table, f.rel.get_related_field().column, f.column),
- promote=True, join_field=f
- )
+ (alias, table, f, False),
+ promote=True)
from_parent = (opts.model if issubclass(model, opts.model)
else None)
columns, aliases = self.get_default_columns(start_alias=alias,
View
8 django/db/models/sql/constants.py
@@ -24,13 +24,7 @@
# dictionary in the Query class).
JoinInfo = namedtuple('JoinInfo',
'table_name rhs_alias join_type lhs_alias '
- 'lhs_join_col rhs_join_col nullable join_field')
-
-# PathInfo is used when converting lookups (fk__somecol). The contents
-# describe the join in Model terms (model Options and Fields for both
-# sides of the join. The rel_field is the field we are joining along.
-PathInfo = namedtuple('PathInfo',
- 'from_field to_field from_opts to_opts join_field')
+ 'nullable join_field direct')
# Pairs of column clauses to select, and (possibly None) field for the clause.
SelectInfo = namedtuple('SelectInfo', 'col field')
View
196 django/db/models/sql/query.py
@@ -18,9 +18,10 @@
from django.db.models.expressions import ExpressionNode
from django.db.models.fields import FieldDoesNotExist
from django.db.models.loading import get_model
+from django.db.models.related import PathInfo
from django.db.models.sql import aggregates as base_aggregates_module
from django.db.models.sql.constants import (QUERY_TERMS, ORDER_DIR, SINGLE,
- ORDER_PATTERN, REUSE_ALL, JoinInfo, SelectInfo, PathInfo)
+ ORDER_PATTERN, REUSE_ALL, JoinInfo, SelectInfo)
from django.db.models.sql.datastructures import EmptyResultSet, Empty, MultiJoin
from django.db.models.sql.expressions import SQLEvaluator
from django.db.models.sql.where import (WhereNode, Constraint, EverythingNode,
@@ -211,6 +212,7 @@ def __getstate__(self):
field_id = (model.app_label, model.object_name, join_info.join_field.name)
new_alias_map[alias] = join_info._replace(join_field=field_id)
obj_dict['alias_map'] = new_alias_map
+ del obj_dict['join_map']
return obj_dict
def __setstate__(self, obj_dict):
@@ -232,7 +234,12 @@ def __setstate__(self, obj_dict):
new_alias_map[alias] = join_info._replace(
join_field=get_model(field_id[0], field_id[1])._meta.get_field(field_id[2]))
obj_dict['alias_map'] = new_alias_map
-
+ join_map = obj_dict['join_map'] = {}
+ for alias, join in new_alias_map.items():
+ connection = join.lhs_alias, join.table_name, join.join_field, join.direct
+ if not connection in join_map:
+ join_map[connection] = []
+ join_map[connection].append(alias)
self.__dict__.update(obj_dict)
def prepare(self):
@@ -498,15 +505,14 @@ def combine(self, rhs, connector):
# Now, add the joins from rhs query into the new query (skipping base
# table).
for alias in rhs.tables[1:]:
- table, _, join_type, lhs, lhs_col, col, nullable, join_field = rhs.alias_map[alias]
+ table, _, join_type, lhs, nullable, join_field, direct = rhs.alias_map[alias]
promote = (join_type == self.LOUTER)
# If the left side of the join was already relabeled, use the
# updated alias.
lhs = change_map.get(lhs, lhs)
new_alias = self.join(
- (lhs, table, lhs_col, col), reuse=reuse, promote=promote,
- outer_if_first=not conjunction, nullable=nullable,
- join_field=join_field)
+ (lhs, table, join_field, direct), reuse=reuse, promote=promote,
+ outer_if_first=not conjunction, nullable=nullable)
# We can't reuse the same join again in the query. If we have two
# distinct joins for the same connection in rhs query, then the
# combined query must have two joins, too.
@@ -743,7 +749,7 @@ def promote_joins(self, aliases, unconditional=False):
aliases = list(aliases)
while aliases:
alias = aliases.pop(0)
- if self.alias_map[alias].rhs_join_col is None:
+ if self.alias_map[alias].join_field is None:
# This is the base table (first FROM entry) - this table
# isn't really joined at all in the query, so we should not
# alter its join type.
@@ -892,7 +898,7 @@ def count_active_tables(self):
return len([1 for count in self.alias_refcount.values() if count])
def join(self, connection, reuse=REUSE_ALL, promote=False,
- outer_if_first=False, nullable=False, join_field=None):
+ outer_if_first=False, nullable=False):
"""
Returns an alias for the join in 'connection', either reusing an
existing alias for that join or creating a new one. 'connection' is a
@@ -924,7 +930,7 @@ def join(self, connection, reuse=REUSE_ALL, promote=False,
The 'join_field' is the field we are joining along (if any).
"""
- lhs, table, lhs_col, col = connection
+ lhs, table, join_field, direct = connection
existing = self.join_map.get(connection, ())
if reuse == REUSE_ALL:
reuse = existing
@@ -933,12 +939,6 @@ def join(self, connection, reuse=REUSE_ALL, promote=False,
else:
reuse = [a for a in existing if a in reuse]
for alias in reuse:
- if join_field and self.alias_map[alias].join_field != join_field:
- # The join_map doesn't contain join_field (mainly because
- # fields in Query structs are problematic in pickling), so
- # check that the existing join is created using the same
- # join_field used for the under work join.
- continue
self.ref_alias(alias)
if promote or (lhs and self.alias_map[lhs].join_type == self.LOUTER):
self.promote_joins([alias])
@@ -957,8 +957,7 @@ def join(self, connection, reuse=REUSE_ALL, promote=False,
join_type = self.LOUTER
else:
join_type = self.INNER
- join = JoinInfo(table, alias, join_type, lhs, lhs_col, col, nullable,
- join_field)
+ join = JoinInfo(table, alias, join_type, lhs, nullable, join_field, direct)
self.alias_map[alias] = join
if connection in self.join_map:
self.join_map[connection] += (alias,)
@@ -985,11 +984,38 @@ def setup_inherited_models(self):
for field, model in opts.get_fields_with_model():
if model not in seen:
- link_field = opts.get_ancestor_link(model)
- seen[model] = self.join((root_alias, model._meta.db_table,
- link_field.column, model._meta.pk.column))
+ self.join_parent_model(opts, model, root_alias, seen)
self.included_inherited_models = seen
+ def join_parent_model(self, opts, model, alias, seen):
+ """
+ Makes sure the given 'model' is joined in the query. The 'model' needs
+ to be a parent of opts' model.
+
+ The 'alias' is the root alias for starting the join, 'seen' is a dict
+ of model -> alias of existing joins.
+ """
+ if model in seen:
+ return seen[model]
+ int_opts = opts
+ chain = opts.get_base_chain(model)
+ if chain is None:
+ return alias
+ for int_model in chain:
+ if int_model in seen:
+ return seen[int_model]
+ # Proxy model have elements in base chain
+ # with no parents, assign the new options
+ # object and skip to the next base in that
+ # case
+ if not int_opts.parents[int_model]:
+ int_opts = int_model._meta
+ continue
+ link_field = int_opts.get_ancestor_link(int_model)
+ int_opts = int_model._meta
+ alias = seen[int_model] = self.join((alias, int_opts.db_table, link_field, True))
+ return alias
+
def remove_inherited_models(self):
"""
Undoes the effects of setup_inherited_models(). Should be called
@@ -1218,7 +1244,11 @@ def add_filter(self, filter_expr, connector=AND, negate=False,
if len(join_list) > 1:
for alias in join_list:
if self.alias_map[alias].join_type == self.LOUTER:
- j_col = self.alias_map[alias].rhs_join_col
+ join = self.alias_map[alias]
+ if join.direct:
+ j_col = join.join_field.rel.get_related_field().column
+ else:
+ j_col = join.join_field.column
# The join promotion logic should never produce
# a LOUTER join for the base join - assert that.
assert j_col is not None
@@ -1309,7 +1339,6 @@ def names_to_path(self, names, opts, allow_many=False,
contain the same value as the final field).
"""
path = []
- multijoin_pos = None
for pos, name in enumerate(names):
if name == 'pk':
name = opts.pk.name
@@ -1343,92 +1372,19 @@ def names_to_path(self, names, opts, allow_many=False,
target = final_field.rel.get_related_field()
opts = int_model._meta
path.append(PathInfo(final_field, target, final_field.model._meta,
- opts, final_field))
- # We have five different cases to solve: foreign keys, reverse
- # foreign keys, m2m fields (also reverse) and non-relational
- # fields. We are mostly just using the related field API to
- # fetch the from and to fields. The m2m fields are handled as
- # two foreign keys, first one reverse, the second one direct.
- if direct and not field.rel and not m2m:
+ opts, final_field, False, True))
+ if hasattr(field, 'get_path_info'):
+ pathinfos, opts, target, final_field = field.get_path_info()
+ path.extend(pathinfos)
+ else:
# Local non-relational field.
final_field = target = field
break
- elif direct and not m2m:
- # Foreign Key
- opts = field.rel.to._meta
- target = field.rel.get_related_field()
- final_field = field
- from_opts = field.model._meta
- path.append(PathInfo(field, target, from_opts, opts, field))
- elif not direct and not m2m:
- # Revere foreign key
- final_field = to_field = field.field
- opts = to_field.model._meta
- from_field = to_field.rel.get_related_field()
- from_opts = from_field.model._meta
- path.append(
- PathInfo(from_field, to_field, from_opts, opts, to_field))
- if from_field.model is to_field.model:
- # Recursive foreign key to self.
- target = opts.get_field_by_name(
- field.field.rel.field_name)[0]
- else:
- target = opts.pk
- elif direct and m2m:
- if not field.rel.through:
- # Gotcha! This is just a fake m2m field - a generic relation
- # field).
- from_field = opts.pk
- opts = field.rel.to._meta
- target = opts.get_field_by_name(field.object_id_field_name)[0]
- final_field = field
- # Note that we are using different field for the join_field
- # than from_field or to_field. This is a hack, but we need the
- # GenericRelation to generate the extra SQL.
- path.append(PathInfo(from_field, target, field.model._meta, opts,
- field))
- else:
- # m2m field. We are travelling first to the m2m table along a
- # reverse relation, then from m2m table to the target table.
- from_field1 = opts.get_field_by_name(
- field.m2m_target_field_name())[0]
- opts = field.rel.through._meta
- to_field1 = opts.get_field_by_name(field.m2m_field_name())[0]
- path.append(
- PathInfo(from_field1, to_field1, from_field1.model._meta,
- opts, to_field1))
- final_field = from_field2 = opts.get_field_by_name(
- field.m2m_reverse_field_name())[0]
- opts = field.rel.to._meta
- target = to_field2 = opts.get_field_by_name(
- field.m2m_reverse_target_field_name())[0]
- path.append(
- PathInfo(from_field2, to_field2, from_field2.model._meta,
- opts, from_field2))
- elif not direct and m2m:
- # This one is just like above, except we are travelling the
- # fields in opposite direction.
- field = field.field
- from_field1 = opts.get_field_by_name(
- field.m2m_reverse_target_field_name())[0]
- int_opts = field.rel.through._meta
- to_field1 = int_opts.get_field_by_name(
- field.m2m_reverse_field_name())[0]
- path.append(
- PathInfo(from_field1, to_field1, from_field1.model._meta,
- int_opts, to_field1))
- final_field = from_field2 = int_opts.get_field_by_name(
- field.m2m_field_name())[0]
- opts = field.opts
- target = to_field2 = opts.get_field_by_name(
- field.m2m_target_field_name())[0]
- path.append(PathInfo(from_field2, to_field2, from_field2.model._meta,
- opts, from_field2))
-
- if m2m and multijoin_pos is None:
- multijoin_pos = pos
- if not direct and not path[-1].to_field.unique and multijoin_pos is None:
- multijoin_pos = pos
+ multijoin_pos = None
+ for m2mpos, pathinfo in enumerate(path):
+ if pathinfo.m2m:
+ multijoin_pos = m2mpos
+ break
if pos != len(names) - 1:
if pos == len(names) - 2:
@@ -1478,15 +1434,13 @@ def setup_joins(self, names, opts, alias, can_reuse, allow_many=True,
# joins at this stage - we will need the information about join type
# of the trimmed joins.
for pos, join in enumerate(path):
- from_field, to_field, from_opts, opts, join_field = join
- direct = join_field == from_field
+ from_field, to_field, from_opts, opts, join_field, _, direct = join
if direct:
nullable = self.is_nullable(from_field)
else:
nullable = True
- connection = alias, opts.db_table, from_field.column, to_field.column
- alias = self.join(connection, reuse=can_reuse, nullable=nullable,
- join_field=join_field)
+ connection = alias, opts.db_table, join_field, direct
+ alias = self.join(connection, reuse=can_reuse, nullable=nullable)
joins.append(alias)
return final_field, target, opts, joins, path
@@ -1505,7 +1459,7 @@ def trim_joins(self, target, joins, path):
the join.
"""
for info in reversed(path):
- direct = info.join_field == info.from_field
+ direct = info.direct
if info.to_field == target and direct:
target = info.from_field
self.unref_alias(joins.pop())
@@ -1642,18 +1596,11 @@ def add_fields(self, field_names, allow_m2m=True):
try:
for name in field_names:
- field, target, u2, joins, u3 = self.setup_joins(
+ field, target, u2, joins, path = self.setup_joins(
name.split(LOOKUP_SEP), opts, alias, REUSE_ALL, allow_m2m,
True)
+ col, _, joins = self.trim_joins(target, joins, path)
final_alias = joins[-1]
- col = target.column
- if len(joins) > 1:
- join = self.alias_map[final_alias]
- if col == join.rhs_join_col:
- self.unref_alias(final_alias)
- final_alias = join.lhs_alias
- col = join.lhs_join_col
- joins = joins[:-1]
self.promote_joins(joins[1:])
self.select.append(SelectInfo((final_alias, col), field))
except MultiJoin:
@@ -1935,7 +1882,7 @@ def set_start(self, start):
alias = self.get_initial_alias()
field, col, opts, joins, extra = self.setup_joins(
start.split(LOOKUP_SEP), opts, alias, REUSE_ALL)
- select_col = self.alias_map[joins[1]].lhs_join_col
+ select_field = self.alias_map[joins[1]].join_field
select_alias = alias
# The call to setup_joins added an extra reference to everything in
@@ -1948,12 +1895,13 @@ def set_start(self, start):
# is *always* the same value as lhs).
for alias in joins[1:]:
join_info = self.alias_map[alias]
- if (join_info.lhs_join_col != select_col
- or join_info.join_type != self.INNER):
+ if (join_info.join_field != select_field
+ or not join_info.direct):
break
self.unref_alias(select_alias)
select_alias = join_info.rhs_alias
- select_col = join_info.rhs_join_col
+ select_field = join_info.join_field
+ select_col = select_field.rel.get_related_field().column
self.select = [SelectInfo((select_alias, select_col), None)]
self.remove_inherited_models()
View
0  tests/modeltests/multicolumn_joins/__init__.py
No changes.
View
140 tests/modeltests/multicolumn_joins/models.py
@@ -0,0 +1,140 @@
+"""
+Note that these tests rely heavily on non-public APIs and are really hacky in
+nature. So, if these are broken by changes it doesn't mean the change isn't
+valid - fixing these tests is also an option...
+
+What we are interested in is some way to generate multicolumn joins using the
+ORM, not the specific way to do it.
+"""
+from django.db import models
+from django.db.models.options import Options
+from django.db.models.related import PathInfo
+from django.utils.encoding import python_2_unicode_compatible
+from django.utils.translation import get_language
+
+class FakeFKField(object):
+ null = False
+
+ @classmethod
+ def get_path_info(cls, direct=True):
+ if direct:
+ opts = Article._meta
+ target = Article._meta.get_field_by_name('id')[0]
+ from_opts = ArticleComment._meta
+ else:
+ opts = ArticleComment._meta
+ target = ArticleComment._meta.pk
+ from_opts = Article._meta
+ return [PathInfo(cls, target, from_opts, opts, cls, False, True)], opts, target, cls
+
+ @classmethod
+ def get_join_sql(cls, connection, qn, lhs_alias, rhs_alias, direct):
+ # Note - we don't care about direct here - we have exactly the same
+ # column names on both sides...
+ lhs, rhs = qn(lhs_alias), qn(rhs_alias)
+ headline, pub_date = qn('headline'), qn('pub_date')
+
+ return ('%s.%s = %s.%s AND %s.%s = %s.%s' %
+ (lhs, headline, rhs, headline, lhs, pub_date, rhs, pub_date)), []
+
+class TranslationField(object):
+ null = True
+ name = 'translation'
+ unique = True
+
+ @classmethod
+ def get_path_info(cls):
+ opts = ArticleTranslation._meta
+ target = ArticleTranslation._meta.pk
+ from_opts = Article._meta
+ return [PathInfo(cls, target, from_opts, opts, cls, False, False)], opts, target, target
+
+ @classmethod
+ def get_join_sql(cls, connection, qn, lhs_alias, rhs_alias, direct):
+ lhs, rhs = qn(lhs_alias), qn(rhs_alias)
+ lang = qn('lang')
+ article_id = qn('article_id')
+ id = qn('id')
+ return ('%s.%s = %s.%s AND %s.%s = %%s' %
+ (lhs, id, rhs, article_id, rhs, lang)), [get_language()]
+
+ @classmethod
+ def select_related_descend(cls, restricted, requested, only_load):
+ if requested and 'translation' in requested:
+ return True
+ return restricted
+
+ @classmethod
+ def get_cache_name(cls):
+ return 'translation'
+
+ @classmethod
+ def get_related_cache_name(cls):
+ return 'article'
+
+class WrappedMeta(object):
+ def __init__(self, meta):
+ self.__dict__['meta'] = meta
+
+ def __getattr__(self, attr):
+ return getattr(self.meta, attr)
+
+ def __setattr__(self, attr, value):
+ setattr(self.meta, attr, value)
+
+class FakeRelObject(object):
+
+ @classmethod
+ def get_path_info(self):
+ return FakeFKField.get_path_info(direct=False)
+
+class ArticleMeta(Options):
+
+ def get_field_by_name(self, name):
+ if name == 'comments':
+ return FakeRelObject, None, False, True
+ elif name == 'translation':
+ return TranslationField, None, False, False
+ else:
+ return super(ArticleMeta, self).get_field_by_name(name)
+
+ def get_fields_with_model(self):
+ return super(ArticleMeta, self).get_fields_with_model() + ((TranslationField, None),)
+
+@python_2_unicode_compatible
+class Article(models.Model):
+ headline = models.CharField(max_length=100, default='Default headline')
+ pub_date = models.DateTimeField()
+ data = models.TextField()
+
+ class Meta:
+ ordering = ('pub_date', 'headline')
+ unique_together = (('headline', 'pub_date'),)
+
+ def __str__(self):
+ return self.headline
+Article._meta.__class__ = ArticleMeta
+
+
+class ArticleCommentMeta(Options):
+ def get_field_by_name(self, name):
+ if name == 'article':
+ return FakeFKField, None, True, False
+ else:
+ return super(ArticleCommentMeta, self).get_field_by_name(name)
+
+class ArticleComment(models.Model):
+ headline = models.CharField(max_length=100)
+ pub_date = models.DateTimeField()
+ comment = models.TextField()
+ArticleComment._meta.__class__ = ArticleCommentMeta
+
+class ArticleTranslation(models.Model):
+ article = models.ForeignKey(Article, related_name='translations')
+ lang = models.CharField(max_length=2)
+ title = models.CharField(max_length=100)
+
+ class Meta:
+ unique_together = (('article', 'lang'),)
+
+TranslationField.model = ArticleTranslation
View
81 tests/modeltests/multicolumn_joins/tests.py
@@ -0,0 +1,81 @@
+from __future__ import absolute_import, unicode_literals
+
+from datetime import datetime
+
+from django.db.models import Q, Count
+from django.test import TestCase
+from django.utils.translation import activate, deactivate
+
+from .models import Article, ArticleComment, ArticleTranslation
+
+
+class ModelTest(TestCase):
+ def setUp(self):
+ self.a1 = Article.objects.create(headline='h1', pub_date=datetime.now(), data='d1')
+ self.a2 = Article.objects.create(headline='h2', pub_date=datetime.now(), data='d2')
+ self.c1 = ArticleComment.objects.create(headline='h1', pub_date=self.a1.pub_date, comment='c1')
+ self.c2 = ArticleComment.objects.create(headline='h1', pub_date=self.a1.pub_date, comment='c2')
+ self.c3 = ArticleComment.objects.create(headline='h2', pub_date=self.a2.pub_date, comment='c3')
+
+ def test_basic_join(self):
+ self.assertEqual(ArticleComment.objects.filter(article__data='d1').count(), 2)
+ self.assertEqual(ArticleComment.objects.filter(article__data='d2').count(), 1)
+
+ def test_basic_revjoin(self):
+ self.assertEqual(Article.objects.filter(comments__comment='c1').count(), 1)
+ self.assertEqual(Article.objects.filter(comments__comment__in=['c1', 'c3']).count(), 2)
+ self.assertEqual(Article.objects.filter(comments__comment='c3').count(), 1)
+
+ def test_translation_join(self):
+ try:
+ ArticleTranslation.objects.create(article=self.a1, lang='fi', title='Otsikko')
+ ArticleTranslation.objects.create(article=self.a1, lang='en', title='Title')
+ activate('fi')
+ self.assertEqual(Article.objects.filter(translation__title='Otsikko').count(), 1)
+ self.assertEqual(Article.objects.filter(translation__title='Title').count(), 0)
+ activate('en')
+ self.assertEqual(Article.objects.filter(translation__title='Otsikko').count(), 0)
+ self.assertEqual(Article.objects.filter(translation__title='Title').count(), 1)
+ # isnull works too (that is, we have proper left joins!)
+ self.assertEqual(Article.objects.filter(
+ Q(translation__title='Title') | Q(translation__isnull=True)).count(), 2)
+ self.assertEqual(Article.objects.filter(
+ Q(translation__title='Otsikko') | Q(translation__isnull=True)).count(), 1)
+ self.assertEqual(Article.objects.aggregate(Count('translation'))['translation__count'], 1)
+ finally:
+ deactivate()
+
+ def test_select_related(self):
+ ArticleTranslation.objects.create(article=self.a1, lang='fi', title='Otsikko')
+ ArticleTranslation.objects.create(article=self.a1, lang='en', title='Title')
+ qs = Article.objects.select_related('translation').order_by('headline')
+ try:
+ with self.assertNumQueries(1):
+ activate('fi')
+ objs = list(qs)
+ self.assertEqual(objs[0].translation.title, 'Otsikko')
+ # Caching works
+ self.assertIs(objs[0].translation.article, objs[0])
+ self.assertIs(objs[1].translation, None)
+ with self.assertNumQueries(1):
+ activate('en')
+ # .all() to get rid of the cache...
+ objs = list(qs.all())
+ self.assertEqual(objs[0].translation.title, 'Title')
+ self.assertEqual(objs[1].translation, None)
+ ArticleTranslation.objects.create(article=self.a2, lang='en', title='Title 2')
+ # Some ordering stuff still...
+ with self.assertNumQueries(1):
+ qs = Article.objects.select_related('translation').order_by('translation__title')
+ self.assertTrue(str(qs.query).count('JOIN'), 1)
+ objs = list(qs)
+ self.assertEqual(objs[0].translation.title, 'Title')
+ self.assertEqual(objs[1].translation.title, 'Title 2')
+ with self.assertNumQueries(1):
+ qs = Article.objects.select_related('translation').order_by('-translation__title')
+ self.assertTrue(str(qs.query).count('JOIN'), 1)
+ objs = list(qs)
+ self.assertEqual(objs[0].translation.title, 'Title 2')
+ self.assertEqual(objs[1].translation.title, 'Title')
+ finally:
+ deactivate()

No commit comments for this range

Something went wrong with that request. Please try again.