Skip to content

Commit

Permalink
CSHARP-4557: Add support for ContainsKey in LINQ3.
Browse files Browse the repository at this point in the history
  • Loading branch information
rstam authored and DmitryLukyanov committed Mar 23, 2023
1 parent 4643231 commit 4284dde
Show file tree
Hide file tree
Showing 7 changed files with 276 additions and 0 deletions.
Expand Up @@ -847,6 +847,11 @@ public static AstExpression Trunc(AstExpression arg)
return new AstUnaryExpression(AstUnaryOperator.Trunc, arg);
}

public static AstExpression Type(AstExpression arg)
{
return new AstUnaryExpression(AstUnaryOperator.Type, arg);
}

public static AstExpression Unary(AstUnaryOperator @operator, AstExpression arg)
{
return new AstUnaryExpression(@operator, arg);
Expand Down
Expand Up @@ -100,6 +100,11 @@ public static AstFieldOperationFilter Eq(AstFilterField field, BsonValue value)
return new AstFieldOperationFilter(field, new AstComparisonFilterOperation(AstComparisonFilterOperator.Eq, value));
}

public static AstFieldOperationFilter Exists(AstFilterField field)
{
return new AstFieldOperationFilter(field, new AstExistsFilterOperation(exists: true));
}

public static AstFilter Expr(AstExpression expression)
{
return new AstExprFilter(expression);
Expand Down
Expand Up @@ -35,6 +35,7 @@ public static AggregationExpression Translate(TranslationContext context, Method
case "CompareTo": return CompareToMethodToAggregationExpressionTranslator.Translate(context, expression);
case "Concat": return ConcatMethodToAggregationExpressionTranslator.Translate(context, expression);
case "Contains": return ContainsMethodToAggregationExpressionTranslator.Translate(context, expression);
case "ContainsKey": return ContainsKeyMethodToAggregationExpressionTranslator.Translate(context, expression);
case "CovariancePopulation": return CovariancePopulationMethodToAggregationExpressionTranslator.Translate(context, expression);
case "CovarianceSample": return CovarianceSampleMethodToAggregationExpressionTranslator.Translate(context, expression);
case "DefaultIfEmpty": return DefaultIfEmptyMethodToAggregationExpressionTranslator.Translate(context, expression);
Expand Down
@@ -0,0 +1,86 @@
/* Copyright 2010-present MongoDB Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

using System.Linq.Expressions;
using System.Reflection;
using MongoDB.Bson.Serialization;
using MongoDB.Bson.Serialization.Options;
using MongoDB.Bson.Serialization.Serializers;
using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions;

namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators.MethodTranslators
{
internal static class ContainsKeyMethodToAggregationExpressionTranslator
{
// public methods
public static AggregationExpression Translate(TranslationContext context, MethodCallExpression expression)
{
var method = expression.Method;
var arguments = expression.Arguments;

if (IsContainsKeyMethod(method))
{
var dictionaryExpression = expression.Object;
var keyExpression = arguments[0];

var dictionaryTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, dictionaryExpression);
var dictionarySerializer = GetDictionarySerializer(expression, dictionaryTranslation);
var dictionaryRepresentation = dictionarySerializer.DictionaryRepresentation;

var keyTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, keyExpression);

AstExpression ast;
switch (dictionaryRepresentation)
{
case DictionaryRepresentation.Document:
if (keyExpression.Type != typeof(string))
{
throw new ExpressionNotSupportedException(expression, because: "ContainsKey requires key to be of type string when DictionaryRepresentation is: Document");
}
ast = AstExpression.Ne(AstExpression.Type(AstExpression.GetField(dictionaryTranslation.Ast, keyTranslation.Ast)), "missing");
break;

default:
throw new ExpressionNotSupportedException(expression, because: $"ContainsKey is not supported when DictionaryRepresentation is: {dictionaryRepresentation}");
}

return new AggregationExpression(expression, ast, BooleanSerializer.Instance);
}

throw new ExpressionNotSupportedException(expression);
}

private static IBsonDictionarySerializer GetDictionarySerializer(Expression expression, AggregationExpression dictionaryTranslation)
{
if (dictionaryTranslation.Serializer is IBsonDictionarySerializer dictionarySerializer)
{
return dictionarySerializer;
}

throw new ExpressionNotSupportedException(expression, because: $"class {dictionaryTranslation.Serializer.GetType().FullName} does not implement the IBsonDictionarySerializer interface");
}

private static bool IsContainsKeyMethod(MethodInfo method)
{
return
!method.IsStatic &&
method.IsPublic &&
method.ReturnType == typeof(bool) &&
method.Name == "ContainsKey" &&
method.GetParameters() is var parameters &&
parameters.Length == 1;
}
}
}
@@ -0,0 +1,88 @@
/* Copyright 2010-present MongoDB Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

using System.Linq.Expressions;
using System.Reflection;
using MongoDB.Bson.Serialization;
using MongoDB.Bson.Serialization.Options;
using MongoDB.Driver.Linq.Linq3Implementation.Ast.Filters;
using MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToFilterTranslators.ToFilterFieldTranslators;

namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToFilterTranslators.MethodTranslators
{
internal static class ContainsKeyMethodToFilterTranslator
{
public static AstFilter Translate(TranslationContext context, MethodCallExpression expression)
{
var method = expression.Method;
var arguments = expression.Arguments;

if (IsContainsKeyMethod(method))
{
var dictionaryExpression = expression.Object;
var keyExpression = arguments[0];

var dictionaryField = ExpressionToFilterFieldTranslator.Translate(context, dictionaryExpression);
var dictionarySerializer = GetDictionarySerializer(expression, dictionaryField);
var valueSerializer = dictionarySerializer.ValueSerializer;
var dictionaryRepresentation = dictionarySerializer.DictionaryRepresentation;

switch (dictionaryRepresentation)
{
case DictionaryRepresentation.Document:
var key = GetKeyStringConstant(expression, keyExpression);
var keyField = dictionaryField.SubField(key, valueSerializer);
return AstFilter.Exists(keyField);

default:
throw new ExpressionNotSupportedException(expression, because: $"ContainsKey is not supported when DictionaryRepresentation is: {dictionaryRepresentation}");
}
}

throw new ExpressionNotSupportedException(expression);
}

private static IBsonDictionarySerializer GetDictionarySerializer(Expression expression, AstFilterField field)
{
if (field.Serializer is IBsonDictionarySerializer dictionarySerializer)
{
return dictionarySerializer;
}

throw new ExpressionNotSupportedException(expression, because: $"class {field.Serializer.GetType().FullName} does not implement the IBsonDictionarySerializer interface");
}

private static string GetKeyStringConstant(Expression expression, Expression keyExpression)
{
if (keyExpression is ConstantExpression keyConstantExpression && keyExpression.Type == typeof(string))
{
return (string)keyConstantExpression.Value;
}

throw new ExpressionNotSupportedException(expression, because: "key must be a string constant");
}

private static bool IsContainsKeyMethod(MethodInfo method)
{
return
!method.IsStatic &&
method.IsPublic &&
method.ReturnType == typeof(bool) &&
method.Name == "ContainsKey" &&
method.GetParameters() is var parameters &&
parameters.Length == 1;
}
}
}
Expand Up @@ -25,6 +25,7 @@ public static AstFilter Translate(TranslationContext context, MethodCallExpressi
switch (expression.Method.Name)
{
case "Contains": return ContainsMethodToFilterTranslator.Translate(context, expression);
case "ContainsKey": return ContainsKeyMethodToFilterTranslator.Translate(context, expression);
case "EndsWith": return EndsWithMethodToFilterTranslator.Translate(context, expression);
case "Equals": return EqualsMethodToFilterTranslator.Translate(context, expression);
case "HasFlag": return HasFlagMethodToFilterTranslator.Translate(context, expression);
Expand Down
@@ -0,0 +1,90 @@
/* Copyright 2010-present MongoDB Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

using System;
using System.Collections.Generic;
using System.Linq;
using FluentAssertions;
using MongoDB.Driver.Linq;
using MongoDB.TestHelpers.XunitExtensions;
using Xunit;

namespace MongoDB.Driver.Tests.Linq.Linq3ImplementationTests.Jira
{
public class CSharp4557Tests : Linq3IntegrationTest
{
[Theory]
[ParameterAttributeData]
public void Where_with_ContainsKey_should_work(
[Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider)
{
var collection = CreateCollection(linqProvider);

var queryable = collection
.AsQueryable()
.Where(x => x.Foo.ContainsKey("bar"));

var stages = Translate(collection, queryable);
AssertStages(stages, "{ $match : { 'Foo.bar' : { $exists : true } } }");

var results = queryable.ToList();
results.Select(x => x.Id).Should().Equal(2);
}

[Theory]
[ParameterAttributeData]
public void Select_with_ContainsKey_should_work(
[Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider)
{
var collection = CreateCollection(linqProvider);

var queryable = collection
.AsQueryable()
.Select(x => x.Foo.ContainsKey("bar"));

if (linqProvider == LinqProvider.V2)
{
var exception = Record.Exception(() => Translate(collection, queryable));
exception.Should().BeOfType<NotSupportedException>();
}
else
{
var stages = Translate(collection, queryable);
AssertStages(stages, "{ $project : { _v : { $ne : [{ $type : '$Foo.bar' }, 'missing'] }, _id : 0 } }");

var results = queryable.ToList();
results.Should().Equal(false, true);
}
}

private IMongoCollection<C> CreateCollection(LinqProvider linqProvider)
{
var collection = GetCollection<C>("C", linqProvider);

CreateCollection(
collection,
new C { Id = 1, Foo = new Dictionary<string, int> { { "foo", 100 } } },
new C { Id = 2, Foo = new Dictionary<string, int> { { "bar", 100 } } });

return collection;
}

private class C
{
public int Id { get; set; }
public Dictionary<string, int> Foo { get; set; }
}
}
}

0 comments on commit 4284dde

Please sign in to comment.