Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restore transactions for data import #480

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
103 changes: 57 additions & 46 deletions import_export/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from django import VERSION
from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
from django.core.management.color import no_style
from django.db import connections, transaction, DEFAULT_DB_ALIAS
from django.db.models.fields import FieldDoesNotExist
Expand Down Expand Up @@ -57,7 +58,7 @@ def emit(self, record):

logging.getLogger(__name__).addHandler(NullHandler())

USE_TRANSACTIONS = getattr(settings, 'IMPORT_EXPORT_USE_TRANSACTIONS', False)
USE_TRANSACTIONS = getattr(settings, 'IMPORT_EXPORT_USE_TRANSACTIONS', True)


class ResourceOptions(object):
Expand Down Expand Up @@ -261,36 +262,42 @@ def get_or_init_instance(self, instance_loader, row):
else:
return (self.init_instance(row), True)

def save_instance(self, instance, dry_run=False):
def save_instance(self, instance, using_transactions=True, dry_run=False):
"""
Takes care of saving the object to the database.

Keep in mind that this is done by calling ``instance.save()``, so
objects are not created in bulk!
"""
self.before_save_instance(instance, dry_run)
if not dry_run:
self.before_save_instance(instance, using_transactions, dry_run)
if not using_transactions and dry_run:
# we don't have transactions and we want to do a dry_run
pass
else:
instance.save()
self.after_save_instance(instance, dry_run)
self.after_save_instance(instance, using_transactions, dry_run)

def before_save_instance(self, instance, dry_run):
def before_save_instance(self, instance, using_transactions, dry_run):
"""
Override to add additional logic. Does nothing by default.
"""
pass

def after_save_instance(self, instance, dry_run):
def after_save_instance(self, instance, using_transactions, dry_run):
"""
Override to add additional logic. Does nothing by default.
"""
pass

def delete_instance(self, instance, dry_run=False):
def delete_instance(self, instance, using_transactions=True, dry_run=False):
"""
Calls :meth:`instance.delete` as long as ``dry_run`` is not set.
"""
self.before_delete_instance(instance, dry_run)
if not dry_run:
if not using_transactions and dry_run:
# we don't have transactions and we want to do a dry_run
pass
else:
instance.delete()
self.after_delete_instance(instance, dry_run)

Expand Down Expand Up @@ -324,14 +331,17 @@ def import_obj(self, obj, data, dry_run):
continue
self.import_field(field, obj, data)

def save_m2m(self, obj, data, dry_run):
def save_m2m(self, obj, data, using_transactions, dry_run):
"""
Saves m2m fields.

Model instance need to have a primary key value before
a many-to-many relationship can be used.
"""
if not dry_run:
if not using_transactions and dry_run:
# we don't have transactions and we want to do a dry_run
pass
else:
for field in self.get_fields():
if not isinstance(field.widget, widgets.ManyToManyWidget):
continue
Expand Down Expand Up @@ -373,27 +383,15 @@ def get_diff_headers(self):
"""
return self.get_export_headers()

def before_import(self, dataset, dry_run, **kwargs):
def before_import(self, dataset, using_transactions, dry_run, **kwargs):
"""
Override to add additional logic. Does nothing by default.

This method receives the ``dataset`` that's going to be imported, the
``dry_run`` parameter which determines whether changes are saved to
the database, and any additional keyword arguments passed to
``import_data`` in a ``kwargs`` dict.
"""
pass

def after_import(self, dataset, result, dry_run, **kwargs):
def after_import(self, dataset, result, using_transactions, dry_run, **kwargs):
"""
Override to add additional logic. Does nothing by default.

This method receives the ``dataset`` that's just been imported, the
``result`` of the import and the ``dry_run`` parameter which determines
whether changes will be saved to the database, and any additional
keyword arguments passed to ``import_data`` in a ``kwargs`` dict. This
method runs after the main import finishes but before the changes are
committed or rolled back.
"""
pass

Expand All @@ -415,7 +413,7 @@ def after_import_instance(self, instance, new, **kwargs):
"""
pass

def import_row(self, row, instance_loader, dry_run=False, **kwargs):
def import_row(self, row, instance_loader, using_transactions=True, dry_run=False, **kwargs):
"""
Imports data from ``tablib.Dataset``. Refer to :doc:`import_workflow`
for a more complete description of the whole import process.
Expand All @@ -424,6 +422,9 @@ def import_row(self, row, instance_loader, dry_run=False, **kwargs):

:param instance_loader: The instance loader to be used to load the row

:param using_transactions: If ``using_transactions`` is set, a transaction
is being used to wrap the import

:param dry_run: If ``dry_run`` is set, or error occurs, transaction
will be rolled back.
"""
Expand Down Expand Up @@ -456,7 +457,7 @@ def import_row(self, row, instance_loader, dry_run=False, **kwargs):
else:
with transaction.atomic():
self.save_instance(instance, dry_run)
self.save_m2m(instance, row, dry_run)
self.save_m2m(instance, row, using_transactions, dry_run)
# Add object info to RowResult for LogEntry
row_result.object_repr = force_text(instance)
row_result.object_id = instance.pk
Expand All @@ -472,7 +473,6 @@ def import_row(self, row, instance_loader, dry_run=False, **kwargs):
row_result.errors.append(self.get_error_result_class()(e, tb_info, row))
return row_result

