Skip to content

Commit b26668f

Browse files
committed
CSHARP-5481: ScalarDiscriminatorConvention class should implement IScalarDiscriminatorConvention interface.
1 parent 51dbb28 commit b26668f

File tree

16 files changed

+1083
-300
lines changed

16 files changed

+1083
-300
lines changed

src/MongoDB.Bson/Serialization/BsonSerializer.cs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -778,6 +778,37 @@ internal static void EnsureKnownTypesAreRegistered(Type nominalType)
778778
}
779779
}
780780

781+
// internal static methods
782+
internal static BsonValue[] GetDiscriminatorsForTypeAndSubTypes(Type type)
783+
{
784+
// note: EnsureKnownTypesAreRegistered handles its own locking so call from outside any lock
785+
EnsureKnownTypesAreRegistered(type);
786+
787+
var discriminators = new List<BsonValue>();
788+
789+
__configLock.EnterReadLock();
790+
try
791+
{
792+
foreach (var entry in __discriminators)
793+
{
794+
var discriminator = entry.Key;
795+
var actualTypes = entry.Value;
796+
797+
var matchingType = actualTypes.SingleOrDefault(t => t == type || t.IsSubclassOf(type));
798+
if (matchingType != null)
799+
{
800+
discriminators.Add(discriminator);
801+
}
802+
}
803+
}
804+
finally
805+
{
806+
__configLock.ExitReadLock();
807+
}
808+
809+
return discriminators.OrderBy(x => x).ToArray();
810+
}
811+
781812
// private static methods
782813
private static void CreateSerializerRegistry()
783814
{

src/MongoDB.Bson/Serialization/Conventions/ScalarDiscriminatorConvention.cs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,17 @@
1414
*/
1515

1616
using System;
17+
using System.Collections.Concurrent;
1718

1819
namespace MongoDB.Bson.Serialization.Conventions
1920
{
2021
/// <summary>
2122
/// Represents a discriminator convention where the discriminator is provided by the class map of the actual type.
2223
/// </summary>
23-
public class ScalarDiscriminatorConvention : StandardDiscriminatorConvention
24+
public class ScalarDiscriminatorConvention : StandardDiscriminatorConvention, IScalarDiscriminatorConvention
2425
{
26+
private readonly ConcurrentDictionary<Type, BsonValue[]> _cachedTypeAndSubTypeDiscriminators = new();
27+
2528
// constructors
2629
/// <summary>
2730
/// Initializes a new instance of the ScalarDiscriminatorConvention class.
@@ -52,5 +55,11 @@ public override BsonValue GetDiscriminator(Type nominalType, Type actualType)
5255
return null;
5356
}
5457
}
58+
59+
/// <inheritdoc/>
60+
public BsonValue[] GetDiscriminatorsForTypeAndSubTypes(Type type)
61+
{
62+
return _cachedTypeAndSubTypeDiscriminators.GetOrAdd(type, BsonSerializer.GetDiscriminatorsForTypeAndSubTypes);
63+
}
5564
}
5665
}

src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Optimizers/AstSimplifier.cs

Lines changed: 71 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
using MongoDB.Bson;
1919
using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions;
2020
using MongoDB.Driver.Linq.Linq3Implementation.Ast.Filters;
21+
using MongoDB.Driver.Linq.Linq3Implementation.Ast.Stages;
2122
using MongoDB.Driver.Linq.Linq3Implementation.Ast.Visitors;
23+
using MongoDB.Driver.Linq.Linq3Implementation.Misc;
2224

2325
namespace MongoDB.Driver.Linq.Linq3Implementation.Ast.Optimizers
2426
{
@@ -351,6 +353,33 @@ elemMatchOperation.Filter is AstFieldOperationFilter elemFilter &&
351353
}
352354
}
353355

356+
public override AstNode VisitFilterExpression(AstFilterExpression node)
357+
{
358+
var inputExpression = VisitAndConvert(node.Input);
359+
var condExpression = VisitAndConvert(node.Cond);
360+
var limitExpression = VisitAndConvert(node.Limit);
361+
362+
if (condExpression is AstConstantExpression condConstantExpression &&
363+
condConstantExpression.Value is BsonBoolean condBsonBoolean)
364+
{
365+
if (condBsonBoolean.Value)
366+
{
367+
// { $filter : { input : <input>, as : "x", cond : true } } => <input>
368+
if (limitExpression == null)
369+
{
370+
return inputExpression;
371+
}
372+
}
373+
else
374+
{
375+
// { $filter : { input : <input>, as : "x", cond : false, optional-limit } } => []
376+
return AstExpression.Constant(new BsonArray());
377+
}
378+
}
379+
380+
return node.Update(inputExpression, condExpression, limitExpression);
381+
}
382+
354383
public override AstNode VisitGetFieldExpression(AstGetFieldExpression node)
355384
{
356385
if (TrySimplifyAsFieldPath(node, out var simplified))
@@ -448,6 +477,26 @@ public override AstNode VisitNotFilterOperation(AstNotFilterOperation node)
448477
return base.VisitNotFilterOperation(node);
449478
}
450479

