Skip to content

Commit

Permalink
Merge pull request #653 from rtpg/from-upstream
Browse files Browse the repository at this point in the history
Allow for passing in extra kwargs on Tag creation
  • Loading branch information
Asif Saif Uddin committed Jan 14, 2020
2 parents e3b3dc4 + 036cdad commit 058a908
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 13 deletions.
2 changes: 1 addition & 1 deletion docs/custom_tagging.rst
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ Custom tag
~~~~~~~~~~

When providing a custom ``Tag`` model it should be a ``ForeignKey`` to your tag
model named ``"tag"``:
model named ``"tag"``. If your custom ``Tag`` model has extra parameters you want to initialize during setup, you can do so by passing it along via the ``tag_kwargs`` parameter of ``TaggableManager.add``. For example ``my_food.tags.add("tag_name1", "tag_name2", tag_kwargs={"my_field":3})``:

.. code-block:: python
Expand Down
31 changes: 19 additions & 12 deletions taggit/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,12 @@ def _lookup_kwargs(self):
return self.through.lookup_kwargs(self.instance)

@require_instance_manager
def add(self, *tags):
def add(self, *tags, **kwargs):
tag_kwargs = kwargs.pop("tag_kwargs", {})

db = router.db_for_write(self.through, instance=self.instance)

tag_objs = self._to_tag_model_instances(tags)
tag_objs = self._to_tag_model_instances(tags, tag_kwargs)
new_ids = {t.pk for t in tag_objs}

# NOTE: can we hardcode 'tag_id' here or should the column name be got
Expand Down Expand Up @@ -165,7 +167,7 @@ def add(self, *tags):
using=db,
)

def _to_tag_model_instances(self, tags):
def _to_tag_model_instances(self, tags, tag_kwargs):
"""
Takes an iterable containing either strings, tag objects, or a mixture
of both and returns set of tag objects.
Expand Down Expand Up @@ -205,19 +207,19 @@ def _to_tag_model_instances(self, tags):
else:
# If str_tags has 0 elements Django actually optimizes that to not
# do a query. Malcolm is very smart.
existing = manager.filter(name__in=str_tags)
tags_to_create = str_tags - {t.name for t in existing}
existing = manager.filter(name__in=str_tags, **tag_kwargs)

tags_to_create = str_tags - set(t.name for t in existing)

tag_objs.update(existing)

for new_tag in tags_to_create:
if case_insensitive:
tag, created = manager.get_or_create(
name__iexact=new_tag, defaults={"name": new_tag}
)
lookup = {"name__iexact": new_tag, **tag_kwargs}
else:
tag, created = manager.get_or_create(name=new_tag)
lookup = {"name": new_tag, **tag_kwargs}

tag, create = manager.get_or_create(**lookup, defaults={"name": new_tag})
tag_objs.add(tag)

return tag_objs
Expand All @@ -237,16 +239,21 @@ def set(self, *tags, **kwargs):
then all existing tags are removed (using `.clear()`) and the new tags
added. Otherwise, only those tags that are not present in the args are
removed and any new tags added.
Any kwarg apart from 'clear' will be passed when adding tags.
"""
db = router.db_for_write(self.through, instance=self.instance)

clear = kwargs.pop("clear", False)
tag_kwargs = kwargs.pop("tag_kwargs", {})

if clear:
self.clear()
self.add(*tags)
self.add(*tags, **kwargs)
else:
# make sure we're working with a collection of a uniform type
objs = self._to_tag_model_instances(tags)
objs = self._to_tag_model_instances(tags, tag_kwargs)

# get the existing tag strings
old_tag_strs = set(
Expand All @@ -263,7 +270,7 @@ def set(self, *tags, **kwargs):
new_objs.append(obj)

self.remove(*old_tag_strs)
self.add(*new_objs)
self.add(*new_objs, **kwargs)

@require_instance_manager
def remove(self, *tags):
Expand Down
20 changes: 20 additions & 0 deletions tests/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,26 @@ def test_lt(self):
self.assertIs(low < high, False)


class CustomTagCreationTestCase(TestCase):
def test_model_manager_add(self):
apple = OfficialFood.objects.create(name="apple")

# let's add two official tags
apple.tags.add(
"foo", "bar", tag_kwargs={"official": True},
)

# and two unofficial ones
apple.tags.add(
"baz", "wow", tag_kwargs={"official": False},
)

# We should end up with 4 tags
self.assertEquals(apple.tags.count(), 4)
self.assertEquals(apple.tags.filter(official=True).count(), 2)
self.assertEquals(apple.tags.filter(official=False).count(), 2)


class TagModelDirectTestCase(TagModelTestCase):
food_model = DirectFood
tag_model = Tag
Expand Down

0 comments on commit 058a908

Please sign in to comment.