Permalink
Browse files

queryset-refactor: Model inheritance support.

This adds both types of model inheritance: abstract base classes (ABCs) and
multi-table inheritance. See the documentation and tests / examples for details.

Still a few known bugs here, so don't file tickets (I know about them). Not
quite ready for prime-time usage, but it mostly works as expected.


git-svn-id: http://code.djangoproject.com/svn/django/branches/queryset-refactor@7126 bcc190cf-cafb-0310-a4f2-bffc1f526a37
  • Loading branch information...
malcolmt committed Feb 17, 2008
1 parent 2d05885 commit da6570bf082620205737c82cd7deb7185daaf538
@@ -26,7 +26,7 @@ def django_table_list(only_existing=False):
for app in models.get_apps():
for model in models.get_models(app):
tables.append(model._meta.db_table)
- tables.extend([f.m2m_db_table() for f in model._meta.many_to_many])
+ tables.extend([f.m2m_db_table() for f in model._meta.local_many_to_many])
if only_existing:
existing = table_list()
tables = [t for t in tables if t in existing]
@@ -54,12 +54,12 @@ def sequence_list():
for app in apps:
for model in models.get_models(app):
- for f in model._meta.fields:
+ for f in model._meta.local_fields:
if isinstance(f, models.AutoField):
sequence_list.append({'table': model._meta.db_table, 'column': f.column})
break # Only one AutoField is allowed per model, so don't bother continuing.
- for f in model._meta.many_to_many:
+ for f in model._meta.local_many_to_many:
sequence_list.append({'table': f.m2m_db_table(), 'column': None})
return sequence_list
@@ -147,7 +147,7 @@ def sql_delete(app, style):
if cursor and table_name_converter(model._meta.db_table) in table_names:
# The table exists, so it needs to be dropped
opts = model._meta
- for f in opts.fields:
+ for f in opts.local_fields:
if f.rel and f.rel.to not in to_delete:
references_to_delete.setdefault(f.rel.to, []).append( (model, f) )
@@ -179,7 +179,7 @@ def sql_delete(app, style):
# Output DROP TABLE statements for many-to-many tables.
for model in app_models:
opts = model._meta
- for f in opts.many_to_many:
+ for f in opts.local_many_to_many:
if isinstance(f.rel, generic.GenericRel):
continue
if cursor and table_name_converter(f.m2m_db_table()) in table_names:
@@ -256,7 +256,7 @@ def sql_model_create(model, style, known_models=set()):
pending_references = {}
qn = connection.ops.quote_name
inline_references = connection.features.inline_fk_references
- for f in opts.fields:
+ for f in opts.local_fields:
col_type = f.db_type()
tablespace = f.db_tablespace or opts.db_tablespace
if col_type is None:
@@ -351,7 +351,7 @@ def many_to_many_sql_for_model(model, style):
final_output = []
qn = connection.ops.quote_name
inline_references = connection.features.inline_fk_references
- for f in opts.many_to_many:
+ for f in opts.local_many_to_many:
if not isinstance(f.rel, generic.GenericRel):
tablespace = f.db_tablespace or opts.db_tablespace
if tablespace and connection.features.supports_tablespaces and connection.features.autoindexes_primary_keys:
@@ -458,7 +458,7 @@ def sql_indexes_for_model(model, style):
output = []
qn = connection.ops.quote_name
- for f in model._meta.fields:
+ for f in model._meta.local_fields:
if f.db_index and not ((f.primary_key or f.unique) and connection.features.autoindexes_primary_keys):
unique = f.unique and 'UNIQUE ' or ''
tablespace = f.db_tablespace or model._meta.db_tablespace
@@ -32,7 +32,7 @@ def get_validation_errors(outfile, app=None):
opts = cls._meta
# Do field-specific validation.
- for f in opts.fields:
+ for f in opts.local_fields:
if f.name == 'id' and not f.primary_key and opts.pk.name == 'id':
e.add(opts, '"%s": You can\'t use "id" as a field name, because each model automatically gets an "id" field if none of the fields have primary_key=True. You need to either remove/rename your "id" field or add primary_key=True to a field.' % f.name)
if f.name.endswith('_'):
@@ -69,8 +69,8 @@ def get_validation_errors(outfile, app=None):
if db_version < (5, 0, 3) and isinstance(f, (models.CharField, models.CommaSeparatedIntegerField, models.SlugField)) and f.max_length > 255:
e.add(opts, '"%s": %s cannot have a "max_length" greater than 255 when you are using a version of MySQL prior to 5.0.3 (you are using %s).' % (f.name, f.__class__.__name__, '.'.join([str(n) for n in db_version[:3]])))
- # Check to see if the related field will clash with any
- # existing fields, m2m fields, m2m related objects or related objects
+ # Check to see if the related field will clash with any existing
+ # fields, m2m fields, m2m related objects or related objects
if f.rel:
if f.rel.to not in models.get_models():
e.add(opts, "'%s' has relation with model %s, which has not been installed" % (f.name, f.rel.to))
@@ -87,7 +87,7 @@ def get_validation_errors(outfile, app=None):
e.add(opts, "Accessor for field '%s' clashes with field '%s.%s'. Add a related_name argument to the definition for '%s'." % (f.name, rel_opts.object_name, r.name, f.name))
if r.name == rel_query_name:
e.add(opts, "Reverse query name for field '%s' clashes with field '%s.%s'. Add a related_name argument to the definition for '%s'." % (f.name, rel_opts.object_name, r.name, f.name))
- for r in rel_opts.many_to_many:
+ for r in rel_opts.local_many_to_many:
if r.name == rel_name:
e.add(opts, "Accessor for field '%s' clashes with m2m field '%s.%s'. Add a related_name argument to the definition for '%s'." % (f.name, rel_opts.object_name, r.name, f.name))
if r.name == rel_query_name:
@@ -104,9 +104,10 @@ def get_validation_errors(outfile, app=None):
if r.get_accessor_name() == rel_query_name:
e.add(opts, "Reverse query name for field '%s' clashes with related field '%s.%s'. Add a related_name argument to the definition for '%s'." % (f.name, rel_opts.object_name, r.get_accessor_name(), f.name))
- for i, f in enumerate(opts.many_to_many):
+ for i, f in enumerate(opts.local_many_to_many):
# Check to see if the related m2m field will clash with any
- # existing fields, m2m fields, m2m related objects or related objects
+ # existing fields, m2m fields, m2m related objects or related
+ # objects
if f.rel.to not in models.get_models():
e.add(opts, "'%s' has m2m relation with model %s, which has not been installed" % (f.name, f.rel.to))
# it is a string and we could not find the model it refers to
@@ -117,17 +118,17 @@ def get_validation_errors(outfile, app=None):
rel_opts = f.rel.to._meta
rel_name = RelatedObject(f.rel.to, cls, f).get_accessor_name()
rel_query_name = f.related_query_name()
- # If rel_name is none, there is no reverse accessor.
- # (This only occurs for symmetrical m2m relations to self).
- # If this is the case, there are no clashes to check for this field, as
- # there are no reverse descriptors for this field.
+ # If rel_name is none, there is no reverse accessor (this only
+ # occurs for symmetrical m2m relations to self). If this is the
+ # case, there are no clashes to check for this field, as there are
+ # no reverse descriptors for this field.
if rel_name is not None:
for r in rel_opts.fields:
if r.name == rel_name:
e.add(opts, "Accessor for m2m field '%s' clashes with field '%s.%s'. Add a related_name argument to the definition for '%s'." % (f.name, rel_opts.object_name, r.name, f.name))
if r.name == rel_query_name:
e.add(opts, "Reverse query name for m2m field '%s' clashes with field '%s.%s'. Add a related_name argument to the definition for '%s'." % (f.name, rel_opts.object_name, r.name, f.name))
- for r in rel_opts.many_to_many:
+ for r in rel_opts.local_many_to_many:
if r.name == rel_name:
e.add(opts, "Accessor for m2m field '%s' clashes with m2m field '%s.%s'. Add a related_name argument to the definition for '%s'." % (f.name, rel_opts.object_name, r.name, f.name))
if r.name == rel_query_name:
View
@@ -1,9 +1,14 @@
+import types
+import sys
+import os
+from itertools import izip
+
import django.db.models.manipulators
import django.db.models.manager
from django.core import validators
from django.core.exceptions import ObjectDoesNotExist, MultipleObjectsReturned
from django.db.models.fields import AutoField, ImageField, FieldDoesNotExist
-from django.db.models.fields.related import OneToOneRel, ManyToOneRel
+from django.db.models.fields.related import OneToOneRel, ManyToOneRel, OneToOneField
from django.db.models.query import delete_objects, Q
from django.db.models.options import Options, AdminOptions
from django.db import connection, transaction
@@ -14,40 +19,53 @@
from django.utils.functional import curry
from django.utils.encoding import smart_str, force_unicode, smart_unicode
from django.conf import settings
-from itertools import izip
-import types
-import sys
-import os
class ModelBase(type):
"Metaclass for all models"
def __new__(cls, name, bases, attrs):
# If this isn't a subclass of Model, don't do anything special.
try:
parents = [b for b in bases if issubclass(b, Model)]
- if not parents:
- return super(ModelBase, cls).__new__(cls, name, bases, attrs)
except NameError:
# 'Model' isn't defined yet, meaning we're looking at Django's own
# Model class, defined below.
+ parents = []
+ if not parents:
return super(ModelBase, cls).__new__(cls, name, bases, attrs)
# Create the class.
new_class = type.__new__(cls, name, bases, {'__module__': attrs.pop('__module__')})
- new_class.add_to_class('_meta', Options(attrs.pop('Meta', None)))
+ meta = attrs.pop('Meta', None)
+ # FIXME: Promote Meta to a newstyle class before attaching it to the
+ # model.
+ ## if meta:
+ ## new_class.Meta = meta
+ new_class.add_to_class('_meta', Options(meta))
+ # FIXME: Need to be smarter here. Exception is an old-style class in
+ # Python <= 2.4, new-style in Python 2.5+. This construction is only
+ # really correct for old-style classes.
new_class.add_to_class('DoesNotExist', types.ClassType('DoesNotExist', (ObjectDoesNotExist,), {}))
- new_class.add_to_class('MultipleObjectsReturned',
- types.ClassType('MultipleObjectsReturned', (MultipleObjectsReturned, ), {}))
+ new_class.add_to_class('MultipleObjectsReturned', types.ClassType('MultipleObjectsReturned', (MultipleObjectsReturned, ), {}))
- # Build complete list of parents
+ # Do the appropriate setup for any model parents.
+ abstract_parents = []
for base in parents:
- # Things without _meta aren't functional models, so they're
- # uninteresting parents.
- if hasattr(base, '_meta'):
- new_class._meta.parents.append(base)
- new_class._meta.parents.extend(base._meta.parents)
-
+ if not hasattr(base, '_meta'):
+ # Things without _meta aren't functional models, so they're
+ # uninteresting parents.
+ continue
+ if not base._meta.abstract:
+ attr_name = '%s_ptr' % base._meta.module_name
+ field = OneToOneField(base, name=attr_name, auto_created=True)
+ new_class.add_to_class(attr_name, field)
+ new_class._meta.parents[base] = field
+ else:
+ abstract_parents.append(base)
+ if getattr(new_class, '_default_manager', None) is not None:
+ # We have a parent who set the default manager. We need to override
+ # this.
+ new_class._default_manager = None
if getattr(new_class._meta, 'app_label', None) is None:
# Figure out the app_label by looking one level up.
# For 'django.contrib.sites.models', this would be 'sites'.
@@ -63,21 +81,26 @@ def __new__(cls, name, bases, attrs):
for obj_name, obj in attrs.items():
new_class.add_to_class(obj_name, obj)
- # Add Fields inherited from parents
- for parent in new_class._meta.parents:
- for field in parent._meta.fields:
- # Only add parent fields if they aren't defined for this class.
- try:
- new_class._meta.get_field(field.name)
- except FieldDoesNotExist:
- field.contribute_to_class(new_class, field.name)
+ for parent in abstract_parents:
+ names = [f.name for f in new_class._meta.local_fields + new_class._meta.many_to_many]
+ for field in parent._meta.local_fields:
+ if field.name in names:
+ raise TypeError('Local field %r in class %r clashes with field of similar name from abstract base class %r'
+ % (field.name, name, parent.__name__))
+ new_class.add_to_class(field.name, field)
- new_class._prepare()
+ if new_class._meta.abstract:
+ # Abstract base models can't be instantiated and don't appear in
+ # the list of models for an app. We do the final setup for them a
+ # little differently from normal models.
+ return new_class
+ new_class._prepare()
register_models(new_class._meta.app_label, new_class)
+
# Because of the way imports happen (recursively), we may or may not be
- # the first class for this model to register with the framework. There
- # should only be one class for each model, so we must always return the
+ # the first time this model tries to register with the framework. There
+ # should only be one class for each model, so we always return the
# registered version.
return get_model(new_class._meta.app_label, name, False)
@@ -113,8 +136,10 @@ def _prepare(cls):
class Model(object):
__metaclass__ = ModelBase
- def _get_pk_val(self):
- return getattr(self, self._meta.pk.attname)
+ def _get_pk_val(self, meta=None):
+ if not meta:
+ meta = self._meta
+ return getattr(self, meta.pk.attname)
def _set_pk_val(self, value):
return setattr(self, self._meta.pk.attname, value)
@@ -207,19 +232,30 @@ def __init__(self, *args, **kwargs):
raise TypeError, "'%s' is an invalid keyword argument for this function" % kwargs.keys()[0]
dispatcher.send(signal=signals.post_init, sender=self.__class__, instance=self)
- def save(self, raw=False):
- dispatcher.send(signal=signals.pre_save, sender=self.__class__,
- instance=self, raw=raw)
+ def save(self, raw=False, cls=None):
+ if not cls:
+ dispatcher.send(signal=signals.pre_save, sender=self.__class__,
+ instance=self, raw=raw)
+ cls = self.__class__
+ meta = self._meta
+ signal = True
+ else:
+ meta = cls._meta
+ signal = False
+
+ for parent, field in meta.parents.items():
+ self.save(raw, parent)
+ setattr(self, field.attname, self._get_pk_val(parent._meta))
- non_pks = [f for f in self._meta.fields if not f.primary_key]
+ non_pks = [f for f in self._meta.local_fields if not f.primary_key]
# First, try an UPDATE. If that doesn't update anything, do an INSERT.
- pk_val = self._get_pk_val()
+ pk_val = self._get_pk_val(meta)
# Note: the comparison with '' is required for compatibility with
# oldforms-style model creation.
pk_set = pk_val is not None and smart_unicode(pk_val) != u''
record_exists = True
- manager = self.__class__._default_manager
+ manager = cls._default_manager
if pk_set:
# Determine whether a record with the primary key already exists.
if manager.filter(pk=pk_val).extra(select={'a': 1}).values('a').order_by():
@@ -231,16 +267,16 @@ def save(self, raw=False):
record_exists = False
if not pk_set or not record_exists:
if not pk_set:
- values = [(f.name, f.get_db_prep_save(raw and getattr(self, f.attname) or f.pre_save(self, True))) for f in self._meta.fields if not isinstance(f, AutoField)]
+ values = [(f.name, f.get_db_prep_save(raw and getattr(self, f.attname) or f.pre_save(self, True))) for f in meta.local_fields if not isinstance(f, AutoField)]
else:
- values = [(f.name, f.get_db_prep_save(raw and getattr(self, f.attname) or f.pre_save(self, True))) for f in self._meta.fields]
+ values = [(f.name, f.get_db_prep_save(raw and getattr(self, f.attname) or f.pre_save(self, True))) for f in meta.local_fields]
- if self._meta.order_with_respect_to:
- field = self._meta.order_with_respect_to
+ if meta.order_with_respect_to:
+ field = meta.order_with_respect_to
values.append(('_order', manager.filter(**{field.name: getattr(self, field.attname)}).count()))
record_exists = False
- update_pk = bool(self._meta.has_auto_field and not pk_set)
+ update_pk = bool(meta.has_auto_field and not pk_set)
if values:
# Create a new record.
result = manager._insert(_return_id=update_pk, **dict(values))
@@ -250,12 +286,13 @@ def save(self, raw=False):
_raw_values=True, pk=connection.ops.pk_default_value())
if update_pk:
- setattr(self, self._meta.pk.attname, result)
+ setattr(self, meta.pk.attname, result)
transaction.commit_unless_managed()
- # Run any post-save hooks.
- dispatcher.send(signal=signals.post_save, sender=self.__class__,
- instance=self, created=(not record_exists), raw=raw)
+ if signal:
+ # Run any post-save hooks.
+ dispatcher.send(signal=signals.post_save, sender=self.__class__,
+ instance=self, created=(not record_exists), raw=raw)
save.alters_data = True
Oops, something went wrong.

0 comments on commit da6570b

Please sign in to comment.