Skip to content

Commit

Permalink
Expand tests, ensure foreign keys are not dereferenced
Browse files Browse the repository at this point in the history
  • Loading branch information
orf committed Jun 4, 2018
1 parent 2a4873d commit 5c1809c
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 10 deletions.
23 changes: 17 additions & 6 deletions django/db/models/query.py
Expand Up @@ -486,20 +486,31 @@ def bulk_save(self, objs, fields, batch_size=None):
raise ValueError('All objects must have a primary key set')
if not objs:
return
fields = list(fields)
fields = [self.model._meta.get_field(name) for name in list(fields)]
if any(not f.concrete for f in fields):
raise ValueError('bulk_save can only be used with concrete fields')

field_names = [f.attname for f in fields]

ops = connections[self.db].ops
# We use PK twice in the resulting update query, once in the filter
# and once in the WHEN.
batch_size = (batch_size or max(ops.bulk_batch_size(['pk', 'pk'] + fields, objs), 1))
max_batch_size = ops.bulk_batch_size(['pk', 'pk'] + fields, objs)
if not batch_size:
batch_size = max_batch_size
else:
batch_size = min(batch_size, max_batch_size)

batches = (objs[i:i + batch_size] for i in range(0, len(objs), batch_size))
with transaction.atomic(using=self.db, savepoint=False):
for batch_objs in [objs[i:i + batch_size] for i in range(0, len(objs), batch_size)]:
for batch_objs in batches:
pks = [obj.pk for obj in batch_objs]
update_kwargs = {
field: Case(
*(When(pk=obj.pk, then=Value(getattr(obj, field)))
name: Case(
*(When(pk=obj.pk, then=Value(getattr(obj, name)))
for obj in batch_objs)
)
for field in fields
for name in field_names
}
self.filter(pk__in=pks).update(**update_kwargs)
bulk_save.alters_data = True
Expand Down
3 changes: 2 additions & 1 deletion tests/bulk_save/models.py
Expand Up @@ -8,7 +8,8 @@ class Town(models.Model):
class Place(models.Model):
name = models.CharField(max_length=255)
rating = models.IntegerField()
town = models.ForeignKey(Town, related_name='places', on_delete=models.CASCADE, null=True)
town = models.ForeignKey(Town, related_name='places', on_delete=models.SET_NULL, null=True)
db_custom_column = models.IntegerField(db_column='custom_column_name', default=0)

if connection.vendor == 'postgresql':
from django.contrib.postgres.fields import ArrayField, JSONField
Expand Down
21 changes: 18 additions & 3 deletions tests/bulk_save/test_bulk_save.py
Expand Up @@ -3,11 +3,13 @@
from django.db import connection
from django.test import TestCase

from .models import Place, Pub
from .models import Place, Pub, Town


class BulkSaveTests(TestCase):
def setUp(self):
self.town = Town.objects.create(name='Saffron Walden')

self.pubs = [
Pub.objects.create(name='The Temeraire', rating=-10, ambiance=-10),
Pub.objects.create(name='The Rose and Crown', rating=5, ambiance=7),
Expand Down Expand Up @@ -77,6 +79,21 @@ def test_no_models_no_queries(self):
with self.assertNumQueries(0):
Pub.objects.bulk_save([], fields=['name'])

def test_only_concrete_fields_allowed(self):
msg = "bulk_save can only be used with concrete fields"
with self.assertRaisesMessage(ValueError, msg):
Town.objects.bulk_save([self.town], fields=['places'])

def test_custom_db_columns(self):
Place.objects.bulk_save(self.places, fields=['db_custom_column'])

def test_foreign_keys_do_not_lookup(self):
for place in self.places:
place.town = self.town

with self.assertNumQueries(1):
Place.objects.bulk_save(self.places, ['town'])


@unittest.skipUnless(connection.vendor == 'postgres', 'Postgres tests')
class PostgresComplexFieldsTest(TestCase):
Expand All @@ -97,5 +114,3 @@ def test_array_of_json(self):
place.array_of_json = [{'a': 'b'}, {'c': 'd'}]
Place.objects.bulk_save([place], fields=['array_of_json'])
self.assertDictEqual(Place.objects.get().array_of_json, [{'a': 'b'}, {'c': 'd'}])

# connection.features.gis_enabled

0 comments on commit 5c1809c

Please sign in to comment.