Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

Fix for get_top query in MySQL; all tests now pass in MySQL

git-svn-id: https://django-voting.googlecode.com/svn/trunk@54 662f01ad-f42a-0410-a340-718c64ddaef4
  • Loading branch information...
commit c7e532685a9b5f77fb871fd8fb53cdb2788a13f6 1 parent 142a566
@insin insin authored
Showing with 41 additions and 13 deletions.
  1. +25 −9 managers.py
  2. +14 −4 tests/settings.py
  3. +2 −0  tests/tests.py
View
34 managers.py
@@ -1,3 +1,4 @@
+from django.conf import settings
from django.db import backend, connection, models
from django.contrib.contenttypes.models import ContentType
@@ -16,7 +17,9 @@ def get_score(self, obj):
cursor = connection.cursor()
cursor.execute(query, [ctype.id, obj._get_pk_val()])
result = cursor.fetchall()[0]
- return {'score': result[0] or 0, 'num_votes': result[1]}
+ # MySQL returns floats and longs respectively for these
+ # results, so we need to convert them to ints explicitly.
+ return {'score': result[0] and int(result[0]) or 0, 'num_votes': int(result[1])}
def get_scores_in_bulk(self, objects):
"""
@@ -36,9 +39,9 @@ def get_scores_in_bulk(self, objects):
cursor = connection.cursor()
cursor.execute(query, [ctype.id] + [obj._get_pk_val() for obj in objects])
results = cursor.fetchall()
- return dict([(object_id, {
- 'score': score,
- 'num_votes': num_votes,
+ return dict([(int(object_id), {
+ 'score': int(score),
+ 'num_votes': int(num_votes),
}) for object_id, score, num_votes in results])
def record_vote(self, obj, user, vote):
@@ -71,14 +74,27 @@ def get_top(self, Model, limit=10, reversed=False):
"""
ctype = ContentType.objects.get_for_model(Model)
query = """
- SELECT object_id, SUM(vote)
+ SELECT object_id, SUM(vote) as %s
FROM %s
WHERE content_type_id = %%s
- GROUP BY object_id""" % backend.quote_name(self.model._meta.db_table)
+ GROUP BY object_id""" % (
+ backend.quote_name('score'),
+ backend.quote_name(self.model._meta.db_table),
+ )
+
+ # MySQL has issues with re-using the aggregate function in the
+ # HAVING clause, so we alias the score and use this alias for
+ # its benefit.
+ if settings.DATABASE_ENGINE == 'mysql':
+ having_score = backend.quote_name('score')
+ else:
+ having_score = 'SUM(vote)'
if reversed:
- query += ' HAVING SUM(vote) < 0 ORDER BY SUM(vote) ASC LIMIT %s'
+ having_sql = ' HAVING %(having_score)s < 0 ORDER BY %(having_score)s ASC LIMIT %%s'
else:
- query += ' HAVING SUM(vote) > 0 ORDER BY SUM(vote) DESC LIMIT %s'
+ having_sql = ' HAVING %(having_score)s > 0 ORDER BY %(having_score)s DESC LIMIT %%s'
+ query += having_sql % {'having_score': having_score}
+
cursor = connection.cursor()
cursor.execute(query, [ctype.id, limit])
results = cursor.fetchall()
@@ -90,7 +106,7 @@ def get_top(self, Model, limit=10, reversed=False):
# relations, missing objects are silently ignored.
for id, score in results:
if id in objects:
- yield objects[id], score
+ yield objects[id], int(score)
def get_bottom(self, Model, limit=10):
"""
View
18 tests/settings.py
@@ -5,10 +5,20 @@
DATABASE_ENGINE = 'sqlite3'
DATABASE_NAME = os.path.join(DIRNAME, 'database.db')
-DATABASE_USER = ''
-DATABASE_PASSWORD = ''
-DATABASE_HOST = ''
-DATABASE_PORT = ''
+
+#DATABASE_ENGINE = 'mysql'
+#DATABASE_NAME = 'tagging_test'
+#DATABASE_USER = 'root'
+#DATABASE_PASSWORD = ''
+#DATABASE_HOST = 'localhost'
+#DATABASE_PORT = '3306'
+
+#DATABASE_ENGINE = 'postgresql_psycopg2'
+#DATABASE_NAME = 'tagging_test'
+#DATABASE_USER = 'postgres'
+#DATABASE_PASSWORD = ''
+#DATABASE_HOST = 'localhost'
+#DATABASE_PORT = '5432'
INSTALLED_APPS = (
'django.contrib.auth',
View
2  tests/tests.py
@@ -80,4 +80,6 @@
>>> list(Vote.objects.get_bottom(Item))
[(<Item: test3>, -4), (<Item: test4>, -3), (<Item: test2>, -2)]
+>>> Vote.objects.get_scores_in_bulk([i1, i2, i3, i4])
+{1: {'score': 0, 'num_votes': 4}, 2: {'score': -2, 'num_votes': 4}, 3: {'score': -4, 'num_votes': 4}, 4: {'score': -3, 'num_votes': 3}}
"""
Please sign in to comment.
Something went wrong with that request. Please try again.