Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions django_mongodb_backend/aggregates.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def count(self, compiler, connection, resolve_inner_expression=False):
# If distinct=True or resolve_inner_expression=False, sum the size of the
# set.
lhs_mql = process_lhs(self, compiler, connection, as_expr=True)
lhs_mql = {"$ifNull": [lhs_mql, []]}
# None shouldn't be counted, so subtract 1 if it's present.
exits_null = {"$cond": {"if": {"$in": [{"$literal": None}, lhs_mql]}, "then": -1, "else": 0}}
return {"$add": [{"$size": lhs_mql}, exits_null]}
Expand Down
18 changes: 3 additions & 15 deletions django_mongodb_backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(self, *args, **kwargs):
self.subqueries = []
# Atlas search stage.
self.search_pipeline = []
self.wrap_for_global_aggregation = False

def _get_group_alias_column(self, expr, annotation_group_idx):
"""Generate a dummy field for use in the ids fields in $group."""
Expand Down Expand Up @@ -234,21 +235,8 @@ def _build_aggregation_pipeline(self, ids, group):
"""Build the aggregation pipeline for grouping."""
pipeline = []
if not ids:
group["_id"] = None
pipeline.append({"$facet": {"group": [{"$group": group}]}})
pipeline.append(
{
"$addFields": {
key: {
"$getField": {
"input": {"$arrayElemAt": ["$group", 0]},
"field": key,
}
}
for key in group
}
}
)
pipeline.append({"$group": {"_id": None, **group}})
self.wrap_for_global_aggregation = True
else:
group["_id"] = ids
pipeline.append({"$group": group})
Expand Down
17 changes: 17 additions & 0 deletions django_mongodb_backend/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(self, compiler):
# $lookup stage that encapsulates the pipeline for performing a nested
# subquery.
self.subquery_lookup = None
self.wrap_for_global_aggregation = compiler.wrap_for_global_aggregation

def __repr__(self):
return f"<MongoQuery: {self.match_mql!r} ORDER {self.ordering!r}>"
Expand Down Expand Up @@ -91,6 +92,22 @@ def get_pipeline(self):
pipeline.append({"$match": self.match_mql})
if self.aggregation_pipeline:
pipeline.extend(self.aggregation_pipeline)
if self.wrap_for_global_aggregation:
pipeline = [
{"$collStats": {}},
{
"$lookup": {
"from": self.compiler.collection_name,
"as": "wrapped",
"pipeline": pipeline,
}
},
{
"$replaceWith": {
"$cond": [{"$eq": ["$wrapped", []]}, {}, {"$first": "$wrapped"}]
}
},
]
if self.project_fields:
pipeline.append({"$project": self.project_fields})
if self.combinator_pipeline:
Expand Down
Loading