From 1c99be2424c5b9801f597eab2217898c60189a5d Mon Sep 17 00:00:00 2001 From: rstam Date: Sat, 15 Nov 2025 11:59:28 -0500 Subject: [PATCH 1/2] CSHARP-5632: Consolidate driver project Type extension methods in Misc\TypeExtensions.cs --- .../FieldValueSerializerHelper.cs | 32 ++- .../Misc/TypeExtensions.cs | 126 ++++++++++-- .../Linq3Implementation/MongoQueryProvider.cs | 8 +- ...essionToAggregationExpressionTranslator.cs | 2 - ...MethodToAggregationExpressionTranslator.cs | 1 - ...MethodToAggregationExpressionTranslator.cs | 1 - ...MethodToAggregationExpressionTranslator.cs | 1 - .../Support/ReflectionExtensions.cs | 191 ------------------ 8 files changed, 129 insertions(+), 233 deletions(-) delete mode 100644 src/MongoDB.Driver/Support/ReflectionExtensions.cs diff --git a/src/MongoDB.Driver/FieldValueSerializerHelper.cs b/src/MongoDB.Driver/FieldValueSerializerHelper.cs index 68880f7fe18..94cb14f600c 100644 --- a/src/MongoDB.Driver/FieldValueSerializerHelper.cs +++ b/src/MongoDB.Driver/FieldValueSerializerHelper.cs @@ -20,7 +20,7 @@ using MongoDB.Bson; using MongoDB.Bson.Serialization; using MongoDB.Bson.Serialization.Serializers; -using MongoDB.Driver.Support; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; namespace MongoDB.Driver { @@ -63,7 +63,7 @@ public static IBsonSerializer GetSerializerForValueType(IBsonSerializer fieldSer var fieldSerializerInterfaceType = typeof(IBsonSerializer<>).MakeGenericType(fieldType); // synthesize a NullableSerializer using the field serializer - if (valueType.IsNullable() && valueType.GetNullableUnderlyingType() == fieldType) + if (valueType.IsNullable(out var nonNullableValueType) && nonNullableValueType == fieldType) { var nullableSerializerType = typeof(NullableSerializer<>).MakeGenericType(fieldType); var nullableSerializerConstructor = nullableSerializerType.GetTypeInfo().GetConstructor(new[] { fieldSerializerInterfaceType }); @@ -80,24 +80,21 @@ public static IBsonSerializer GetSerializerForValueType(IBsonSerializer fieldSer return (IBsonSerializer)enumConvertingSerializerConstructor.Invoke(new object[] { fieldSerializer }); } - if (valueType.IsNullable() && valueType.GetNullableUnderlyingType().IsConvertibleToEnum()) + if (valueType.IsNullable(out nonNullableValueType) && nonNullableValueType.IsConvertibleToEnum()) { - var underlyingValueType = valueType.GetNullableUnderlyingType(); - var underlyingValueSerializerInterfaceType = typeof(IBsonSerializer<>).MakeGenericType(underlyingValueType); - var enumConvertingSerializerType = typeof(EnumConvertingSerializer<,>).MakeGenericType(underlyingValueType, fieldType); + var nonNullableValueSerializerInterfaceType = typeof(IBsonSerializer<>).MakeGenericType(nonNullableValueType); + var enumConvertingSerializerType = typeof(EnumConvertingSerializer<,>).MakeGenericType(nonNullableValueType, fieldType); var enumConvertingSerializerConstructor = enumConvertingSerializerType.GetTypeInfo().GetConstructor(new[] { fieldSerializerInterfaceType }); var enumConvertingSerializer = enumConvertingSerializerConstructor.Invoke(new object[] { fieldSerializer }); - var nullableSerializerType = typeof(NullableSerializer<>).MakeGenericType(underlyingValueType); - var nullableSerializerConstructor = nullableSerializerType.GetTypeInfo().GetConstructor(new[] { underlyingValueSerializerInterfaceType }); + var nullableSerializerType = typeof(NullableSerializer<>).MakeGenericType(nonNullableValueType); + var nullableSerializerConstructor = nullableSerializerType.GetTypeInfo().GetConstructor(new[] { nonNullableValueSerializerInterfaceType }); return (IBsonSerializer)nullableSerializerConstructor.Invoke(new object[] { enumConvertingSerializer }); } } // synthesize a NullableEnumConvertingSerializer using the field serializer - if (fieldType.IsNullableEnum() && valueType.IsNullable()) + if (fieldType.IsNullableEnum(out var nonNullableFieldType) && valueType.IsNullable(out nonNullableValueType)) { - var nonNullableFieldType = fieldType.GetNullableUnderlyingType(); - var nonNullableValueType = valueType.GetNullableUnderlyingType(); var nonNullableFieldSerializer = ((IChildSerializerConfigurable)fieldSerializer).ChildSerializer; var nonNullableFieldSerializerInterfaceType = typeof(IBsonSerializer<>).MakeGenericType(nonNullableFieldType); var nullableEnumConvertingSerializerType = typeof(NullableEnumConvertingSerializer<,>).MakeGenericType(nonNullableValueType, nonNullableFieldType); @@ -106,18 +103,15 @@ public static IBsonSerializer GetSerializerForValueType(IBsonSerializer fieldSer } // synthesize an IEnumerableSerializer serializer using the item serializer from the field serializer - Type fieldIEnumerableInterfaceType; - Type valueIEnumerableInterfaceType; - Type itemType; if ( - (fieldIEnumerableInterfaceType = fieldType.FindIEnumerable()) != null && - (valueIEnumerableInterfaceType = valueType.FindIEnumerable()) != null && - (itemType = fieldIEnumerableInterfaceType.GetSequenceElementType()) == valueIEnumerableInterfaceType.GetSequenceElementType() && + fieldType.ImplementsIEnumerable(out var fieldItemType) && + valueType.ImplementsIEnumerable(out var valueItemType) && + fieldItemType == valueItemType && fieldSerializer is IChildSerializerConfigurable) { var itemSerializer = ((IChildSerializerConfigurable)fieldSerializer).ChildSerializer; - var itemSerializerInterfaceType = typeof(IBsonSerializer<>).MakeGenericType(itemType); - var ienumerableSerializerType = typeof(IEnumerableSerializer<>).MakeGenericType(itemType); + var itemSerializerInterfaceType = typeof(IBsonSerializer<>).MakeGenericType(fieldItemType); + var ienumerableSerializerType = typeof(IEnumerableSerializer<>).MakeGenericType(fieldItemType); var ienumerableSerializerConstructor = ienumerableSerializerType.GetTypeInfo().GetConstructor(new[] { itemSerializerInterfaceType }); return (IBsonSerializer)ienumerableSerializerConstructor.Invoke(new object[] { itemSerializer }); } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/TypeExtensions.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/TypeExtensions.cs index f6d2f758940..7b3f5eb4ebb 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/TypeExtensions.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/TypeExtensions.cs @@ -16,7 +16,9 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Reflection; using System.Runtime.CompilerServices; +using MongoDB.Bson; namespace MongoDB.Driver.Linq.Linq3Implementation.Misc { @@ -52,6 +54,14 @@ internal static class TypeExtensions typeof(ValueTuple<,,,,,,,>) }; + public static object GetDefaultValue(this Type type) + { + var genericMethod = typeof(TypeExtensions) + .GetMethod(nameof(GetDefaultValueGeneric), BindingFlags.NonPublic | BindingFlags.Static) + .MakeGenericMethod(type); + return genericMethod.Invoke(null, null); + } + public static Type GetIEnumerableGenericInterface(this Type enumerableType) { if (enumerableType.TryGetIEnumerableGenericInterface(out var ienumerableGenericInterface)) @@ -136,6 +146,18 @@ public static bool ImplementsIList(this Type type, out Type itemType) return false; } + public static bool ImplementsIQueryable(this Type type, out Type itemType) + { + if (TryGetIQueryableGenericInterface(type, out var iqueryableType)) + { + itemType = iqueryableType.GetGenericArguments()[0]; + return true; + } + + itemType = null; + return false; + } + public static bool Is(this Type type, Type comparand) { if (type == comparand) @@ -175,6 +197,28 @@ public static bool IsArray(this Type type, out Type itemType) return false; } + public static bool IsBooleanOrNullableBoolean(this Type type) + { + return + type == typeof(bool) || + type.IsNullable(out var valueType) && valueType == typeof(bool); + } + + public static bool IsConvertibleToEnum(this Type type) + { + return + type == typeof(sbyte) || + type == typeof(short) || + type == typeof(int) || + type == typeof(long) || + type == typeof(byte) || + type == typeof(ushort) || + type == typeof(uint) || + type == typeof(ulong) || + type == typeof(Enum) || + type == typeof(string); + } + public static bool IsEnum(this Type type, out Type underlyingType) { if (type.IsEnum) @@ -189,27 +233,15 @@ public static bool IsEnum(this Type type, out Type underlyingType) } } - public static bool IsEnum(this Type type, out Type enumType, out Type underlyingType) + public static bool IsEnumOrNullableEnum(this Type type, out Type enumType, out Type underlyingType) { - if (type.IsEnum) + if (type.IsEnum(out underlyingType)) { enumType = type; - underlyingType = Enum.GetUnderlyingType(type); return true; } - else - { - enumType = null; - underlyingType = null; - return false; - } - } - public static bool IsEnumOrNullableEnum(this Type type, out Type enumType, out Type underlyingType) - { - return - type.IsEnum(out enumType, out underlyingType) || - type.IsNullableEnum(out enumType, out underlyingType); + return IsNullableEnum(type, out enumType, out underlyingType); } public static bool IsNullable(this Type type) @@ -236,11 +268,29 @@ public static bool IsNullableEnum(this Type type) return type.IsNullable(out var valueType) && valueType.IsEnum; } + public static bool IsNullableEnum(this Type type, out Type enumType) + { + if (type.IsNullable(out var valueType) && valueType.IsEnum) + { + enumType = valueType; + return true; + } + + enumType = null; + return false; + } + public static bool IsNullableEnum(this Type type, out Type enumType, out Type underlyingType) { + if (type.IsNullable(out var valueType) && valueType.IsEnum(out underlyingType)) + { + enumType = valueType; + return true; + } + enumType = null; underlyingType = null; - return type.IsNullable(out var valueType) && valueType.IsEnum(out enumType, out underlyingType); + return false; } public static bool IsNullableOf(this Type type, Type valueType) @@ -256,6 +306,24 @@ public static bool IsReadOnlySpanOf(this Type type, Type itemType) type.GetGenericArguments()[0] == itemType; } + public static bool IsNumeric(this Type type) + { + return + type == typeof(int) || + type == typeof(long) || + type == typeof(double) || + type == typeof(float) || + type == typeof(decimal) || + type == typeof(Decimal128); + } + + public static bool IsNumericOrNullableNumeric(this Type type) + { + return + type.IsNumeric() || + type.IsNullable(out var valueType) && valueType.IsNumeric(); + } + public static bool IsSameAsOrNullableOf(this Type type, Type valueType) { return type == valueType || type.IsNullableOf(valueType); @@ -348,5 +416,31 @@ public static bool TryGetIListGenericInterface(this Type type, out Type ilistGen ilistGenericInterface = null; return false; } + + public static bool TryGetIQueryableGenericInterface(this Type type, out Type iqueryableGenericInterface) + { + if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(IQueryable<>)) + { + iqueryableGenericInterface = type; + return true; + } + + foreach (var interfaceType in type.GetInterfaces()) + { + if (interfaceType.IsGenericType && interfaceType.GetGenericTypeDefinition() == typeof(IQueryable<>)) + { + iqueryableGenericInterface = interfaceType; + return true; + } + } + + iqueryableGenericInterface = null; + return false; + } + + private static TValue GetDefaultValueGeneric() + { + return default(TValue); + } } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/MongoQueryProvider.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/MongoQueryProvider.cs index 2717a7e71d7..e70a90a116e 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/MongoQueryProvider.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/MongoQueryProvider.cs @@ -21,8 +21,8 @@ using MongoDB.Bson; using MongoDB.Bson.Serialization; using MongoDB.Driver.Core.Misc; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; using MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToExecutableQueryTranslators; -using MongoDB.Driver.Support; namespace MongoDB.Driver.Linq.Linq3Implementation { @@ -105,7 +105,11 @@ internal MongoQueryProvider( // public methods public override IQueryable CreateQuery(Expression expression) { - var outputType = expression.Type.GetSequenceElementType(); + if (!expression.Type.ImplementsIQueryable(out var outputType)) + { + throw new ExpressionNotSupportedException(expression, because: "expression type does not implement IQueryable"); + } + var queryType = typeof(MongoQuery<,>).MakeGenericType(typeof(TDocument), outputType); return (IQueryable)Activator.CreateInstance(queryType, new object[] { this, expression }); } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/BinaryExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/BinaryExpressionToAggregationExpressionTranslator.cs index 3881a1135a3..0b5e37352f0 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/BinaryExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/BinaryExpressionToAggregationExpressionTranslator.cs @@ -18,11 +18,9 @@ using MongoDB.Bson.Serialization; using MongoDB.Bson.Serialization.Serializers; using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions; -using MongoDB.Driver.Linq.Linq3Implementation.ExtensionMethods; using MongoDB.Driver.Linq.Linq3Implementation.Misc; using MongoDB.Driver.Linq.Linq3Implementation.Serializers; using MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators.MethodTranslators; -using MongoDB.Driver.Support; namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators { diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/DefaultIfEmptyMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/DefaultIfEmptyMethodToAggregationExpressionTranslator.cs index 3d66a096ad8..ce74b191fa0 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/DefaultIfEmptyMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/DefaultIfEmptyMethodToAggregationExpressionTranslator.cs @@ -20,7 +20,6 @@ using MongoDB.Driver.Linq.Linq3Implementation.Misc; using MongoDB.Driver.Linq.Linq3Implementation.Reflection; using MongoDB.Driver.Linq.Linq3Implementation.Serializers; -using MongoDB.Driver.Support; namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators.MethodTranslators { diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ElementAtMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ElementAtMethodToAggregationExpressionTranslator.cs index 17cb037c144..e6e9cf24e1d 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ElementAtMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ElementAtMethodToAggregationExpressionTranslator.cs @@ -18,7 +18,6 @@ using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions; using MongoDB.Driver.Linq.Linq3Implementation.Misc; using MongoDB.Driver.Linq.Linq3Implementation.Reflection; -using MongoDB.Driver.Support; namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators.MethodTranslators { diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/FirstOrLastMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/FirstOrLastMethodToAggregationExpressionTranslator.cs index 932d4c6856b..c95fe42865b 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/FirstOrLastMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/FirstOrLastMethodToAggregationExpressionTranslator.cs @@ -19,7 +19,6 @@ using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions; using MongoDB.Driver.Linq.Linq3Implementation.Misc; using MongoDB.Driver.Linq.Linq3Implementation.Reflection; -using MongoDB.Driver.Support; namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators.MethodTranslators { diff --git a/src/MongoDB.Driver/Support/ReflectionExtensions.cs b/src/MongoDB.Driver/Support/ReflectionExtensions.cs deleted file mode 100644 index 7bae4fa777e..00000000000 --- a/src/MongoDB.Driver/Support/ReflectionExtensions.cs +++ /dev/null @@ -1,191 +0,0 @@ -/* Copyright 2015-present MongoDB Inc. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -using System; -using System.Collections.Generic; -using System.Linq; -using System.Reflection; -using MongoDB.Bson; - -namespace MongoDB.Driver.Support -{ - internal static class ReflectionExtensions - { - public static object GetDefaultValue(this Type type) - { - var genericMethod = typeof(ReflectionExtensions) - .GetMethod(nameof(GetDefaultValueGeneric), BindingFlags.NonPublic | BindingFlags.Static) - .MakeGenericMethod(type); - return genericMethod.Invoke(null, null); - } - - public static bool ImplementsInterface(this Type type, Type iface) - { - if (type.Equals(iface)) - { - return true; - } - - var typeInfo = type.GetTypeInfo(); - if (typeInfo.IsGenericType && type.GetGenericTypeDefinition().Equals(iface)) - { - return true; - } - - return typeInfo.GetInterfaces().Any(i => i.ImplementsInterface(iface)); - } - - public static bool IsBooleanOrNullableBoolean(this Type type) - { - if (type.IsConstructedGenericType && - type.GetGenericTypeDefinition() is var genericTypeDefinition && - genericTypeDefinition == typeof(Nullable<>)) - { - var valueType = type.GetGenericArguments()[0]; - return valueType == typeof(bool); - } - else - { - return type == typeof(bool); - } - } - - public static bool IsNullable(this Type type) - { - return type.GetTypeInfo().IsGenericType && type.GetGenericTypeDefinition() == typeof(Nullable<>); - } - - public static bool IsNullableEnum(this Type type) - { - if (!IsNullable(type)) - { - return false; - } - - return GetNullableUnderlyingType(type).GetTypeInfo().IsEnum; - } - - public static bool IsNumeric(this Type type) - { - return - type == typeof(int) || - type == typeof(long) || - type == typeof(double) || - type == typeof(float) || - type == typeof(decimal) || - type == typeof(Decimal128); - } - - public static bool IsNumericOrNullableNumeric(this Type type) - { - if (type.IsConstructedGenericType && - type.GetGenericTypeDefinition() is var genericTypeDefinition && - genericTypeDefinition == typeof(Nullable<>)) - { - var valueType = type.GetGenericArguments()[0]; - return IsNumeric(valueType); - } - else - { - return IsNumeric(type); - } - } - - public static bool IsConvertibleToEnum(this Type type) - { - return - type == typeof(sbyte) || - type == typeof(short) || - type == typeof(int) || - type == typeof(long) || - type == typeof(byte) || - type == typeof(ushort) || - type == typeof(uint) || - type == typeof(ulong) || - type == typeof(Enum) || - type == typeof(string); - } - - public static Type GetNullableUnderlyingType(this Type type) - { - if (!IsNullable(type)) - { - throw new ArgumentException("Type must be nullable.", "type"); - } - - return type.GetTypeInfo().GetGenericArguments()[0]; - } - - public static Type GetSequenceElementType(this Type type) - { - Type ienum = FindIEnumerable(type); - if (ienum == null) { return type; } - return ienum.GetTypeInfo().GetGenericArguments()[0]; - } - - public static Type FindIEnumerable(this Type seqType) - { - if (seqType == null || seqType == typeof(string)) - { - return null; - } - - var seqTypeInfo = seqType.GetTypeInfo(); - if (seqTypeInfo.IsGenericType && seqTypeInfo.GetGenericTypeDefinition() == typeof(IEnumerable<>)) - { - return seqType; - } - - if (seqTypeInfo.IsArray) - { - return typeof(IEnumerable<>).MakeGenericType(seqType.GetElementType()); - } - - if (seqTypeInfo.IsGenericType) - { - foreach (Type arg in seqTypeInfo.GetGenericArguments()) - { - Type ienum = typeof(IEnumerable<>).MakeGenericType(arg); - if (ienum.GetTypeInfo().IsAssignableFrom(seqType)) - { - return ienum; - } - } - } - - Type[] ifaces = seqTypeInfo.GetInterfaces(); - if (ifaces != null && ifaces.Length > 0) - { - foreach (Type iface in ifaces) - { - Type ienum = FindIEnumerable(iface); - if (ienum != null) { return ienum; } - } - } - - if (seqTypeInfo.BaseType != null && seqTypeInfo.BaseType != typeof(object)) - { - return FindIEnumerable(seqTypeInfo.BaseType); - } - - return null; - } - - private static TValue GetDefaultValueGeneric() - { - return default(TValue); - } - } -} From 90e19ff402244c62015854d14d9a4661bfb04475 Mon Sep 17 00:00:00 2001 From: rstam Date: Mon, 17 Nov 2025 15:43:46 -0800 Subject: [PATCH 2/2] CSHARP-5632: Requested changes --- .../Misc/TypeExtensions.cs | 93 +++++-------------- 1 file changed, 22 insertions(+), 71 deletions(-) diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/TypeExtensions.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/TypeExtensions.cs index 7b3f5eb4ebb..636c616deb3 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/TypeExtensions.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/TypeExtensions.cs @@ -24,7 +24,7 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Misc { internal static class TypeExtensions { - private static readonly Type[] __dictionaryInterfaces = + private static readonly Type[] __dictionaryInterfaceDefinitions = { typeof(IDictionary<,>), typeof(IReadOnlyDictionary<,>) @@ -102,7 +102,7 @@ public static bool Implements(this Type type, Type @interface) public static bool ImplementsDictionaryInterface(this Type type, out Type keyType, out Type valueType) { - if (TryGetGenericInterface(type, __dictionaryInterfaces, out var dictionaryInterface)) + if (TryGetGenericInterface(type, __dictionaryInterfaceDefinitions, out var dictionaryInterface)) { var genericArguments = dictionaryInterface.GetGenericArguments(); keyType = genericArguments[0]; @@ -226,11 +226,9 @@ public static bool IsEnum(this Type type, out Type underlyingType) underlyingType = Enum.GetUnderlyingType(type); return true; } - else - { - underlyingType = null; - return false; - } + + underlyingType = null; + return false; } public static bool IsEnumOrNullableEnum(this Type type, out Type enumType, out Type underlyingType) @@ -256,11 +254,9 @@ public static bool IsNullable(this Type type, out Type valueType) valueType = type.GetGenericArguments()[0]; return true; } - else - { - valueType = null; - return false; - } + + valueType = null; + return false; } public static bool IsNullableEnum(this Type type) @@ -366,77 +362,32 @@ public static bool IsValueTuple(this Type type) __valueTupleTypeDefinitions.Contains(typeDefinition); } - public static bool TryGetGenericInterface(this Type type, Type[] interfaceDefinitions, out Type genericInterface) + public static bool TryGetGenericInterface(this Type type, Type genericInterfaceDefintion, out Type genericInterface) { genericInterface = - type.IsConstructedGenericType && interfaceDefinitions.Contains(type.GetGenericTypeDefinition()) ? + type.IsConstructedGenericType && type.GetGenericTypeDefinition() == genericInterfaceDefintion ? type : - type.GetInterfaces().FirstOrDefault(i => i.IsConstructedGenericType && interfaceDefinitions.Contains(i.GetGenericTypeDefinition())); + type.GetInterfaces().FirstOrDefault(i => i.IsConstructedGenericType && i.GetGenericTypeDefinition() == genericInterfaceDefintion); return genericInterface != null; } - public static bool TryGetIEnumerableGenericInterface(this Type type, out Type ienumerableGenericInterface) + public static bool TryGetGenericInterface(this Type type, Type[] genericInterfaceDefinitions, out Type genericInterface) { - if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(IEnumerable<>)) - { - ienumerableGenericInterface = type; - return true; - } - - foreach (var interfaceType in type.GetInterfaces()) - { - if (interfaceType.IsGenericType && interfaceType.GetGenericTypeDefinition() == typeof(IEnumerable<>)) - { - ienumerableGenericInterface = interfaceType; - return true; - } - } - - ienumerableGenericInterface = null; - return false; + genericInterface = + type.IsConstructedGenericType && genericInterfaceDefinitions.Contains(type.GetGenericTypeDefinition()) ? + type : + type.GetInterfaces().FirstOrDefault(i => i.IsConstructedGenericType && genericInterfaceDefinitions.Contains(i.GetGenericTypeDefinition())); + return genericInterface != null; } - public static bool TryGetIListGenericInterface(this Type type, out Type ilistGenericInterface) - { - if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(IList<>)) - { - ilistGenericInterface = type; - return true; - } - - foreach (var interfaceType in type.GetInterfaces()) - { - if (interfaceType.IsGenericType && interfaceType.GetGenericTypeDefinition() == typeof(IList<>)) - { - ilistGenericInterface = interfaceType; - return true; - } - } + public static bool TryGetIEnumerableGenericInterface(this Type type, out Type ienumerableGenericInterface) + => TryGetGenericInterface(type, typeof(IEnumerable<>), out ienumerableGenericInterface); - ilistGenericInterface = null; - return false; - } + public static bool TryGetIListGenericInterface(this Type type, out Type ilistGenericInterface) + => TryGetGenericInterface(type, typeof(IList<>), out ilistGenericInterface); public static bool TryGetIQueryableGenericInterface(this Type type, out Type iqueryableGenericInterface) - { - if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(IQueryable<>)) - { - iqueryableGenericInterface = type; - return true; - } - - foreach (var interfaceType in type.GetInterfaces()) - { - if (interfaceType.IsGenericType && interfaceType.GetGenericTypeDefinition() == typeof(IQueryable<>)) - { - iqueryableGenericInterface = interfaceType; - return true; - } - } - - iqueryableGenericInterface = null; - return false; - } + => TryGetGenericInterface(type, typeof(IQueryable<>), out iqueryableGenericInterface); private static TValue GetDefaultValueGeneric() {