@atomic()
def import_data(self, dataset, dry_run=False, raise_errors=False,
use_transactions=None, **kwargs):
"""
Expand All @@ -487,9 +487,27 @@ def import_data(self, dataset, dry_run=False, raise_errors=False,
:param use_transactions: If ``True`` import process will be processed
inside transaction.

:param dry_run: If ``dry_run`` is set, or error occurs, transaction
will be rolled back.
:param dry_run: If ``dry_run`` is set, or an error occurs, if a transaction
is being used, it will be rolled back.
"""

if use_transactions is None:
use_transactions = self.get_use_transactions()

connection = connections[DEFAULT_DB_ALIAS]
supports_transactions = getattr(connection.features, "supports_transactions", False)

if use_transactions and not supports_transactions:
raise ImproperlyConfigured

using_transactions = (use_transactions or dry_run) and supports_transactions

if using_transactions:
with transaction.atomic():
return self.import_data_inner(dataset, dry_run, raise_errors, using_transactions, **kwargs)
return self.import_data_inner(dataset, dry_run, raise_errors, using_transactions, **kwargs)

def import_data_inner(self, dataset, dry_run, raise_errors, using_transactions, **kwargs):
result = self.get_result_class()()
result.diff_headers = self.get_diff_headers()
result.totals = OrderedDict([(RowResult.IMPORT_TYPE_NEW, 0),
Expand All @@ -499,25 +517,19 @@ def import_data(self, dataset, dry_run=False, raise_errors=False,
(RowResult.IMPORT_TYPE_ERROR, 0),
('total', len(dataset))])

if use_transactions is None:
use_transactions = self.get_use_transactions()

if use_transactions is True:
if using_transactions:
# when transactions are used we want to create/update/delete object
# as transaction will be rolled back if dry_run is set
real_dry_run = False
sp1 = savepoint()
else:
real_dry_run = dry_run

try:
self.before_import(dataset, real_dry_run, **kwargs)
self.before_import(dataset, using_transactions, dry_run, **kwargs)
except Exception as e:
logging.exception(e)
tb_info = traceback.format_exc()
result.base_errors.append(self.get_error_result_class()(e, tb_info))
if raise_errors:
if use_transactions:
if using_transactions:
savepoint_rollback(sp1)
raise

Expand All @@ -527,11 +539,11 @@ def import_data(self, dataset, dry_run=False, raise_errors=False,
result.totals['total'] = len(dataset)

for row in dataset.dict:
row_result = self.import_row(row, instance_loader, real_dry_run, **kwargs)
row_result = self.import_row(row, instance_loader, using_transactions, dry_run, **kwargs)
if row_result.errors:
result.totals[row_result.IMPORT_TYPE_ERROR] += 1
if raise_errors:
if use_transactions:
if using_transactions:
savepoint_rollback(sp1)
raise row_result.errors[-1].error
else:
Expand All @@ -541,17 +553,17 @@ def import_data(self, dataset, dry_run=False, raise_errors=False,
result.append_row_result(row_result)

try:
self.after_import(dataset, result, real_dry_run, **kwargs)
self.after_import(dataset, result, using_transactions, dry_run, **kwargs)
except Exception as e:
logging.exception(e)
tb_info = traceback.format_exc()
result.base_errors.append(self.get_error_result_class()(e, tb_info))
if raise_errors:
if use_transactions:
if using_transactions:
savepoint_rollback(sp1)
raise

if use_transactions:
if using_transactions:
if dry_run or result.has_errors():
savepoint_rollback(sp1)
else:
Expand All @@ -563,7 +575,6 @@ def get_export_order(self):
order = tuple(self._meta.export_order or ())
return order + tuple(k for k in self.fields.keys() if k not in order)


def before_export(self, queryset, *args, **kwargs):
"""
Override to add additional logic. Does nothing by default.
Expand Down Expand Up @@ -790,7 +801,7 @@ def init_instance(self, row=None):
"""
return self._meta.model()

def after_import(self, dataset, result, dry_run, **kwargs):
def after_import(self, dataset, result, using_transactions, dry_run, **kwargs):
"""
Reset the SQL sequences after new objects are imported
"""
Expand Down
11 changes: 7 additions & 4 deletions tests/core/tests/resources_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ def attempted_save(instance, real_dry_run):

def test_before_import_access_to_kwargs(self):
class B(BookResource):
def before_import(self, dataset, dry_run, **kwargs):
def before_import(self, dataset, using_transactions, dry_run, **kwargs):
if 'extra_arg' in kwargs:
dataset.headers[dataset.headers.index('author_email')] = 'old_email'
dataset.insert_col(0,
Expand Down Expand Up @@ -590,8 +590,11 @@ class Meta:
model = Entry
fields = ('id', )

def after_save_instance(self, instance, dry_run):
if not dry_run:
def after_save_instance(self, instance, using_transactions, dry_run):
if not using_transactions and dry_run:
# we don't have transactions and we want to do a dry_run
pass
else:
instance.user.save()

user = User.objects.create(username='foo')
Expand Down Expand Up @@ -656,7 +659,7 @@ def test_m2m_import_with_transactions(self):

id_field = resource.fields['id']
id_diff = row_diff[fields.index(id_field)]
# id diff should exists because in rollbacked transaction
# id diff should exist because in rollbacked transaction
# FooBook has been saved
self.assertTrue(id_diff)

Expand Down