Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes #1544. Add new parameter to Field called m2m_add which changes… #1545

Merged
merged 6 commits into from
Feb 21, 2023
Merged
13 changes: 11 additions & 2 deletions import_export/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,15 @@ class Field:
:param saves_null_values: Controls whether null values are saved on the object
:param dehydrate_method: Lets you choose your own method for dehydration rather
than using `dehydrate_{field_name}` syntax.
:param m2m_add: changes save of this field to add the values, if they do not exist,
to a ManyToMany field instead of setting all values. Only useful if field is
a ManyToMany field.
"""
empty_values = [None, '']

def __init__(self, attribute=None, column_name=None, widget=None,
default=NOT_PROVIDED, readonly=False, saves_null_values=True, dehydrate_method=None):
default=NOT_PROVIDED, readonly=False, saves_null_values=True,
dehydrate_method=None, m2m_add=False):
self.attribute = attribute
self.default = default
self.column_name = column_name
Expand All @@ -44,6 +48,7 @@ def __init__(self, attribute=None, column_name=None, widget=None,
self.readonly = readonly
self.saves_null_values = saves_null_values
self.dehydrate_method = dehydrate_method
self.m2m_add = m2m_add

def __repr__(self):
"""
Expand Down Expand Up @@ -116,7 +121,11 @@ def save(self, obj, data, is_m2m=False, **kwargs):
if not is_m2m:
setattr(obj, attrs[-1], cleaned)
else:
getattr(obj, attrs[-1]).set(cleaned)
if self.m2m_add:
new_values = [val for val in cleaned if val not in getattr(obj, attrs[-1]).all()]
matthewhegarty marked this conversation as resolved.
Show resolved Hide resolved
getattr(obj, attrs[-1]).add(*new_values)
else:
getattr(obj, attrs[-1]).set(cleaned)

def export(self, obj):
"""
Expand Down
36 changes: 36 additions & 0 deletions tests/core/tests/test_fields.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from datetime import date
from unittest import mock

from django.test import TestCase

Expand Down Expand Up @@ -108,3 +109,38 @@ def testget_dehydrate_method_without_params_raises_attribute_error(self):
FieldError,
field.get_dehydrate_method
)

def test_m2m_add_true(self):
m2m_related_manager = mock.Mock(spec=["add", "set", "all"])
m2m_related_manager.all.return_value = []
self.obj.aliases = m2m_related_manager
field = fields.Field(column_name='aliases', attribute='aliases', m2m_add=True)
row = {
'aliases': ["Foo", "Bar"],
}
field.save(self.obj, row, is_m2m=True)

self.assertEqual(m2m_related_manager.add.call_count, 1)
self.assertEqual(m2m_related_manager.set.call_count, 0)
self.assertSequenceEqual(m2m_related_manager.add.call_args.args, ('Foo', 'Bar'))

row = {
'aliases': ["apple"],
}
field.save(self.obj, row, is_m2m=True)
self.assertEqual(m2m_related_manager.add.call_args.args[0], 'apple')

def test_m2m_add_False(self):
m2m_related_manager = mock.Mock(spec=["add", "set", "all"])
self.obj.aliases = m2m_related_manager
field = fields.Field(column_name='aliases', attribute='aliases')
row = {
'aliases': ["Foo", "Bar"],
}
field.save(self.obj, row, is_m2m=True)

self.assertEqual(m2m_related_manager.add.call_count, 0)
self.assertEqual(m2m_related_manager.set.call_count, 1)
self.assertEqual(m2m_related_manager.set.call_args.args[0], ['Foo', 'Bar'])


40 changes: 40 additions & 0 deletions tests/core/tests/test_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -1053,6 +1053,46 @@ class Meta:
self.assertIn(cat1, book.categories.all())
self.assertIn(cat2, book.categories.all())

def test_m2m_add(self):
cat1 = Category.objects.create(name='Cat 1')
cat2 = Category.objects.create(name='Cat 2')
cat3 = Category.objects.create(name='Cat 3')
cat4 = Category.objects.create(name='Cat 4')
headers = ['id', 'name', 'categories']
row = [None, 'FooBook', "Cat 1|Cat 2"]
dataset = tablib.Dataset(row, headers=headers)

class BookM2MResource(resources.ModelResource):
categories = fields.Field(
attribute='categories',
m2m_add=True,
widget=widgets.ManyToManyWidget(Category, field='name',
separator='|')
)

class Meta:
model = Book

resource = BookM2MResource()
resource.import_data(dataset, raise_errors=True)
book = Book.objects.get(name='FooBook')
self.assertIn(cat1, book.categories.all())
self.assertIn(cat2, book.categories.all())
self.assertNotIn(cat3, book.categories.all())
self.assertNotIn(cat4, book.categories.all())

row1 = [book.id, 'FooBook', "Cat 1|Cat 2"] # This should have no effect, since Cat 1 and Cat 2 already exist
row2 = [book.id, 'FooBook', "Cat 3|Cat 4"]
dataset = tablib.Dataset(row1, row2, headers=headers)
resource.import_data(dataset, raise_errors=True)
book2 = Book.objects.get(name='FooBook')
self.assertEqual(book.id, book2.id)
self.assertEqual(book.categories.count(), 4)
self.assertIn(cat1, book2.categories.all())
self.assertIn(cat2, book2.categories.all())
self.assertIn(cat3, book2.categories.all())
self.assertIn(cat4, book2.categories.all())

def test_related_one_to_one(self):
# issue #17 - Exception when attempting access something on the
# related_name
Expand Down