Skip to content

Commit

Permalink
Merge branch 'develop'
Browse files Browse the repository at this point in the history
  • Loading branch information
frewsxcv committed Apr 26, 2015
2 parents 24df08c + dcabf68 commit 72156f5
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 26 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
Changelog
=========

0.14.0 (04.26.2015)
~~~~~~~~~~~~~~~~~~~
* Prevent extra JOIN when prefetching
* https://github.com/alex/django-taggit/pull/275
* Prevent _meta warnings with Django 1.8
* https://github.com/alex/django-taggit/pull/299

0.13.0 (04.02.2015)
~~~~~~~~~~~~~~~~~~~
* Django 1.8 support
Expand Down
2 changes: 1 addition & 1 deletion taggit/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
VERSION = (0, 13, 0)
VERSION = (0, 14, 0)
38 changes: 20 additions & 18 deletions taggit/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from taggit.forms import TagField
from taggit.models import GenericTaggedItemBase, TaggedItem
from taggit.utils import require_instance_manager
from taggit.utils import _get_field, require_instance_manager

try:
from django.contrib.contenttypes.fields import GenericRelation
Expand Down Expand Up @@ -101,11 +101,12 @@ def __init__(self, through, model, instance, prefetch_cache_name):
def is_cached(self, instance):
return self.prefetch_cache_name in instance._prefetched_objects_cache

def get_queryset(self):
def get_queryset(self, extra_filters=None):
try:
return self.instance._prefetched_objects_cache[self.prefetch_cache_name]
except (AttributeError, KeyError):
return self.through.tags_for(self.model, self.instance)
kwargs = extra_filters if extra_filters else {}
return self.through.tags_for(self.model, self.instance, **kwargs)

