From 096a14f19c8c4467c20409b54eb6d984a7de6c06 Mon Sep 17 00:00:00 2001 From: Kristian Hellang Date: Fri, 18 Aug 2023 14:04:43 +0200 Subject: [PATCH] Leverage keyed services for decoration instead of DecoratedType hack --- src/Scrutor/ClosedTypeDecorationStrategy.cs | 6 +- src/Scrutor/DecoratedType.cs | 119 ------------------ src/Scrutor/DecorationStrategy.cs | 10 +- src/Scrutor/OpenGenericDecorationStrategy.cs | 6 +- .../ServiceCollectionExtensions.Decoration.cs | 39 ++++-- src/Scrutor/ServiceDescriptorExtensions.cs | 48 +++++-- .../ServiceCollectionExtensions.cs | 2 +- 7 files changed, 83 insertions(+), 147 deletions(-) delete mode 100644 src/Scrutor/DecoratedType.cs diff --git a/src/Scrutor/ClosedTypeDecorationStrategy.cs b/src/Scrutor/ClosedTypeDecorationStrategy.cs index c2076ff..06c5768 100644 --- a/src/Scrutor/ClosedTypeDecorationStrategy.cs +++ b/src/Scrutor/ClosedTypeDecorationStrategy.cs @@ -16,16 +16,16 @@ public ClosedTypeDecorationStrategy(Type serviceType, Type? decoratorType, Func< public override bool CanDecorate(Type serviceType) => ServiceType == serviceType; - public override Func CreateDecorator(Type serviceType) + public override Func CreateDecorator(Type serviceType, string serviceKey) { if (DecoratorType is not null) { - return TypeDecorator(serviceType, DecoratorType); + return TypeDecorator(serviceType, serviceKey, DecoratorType); } if (DecoratorFactory is not null) { - return FactoryDecorator(serviceType, DecoratorFactory); + return FactoryDecorator(serviceType, serviceKey, DecoratorFactory); } throw new InvalidOperationException($"Both serviceType and decoratorFactory can not be null."); diff --git a/src/Scrutor/DecoratedType.cs b/src/Scrutor/DecoratedType.cs deleted file mode 100644 index 69ad63c..0000000 --- a/src/Scrutor/DecoratedType.cs +++ /dev/null @@ -1,119 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Globalization; -using System.Reflection; -using System.Runtime.InteropServices; - -namespace Scrutor; - -public class DecoratedType : Type -{ - public DecoratedType(Type type) => ProxiedType = type; - private Type ProxiedType { get; } - - // We use object reference equality here to ensure that only the decorating object can match. - public override bool Equals(Type? o) => ReferenceEquals(this, o); - public override bool Equals(object? o) => ReferenceEquals(this, o); - public override int GetHashCode() => ProxiedType.GetHashCode(); - public override string? Namespace => ProxiedType.Namespace; - public override string? AssemblyQualifiedName => ProxiedType.AssemblyQualifiedName; - public override string? FullName => ProxiedType.FullName; - public override Assembly Assembly => ProxiedType.Assembly; - public override Module Module => ProxiedType.Module; - public override Type? DeclaringType => ProxiedType.DeclaringType; - public override MethodBase? DeclaringMethod => ProxiedType.DeclaringMethod; - public override Type? ReflectedType => ProxiedType.ReflectedType; - public override Type UnderlyingSystemType => ProxiedType.UnderlyingSystemType; - -#if NETCOREAPP3_1_OR_GREATER - public override bool IsTypeDefinition => ProxiedType.IsTypeDefinition; -#endif - protected override bool IsArrayImpl() => ProxiedType.HasElementType; - protected override bool IsByRefImpl() => ProxiedType.IsByRef; - protected override bool IsPointerImpl() => ProxiedType.IsPointer; - public override bool IsConstructedGenericType => ProxiedType.IsConstructedGenericType; - public override bool IsGenericParameter => ProxiedType.IsGenericParameter; -#if NETCOREAPP3_1_OR_GREATER - public override bool IsGenericTypeParameter => ProxiedType.IsGenericTypeParameter; - public override bool IsGenericMethodParameter => ProxiedType.IsGenericMethodParameter; -#endif - public override bool IsGenericType => ProxiedType.IsGenericType; - public override bool IsGenericTypeDefinition => ProxiedType.IsGenericTypeDefinition; -#if NETCOREAPP3_1_OR_GREATER - public override bool IsSZArray => ProxiedType.IsSZArray; - public override bool IsVariableBoundArray => ProxiedType.IsVariableBoundArray; - public override bool IsByRefLike => ProxiedType.IsByRefLike; -#endif - protected override bool HasElementTypeImpl() => ProxiedType.HasElementType; - public override Type? GetElementType() => ProxiedType.GetElementType(); - public override int GetArrayRank() => ProxiedType.GetArrayRank(); - public override Type GetGenericTypeDefinition() => ProxiedType.GetGenericTypeDefinition(); - public override Type[] GetGenericArguments() => ProxiedType.GetGenericArguments(); - public override int GenericParameterPosition => ProxiedType.GenericParameterPosition; - public override GenericParameterAttributes GenericParameterAttributes => ProxiedType.GenericParameterAttributes; - public override Type[] GetGenericParameterConstraints() => ProxiedType.GetGenericParameterConstraints(); - protected override TypeAttributes GetAttributeFlagsImpl() => ProxiedType.Attributes; - protected override bool IsCOMObjectImpl() => ProxiedType.IsCOMObject; - protected override bool IsContextfulImpl() => ProxiedType.IsContextful; - public override bool IsEnum => ProxiedType.IsEnum; - protected override bool IsMarshalByRefImpl() => ProxiedType.IsMarshalByRef; - protected override bool IsPrimitiveImpl() => ProxiedType.IsPrimitive; - protected override bool IsValueTypeImpl() => ProxiedType.IsValueType; -#if NETCOREAPP3_1_OR_GREATER - public override bool IsSignatureType => ProxiedType.IsSignatureType; -#endif - public override bool IsSecurityCritical => ProxiedType.IsSecurityCritical; - public override bool IsSecuritySafeCritical => ProxiedType.IsSecuritySafeCritical; - public override bool IsSecurityTransparent => ProxiedType.IsSecurityTransparent; - public override StructLayoutAttribute? StructLayoutAttribute => ProxiedType.StructLayoutAttribute; - protected override ConstructorInfo? GetConstructorImpl(BindingFlags bindingAttr, Binder? binder, CallingConventions callConvention, Type[] types, ParameterModifier[]? modifiers) - => ProxiedType.GetConstructor(bindingAttr, binder, callConvention, types, modifiers); - public override ConstructorInfo[] GetConstructors(BindingFlags bindingAttr) => ProxiedType.GetConstructors(bindingAttr); - public override EventInfo? GetEvent(string name, BindingFlags bindingAttr) => ProxiedType.GetEvent(name, bindingAttr); - public override EventInfo[] GetEvents() => ProxiedType.GetEvents(); - public override EventInfo[] GetEvents(BindingFlags bindingAttr) => ProxiedType.GetEvents(bindingAttr); - public override FieldInfo? GetField(string name, BindingFlags bindingAttr) => ProxiedType.GetField(name, bindingAttr); - public override FieldInfo[] GetFields(BindingFlags bindingAttr) => ProxiedType.GetFields(bindingAttr); - public override MemberInfo[] GetMember(string name, BindingFlags bindingAttr) => ProxiedType.GetMember(name, bindingAttr); - public override MemberInfo[] GetMember(string name, MemberTypes type, BindingFlags bindingAttr) => ProxiedType.GetMember(name, type, bindingAttr); -#if NET6_0 - public override MemberInfo GetMemberWithSameMetadataDefinitionAs(MemberInfo member) => ProxiedType.GetMemberWithSameMetadataDefinitionAs(member); -#endif - public override MemberInfo[] GetMembers(BindingFlags bindingAttr) => ProxiedType.GetMembers(bindingAttr); - protected override MethodInfo? GetMethodImpl(string name, BindingFlags bindingAttr, Binder? binder, CallingConventions callConvention, Type[]? types, ParameterModifier[]? modifiers) - => ProxiedType.GetMethod(name, bindingAttr, binder, callConvention, types!, modifiers); - public override MethodInfo[] GetMethods(BindingFlags bindingAttr) => ProxiedType.GetMethods(bindingAttr); - public override Type? GetNestedType(string name, BindingFlags bindingAttr) => ProxiedType.GetNestedType(name, bindingAttr); - public override Type[] GetNestedTypes(BindingFlags bindingAttr) => ProxiedType.GetNestedTypes(bindingAttr); - protected override PropertyInfo? GetPropertyImpl(string name, BindingFlags bindingAttr, Binder? binder, Type? returnType, Type[]? types, ParameterModifier[]? modifiers) - => ProxiedType.GetProperty(name, bindingAttr, binder, returnType, types!, modifiers); - public override PropertyInfo[] GetProperties(BindingFlags bindingAttr) => ProxiedType.GetProperties(bindingAttr); - public override MemberInfo[] GetDefaultMembers() => ProxiedType.GetDefaultMembers(); - public override RuntimeTypeHandle TypeHandle => ProxiedType.TypeHandle; - protected override TypeCode GetTypeCodeImpl() => GetTypeCode(ProxiedType); - public override Guid GUID => ProxiedType.GUID; - public override Type? BaseType => ProxiedType.BaseType; - public override object? InvokeMember(string name, BindingFlags invokeAttr, Binder? binder, object? target, object?[]? args, ParameterModifier[]? modifiers, CultureInfo? culture, string[]? namedParameters) => - ProxiedType.InvokeMember(name, invokeAttr, binder, target, args, modifiers, culture, namedParameters); - public override Type? GetInterface(string name, bool ignoreCase) => ProxiedType.GetInterface(name, ignoreCase); - public override Type[] GetInterfaces() => ProxiedType.GetInterfaces(); - public override InterfaceMapping GetInterfaceMap(Type interfaceType) => ProxiedType.GetInterfaceMap(interfaceType); - public override bool IsInstanceOfType(object? o) => ProxiedType.IsInstanceOfType(o); - public override bool IsEquivalentTo(Type? other) => ProxiedType.IsEquivalentTo(other); - public override Type GetEnumUnderlyingType() => ProxiedType.GetEnumUnderlyingType(); - public override Array GetEnumValues() => ProxiedType.GetEnumValues(); - public override Type MakeArrayType() => ProxiedType.MakeArrayType(); - public override Type MakeArrayType(int rank) => ProxiedType.MakeArrayType(rank); - public override Type MakeByRefType() => ProxiedType.MakeByRefType(); - public override Type MakeGenericType(params Type[] typeArguments) => ProxiedType.MakeGenericType(typeArguments); - public override Type MakePointerType() => ProxiedType.MakePointerType(); - public override string ToString() => "Type: " + Name; - public override MemberTypes MemberType => ProxiedType.MemberType; - public override string Name => $"{ProxiedType.Name}+Decorated"; - public override IEnumerable CustomAttributes => ProxiedType.CustomAttributes; - public override int MetadataToken => ProxiedType.MetadataToken; - public override object[] GetCustomAttributes(bool inherit) => ProxiedType.GetCustomAttributes(inherit); - public override object[] GetCustomAttributes(Type attributeType, bool inherit) => ProxiedType.GetCustomAttributes(attributeType, inherit); - public override bool IsDefined(Type attributeType, bool inherit) => ProxiedType.IsDefined(attributeType, inherit); - public override IList GetCustomAttributesData() => ProxiedType.GetCustomAttributesData(); -} diff --git a/src/Scrutor/DecorationStrategy.cs b/src/Scrutor/DecorationStrategy.cs index ff6340b..746da94 100644 --- a/src/Scrutor/DecorationStrategy.cs +++ b/src/Scrutor/DecorationStrategy.cs @@ -14,7 +14,7 @@ protected DecorationStrategy(Type serviceType) public abstract bool CanDecorate(Type serviceType); - public abstract Func CreateDecorator(Type serviceType); + public abstract Func CreateDecorator(Type serviceType, string serviceKey); internal static DecorationStrategy WithType(Type serviceType, Type decoratorType) => Create(serviceType, decoratorType, decoratorFactory: null); @@ -22,15 +22,15 @@ internal static DecorationStrategy WithType(Type serviceType, Type decoratorType internal static DecorationStrategy WithFactory(Type serviceType, Func decoratorFactory) => Create(serviceType, decoratorType: null, decoratorFactory); - protected static Func TypeDecorator(Type serviceType, Type decoratorType) => serviceProvider => + protected static Func TypeDecorator(Type serviceType, string serviceKey, Type decoratorType) => (serviceProvider, _) => { - var instanceToDecorate = serviceProvider.GetRequiredService(serviceType); + var instanceToDecorate = serviceProvider.GetRequiredKeyedService(serviceType, serviceKey); return ActivatorUtilities.CreateInstance(serviceProvider, decoratorType, instanceToDecorate); }; - protected static Func FactoryDecorator(Type decorated, Func decoratorFactory) => serviceProvider => + protected static Func FactoryDecorator(Type serviceType, string serviceKey, Func decoratorFactory) => (serviceProvider, _) => { - var instanceToDecorate = serviceProvider.GetRequiredService(decorated); + var instanceToDecorate = serviceProvider.GetRequiredKeyedService(serviceType, serviceKey); return decoratorFactory(instanceToDecorate, serviceProvider); }; diff --git a/src/Scrutor/OpenGenericDecorationStrategy.cs b/src/Scrutor/OpenGenericDecorationStrategy.cs index 40a7ab7..62f014f 100644 --- a/src/Scrutor/OpenGenericDecorationStrategy.cs +++ b/src/Scrutor/OpenGenericDecorationStrategy.cs @@ -20,19 +20,19 @@ public override bool CanDecorate(Type serviceType) => && serviceType.GetGenericTypeDefinition() == ServiceType.GetGenericTypeDefinition() && (DecoratorType is null || serviceType.HasCompatibleGenericArguments(DecoratorType)); - public override Func CreateDecorator(Type serviceType) + public override Func CreateDecorator(Type serviceType, string serviceKey) { if (DecoratorType is not null) { var genericArguments = serviceType.GetGenericArguments(); var closedDecorator = DecoratorType.MakeGenericType(genericArguments); - return TypeDecorator(serviceType, closedDecorator); + return TypeDecorator(serviceType, serviceKey, closedDecorator); } if (DecoratorFactory is not null) { - return FactoryDecorator(serviceType, DecoratorFactory); + return FactoryDecorator(serviceType, serviceKey, DecoratorFactory); } throw new InvalidOperationException($"Both serviceType and decoratorFactory can not be null."); diff --git a/src/Scrutor/ServiceCollectionExtensions.Decoration.cs b/src/Scrutor/ServiceCollectionExtensions.Decoration.cs index 8d08174..36f235b 100644 --- a/src/Scrutor/ServiceCollectionExtensions.Decoration.cs +++ b/src/Scrutor/ServiceCollectionExtensions.Decoration.cs @@ -8,6 +8,8 @@ namespace Microsoft.Extensions.DependencyInjection; [PublicAPI] public static partial class ServiceCollectionExtensions { + private const string DecoratedServiceKeySuffix = "+Decorated"; + /// /// Decorates all registered services of type /// using the specified type . @@ -250,27 +252,48 @@ public static bool TryDecorate(this IServiceCollection services, DecorationStrat { var serviceDescriptor = services[i]; - if (serviceDescriptor.ServiceType is DecoratedType) + if (IsDecorated(serviceDescriptor) || !strategy.CanDecorate(serviceDescriptor.ServiceType)) { - continue; // Service has already been decorated. + continue; } - if (!strategy.CanDecorate(serviceDescriptor.ServiceType)) + var serviceKey = GetDecoratorKey(serviceDescriptor); + if (serviceKey is null) { - continue; // Unable to decorate using the specified strategy. + return false; } - var decoratedType = new DecoratedType(serviceDescriptor.ServiceType); - // Insert decorated - services.Add(serviceDescriptor.WithServiceType(decoratedType)); + services.Add(serviceDescriptor.WithServiceKey(serviceKey)); // Replace decorator - services[i] = serviceDescriptor.WithImplementationFactory(strategy.CreateDecorator(decoratedType)); + services[i] = serviceDescriptor.WithImplementationFactory(strategy.CreateDecorator(serviceDescriptor.ServiceType, serviceKey)); decorated = true; } return decorated; } + + private static string? GetDecoratorKey(ServiceDescriptor descriptor) + { + var uniqueId = Guid.NewGuid().ToString("n"); + + if (descriptor.ServiceKey is null) + { + return $"{descriptor.ServiceType.Name}+{uniqueId}{DecoratedServiceKeySuffix}"; + } + + if (descriptor.ServiceKey is string stringKey) + { + return $"{stringKey}+{uniqueId}{DecoratedServiceKeySuffix}"; + } + + return null; + } + + private static bool IsDecorated(ServiceDescriptor descriptor) + { + return descriptor.ServiceKey is string stringKey && stringKey.EndsWith(DecoratedServiceKeySuffix); + } } diff --git a/src/Scrutor/ServiceDescriptorExtensions.cs b/src/Scrutor/ServiceDescriptorExtensions.cs index 57d2aac..8c2e90d 100644 --- a/src/Scrutor/ServiceDescriptorExtensions.cs +++ b/src/Scrutor/ServiceDescriptorExtensions.cs @@ -5,14 +5,46 @@ namespace Scrutor; internal static class ServiceDescriptorExtensions { - public static ServiceDescriptor WithImplementationFactory(this ServiceDescriptor descriptor, Func implementationFactory) => - new(descriptor.ServiceType, implementationFactory, descriptor.Lifetime); + public static ServiceDescriptor WithImplementationFactory(this ServiceDescriptor descriptor, Func implementationFactory) => + new(descriptor.ServiceType, descriptor.ServiceKey, implementationFactory, descriptor.Lifetime); - public static ServiceDescriptor WithServiceType(this ServiceDescriptor descriptor, Type serviceType) => descriptor switch + public static ServiceDescriptor WithServiceKey(this ServiceDescriptor descriptor, string serviceKey) { - { ImplementationType: not null } => new ServiceDescriptor(serviceType, descriptor.ImplementationType, descriptor.Lifetime), - { ImplementationFactory: not null } => new ServiceDescriptor(serviceType, descriptor.ImplementationFactory, descriptor.Lifetime), - { ImplementationInstance: not null } => new ServiceDescriptor(serviceType, descriptor.ImplementationInstance), - _ => throw new ArgumentException($"No implementation factory or instance or type found for {descriptor.ServiceType}.", nameof(descriptor)) - }; + if (descriptor.IsKeyedService) + { + if (descriptor.KeyedImplementationType is not null) + { + return new ServiceDescriptor(descriptor.ServiceType, serviceKey, descriptor.KeyedImplementationType, descriptor.Lifetime); + } + + if (descriptor.KeyedImplementationInstance is not null) + { + return new ServiceDescriptor(descriptor.ServiceType, serviceKey, descriptor.KeyedImplementationInstance); + } + + if (descriptor.KeyedImplementationFactory is not null) + { + return new ServiceDescriptor(descriptor.ServiceType, serviceKey, descriptor.KeyedImplementationFactory, descriptor.Lifetime); + } + + throw new InvalidOperationException($"One of the following properties must be set: {nameof(ServiceDescriptor.KeyedImplementationType)}, {nameof(ServiceDescriptor.KeyedImplementationInstance)} or {nameof(ServiceDescriptor.KeyedImplementationFactory)}"); + } + + if (descriptor.ImplementationType is not null) + { + return new ServiceDescriptor(descriptor.ServiceType, serviceKey, descriptor.ImplementationType, descriptor.Lifetime); + } + + if (descriptor.ImplementationInstance is not null) + { + return new ServiceDescriptor(descriptor.ServiceType, serviceKey, descriptor.ImplementationInstance); + } + + if (descriptor.ImplementationFactory is not null) + { + return new ServiceDescriptor(descriptor.ServiceType, serviceKey, (sp, key) => descriptor.ImplementationFactory(sp), descriptor.Lifetime); + } + + throw new InvalidOperationException($"One of the following properties must be set: {nameof(ServiceDescriptor.ImplementationType)}, {nameof(ServiceDescriptor.ImplementationInstance)} or {nameof(ServiceDescriptor.ImplementationFactory)}"); + } } diff --git a/test/Scrutor.Tests/ServiceCollectionExtensions.cs b/test/Scrutor.Tests/ServiceCollectionExtensions.cs index 6a6cb84..011cb3f 100644 --- a/test/Scrutor.Tests/ServiceCollectionExtensions.cs +++ b/test/Scrutor.Tests/ServiceCollectionExtensions.cs @@ -18,6 +18,6 @@ public static ServiceDescriptor[] GetDescriptors(this IServiceCollection serv public static ServiceDescriptor[] GetDescriptors(this IServiceCollection services, Type serviceType) { - return services.Where(x => x.ServiceType == serviceType).ToArray(); + return services.Where(x => x.ServiceType == serviceType && x.ServiceKey is null).ToArray(); } }