Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed #31685 -- Added support for updating conflicts to QuerySet.bulk_create(). #13065

Merged
merged 1 commit into from
Jan 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions django/db/backends/base/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,10 @@ class BaseDatabaseFeatures:
# Does the backend support ignoring constraint or uniqueness errors during
# INSERT?
supports_ignore_conflicts = True
# Does the backend support updating rows on constraint or uniqueness errors
# during INSERT?
supports_update_conflicts = False
supports_update_conflicts_with_target = False

# Does this backend require casting the results of CASE expressions used
# in UPDATE statements to ensure the expression has the correct type?
Expand Down
4 changes: 2 additions & 2 deletions django/db/backends/base/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,8 +717,8 @@ def explain_query_prefix(self, format=None, **options):
raise ValueError('Unknown options: %s' % ', '.join(sorted(options.keys())))
return self.explain_prefix

def insert_statement(self, ignore_conflicts=False):
def insert_statement(self, on_conflict=None):
ChihSeanHsu marked this conversation as resolved.
Show resolved Hide resolved
return 'INSERT INTO'

def ignore_conflicts_suffix_sql(self, ignore_conflicts=None):
def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
return ''
1 change: 1 addition & 0 deletions django/db/backends/mysql/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
supports_select_difference = False
supports_slicing_ordering_in_compound = True
supports_index_on_text_field = False
supports_update_conflicts = True
create_test_procedure_without_params_sql = """
CREATE PROCEDURE test_procedure ()
BEGIN
Expand Down
31 changes: 29 additions & 2 deletions django/db/backends/mysql/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from django.db.backends.base.operations import BaseDatabaseOperations
from django.db.backends.utils import split_tzname_delta
from django.db.models import Exists, ExpressionWrapper, Lookup
from django.db.models.constants import OnConflict
from django.utils import timezone
from django.utils.encoding import force_str

Expand Down Expand Up @@ -365,8 +366,10 @@ def regex_lookup(self, lookup_type):
match_option = 'c' if lookup_type == 'regex' else 'i'
return "REGEXP_LIKE(%%s, %%s, '%s')" % match_option

def insert_statement(self, ignore_conflicts=False):
return 'INSERT IGNORE INTO' if ignore_conflicts else super().insert_statement(ignore_conflicts)
def insert_statement(self, on_conflict=None):
if on_conflict == OnConflict.IGNORE:
return 'INSERT IGNORE INTO'
return super().insert_statement(on_conflict=on_conflict)

def lookup_cast(self, lookup_type, internal_type=None):
lookup = '%s'
Expand All @@ -388,3 +391,27 @@ def conditional_expression_supported_in_where_clause(self, expression):
if getattr(expression, 'conditional', False):
return False
return super().conditional_expression_supported_in_where_clause(expression)

def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
if on_conflict == OnConflict.UPDATE:
conflict_suffix_sql = 'ON DUPLICATE KEY UPDATE %(fields)s'
field_sql = '%(field)s = VALUES(%(field)s)'
# The use of VALUES() is deprecated in MySQL 8.0.20+. Instead, use
# aliases for the new row and its columns available in MySQL
# 8.0.19+.
if not self.connection.mysql_is_mariadb:
if self.connection.mysql_version >= (8, 0, 19):
conflict_suffix_sql = f'AS new {conflict_suffix_sql}'
field_sql = '%(field)s = new.%(field)s'
# VALUES() was renamed to VALUE() in MariaDB 10.3.3+.
elif self.connection.mysql_version >= (10, 3, 3):
field_sql = '%(field)s = VALUE(%(field)s)'

fields = ', '.join([
field_sql % {'field': field}
for field in map(self.quote_name, update_fields)
])
felixxm marked this conversation as resolved.
Show resolved Hide resolved
return conflict_suffix_sql % {'fields': fields}
return super().on_conflict_suffix_sql(
fields, on_conflict, update_fields, unique_fields,
)
2 changes: 2 additions & 0 deletions django/db/backends/postgresql/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ class DatabaseFeatures(BaseDatabaseFeatures):
supports_deferrable_unique_constraints = True
has_json_operators = True
json_key_contains_list_matching_requires_list = True
supports_update_conflicts = True
supports_update_conflicts_with_target = True
test_collations = {
'non_default': 'sv-x-icu',
'swedish_ci': 'sv-x-icu',
Expand Down
17 changes: 15 additions & 2 deletions django/db/backends/postgresql/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from django.conf import settings
from django.db.backends.base.operations import BaseDatabaseOperations
from django.db.backends.utils import split_tzname_delta
from django.db.models.constants import OnConflict


class DatabaseOperations(BaseDatabaseOperations):
Expand Down Expand Up @@ -272,5 +273,17 @@ def explain_query_prefix(self, format=None, **options):
prefix += ' (%s)' % ', '.join('%s %s' % i for i in extra.items())
return prefix

def ignore_conflicts_suffix_sql(self, ignore_conflicts=None):
return 'ON CONFLICT DO NOTHING' if ignore_conflicts else super().ignore_conflicts_suffix_sql(ignore_conflicts)
def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
if on_conflict == OnConflict.IGNORE:
return 'ON CONFLICT DO NOTHING'
if on_conflict == OnConflict.UPDATE:
return 'ON CONFLICT(%s) DO UPDATE SET %s' % (
', '.join(map(self.quote_name, unique_fields)),
', '.join([
f'{field} = EXCLUDED.{field}'
for field in map(self.quote_name, update_fields)
]),
)
return super().on_conflict_suffix_sql(
fields, on_conflict, update_fields, unique_fields,
)
2 changes: 2 additions & 0 deletions django/db/backends/sqlite3/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ class DatabaseFeatures(BaseDatabaseFeatures):
supports_order_by_nulls_modifier = Database.sqlite_version_info >= (3, 30, 0)
order_by_nulls_first = True
supports_json_field_contains = False
supports_update_conflicts = Database.sqlite_version_info >= (3, 24, 0)
supports_update_conflicts_with_target = supports_update_conflicts
test_collations = {
'ci': 'nocase',
'cs': 'binary',
Expand Down
23 changes: 21 additions & 2 deletions django/db/backends/sqlite3/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from django.core.exceptions import FieldError
from django.db import DatabaseError, NotSupportedError, models
from django.db.backends.base.operations import BaseDatabaseOperations
from django.db.models.constants import OnConflict
from django.db.models.expressions import Col
from django.utils import timezone
from django.utils.dateparse import parse_date, parse_datetime, parse_time
Expand Down Expand Up @@ -370,8 +371,10 @@ def subtract_temporals(self, internal_type, lhs, rhs):
return 'django_time_diff(%s, %s)' % (lhs_sql, rhs_sql), params
return 'django_timestamp_diff(%s, %s)' % (lhs_sql, rhs_sql), params

def insert_statement(self, ignore_conflicts=False):
return 'INSERT OR IGNORE INTO' if ignore_conflicts else super().insert_statement(ignore_conflicts)
def insert_statement(self, on_conflict=None):
if on_conflict == OnConflict.IGNORE:
return 'INSERT OR IGNORE INTO'
return super().insert_statement(on_conflict=on_conflict)

def return_insert_columns(self, fields):
# SQLite < 3.35 doesn't support an INSERT...RETURNING statement.
Expand All @@ -384,3 +387,19 @@ def return_insert_columns(self, fields):
) for field in fields
]
return 'RETURNING %s' % ', '.join(columns), ()

def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
if (
on_conflict == OnConflict.UPDATE and
self.connection.features.supports_update_conflicts_with_target
):
return 'ON CONFLICT(%s) DO UPDATE SET %s' % (
', '.join(map(self.quote_name, unique_fields)),
', '.join([
f'{field} = EXCLUDED.{field}'
for field in map(self.quote_name, update_fields)
]),
)
return super().on_conflict_suffix_sql(
fields, on_conflict, update_fields, unique_fields,
)
6 changes: 6 additions & 0 deletions django/db/models/constants.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
"""
Constants used across the ORM in general.
"""
from enum import Enum

# Separator used to split filter strings apart.
LOOKUP_SEP = '__'


class OnConflict(Enum):
IGNORE = 'ignore'
UPDATE = 'update'
119 changes: 106 additions & 13 deletions django/db/models/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
router, transaction,
)
from django.db.models import AutoField, DateField, DateTimeField, sql
from django.db.models.constants import LOOKUP_SEP
from django.db.models.constants import LOOKUP_SEP, OnConflict
from django.db.models.deletion import Collector
from django.db.models.expressions import Case, Expression, F, Ref, Value, When
from django.db.models.functions import Cast, Trunc
Expand Down Expand Up @@ -466,7 +466,69 @@ def _prepare_for_bulk_create(self, objs):
obj.pk = obj._meta.pk.get_pk_value_on_save(obj)
obj._prepare_related_fields_for_save(operation_name='bulk_create')

def bulk_create(self, objs, batch_size=None, ignore_conflicts=False):
def _check_bulk_create_options(self, ignore_conflicts, update_conflicts, update_fields, unique_fields):
if ignore_conflicts and update_conflicts:
raise ValueError(
'ignore_conflicts and update_conflicts are mutually exclusive.'
)
db_features = connections[self.db].features
if ignore_conflicts:
if not db_features.supports_ignore_conflicts:
raise NotSupportedError(
'This database backend does not support ignoring conflicts.'
)
return OnConflict.IGNORE
elif update_conflicts:
if not db_features.supports_update_conflicts:
raise NotSupportedError(
'This database backend does not support updating conflicts.'
)
if not update_fields:
raise ValueError(
'Fields that will be updated when a row insertion fails '
'on conflicts must be provided.'
)
if unique_fields and not db_features.supports_update_conflicts_with_target:
raise NotSupportedError(
'This database backend does not support updating '
'conflicts with specifying unique fields that can trigger '
'the upsert.'
)
if not unique_fields and db_features.supports_update_conflicts_with_target:
raise ValueError(
'Unique fields that can trigger the upsert must be '
'provided.'
)
# Updating primary keys and non-concrete fields is forbidden.
update_fields = [self.model._meta.get_field(name) for name in update_fields]
if any(not f.concrete or f.many_to_many for f in update_fields):
raise ValueError(
'bulk_create() can only be used with concrete fields in '
'update_fields.'
)
if any(f.primary_key for f in update_fields):
raise ValueError(
'bulk_create() cannot be used with primary keys in '
'update_fields.'
)
if unique_fields:
# Primary key is allowed in unique_fields.
unique_fields = [
self.model._meta.get_field(name)
for name in unique_fields if name != 'pk'
]
if any(not f.concrete or f.many_to_many for f in unique_fields):
raise ValueError(
'bulk_create() can only be used with concrete fields '
'in unique_fields.'
)
return OnConflict.UPDATE
return None

def bulk_create(
self, objs, batch_size=None, ignore_conflicts=False,
update_conflicts=False, update_fields=None, unique_fields=None,
):
"""
Insert each of the instances into the database. Do *not* call
save() on each of the instances, do not send any pre/post_save
Expand Down Expand Up @@ -497,6 +559,12 @@ def bulk_create(self, objs, batch_size=None, ignore_conflicts=False):
raise ValueError("Can't bulk create a multi-table inherited model")
if not objs:
return objs
on_conflict = self._check_bulk_create_options(
ignore_conflicts,
update_conflicts,
update_fields,
unique_fields,
)
self._for_write = True
opts = self.model._meta
fields = opts.concrete_fields
Expand All @@ -506,7 +574,12 @@ def bulk_create(self, objs, batch_size=None, ignore_conflicts=False):
objs_with_pk, objs_without_pk = partition(lambda o: o.pk is None, objs)
if objs_with_pk:
returned_columns = self._batched_insert(
objs_with_pk, fields, batch_size, ignore_conflicts=ignore_conflicts,
objs_with_pk,
fields,
batch_size,
on_conflict=on_conflict,
update_fields=update_fields,
unique_fields=unique_fields,
)
for obj_with_pk, results in zip(objs_with_pk, returned_columns):
for result, field in zip(results, opts.db_returning_fields):
Expand All @@ -518,10 +591,15 @@ def bulk_create(self, objs, batch_size=None, ignore_conflicts=False):
if objs_without_pk:
fields = [f for f in fields if not isinstance(f, AutoField)]
returned_columns = self._batched_insert(
objs_without_pk, fields, batch_size, ignore_conflicts=ignore_conflicts,
objs_without_pk,
fields,
batch_size,
on_conflict=on_conflict,
update_fields=update_fields,
unique_fields=unique_fields,
)
connection = connections[self.db]
if connection.features.can_return_rows_from_bulk_insert and not ignore_conflicts:
if connection.features.can_return_rows_from_bulk_insert and on_conflict is None:
assert len(returned_columns) == len(objs_without_pk)
for obj_without_pk, results in zip(objs_without_pk, returned_columns):
for result, field in zip(results, opts.db_returning_fields):
Expand Down Expand Up @@ -1293,41 +1371,56 @@ def db(self):
# PRIVATE METHODS #
###################

