Permalink
Browse files

[soc2010/query-refactor] Implemented values (and values_list).

  • Loading branch information...
alex committed Jun 21, 2010
1 parent 28c9044 commit 1fda238ce897e342402983e72a6226104141cc6d
Showing with 54 additions and 6 deletions.
  1. +22 −5 django/contrib/mongodb/compiler.py
  2. +1 −1 django/db/models/query.py
  3. +31 −0 tests/regressiontests/mongodb/tests.py
@@ -56,16 +56,28 @@ def negate(self, k, v):
return {k: {"$not": v}}
return {k: {"$ne": v}}
- def build_query(self, aggregates=False):
- assert len([a for a in self.query.alias_map if self.query.alias_refcount[a]]) <= 1
+ def get_fields(self, aggregates):
+ if self.query.select:
+ fields = []
+ for alias, field in self.query.select:
+ assert alias == self.query.model._meta.db_table
+ if field == self.query.model._meta.pk.column:
+ field = "_id"
+ fields.append(field)
+ return fields
if not aggregates:
assert self.query.default_cols
+ return None
+
+ def build_query(self, aggregates=False):
+ assert len([a for a in self.query.alias_map if self.query.alias_refcount[a]]) <= 1
assert not self.query.distinct
assert not self.query.extra
assert not self.query.having
filters = self.get_filters(self.query.where)
- cursor = self.connection.db[self.query.model._meta.db_table].find(filters)
+ fields = self.get_fields(aggregates=aggregates)
+ cursor = self.connection.db[self.query.model._meta.db_table].find(filters, fields=fields)
if self.query.order_by:
cursor = cursor.sort([
(ordering.lstrip("-"), DESCENDING if ordering.startswith("-") else ASCENDING)
@@ -79,10 +91,15 @@ def build_query(self, aggregates=False):
def results_iter(self):
query = self.build_query()
+ fields = self.get_fields(aggregates=False)
+ if fields is None:
+ fields = [
+ f.column if f is not self.query.model._meta.pk else "_id"
+ for f in self.query.model._meta.fields
+ ]
for row in query:
yield tuple(
- row[f.column if f is not self.query.model._meta.pk else "_id"]
- for f in self.query.model._meta.fields
+ row[f] for f in fields
)
def has_results(self):
@@ -843,12 +843,12 @@ def _setup_query(self):
if self._fields:
self.extra_names = []
self.aggregate_names = []
+ self.query.default_cols = False
if not self.query.extra and not self.query.aggregates:
# Short cut - if there are no extra or aggregates, then
# the values() clause must be just field names.
self.field_names = list(self._fields)
else:
- self.query.default_cols = False
self.field_names = []
for f in self._fields:
# we inspect the full extra_select list since we might
@@ -113,6 +113,37 @@ def test_slicing(self):
artists[2:],
lambda a: a,
)
+
+ def test_values(self):
+ a = Artist.objects.create(name="Steve Perry", good=True)
+
+ self.assertQuerysetEqual(
+ Artist.objects.values(), [
+ {"name": "Steve Perry", "good": True, "current_group_id": None, "id": a.pk},
+ ],
+ lambda a: a,
+ )
+
+ self.assertQuerysetEqual(
+ Artist.objects.values("name"), [
+ {"name": "Steve Perry"},
+ ],
+ lambda a: a,
+ )
+
+ self.assertQuerysetEqual(
+ Artist.objects.values_list("name"), [
+ ("Steve Perry",)
+ ],
+ lambda a: a,
+ )
+
+ self.assertQuerysetEqual(
+ Artist.objects.values_list("name", flat=True), [
+ "Steve Perry",
+ ],
+ lambda a: a,
+ )
def test_not_equals(self):
q = Group.objects.create(name="Queen", year_formed=1971)

0 comments on commit 1fda238

Please sign in to comment.