Skip to content

Commit

Permalink
CSHARP-4744: Improve optimization of Count with predicate in Group.
Browse files Browse the repository at this point in the history
  • Loading branch information
rstam committed Aug 9, 2023
1 parent a88dc2e commit 208438f
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,17 @@ public static AggregationExpression Translate(TranslationContext context, Method
}

var predicateLambda = (LambdaExpression)arguments[1];
var sourceItemSerializer = ArraySerializerHelper.GetItemSerializer(sourceTranslation.Serializer);
var predicateTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, predicateLambda, sourceItemSerializer, asRoot: false);
var filteredSourceAst = AstExpression.Filter(
input: sourceTranslation.Ast,
cond: predicateTranslation.Ast,
@as: predicateLambda.Parameters[0].Name);
ast = AstExpression.Size(filteredSourceAst);
var predicateParameter = predicateLambda.Parameters[0];
var predicateParameterSerializer = ArraySerializerHelper.GetItemSerializer(sourceTranslation.Serializer);
var predicateSymbol = context.CreateSymbol(predicateParameter, predicateParameterSerializer);
var predicateContext = context.WithSymbol(predicateSymbol);
var predicateTranslation = ExpressionToAggregationExpressionTranslator.Translate(predicateContext, predicateLambda.Body);

ast = AstExpression.Sum(
AstExpression.Map(
input: sourceTranslation.Ast,
@as: predicateSymbol.Var,
@in: AstExpression.Cond(predicateTranslation.Ast, 1, 0)));
}
else
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -686,8 +686,8 @@ public void IGrouping_Count_with_predicate_of_root_should_work()
var stages = Translate(collection, queryable);
var expectedStages = new[]
{
"{ $group : { _id : '$_id', _elements : { $push : '$$ROOT' } } }", // MQL could be optimized further
"{ $project : { _id : '$_id', Result : { $size : { $filter : { input : '$_elements', as : 'e', cond : { $eq : ['$$e.X', 1] } } } } } }",
"{ $group : { _id : '$_id', __agg0 : { $sum : { $cond : { if : { $eq : ['$X', 1] }, then : 1, else : 0 } } } } }",
"{ $project : { _id : '$_id', Result : '$__agg0' } }",
"{ $sort : { _id : 1 } }"
};
AssertStages(stages, expectedStages);
Expand All @@ -711,8 +711,8 @@ public void IGrouping_Count_with_predicate_of_scalar_should_work()
var stages = Translate(collection, queryable);
var expectedStages = new[]
{
"{ $group : { _id : '$_id', _elements : { $push : '$X' } } }", // MQL could be optimized further
"{ $project : { _id : '$_id', Result : { $size : { $filter : { input : '$_elements', as : 'e', cond : { $eq : ['$$e', 1] } } } } } }",
"{ $group : { _id : '$_id', __agg0 : { $sum : { $cond : { if : { $eq : ['$X', 1] }, then : 1, else : 0 } } } } }",
"{ $project : { _id : '$_id', Result : '$__agg0' } }",
"{ $sort : { _id : 1 } }"
};
AssertStages(stages, expectedStages);
Expand Down Expand Up @@ -1376,8 +1376,8 @@ public void IGrouping_LongCount_with_predicate_of_root_should_work()
var stages = Translate(collection, queryable);
var expectedStages = new[]
{
"{ $group : { _id : '$_id', _elements : { $push : '$$ROOT' } } }", // MQL could be optimized further
"{ $project : { _id : '$_id', Result : { $size : { $filter : { input : '$_elements', as : 'e', cond : { $eq : ['$$e.X', 1] } } } } } }",
"{ $group : { _id : '$_id', __agg0 : { $sum : { $cond : { if : { $eq : ['$X', 1] }, then : 1, else : 0 } } } } }",
"{ $project : { _id : '$_id', Result : '$__agg0' } }",
"{ $sort : { _id : 1 } }"
};
AssertStages(stages, expectedStages);
Expand All @@ -1401,8 +1401,8 @@ public void IGrouping_LongCount_with_predicate_of_scalar_should_work()
var stages = Translate(collection, queryable);
var expectedStages = new[]
{
"{ $group : { _id : '$_id', _elements : { $push : '$X' } } }", // MQL could be optimized further
"{ $project : { _id : '$_id', Result : { $size : { $filter : { input : '$_elements', as : 'e', cond : { $eq : ['$$e', 1] } } } } } }",
"{ $group : { _id : '$_id', __agg0 : { $sum : { $cond : { if : { $eq : ['$X', 1] }, then : 1, else : 0 } } } } }",
"{ $project : { _id : '$_id', Result : '$__agg0' } }",
"{ $sort : { _id : 1 } }"
};
AssertStages(stages, expectedStages);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ public void GroupBy_with_bool_should_work()
var stages = Translate(collection, queryable);
var expectedStages = new[]
{
"{ $group : { _id : '$_id', _elements : { $push : '$$ROOT' } } }",
"{ $project : { Value : { $size : { $filter : { input : '$_elements', as : 'x', cond : '$$x.Bool' } } }, _id : 0 } }"
"{ $group : { _id : '$_id', __agg0 : { $sum : { $cond : { if : '$Bool', then : 1, else : 0 } } } } }",
"{ $project : { Value : '$__agg0', _id : 0 } }"
};
AssertStages(stages, expectedStages);
}
Expand All @@ -52,8 +52,8 @@ public void GroupBy_with_nullable_bool_should_work()
var stages = Translate(collection, queryable);
var expectedStages = new[]
{
"{ $group : { _id : '$_id', _elements : { $push : '$$ROOT' } } }",
"{ $project : { Value : { $size : { $filter : { input : '$_elements', as : 'x', cond : { $and : [{ $ne : ['$$x.NullableBool', null] }, '$$x.NullableBool'] } } } }, _id : 0 } }"
"{ $group : { _id : '$_id', __agg0 : { $sum : { $cond : { if : { $and : [{ $ne : ['$NullableBool', null] }, '$NullableBool'] }, then : 1, else : 0 } } } } }",
"{ $project : { Value : '$__agg0', _id : 0 } }"
};
AssertStages(stages, expectedStages);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/* Copyright 2010-present MongoDB Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

using System.Linq;
using FluentAssertions;
using MongoDB.Bson;
using MongoDB.Bson.Serialization.Attributes;
using MongoDB.Driver.Linq;
using MongoDB.TestHelpers.XunitExtensions;
using Xunit;

namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Jira
{
public class CSharp4744Tests : Linq3IntegrationTest
{
[Theory]
[ParameterAttributeData]
public void ReplaceOne(
[Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider)
{
var collection = GetCollection(linqProvider);

var queryable = collection.AsQueryable()
.GroupBy(x => x.FooName, (x, y) => new Summary()
{
FooName = x,
Count = y.Count(x => x.State == State.Running)
});

var stages = Translate(collection, queryable);
if (linqProvider == LinqProvider.V2)
{
AssertStages(
stages,
"{ $group: { _id : '$FooName', Count : { $sum : { $cond : [{ $eq : ['$State', 1] }, 1, 0] } } } }"); // note: 1 instead of "Running" is an error
}
else
{
AssertStages(
stages,
"{ $group: { _id : '$FooName', __agg0 : { $sum : { $cond : { if : { $eq : ['$State', 'Running'] }, then : 1, else : 0 } } } } }",
"{ $project : { FooName : '$_id', Count : '$__agg0', _id : 0 } }");
}
}

private IMongoCollection<Foo> GetCollection(LinqProvider linqProvider)
{
var collection = GetCollection<Foo>("test", linqProvider);
CreateCollection(collection);
return collection;
}

public enum State
{
Started,
Running,
Complete
}

public class Foo
{
public string FooName;
[BsonRepresentation(BsonType.String)]
public State State;
}

public class Summary
{
public string FooName;
public int Count;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,8 @@ public void Should_translate_count_with_a_predicate()

AssertStages(
result.Stages,
"{ $group : { _id : '$A', _elements : { $push : '$$ROOT' } } }",
"{ $project : { Result : { $size : { $filter : { input : '$_elements', as : 'x', cond : { $ne : ['$$x.A', 'Awesome' ] } } } }, _id : 0 } }");
"{ $group : { _id : '$A', __agg0 : { $sum : { $cond : { if : { $ne : ['$A', 'Awesome'] }, then : 1, else : 0 } } } } }",
"{ $project : { Result : '$__agg0', _id : 0 } }");

result.Value.Result.Should().Be(1);
}
Expand All @@ -182,12 +182,12 @@ public void Should_translate_where_with_a_predicate_and_count()
[Fact]
public void Should_translate_where_select_and_count_with_predicates()
{
var result = Group(x => x.A, g => new { Result = g.Select(x => new { A = x.A }).Count(x => x.A != "Awesome") });
var result = Group(x => x.A, g => new { Result = g.Select(x => new { B = x.A }).Count(x => x.B != "Awesome") });

AssertStages(
result.Stages,
"{ $group : { _id : '$A', __agg0 : { $push : { A : '$A' } } } }",
"{ $project : { Result : { $size : { $filter : { input : '$__agg0', as : 'x', cond : { $ne : ['$$x.A', 'Awesome'] } } } }, _id : 0 } }");
"{ $group : { _id : '$A', __agg0 : { $push : { B : '$A' } } } }",
"{ $project : { Result : { $sum : { $map : { input : '$__agg0', as : 'x', in : { $cond : { if : { $ne : ['$$x.B', 'Awesome'] }, then : 1, else : 0 } } } } }, _id : 0 } }");

result.Value.Result.Should().Be(1);
}
Expand Down

0 comments on commit 208438f

Please sign in to comment.