def _insert(self, objs, fields, returning_fields=None, raw=False, using=None, ignore_conflicts=False):
def _insert(
self, objs, fields, returning_fields=None, raw=False, using=None,
on_conflict=None, update_fields=None, unique_fields=None,
):
"""
Insert a new record for the given model. This provides an interface to
the InsertQuery class and is how Model.save() is implemented.
"""
self._for_write = True
if using is None:
using = self.db
query = sql.InsertQuery(self.model, ignore_conflicts=ignore_conflicts)
query = sql.InsertQuery(
self.model,
on_conflict=on_conflict,
update_fields=update_fields,
unique_fields=unique_fields,
)
query.insert_values(fields, objs, raw=raw)
return query.get_compiler(using=using).execute_sql(returning_fields)
_insert.alters_data = True
_insert.queryset_only = False

def _batched_insert(self, objs, fields, batch_size, ignore_conflicts=False):
def _batched_insert(
self, objs, fields, batch_size, on_conflict=None, update_fields=None,
unique_fields=None,
):
"""
Helper method for bulk_create() to insert objs one batch at a time.
"""
connection = connections[self.db]
if ignore_conflicts and not connection.features.supports_ignore_conflicts:
raise NotSupportedError('This database backend does not support ignoring conflicts.')
ops = connection.ops
max_batch_size = max(ops.bulk_batch_size(fields, objs), 1)
batch_size = min(batch_size, max_batch_size) if batch_size else max_batch_size
inserted_rows = []
bulk_return = connection.features.can_return_rows_from_bulk_insert
for item in [objs[i:i + batch_size] for i in range(0, len(objs), batch_size)]:
if bulk_return and not ignore_conflicts:
if bulk_return and on_conflict is None:
inserted_rows.extend(self._insert(
item, fields=fields, using=self.db,
returning_fields=self.model._meta.db_returning_fields,
ignore_conflicts=ignore_conflicts,
))
else:
self._insert(item, fields=fields, using=self.db, ignore_conflicts=ignore_conflicts)
self._insert(
item,
fields=fields,
using=self.db,
on_conflict=on_conflict,
update_fields=update_fields,
unique_fields=unique_fields,
)
return inserted_rows

def _chain(self):
Expand Down