Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

Fixed #19385 again, now with real code changes

The commit of 266de5f included only
tests, this time also code changes included...
  • Loading branch information...
commit 97774429aeb54df4c09895c07cd1b09e70201f7d 1 parent 266de5f
@akaariai akaariai authored
View
102 django/contrib/contenttypes/generic.py
@@ -8,10 +8,11 @@
from django.core.exceptions import ObjectDoesNotExist
from django.db import connection
-from django.db.models import signals
from django.db import models, router, DEFAULT_DB_ALIAS
-from django.db.models.fields.related import RelatedField, Field, ManyToManyRel
+from django.db.models import signals
+from django.db.models.fields.related import ForeignObject, ForeignObjectRel
from django.db.models.related import PathInfo
+from django.db.models.sql.where import Constraint
from django.forms import ModelForm
from django.forms.models import BaseModelFormSet, modelformset_factory, save_instance
from django.contrib.admin.options import InlineModelAdmin, flatten_fieldsets
@@ -149,17 +150,14 @@ def __set__(self, instance, value):
setattr(instance, self.fk_field, fk)
setattr(instance, self.cache_attr, value)
-class GenericRelation(RelatedField, Field):
+class GenericRelation(ForeignObject):
"""Provides an accessor to generic related objects (e.g. comments)"""
def __init__(self, to, **kwargs):
kwargs['verbose_name'] = kwargs.get('verbose_name', None)
- kwargs['rel'] = GenericRel(to,
- related_name=kwargs.pop('related_name', None),
- limit_choices_to=kwargs.pop('limit_choices_to', None),
- symmetrical=kwargs.pop('symmetrical', True))
-
-
+ kwargs['rel'] = GenericRel(
+ self, to, related_name=kwargs.pop('related_name', None),
+ limit_choices_to=kwargs.pop('limit_choices_to', None),)
# Override content-type/object-id field names on the related class
self.object_id_field_name = kwargs.pop("object_id_field", "object_id")
self.content_type_field_name = kwargs.pop("content_type_field", "content_type")
@@ -167,47 +165,44 @@ def __init__(self, to, **kwargs):
kwargs['blank'] = True
kwargs['editable'] = False
kwargs['serialize'] = False
- Field.__init__(self, **kwargs)
-
- def get_path_info(self):
- from_field = self.model._meta.pk
+ # This construct is somewhat of an abuse of ForeignObject. This field
+ # represents a relation from pk to object_id field. But, this relation
+ # isn't direct, the join is generated reverse along foreign key. So,
+ # the from_field is object_id field, to_field is pk because of the
+ # reverse join.
+ super(GenericRelation, self).__init__(
+ to, to_fields=[],
+ from_fields=[self.object_id_field_name], **kwargs)
+
+ def resolve_related_fields(self):
+ self.to_fields = [self.model._meta.pk.name]
+ return [(self.rel.to._meta.get_field_by_name(self.object_id_field_name)[0],
+ self.model._meta.pk)]
+
+ def get_reverse_path_info(self):
opts = self.rel.to._meta
target = opts.get_field_by_name(self.object_id_field_name)[0]
- # Note that we are using different field for the join_field
- # than from_field or to_field. This is a hack, but we need the
- # GenericRelation to generate the extra SQL.
- return ([PathInfo(from_field, target, self.model._meta, opts, self, True, False)],
- opts, target, self)
+ return [PathInfo(self.model._meta, opts, (target,), self.rel, True, False)]
def get_choices_default(self):
- return Field.get_choices(self, include_blank=False)
+ return super(GenericRelation, self).get_choices(include_blank=False)
def value_to_string(self, obj):
qs = getattr(obj, self.name).all()
return smart_text([instance._get_pk_val() for instance in qs])
- def m2m_db_table(self):
- return self.rel.to._meta.db_table
-
- def m2m_column_name(self):
- return self.object_id_field_name
-
- def m2m_reverse_name(self):
- return self.rel.to._meta.pk.column
-
- def m2m_target_field_name(self):
- return self.model._meta.pk.name
-
- def m2m_reverse_target_field_name(self):
- return self.rel.to._meta.pk.name
+ def get_joining_columns(self, reverse_join=False):
+ if not reverse_join:
+ # This error message is meant for the user, and from user
+ # perspective this is a reverse join along the GenericRelation.
+ raise ValueError('Joining in reverse direction not allowed.')
+ return super(GenericRelation, self).get_joining_columns(reverse_join)
def contribute_to_class(self, cls, name):
- super(GenericRelation, self).contribute_to_class(cls, name)
-
+ super(GenericRelation, self).contribute_to_class(cls, name, virtual_only=True)
# Save a reference to which model this class is on for future use
self.model = cls
-
- # Add the descriptor for the m2m relation
+ # Add the descriptor for the relation
setattr(cls, self.name, ReverseGenericRelatedObjectsDescriptor(self))
def contribute_to_related_class(self, cls, related):
@@ -219,21 +214,18 @@ def set_attributes_from_rel(self):
def get_internal_type(self):
return "ManyToManyField"
- def db_type(self, connection):
- # Since we're simulating a ManyToManyField, in effect, best return the
- # same db_type as well.
- return None
-
def get_content_type(self):
"""
Returns the content type associated with this field's model.
"""
return ContentType.objects.get_for_model(self.model)
- def get_extra_join_sql(self, connection, qn, lhs_alias, rhs_alias):
- extra_col = self.rel.to._meta.get_field_by_name(self.content_type_field_name)[0].column
- contenttype = self.get_content_type().pk
- return " AND %s.%s = %%s" % (qn(rhs_alias), qn(extra_col)), [contenttype]
+ def get_extra_restriction(self, where_class, alias, remote_alias):
+ field = self.rel.to._meta.get_field_by_name(self.content_type_field_name)[0]
+ contenttype_pk = self.get_content_type().pk
+ cond = where_class()
+ cond.add((Constraint(remote_alias, field.column, field), 'exact', contenttype_pk), 'AND')
+ return cond
def bulk_related_objects(self, objs, using=DEFAULT_DB_ALIAS):
"""
@@ -273,12 +265,12 @@ def __get__(self, instance, instance_type=None):
qn = connection.ops.quote_name
content_type = ContentType.objects.db_manager(instance._state.db).get_for_model(instance)
+ join_cols = self.field.get_joining_columns(reverse_join=True)[0]
manager = RelatedManager(
model = rel_model,
instance = instance,
- symmetrical = (self.field.rel.symmetrical and instance.__class__ == rel_model),
- source_col_name = qn(self.field.m2m_column_name()),
- target_col_name = qn(self.field.m2m_reverse_name()),
+ source_col_name = qn(join_cols[0]),
+ target_col_name = qn(join_cols[1]),
content_type = content_type,
content_type_field_name = self.field.content_type_field_name,
object_id_field_name = self.field.object_id_field_name,
@@ -378,14 +370,10 @@ def create(self, **kwargs):
return GenericRelatedObjectManager
-class GenericRel(ManyToManyRel):
- def __init__(self, to, related_name=None, limit_choices_to=None, symmetrical=True):
- self.to = to
- self.related_name = related_name
- self.limit_choices_to = limit_choices_to or {}
- self.symmetrical = symmetrical
- self.multiple = True
- self.through = None
+class GenericRel(ForeignObjectRel):
+
+ def __init__(self, field, to, related_name=None, limit_choices_to=None):
+ super(GenericRel, self).__init__(field, to, related_name, limit_choices_to)
class BaseGenericInlineFormSet(BaseModelFormSet):
"""
View
12 django/core/management/validation.py
@@ -153,8 +153,16 @@ def get_validation_errors(outfile, app=None):
continue
# Make sure the related field specified by a ForeignKey is unique
- if not f.rel.to._meta.get_field(f.rel.field_name).unique:
- e.add(opts, "Field '%s' under model '%s' must have a unique=True constraint." % (f.rel.field_name, f.rel.to.__name__))
+ if f.requires_unique_target:
+ if len(f.foreign_related_fields) > 1:
+ has_unique_field = False
+ for rel_field in f.foreign_related_fields:
+ has_unique_field = has_unique_field or rel_field.unique
+ if not has_unique_field:
+ e.add(opts, "Field combination '%s' under model '%s' must have a unique=True constraint" % (','.join([rel_field.name for rel_field in f.foreign_related_fields]), f.rel.to.__name__))
+ else:
+ if not f.foreign_related_fields[0].unique:
+ e.add(opts, "Field '%s' under model '%s' must have a unique=True constraint." % (f.foreign_related_fields[0].name, f.rel.to.__name__))
rel_opts = f.rel.to._meta
rel_name = f.related.get_accessor_name()
View
6 django/db/backends/mysql/compiler.py
@@ -17,6 +17,12 @@ def resolve_columns(self, row, fields=()):
values.append(value)
return row[:index_extra_select] + tuple(values)
+ def as_subquery_condition(self, alias, columns):
+ qn = self.quote_name_unless_alias
+ qn2 = self.connection.ops.quote_name
+ sql, params = self.as_sql()
+ return '(%s) IN (%s)' % (', '.join(['%s.%s' % (qn(alias), qn2(column)) for column in columns]), sql), params
+
class SQLInsertCompiler(compiler.SQLInsertCompiler, SQLCompiler):
pass
View
2  django/db/models/__init__.py
@@ -8,7 +8,7 @@
from django.db.models.fields import *
from django.db.models.fields.subclassing import SubfieldBase
from django.db.models.fields.files import FileField, ImageField
-from django.db.models.fields.related import ForeignKey, OneToOneField, ManyToManyField, ManyToOneRel, ManyToManyRel, OneToOneRel
+from django.db.models.fields.related import ForeignKey, ForeignObject, OneToOneField, ManyToManyField, ManyToOneRel, ManyToManyRel, OneToOneRel
from django.db.models.deletion import CASCADE, PROTECT, SET, SET_NULL, SET_DEFAULT, DO_NOTHING, ProtectedError
from django.db.models import signals
from django.utils.decorators import wraps
View
19 django/db/models/base.py
@@ -10,7 +10,7 @@
from django.core.exceptions import (ObjectDoesNotExist,
MultipleObjectsReturned, FieldError, ValidationError, NON_FIELD_ERRORS)
from django.db.models.fields import AutoField, FieldDoesNotExist
-from django.db.models.fields.related import (ManyToOneRel,
+from django.db.models.fields.related import (ForeignObjectRel, ManyToOneRel,
OneToOneField, add_lazy_relation)
from django.db import (router, transaction, DatabaseError,
DEFAULT_DB_ALIAS)
@@ -333,12 +333,12 @@ def __init__(self, *args, **kwargs):
# The reason for the kwargs check is that standard iterator passes in by
# args, and instantiation for iteration is 33% faster.
args_len = len(args)
- if args_len > len(self._meta.fields):
+ if args_len > len(self._meta.concrete_fields):
# Daft, but matches old exception sans the err msg.
raise IndexError("Number of args exceeds number of fields")
- fields_iter = iter(self._meta.fields)
if not kwargs:
+ fields_iter = iter(self._meta.concrete_fields)
# The ordering of the zip calls matter - zip throws StopIteration
# when an iter throws it. So if the first iter throws it, the second
# is *not* consumed. We rely on this, so don't change the order
@@ -347,6 +347,7 @@ def __init__(self, *args, **kwargs):
setattr(self, field.attname, val)
else:
# Slower, kwargs-ready version.
+ fields_iter = iter(self._meta.fields)
for val, field in zip(args, fields_iter):
setattr(self, field.attname, val)
kwargs.pop(field.name, None)
@@ -363,11 +364,12 @@ def __init__(self, *args, **kwargs):
# data-descriptor object (DeferredAttribute) without triggering its
# __get__ method.
if (field.attname not in kwargs and
- isinstance(self.__class__.__dict__.get(field.attname), DeferredAttribute)):
+ (isinstance(self.__class__.__dict__.get(field.attname), DeferredAttribute)
+ or field.column is None)):
# This field will be populated on request.
continue
if kwargs:
- if isinstance(field.rel, ManyToOneRel):
+ if isinstance(field.rel, ForeignObjectRel):
try:
# Assume object instance was passed in.
rel_obj = kwargs.pop(field.name)
@@ -394,6 +396,7 @@ def __init__(self, *args, **kwargs):
val = field.get_default()
else:
val = field.get_default()
+
if is_related_object:
# If we are passed a related instance, set it using the
# field.name instead of field.attname (e.g. "user" instead of
@@ -528,7 +531,7 @@ def save(self, force_insert=False, force_update=False, using=None,
# automatically do a "update_fields" save on the loaded fields.
elif not force_insert and self._deferred and using == self._state.db:
field_names = set()
- for field in self._meta.fields:
+ for field in self._meta.concrete_fields:
if not field.primary_key and not hasattr(field, 'through'):
field_names.add(field.attname)
deferred_fields = [
@@ -614,7 +617,7 @@ def _save_table(self, raw=False, cls=None, force_insert=False,
for a single table.
"""
meta = cls._meta
- non_pks = [f for f in meta.local_fields if not f.primary_key]
+ non_pks = [f for f in meta.local_concrete_fields if not f.primary_key]
if update_fields:
non_pks = [f for f in non_pks
@@ -652,7 +655,7 @@ def _save_table(self, raw=False, cls=None, force_insert=False,
**{field.name: getattr(self, field.attname)}).count()
self._order = order_value
- fields = meta.local_fields
+ fields = meta.local_concrete_fields
if not pk_set:
fields = [f for f in fields if not isinstance(f, AutoField)]
View
15 django/db/models/deletion.py
@@ -1,4 +1,3 @@
-from functools import wraps
from operator import attrgetter
from django.db import connections, transaction, IntegrityError
@@ -196,17 +195,13 @@ def collect(self, objs, source=None, nullable=False, collect_related=True,
self.fast_deletes.append(sub_objs)
elif sub_objs:
field.rel.on_delete(self, field, sub_objs, self.using)
-
- # TODO This entire block is only needed as a special case to
- # support cascade-deletes for GenericRelation. It should be
- # removed/fixed when the ORM gains a proper abstraction for virtual
- # or composite fields, and GFKs are reworked to fit into that.
- for relation in model._meta.many_to_many:
- if not relation.rel.through:
- sub_objs = relation.bulk_related_objects(new_objs, self.using)
+ for field in model._meta.virtual_fields:
+ if hasattr(field, 'bulk_related_objects'):
+ # Its something like generic foreign key.
+ sub_objs = field.bulk_related_objects(new_objs, self.using)
self.collect(sub_objs,
source=model,
- source_attr=relation.rel.related_name,
+ source_attr=field.rel.related_name,
nullable=True)
def related_objects(self, related, objs):
View
7 django/db/models/fields/__init__.py
@@ -292,10 +292,13 @@ def set_attributes_from_name(self, name):
if self.verbose_name is None and self.name:
self.verbose_name = self.name.replace('_', ' ')
- def contribute_to_class(self, cls, name):
+ def contribute_to_class(self, cls, name, virtual_only=False):
self.set_attributes_from_name(name)
self.model = cls
- cls._meta.add_field(self)
+ if virtual_only:
+ cls._meta.add_virtual_field(self)
+ else:
+ cls._meta.add_field(self)
if self.choices:
setattr(cls, 'get_%s_display' % self.name,
curry(cls._get_FIELD_display, field=self))
View
532 django/db/models/fields/related.py
@@ -7,7 +7,6 @@
PositiveIntegerField, PositiveSmallIntegerField, FieldDoesNotExist)
from django.db.models.related import RelatedObject, PathInfo
from django.db.models.query import QuerySet
-from django.db.models.query_utils import QueryWrapper
from django.db.models.deletion import CASCADE
from django.utils.encoding import smart_text
from django.utils import six
@@ -93,22 +92,27 @@ def do_pending_lookups(sender, **kwargs):
#HACK
-class RelatedField(object):
- def contribute_to_class(self, cls, name):
+class RelatedField(Field):
+ def db_type(self, connection):
+ '''By default related field will not have a column
+ as it relates columns to another table'''
+ return None
+
+ def contribute_to_class(self, cls, name, virtual_only=False):
sup = super(RelatedField, self)
# Store the opts for related_query_name()
self.opts = cls._meta
if hasattr(sup, 'contribute_to_class'):
- sup.contribute_to_class(cls, name)
+ sup.contribute_to_class(cls, name, virtual_only=virtual_only)
if not cls._meta.abstract and self.rel.related_name:
- self.rel.related_name = self.rel.related_name % {
- 'class': cls.__name__.lower(),
- 'app_label': cls._meta.app_label.lower(),
- }
-
+ related_name = self.rel.related_name % {
+ 'class': cls.__name__.lower(),
+ 'app_label': cls._meta.app_label.lower()
+ }
+ self.rel.related_name = related_name
other = self.rel.to
if isinstance(other, six.string_types) or other._meta.pk is None:
def resolve_related_class(field, model, cls):
@@ -122,7 +126,6 @@ def set_attributes_from_rel(self):
self.name = self.name or (self.rel.to._meta.model_name + '_' + self.rel.to._meta.pk.name)
if self.verbose_name is None:
self.verbose_name = self.rel.to._meta.verbose_name
- self.rel.field_name = self.rel.field_name or self.rel.to._meta.pk.name
def do_related_class(self, other, cls):
self.set_attributes_from_rel()
@@ -130,94 +133,6 @@ def do_related_class(self, other, cls):
if not cls._meta.abstract:
self.contribute_to_related_class(other, self.related)
- def get_prep_lookup(self, lookup_type, value):
- if hasattr(value, 'prepare'):
- return value.prepare()
- if hasattr(value, '_prepare'):
- return value._prepare()
- # FIXME: lt and gt are explicitly allowed to make
- # get_(next/prev)_by_date work; other lookups are not allowed since that
- # gets messy pretty quick. This is a good candidate for some refactoring
- # in the future.
- if lookup_type in ['exact', 'gt', 'lt', 'gte', 'lte']:
- return self._pk_trace(value, 'get_prep_lookup', lookup_type)
- if lookup_type in ('range', 'in'):
- return [self._pk_trace(v, 'get_prep_lookup', lookup_type) for v in value]
- elif lookup_type == 'isnull':
- return []
- raise TypeError("Related Field has invalid lookup: %s" % lookup_type)
-
- def get_db_prep_lookup(self, lookup_type, value, connection, prepared=False):
- if not prepared:
- value = self.get_prep_lookup(lookup_type, value)
- if hasattr(value, 'get_compiler'):
- value = value.get_compiler(connection=connection)
- if hasattr(value, 'as_sql') or hasattr(value, '_as_sql'):
- # If the value has a relabeled_clone method it means the
- # value will be handled later on.
- if hasattr(value, 'relabeled_clone'):
- return value
- if hasattr(value, 'as_sql'):
- sql, params = value.as_sql()
- else:
- sql, params = value._as_sql(connection=connection)
- return QueryWrapper(('(%s)' % sql), params)
-
- # FIXME: lt and gt are explicitly allowed to make
- # get_(next/prev)_by_date work; other lookups are not allowed since that
- # gets messy pretty quick. This is a good candidate for some refactoring
- # in the future.
- if lookup_type in ['exact', 'gt', 'lt', 'gte', 'lte']:
- return [self._pk_trace(value, 'get_db_prep_lookup', lookup_type,
- connection=connection, prepared=prepared)]
- if lookup_type in ('range', 'in'):
- return [self._pk_trace(v, 'get_db_prep_lookup', lookup_type,
- connection=connection, prepared=prepared)
- for v in value]
- elif lookup_type == 'isnull':
- return []
- raise TypeError("Related Field has invalid lookup: %s" % lookup_type)
-
- def _pk_trace(self, value, prep_func, lookup_type, **kwargs):
- # Value may be a primary key, or an object held in a relation.
- # If it is an object, then we need to get the primary key value for
- # that object. In certain conditions (especially one-to-one relations),
- # the primary key may itself be an object - so we need to keep drilling
- # down until we hit a value that can be used for a comparison.
- v = value
-
- # In the case of an FK to 'self', this check allows to_field to be used
- # for both forwards and reverse lookups across the FK. (For normal FKs,
- # it's only relevant for forward lookups).
- if isinstance(v, self.rel.to):
- field_name = getattr(self.rel, "field_name", None)
- else:
- field_name = None
- try:
- while True:
- if field_name is None:
- field_name = v._meta.pk.name
- v = getattr(v, field_name)
- field_name = None
- except AttributeError:
- pass
- except exceptions.ObjectDoesNotExist:
- v = None
-
- field = self
- while field.rel:
- if hasattr(field.rel, 'field_name'):
- field = field.rel.to._meta.get_field(field.rel.field_name)
- else:
- field = field.rel.to._meta.pk
-
- if lookup_type in ('range', 'in'):
- v = [v]
- v = getattr(field, prep_func)(lookup_type, v, **kwargs)
- if isinstance(v, list):
- v = v[0]
- return v
-
def related_query_name(self):
# This method defines the name that can be used to identify this
# related object in a table-spanning query. It uses the lower-cased
@@ -254,8 +169,8 @@ def get_prefetch_queryset(self, instances):
rel_obj_attr = attrgetter(self.related.field.attname)
instance_attr = lambda obj: obj._get_pk_val()
instances_dict = dict((instance_attr(inst), inst) for inst in instances)
- params = {'%s__pk__in' % self.related.field.name: list(instances_dict)}
- qs = self.get_queryset(instance=instances[0]).filter(**params)
+ query = {'%s__in' % self.related.field.name: instances}
+ qs = self.get_query_set(instance=instances[0]).filter(**query)
# Since we're going to assign directly in the cache,
# we must manage the reverse relation cache manually.
rel_obj_cache_name = self.related.field.get_cache_name()
@@ -274,7 +189,9 @@ def __get__(self, instance, instance_type=None):
if related_pk is None:
rel_obj = None
else:
- params = {'%s__pk' % self.related.field.name: related_pk}
+ params = {}
+ for lh_field, rh_field in self.related.field.related_fields:
+ params['%s__%s' % (self.related.field.name, rh_field.name)] = getattr(instance, rh_field.attname)
try:
rel_obj = self.get_queryset(instance=instance).get(**params)
except self.related.model.DoesNotExist:
@@ -314,13 +231,14 @@ def __set__(self, instance, value):
raise ValueError('Cannot assign "%r": instance is on database "%s", value is on database "%s"' %
(value, instance._state.db, value._state.db))
- related_pk = getattr(instance, self.related.field.rel.get_related_field().attname)
- if related_pk is None:
+ related_pk = tuple([getattr(instance, field.attname) for field in self.related.field.foreign_related_fields])
+ if None in related_pk:
raise ValueError('Cannot assign "%r": "%s" instance isn\'t saved in the database.' %
(value, instance._meta.object_name))
# Set the value of the related field to the value of the related object's related field
- setattr(value, self.related.field.attname, related_pk)
+ for index, field in enumerate(self.related.field.local_related_fields):
+ setattr(value, field.attname, related_pk[index])
# Since we already know what the related object is, seed the related
# object caches now, too. This avoids another db hit if you get the
@@ -352,16 +270,12 @@ def get_queryset(self, **db_hints):
else:
return QuerySet(self.field.rel.to).using(db)
- def get_prefetch_queryset(self, instances):
- other_field = self.field.rel.get_related_field()
- rel_obj_attr = attrgetter(other_field.attname)
- instance_attr = attrgetter(self.field.attname)
+ def get_prefetch_query_set(self, instances):
+ rel_obj_attr = self.field.get_foreign_related_value
+ instance_attr = self.field.get_local_related_value
instances_dict = dict((instance_attr(inst), inst) for inst in instances)
- if other_field.rel:
- params = {'%s__pk__in' % self.field.rel.field_name: list(instances_dict)}
- else:
- params = {'%s__in' % self.field.rel.field_name: list(instances_dict)}
- qs = self.get_queryset(instance=instances[0]).filter(**params)
+ query = {'%s__in' % self.field.related_query_name(): instances}
+ qs = self.get_query_set(instance=instances[0]).filter(**query)
# Since we're going to assign directly in the cache,
# we must manage the reverse relation cache manually.
if not self.field.rel.multiple:
@@ -377,16 +291,14 @@ def __get__(self, instance, instance_type=None):
try:
rel_obj = getattr(instance, self.cache_name)
except AttributeError:
- val = getattr(instance, self.field.attname)
- if val is None:
+ val = self.field.get_local_related_value(instance)
+ if None in val:
rel_obj = None
else:
- other_field = self.field.rel.get_related_field()
- if other_field.rel:
- params = {'%s__%s' % (self.field.rel.field_name, other_field.rel.field_name): val}
- else:
- params = {'%s__exact' % self.field.rel.field_name: val}
- qs = self.get_queryset(instance=instance)
+ params = {rh_field.attname: getattr(instance, lh_field.attname)
+ for lh_field, rh_field in self.field.related_fields}
+ params.update(self.field.get_extra_descriptor_filter(instance))
+ qs = self.get_query_set(instance=instance)
# Assuming the database enforces foreign keys, this won't fail.
rel_obj = qs.get(**params)
if not self.field.rel.multiple:
@@ -440,11 +352,11 @@ def __set__(self, instance, value):
setattr(related, self.field.related.get_cache_name(), None)
# Set the value of the related field
- try:
- val = getattr(value, self.field.rel.get_related_field().attname)
- except AttributeError:
- val = None
- setattr(instance, self.field.attname, val)
+ for lh_field, rh_field in self.field.related_fields:
+ try:
+ setattr(instance, lh_field.attname, getattr(value, rh_field.attname))
+ except AttributeError:
+ setattr(instance, lh_field.attname, None)
# Since we already know what the related object is, seed the related
# object caches now, too. This avoids another db hit if you get the
@@ -487,15 +399,12 @@ def related_manager_cls(self):
superclass = self.related.model._default_manager.__class__
rel_field = self.related.field
rel_model = self.related.model
- attname = rel_field.rel.get_related_field().attname
class RelatedManager(superclass):
def __init__(self, instance):
super(RelatedManager, self).__init__()
self.instance = instance
- self.core_filters = {
- '%s__%s' % (rel_field.name, attname): getattr(instance, attname)
- }
+ self.core_filters= {'%s__exact' % rel_field.name: instance}
self.model = rel_model
def get_queryset(self):
@@ -504,20 +413,22 @@ def get_queryset(self):
except (AttributeError, KeyError):
db = self._db or router.db_for_read(self.model, instance=self.instance)
qs = super(RelatedManager, self).get_queryset().using(db).filter(**self.core_filters)
- val = getattr(self.instance, attname)
- if val is None or val == '' and connections[db].features.interprets_empty_strings_as_nulls:
- return qs.none()
+ empty_strings_as_null = connections[db].features.interprets_empty_strings_as_nulls
+ for field in rel_field.foreign_related_fields:
+ val = getattr(self.instance, field.attname)
+ if val is None or (val == '' and empty_strings_as_null):
+ return qs.none()
qs._known_related_objects = {rel_field: {self.instance.pk: self.instance}}
return qs
def get_prefetch_queryset(self, instances):
- rel_obj_attr = attrgetter(rel_field.attname)
- instance_attr = attrgetter(attname)
+ rel_obj_attr = rel_field.get_local_related_value
+ instance_attr = rel_field.get_foreign_related_value
instances_dict = dict((instance_attr(inst), inst) for inst in instances)
db = self._db or router.db_for_read(self.model, instance=instances[0])
- query = {'%s__%s__in' % (rel_field.name, attname): list(instances_dict)}
- qs = super(RelatedManager, self).get_queryset().using(db).filter(**query)
- # Since we just bypassed this class' get_queryset(), we must manage
+ query = {'%s__in' % rel_field.name: instances}
+ qs = super(RelatedManager, self).get_query_set().using(db).filter(**query)
+ # Since we just bypassed this class' get_query_set(), we must manage
# the reverse relation manually.
for rel_obj in qs:
instance = instances_dict[rel_obj_attr(rel_obj)]
@@ -550,10 +461,10 @@ def get_or_create(self, **kwargs):
# remove() and clear() are only provided if the ForeignKey can have a value of null.
if rel_field.null:
def remove(self, *objs):
- val = getattr(self.instance, attname)
+ val = rel_field.get_foreign_related_value(self.instance)
for obj in objs:
# Is obj actually part of this descriptor set?
- if getattr(obj, rel_field.attname) == val:
+ if rel_field.get_local_related_value(obj) == val:
setattr(obj, rel_field.name, None)
obj.save()
else:
@@ -577,16 +488,26 @@ def __init__(self, model=None, query_field_name=None, instance=None, symmetrical
super(ManyRelatedManager, self).__init__()
self.model = model
self.query_field_name = query_field_name
- self.core_filters = {'%s__pk' % query_field_name: instance._get_pk_val()}
+
+ source_field = through._meta.get_field(source_field_name)
+ source_related_fields = source_field.related_fields
+
+ self.core_filters = {}
+ for lh_field, rh_field in source_related_fields:
+ self.core_filters['%s__%s' % (query_field_name, rh_field.name)] = getattr(instance, rh_field.attname)
+
self.instance = instance
self.symmetrical = symmetrical
+ self.source_field = source_field
self.source_field_name = source_field_name
self.target_field_name = target_field_name
self.reverse = reverse
self.through = through
self.prefetch_cache_name = prefetch_cache_name
- self._fk_val = self._get_fk_val(instance, source_field_name)
- if self._fk_val is None:
+ self.related_val = source_field.get_foreign_related_value(instance)
+ # Used for single column related auto created models
+ self._fk_val = self.related_val[0]
+ if None in self.related_val:
raise ValueError('"%r" needs to have a value for field "%s" before '
'this many-to-many relationship can be used.' %
(instance, source_field_name))
@@ -620,11 +541,9 @@ def get_queryset(self):
def get_prefetch_queryset(self, instances):
instance = instances[0]
- from django.db import connections
db = self._db or router.db_for_read(instance.__class__, instance=instance)
- query = {'%s__pk__in' % self.query_field_name:
- set(obj._get_pk_val() for obj in instances)}
- qs = super(ManyRelatedManager, self).get_queryset().using(db)._next_is_sticky().filter(**query)
+ query = {'%s__in' % self.query_field_name: instances}
+ qs = super(ManyRelatedManager, self).get_query_set().using(db)._next_is_sticky().filter(**query)
# M2M: need to annotate the query in order to get the primary model
# that the secondary model was actually related to. We know that
@@ -634,16 +553,14 @@ def get_prefetch_queryset(self, instances):
# For non-autocreated 'through' models, can't assume we are
# dealing with PK values.
fk = self.through._meta.get_field(self.source_field_name)
- source_col = fk.column
join_table = self.through._meta.db_table
connection = connections[db]
qn = connection.ops.quote_name
- qs = qs.extra(select={'_prefetch_related_val':
- '%s.%s' % (qn(join_table), qn(source_col))})
- select_attname = fk.rel.get_related_field().get_attname()
+ qs = qs.extra(select={'_prefetch_related_val_%s' % f.attname:
+ '%s.%s' % (qn(join_table), qn(f.column)) for f in fk.local_related_fields})
return (qs,
- attrgetter('_prefetch_related_val'),
- attrgetter(select_attname),
+ lambda result: tuple([getattr(result, '_prefetch_related_val_%s' % f.attname) for f in fk.local_related_fields]),
+ lambda inst: tuple([getattr(inst, f.attname) for f in fk.foreign_related_fields]),
False,
self.prefetch_cache_name)
@@ -795,7 +712,7 @@ def _clear_items(self, source_field_name):
instance=self.instance, reverse=self.reverse,
model=self.model, pk_set=None, using=db)
self.through._default_manager.using(db).filter(**{
- source_field_name: self._fk_val
+ source_field_name: self.related_val
}).delete()
if self.reverse or source_field_name == self.source_field_name:
# Don't send the signal when we are clearing the
@@ -918,19 +835,18 @@ def __set__(self, instance, value):
manager.clear()
manager.add(*value)
-
-class ManyToOneRel(object):
- def __init__(self, to, field_name, related_name=None, limit_choices_to=None,
- parent_link=False, on_delete=None):
+class ForeignObjectRel(object):
+ def __init__(self, field, to, related_name=None, limit_choices_to=None,
+ parent_link=False, on_delete=None):
try:
to._meta
except AttributeError: # to._meta doesn't exist, so it must be RECURSIVE_RELATIONSHIP_CONSTANT
assert isinstance(to, six.string_types), "'to' must be either a model, a model name or the string %r" % RECURSIVE_RELATIONSHIP_CONSTANT
- self.to, self.field_name = to, field_name
+
+ self.field = field
+ self.to = to
self.related_name = related_name
- if limit_choices_to is None:
- limit_choices_to = {}
- self.limit_choices_to = limit_choices_to
+ self.limit_choices_to = {} if limit_choices_to is None else limit_choices_to
self.multiple = True
self.parent_link = parent_link
self.on_delete = on_delete
@@ -939,6 +855,20 @@ def is_hidden(self):
"Should the related object be hidden?"
return self.related_name and self.related_name[-1] == '+'
+ def get_joining_columns(self):
+ return self.field.get_reverse_joining_columns()
+
+ def get_extra_restriction(self, where_class, alias, related_alias):
+ return self.field.get_extra_restriction(where_class, related_alias, alias)
+
+class ManyToOneRel(ForeignObjectRel):
+ def __init__(self, field, to, field_name, related_name=None, limit_choices_to=None,
+ parent_link=False, on_delete=None):
+ super(ManyToOneRel, self).__init__(
+ field, to, related_name=related_name, limit_choices_to=limit_choices_to,
+ parent_link=parent_link, on_delete=on_delete)
+ self.field_name = field_name
+
def get_related_field(self):
"""
Returns the Field in the 'to' object to which this relationship is
@@ -952,9 +882,9 @@ def get_related_field(self):
class OneToOneRel(ManyToOneRel):
- def __init__(self, to, field_name, related_name=None, limit_choices_to=None,
- parent_link=False, on_delete=None):
- super(OneToOneRel, self).__init__(to, field_name,
+ def __init__(self, field, to, field_name, related_name=None, limit_choices_to=None,
+ parent_link=False, on_delete=None):
+ super(OneToOneRel, self).__init__(field, to, field_name,
related_name=related_name, limit_choices_to=limit_choices_to,
parent_link=parent_link, on_delete=on_delete
)
@@ -963,7 +893,7 @@ def __init__(self, to, field_name, related_name=None, limit_choices_to=None,
class ManyToManyRel(object):
def __init__(self, to, related_name=None, limit_choices_to=None,
- symmetrical=True, through=None, db_constraint=True):
+ symmetrical=True, through=None, db_constraint=True):
if through and not db_constraint:
raise ValueError("Can't supply a through model and db_constraint=False")
self.to = to
@@ -989,7 +919,199 @@ def get_related_field(self):
return self.to._meta.pk
-class ForeignKey(RelatedField, Field):
+class ForeignObject(RelatedField):
+ requires_unique_target = True
+ generate_reverse_relation = True
+
+ def __init__(self, to, from_fields, to_fields, **kwargs):
+ self.from_fields = from_fields
+ self.to_fields = to_fields
+
+ if 'rel' not in kwargs:
+ kwargs['rel'] = ForeignObjectRel(
+ self, to,
+ related_name=kwargs.pop('related_name', None),
+ limit_choices_to=kwargs.pop('limit_choices_to', None),
+ parent_link=kwargs.pop('parent_link', False),
+ on_delete=kwargs.pop('on_delete', CASCADE),
+ )
+ kwargs['verbose_name'] = kwargs.get('verbose_name', None)
+
+ super(ForeignObject, self).__init__(**kwargs)
+
+ def resolve_related_fields(self):
+ if len(self.from_fields) < 1 or len(self.from_fields) != len(self.to_fields):
+ raise ValueError('Foreign Object from and to fields must be the same non-zero length')
+ related_fields = []
+ for index in range(len(self.from_fields)):
+ from_field_name = self.from_fields[index]
+ to_field_name = self.to_fields[index]
+ from_field = (self if from_field_name == 'self'
+ else self.opts.get_field_by_name(from_field_name)[0])
+ to_field = (self.rel.to._meta.pk if to_field_name is None
+ else self.rel.to._meta.get_field_by_name(to_field_name)[0])
+ related_fields.append((from_field, to_field))
+ return related_fields
+
+ @property
+ def related_fields(self):
+ if not hasattr(self, '_related_fields'):
+ self._related_fields = self.resolve_related_fields()
+ return self._related_fields
+
+ @property
+ def reverse_related_fields(self):
+ return [(rhs_field, lhs_field) for lhs_field, rhs_field in self.related_fields]
+
+ @property
+ def local_related_fields(self):
+ return tuple([lhs_field for lhs_field, rhs_field in self.related_fields])
+
+ @property
+ def foreign_related_fields(self):
+ return tuple([rhs_field for lhs_field, rhs_field in self.related_fields])
+
+ def get_local_related_value(self, instance):
+ return self.get_instance_value_for_fields(instance, self.local_related_fields)
+
+ def get_foreign_related_value(self, instance):
+ return self.get_instance_value_for_fields(instance, self.foreign_related_fields)
+
+ @staticmethod
+ def get_instance_value_for_fields(instance, fields):
+ return tuple([getattr(instance, field.attname) for field in fields])
+
+ def get_attname_column(self):
+ attname, column = super(ForeignObject, self).get_attname_column()
+ return attname, None
+
+ def get_joining_columns(self, reverse_join=False):
+ source = self.reverse_related_fields if reverse_join else self.related_fields
+ return tuple([(lhs_field.column, rhs_field.column) for lhs_field, rhs_field in source])
+
+ def get_reverse_joining_columns(self):
+ return self.get_joining_columns(reverse_join=True)
+
+ def get_extra_descriptor_filter(self, instance):
+ """
+ Returns an extra filter condition for related object fetching when
+ user does 'instance.fieldname', that is the extra filter is used in
+ the descriptor of the field.
+
+ The filter should be something usable in .filter(**kwargs) call, and
+ will be ANDed together with the joining columns condition.
+
+ A parallel method is get_extra_relation_restriction() which is used in
+ JOIN and subquery conditions.
+ """
+ return {}
+
+ def get_extra_restriction(self, where_class, alias, related_alias):
+ """
+ Returns a pair condition used for joining and subquery pushdown. The
+ condition is something that responds to as_sql(qn, connection) method.
+
+ Note that currently referring both the 'alias' and 'related_alias'
+ will not work in some conditions, like subquery pushdown.
+
+ A parallel method is get_extra_descriptor_filter() which is used in
+ instance.fieldname related object fetching.
+ """
+ return None
+
+ def get_path_info(self):
+ """
+ Get path from this field to the related model.
+ """
+ opts = self.rel.to._meta
+ from_opts = self.model._meta
+ return [PathInfo(from_opts, opts, self.foreign_related_fields, self, False, True)]
+
+ def get_reverse_path_info(self):
+ """
+ Get path from the related model to this field's model.
+ """
+ opts = self.model._meta
+ from_opts = self.rel.to._meta
+ pathinfos = [PathInfo(from_opts, opts, (opts.pk,), self.rel, not self.unique, False)]
+ return pathinfos
+
+ def get_lookup_constraint(self, constraint_class, alias, targets, sources, lookup_type,
+ raw_value):
+ from django.db.models.sql.where import SubqueryConstraint, Constraint, AND, OR
+ root_constraint = constraint_class()
+ assert len(targets) == len(sources)
+
+ def get_normalized_value(value):
+
+ from django.db.models import Model
+ if isinstance(value, Model):
+ value_list = []
+ for source in sources:
+ # Account for one-to-one relations when sent a different model
+ while not isinstance(value, source.model):
+ source = source.rel.to._meta.get_field(source.rel.field_name)
+ value_list.append(getattr(value, source.attname))
+ return tuple(value_list)
+ elif not isinstance(value, tuple):
+ return (value,)
+ return value
+
+ is_multicolumn = len(self.related_fields) > 1
+ if (hasattr(raw_value, '_as_sql') or
+ hasattr(raw_value, 'get_compiler')):
+ root_constraint.add(SubqueryConstraint(alias, [target.column for target in targets],
+ [source.name for source in sources], raw_value),
+ AND)
+ elif lookup_type == 'isnull':
+ root_constraint.add(
+ (Constraint(alias, targets[0].column, targets[0]), lookup_type, raw_value), AND)
+ elif (lookup_type == 'exact' or (lookup_type in ['gt', 'lt', 'gte', 'lte']
+ and not is_multicolumn)):
+ value = get_normalized_value(raw_value)
+ for index, source in enumerate(sources):
+ root_constraint.add(
+ (Constraint(alias, targets[index].column, sources[index]), lookup_type,
+ value[index]), AND)
+ elif lookup_type in ['range', 'in'] and not is_multicolumn:
+ values = [get_normalized_value(value) for value in raw_value]
+ value = [val[0] for val in values]
+ root_constraint.add(
+ (Constraint(alias, targets[0].column, sources[0]), lookup_type, value), AND)
+ elif lookup_type == 'in':
+ values = [get_normalized_value(value) for value in raw_value]
+ for value in values:
+ value_constraint = constraint_class()
+ for index, target in enumerate(targets):
+ value_constraint.add(
+ (Constraint(alias, target.column, sources[index]), 'exact', value[index]),
+ AND)
+ root_constraint.add(value_constraint, OR)
+ else:
+ raise TypeError('Related Field got invalid lookup: %s' % lookup_type)
+ return root_constraint
+
+ @property
+ def attnames(self):
+ return tuple([field.attname for field in self.local_related_fields])
+
+ def get_defaults(self):
+ return tuple([field.get_default() for field in self.local_related_fields])
+
+ def contribute_to_class(self, cls, name, virtual_only=False):
+ super(ForeignObject, self).contribute_to_class(cls, name, virtual_only=virtual_only)
+ setattr(cls, self.name, ReverseSingleRelatedObjectDescriptor(self))
+
+ def contribute_to_related_class(self, cls, related):
+ # Internal FK's - i.e., those with a related name ending with '+' -
+ # and swapped models don't get a related descriptor.
+ if not self.rel.is_hidden() and not related.model._meta.swapped:
+ setattr(cls, related.get_accessor_name(), ForeignRelatedObjectsDescriptor(related))
+ if self.rel.limit_choices_to:
+ cls._meta.related_fkey_lookups.append(self.rel.limit_choices_to)
+
+
+class ForeignKey(ForeignObject):
empty_strings_allowed = False
default_error_messages = {
'invalid': _('Model %(model)s with pk %(pk)r does not exist.')
@@ -999,7 +1121,7 @@ class ForeignKey(RelatedField, Field):
def __init__(self, to, to_field=None, rel_class=ManyToOneRel,
db_constraint=True, **kwargs):
try:
- to._meta.model_name
+ to_name = to._meta.object_name.lower()
except AttributeError: # to._meta doesn't exist, so it must be RECURSIVE_RELATIONSHIP_CONSTANT
assert isinstance(to, six.string_types), "%s(%r) is invalid. First parameter to ForeignKey must be either a model, a model name, or the string %r" % (self.__class__.__name__, to, RECURSIVE_RELATIONSHIP_CONSTANT)
else:
@@ -1008,44 +1130,33 @@ def __init__(self, to, to_field=None, rel_class=ManyToOneRel,
# the to_field during FK construction. It won't be guaranteed to
# be correct until contribute_to_class is called. Refs #12190.
to_field = to_field or (to._meta.pk and to._meta.pk.name)
- kwargs['verbose_name'] = kwargs.get('verbose_name', None)
if 'db_index' not in kwargs:
kwargs['db_index'] = True
self.db_constraint = db_constraint
- kwargs['rel'] = rel_class(to, to_field,
+
+ kwargs['rel'] = rel_class(
+ self, to, to_field,
related_name=kwargs.pop('related_name', None),
limit_choices_to=kwargs.pop('limit_choices_to', None),
parent_link=kwargs.pop('parent_link', False),
on_delete=kwargs.pop('on_delete', CASCADE),
)
- super(ForeignKey, self).__init__(**kwargs)
+ super(ForeignKey, self).__init__(to, ['self'], [to_field], **kwargs)
- def get_path_info(self):
- """
- Get path from this field to the related model.
- """
- opts = self.rel.to._meta
- target = self.rel.get_related_field()
- from_opts = self.model._meta
- return [PathInfo(self, target, from_opts, opts, self, False, True)], opts, target, self
+ @property
+ def related_field(self):
+ return self.foreign_related_fields[0]
def get_reverse_path_info(self):
"""
Get path from the related model to this field's model.
"""
opts = self.model._meta
- from_field = self.rel.get_related_field()
- from_opts = from_field.model._meta
- pathinfos = [PathInfo(from_field, self, from_opts, opts, self, not self.unique, False)]
- if from_field.model is self.model:
- # Recursive foreign key to self.
- target = opts.get_field_by_name(
- self.rel.field_name)[0]
- else:
- target = opts.pk
- return pathinfos, opts, target, self
+ from_opts = self.rel.to._meta
+ pathinfos = [PathInfo(from_opts, opts, (opts.pk,), self.rel, not self.unique, False)]
+ return pathinfos
def validate(self, value, model_instance):
if self.rel.parent_link:
@@ -1066,21 +1177,26 @@ def validate(self, value, model_instance):
def get_attname(self):
return '%s_id' % self.name
+ def get_attname_column(self):
+ attname = self.get_attname()
+ column = self.db_column or attname
+ return attname, column
+
def get_validator_unique_lookup_type(self):
- return '%s__%s__exact' % (self.name, self.rel.get_related_field().name)
+ return '%s__%s__exact' % (self.name, self.related_field.name)
def get_default(self):
"Here we check if the default value is an object and return the to_field if so."
field_default = super(ForeignKey, self).get_default()
if isinstance(field_default, self.rel.to):
- return getattr(field_default, self.rel.get_related_field().attname)
+ return getattr(field_default, self.related_field.attname)
return field_default
def get_db_prep_save(self, value, connection):
if value == '' or value == None:
return None
else:
- return self.rel.get_related_field().get_db_prep_save(value,
+ return self.related_field.get_db_prep_save(value,
connection=connection)
def value_to_string(self, obj):
@@ -1093,19 +1209,10 @@ def value_to_string(self, obj):
choice_list = self.get_choices_default()
if len(choice_list) == 2:
return smart_text(choice_list[1][0])
- return Field.value_to_string(self, obj)
-
- def contribute_to_class(self, cls, name):
- super(ForeignKey, self).contribute_to_class(cls, name)
- setattr(cls, self.name, ReverseSingleRelatedObjectDescriptor(self))
+ return super(ForeignKey, self).value_to_string(obj)
def contribute_to_related_class(self, cls, related):
- # Internal FK's - i.e., those with a related name ending with '+' -
- # and swapped models don't get a related descriptor.
- if not self.rel.is_hidden() and not related.model._meta.swapped:
- setattr(cls, related.get_accessor_name(), ForeignRelatedObjectsDescriptor(related))
- if self.rel.limit_choices_to:
- cls._meta.related_fkey_lookups.append(self.rel.limit_choices_to)
+ super(ForeignKey, self).contribute_to_related_class(cls, related)
if self.rel.field_name is None:
self.rel.field_name = cls._meta.pk.name
@@ -1130,7 +1237,7 @@ def db_type(self, connection):
# in which case the column type is simply that of an IntegerField.
# If the database needs similar types for key fields however, the only
# thing we can do is making AutoField an IntegerField.
- rel_field = self.rel.get_related_field()
+ rel_field = self.related_field
if (isinstance(rel_field, AutoField) or
(not connection.features.related_fields_match_type and
isinstance(rel_field, (PositiveIntegerField,
@@ -1212,7 +1319,7 @@ def set_managed(field, model, cls):
})
-class ManyToManyField(RelatedField, Field):
+class ManyToManyField(RelatedField):
description = _("Many-to-many relationship")
def __init__(self, to, db_constraint=True, **kwargs):
@@ -1252,14 +1359,14 @@ def _get_path_info(self, direct=False):
linkfield1 = int_model._meta.get_field_by_name(self.m2m_field_name())[0]
linkfield2 = int_model._meta.get_field_by_name(self.m2m_reverse_field_name())[0]
if direct:
- join1infos, _, _, _ = linkfield1.get_reverse_path_info()
- join2infos, opts, target, final_field = linkfield2.get_path_info()
+ join1infos = linkfield1.get_reverse_path_info()
+ join2infos = linkfield2.get_path_info()
else:
- join1infos, _, _, _ = linkfield2.get_reverse_path_info()
- join2infos, opts, target, final_field = linkfield1.get_path_info()
+ join1infos = linkfield2.get_reverse_path_info()
+ join2infos = linkfield1.get_path_info()
pathinfos.extend(join1infos)
pathinfos.extend(join2infos)
- return pathinfos, opts, target, final_field
+ return pathinfos
def get_path_info(self):
return self._get_path_info(direct=True)
@@ -1402,8 +1509,3 @@ def formfield(self, **kwargs):
initial = initial()
defaults['initial'] = [i._get_pk_val() for i in initial]
return super(ManyToManyField, self).formfield(**defaults)
-
- def db_type(self, connection):
- # A ManyToManyField is not represented by a single column,
- # so return None.
- return None
View
38 django/db/models/options.py
@@ -10,6 +10,7 @@
from django.db.models.fields.proxy import OrderWrt
from django.db.models.loading import get_models, app_cache_ready
from django.utils import six
+from django.utils.functional import cached_property
from django.utils.datastructures import SortedDict
from django.utils.encoding import force_text, smart_text, python_2_unicode_compatible
from django.utils.translation import activate, deactivate_all, get_language, string_concat
@@ -173,6 +174,22 @@ def add_field(self, field):
if hasattr(self, '_field_cache'):
del self._field_cache
del self._field_name_cache
+ # The fields, concrete_fields and local_concrete_fields are
+ # implemented as cached properties for performance reasons.
+ # The attrs will not exists if the cached property isn't
+ # accessed yet, hence the try-excepts.
+ try:
+ del self.fields
+ except AttributeError:
+ pass
+ try:
+ del self.concrete_fields
+ except AttributeError:
+ pass
+ try:
+ del self.local_concrete_fields
+ except AttributeError:
+ pass
if hasattr(self, '_name_map'):
del self._name_map
@@ -245,7 +262,8 @@ def _swapped(self):
return None
swapped = property(_swapped)
- def _fields(self):
+ @cached_property
+ def fields(self):
"""
The getter for self.fields. This returns the list of field objects
available to this model (including through parent models).
@@ -258,7 +276,14 @@ def _fields(self):
except AttributeError:
self._fill_fields_cache()
return self._field_name_cache
- fields = property(_fields)
+
+ @cached_property
+ def concrete_fields(self):
+ return [f for f in self.fields if f.column is not None]
+
+ @cached_property
+ def local_concrete_fields(self):
+ return [f for f in self.local_fields if f.column is not None]
def get_fields_with_model(self):
"""
@@ -272,6 +297,10 @@ def get_fields_with_model(self):
self._fill_fields_cache()
return self._field_cache
+ def get_concrete_fields_with_model(self):
+ return [(field, model) for field, model in self.get_fields_with_model() if
+ field.column is not None]
+
def _fill_fields_cache(self):
cache = []
for parent in self.parents:
@@ -377,6 +406,9 @@ def init_name_map(self):
cache[f.name] = (f, model, True, True)
for f, model in self.get_fields_with_model():
cache[f.name] = (f, model, True, False)
+ for f in self.virtual_fields:
+ if hasattr(f, 'related'):
+ cache[f.name] = (f.related, None if f.model == self.model else f.model, True, False)
if app_cache_ready():
self._name_map = cache
return cache
@@ -432,7 +464,7 @@ def _fill_related_objects_cache(self):
for klass in get_models(include_auto_created=True, only_installed=False):
if not klass._meta.swapped:
for f in klass._meta.local_fields:
- if f.rel and not isinstance(f.rel.to, six.string_types):
+ if f.rel and not isinstance(f.rel.to, six.string_types) and f.generate_reverse_relation:
if self == f.rel.to._meta:
cache[f.related] = None
proxy_cache[f.related] = None
View
26 django/db/models/query.py
@@ -261,13 +261,13 @@ def iterator(self):
only_load = self.query.get_loaded_field_names()
if not fill_cache:
- fields = self.model._meta.fields
+ fields = self.model._meta.concrete_fields
load_fields = []
# If only/defer clauses have been specified,
# build the list of fields that are to be loaded.
if only_load:
- for field, model in self.model._meta.get_fields_with_model():
+ for field, model in self.model._meta.get_concrete_fields_with_model():
if model is None:
model = self.model
try:
@@ -280,7 +280,7 @@ def iterator(self):
load_fields.append(field.name)
index_start = len(extra_select)
- aggregate_start = index_start + len(load_fields or self.model._meta.fields)
+ aggregate_start = index_start + len(load_fields or self.model._meta.concrete_fields)
skip = None
if load_fields and not fill_cache:
@@ -312,7 +312,11 @@ def iterator(self):
if skip:
obj = model_cls(**dict(zip(init_list, row_data)))
else:
- obj = model(*row_data)
+ try:
+ obj = model(*row_data)
+ except IndexError:
+ import ipdb; ipdb.set_trace()
+ pass
# Store the source database of the object
obj._state.db = db
@@ -962,7 +966,7 @@ def _setup_aggregate_query(self, aggregates):
"""
opts = self.model._meta
if self.query.group_by is None:
- field_names = [f.attname for f in opts.fields]
+ field_names = [f.attname for f in opts.concrete_fields]
self.query.add_fields(field_names, False)
self.query.set_group_by()
@@ -1055,7 +1059,7 @@ def _setup_query(self):
else:
# Default to all fields.
self.extra_names = None
- self.field_names = [f.attname for f in self.model._meta.fields]
+ self.field_names = [f.attname for f in self.model._meta.concrete_fields]
self.aggregate_names = None
self.query.select = []
@@ -1266,7 +1270,7 @@ def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
skip = set()
init_list = []
# Build the list of fields that *haven't* been requested
- for field, model in klass._meta.get_fields_with_model():
+ for field, model in klass._meta.get_concrete_fields_with_model():
if field.name not in load_fields:
skip.add(field.attname)
elif from_parent and issubclass(from_parent, model.__class__):
@@ -1285,22 +1289,22 @@ def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
else:
# Load all fields on klass
- field_count = len(klass._meta.fields)
+ field_count = len(klass._meta.concrete_fields)
# Check if we need to skip some parent fields.
- if from_parent and len(klass._meta.local_fields) != len(klass._meta.fields):
+ if from_parent and len(klass._meta.local_concrete_fields) != len(klass._meta.concrete_fields):
# Only load those fields which haven't been already loaded into
# 'from_parent'.
non_seen_models = [p for p in klass._meta.get_parent_list()
if not issubclass(from_parent, p)]
# Load local fields, too...
non_seen_models.append(klass)
- field_names = [f.attname for f in klass._meta.fields
+ field_names = [f.attname for f in klass._meta.concrete_fields
if f.model in non_seen_models]
field_count = len(field_names)
# Try to avoid populating field_names variable for perfomance reasons.
# If field_names variable is set, we use **kwargs based model init
# which is slower than normal init.
- if field_count == len(klass._meta.fields):
+ if field_count == len(klass._meta.concrete_fields):
field_names = ()
restricted = requested is not None
View
2  django/db/models/related.py
@@ -7,7 +7,7 @@
# describe the relation in Model terms (model Options and Fields for both
# sides of the relation. The join_field is the field backing the relation.
PathInfo = namedtuple('PathInfo',
- 'from_field to_field from_opts to_opts join_field '
+ 'from_opts to_opts target_fields join_field '
'm2m direct')
class RelatedObject(object):
View
108 django/db/models/sql/compiler.py
@@ -2,10 +2,9 @@
from django.conf import settings
from django.core.exceptions import FieldError
-from django.db import transaction
from django.db.backends.util import truncate_name
from django.db.models.constants import LOOKUP_SEP
-from django.db.models.query_utils import select_related_descend
+from django.db.models.query_utils import select_related_descend, QueryWrapper
from django.db.models.sql.constants import (SINGLE, MULTI, ORDER_DIR,
GET_ITERATOR_CHUNK_SIZE, SelectInfo)
from django.db.models.sql.datastructures import EmptyResultSet
@@ -33,7 +32,7 @@ def pre_sql_setup(self):
# cleaned. We are not using a clone() of the query here.
"""
if not self.query.tables:
- self.query.join((None, self.query.model._meta.db_table, None, None))
+ self.query.join((None, self.query.model._meta.db_table, None))
if (not self.query.select and self.query.default_cols and not
self.query.included_inherited_models):
self.query.setup_inherited_models()
@@ -273,7 +272,7 @@ def get_default_columns(self, with_aliases=False, col_aliases=None,
# be used by local fields.
seen_models = {None: start_alias}
- for field, model in opts.get_fields_with_model():
+ for field, model in opts.get_concrete_fields_with_model():
if from_parent and model is not None and issubclass(from_parent, model):
# Avoid loading data for already loaded parents.
continue
@@ -314,9 +313,10 @@ def get_distinct(self):
for name in self.query.distinct_fields:
parts = name.split(LOOKUP_SEP)
- field, col, alias, _, _ = self._setup_joins(parts, opts, None)
- col, alias = self._final_join_removal(col, alias)
- result.append("%s.%s" % (qn(alias), qn2(col)))
+ field, cols, alias, _, _ = self._setup_joins(parts, opts, None)
+ cols, alias = self._final_join_removal(cols, alias)
+ for col in cols:
+ result.append("%s.%s" % (qn(alias), qn2(col)))
return result
@@ -387,15 +387,16 @@ def get_ordering(self):
elif get_order_dir(field)[0] not in self.query.extra_select:
# 'col' is of the form 'field' or 'field1__field2' or
# '-field1__field2__field', etc.
- for table, col, order in self.find_ordering_name(field,
+ for table, cols, order in self.find_ordering_name(field,
self.query.model._meta, default_order=asc):
- if (table, col) not in processed_pairs:
- elt = '%s.%s' % (qn(table), qn2(col))
- processed_pairs.add((table, col))
- if distinct and elt not in select_aliases:
- ordering_aliases.append(elt)
- result.append('%s %s' % (elt, order))
- group_by.append((elt, []))
+ for col in cols:
+ if (table, col) not in processed_pairs:
+ elt = '%s.%s' % (qn(table), qn2(col))
+ processed_pairs.add((table, col))
+ if distinct and elt not in select_aliases:
+ ordering_aliases.append(elt)
+ result.append('%s %s' % (elt, order))
+ group_by.append((elt, []))
else:
elt = qn2(col)
if distinct and col not in select_aliases:
@@ -414,7 +415,7 @@ def find_ordering_name(self, name, opts, alias=None, default_order='ASC',
"""
name, order = get_order_dir(name, default_order)
pieces = name.split(LOOKUP_SEP)
- field, col, alias, joins, opts = self._setup_joins(pieces, opts, alias)
+ field, cols, alias, joins, opts = self._setup_joins(pieces, opts, alias)
# If we get to this point and the field is a relation to another model,
# append the default ordering for that model.
@@ -432,8 +433,8 @@ def find_ordering_name(self, name, opts, alias=None, default_order='ASC',
results.extend(self.find_ordering_name(item, opts, alias,
order, already_seen))
return results
- col, alias = self._final_join_removal(col, alias)
- return [(alias, col, order)]
+ cols, alias = self._final_join_removal(cols, alias)
+ return [(alias, cols, order)]
def _setup_joins(self, pieces, opts, alias):
"""
@@ -446,13 +447,13 @@ def _setup_joins(self, pieces, opts, alias):
"""
if not alias:
alias = self.query.get_initial_alias()
- field, target, opts, joins, _ = self.query.setup_joins(
+ field, targets, opts, joins, _ = self.query.setup_joins(
pieces, opts, alias)
# We will later on need to promote those joins that were added to the
# query afresh above.
joins_to_promote = [j for j in joins if self.query.alias_refcount[j] < 2]
alias = joins[-1]
- col = target.column
+ cols = [target.column for target in targets]
if not field.rel:
# To avoid inadvertent trimming of a necessary alias, use the
# refcount to show that we are referencing a non-relation field on
@@ -463,9 +464,9 @@ def _setup_joins(self, pieces, opts, alias):
# Ordering or distinct must not affect the returned set, and INNER
# JOINS for nullable fields could do this.
self.query.promote_joins(joins_to_promote)
- return field, col, alias, joins, opts
+ return field, cols, alias, joins, opts
- def _final_join_removal(self, col, alias):
+ def _final_join_removal(self, cols, alias):
"""
A helper method for get_distinct and get_ordering. This method will
trim extra not-needed joins from the tail of the join chain.
@@ -477,12 +478,14 @@ def _final_join_removal(self, col, alias):
if alias:
while 1:
join = self.query.alias_map[alias]
- if col != join.rhs_join_col:
+ lhs_cols, rhs_cols = zip(*[(lhs_col, rhs_col) for lhs_col, rhs_col in join.join_cols])
+ if set(cols) != set(rhs_cols):
break
+
+ cols = [lhs_cols[rhs_cols.index(col)] for col in cols]
self.query.unref_alias(alias)
alias = join.lhs_alias
- col = join.lhs_join_col
- return col, alias
+ return cols, alias
def get_from_clause(self):
"""
@@ -504,22 +507,30 @@ def get_from_clause(self):
if not self.query.alias_refcount[alias]:
continue
try:
- name, alias, join_type, lhs, lhs_col, col, _, join_field = self.query.alias_map[alias]
+ name, alias, join_type, lhs, join_cols, _, join_field = self.query.alias_map[alias]
except KeyError:
# Extra tables can end up in self.tables, but not in the
# alias_map if they aren't in a join. That's OK. We skip them.
continue
alias_str = (alias != name and ' %s' % alias or '')
if join_type and not first:
- if join_field and hasattr(join_field, 'get_extra_join_sql'):
- extra_cond, extra_params = join_field.get_extra_join_sql(
- self.connection, qn, lhs, alias)
+ extra_cond = join_field.get_extra_restriction(
+ self.query.where_class, alias, lhs)
+ if extra_cond:
+ extra_sql, extra_params = extra_cond.as_sql(
+ qn, self.connection)
+ extra_sql = 'AND (%s)' % extra_sql
from_params.extend(extra_params)
else:
- extra_cond = ""
- result.append('%s %s%s ON (%s.%s = %s.%s%s)' %
- (join_type, qn(name), alias_str, qn(lhs),
- qn2(lhs_col), qn(alias), qn2(col), extra_cond))
+ extra_sql = ""
+ result.append('%s %s%s ON ('
+ % (join_type, qn(name), alias_str))
+ for index, (lhs_col, rhs_col) in enumerate(join_cols):
+ if index != 0:
+ result.append(' AND ')
+ result.append('%s.%s = %s.%s' %
+ (qn(lhs), qn2(lhs_col), qn(alias), qn2(rhs_col)))
+ result.append('%s)' % extra_sql)
else:
connector = not first and ', ' or ''
result.append('%s%s%s' % (connector, qn(name), alias_str))
@@ -545,7 +556,7 @@ def get_grouping(self, having_group_by, ordering_group_by):
select_cols = self.query.select + self.query.related_select_cols
# Just the column, not the fields.
select_cols = [s[0] for s in select_cols]
- if (len(self.query.model._meta.fields) == len(self.query.select)
+ if (len(self.query.model._meta.concrete_fields) == len(self.query.select)
and self.connection.features.allows_group_by_pk):
self.query.group_by = [
(self.query.model._meta.db_table, self.query.model._meta.pk.column)
@@ -623,14 +634,13 @@ def fill_related_selections(self, opts=None, root_alias=None, cur_depth=1,
table = f.rel.to._meta.db_table
promote = nullable or f.null
alias = self.query.join_parent_model(opts, model, root_alias, {})
-
- alias = self.query.join((alias, table, f.column,
- f.rel.get_related_field().column),
+ join_cols = f.get_joining_columns()
+ alias = self.query.join((alias, table, join_cols),
outer_if_first=promote, join_field=f)
columns, aliases = self.get_default_columns(start_alias=alias,
opts=f.rel.to._meta, as_pairs=True)
self.query.related_select_cols.extend(
- SelectInfo(col, field) for col, field in zip(columns, f.rel.to._meta.fields))
+ SelectInfo(col, field) for col, field in zip(columns, f.rel.to._meta.concrete_fields))
if restricted:
next = requested.get(f.name, {})
else:
@@ -653,7 +663,7 @@ def fill_related_selections(self, opts=None, root_alias=None, cur_depth=1,
alias = self.query.join_parent_model(opts, f.rel.to, root_alias, {})
table = model._meta.db_table
alias = self.query.join(
- (alias, table, f.rel.get_related_field().column, f.column),
+ (alias, table, f.get_joining_columns(reverse_join=True)),
outer_if_first=True, join_field=f
)
from_parent = (opts.model if issubclass(model, opts.model)
@@ -662,7 +672,7 @@ def fill_related_selections(self, opts=None, root_alias=None, cur_depth=1,
opts=model._meta, as_pairs=True, from_parent=from_parent)
self.query.related_select_cols.extend(
SelectInfo(col, field) for col, field
- in zip(columns, model._meta.fields))
+ in zip(columns, model._meta.concrete_fields))
next = requested.get(f.related_query_name(), {})
# Use True here because we are looking at the _reverse_ side of
# the relation, which is always nullable.
@@ -706,7 +716,7 @@ def results_iter(self):
if self.query.select:
fields = [f.field for f in self.query.select]
else:
- fields = self.query.model._meta.fields
+ fields = self.query.model._meta.concrete_fields
fields = fields + [f.field for f in self.query.related_select_cols]
# If the field was deferred, exclude it from being passed
@@ -776,6 +786,22 @@ def execute_sql(self, result_type=MULTI):
return list(result)
return result
+ def as_subquery_condition(self, alias, columns):
+ qn = self.quote_name_unless_alias
+ qn2 = self.connection.ops.quote_name
+ if len(columns) == 1:
+ sql, params = self.as_sql()
+ return '%s.%s IN (%s)' % (qn(alias), qn2(columns[0]), sql), params
+
+ for index, select_col in enumerate(self.query.select):
+ lhs = '%s.%s' % (qn(select_col.col[0]), qn2(select_col.col[1]))
+ rhs = '%s.%s' % (qn(alias), qn2(columns[index]))
+ self.query.where.add(
+ QueryWrapper('%s = %s' % (lhs, rhs), []), 'AND')
+
+ sql, params = self.as_sql()
+ return 'EXISTS (%s)' % sql, params
+
class SQLInsertCompiler(SQLCompiler):
def placeholder(self, field, val):
View
2  django/db/models/sql/constants.py
@@ -25,7 +25,7 @@
# dictionary in the Query class).
JoinInfo = namedtuple('JoinInfo',
'table_name rhs_alias join_type lhs_alias '
- 'lhs_join_col rhs_join_col nullable join_field')
+ 'join_cols nullable join_field')
# Pairs of column clauses to select, and (possibly None) field for the clause.
SelectInfo = namedtuple('SelectInfo', 'col field')
View
7 django/db/models/sql/expressions.py
@@ -55,13 +55,14 @@ def prepare_leaf(self, node, query, allow_joins):
self.cols.append((node, query.aggregate_select[node.name]))
else:
try:
- field, source, opts, join_list, path = query.setup_joins(
+ field, sources, opts, join_list, path = query.setup_joins(
field_list, query.get_meta(),
query.get_initial_alias(), self.reuse)
- target, _, join_list = query.trim_joins(source, join_list, path)
+ targets, _, join_list = query.trim_joins(sources, join_list, path)
if self.reuse is not None:
self.reuse.update(join_list)
- self.cols.append((node, (join_list[-1], target.column)))
+ for t in targets:
+ self.cols.append((node, (join_list[-1], t.column)))
except FieldDoesNotExist:
raise FieldError("Cannot resolve keyword %r into field. "
"Choices are: %s" % (self.name,
View
174 django/db/models/sql/query.py
@@ -452,13 +452,13 @@ def combine(self, rhs, connector):
# Now, add the joins from rhs query into the new query (skipping base
# table).
for alias in rhs.tables[1:]:
- table, _, join_type, lhs, lhs_col, col, nullable, join_field = rhs.alias_map[alias]
+ table, _, join_type, lhs, join_cols, nullable, join_field = rhs.alias_map[alias]
promote = (join_type == self.LOUTER)
# If the left side of the join was already relabeled, use the
# updated alias.
lhs = change_map.get(lhs, lhs)
new_alias = self.join(
- (lhs, table, lhs_col, col), reuse=reuse,
+ (lhs, table, join_cols), reuse=reuse,
outer_if_first=not conjunction, nullable=nullable,
join_field=join_field)
if promote:
@@ -682,7 +682,7 @@ def promote_joins(self, aliases, unconditional=False):
aliases = list(aliases)
while aliases:
alias = aliases.pop(0)
- if self.alias_map[alias].rhs_join_col is None:
+ if self.alias_map[alias].join_cols[0][1] is None:
# This is the base table (first FROM entry) - this table
# isn't really joined at all in the query, so we should not
# alter its join type.
@@ -818,7 +818,7 @@ def get_initial_alias(self):
alias = self.tables[0]
self.ref_alias(alias)
else:
- alias = self.join((None, self.model._meta.db_table, None, None))
+ alias = self.join((None, self.model._meta.db_table, None))
return alias
def count_active_tables(self):
@@ -834,11 +834,12 @@ def join(self, connection, reuse=None, outer_if_first=False,
"""
Returns an alias for the join in 'connection', either reusing an
existing alias for that join or creating a new one. 'connection' is a
- tuple (lhs, table, lhs_col, col) where 'lhs' is either an existing
- table alias or a table name. The join correspods to the SQL equivalent
- of::
+ tuple (lhs, table, join_cols) where 'lhs' is either an existing
+ table alias or a table name. 'join_cols' is a tuple of tuples containing
+ columns to join on ((l_id1, r_id1), (l_id2, r_id2)). The join corresponds
+ to the SQL equivalent of::
- lhs.lhs_col = table.col
+ lhs.l_id1 = table.r_id1 AND lhs.l_id2 = table.r_id2
The 'reuse' parameter can be either None which means all joins
(matching the connection) are reusable, or it can be a set containing
@@ -855,7 +856,7 @@ def join(self, connection, reuse=None, outer_if_first=False,
The 'join_field' is the field we are joining along (if any).
"""
- lhs, table, lhs_col, col = connection
+ lhs, table, join_cols = connection
assert lhs is None or join_field is not None
existing = self.join_map.get(connection, ())
if reuse is None:
@@ -884,7 +885,7 @@ def join(self, connection, reuse=None, outer_if_first=False,
join_type = self.LOUTER
else:
join_type = self.INNER
- join = JoinInfo(table, alias, join_type, lhs, lhs_col, col, nullable,
+ join = JoinInfo(table, alias, join_type, lhs, join_cols or ((None, None),), nullable,
join_field)
self.alias_map[alias] = join
if connection in self.join_map:
@@ -941,7 +942,7 @@ def join_parent_model(self, opts, model, alias, seen):
continue
link_field = int_opts.get_ancestor_link(int_model)
int_opts = int_model._meta
- connection = (alias, int_opts.db_table, link_field.column, int_opts.pk.column)
+ connection = (alias, int_opts.db_table, link_field.get_joining_columns())
alias = seen[int_model] = self.join(connection, nullable=False,
join_field=link_field)
return alias or seen[None]
@@ -982,18 +983,20 @@ def add_aggregate(self, aggregate, model, alias, is_summary):
# - this is an annotation over a model field
# then we need to explore the joins that are required.
- field, source, opts, join_list, path = self.setup_joins(
+ field, sources, opts, join_list, path = self.setup_joins(
field_list, opts, self.get_initial_alias())
# Process the join chain to see if it can be trimmed
- target, _, join_list = self.trim_joins(source, join_list, path)
+ targets, _, join_list = self.trim_joins(sources, join_list, path)
# If the aggregate references a model or field that requires a join,
# those joins must be LEFT OUTER - empty join rows must be returned
# in order for zeros to be returned for those aggregates.
self.promote_joins(join_list, True)
- col = (join_list[-1], target.column)
+ col = targets[0].column
+ source = sources[0]
+ col = (join_list[-1], col)
else:
# The simplest cases. No joins required -
# just reference the provided column alias.
@@ -1086,7 +1089,7 @@ def build_filter(self, filter_expr, branch_negated=False, current_negated=False,
allow_many = not branch_negated
try:
- field, target, opts, join_list, path = self.setup_joins(
+ field, sources, opts, join_list, path = self.setup_joins(
parts, opts, alias, can_reuse, allow_many,
allow_explicit_fk=True)
if can_reuse is not None:
@@ -1106,13 +1109,19 @@ def build_filter(self, filter_expr, branch_negated=False, current_negated=False,
# the far end (fewer tables in a query is better). Note that join
# promotion must happen before join trimming to have the join type
# information available when reusing joins.
- target, alias, join_list = self.trim_joins(target, join_list, path)
- clause.add((Constraint(alias, target.column, field), lookup_type, value),
- AND)
+ targets, alias, join_list = self.trim_joins(sources, join_list, path)
+
+ if hasattr(field, 'get_lookup_constraint'):
+ constraint = field.get_lookup_constraint(self.where_class, alias, targets, sources,
+ lookup_type, value)
+ else:
+ constraint = (Constraint(alias, targets[0].column, field), lookup_type, value)
+ clause.add(constraint, AND)
if current_negated and (lookup_type != 'isnull' or value is False):
self.promote_joins(join_list)
if (lookup_type != 'isnull' and (
- self.is_nullable(target) or self.alias_map[join_list[-1]].join_type == self.LOUTER)):
+ self.is_nullable(targets[0]) or
+ self.alias_map[join_list[-1]].join_type == self.LOUTER)):
# The condition added here will be SQL like this:
# NOT (col IS NOT NULL), where the first NOT is added in
# upper layers of code. The reason for addition is that if col
@@ -1122,7 +1131,7 @@ def build_filter(self, filter_expr, branch_negated=False, current_negated=False,
# (col IS NULL OR col != someval)
# <=>
# NOT (col IS NOT NULL AND col = someval).
- clause.add((Constraint(alias, target.column, None), 'isnull', False), AND)
+ clause.add((Constraint(alias, targets[0].column, None), 'isnull', False), AND)
return clause
def add_filter(self, filter_clause):
@@ -1272,22 +1281,26 @@ def names_to_path(self, names, opts, allow_many, allow_explicit_fk):
opts = int_model._meta
else:
final_field = opts.parents[int_model]
- target = final_field.rel.get_related_field()
+ targets = (final_field.rel.get_related_field(),)
opts = int_model._meta
- path.append(PathInfo(final_field, target, final_field.model._meta,
- opts, final_field, False, True))
+ path.append(PathInfo(final_field.model._meta, opts, targets, final_field, False, True))
if hasattr(field, 'get_path_info'):
- pathinfos, opts, target, final_field = field.get_path_info()
+ pathinfos = field.get_path_info()
if not allow_many:
for inner_pos, p in enumerate(pathinfos):
if p.m2m:
names_with_path.append((name, pathinfos[0:inner_pos + 1]))
raise MultiJoin(pos + 1, names_with_path)
+ last = pathinfos[-1]
path.extend(pathinfos)
+ final_field = last.join_field
+ opts = last.to_opts
+ targets = last.target_fields
names_with_path.append((name, pathinfos))
else:
# Local non-relational field.
- final_field = target = field
+ final_field = field
+ targets = (field,)
break
if pos != len(names) - 1:
@@ -1297,7 +1310,7 @@ def names_to_path(self, names, opts, allow_many, allow_explicit_fk):
"the lookup type?" % (name, names[pos + 1]))
else:
raise FieldError("Join on field %r not permitted." % name)
- return path, final_field, target
+ return path, final_field, targets
def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True,
allow_explicit_fk=False):
@@ -1330,7 +1343,7 @@ def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True,
"""
joins = [alias]
# First, generate the path for the names
- path, final_field, target = self.names_to_path(
+ path, final_field, targets = self.names_to_path(
names, opts, allow_many, allow_explicit_fk)
# Then, add the path to the query's joins. Note that we can't trim
# joins at this stage - we will need the information about join type
@@ -1338,17 +1351,19 @@ def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True,
for pos, join in enumerate(path):
opts = join.to_opts
if join.direct:
- nullable = self.is_nullable(join.from_field)
+ nullable = self.is_nullable(join.join_field)
else:
nullable = True
- connection = alias, opts.db_table, join.from_field.column, join.to_field.column
+ connection = alias, opts.db_table, join.join_field.get_joining_columns()
reuse = can_reuse if join.m2m else None
alias = self.join(connection, reuse=reuse,
nullable=nullable, join_field=join.join_field)
joins.append(alias)
- return final_field, target, opts, joins, path
+ if hasattr(final_field, 'field'):
+ final_field = final_field.field
+ return final_field, targets, opts, joins, path
- def trim_joins(self, target, joins, path):
+ def trim_joins(self, targets, joins, path):
"""
The 'target' parameter is the final field being joined to, 'joins'
is the full list of join aliases. The 'path' contain the PathInfos
@@ -1362,13 +1377,16 @@ def trim_joins(self, target, joins, path):
trimmed as we don't know if there is anything on the other side of
the join.
"""
- for info in reversed(path):
- if info.to_field == target and info.direct:
- target = info.from_field
- self.unref_alias(joins.pop())
- else:
+ for pos, info in enumerate(reversed(path)):
+ if len(joins) == 1 or not info.direct:
break
- return target, joins[-1], joins
+ join_targets = set(t.column for t in info.join_field.foreign_related_fields)
+ cur_targets = set(t.column for t in targets)
+ if not cur_targets.issubset(join_targets):
+ break
+ targets = tuple(r[0] for r in info.join_field.related_fields if r[1].column in cur_targets)
+ self.unref_alias(joins.pop())
+ return targets, joins[-1], joins
def split_exclude(self, filter_expr, prefix, can_reuse, names_with_path):
"""
@@ -1413,17 +1431,31 @@ def split_exclude(self, filter_expr, prefix, can_reuse, names_with_path):
trimmed_prefix = []
paths_in_prefix = trimmed_joins
for name, path in names_with_path:
- if paths_in_prefix - len(path) > 0:
- trimmed_prefix.append(name)
- paths_in_prefix -= len(path)
- else:
- trimmed_prefix.append(
- path[paths_in_prefix - len(path)].from_field.name)
+ if paths_in_prefix - len(path) < 0:
break
+ trimmed_prefix.append(name)
+ paths_in_prefix -= len(path)
+ join_field = path[paths_in_prefix].join_field
+ # TODO: This should be made properly multicolumn
+ # join aware. It is likely better to not use build_filter
+ # at all, instead construct joins up to the correct point,
+ # then construct the needed equality constraint manually,
+ # or maybe using SubqueryConstraint would work, too.
+ # The foreign_related_fields attribute is right here, we
+ # don't ever split joins for direct case.
+ trimmed_prefix.append(
+ join_field.field.foreign_related_fields[0].name)
trimmed_prefix = LOOKUP_SEP.join(trimmed_prefix)
- return self.build_filter(
+ condition = self.build_filter(
('%s__in' % trimmed_prefix, query),
current_negated=True, branch_negated=True, can_reuse=can_reuse)
+ # Intentionally leave the other alias as blank, if the condition
+ # refers it, things will break here.
+ extra_restriction = join_field.get_extra_restriction(
+ self.where_class, None, [t for t in query.tables if query.alias_refcount[t]][0])
+ if extra_restriction:
+ query.where.add(extra_restriction, 'AND')
+ return condition
def set_empty(self):
self.where = EmptyWhere()
@@ -1502,20 +1534,17 @@ def add_fields(self, field_names, allow_m2m=True):
try:
for name in field_names:
- field, target, u2, joins, u3 = self.setup_joins(
+ field, targets, u2, joins, path = self.setup_joins(
name.split(LOOKUP_SEP), opts, alias, None, allow_m2m,
True)
- final_alias = joins[-1]
- col = target.column
- if len(joins) > 1:
- join = self.alias_map[final_alias]
- if col == join.rhs_join_col:
- self.unref_alias(final_alias)
- final_alias = join.lhs_alias
- col = join.lhs_join_col
- joins = joins[:-1]
+
+ # Trim last join if possible
+ targets, final_alias, remaining_joins = self.trim_joins(targets, joins[-2:], path)
+ joins = joins[:-2] + remaining_joins
+
self.promote_joins(joins[1:])
- self.select.append(SelectInfo((final_alias, col), field))
+ for target in targets:
+ self.select.append(SelectInfo((final_alias, target.column), target))
except MultiJoin:
raise FieldError("Invalid field name: '%s'" % name)
except FieldError:
@@ -1590,7 +1619,7 @@ def add_count_column(self):
opts = self.model._meta
if not self.select:
count = self.aggregates_module.Count(
- (self.join((None, opts.db_table, None, None)), opts.pk.column),
+ (self.join((None, opts.db_table, None)), opts.pk.column),
is_summary=True, distinct=True)
else:
# Because of SQL portability issues, multi-column, distinct