Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP

Comparing changes

Choose two branches to see what’s changed or to start a new pull request. If you need to, you can also compare across forks.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also compare across forks.
base fork: django/django
...
head fork: jtillman/django
Checking mergeability… Don’t worry, you can still create the pull request.
  • 12 commits
  • 17 files changed
  • 0 commit comments
  • 1 contributor
View
17 django/contrib/contenttypes/generic.py
@@ -141,7 +141,7 @@ def __set__(self, instance, value):
setattr(instance, self.fk_field, fk)
setattr(instance, self.cache_attr, value)
-class GenericRelation(RelatedField, Field):
+class GenericRelation(RelatedField):
"""Provides an accessor to generic related objects (e.g. comments)"""
def __init__(self, to, **kwargs):
@@ -159,7 +159,7 @@ def __init__(self, to, **kwargs):
kwargs['blank'] = True
kwargs['editable'] = False
kwargs['serialize'] = False
- Field.__init__(self, **kwargs)
+ super(GenericRelation, self).__init__(**kwargs)
def get_path_info(self):
from_field = self.model._meta.pk
@@ -168,16 +168,23 @@ def get_path_info(self):
# Note that we are using different field for the join_field
# than from_field or to_field. This is a hack, but we need the
# GenericRelation to generate the extra SQL.
- return ([PathInfo(from_field, target, self.model._meta, opts, self, True, False)],
- opts, target, self)
+ return [PathInfo(self.model._meta, opts, (target,), self, True, False)]
def get_choices_default(self):
- return Field.get_choices(self, include_blank=False)
+ return super(GenericRelation, self).get_choices(include_blank=False)
def value_to_string(self, obj):
qs = getattr(obj, self.name).all()
return smart_text([instance._get_pk_val() for instance in qs])
+ def get_joining_columns(self, reverse_join=False):
+ # Our second join will happen in the extra sql
+ join_cols = ((self.m2m_target_field_name(), self.m2m_column_name()),)
+ if not reverse_join:
+ raise ValueError('GenericRelation only supports reverse joins.')
+
+ return join_cols
+
def m2m_db_table(self):
return self.rel.to._meta.db_table
View
11 django/core/management/validation.py
@@ -150,8 +150,15 @@ def get_validation_errors(outfile, app=None):
continue
# Make sure the related field specified by a ForeignKey is unique
- if not f.rel.to._meta.get_field(f.rel.field_name).unique:
- e.add(opts, "Field '%s' under model '%s' must have a unique=True constraint." % (f.rel.field_name, f.rel.to.__name__))
+ if len(f.foreign_related_fields) > 1:
+ has_unique_field = False
+ for rel_field in f.foreign_related_fields:
+ has_unique_field = has_unique_field or rel_field.unique
+ if not has_unique_field:
+ e.add(opts, "Field combination '%s' under model '%s' must have a unique=True constraint" % (','.join([rel_field.name for rel_field in f.foreign_related_fields]), f.rel.to.__name__))
+ else:
+ if not f.foreign_related_fields[0].unique:
+ e.add(opts, "Field '%s' under model '%s' must have a unique=True constraint." % (f.foreign_related_fields[0].name, f.rel.to.__name__))
rel_opts = f.rel.to._meta
rel_name = RelatedObject(f.rel.to, cls, f).get_accessor_name()
View
6 django/db/backends/mysql/compiler.py
@@ -17,6 +17,12 @@ def resolve_columns(self, row, fields=()):
values.append(value)
return row[:index_extra_select] + tuple(values)
+ def as_subquery_condition(self, alias, columns):
+ qn = self.quote_name_unless_alias
+ qn2 = self.connection.ops.quote_name
+ sql, params = self.as_sql()
+ return '(%s) IN (%s)' % (', '.join(['%s.%s' % (qn(alias), qn2(column)) for column in columns]), sql), params
+
class SQLInsertCompiler(compiler.SQLInsertCompiler, SQLCompiler):
pass
View
2  django/db/models/__init__.py
@@ -10,7 +10,7 @@
from django.db.models.fields import *
from django.db.models.fields.subclassing import SubfieldBase
from django.db.models.fields.files import FileField, ImageField
-from django.db.models.fields.related import ForeignKey, OneToOneField, ManyToManyField, ManyToOneRel, ManyToManyRel, OneToOneRel
+from django.db.models.fields.related import ForeignKey, ForeignObject, OneToOneField, ManyToManyField, ManyToOneRel, ManyToManyRel, OneToOneRel
from django.db.models.deletion import CASCADE, PROTECT, SET, SET_NULL, SET_DEFAULT, DO_NOTHING, ProtectedError
from django.db.models import signals
from django.utils.decorators import wraps
View
14 django/db/models/base.py
@@ -11,7 +11,7 @@
MultipleObjectsReturned, FieldError, ValidationError, NON_FIELD_ERRORS)
from django.core import validators
from django.db.models.fields import AutoField, FieldDoesNotExist
-from django.db.models.fields.related import (ManyToOneRel,
+from django.db.models.fields.related import (ForeignObjectRel, ManyToOneRel,
OneToOneField, add_lazy_relation)
from django.db import (router, transaction, DatabaseError,
DEFAULT_DB_ALIAS)
@@ -325,7 +325,7 @@ def __init__(self, *args, **kwargs):
# The reason for the kwargs check is that standard iterator passes in by
# args, and instantiation for iteration is 33% faster.
args_len = len(args)
- if args_len > len(self._meta.fields):
+ if args_len > len(self._meta.column_fields):
# Daft, but matches old exception sans the err msg.
raise IndexError("Number of args exceeds number of fields")
@@ -355,11 +355,12 @@ def __init__(self, *args, **kwargs):
# data-descriptor object (DeferredAttribute) without triggering its
# __get__ method.
if (field.attname not in kwargs and
- isinstance(self.__class__.__dict__.get(field.attname), DeferredAttribute)):
+ isinstance(self.__class__.__dict__.get(field.attname), DeferredAttribute)
+ or field.attname not in kwargs and field.column is None):
# This field will be populated on request.
continue
if kwargs:
- if isinstance(field.rel, ManyToOneRel):
+ if isinstance(field.rel, ForeignObjectRel):
try:
# Assume object instance was passed in.
rel_obj = kwargs.pop(field.name)
@@ -386,6 +387,7 @@ def __init__(self, *args, **kwargs):
val = field.get_default()
else:
val = field.get_default()
+
if is_related_object:
# If we are passed a related instance, set it using the
# field.name instead of field.attname (e.g. "user" instead of
@@ -520,7 +522,7 @@ def save(self, force_insert=False, force_update=False, using=None,
# automatically do a "update_fields" save on the loaded fields.
elif not force_insert and self._deferred and using == self._state.db:
field_names = set()
- for field in self._meta.fields:
+ for field in self._meta.column_fields:
if not field.primary_key and not hasattr(field, 'through'):
field_names.add(field.attname)
deferred_fields = [
@@ -629,7 +631,7 @@ def save_base(self, raw=False, cls=None, origin=None, force_insert=False,
order_value = manager.using(using).filter(**{field.name: getattr(self, field.attname)}).count()
self._order = order_value
- fields = meta.local_fields
+ fields = meta.local_column_fields
if not pk_set:
if force_update or update_fields:
raise ValueError("Cannot force an update in save() with no primary key.")
View
442 django/db/models/fields/related.py
@@ -7,7 +7,6 @@
PositiveIntegerField, PositiveSmallIntegerField, FieldDoesNotExist)
from django.db.models.related import RelatedObject, PathInfo
from django.db.models.query import QuerySet
-from django.db.models.query_utils import QueryWrapper
from django.db.models.deletion import CASCADE
from django.utils.encoding import smart_text
from django.utils import six
@@ -92,7 +91,12 @@ def do_pending_lookups(sender, **kwargs):
#HACK
-class RelatedField(object):
+class RelatedField(Field):
+ def db_type(self, connection):
+ '''By default related field will not have a column
+ as it relates columns to another table'''
+ return None
+
def contribute_to_class(self, cls, name):
sup = super(RelatedField, self)
@@ -121,7 +125,6 @@ def set_attributes_from_rel(self):
self.name = self.name or (self.rel.to._meta.object_name.lower() + '_' + self.rel.to._meta.pk.name)
if self.verbose_name is None:
self.verbose_name = self.rel.to._meta.verbose_name
- self.rel.field_name = self.rel.field_name or self.rel.to._meta.pk.name
def do_related_class(self, other, cls):
self.set_attributes_from_rel()
@@ -129,94 +132,6 @@ def do_related_class(self, other, cls):
if not cls._meta.abstract:
self.contribute_to_related_class(other, self.related)
- def get_prep_lookup(self, lookup_type, value):
- if hasattr(value, 'prepare'):
- return value.prepare()
- if hasattr(value, '_prepare'):
- return value._prepare()
- # FIXME: lt and gt are explicitly allowed to make
- # get_(next/prev)_by_date work; other lookups are not allowed since that
- # gets messy pretty quick. This is a good candidate for some refactoring
- # in the future.
- if lookup_type in ['exact', 'gt', 'lt', 'gte', 'lte']:
- return self._pk_trace(value, 'get_prep_lookup', lookup_type)
- if lookup_type in ('range', 'in'):
- return [self._pk_trace(v, 'get_prep_lookup', lookup_type) for v in value]
- elif lookup_type == 'isnull':
- return []
- raise TypeError("Related Field has invalid lookup: %s" % lookup_type)
-
- def get_db_prep_lookup(self, lookup_type, value, connection, prepared=False):
- if not prepared:
- value = self.get_prep_lookup(lookup_type, value)
- if hasattr(value, 'get_compiler'):
- value = value.get_compiler(connection=connection)
- if hasattr(value, 'as_sql') or hasattr(value, '_as_sql'):
- # If the value has a relabel_aliases method, it will need to
- # be invoked before the final SQL is evaluated
- if hasattr(value, 'relabel_aliases'):
- return value
- if hasattr(value, 'as_sql'):
- sql, params = value.as_sql()
- else:
- sql, params = value._as_sql(connection=connection)
- return QueryWrapper(('(%s)' % sql), params)
-
- # FIXME: lt and gt are explicitly allowed to make
- # get_(next/prev)_by_date work; other lookups are not allowed since that
- # gets messy pretty quick. This is a good candidate for some refactoring
- # in the future.
- if lookup_type in ['exact', 'gt', 'lt', 'gte', 'lte']:
- return [self._pk_trace(value, 'get_db_prep_lookup', lookup_type,
- connection=connection, prepared=prepared)]
- if lookup_type in ('range', 'in'):
- return [self._pk_trace(v, 'get_db_prep_lookup', lookup_type,
- connection=connection, prepared=prepared)
- for v in value]
- elif lookup_type == 'isnull':
- return []
- raise TypeError("Related Field has invalid lookup: %s" % lookup_type)
-
- def _pk_trace(self, value, prep_func, lookup_type, **kwargs):
- # Value may be a primary key, or an object held in a relation.
- # If it is an object, then we need to get the primary key value for
- # that object. In certain conditions (especially one-to-one relations),
- # the primary key may itself be an object - so we need to keep drilling
- # down until we hit a value that can be used for a comparison.
- v = value
-
- # In the case of an FK to 'self', this check allows to_field to be used
- # for both forwards and reverse lookups across the FK. (For normal FKs,
- # it's only relevant for forward lookups).
- if isinstance(v, self.rel.to):
- field_name = getattr(self.rel, "field_name", None)
- else:
- field_name = None
- try:
- while True:
- if field_name is None:
- field_name = v._meta.pk.name
- v = getattr(v, field_name)
- field_name = None
- except AttributeError:
- pass
- except exceptions.ObjectDoesNotExist:
- v = None
-
- field = self
- while field.rel:
- if hasattr(field.rel, 'field_name'):
- field = field.rel.to._meta.get_field(field.rel.field_name)
- else:
- field = field.rel.to._meta.pk
-
- if lookup_type in ('range', 'in'):
- v = [v]
- v = getattr(field, prep_func)(lookup_type, v, **kwargs)
- if isinstance(v, list):
- v = v[0]
- return v
-
def related_query_name(self):
# This method defines the name that can be used to identify this
# related object in a table-spanning query. It uses the lower-cased
@@ -246,8 +161,8 @@ def get_prefetch_query_set(self, instances):
rel_obj_attr = attrgetter(self.related.field.attname)
instance_attr = lambda obj: obj._get_pk_val()
instances_dict = dict((instance_attr(inst), inst) for inst in instances)
- params = {'%s__pk__in' % self.related.field.name: list(instances_dict)}
- qs = self.get_query_set(instance=instances[0]).filter(**params)
+ query = {'%s__in' % self.related.field.name: instances}
+ qs = self.get_query_set(instance=instances[0]).filter(**query)
# Since we're going to assign directly in the cache,
# we must manage the reverse relation cache manually.
rel_obj_cache_name = self.related.field.get_cache_name()
@@ -266,7 +181,9 @@ def __get__(self, instance, instance_type=None):
if related_pk is None:
rel_obj = None
else:
- params = {'%s__pk' % self.related.field.name: related_pk}
+ params = {}
+ for lh_field, rh_field in self.related.field.related_fields:
+ params['%s__%s' % (self.related.field.name, rh_field.name)] = getattr(instance, rh_field.attname)
try:
rel_obj = self.get_query_set(instance=instance).get(**params)
except self.related.model.DoesNotExist:
@@ -306,13 +223,14 @@ def __set__(self, instance, value):
raise ValueError('Cannot assign "%r": instance is on database "%s", value is on database "%s"' %
(value, instance._state.db, value._state.db))
- related_pk = getattr(instance, self.related.field.rel.get_related_field().attname)
- if related_pk is None:
+ related_pk = tuple([getattr(instance, field.attname) for field in self.related.field.foreign_related_fields])
+ if None in related_pk:
raise ValueError('Cannot assign "%r": "%s" instance isn\'t saved in the database.' %
(value, instance._meta.object_name))
# Set the value of the related field to the value of the related object's related field
- setattr(value, self.related.field.attname, related_pk)
+ for index, field in enumerate(self.related.field.native_related_fields):
+ setattr(value, field.attname, related_pk[index])
# Since we already know what the related object is, seed the related
# object caches now, too. This avoids another db hit if you get the
@@ -345,15 +263,11 @@ def get_query_set(self, **db_hints):
return QuerySet(self.field.rel.to).using(db)
def get_prefetch_query_set(self, instances):
- other_field = self.field.rel.get_related_field()
- rel_obj_attr = attrgetter(other_field.attname)
- instance_attr = attrgetter(self.field.attname)
+ rel_obj_attr = self.field.get_foreign_related_value
+ instance_attr = self.field.get_native_related_value
instances_dict = dict((instance_attr(inst), inst) for inst in instances)
- if other_field.rel:
- params = {'%s__pk__in' % self.field.rel.field_name: list(instances_dict)}
- else:
- params = {'%s__in' % self.field.rel.field_name: list(instances_dict)}
- qs = self.get_query_set(instance=instances[0]).filter(**params)
+ query = {'%s__in' % self.field.related_query_name(): instances}
+ qs = self.get_query_set(instance=instances[0]).filter(**query)
# Since we're going to assign directly in the cache,
# we must manage the reverse relation cache manually.
if not self.field.rel.multiple:
@@ -369,15 +283,11 @@ def __get__(self, instance, instance_type=None):
try:
rel_obj = getattr(instance, self.cache_name)
except AttributeError:
- val = getattr(instance, self.field.attname)
- if val is None:
+ val = self.field.get_native_related_value(instance)
+ if None in val:
rel_obj = None
else:
- other_field = self.field.rel.get_related_field()
- if other_field.rel:
- params = {'%s__%s' % (self.field.rel.field_name, other_field.rel.field_name): val}
- else:
- params = {'%s__exact' % self.field.rel.field_name: val}
+ params = {rh_field.attname: getattr(instance, lh_field.attname) for lh_field, rh_field in self.field.related_fields}
qs = self.get_query_set(instance=instance)
# Assuming the database enforces foreign keys, this won't fail.
rel_obj = qs.get(**params)
@@ -432,11 +342,11 @@ def __set__(self, instance, value):
setattr(related, self.field.related.get_cache_name(), None)
# Set the value of the related field
- try:
- val = getattr(value, self.field.rel.get_related_field().attname)
- except AttributeError:
- val = None
- setattr(instance, self.field.attname, val)
+ for lh_field, rh_field in self.field.related_fields:
+ try:
+ setattr(instance, lh_field.attname, getattr(value, rh_field.attname))
+ except AttributeError:
+ setattr(instance, lh_field.attname, None)
# Since we already know what the related object is, seed the related
# object caches now, too. This avoids another db hit if you get the
@@ -479,15 +389,12 @@ def related_manager_cls(self):
superclass = self.related.model._default_manager.__class__
rel_field = self.related.field
rel_model = self.related.model
- attname = rel_field.rel.get_related_field().attname
class RelatedManager(superclass):
def __init__(self, instance):
super(RelatedManager, self).__init__()
self.instance = instance
- self.core_filters = {
- '%s__%s' % (rel_field.name, attname): getattr(instance, attname)
- }
+ self.core_filters= {'%s__exact' % rel_field.name: instance}
self.model = rel_model
def get_query_set(self):
@@ -496,17 +403,18 @@ def get_query_set(self):
except (AttributeError, KeyError):
db = self._db or router.db_for_read(self.model, instance=self.instance)
qs = super(RelatedManager, self).get_query_set().using(db).filter(**self.core_filters)
- if getattr(self.instance, attname) is None:
- return qs.none()
+ for field in rel_field.foreign_related_fields:
+ if getattr(self.instance, field.attname) is None:
+ return qs.none()
qs._known_related_objects = {rel_field: {self.instance.pk: self.instance}}
return qs
def get_prefetch_query_set(self, instances):
- rel_obj_attr = attrgetter(rel_field.attname)
- instance_attr = attrgetter(attname)
+ rel_obj_attr = rel_field.get_native_related_value
+ instance_attr = rel_field.get_foreign_related_value
instances_dict = dict((instance_attr(inst), inst) for inst in instances)
db = self._db or router.db_for_read(self.model, instance=instances[0])
- query = {'%s__%s__in' % (rel_field.name, attname): list(instances_dict)}
+ query = {'%s__in' % rel_field.name: instances}
qs = super(RelatedManager, self).get_query_set().using(db).filter(**query)
# Since we just bypassed this class' get_query_set(), we must manage
# the reverse relation manually.
@@ -541,10 +449,10 @@ def get_or_create(self, **kwargs):
# remove() and clear() are only provided if the ForeignKey can have a value of null.
if rel_field.null:
def remove(self, *objs):
- val = getattr(self.instance, attname)
+ val = rel_field.get_foreign_related_value(self.instance)
for obj in objs:
# Is obj actually part of this descriptor set?
- if getattr(obj, rel_field.attname) == val:
+ if rel_field.get_native_related_value(obj) == val:
setattr(obj, rel_field.name, None)
obj.save()
else:
@@ -568,16 +476,26 @@ def __init__(self, model=None, query_field_name=None, instance=None, symmetrical
super(ManyRelatedManager, self).__init__()
self.model = model
self.query_field_name = query_field_name
- self.core_filters = {'%s__pk' % query_field_name: instance._get_pk_val()}
+
+ source_field = through._meta.get_field(source_field_name)
+ source_related_fields = source_field.related_fields
+
+ self.core_filters = {}
+ for lh_field, rh_field in source_related_fields:
+ self.core_filters['%s__%s' % (query_field_name, rh_field.name)] = getattr(instance, rh_field.attname)
+
self.instance = instance
self.symmetrical = symmetrical
+ self.source_field = source_field
self.source_field_name = source_field_name
self.target_field_name = target_field_name
self.reverse = reverse
self.through = through
self.prefetch_cache_name = prefetch_cache_name
- self._fk_val = self._get_fk_val(instance, source_field_name)
- if self._fk_val is None:
+ self.related_val = source_field.get_foreign_related_value(instance)
+ # Used for single column related auto created models
+ self._fk_val = self.related_val[0]
+ if None in self.related_val:
raise ValueError('"%r" needs to have a value for field "%s" before '
'this many-to-many relationship can be used.' %
(instance, source_field_name))
@@ -613,8 +531,7 @@ def get_prefetch_query_set(self, instances):
instance = instances[0]
from django.db import connections
db = self._db or router.db_for_read(instance.__class__, instance=instance)
- query = {'%s__pk__in' % self.query_field_name:
- set(obj._get_pk_val() for obj in instances)}
+ query = {'%s__in' % self.query_field_name: instances}
qs = super(ManyRelatedManager, self).get_query_set().using(db)._next_is_sticky().filter(**query)
# M2M: need to annotate the query in order to get the primary model
@@ -625,16 +542,16 @@ def get_prefetch_query_set(self, instances):
# For non-autocreated 'through' models, can't assume we are
# dealing with PK values.
fk = self.through._meta.get_field(self.source_field_name)
+ source_cols = [f.column for f in fk.native_related_fields]
source_col = fk.column
join_table = self.through._meta.db_table
connection = connections[db]
qn = connection.ops.quote_name
- qs = qs.extra(select={'_prefetch_related_val':
- '%s.%s' % (qn(join_table), qn(source_col))})
- select_attname = fk.rel.get_related_field().get_attname()
+ qs = qs.extra(select={'_prefetch_related_val_%s' % f.attname:
+ '%s.%s' % (qn(join_table), qn(f.column)) for f in fk.native_related_fields})
return (qs,
- attrgetter('_prefetch_related_val'),
- attrgetter(select_attname),
+ lambda result: tuple([getattr(result, '_prefetch_related_val_%s' % f.attname) for f in fk.native_related_fields]),
+ lambda inst: tuple([getattr(inst, f.attname) for f in fk.foreign_related_fields]),
False,
self.prefetch_cache_name)
@@ -786,7 +703,7 @@ def _clear_items(self, source_field_name):
instance=self.instance, reverse=self.reverse,
model=self.model, pk_set=None, using=db)
self.through._default_manager.using(db).filter(**{
- source_field_name: self._fk_val
+ source_field_name: self.related_val
}).delete()
if self.reverse or source_field_name == self.source_field_name:
# Don't send the signal when we are clearing the
@@ -906,19 +823,17 @@ def __set__(self, instance, value):
manager.clear()
manager.add(*value)
-
-class ManyToOneRel(object):
- def __init__(self, to, field_name, related_name=None, limit_choices_to=None,
+class ForeignObjectRel(object):
+ def __init__(self, to, related_name=None, limit_choices_to=None,
parent_link=False, on_delete=None):
try:
to._meta
except AttributeError: # to._meta doesn't exist, so it must be RECURSIVE_RELATIONSHIP_CONSTANT
assert isinstance(to, six.string_types), "'to' must be either a model, a model name or the string %r" % RECURSIVE_RELATIONSHIP_CONSTANT
- self.to, self.field_name = to, field_name
+
+ self.to = to
self.related_name = related_name
- if limit_choices_to is None:
- limit_choices_to = {}
- self.limit_choices_to = limit_choices_to
+ self.limit_choices_to = {} if limit_choices_to is None else limit_choices_to
self.multiple = True
self.parent_link = parent_link
self.on_delete = on_delete
@@ -927,6 +842,13 @@ def is_hidden(self):
"Should the related object be hidden?"
return self.related_name and self.related_name[-1] == '+'
+class ManyToOneRel(ForeignObjectRel):
+ def __init__(self, to, field_name, related_name=None, limit_choices_to=None,
+ parent_link=False, on_delete=None):
+ super(ManyToOneRel, self).__init__(to, related_name=related_name, limit_choices_to=limit_choices_to,
+ parent_link=parent_link, on_delete=on_delete)
+ self.field_name = field_name
+
def get_related_field(self):
"""
Returns the Field in the 'to' object to which this relationship is
@@ -974,7 +896,160 @@ def get_related_field(self):
return self.to._meta.pk
-class ForeignKey(RelatedField, Field):
+class ForeignObject(RelatedField):
+ def __init__(self, to, from_fields, to_fields, **kwargs):
+ self.from_fields = from_fields
+ self.to_fields = to_fields
+
+ if 'rel' not in kwargs:
+ kwargs['rel'] = ForeignObjectRel(to,
+ related_name=kwargs.pop('related_name', None),
+ limit_choices_to=kwargs.pop('limit_choices_to', None),
+ parent_link=kwargs.pop('parent_link', False),
+ on_delete=kwargs.pop('on_delete', CASCADE),
+ )
+
+ kwargs['verbose_name'] = kwargs.get('verbose_name', None)
+
+ super(ForeignObject, self).__init__(**kwargs)
+
+ def resolve_related_fields(self):
+ if len(self.from_fields) < 1 or len(self.from_fields) != len(self.to_fields):
+ raise ValueError('Foreign Object from and to fields must be the same non-zero length')
+
+ related_fields = []
+ for index in range(len(self.from_fields)):
+ from_field_name = self.from_fields[index]
+ to_field_name = self.to_fields[index]
+
+ from_field = self if from_field_name == 'self' else self.opts.get_field_by_name(from_field_name)[0]
+ to_field = self.rel.to._meta.pk if to_field_name is None else self.rel.to._meta.get_field_by_name(to_field_name)[0]
+
+ related_fields.append((from_field, to_field))
+
+ return related_fields
+
+ @property
+ def related_fields(self):
+ if not hasattr(self, '_related_fields'):
+ self._related_fields = self.resolve_related_fields()
+ return self._related_fields
+
+ @property
+ def reverse_related_fields(self):
+ return [(rhs_field, lhs_field) for lhs_field, rhs_field in self.related_fields]
+
+ @property
+ def native_related_fields(self):
+ return tuple([lhs_field for lhs_field, rhs_field in self.related_fields])
+
+ @property
+ def foreign_related_fields(self):
+ return tuple([rhs_field for lhs_field, rhs_field in self.related_fields])
+
+ def get_native_related_value(self, instance):
+ return self.get_instance_value_for_fields(instance, self.native_related_fields)
+
+ def get_foreign_related_value(self, instance):
+ return self.get_instance_value_for_fields(instance, self.foreign_related_fields)
+
+ @staticmethod
+ def get_instance_value_for_fields(instance, fields):
+ return tuple([getattr(instance, field.attname) for field in fields])
+
+ def get_attname_column(self):
+ attname, column = super(ForeignObject, self).get_attname_column()
+ return attname, None
+
+ def get_joining_columns(self, reverse_join=False):
+ source = self.reverse_related_fields if reverse_join else self.related_fields
+ return tuple([(lhs_field.column, rhs_field.column) for lhs_field, rhs_field in source])
+
+ def get_path_info(self):
+ """
+ Get path from this field to the related model.
+ """
+ opts = self.rel.to._meta
+ from_opts = self.model._meta
+ return [PathInfo(from_opts, opts, self.foreign_related_fields, self, False, True)]
+
+ def get_reverse_path_info(self):
+ """
+ Get path from the related model to this field's model.
+ """
+ opts = self.model._meta
+ from_opts = self.rel.to._meta
+
+ pathinfos = [PathInfo(from_opts, opts, (opts.pk,), self, not self.unique, False)]
+ return pathinfos
+
+ def get_lookup_constraint(self, constraint_class, alias, columns, targets, lookup_type, raw_value):
+ from django.db.models.sql.where import SubqueryConstraint, Constraint, AND, OR
+ root_constraint = constraint_class()
+
+ def get_normalized_value(value):
+
+ from django.db.models import Model
+ if isinstance(value, Model):
+ value_list = []
+ for target in targets:
+ field = target
+ # Account for one-to-one relations when sent a different model
+ while not isinstance(value, field.model):
+ field = field.rel.to._meta.get_field(field.rel.field_name)
+ value_list.append(getattr(value, field.attname))
+ return tuple(value_list)
+ elif not isinstance(value, tuple):
+ return (value,)
+ return value
+
+ is_multicolumn = len(self.related_fields) > 1
+ if hasattr(raw_value, '_as_sql') or \
+ hasattr(raw_value, 'get_compiler'):
+ root_constraint.add(SubqueryConstraint(alias, columns, [target.name for target in targets], raw_value), AND)
+ elif lookup_type == 'isnull':
+ root_constraint.add((Constraint(alias, columns[0], None), lookup_type, raw_value), AND)
+ elif lookup_type == 'exact' or lookup_type in ['gt', 'lt', 'gte', 'lte'] and not is_multicolumn:
+ value = get_normalized_value(raw_value)
+ for index, field in enumerate(targets):
+ root_constraint.add((Constraint(alias, columns[index], field), lookup_type, value[index]), AND)
+ elif lookup_type in ['range', 'in'] and not is_multicolumn:
+ values = [get_normalized_value(value) for value in raw_value]
+ value = [val[0] for val in values]
+ root_constraint.add((Constraint(alias, columns[0], targets[0]), lookup_type, value), AND)
+ elif lookup_type == 'in':
+ values = [get_normalized_value(value) for value in raw_value]
+ for value in values:
+ value_constraint = constraint_class()
+ for index, field in enumerate(targets):
+ value_constraint.add((Constraint(alias, columns[index], field), 'exact', value[index]), AND)
+ root_constraint.add(value_constraint, OR)
+ else:
+ raise TypeError('Related Field has invalid lookup: %s' % lookup_type)
+
+ return root_constraint
+
+ @property
+ def attnames(self):
+ return tuple([field.attname for field in self.native_related_fields])
+
+ def get_defaults(self):
+ return tuple([field.get_default() for field in self.native_related_fields])
+
+ def contribute_to_class(self, cls, name):
+ super(ForeignObject, self).contribute_to_class(cls, name)
+ setattr(cls, self.name, ReverseSingleRelatedObjectDescriptor(self))
+
+ def contribute_to_related_class(self, cls, related):
+ # Internal FK's - i.e., those with a related name ending with '+' -
+ # and swapped models don't get a related descriptor.
+ if not self.rel.is_hidden() and not related.model._meta.swapped:
+ setattr(cls, related.get_accessor_name(), ForeignRelatedObjectsDescriptor(related))
+ if self.rel.limit_choices_to:
+ cls._meta.related_fkey_lookups.append(self.rel.limit_choices_to)
+
+
+class ForeignKey(ForeignObject):
empty_strings_allowed = False
default_error_messages = {
'invalid': _('Model %(model)s with pk %(pk)r does not exist.')
@@ -992,7 +1067,6 @@ def __init__(self, to, to_field=None, rel_class=ManyToOneRel, **kwargs):
# the to_field during FK construction. It won't be guaranteed to
# be correct until contribute_to_class is called. Refs #12190.
to_field = to_field or (to._meta.pk and to._meta.pk.name)
- kwargs['verbose_name'] = kwargs.get('verbose_name', None)
if 'db_index' not in kwargs:
kwargs['db_index'] = True
@@ -1003,32 +1077,27 @@ def __init__(self, to, to_field=None, rel_class=ManyToOneRel, **kwargs):
parent_link=kwargs.pop('parent_link', False),
on_delete=kwargs.pop('on_delete', CASCADE),
)
- Field.__init__(self, **kwargs)
+ super(ForeignKey, self).__init__(to, ['self'], [to_field], **kwargs)
- def get_path_info(self):
- """
- Get path from this field to the related model.
- """
- opts = self.rel.to._meta
- target = self.rel.get_related_field()
- from_opts = self.model._meta
- return [PathInfo(self, target, from_opts, opts, self, False, True)], opts, target, self
+ @property
+ def related_field(self):
+ return self.foreign_related_fields[0]
def get_reverse_path_info(self):
"""
Get path from the related model to this field's model.
"""
opts = self.model._meta
- from_field = self.rel.get_related_field()
- from_opts = from_field.model._meta
- pathinfos = [PathInfo(from_field, self, from_opts, opts, self, not self.unique, False)]
- if from_field.model is self.model:
+ from_opts = self.rel.to._meta
+
+ if self.rel.to is self.model:
# Recursive foreign key to self.
- target = opts.get_field_by_name(
- self.rel.field_name)[0]
+ target = self.related_field
else:
target = opts.pk
- return pathinfos, opts, target, self
+
+ pathinfos = [PathInfo(from_opts, opts, (opts.pk,), self, not self.unique, False)]
+ return pathinfos
def validate(self, value, model_instance):
if self.rel.parent_link:
@@ -1049,21 +1118,26 @@ def validate(self, value, model_instance):
def get_attname(self):
return '%s_id' % self.name
+ def get_attname_column(self):
+ attname = self.get_attname()
+ column = self.db_column or attname
+ return attname, column
+
def get_validator_unique_lookup_type(self):
- return '%s__%s__exact' % (self.name, self.rel.get_related_field().name)
+ return '%s__%s__exact' % (self.name, self.related_field.name)
def get_default(self):
"Here we check if the default value is an object and return the to_field if so."
field_default = super(ForeignKey, self).get_default()
if isinstance(field_default, self.rel.to):
- return getattr(field_default, self.rel.get_related_field().attname)
+ return getattr(field_default, self.related_field.attname)
return field_default
def get_db_prep_save(self, value, connection):
if value == '' or value == None:
return None
else:
- return self.rel.get_related_field().get_db_prep_save(value,
+ return self.related_field.get_db_prep_save(value,
connection=connection)
def value_to_string(self, obj):
@@ -1076,19 +1150,10 @@ def value_to_string(self, obj):
choice_list = self.get_choices_default()
if len(choice_list) == 2:
return smart_text(choice_list[1][0])
- return Field.value_to_string(self, obj)
-
- def contribute_to_class(self, cls, name):
- super(ForeignKey, self).contribute_to_class(cls, name)
- setattr(cls, self.name, ReverseSingleRelatedObjectDescriptor(self))
+ return super(ForeignKey, self).value_to_string(obj)
def contribute_to_related_class(self, cls, related):
- # Internal FK's - i.e., those with a related name ending with '+' -
- # and swapped models don't get a related descriptor.
- if not self.rel.is_hidden() and not related.model._meta.swapped:
- setattr(cls, related.get_accessor_name(), ForeignRelatedObjectsDescriptor(related))
- if self.rel.limit_choices_to:
- cls._meta.related_fkey_lookups.append(self.rel.limit_choices_to)
+ super(ForeignKey, self).contribute_to_related_class(cls, related)
if self.rel.field_name is None:
self.rel.field_name = cls._meta.pk.name
@@ -1113,7 +1178,7 @@ def db_type(self, connection):
# in which case the column type is simply that of an IntegerField.
# If the database needs similar types for key fields however, the only
# thing we can do is making AutoField an IntegerField.
- rel_field = self.rel.get_related_field()
+ rel_field = self.related_field
if (isinstance(rel_field, AutoField) or
(not connection.features.related_fields_match_type and
isinstance(rel_field, (PositiveIntegerField,
@@ -1195,7 +1260,7 @@ def set_managed(field, model, cls):
})
-class ManyToManyField(RelatedField, Field):
+class ManyToManyField(RelatedField):
description = _("Many-to-many relationship")
def __init__(self, to, **kwargs):
@@ -1219,7 +1284,7 @@ def __init__(self, to, **kwargs):
if kwargs['rel'].through is not None:
assert self.db_table is None, "Cannot specify a db_table if an intermediary model is used."
- Field.__init__(self, **kwargs)
+ super(ManyToManyField, self).__init__(**kwargs)
msg = _('Hold down "Control", or "Command" on a Mac, to select more than one.')
self.help_text = string_concat(self.help_text, ' ', msg)
@@ -1233,14 +1298,14 @@ def _get_path_info(self, direct=False):
linkfield1 = int_model._meta.get_field_by_name(self.m2m_field_name())[0]
linkfield2 = int_model._meta.get_field_by_name(self.m2m_reverse_field_name())[0]
if direct:
- join1infos, _, _, _ = linkfield1.get_reverse_path_info()
- join2infos, opts, target, final_field = linkfield2.get_path_info()
+ join1infos = linkfield1.get_reverse_path_info()
+ join2infos = linkfield2.get_path_info()
else:
- join1infos, _, _, _ = linkfield2.get_reverse_path_info()
- join2infos, opts, target, final_field = linkfield1.get_path_info()
+ join1infos = linkfield2.get_reverse_path_info()
+ join2infos = linkfield1.get_path_info()
pathinfos.extend(join1infos)
pathinfos.extend(join2infos)
- return pathinfos, opts, target, final_field
+ return pathinfos
def get_path_info(self):
return self._get_path_info(direct=True)
@@ -1383,8 +1448,3 @@ def formfield(self, **kwargs):
initial = initial()
defaults['initial'] = [i._get_pk_val() for i in initial]
return super(ManyToManyField, self).formfield(**defaults)
-
- def db_type(self, connection):
- # A ManyToManyField is not represented by a single column,
- # so return None.
- return None
View
11 django/db/models/options.py
@@ -250,6 +250,14 @@ def _fields(self):
return self._field_name_cache
fields = property(_fields)
+ @property
+ def column_fields(self):
+ return [f for f in self.fields if f.column is not None]
+
+ @property
+ def local_column_fields(self):
+ return [f for f in self.local_fields if f.column is not None]
+
def get_fields_with_model(self):
"""
Returns a sequence of (field, model) pairs for all fields. The "model"
@@ -262,6 +270,9 @@ def get_fields_with_model(self):
self._fill_fields_cache()
return self._field_cache
+ def get_column_fields_with_model(self):
+ return [(field, model) for field, model in self.get_fields_with_model() if field.column is not None]
+
def _fill_fields_cache(self):
cache = []
for parent in self.parents:
View
16 django/db/models/query.py
@@ -258,13 +258,13 @@ def iterator(self):
only_load = self.query.get_loaded_field_names()
if not fill_cache:
- fields = self.model._meta.fields
+ fields = self.model._meta.column_fields
load_fields = []
# If only/defer clauses have been specified,
# build the list of fields that are to be loaded.
if only_load:
- for field, model in self.model._meta.get_fields_with_model():
+ for field, model in self.model._meta.get_column_fields_with_model():
if model is None:
model = self.model
try:
@@ -966,7 +966,7 @@ def _setup_aggregate_query(self, aggregates):
"""
opts = self.model._meta
if self.query.group_by is None:
- field_names = [f.attname for f in opts.fields]
+ field_names = [f.attname for f in opts.column_fields]
self.query.add_fields(field_names, False)
self.query.set_group_by()
@@ -1244,7 +1244,7 @@ def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
skip = set()
init_list = []
# Build the list of fields that *haven't* been requested
- for field, model in klass._meta.get_fields_with_model():
+ for field, model in klass._meta.get_column_fields_with_model():
if field.name not in load_fields:
skip.add(field.attname)
elif from_parent and issubclass(from_parent, model.__class__):
@@ -1263,22 +1263,22 @@ def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
else:
# Load all fields on klass
- field_count = len(klass._meta.fields)
+ field_count = len(klass._meta.column_fields)
# Check if we need to skip some parent fields.
- if from_parent and len(klass._meta.local_fields) != len(klass._meta.fields):
+ if from_parent and len(klass._meta.local_column_fields) != len(klass._meta.column_fields):
# Only load those fields which haven't been already loaded into
# 'from_parent'.
non_seen_models = [p for p in klass._meta.get_parent_list()
if not issubclass(from_parent, p)]
# Load local fields, too...
non_seen_models.append(klass)
- field_names = [f.attname for f in klass._meta.fields
+ field_names = [f.attname for f in klass._meta.column_fields
if f.model in non_seen_models]
field_count = len(field_names)
# Try to avoid populating field_names variable for perfomance reasons.
# If field_names variable is set, we use **kwargs based model init
# which is slower than normal init.
- if field_count == len(klass._meta.fields):
+ if field_count == len(klass._meta.column_fields):
field_names = ()
restricted = requested is not None
View
2  django/db/models/related.py
@@ -7,7 +7,7 @@
# describe the relation in Model terms (model Options and Fields for both
# sides of the relation. The join_field is the field backing the relation.
PathInfo = namedtuple('PathInfo',
- 'from_field to_field from_opts to_opts join_field '
+ 'from_opts to_opts target_fields join_field '
'm2m direct')
class BoundRelatedObject(object):
View
96 django/db/models/sql/compiler.py
@@ -4,7 +4,7 @@
from django.db import transaction
from django.db.backends.util import truncate_name
from django.db.models.constants import LOOKUP_SEP
-from django.db.models.query_utils import select_related_descend
+from django.db.models.query_utils import select_related_descend, QueryWrapper
from django.db.models.sql.constants import (SINGLE, MULTI, ORDER_DIR,
GET_ITERATOR_CHUNK_SIZE, SelectInfo)
from django.db.models.sql.datastructures import EmptyResultSet
@@ -30,7 +30,7 @@ def pre_sql_setup(self):
# cleaned. We are not using a clone() of the query here.
"""
if not self.query.tables:
- self.query.join((None, self.query.model._meta.db_table, None, None))
+ self.query.join((None, self.query.model._meta.db_table, None))
if (not self.query.select and self.query.default_cols and not
self.query.included_inherited_models):
self.query.setup_inherited_models()
@@ -262,7 +262,7 @@ def get_default_columns(self, with_aliases=False, col_aliases=None,
seen = self.query.included_inherited_models.copy()
if start_alias:
seen[None] = start_alias
- for field, model in opts.get_fields_with_model():
+ for field, model in opts.get_column_fields_with_model():
if from_parent and model is not None and issubclass(from_parent, model):
# Avoid loading data for already loaded parents.
continue
@@ -302,9 +302,10 @@ def get_distinct(self):
for name in self.query.distinct_fields:
parts = name.split(LOOKUP_SEP)
- field, col, alias, _, _ = self._setup_joins(parts, opts, None)
- col, alias = self._final_join_removal(col, alias)
- result.append("%s.%s" % (qn(alias), qn2(col)))
+ field, cols, alias, _, _ = self._setup_joins(parts, opts, None)
+ cols, alias = self._final_join_removal(cols, alias)
+ for col in cols:
+ result.append("%s.%s" % (qn(alias), qn2(col)))
return result
@@ -375,15 +376,16 @@ def get_ordering(self):
elif get_order_dir(field)[0] not in self.query.extra_select:
# 'col' is of the form 'field' or 'field1__field2' or
# '-field1__field2__field', etc.
- for table, col, order in self.find_ordering_name(field,
+ for table, cols, order in self.find_ordering_name(field,
self.query.model._meta, default_order=asc):
- if (table, col) not in processed_pairs:
- elt = '%s.%s' % (qn(table), qn2(col))
- processed_pairs.add((table, col))
- if distinct and elt not in select_aliases:
- ordering_aliases.append(elt)
- result.append('%s %s' % (elt, order))
- group_by.append((elt, []))
+ for col in cols:
+ if (table, col) not in processed_pairs:
+ elt = '%s.%s' % (qn(table), qn2(col))
+ processed_pairs.add((table, col))
+ if distinct and elt not in select_aliases:
+ ordering_aliases.append(elt)
+ result.append('%s %s' % (elt, order))
+ group_by.append((elt, []))
else:
elt = qn2(col)
if distinct and col not in select_aliases:
@@ -402,7 +404,7 @@ def find_ordering_name(self, name, opts, alias=None, default_order='ASC',
"""
name, order = get_order_dir(name, default_order)
pieces = name.split(LOOKUP_SEP)
- field, col, alias, joins, opts = self._setup_joins(pieces, opts, alias)
+ field, cols, alias, joins, opts = self._setup_joins(pieces, opts, alias)
# If we get to this point and the field is a relation to another model,
# append the default ordering for that model.
@@ -420,8 +422,8 @@ def find_ordering_name(self, name, opts, alias=None, default_order='ASC',
results.extend(self.find_ordering_name(item, opts, alias,
order, already_seen))
return results
- col, alias = self._final_join_removal(col, alias)
- return [(alias, col, order)]
+ cols, alias = self._final_join_removal(cols, alias)
+ return [(alias, cols, order)]
def _setup_joins(self, pieces, opts, alias):
"""
@@ -434,13 +436,13 @@ def _setup_joins(self, pieces, opts, alias):
"""
if not alias:
alias = self.query.get_initial_alias()
- field, target, opts, joins, _ = self.query.setup_joins(
+ field, targets, opts, joins, _ = self.query.setup_joins(
pieces, opts, alias)
# We will later on need to promote those joins that were added to the
# query afresh above.
joins_to_promote = [j for j in joins if self.query.alias_refcount[j] < 2]
alias = joins[-1]
- col = target.column
+ cols = [target.column for target in targets]
if not field.rel:
# To avoid inadvertent trimming of a necessary alias, use the
# refcount to show that we are referencing a non-relation field on
@@ -451,9 +453,9 @@ def _setup_joins(self, pieces, opts, alias):
# Ordering or distinct must not affect the returned set, and INNER
# JOINS for nullable fields could do this.
self.query.promote_joins(joins_to_promote)
- return field, col, alias, joins, opts
+ return field, cols, alias, joins, opts
- def _final_join_removal(self, col, alias):
+ def _final_join_removal(self, cols, alias):
"""
A helper method for get_distinct and get_ordering. This method will
trim extra not-needed joins from the tail of the join chain.
@@ -465,12 +467,14 @@ def _final_join_removal(self, col, alias):
if alias:
while 1:
join = self.query.alias_map[alias]
- if col != join.rhs_join_col:
+ lhs_cols, rhs_cols = zip(*[(lhs_col, rhs_col) for lhs_col, rhs_col in join.join_cols])
+ if set(cols) != set(rhs_cols):
break
+
+ cols = [lhs_cols[rhs_cols.index(col)] for col in cols]
self.query.unref_alias(alias)
alias = join.lhs_alias
- col = join.lhs_join_col
- return col, alias
+ return cols, alias
def get_from_clause(self):
"""
@@ -492,7 +496,7 @@ def get_from_clause(self):
if not self.query.alias_refcount[alias]:
continue
try:
- name, alias, join_type, lhs, lhs_col, col, _, join_field = self.query.alias_map[alias]
+ name, alias, join_type, lhs, join_cols, _, join_field = self.query.alias_map[alias]
except KeyError:
# Extra tables can end up in self.tables, but not in the
# alias_map if they aren't in a join. That's OK. We skip them.
@@ -505,9 +509,14 @@ def get_from_clause(self):
from_params.extend(extra_params)
else:
extra_cond = ""
- result.append('%s %s%s ON (%s.%s = %s.%s%s)' %
- (join_type, qn(name), alias_str, qn(lhs),
- qn2(lhs_col), qn(alias), qn2(col), extra_cond))
+ result.append('%s %s%s ON ('
+ % (join_type, qn(name), alias_str))
+ for index, (lhs_col, rhs_col) in enumerate(join_cols):
+ if index != 0:
+ result.append(' AND ')
+ result.append('%s.%s = %s.%s' %
+ (qn(lhs), qn2(lhs_col), qn(alias), qn2(rhs_col)))
+ result.append('%s)' % extra_cond)
else:
connector = not first and ', ' or ''
result.append('%s%s%s' % (connector, qn(name), alias_str))
@@ -533,7 +542,7 @@ def get_grouping(self, ordering_group_by):
select_cols = self.query.select + self.query.related_select_cols
# Just the column, not the fields.
select_cols = [s[0] for s in select_cols]
- if (len(self.query.model._meta.fields) == len(self.query.select)
+ if (len(self.query.model._meta.column_fields) == len(self.query.select)
and self.connection.features.allows_group_by_pk):
self.query.group_by = [
(self.query.model._meta.db_table, self.query.model._meta.pk.column)
@@ -609,14 +618,13 @@ def fill_related_selections(self, opts=None, root_alias=None, cur_depth=1,
table = f.rel.to._meta.db_table
promote = nullable or f.null
alias = self.query.join_parent_model(opts, model, root_alias, {})
-
- alias = self.query.join((alias, table, f.column,
- f.rel.get_related_field().column),
+ join_cols = f.get_joining_columns()
+ alias = self.query.join((alias, table, join_cols),
promote=promote, join_field=f)
columns, aliases = self.get_default_columns(start_alias=alias,
opts=f.rel.to._meta, as_pairs=True)
self.query.related_select_cols.extend(
- SelectInfo(col, field) for col, field in zip(columns, f.rel.to._meta.fields))
+ SelectInfo(col, field) for col, field in zip(columns, f.rel.to._meta.column_fields))
if restricted:
next = requested.get(f.name, {})
else:
@@ -639,7 +647,7 @@ def fill_related_selections(self, opts=None, root_alias=None, cur_depth=1,
alias = self.query.join_parent_model(opts, f.rel.to, root_alias, {})
table = model._meta.db_table
alias = self.query.join(
- (alias, table, f.rel.get_related_field().column, f.column),
+ (alias, table, f.get_joining_columns(reverse_join=True)),
promote=True, join_field=f
)
from_parent = (opts.model if issubclass(model, opts.model)
@@ -648,7 +656,7 @@ def fill_related_selections(self, opts=None, root_alias=None, cur_depth=1,
opts=model._meta, as_pairs=True, from_parent=from_parent)
self.query.related_select_cols.extend(
SelectInfo(col, field) for col, field
- in zip(columns, model._meta.fields))
+ in zip(columns, model._meta.column_fields))
next = requested.get(f.related_query_name(), {})
# Use True here because we are looking at the _reverse_ side of
# the relation, which is always nullable.
@@ -697,7 +705,7 @@ def results_iter(self):
if self.query.select:
fields = [f.field for f in self.query.select]
else:
- fields = self.query.model._meta.fields
+ fields = self.query.model._meta.column_fields
fields = fields + [f.field for f in self.query.related_select_cols]
# If the field was deferred, exclude it from being passed
@@ -767,6 +775,22 @@ def execute_sql(self, result_type=MULTI):
return list(result)
return result
+ def as_subquery_condition(self, alias, columns):
+ qn = self.quote_name_unless_alias
+ qn2 = self.connection.ops.quote_name
+ if len(columns) == 1:
+ sql, params = self.as_sql()
+ return '%s.%s IN (%s)' % (qn(alias), qn2(columns[0]), sql), params
+
+ for index, select_col in enumerate(self.query.select):
+ lhs = '%s.%s' % (qn(select_col.col[0]), qn2(select_col.col[1]))
+ rhs = '%s.%s' % (qn(alias), qn2(columns[index]))
+ self.query.where.add(
+ QueryWrapper('%s = %s' % (lhs, rhs), []), 'AND')
+
+ sql, params = self.as_sql()
+ return 'EXISTS (%s)' % sql, params
+
class SQLInsertCompiler(SQLCompiler):
def placeholder(self, field, val):
View
2  django/db/models/sql/constants.py
@@ -24,7 +24,7 @@
# dictionary in the Query class).
JoinInfo = namedtuple('JoinInfo',
'table_name rhs_alias join_type lhs_alias '
- 'lhs_join_col rhs_join_col nullable join_field')
+ 'join_cols nullable join_field')
# Pairs of column clauses to select, and (possibly None) field for the clause.
SelectInfo = namedtuple('SelectInfo', 'col field')
View
7 django/db/models/sql/expressions.py
@@ -49,13 +49,14 @@ def prepare_leaf(self, node, query, allow_joins):
self.cols.append((node, query.aggregate_select[node.name]))
else:
try:
- field, source, opts, join_list, path = query.setup_joins(
+ field, sources, opts, join_list, path = query.setup_joins(
field_list, query.get_meta(),
query.get_initial_alias(), self.reuse)
- col, _, join_list = query.trim_joins(source, join_list, path)
+ cols, _, join_list = query.trim_joins(sources, join_list, path)
if self.reuse is not None:
self.reuse.update(join_list)
- self.cols.append((node, (join_list[-1], col)))
+ for col in cols:
+ self.cols.append((node, (join_list[-1], col)))
except FieldDoesNotExist:
raise FieldError("Cannot resolve keyword %r into field. "
"Choices are: %s" % (self.name,
View
120 django/db/models/sql/query.py
@@ -501,13 +501,13 @@ def combine(self, rhs, connector):
# Now, add the joins from rhs query into the new query (skipping base
# table).
for alias in rhs.tables[1:]:
- table, _, join_type, lhs, lhs_col, col, nullable, join_field = rhs.alias_map[alias]
+ table, _, join_type, lhs, join_cols, nullable, join_field = rhs.alias_map[alias]
promote = (join_type == self.LOUTER)
# If the left side of the join was already relabeled, use the
# updated alias.
lhs = change_map.get(lhs, lhs)
new_alias = self.join(
- (lhs, table, lhs_col, col), reuse=reuse, promote=promote,
+ (lhs, table, join_cols), reuse=reuse, promote=promote,
outer_if_first=not conjunction, nullable=nullable,
join_field=join_field)
# We can't reuse the same join again in the query. If we have two
@@ -746,7 +746,7 @@ def promote_joins(self, aliases, unconditional=False):
aliases = list(aliases)
while aliases:
alias = aliases.pop(0)
- if self.alias_map[alias].rhs_join_col is None:
+ if self.alias_map[alias].join_cols[0][1] is None:
# This is the base table (first FROM entry) - this table
# isn't really joined at all in the query, so we should not
# alter its join type.
@@ -903,7 +903,7 @@ def get_initial_alias(self):
alias = self.tables[0]
self.ref_alias(alias)
else:
- alias = self.join((None, self.model._meta.db_table, None, None))
+ alias = self.join((None, self.model._meta.db_table, None))
return alias
def count_active_tables(self):
@@ -919,11 +919,12 @@ def join(self, connection, reuse=None, promote=False,
"""
Returns an alias for the join in 'connection', either reusing an
existing alias for that join or creating a new one. 'connection' is a
- tuple (lhs, table, lhs_col, col) where 'lhs' is either an existing
- table alias or a table name. The join correspods to the SQL equivalent
- of::
+ tuple (lhs, table, join_cols) where 'lhs' is either an existing
+ table alias or a table name. 'join_cols' is a tuple of tuples containing
+ columns to join on ((l_id1, r_id1), (l_id2, r_id2)). The join corresponds
+ to the SQL equivalent of::
- lhs.lhs_col = table.col
+ lhs.l_id1 = table.r_id1 AND lhs.l_id2 = table.r_id2
The 'reuse' parameter can be either None which means all joins
(matching the connection) are reusable, or it can be a set containing
@@ -946,7 +947,7 @@ def join(self, connection, reuse=None, promote=False,
The 'join_field' is the field we are joining along (if any).
"""
- lhs, table, lhs_col, col = connection
+ lhs, table, join_cols = connection
assert lhs is None or join_field is not None
existing = self.join_map.get(connection, ())
if reuse is None:
@@ -978,7 +979,7 @@ def join(self, connection, reuse=None, promote=False,
join_type = self.LOUTER
else:
join_type = self.INNER
- join = JoinInfo(table, alias, join_type, lhs, lhs_col, col, nullable,
+ join = JoinInfo(table, alias, join_type, lhs, join_cols or ((None, None),), nullable,
join_field)
self.alias_map[alias] = join
if connection in self.join_map:
@@ -1035,7 +1036,7 @@ def join_parent_model(self, opts, model, alias, seen):
continue
link_field = int_opts.get_ancestor_link(int_model)
int_opts = int_model._meta
- connection = (alias, int_opts.db_table, link_field.column, int_opts.pk.column)
+ connection = (alias, int_opts.db_table, link_field.get_joining_columns())
alias = seen[int_model] = self.join(connection, nullable=False,
join_field=link_field)
return alias or seen[None]
@@ -1089,17 +1090,19 @@ def add_aggregate(self, aggregate, model, alias, is_summary):
# - this is an annotation over a model field
# then we need to explore the joins that are required.
- field, source, opts, join_list, path = self.setup_joins(
+ field, sources, opts, join_list, path = self.setup_joins(
field_list, opts, self.get_initial_alias())
# Process the join chain to see if it can be trimmed
- col, _, join_list = self.trim_joins(source, join_list, path)
+ cols, _, join_list = self.trim_joins(sources, join_list, path)
# If the aggregate references a model or field that requires a join,
# those joins must be LEFT OUTER - empty join rows must be returned
# in order for zeros to be returned for those aggregates.
self.promote_joins(join_list, True)
+ col = cols[0]
+ source = sources[0]
col = (join_list[-1], col)
else:
# The simplest cases. No joins required -
@@ -1193,7 +1196,7 @@ def add_filter(self, filter_expr, connector=AND, negate=False,
allow_many = not negate
try:
- field, target, opts, join_list, path = self.setup_joins(
+ field, targets, opts, join_list, path = self.setup_joins(
parts, opts, alias, can_reuse, allow_many,
allow_explicit_fk=True)
if can_reuse is not None:
@@ -1214,16 +1217,20 @@ def add_filter(self, filter_expr, connector=AND, negate=False,
# the far end (fewer tables in a query is better). Note that join
# promotion must happen before join trimming to have the join type
# information available when reusing joins.
- col, alias, join_list = self.trim_joins(target, join_list, path)
+ cols, alias, join_list = self.trim_joins(targets, join_list, path)
+
+ if hasattr(field, 'get_lookup_constraint'):
+ constraint = field.get_lookup_constraint(self.where_class, alias, cols, targets, lookup_type, value)
+ else:
+ constraint = (Constraint(alias, cols[0], field), lookup_type, value)
if having_clause or force_having:
- if (alias, col) not in self.group_by:
- self.group_by.append((alias, col))
- self.having.add((Constraint(alias, col, field), lookup_type, value),
- connector)
+ for col in cols:
+ if (alias, col) not in self.group_by:
+ self.group_by.append((alias, col))
+ self.having.add(constraint, connector)
else:
- self.where.add((Constraint(alias, col, field), lookup_type, value),
- connector)
+ self.where.add(constraint, connector)
if negate:
self.promote_joins(join_list)
@@ -1231,7 +1238,7 @@ def add_filter(self, filter_expr, connector=AND, negate=False,
if len(join_list) > 1:
for alias in join_list:
if self.alias_map[alias].join_type == self.LOUTER:
- j_col = self.alias_map[alias].rhs_join_col
+ j_col = self.alias_map[alias].join_cols[0][1]
# The join promotion logic should never produce
# a LOUTER join for the base join - assert that.
assert j_col is not None
@@ -1252,7 +1259,7 @@ def add_filter(self, filter_expr, connector=AND, negate=False,
# be included in the final resultset. We are essentially creating
# SQL like this here: NOT (col IS NOT NULL), where the first NOT
# is added in upper layers of the code.
- self.where.add((Constraint(alias, col, None), 'isnull', False), AND)
+ self.where.add((Constraint(alias, cols[0], None), 'isnull', False), AND)
def add_q(self, q_object, used_aliases=None, force_having=False):
@@ -1355,16 +1362,20 @@ def names_to_path(self, names, opts, allow_many=False,
opts = int_model._meta
else:
final_field = opts.parents[int_model]
- target = final_field.rel.get_related_field()
+ targets = (final_field.rel.get_related_field(),)
opts = int_model._meta
- path.append(PathInfo(final_field, target, final_field.model._meta,
- opts, final_field, False, True))
+ path.append(PathInfo(final_field.model._meta, opts, targets, final_field, False, True))
if hasattr(field, 'get_path_info'):
- pathinfos, opts, target, final_field = field.get_path_info()
+ pathinfos = field.get_path_info()
path.extend(pathinfos)
+ last = pathinfos[-1]
+ final_field = last.join_field
+ opts = last.to_opts
+ targets = last.target_fields
else:
# Local non-relational field.
- final_field = target = field
+ final_field = field
+ targets = (field,)
break
multijoin_pos = None
for m2mpos, pathinfo in enumerate(path):
@@ -1381,7 +1392,7 @@ def names_to_path(self, names, opts, allow_many=False,
raise FieldError("Join on field %r not permitted." % name)
if multijoin_pos is not None and len(path) >= multijoin_pos and not allow_many:
raise MultiJoin(multijoin_pos + 1)
- return path, final_field, target
+ return path, final_field, targets
def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True,
allow_explicit_fk=False):
@@ -1414,7 +1425,7 @@ def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True,
"""
joins = [alias]
# First, generate the path for the names
- path, final_field, target = self.names_to_path(
+ path, final_field, targets = self.names_to_path(
names, opts, allow_many, allow_explicit_fk)
# Then, add the path to the query's joins. Note that we can't trim
# joins at this stage - we will need the information about join type
@@ -1422,17 +1433,17 @@ def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True,
for pos, join in enumerate(path):
opts = join.to_opts
if join.direct:
- nullable = self.is_nullable(join.from_field)
+ nullable = self.is_nullable(join.join_field)
else:
nullable = True
- connection = alias, opts.db_table, join.from_field.column, join.to_field.column
+ connection = alias, opts.db_table, join.join_field.get_joining_columns(reverse_join=not join.direct)
reuse = can_reuse if join.m2m else None
alias = self.join(connection, reuse=reuse,
nullable=nullable, join_field=join.join_field)
joins.append(alias)
- return final_field, target, opts, joins, path
+ return final_field, targets, opts, joins, path
- def trim_joins(self, target, joins, path):
+ def trim_joins(self, targets, joins, path):
"""
The 'target' parameter is the final field being joined to, 'joins'
is the full list of join aliases. The 'path' contain the PathInfos
@@ -1446,13 +1457,15 @@ def trim_joins(self, target, joins, path):
trimmed as we don't know if there is anything on the other side of
the join.
"""
+ target_cols = [target.column for target in targets]
for info in reversed(path):
- if info.to_field == target and info.direct:
- target = info.from_field
- self.unref_alias(joins.pop())
- else:
+ join = self.alias_map[joins[-1]]
+ lhs_cols, rhs_cols = zip(*[(lhs_col, rhs_col) for lhs_col, rhs_col in join.join_cols])
+ if len(joins) == 1 or not info.direct or set(target_cols) != set(rhs_cols):
break
- return target.column, joins[-1], joins
+ target_cols = [lhs_cols[rhs_cols.index(col)] for col in target_cols]
+ self.unref_alias(joins.pop())
+ return target_cols, joins[-1], joins
def split_exclude(self, filter_expr, prefix, can_reuse):
"""
@@ -1590,20 +1603,17 @@ def add_fields(self, field_names, allow_m2m=True):
try:
for name in field_names:
- field, target, u2, joins, u3 = self.setup_joins(
+ field, targets, u2, joins, path = self.setup_joins(
name.split(LOOKUP_SEP), opts, alias, None, allow_m2m,
True)
- final_alias = joins[-1]
- col = target.column
- if len(joins) > 1:
- join = self.alias_map[final_alias]
- if col == join.rhs_join_col:
- self.unref_alias(final_alias)
- final_alias = join.lhs_alias
- col = join.lhs_join_col
- joins = joins[:-1]
+
+ # Trim last join if possible
+ cols, final_alias, remaining_joins = self.trim_joins(targets, joins[-2:], path)
+ joins = joins[:-2] + remaining_joins
+
self.promote_joins(joins[1:])
- self.select.append(SelectInfo((final_alias, col), field))
+ for index, col in enumerate(cols):
+ self.select.append(SelectInfo((final_alias, col), targets[index]))
except MultiJoin:
raise FieldError("Invalid field name: '%s'" % name)
except FieldError:
@@ -1678,7 +1688,7 @@ def add_count_column(self):
opts = self.model._meta
if not self.select:
count = self.aggregates_module.Count(
- (self.join((None, opts.db_table, None, None)), opts.pk.column),
+ (self.join((None, opts.db_table, None)), opts.pk.column),
is_summary=True, distinct=True)
else:
# Because of SQL portability issues, multi-column, distinct
@@ -1882,9 +1892,9 @@ def set_start(self, start):
"""
opts = self.model._meta
alias = self.get_initial_alias()
- field, col, opts, joins, extra = self.setup_joins(
+ field, targets, opts, joins, extra = self.setup_joins(
start.split(LOOKUP_SEP), opts, alias)
- select_col = self.alias_map[joins[1]].lhs_join_col
+ select_col = self.alias_map[joins[1]].join_cols[0][0]
select_alias = alias
# The call to setup_joins added an extra reference to everything in
@@ -1897,12 +1907,12 @@ def set_start(self, start):
# is *always* the same value as lhs).
for alias in joins[1:]:
join_info = self.alias_map[alias]
- if (join_info.lhs_join_col != select_col
+ if (join_info.join_cols[0][0] != select_col
or join_info.join_type != self.INNER):
break
self.unref_alias(select_alias)
select_alias = join_info.rhs_alias