def get_prefetch_queryset(self, instances, queryset=None):
if queryset is not None:
Expand All @@ -126,7 +127,7 @@ def get_prefetch_queryset(self, instances, queryset=None):
source_col = fk.column
connection = connections[db]
qn = connection.ops.quote_name
qs = self.get_queryset().using(db)._next_is_sticky().filter(**query).extra(
qs = self.get_queryset(query).using(db).extra(
select={
'_prefetch_related_val': '%s.%s' % (qn(join_table), qn(source_col))
}
Expand Down Expand Up @@ -212,7 +213,7 @@ def similar_objects(self):
if len(lookup_keys) == 1:
# Can we do this without a second query by using a select_related()
# somehow?
f = self.through._meta.get_field_by_name(lookup_keys[0])[0]
f = _get_field(self.through, lookup_keys[0])
objs = f.rel.to._default_manager.filter(**{
"%s__in" % f.rel.field_name: [r["content_object"] for r in qs]
})
Expand Down Expand Up @@ -389,10 +390,10 @@ def related_query_name(self):
return _model_name(self.model)

def m2m_reverse_name(self):
return self.through._meta.get_field_by_name("tag")[0].column
return _get_field(self.through, 'tag').column

def m2m_reverse_field_name(self):
return self.through._meta.get_field_by_name("tag")[0].name
return _get_field(self.through, 'tag').name

def m2m_target_field_name(self):
return self.model._meta.pk.name
Expand Down Expand Up @@ -430,7 +431,7 @@ def get_extra_join_sql(self, connection, qn, lhs_alias, rhs_alias):
alias_to_join = rhs_alias
else:
alias_to_join = lhs_alias
extra_col = self.through._meta.get_field_by_name('content_type')[0].column
extra_col = _get_field(self.through, 'content_type').column
content_type_ids = [ContentType.objects.get_for_model(subclass).pk for
subclass in _get_subclasses(self.model)]
if len(content_type_ids) == 1:
Expand All @@ -449,8 +450,8 @@ def get_extra_join_sql(self, connection, qn, lhs_alias, rhs_alias):
# This and all the methods till the end of class are only used in django >= 1.6
def _get_mm_case_path_info(self, direct=False):
pathinfos = []
linkfield1 = self.through._meta.get_field_by_name('content_object')[0]
linkfield2 = self.through._meta.get_field_by_name(self.m2m_reverse_field_name())[0]
linkfield1 = _get_field(self.through, 'content_object')
linkfield2 = _get_field(self.through, self.m2m_reverse_field_name())
if direct:
join1infos = linkfield1.get_reverse_path_info()
join2infos = linkfield2.get_path_info()
Expand All @@ -465,8 +466,8 @@ def _get_gfk_case_path_info(self, direct=False):
pathinfos = []
from_field = self.model._meta.pk
opts = self.through._meta
object_id_field = opts.get_field_by_name('object_id')[0]
linkfield = self.through._meta.get_field_by_name(self.m2m_reverse_field_name())[0]
object_id_field = _get_field(self.through, 'object_id')
linkfield = _get_field(self.through, self.m2m_reverse_field_name())
if direct:
join1infos = [PathInfo(self.model._meta, opts, [from_field], self.rel, True, False)]
join2infos = linkfield.get_path_info()
Expand Down Expand Up @@ -496,7 +497,7 @@ def get_joining_columns(self, reverse_join=False):
return (("object_id", "id"),)

def get_extra_restriction(self, where_class, alias, related_alias):
extra_col = self.through._meta.get_field_by_name('content_type')[0].column
extra_col = _get_field(self.through, 'content_type').column
content_type_ids = [ContentType.objects.get_for_model(subclass).pk
for subclass in _get_subclasses(self.model)]
return ExtraJoinRestriction(related_alias, extra_col, content_type_ids)
Expand All @@ -506,8 +507,7 @@ def get_reverse_joining_columns(self):

@property
def related_fields(self):
return [(self.through._meta.get_field_by_name('object_id')[0],
self.model._meta.pk)]
return [(_get_field(self.through, 'object_id'), self.model._meta.pk)]

@property
def foreign_related_fields(self):
Expand All @@ -516,9 +516,11 @@ def foreign_related_fields(self):

def _get_subclasses(model):
subclasses = [model]
for f in model._meta.get_all_field_names():
field = model._meta.get_field_by_name(f)[0]

if VERSION < (1, 8):
all_fields = (_get_field(model, f) for f in model._meta.get_all_field_names())
else:
all_fields = model._meta.get_fields()
for field in all_fields:
# Django 1.8 +
if (not RelatedObject and isinstance(field, OneToOneRel) and
getattr(field.field.rel, "parent_link", None)):
Expand Down
21 changes: 14 additions & 7 deletions taggit/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from django.utils.translation import ugettext_lazy as _
from django.utils.translation import ugettext

from taggit.utils import _get_field

try:
from django.contrib.contenttypes.fields import GenericForeignKey
except ImportError: # django < 1.7
Expand Down Expand Up @@ -101,11 +103,11 @@ class Meta:

@classmethod
def tag_model(cls):
return cls._meta.get_field_by_name("tag")[0].rel.to
return _get_field(cls, 'tag').rel.to

@classmethod
def tag_relname(cls):
return cls._meta.get_field_by_name('tag')[0].rel.related_name
return _get_field(cls, 'tag').rel.related_name

@classmethod
def lookup_kwargs(cls, instance):
Expand All @@ -127,14 +129,17 @@ class Meta:
abstract = True

@classmethod
def tags_for(cls, model, instance=None):
def tags_for(cls, model, instance=None, **extra_filters):
kwargs = extra_filters or {}
if instance is not None:
return cls.tag_model().objects.filter(**{
kwargs.update({
'%s__content_object' % cls.tag_relname(): instance
})
return cls.tag_model().objects.filter(**{
return cls.tag_model().objects.filter(**kwargs)
kwargs.update({
'%s__content_object__isnull' % cls.tag_relname(): False
}).distinct()
})
return cls.tag_model().objects.filter(**kwargs).distinct()


class GenericTaggedItemBase(ItemBase):
Expand Down Expand Up @@ -172,13 +177,15 @@ def bulk_lookup_kwargs(cls, instances):
}

@classmethod
def tags_for(cls, model, instance=None):
def tags_for(cls, model, instance=None, **extra_filters):
ct = ContentType.objects.get_for_model(model)
kwargs = {
"%s__content_type" % cls.tag_relname(): ct
}
if instance is not None:
kwargs["%s__object_id" % cls.tag_relname()] = instance.pk
if extra_filters:
kwargs.update(extra_filters)
return cls.tag_model().objects.filter(**kwargs).distinct()


Expand Down
8 changes: 8 additions & 0 deletions taggit/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
from __future__ import unicode_literals

from django import VERSION
from django.utils import six
from django.utils.encoding import force_text
from django.utils.functional import wraps


def _get_field(model, name):
if VERSION < (1, 8):
return model._meta.get_field_by_name(name)[0]
else:
return model._meta.get_field(name)


def parse_tags(tagstring):
"""
Parses tag input, with multiple word input being activated and
Expand Down
10 changes: 10 additions & 0 deletions tests/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from django.contrib.contenttypes.models import ContentType
from django.core import serializers
from django.core.exceptions import ImproperlyConfigured, ValidationError
from django.db import connection
from django.test import TestCase, TransactionTestCase
from django.utils.encoding import force_text

Expand Down Expand Up @@ -380,6 +381,15 @@ def test_internal_type_is_manytomany(self):
TaggableManager().get_internal_type(), 'ManyToManyField'
)

def test_prefetch_no_extra_join(self):
apple = self.food_model.objects.create(name="apple")
apple.tags.add('1', '2')
with self.assertNumQueries(2):
l = list(self.food_model.objects.prefetch_related('tags').all())
join_clause = 'INNER JOIN "%s"' % self.taggeditem_model._meta.db_table
self.assertEqual(connection.queries[-1]['sql'].count(join_clause), 1, connection.queries[-2:])


class TaggableManagerDirectTestCase(TaggableManagerTestCase):
food_model = DirectFood
pet_model = DirectPet
Expand Down

0 comments on commit 72156f5

Please sign in to comment.