Skip to content
Browse files

[soc2010/query-refactor] Implemented count() (and by extension the Co…

…unt() aggregate on the primary key).

git-svn-id: http://code.djangoproject.com/svn/django/branches/soc2010/query-refactor@13353 bcc190cf-cafb-0310-a4f2-bffc1f526a37
  • Loading branch information...
1 parent 8f441f0 commit f522555392ed9e133431437dd815ba0e84bc2394 @alex alex committed Jun 14, 2010
View
5 django/contrib/mongodb/base.py
@@ -46,6 +46,11 @@ def flush(self, style, only_django=False):
tables = self.connection.introspection.table_names()
for table in tables:
self.connection.db.drop_collection(table)
+
+ def check_aggregate_support(self, aggregate):
+ # TODO: this really should use the generic aggregates, not the SQL ones
+ from django.db.models.sql.aggregates import Count
+ return isinstance(aggregate, Count)
class DatabaseWrapper(BaseDatabaseWrapper):
def __init__(self, *args, **kwargs):
View
19 django/contrib/mongodb/compiler.py
@@ -32,10 +32,10 @@ def make_atom(self, lhs, lookup_type, value_annotation, params_or_value):
column = "_id"
return column, params[0]
- def build_query(self):
- assert not self.query.aggregates
- assert len([a for a in self.query.alias_map if self.query.alias_refcount[a]]) == 1
- assert self.query.default_cols
+ def build_query(self, aggregates=False):
+ assert len([a for a in self.query.alias_map if self.query.alias_refcount[a]]) <= 1
+ if not aggregates:
+ assert self.query.default_cols
assert not self.query.distinct
assert not self.query.extra
assert not self.query.having
@@ -60,6 +60,17 @@ def has_results(self):
return False
else:
return True
+
+ def get_aggregates(self):
+ assert len(self.query.aggregates) == 1
+ agg = self.query.aggregates.values()[0]
+ assert (
+ isinstance(agg, self.query.aggregates_module.Count) and (
+ agg.col == "*" or
+ isinstance(agg.col, tuple) and agg.col == (self.query.model._meta.db_table, self.query.model._meta.pk.column)
+ )
+ )
+ return [self.build_query(aggregates=True).count()]
class SQLInsertCompiler(SQLCompiler):
View
5 django/db/models/sql/compiler.py
@@ -675,7 +675,10 @@ def has_results(self):
self.query.clear_ordering(True)
self.query.set_limits(high=1)
return bool(self.execute_sql(SINGLE))
-
+
+ def get_aggregates(self):
+ return self.execute_sql(SINGLE)
+
def results_iter(self):
"""
Returns an iterator over the results from executing this query.
View
2 django/db/models/sql/query.py
@@ -363,7 +363,7 @@ def get_aggregation(self, using):
query.related_select_cols = []
query.related_select_fields = []
- result = query.get_compiler(using).execute_sql(SINGLE)
+ result = query.get_compiler(using).get_aggregates()
if result is None:
result = [None for q in query.aggregate_select.items()]
View
19 tests/regressiontests/mongodb/tests.py
@@ -1,3 +1,4 @@
+from django.db.models import Count
from django.test import TestCase
from models import Artist
@@ -25,3 +26,21 @@ def test_update(self):
l = Artist.objects.get(pk=pk)
self.assertTrue(not l.good)
+
+ def test_count(self):
+ Artist.objects.create(name="Billy Joel", good=True)
+ Artist.objects.create(name="John Mellencamp", good=True)
+ Artist.objects.create(name="Warren Zevon", good=True)
+ Artist.objects.create(name="Matisyahu", good=True)
+ Artist.objects.create(name="Gary US Bonds", good=True)
+
+ self.assertEqual(Artist.objects.count(), 5)
+ self.assertEqual(Artist.objects.filter(good=True).count(), 5)
+
+ Artist.objects.create(name="Bon Iver", good=False)
+
+ self.assertEqual(Artist.objects.count(), 6)
+ self.assertEqual(Artist.objects.filter(good=True).count(), 5)
+ self.assertEqual(Artist.objects.filter(good=False).count(), 1)
+
+ self.assertEqual(Artist.objects.aggregate(c=Count("pk")), {"c": 6})

0 comments on commit f522555

Please sign in to comment.
Something went wrong with that request. Please try again.