Skip to content

Commit

Permalink
CachedInstanceLoader defaults to empty when import_id is missing (#1225)
Browse files Browse the repository at this point in the history
  • Loading branch information
DonQueso89 committed Dec 31, 2020
1 parent 3c1d10e commit 3ad28a0
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 9 deletions.
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,4 @@ The following is a list of much appreciated contributors:
* kjpc-tech (Kyle)
* Matthew Hegarty
* jinmay (jinmyeong Cho)
* DonQueso89 (Kees van Ekeren)
24 changes: 15 additions & 9 deletions import_export/instance_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,21 @@ def __init__(self, *args, **kwargs):
pk_field_name = self.resource.get_import_id_fields()[0]
self.pk_field = self.resource.fields[pk_field_name]

ids = [self.pk_field.clean(row) for row in self.dataset.dict]
qs = self.get_queryset().filter(**{
"%s__in" % self.pk_field.attribute: ids
})
# If the pk field is missing, all instances in dataset are new
# and cache is empty.
self.all_instances = {}
if self.dataset.dict and self.pk_field.column_name in self.dataset.dict[0]:
ids = [self.pk_field.clean(row) for row in self.dataset.dict]
qs = self.get_queryset().filter(**{
"%s__in" % self.pk_field.attribute: ids
})

self.all_instances = {
self.pk_field.get_value(instance): instance
for instance in qs
}
self.all_instances = {
self.pk_field.get_value(instance): instance
for instance in qs
}

def get_instance(self, row):
return self.all_instances.get(self.pk_field.clean(row))
if self.all_instances:
return self.all_instances.get(self.pk_field.clean(row))
return None
25 changes: 25 additions & 0 deletions tests/core/tests/test_instance_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,28 @@ def test_all_instances(self):
def test_get_instance(self):
obj = self.instance_loader.get_instance(self.dataset.dict[0])
self.assertEqual(obj, self.book)



class CachedInstanceLoaderWithAbsentImportIdFieldTest(TestCase):
"""Ensure that the cache is empty when the PK field is absent
in the inbound dataset.
"""

def setUp(self):
self.resource = resources.modelresource_factory(Book)()
self.dataset = tablib.Dataset(headers=['name', 'author_email'])
self.book = Book.objects.create(name="Some book")
self.book2 = Book.objects.create(name="Some other book")
row = ['Some book', 'test@example.com']
self.dataset.append(row)
self.instance_loader = instance_loaders.CachedInstanceLoader(
self.resource, self.dataset)

def test_all_instances(self):
self.assertEqual(self.instance_loader.all_instances, {})
self.assertEqual(len(self.instance_loader.all_instances), 0)

def test_get_instance(self):
obj = self.instance_loader.get_instance(self.dataset.dict[0])
self.assertEqual(obj, None)

0 comments on commit 3ad28a0

Please sign in to comment.