Permalink
Browse files

Fixed #7596. Added Model.objects.bulk_create, and make use of it in s…

…everal places. This provides a performance benefit when inserting multiple objects. THanks to Russ for the review, and Simon Meers for the MySQl implementation.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@16739 bcc190cf-cafb-0310-a4f2-bffc1f526a37
  • Loading branch information...
1 parent e55bbf4 commit 7deb25b8dd5aa1ed02b5e30cbc67cd1fb0c3d6e6 @alex alex committed Sep 9, 2011
@@ -46,17 +46,15 @@ def create_permissions(app, created_models, verbosity, **kwargs):
"content_type", "codename"
))
- for ctype, (codename, name) in searched_perms:
- # If the permissions exists, move on.
- if (ctype.pk, codename) in all_perms:
- continue
- p = auth_app.Permission.objects.create(
- codename=codename,
- name=name,
- content_type=ctype
- )
- if verbosity >= 2:
- print "Adding permission '%s'" % p
+ objs = [
+ auth_app.Permission(codename=codename, name=name, content_type=ctype)
+ for ctype, (codename, name) in searched_perms
+ if (ctype.pk, codename) not in all_perms
+ ]
+ auth_app.Permission.objects.bulk_create(objs)
+ if verbosity >= 2:
+ for obj in objs:
+ print "Adding permission '%s'" % obj
def create_superuser(app, created_models, verbosity, **kwargs):
@@ -8,25 +8,41 @@ def update_contenttypes(app, created_models, verbosity=2, **kwargs):
entries that no longer have a matching model class.
"""
ContentType.objects.clear_cache()
- content_types = list(ContentType.objects.filter(app_label=app.__name__.split('.')[-2]))
app_models = get_models(app)
if not app_models:
return
- for klass in app_models:
- opts = klass._meta
- try:
- ct = ContentType.objects.get(app_label=opts.app_label,
- model=opts.object_name.lower())
- content_types.remove(ct)
- except ContentType.DoesNotExist:
- ct = ContentType(name=smart_unicode(opts.verbose_name_raw),
- app_label=opts.app_label, model=opts.object_name.lower())
- ct.save()
- if verbosity >= 2:
- print "Adding content type '%s | %s'" % (ct.app_label, ct.model)
- # The presence of any remaining content types means the supplied app has an
- # undefined model. Confirm that the content type is stale before deletion.
- if content_types:
+ # They all have the same app_label, get the first one.
+ app_label = app_models[0]._meta.app_label
+ app_models = dict(
+ (model._meta.object_name.lower(), model)
+ for model in app_models
+ )
+ # Get all the content types
+ content_types = dict(
+ (ct.model, ct)
+ for ct in ContentType.objects.filter(app_label=app_label)
+ )
+ to_remove = [
+ ct
+ for (model_name, ct) in content_types.iteritems()
+ if model_name not in app_models
+ ]
+
+ cts = ContentType.objects.bulk_create([
+ ContentType(
+ name=smart_unicode(model._meta.verbose_name_raw),
+ app_label=app_label,
+ model=model_name,
+ )
+ for (model_name, model) in app_models.iteritems()
+ if model_name not in content_types
+ ])
+ if verbosity >= 2:
+ for ct in cts:
+ print "Adding content type '%s | %s'" % (ct.app_label, ct.model)
+
+ # Confirm that the content type is stale before deletion.
+ if to_remove:
if kwargs.get('interactive', False):
content_type_display = '\n'.join([' %s | %s' % (ct.app_label, ct.model) for ct in content_types])
ok_to_delete = raw_input("""The following content types are stale and need to be deleted:
@@ -42,7 +58,7 @@ def update_contenttypes(app, created_models, verbosity=2, **kwargs):
ok_to_delete = False
if ok_to_delete == 'yes':
- for ct in content_types:
+ for ct in to_remove:
if verbosity >= 2:
print "Deleting stale content type '%s | %s'" % (ct.app_label, ct.model)
ct.delete()
@@ -301,8 +301,10 @@ class BaseDatabaseFeatures(object):
can_use_chunked_reads = True
can_return_id_from_insert = False
+ has_bulk_insert = False
uses_autocommit = False
uses_savepoints = False
+ can_combine_inserts_with_and_without_auto_increment_pk = False
# If True, don't use integer foreign keys referring to, e.g., positive
# integer primary keys.
@@ -124,6 +124,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
allows_group_by_pk = True
related_fields_match_type = True
allow_sliced_subqueries = False
+ has_bulk_insert = True
has_select_for_update = True
has_select_for_update_nowait = False
supports_forward_references = False
@@ -263,6 +264,10 @@ def year_lookup_bounds(self, value):
def max_name_length(self):
return 64
+ def bulk_insert_sql(self, fields, num_values):
+ items_sql = "(%s)" % ", ".join(["%s"] * len(fields))
+ return "VALUES " + ", ".join([items_sql] * num_values)
+
class DatabaseWrapper(BaseDatabaseWrapper):
vendor = 'mysql'
operators = {
@@ -74,6 +74,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
can_defer_constraint_checks = True
has_select_for_update = True
has_select_for_update_nowait = True
+ has_bulk_insert = True
class DatabaseWrapper(BaseDatabaseWrapper):
@@ -180,3 +180,7 @@ def last_executed_query(self, cursor, sql, params):
def return_insert_id(self):
return "RETURNING %s", ()
+
+ def bulk_insert_sql(self, fields, num_values):
+ items_sql = "(%s)" % ", ".join(["%s"] * len(fields))
+ return "VALUES " + ", ".join([items_sql] * num_values)
@@ -58,6 +58,8 @@ class DatabaseFeatures(BaseDatabaseFeatures):
supports_unspecified_pk = True
supports_1000_query_parameters = False
supports_mixed_date_datetime_comparisons = False
+ has_bulk_insert = True
+ can_combine_inserts_with_and_without_auto_increment_pk = True
def _supports_stddev(self):
"""Confirm support for STDDEV and related stats functions
@@ -106,7 +108,7 @@ def drop_foreignkey_sql(self):
return ""
def pk_default_value(self):
- return 'NULL'
+ return "NULL"
def quote_name(self, name):
if name.startswith('"') and name.endswith('"'):
@@ -154,6 +156,14 @@ def convert_values(self, value, field):
# No field, or the field isn't known to be a decimal or integer
return value
+ def bulk_insert_sql(self, fields, num_values):
+ res = []
+ res.append("SELECT %s" % ", ".join(
+ "%%s AS %s" % self.quote_name(f.column) for f in fields
+ ))
+ res.extend(["UNION SELECT %s" % ", ".join(["%s"] * len(fields))] * (num_values - 1))
+ return " ".join(res)
+
class DatabaseWrapper(BaseDatabaseWrapper):
vendor = 'sqlite'
# SQLite requires LIKE statements to include an ESCAPE clause if the value
@@ -540,24 +540,16 @@ def save_base(self, raw=False, cls=None, origin=None, force_insert=False,
order_value = manager.using(using).filter(**{field.name: getattr(self, field.attname)}).count()
self._order = order_value
+ fields = meta.local_fields
if not pk_set:
if force_update:
raise ValueError("Cannot force an update in save() with no primary key.")
- values = [(f, f.get_db_prep_save(raw and getattr(self, f.attname) or f.pre_save(self, True), connection=connection))
- for f in meta.local_fields if not isinstance(f, AutoField)]
- else:
- values = [(f, f.get_db_prep_save(raw and getattr(self, f.attname) or f.pre_save(self, True), connection=connection))
- for f in meta.local_fields]
+ fields = [f for f in fields if not isinstance(f, AutoField)]
record_exists = False
update_pk = bool(meta.has_auto_field and not pk_set)
- if values:
- # Create a new record.
- result = manager._insert(values, return_id=update_pk, using=using)
- else:
- # Create a new record with defaults for everything.
- result = manager._insert([(meta.pk, connection.ops.pk_default_value())], return_id=update_pk, raw_values=True, using=using)
+ result = manager._insert([self], fields=fields, return_id=update_pk, using=using, raw=raw)
if update_pk:
setattr(self, meta.pk.attname, result)
@@ -430,15 +430,15 @@ def add(self, *objs):
add.alters_data = True
def create(self, **kwargs):
- kwargs.update({rel_field.name: instance})
+ kwargs[rel_field.name] = instance
db = router.db_for_write(rel_model, instance=instance)
return super(RelatedManager, self.db_manager(db)).create(**kwargs)
create.alters_data = True
def get_or_create(self, **kwargs):
# Update kwargs with the related object that this
# ForeignRelatedObjectsDescriptor knows about.
- kwargs.update({rel_field.name: instance})
+ kwargs[rel_field.name] = instance
db = router.db_for_write(rel_model, instance=instance)
return super(RelatedManager, self.db_manager(db)).get_or_create(**kwargs)
get_or_create.alters_data = True
@@ -578,11 +578,13 @@ def _add_items(self, source_field_name, target_field_name, *objs):
instance=self.instance, reverse=self.reverse,
model=self.model, pk_set=new_ids, using=db)
# Add the ones that aren't there already
- for obj_id in new_ids:
- self.through._default_manager.using(db).create(**{
+ self.through._default_manager.using(db).bulk_create([
+ self.through(**{
'%s_id' % source_field_name: self._pk_val,
'%s_id' % target_field_name: obj_id,
})
+ for obj_id in new_ids
+ ])
if self.reverse or source_field_name == self.source_field_name:
# Don't send the signal when we are inserting the
# duplicate data row for symmetrical reverse entries.
@@ -701,12 +703,12 @@ class ReverseManyRelatedObjectsDescriptor(object):
def __init__(self, m2m_field):
self.field = m2m_field
- def _through(self):
+ @property
+ def through(self):
# through is provided so that you have easy access to the through
# model (Book.authors.through) for inlines, etc. This is done as
# a property to ensure that the fully resolved value is returned.
return self.field.rel.through
- through = property(_through)
def __get__(self, instance, instance_type=None):
if instance is None:
@@ -136,6 +136,9 @@ def get_or_create(self, **kwargs):
def create(self, **kwargs):
return self.get_query_set().create(**kwargs)
+ def bulk_create(self, *args, **kwargs):
+ return self.get_query_set().bulk_create(*args, **kwargs)
+
def filter(self, *args, **kwargs):
return self.get_query_set().filter(*args, **kwargs)
@@ -193,8 +196,8 @@ def using(self, *args, **kwargs):
def exists(self, *args, **kwargs):
return self.get_query_set().exists(*args, **kwargs)
- def _insert(self, values, **kwargs):
- return insert_query(self.model, values, **kwargs)
+ def _insert(self, objs, fields, **kwargs):
+ return insert_query(self.model, objs, fields, **kwargs)
def _update(self, values, **kwargs):
return self.get_query_set()._update(values, **kwargs)
@@ -5,10 +5,12 @@
import copy
from django.db import connections, router, transaction, IntegrityError
+from django.db.models.fields import AutoField
from django.db.models.query_utils import (Q, select_related_descend,
deferred_class_factory, InvalidQuery)
from django.db.models.deletion import Collector
from django.db.models import signals, sql
+from django.utils.functional import partition
# Used to control how many objects are worked with at once in some cases (e.g.
# when deleting objects).
@@ -352,6 +354,41 @@ def create(self, **kwargs):
obj.save(force_insert=True, using=self.db)
return obj
+ def bulk_create(self, objs):
+ """
+ Inserts each of the instances into the database. This does *not* call
+ save() on each of the instances, does not send any pre/post save
+ signals, and does not set the primary key attribute if it is an
+ autoincrement field.
+ """
+ # So this case is fun. When you bulk insert you don't get the primary
+ # keys back (if it's an autoincrement), so you can't insert into the
+ # child tables which references this. There are two workarounds, 1)
+ # this could be implemented if you didn't have an autoincrement pk,
+ # and 2) you could do it by doing O(n) normal inserts into the parent
+ # tables to get the primary keys back, and then doing a single bulk
+ # insert into the childmost table. We're punting on these for now
+ # because they are relatively rare cases.
+ if self.model._meta.parents:
+ raise ValueError("Can't bulk create an inherited model")
+ if not objs:
+ return
+ self._for_write = True
+ connection = connections[self.db]
+ fields = self.model._meta.local_fields
+ if (connection.features.can_combine_inserts_with_and_without_auto_increment_pk
+ and self.model._meta.has_auto_field):
+ self.model._base_manager._insert(objs, fields=fields, using=self.db)
+ else:
+ objs_with_pk, objs_without_pk = partition(
+ lambda o: o.pk is None,
+ objs
+ )
+ if objs_with_pk:
+ self.model._base_manager._insert(objs_with_pk, fields=fields, using=self.db)
+ if objs_without_pk:
+ self.model._base_manager._insert(objs_without_pk, fields=[f for f in fields if not isinstance(f, AutoField)], using=self.db)
+
def get_or_create(self, **kwargs):
"""
Looks up an object with the given kwargs, creating one if necessary.
@@ -1437,12 +1474,12 @@ def model_fields(self):
self._model_fields[converter(column)] = field
return self._model_fields
-def insert_query(model, values, return_id=False, raw_values=False, using=None):
+def insert_query(model, objs, fields, return_id=False, raw=False, using=None):
"""
Inserts a new record for the given model. This provides an interface to
the InsertQuery class and is how Model.save() is implemented. It is not
part of the public API.
"""
query = sql.InsertQuery(model)
- query.insert_values(values, raw_values)
+ query.insert_values(fields, objs, raw=raw)
return query.get_compiler(using=using).execute_sql(return_id)
Oops, something went wrong. Retry.

0 comments on commit 7deb25b

Please sign in to comment.