480+
public override AstNode VisitPipeline(AstPipeline node)
481+
{
482+
var stages = VisitAndConvert(node.Stages);
483+
484+
// { $match : { } } => remove redundant stage
485+
if (stages.Any(stage => IsMatchEverythingStage(stage)))
486+
{
487+
stages = stages.Where(stage => !IsMatchEverythingStage(stage)).AsReadOnlyList();
488+
}
489+
490+
return node.Update(stages);
491+
492+
static bool IsMatchEverythingStage(AstStage stage)
493+
{
494+
return
495+
stage is AstMatchStage matchStage &&
496+
matchStage.Filter is AstMatchesEverythingFilter;
497+
}
498+
}
499+
451500
public override AstNode VisitSliceExpression(AstSliceExpression node)
452501
{
453502
node = (AstSliceExpression)base.VisitSliceExpression(node);
@@ -498,15 +547,34 @@ arrayConstant.Value is BsonArray bsonArrayConstant &&
498547

499548
public override AstNode VisitUnaryExpression(AstUnaryExpression node)
500549
{
550+
var arg = VisitAndConvert(node.Arg);
551+
501552
// { $first : <arg> } => { $arrayElemAt : [<arg>, 0] } (or -1 for $last)
502553
if (node.Operator == AstUnaryOperator.First || node.Operator == AstUnaryOperator.Last)
503554
{
504-
var simplifiedArg = VisitAndConvert(node.Arg);
505555
var index = node.Operator == AstUnaryOperator.First ? 0 : -1;
506-
return AstExpression.ArrayElemAt(simplifiedArg, index);
556+
return AstExpression.ArrayElemAt(arg, index);
557+
}
558+
559+
// { $not : booleanConstant } => !booleanConstant
560+
if (node.Operator is AstUnaryOperator.Not &&
561+
arg is AstConstantExpression argConstantExpression &&
562+
argConstantExpression.Value is BsonBoolean argBsonBoolean)
563+
{
564+
return AstExpression.Constant(!argBsonBoolean.Value);
565+
}
566+
567+
// { $not : { $eq : [expr1, expr2] } } => { $ne : [expr1, expr2] }
568+
// { $not : { $ne : [expr1, expr2] } } => { $eq : [expr1, expr2] }
569+
if (node.Operator is AstUnaryOperator.Not &&
570+
arg is AstBinaryExpression argBinaryExpression &&
571+
argBinaryExpression.Operator is AstBinaryOperator.Eq or AstBinaryOperator.Ne)
572+
{
573+
var oppositeComparisonOperator = argBinaryExpression.Operator == AstBinaryOperator.Eq ? AstBinaryOperator.Ne : AstBinaryOperator.Eq;
574+
return AstExpression.Binary(oppositeComparisonOperator, argBinaryExpression.Arg1, argBinaryExpression.Arg2);
507575
}
508576

509-
return base.VisitUnaryExpression(node);
577+
return node.Update(arg);
510578
}
511579
}
512580
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
/* Copyright 2010-present MongoDB Inc.
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
using System.Linq;
17+
using System.Linq.Expressions;
18+
using System.Reflection;
19+
using MongoDB.Bson.Serialization;
20+
using MongoDB.Bson.Serialization.Conventions;
21+
using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions;
22+
using MongoDB.Driver.Linq.Linq3Implementation.Misc;
23+
using MongoDB.Driver.Linq.Linq3Implementation.Reflection;
24+
using MongoDB.Driver.Linq.Linq3Implementation.Serializers;
25+
26+
namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators.MethodTranslators
27+
{
28+
internal static class OfTypeMethodToAggregationExpressionTranslator
29+
{
30+
private static MethodInfo[] __ofTypeMethods =
31+
{
32+
EnumerableMethod.OfType,
33+
QueryableMethod.OfType
34+
};
35+
36+
public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression)
37+
{
38+
var method = expression.Method;
39+
var arguments = expression.Arguments;
40+
41+
if (method.IsOneOf(__ofTypeMethods))
42+
{
43+
var sourceExpression = arguments[0];
44+
var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression);
45+
NestedAsQueryableHelper.EnsureQueryableMethodHasNestedAsQueryableSource(expression, sourceTranslation);
46+
47+
var sourceAst = sourceTranslation.Ast;
48+
var sourceSerializer = sourceTranslation.Serializer;
49+
if (sourceSerializer is IWrappedValueSerializer wrappedValueSerializer)
50+
{
51+
sourceAst = AstExpression.GetField(sourceAst, wrappedValueSerializer.FieldName);
52+
sourceSerializer = wrappedValueSerializer.ValueSerializer;
53+
}
54+
var itemSerializer = ArraySerializerHelper.GetItemSerializer(sourceSerializer);
55+
56+
var nominalType = itemSerializer.ValueType;
57+
var nominalTypeSerializer = itemSerializer;
58+
var actualType = method.GetGenericArguments().Single();
59+
var actualTypeSerializer = BsonSerializer.LookupSerializer(actualType);
60+
61+
AstExpression ast;
62+
if (nominalType == actualType)
63+
{
64+
ast = sourceAst;
65+
}
66+
else
67+
{
68+
var discriminatorConvention = nominalTypeSerializer.GetDiscriminatorConvention();
69+
var itemVar = AstExpression.Var("item");
70+
var discriminatorField = AstExpression.GetField(itemVar, discriminatorConvention.ElementName);
71+
72+
var ofTypeExpression = discriminatorConvention switch
73+
{
74+
IHierarchicalDiscriminatorConvention hierarchicalDiscriminatorConvention => DiscriminatorAstExpression.TypeIs(discriminatorField, hierarchicalDiscriminatorConvention, nominalType, actualType),
75+
IScalarDiscriminatorConvention scalarDiscriminatorConvention => DiscriminatorAstExpression.TypeIs(discriminatorField, scalarDiscriminatorConvention, nominalType, actualType),
76+
_ => throw new ExpressionNotSupportedException(expression, because: "OfType is not supported with the configured discriminator convention")
77+
};
78+
79+
ast = AstExpression.Filter(
80+
input: sourceAst,
81+
cond: ofTypeExpression,
82+
@as: "item");
83+
}
84+
85+
var resultSerializer = NestedAsQueryableSerializer.CreateIEnumerableOrNestedAsQueryableSerializer(expression.Type, actualTypeSerializer);
86+
return new TranslatedExpression(expression, ast, resultSerializer);
87+
}
88+
89+
throw new ExpressionNotSupportedException(expression);
90+
}
91+
}
92+
}

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/WhereMethodToAggregationExpressionTranslator.cs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,15 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC
4141
var sourceExpression = arguments[0];
4242
var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression);
4343
NestedAsQueryableHelper.EnsureQueryableMethodHasNestedAsQueryableSource(expression, sourceTranslation);
44-
var itemSerializer = ArraySerializerHelper.GetItemSerializer(sourceTranslation.Serializer);
44+
45+
var sourceAst = sourceTranslation.Ast;
46+
var sourceSerializer = sourceTranslation.Serializer;
47+
if (sourceSerializer is IWrappedValueSerializer wrappedValueSerializer)
48+
{
49+
sourceAst = AstExpression.GetField(sourceAst, wrappedValueSerializer.FieldName);
50+
sourceSerializer = wrappedValueSerializer.ValueSerializer;
51+
}
52+
var itemSerializer = ArraySerializerHelper.GetItemSerializer(sourceSerializer);
4553

4654
var predicateLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]);
4755
var predicateParameter = predicateLambda.Parameters[0];
@@ -57,7 +65,7 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC
5765
}
5866

5967
var ast = AstExpression.Filter(
60-
sourceTranslation.Ast,
68+
sourceAst,
6169
predicateTranslation.Ast,
6270
@as: predicateSymbol.Var.Name,
6371
limitTranslation?.Ast);

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/OfTypeMethodToAggregationExpressionTranslator.cs

Lines changed: 0 additions & 80 deletions
This file was deleted.

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/TypeIsExpressionToAggregationExpressionTranslator.cs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,7 @@
1313
* limitations under the License.
1414
*/
1515

16-
using System.Linq;
1716
using System.Linq.Expressions;
18-
using MongoDB.Bson;
1917
using MongoDB.Bson.Serialization;
2018
using MongoDB.Bson.Serialization.Conventions;
2119
using MongoDB.Bson.Serialization.Serializers;

0 commit comments

Comments
 (0)