diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/ManualTypeMarshallingHelper.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/ManualTypeMarshallingHelper.cs index 6ec395354d2fc..7a0cbac591a4a 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/ManualTypeMarshallingHelper.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/ManualTypeMarshallingHelper.cs @@ -16,6 +16,7 @@ namespace Microsoft.Interop public readonly record struct CustomTypeMarshallerData( ManagedTypeInfo MarshallerType, ManagedTypeInfo NativeType, + bool HasState, MarshallerShape Shape, bool IsStrictlyBlittable, ManagedTypeInfo? BufferElementType); @@ -73,6 +74,11 @@ public static bool IsLinearCollectionEntryPoint(ITypeSymbol entryPointType) return false; } + public static bool HasEntryPointMarshallerAttribute(ITypeSymbol entryPointType) + { + return entryPointType.GetAttributes().Any(attr => attr.AttributeClass.ToDisplayString() == TypeNames.CustomMarshallerAttribute); + } + public static bool TryGetMarshallersFromEntryType( INamedTypeSymbol entryPointType, ITypeSymbol managedType, @@ -288,7 +294,20 @@ public static (AttributeData? attribute, INamedTypeSymbol? entryType) GetDefault private static CustomTypeMarshallerData? GetMarshallerDataForType(ITypeSymbol marshallerType, MarshallingDirection direction, ITypeSymbol managedType, Compilation compilation) { - (MarshallerShape shape, Dictionary methodsByShape) = MarshallerShapeHelper.GetShapeForType(marshallerType, managedType, compilation); + if (marshallerType is { IsStatic: true, TypeKind: TypeKind.Class }) + { + return GetStatelessMarshallerDataForType(marshallerType, direction, managedType, compilation); + } + if (marshallerType.IsValueType) + { + return GetStatefulMarshallerDataForType(marshallerType, direction, managedType, compilation); + } + return null; + } + + private static CustomTypeMarshallerData? GetStatelessMarshallerDataForType(ITypeSymbol marshallerType, MarshallingDirection direction, ITypeSymbol managedType, Compilation compilation) + { + (MarshallerShape shape, Dictionary methodsByShape) = StatelessMarshallerShapeHelper.GetShapeForType(marshallerType, managedType, compilation); ITypeSymbol? nativeType = null; if (direction.HasFlag(MarshallingDirection.ManagedToUnmanaged)) @@ -339,6 +358,56 @@ public static (AttributeData? attribute, INamedTypeSymbol? entryType) GetDefault return new CustomTypeMarshallerData( ManagedTypeInfo.CreateTypeInfoForTypeSymbol(marshallerType), ManagedTypeInfo.CreateTypeInfoForTypeSymbol(nativeType), + HasState: false, + shape, + nativeType.IsStrictlyBlittable(), + bufferElementType); + } + + private static CustomTypeMarshallerData? GetStatefulMarshallerDataForType(ITypeSymbol marshallerType, MarshallingDirection direction, ITypeSymbol managedType, Compilation compilation) + { + (MarshallerShape shape, StatefulMarshallerShapeHelper.MarshallerMethods methods) = StatefulMarshallerShapeHelper.GetShapeForType(marshallerType, managedType, compilation); + + ITypeSymbol? nativeType = null; + if (direction.HasFlag(MarshallingDirection.ManagedToUnmanaged)) + { + if (!shape.HasFlag(MarshallerShape.CallerAllocatedBuffer) && !shape.HasFlag(MarshallerShape.ToUnmanaged)) + return null; + + if (methods.ToUnmanaged is not null) + { + nativeType = methods.ToUnmanaged.ReturnType; + } + } + + if (nativeType is null && direction.HasFlag(MarshallingDirection.UnmanagedToManaged)) + { + if (!shape.HasFlag(MarshallerShape.GuaranteedUnmarshal) && !shape.HasFlag(MarshallerShape.ToManaged)) + return null; + + if (methods.FromUnmanaged is not null) + { + nativeType = methods.FromUnmanaged.Parameters[0].Type; + } + } + + // Bidirectional requires ToUnmanaged without the caller-allocated buffer + if (direction.HasFlag(MarshallingDirection.Bidirectional) && !shape.HasFlag(MarshallerShape.ToUnmanaged)) + return null; + + if (nativeType is null) + return null; + + ManagedTypeInfo bufferElementType = null; + if (methods.FromManagedWithBuffer is not null) + { + bufferElementType = ManagedTypeInfo.CreateTypeInfoForTypeSymbol(((INamedTypeSymbol)methods.FromManagedWithBuffer.Parameters[1].Type).TypeArguments[0]); + } + + return new CustomTypeMarshallerData( + ManagedTypeInfo.CreateTypeInfoForTypeSymbol(marshallerType), + ManagedTypeInfo.CreateTypeInfoForTypeSymbol(nativeType), + HasState: true, shape, nativeType.IsStrictlyBlittable(), bufferElementType); diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshallerShape.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshallerShape.cs index 47e0e8b6e94a7..d115e7fbb9305 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshallerShape.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshallerShape.cs @@ -16,10 +16,12 @@ public enum MarshallerShape None = 0x0, ToUnmanaged = 0x1, CallerAllocatedBuffer = 0x2, - PinnableReference = 0x4, - ToManaged = 0x8, - GuaranteedUnmarshal = 0x10, - Free = 0x20, + StatelessPinnableReference = 0x4, + StatefulPinnableReference = 0x8, + ToManaged = 0x10, + GuaranteedUnmarshal = 0x20, + Free = 0x40, + NotifyInvokeSucceeded = 0x80, } public static class ShapeMemberNames @@ -36,6 +38,20 @@ public static class Stateless public const string ConvertToManagedGuaranteed = nameof(ConvertToManagedGuaranteed); public const string ConvertToUnmanaged = nameof(ConvertToUnmanaged); } + + public static class Stateful + { + // Managed to Unmanaged + public const string FromManaged = nameof(FromManaged); + public const string ToUnmanaged = nameof(ToUnmanaged); + // Unmanaged to managed + public const string ToManaged = nameof(ToManaged); + public const string ToManagedGuaranteed = nameof(ToManagedGuaranteed); + public const string FromUnmanaged = nameof(FromUnmanaged); + // Optional features + public const string Free = nameof(Free); + public const string NotifyInvokeSucceeded = nameof(NotifyInvokeSucceeded); + } } public static class LinearCollection @@ -53,10 +69,28 @@ public static class Stateless public const string GetManagedValuesDestination = nameof(GetManagedValuesDestination); public const string GetUnmanagedValuesSource = nameof(GetUnmanagedValuesSource); } + + public static class Stateful + { + // Managed to Unmanaged + public const string FromManaged = nameof(FromManaged); + public const string ToUnmanaged = nameof(ToUnmanaged); + public const string GetManagedValuesSource = nameof(GetManagedValuesSource); + public const string GetUnmanagedValuesDestination = nameof(GetUnmanagedValuesDestination); + // Unmanaged to managed + public const string GetManagedValuesDestination = nameof(GetManagedValuesDestination); + public const string GetUnmanagedValuesSource = nameof(GetUnmanagedValuesSource); + public const string ToManaged = nameof(ToManaged); + public const string ToManagedGuaranteed = nameof(ToManagedGuaranteed); + public const string FromUnmanaged = nameof(FromUnmanaged); + // Optional features + public const string Free = nameof(Free); + public const string NotifyInvokeSucceeded = nameof(NotifyInvokeSucceeded); + } } } - public static class MarshallerShapeHelper + public static class StatelessMarshallerShapeHelper { public static (MarshallerShape, Dictionary) GetShapeForType(ITypeSymbol marshallerType, ITypeSymbol managedType, Compilation compilation) { @@ -80,9 +114,9 @@ public static (MarshallerShape, Dictionary) GetS if (method is not null) AddMethod(MarshallerShape.GuaranteedUnmarshal, method); - method = GetStatelessGetPinnableReference(marshallerType); + method = GetStatelessGetPinnableReference(marshallerType, managedType); if (method is not null) - AddMethod(MarshallerShape.PinnableReference, method); + AddMethod(MarshallerShape.StatelessPinnableReference, method); method = GetStatelessFree(marshallerType); if (method is not null) @@ -104,12 +138,13 @@ void AddMethod(MarshallerShape shapeToAdd, IMethodSymbol methodToAdd) .FirstOrDefault(m => m is { IsStatic: true, Parameters.Length: 1, ReturnsVoid: true }); } - private static IMethodSymbol? GetStatelessGetPinnableReference(ITypeSymbol type) + private static IMethodSymbol? GetStatelessGetPinnableReference(ITypeSymbol type, ITypeSymbol managedType) { return type.GetMembers(ShapeMemberNames.GetPinnableReference) .OfType() .FirstOrDefault(m => m is { IsStatic: true, Parameters.Length: 1 } and - ({ ReturnsByRef: true } or { ReturnsByRefReadonly: true })); + ({ ReturnsByRef: true } or { ReturnsByRefReadonly: true }) + && SymbolEqualityComparer.Default.Equals(m.Parameters[0].Type, managedType)); } private static IMethodSymbol? GetConvertToUnmanagedMethod(ITypeSymbol type, ITypeSymbol managedType) @@ -174,4 +209,235 @@ static bool IsSpanOfUnmanagedType(ITypeSymbol typeToCheck, ITypeSymbol spanOfT, && SymbolEqualityComparer.Default.Equals(managedType, m.ReturnType)); } } + + public static class StatefulMarshallerShapeHelper + { + public record MarshallerMethods + { + public IMethodSymbol? FromManaged { get; init; } + public IMethodSymbol? FromManagedWithBuffer { get; init; } + public IMethodSymbol? ToManaged { get; init; } + public IMethodSymbol? ToManagedGuranteed { get; init; } + public IMethodSymbol? FromUnmanaged { get; init; } + public IMethodSymbol? ToUnmanaged { get; init; } + public IMethodSymbol? Free { get; init; } + public IMethodSymbol? NotifyInvokeSucceeded { get; init; } + } + + public static (MarshallerShape shape, MarshallerMethods methods) GetShapeForType(ITypeSymbol marshallerType, ITypeSymbol managedType, Compilation compilation) + { + MarshallerShape shape = MarshallerShape.None; + MarshallerMethods methods = new(); + + ITypeSymbol? unmanagedType = null; + + IMethodSymbol? fromManaged = GetFromManagedMethod(marshallerType, managedType); + INamedTypeSymbol spanOfT = compilation.GetTypeByMetadataName(TypeNames.System_Span_Metadata)!; + IMethodSymbol? fromManagedWithCallerAllocatedBuffer = GetFromManagedWithCallerAllocatedBufferMethod(marshallerType, managedType, spanOfT, out _); + + IMethodSymbol? toUnmanaged = GetToUnmanagedMethod(marshallerType); + + if ((fromManaged, fromManagedWithCallerAllocatedBuffer) is not (null, null) && toUnmanaged is not null) + { + unmanagedType = toUnmanaged.ReturnType; + if (unmanagedType.IsUnmanagedType) + { + if (fromManagedWithCallerAllocatedBuffer is not null) + { + shape |= MarshallerShape.CallerAllocatedBuffer; + } + if (fromManaged is not null) + { + shape |= MarshallerShape.ToUnmanaged; + } + methods = methods with + { + FromManaged = fromManaged, + FromManagedWithBuffer = fromManagedWithCallerAllocatedBuffer, + ToUnmanaged = toUnmanaged + }; + } + } + + IMethodSymbol toManaged = GetToManagedMethod(marshallerType, managedType); + IMethodSymbol toManagedGuaranteed = GetToManagedGuaranteedMethod(marshallerType, managedType); + IMethodSymbol fromUnmanaged = GetFromUnmanagedMethod(marshallerType, unmanagedType); + if ((toManaged, toManagedGuaranteed) is not (null, null) && fromUnmanaged is not null) + { + if (toManagedGuaranteed is not null) + { + shape |= MarshallerShape.GuaranteedUnmarshal; + } + if (toManaged is not null) + { + shape |= MarshallerShape.ToManaged; + } + methods = methods with + { + FromUnmanaged = fromUnmanaged, + ToManaged = toManaged, + ToManagedGuranteed = toManagedGuaranteed + }; + } + + IMethodSymbol free = GetStatefulFreeMethod(marshallerType); + if (free is not null) + { + shape |= MarshallerShape.Free; + methods = methods with { Free = free }; + } + + IMethodSymbol notifyInvokeSucceeded = GetNotifyInvokeSucceededMethod(marshallerType); + if (notifyInvokeSucceeded is not null) + { + shape |= MarshallerShape.NotifyInvokeSucceeded; + methods = methods with { NotifyInvokeSucceeded = notifyInvokeSucceeded }; + } + + if (GetStatelessGetPinnableReference(marshallerType, managedType) is not null) + { + shape |= MarshallerShape.StatelessPinnableReference; + } + if (GetStatefulGetPinnableReference(marshallerType) is not null) + { + shape |= MarshallerShape.StatefulPinnableReference; + } + + return (shape, methods); + } + + private static IMethodSymbol? GetFromManagedMethod(ITypeSymbol type, ITypeSymbol managedType) + { + return type.GetMembers(ShapeMemberNames.Value.Stateful.FromManaged) + .OfType() + .FirstOrDefault(m => m is { IsStatic: false, Parameters.Length: 1, ReturnsVoid: true } + && SymbolEqualityComparer.Default.Equals(managedType, m.Parameters[0].Type)); + } + + private static IMethodSymbol? GetFromManagedWithCallerAllocatedBufferMethod( + ITypeSymbol type, + ITypeSymbol managedType, + ITypeSymbol spanOfT, + out ITypeSymbol? spanElementType) + { + spanElementType = null; + IEnumerable methods = type.GetMembers(ShapeMemberNames.Value.Stateful.FromManaged) + .OfType() + .Where(m => m is { IsStatic: false, Parameters.Length: 2, ReturnsVoid: true } + && SymbolEqualityComparer.Default.Equals(managedType, m.Parameters[0].Type)); + + foreach (IMethodSymbol method in methods) + { + if (IsSpanOfUnmanagedType(method.Parameters[1].Type, spanOfT, out spanElementType)) + { + return method; + } + } + + return null; + + static bool IsSpanOfUnmanagedType(ITypeSymbol typeToCheck, ITypeSymbol spanOfT, out ITypeSymbol? typeArgument) + { + typeArgument = null; + if (typeToCheck is INamedTypeSymbol namedType + && SymbolEqualityComparer.Default.Equals(spanOfT, namedType.ConstructedFrom) + && namedType.TypeArguments.Length == 1 + && namedType.TypeArguments[0].IsUnmanagedType) + { + typeArgument = namedType.TypeArguments[0]; + return true; + } + + return false; + } + } + + private static IMethodSymbol? GetToManagedMethod(ITypeSymbol type, ITypeSymbol managedType) + { + return type.GetMembers(ShapeMemberNames.Value.Stateful.ToManaged) + .OfType() + .FirstOrDefault(m => m is { IsStatic: false, Parameters.Length: 0, ReturnsVoid: false, ReturnsByRef: false, ReturnsByRefReadonly: false } + && SymbolEqualityComparer.Default.Equals(managedType, m.ReturnType)); + } + + private static IMethodSymbol? GetToManagedGuaranteedMethod(ITypeSymbol type, ITypeSymbol managedType) + { + return type.GetMembers(ShapeMemberNames.Value.Stateful.ToManagedGuaranteed) + .OfType() + .FirstOrDefault(m => m is { IsStatic: false, Parameters.Length: 0, ReturnsVoid: false, ReturnsByRef: false, ReturnsByRefReadonly: false } + && SymbolEqualityComparer.Default.Equals(managedType, m.ReturnType)); + } + + private static IMethodSymbol? GetToUnmanagedMethod(ITypeSymbol type) + { + return type.GetMembers(ShapeMemberNames.Value.Stateful.ToUnmanaged) + .OfType() + .FirstOrDefault(m => m is { IsStatic: false, Parameters.Length: 0, ReturnsVoid: false, ReturnsByRef: false, ReturnsByRefReadonly: false }); + } + + private static IMethodSymbol? GetFromUnmanagedMethod(ITypeSymbol type, ITypeSymbol? unmanagedType) + { + IMethodSymbol[] candidates = type.GetMembers(ShapeMemberNames.Value.Stateful.FromUnmanaged) + .OfType() + .Where(m => m is { IsStatic: false, Parameters.Length: 1, ReturnsVoid: true }) + .ToArray(); + + // If there are multiple overloads of FromUnmanaged, we'll treat it as not present. + // Otherwise we get into a weird state where bidirectional marshallers would support overloads + // of FromUnmanaged as we'd have an unmanaged type to check, but unmanaged->managed marshallers + // would not support it as there's no way to know which overload is the correct overload. + if (candidates.Length != 1) + { + return null; + } + + if (unmanagedType is null) + { + // We don't know the unmanaged type to expected for the parameter, so just assume that the only overload of FromUnmanaged + // is correct. + return candidates[0]; + } + + if (SymbolEqualityComparer.Default.Equals(candidates[0].Parameters[0].Type, unmanagedType)) + { + // We know the unmanaged type and it matches. + // Use the method as we know it will work. + return candidates[0]; + } + + // The unmanaged type doesn't match the expected type, so we don't have an overload that will work. + return null; + } + + private static IMethodSymbol? GetStatefulFreeMethod(ITypeSymbol type) + { + return type.GetMembers(ShapeMemberNames.Value.Stateful.Free) + .OfType() + .FirstOrDefault(m => m is { IsStatic: false, Parameters.Length: 0, ReturnsVoid: true }); + } + + private static IMethodSymbol? GetNotifyInvokeSucceededMethod(ITypeSymbol type) + { + return type.GetMembers(ShapeMemberNames.Value.Stateful.NotifyInvokeSucceeded) + .OfType() + .FirstOrDefault(m => m is { IsStatic: false, Parameters.Length: 0, ReturnsVoid: true }); + } + + private static IMethodSymbol? GetStatelessGetPinnableReference(ITypeSymbol type, ITypeSymbol managedType) + { + return type.GetMembers(ShapeMemberNames.GetPinnableReference) + .OfType() + .FirstOrDefault(m => m is { IsStatic: true, Parameters.Length: 1 } and + ({ ReturnsByRef: true } or { ReturnsByRefReadonly: true }) + && SymbolEqualityComparer.Default.Equals(m.Parameters[0].Type, managedType)); + } + + private static IMethodSymbol? GetStatefulGetPinnableReference(ITypeSymbol type) + { + return type.GetMembers(ShapeMemberNames.GetPinnableReference) + .OfType() + .FirstOrDefault(m => m is { IsStatic: false, Parameters.Length: 0 } and + ({ ReturnsByRef: true } or { ReturnsByRefReadonly: true })); + } + } } diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/AttributedMarshallingModelGeneratorFactory.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/AttributedMarshallingModelGeneratorFactory.cs index b96d5cca47d5a..4adec7d230f57 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/AttributedMarshallingModelGeneratorFactory.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/AttributedMarshallingModelGeneratorFactory.cs @@ -235,9 +235,19 @@ private IMarshallingGenerator CreateCustomNativeTypeMarshaller(TypePositionInfo }; } - ICustomTypeMarshallingStrategy marshallingStrategy = new StatelessValueMarshalling(marshallerData.MarshallerType.Syntax, marshallerData.NativeType.Syntax, marshallerData.Shape); - if (marshallerData.Shape.HasFlag(MarshallerShape.CallerAllocatedBuffer)) - marshallingStrategy = new CallerAllocatedBufferMarshalling(marshallingStrategy, marshallerData.MarshallerType.Syntax, marshallerData.BufferElementType.Syntax); + ICustomTypeMarshallingStrategy marshallingStrategy; + if (marshallerData.HasState) + { + marshallingStrategy = new StatefulValueMarshalling(marshallerData.MarshallerType.Syntax, marshallerData.NativeType.Syntax, marshallerData.Shape); + if (marshallerData.Shape.HasFlag(MarshallerShape.CallerAllocatedBuffer)) + marshallingStrategy = new StatefulCallerAllocatedBufferMarshalling(marshallingStrategy, marshallerData.MarshallerType.Syntax, marshallerData.BufferElementType.Syntax); + } + else + { + marshallingStrategy = new StatelessValueMarshalling(marshallerData.MarshallerType.Syntax, marshallerData.NativeType.Syntax, marshallerData.Shape); + if (marshallerData.Shape.HasFlag(MarshallerShape.CallerAllocatedBuffer)) + marshallingStrategy = new StatelessCallerAllocatedBufferMarshalling(marshallingStrategy, marshallerData.MarshallerType.Syntax, marshallerData.BufferElementType.Syntax); + } IMarshallingGenerator marshallingGenerator = new CustomNativeTypeMarshallingGenerator(marshallingStrategy, enableByValueContentsMarshalling: false); diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/CustomNativeTypeMarshallingGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/CustomNativeTypeMarshallingGenerator.cs index 90e0df1bc497f..c8eda3b81a1ff 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/CustomNativeTypeMarshallingGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/CustomNativeTypeMarshallingGenerator.cs @@ -69,6 +69,15 @@ public IEnumerable Generate(TypePositionInfo info, StubCodeCont return _nativeTypeMarshaller.GeneratePinnedMarshalStatements(info, context); } break; + case StubCodeContext.Stage.NotifyForSuccessfulInvoke: + if (!info.IsManagedReturnPosition && info.RefKind != RefKind.Out) + { + if (_nativeTypeMarshaller is ICustomTypeMarshallingStrategy strategyWithGuaranteedUnmarshal) + { + return strategyWithGuaranteedUnmarshal.GenerateNotifyForSuccessfulInvokeStatements(info, context); + } + } + break; case StubCodeContext.Stage.UnmarshalCapture: if (info.IsManagedReturnPosition || (info.IsByRef && info.RefKind != RefKind.In)) { diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/ICustomTypeMarshallingStrategy.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/ICustomTypeMarshallingStrategy.cs index 124ffa05eccd7..1b52e8712cca3 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/ICustomTypeMarshallingStrategy.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/ICustomTypeMarshallingStrategy.cs @@ -39,6 +39,7 @@ internal interface ICustomTypeMarshallingStrategyBase internal interface ICustomTypeMarshallingStrategy : ICustomTypeMarshallingStrategyBase { IEnumerable GenerateGuaranteedUnmarshalStatements(TypePositionInfo info, StubCodeContext context); + IEnumerable GenerateNotifyForSuccessfulInvokeStatements(TypePositionInfo info, StubCodeContext context); } /// @@ -69,7 +70,7 @@ public IEnumerable GenerateCleanupStatements(TypePositionInfo i if (!_shape.HasFlag(MarshallerShape.Free)) yield break; - // = .ConvertToManaged(); + // .Free(); yield return ExpressionStatement( InvocationExpression( MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, @@ -163,18 +164,23 @@ public IEnumerable GeneratePinStatements(TypePositionInfo info, { return Array.Empty(); } + + public IEnumerable GenerateNotifyForSuccessfulInvokeStatements(TypePositionInfo info, StubCodeContext context) + { + return Array.Empty(); + } } /// /// Marshaller that enables support for a stackalloc constructor variant on a native type. /// - internal sealed class CallerAllocatedBufferMarshalling : ICustomTypeMarshallingStrategy + internal sealed class StatelessCallerAllocatedBufferMarshalling : ICustomTypeMarshallingStrategy { private readonly ICustomTypeMarshallingStrategy _innerMarshaller; private readonly TypeSyntax _marshallerType; private readonly TypeSyntax _bufferElementType; - public CallerAllocatedBufferMarshalling(ICustomTypeMarshallingStrategy innerMarshaller, TypeSyntax marshallerType, TypeSyntax bufferElementType) + public StatelessCallerAllocatedBufferMarshalling(ICustomTypeMarshallingStrategy innerMarshaller, TypeSyntax marshallerType, TypeSyntax bufferElementType) { _innerMarshaller = innerMarshaller; _marshallerType = marshallerType; @@ -188,6 +194,13 @@ public CallerAllocatedBufferMarshalling(ICustomTypeMarshallingStrategy innerMars public IEnumerable GenerateMarshalStatements(TypePositionInfo info, StubCodeContext context, IEnumerable nativeTypeConstructorArguments) { if (CanUseCallerAllocatedBuffer(info, context)) + { + return GenerateCallerAllocatedBufferMarshalStatements(); + } + + return _innerMarshaller.GenerateMarshalStatements(info, context, nativeTypeConstructorArguments); + + IEnumerable GenerateCallerAllocatedBufferMarshalStatements() { string bufferIdentifier = context.GetAdditionalIdentifier(info, "buffer"); @@ -227,13 +240,6 @@ public IEnumerable GenerateMarshalStatements(TypePositionInfo i Argument(IdentifierName(bufferIdentifier)) }))))); } - else - { - foreach (StatementSyntax statement in _innerMarshaller.GenerateMarshalStatements(info, context, nativeTypeConstructorArguments)) - { - yield return statement; - } - } } public IEnumerable GeneratePinnedMarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GeneratePinnedMarshalStatements(info, context); @@ -248,5 +254,295 @@ private static bool CanUseCallerAllocatedBuffer(TypePositionInfo info, StubCodeC { return context.SingleFrameSpansNativeContext && (!info.IsByRef || info.RefKind == RefKind.In); } + + public IEnumerable GenerateNotifyForSuccessfulInvokeStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateNotifyForSuccessfulInvokeStatements(info, context); + } + + internal sealed class StatefulValueMarshalling : ICustomTypeMarshallingStrategy + { + internal const string MarshallerIdentifier = "marshaller"; + private readonly TypeSyntax _marshallerTypeSyntax; + private readonly TypeSyntax _nativeTypeSyntax; + private readonly MarshallerShape _shape; + + public StatefulValueMarshalling(TypeSyntax marshallerTypeSyntax, TypeSyntax nativeTypeSyntax, MarshallerShape shape) + { + _marshallerTypeSyntax = marshallerTypeSyntax; + _nativeTypeSyntax = nativeTypeSyntax; + _shape = shape; + } + + public TypeSyntax AsNativeType(TypePositionInfo info) + { + return _nativeTypeSyntax; + } + + public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) => true; + + public IEnumerable GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context) + { + if (!_shape.HasFlag(MarshallerShape.Free)) + yield break; + + // .Free(); + yield return ExpressionStatement( + InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + IdentifierName(context.GetAdditionalIdentifier(info, MarshallerIdentifier)), + IdentifierName(ShapeMemberNames.Free)), + ArgumentList())); + } + + public IEnumerable GenerateGuaranteedUnmarshalStatements(TypePositionInfo info, StubCodeContext context) + { + if (!_shape.HasFlag(MarshallerShape.GuaranteedUnmarshal)) + yield break; + + (string managedIdentifier, _) = context.GetIdentifiers(info); + + // = .ToManagedGuaranteed(); + yield return ExpressionStatement( + AssignmentExpression( + SyntaxKind.SimpleAssignmentExpression, + IdentifierName(managedIdentifier), + InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + IdentifierName(context.GetAdditionalIdentifier(info, MarshallerIdentifier)), + IdentifierName(ShapeMemberNames.Value.Stateful.ToManagedGuaranteed)), + ArgumentList()))); + } + + public IEnumerable GenerateMarshalStatements(TypePositionInfo info, StubCodeContext context, IEnumerable nativeTypeConstructorArguments) + { + if (!_shape.HasFlag(MarshallerShape.ToUnmanaged)) + yield break; + + (string managedIdentifier, _) = context.GetIdentifiers(info); + + // .FromManaged(); + yield return ExpressionStatement( + InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + IdentifierName(context.GetAdditionalIdentifier(info, MarshallerIdentifier)), + IdentifierName(ShapeMemberNames.Value.Stateful.FromManaged)), + ArgumentList(SingletonSeparatedList( + Argument(IdentifierName(managedIdentifier)))))); + } + + public IEnumerable GeneratePinnedMarshalStatements(TypePositionInfo info, StubCodeContext context) + { + if (!_shape.HasFlag(MarshallerShape.ToUnmanaged) && !_shape.HasFlag(MarshallerShape.CallerAllocatedBuffer)) + yield break; + + (_, string nativeIdentifier) = context.GetIdentifiers(info); + + // = .ToUnmanaged(); + yield return ExpressionStatement( + AssignmentExpression( + SyntaxKind.SimpleAssignmentExpression, + IdentifierName(nativeIdentifier), + InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + IdentifierName(context.GetAdditionalIdentifier(info, MarshallerIdentifier)), + IdentifierName(ShapeMemberNames.Value.Stateful.ToUnmanaged)), + ArgumentList()))); + } + + public IEnumerable GenerateUnmarshalStatements(TypePositionInfo info, StubCodeContext context) + { + if (!_shape.HasFlag(MarshallerShape.ToManaged)) + yield break; + + (string managedIdentifier, _) = context.GetIdentifiers(info); + + // = .ToManaged(); + yield return ExpressionStatement( + AssignmentExpression( + SyntaxKind.SimpleAssignmentExpression, + IdentifierName(managedIdentifier), + InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + IdentifierName(context.GetAdditionalIdentifier(info, MarshallerIdentifier)), + IdentifierName(ShapeMemberNames.Value.Stateful.ToManaged)), + ArgumentList()))); + } + + public IEnumerable GenerateUnmarshalCaptureStatements(TypePositionInfo info, StubCodeContext context) + { + if (!_shape.HasFlag(MarshallerShape.ToManaged) && !_shape.HasFlag(MarshallerShape.GuaranteedUnmarshal)) + yield break; + + (_, string nativeIdentifier) = context.GetIdentifiers(info); + + // .FromUnmanaged(); + yield return ExpressionStatement( + InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + IdentifierName(context.GetAdditionalIdentifier(info, MarshallerIdentifier)), + IdentifierName(ShapeMemberNames.Value.Stateful.FromUnmanaged)), + ArgumentList(SingletonSeparatedList( + Argument(IdentifierName(nativeIdentifier)))))); + } + + public IEnumerable GetNativeTypeConstructorArguments(TypePositionInfo info, StubCodeContext context) + { + return Array.Empty(); + } + + public IEnumerable GenerateSetupStatements(TypePositionInfo info, StubCodeContext context) + { + yield return MarshallerHelpers.Declare( + _marshallerTypeSyntax, + context.GetAdditionalIdentifier(info, MarshallerIdentifier), + ImplicitObjectCreationExpression(ArgumentList(), initializer: null)); + } + + public IEnumerable GeneratePinStatements(TypePositionInfo info, StubCodeContext context) + { + return Array.Empty(); + } + + public IEnumerable GenerateNotifyForSuccessfulInvokeStatements(TypePositionInfo info, StubCodeContext context) + { + if (!_shape.HasFlag(MarshallerShape.NotifyInvokeSucceeded)) + yield break; + + // .NotifyInvokeSucceeded(); + yield return ExpressionStatement( + InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + IdentifierName(context.GetAdditionalIdentifier(info, MarshallerIdentifier)), + IdentifierName(ShapeMemberNames.Value.Stateful.NotifyInvokeSucceeded)), + ArgumentList())); + } + } + + /// + /// Marshaller that enables support for a stackalloc constructor variant on a native type. + /// + internal sealed class StatefulCallerAllocatedBufferMarshalling : ICustomTypeMarshallingStrategy + { + private readonly ICustomTypeMarshallingStrategy _innerMarshaller; + private readonly TypeSyntax _marshallerType; + private readonly TypeSyntax _bufferElementType; + + public StatefulCallerAllocatedBufferMarshalling(ICustomTypeMarshallingStrategy innerMarshaller, TypeSyntax marshallerType, TypeSyntax bufferElementType) + { + _innerMarshaller = innerMarshaller; + _marshallerType = marshallerType; + _bufferElementType = bufferElementType; + } + + public TypeSyntax AsNativeType(TypePositionInfo info) + { + return _innerMarshaller.AsNativeType(info); + } + + public IEnumerable GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context) + { + return _innerMarshaller.GenerateCleanupStatements(info, context); + } + + public IEnumerable GenerateMarshalStatements(TypePositionInfo info, StubCodeContext context, IEnumerable nativeTypeConstructorArguments) + { + if (CanUseCallerAllocatedBuffer(info, context)) + { + return GenerateCallerAllocatedBufferMarshalStatements(); + } + + return _innerMarshaller.GenerateMarshalStatements(info, context, nativeTypeConstructorArguments); + + IEnumerable GenerateCallerAllocatedBufferMarshalStatements() + { + // TODO: Update once we can consume the scoped keword. We should be able to simplify this once we get that API. + string stackPtrIdentifier = context.GetAdditionalIdentifier(info, "stackptr"); + // * __stackptr = stackalloc [<_bufferSize>]; + yield return LocalDeclarationStatement( + VariableDeclaration( + PointerType(_bufferElementType), + SingletonSeparatedList( + VariableDeclarator(stackPtrIdentifier) + .WithInitializer(EqualsValueClause( + StackAllocArrayCreationExpression( + ArrayType( + _bufferElementType, + SingletonList(ArrayRankSpecifier(SingletonSeparatedList( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + _marshallerType, + IdentifierName(ShapeMemberNames.BufferSize)) + )))))))))); + + + (string managedIdentifier, _) = context.GetIdentifiers(info); + + // .FromManaged(); + yield return ExpressionStatement( + InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + IdentifierName(context.GetAdditionalIdentifier(info, StatefulValueMarshalling.MarshallerIdentifier)), + IdentifierName(ShapeMemberNames.Value.Stateful.FromManaged)), + ArgumentList(SeparatedList( + new[] + { + Argument(IdentifierName(managedIdentifier)), + Argument( + ObjectCreationExpression( + GenericName(Identifier(TypeNames.System_Span), + TypeArgumentList(SingletonSeparatedList( + _bufferElementType)))) + .WithArgumentList( + ArgumentList(SeparatedList(new ArgumentSyntax[] + { + Argument(IdentifierName(stackPtrIdentifier)), + Argument(MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + _marshallerType, + IdentifierName(ShapeMemberNames.BufferSize))) + })))) + })))); + } + } + + public IEnumerable GeneratePinnedMarshalStatements(TypePositionInfo info, StubCodeContext context) + { + return _innerMarshaller.GeneratePinnedMarshalStatements(info, context); + } + + private static bool CanUseCallerAllocatedBuffer(TypePositionInfo info, StubCodeContext context) + { + return context.SingleFrameSpansNativeContext && (!info.IsByRef || info.RefKind == RefKind.In); + } + + public IEnumerable GeneratePinStatements(TypePositionInfo info, StubCodeContext context) + { + return _innerMarshaller.GeneratePinStatements(info, context); + } + + public IEnumerable GenerateSetupStatements(TypePositionInfo info, StubCodeContext context) + { + return _innerMarshaller.GenerateSetupStatements(info, context); + } + + public IEnumerable GenerateUnmarshalCaptureStatements(TypePositionInfo info, StubCodeContext context) + { + return _innerMarshaller.GenerateUnmarshalCaptureStatements(info, context); + } + + public IEnumerable GenerateUnmarshalStatements(TypePositionInfo info, StubCodeContext context) + { + return _innerMarshaller.GenerateUnmarshalStatements(info, context); + } + + public IEnumerable GetNativeTypeConstructorArguments(TypePositionInfo info, StubCodeContext context) + { + return Array.Empty(); + } + + public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) + { + return _innerMarshaller.UsesNativeIdentifier(info, context); + } + + public IEnumerable GenerateGuaranteedUnmarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateGuaranteedUnmarshalStatements(info, context); + public IEnumerable GenerateNotifyForSuccessfulInvokeStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateNotifyForSuccessfulInvokeStatements(info, context); } } diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/MarshallerHelpers.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/MarshallerHelpers.cs index 7bf836506a911..8f9d334d313ed 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/MarshallerHelpers.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/MarshallerHelpers.cs @@ -45,18 +45,23 @@ public static ForStatementSyntax GetForLoop(ExpressionSyntax lengthExpression, s } public static LocalDeclarationStatementSyntax Declare(TypeSyntax typeSyntax, string identifier, bool initializeToDefault) + { + return Declare(typeSyntax, identifier, initializeToDefault ? LiteralExpression(SyntaxKind.DefaultLiteralExpression) : null); + } + + public static LocalDeclarationStatementSyntax Declare(TypeSyntax typeSyntax, string identifier, ExpressionSyntax? initializer) { VariableDeclaratorSyntax decl = VariableDeclarator(identifier); - if (initializeToDefault) + if (initializer is not null) { decl = decl.WithInitializer( EqualsValueClause( - LiteralExpression(SyntaxKind.DefaultLiteralExpression))); + initializer)); } // ; // or - // = default; + // = ; return LocalDeclarationStatement( VariableDeclaration( typeSyntax, diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshallingAttributeInfo.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshallingAttributeInfo.cs index 81021677c4f02..e4e72e0535d4d 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshallingAttributeInfo.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshallingAttributeInfo.cs @@ -584,7 +584,7 @@ private MarshallingInfo CreateNativeMarshallingInfo( ref int maxIndirectionDepthUsed) { bool isLinearCollectionMarshalling = ManualTypeMarshallingHelper.IsLinearCollectionEntryPoint(entryPointType); - if (ManualTypeMarshallingHelper.TryGetMarshallersFromEntryType(entryPointType, type, isLinearCollectionMarshalling, _compilation, out CustomTypeMarshallers? marshallers)) + if (ManualTypeMarshallingHelper.HasEntryPointMarshallerAttribute(entryPointType)) { if (!entryPointType.IsStatic) { @@ -592,10 +592,15 @@ private MarshallingInfo CreateNativeMarshallingInfo( return NoMarshallingInfo.Instance; } - bool isPinnableManagedType = !isMarshalUsingAttribute && ManualTypeMarshallingHelper.FindGetPinnableReference(type) is not null; - return isLinearCollectionMarshalling - ? NoMarshallingInfo.Instance // TODO: handle linear collection marshallers - : new NativeMarshallingAttributeInfo(ManagedTypeInfo.CreateTypeInfoForTypeSymbol(entryPointType), marshallers.Value, isPinnableManagedType); + if (ManualTypeMarshallingHelper.TryGetMarshallersFromEntryType(entryPointType, type, isLinearCollectionMarshalling, _compilation, out CustomTypeMarshallers? marshallers)) + { + bool isPinnableManagedType = !isMarshalUsingAttribute && ManualTypeMarshallingHelper.FindGetPinnableReference(type) is not null; + return isLinearCollectionMarshalling + ? NoMarshallingInfo.Instance // TODO: handle linear collection marshallers + : new NativeMarshallingAttributeInfo(ManagedTypeInfo.CreateTypeInfoForTypeSymbol(entryPointType), marshallers.Value, isPinnableManagedType); + } + + return NoMarshallingInfo.Instance; } return CreateNativeMarshallingInfo_V1(type, entryPointType, attrData, isMarshalUsingAttribute, indirectionLevel, parsedCountInfo, useSiteAttributes, inspectedElements, ref maxIndirectionDepthUsed); diff --git a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.Tests/CustomMarshallingTests.cs b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.Tests/CustomMarshallingTests.cs index f35fdcd05ba52..e51e8602825ab 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.Tests/CustomMarshallingTests.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.Tests/CustomMarshallingTests.cs @@ -5,6 +5,7 @@ using System.Runtime.InteropServices; using System.Runtime.InteropServices.Marshalling; using System.Text; +using NativeExports; using SharedTypes; using Xunit; @@ -61,6 +62,28 @@ public static int ConvertToManagedGuaranteed(int unmanaged) } } + internal partial class Stateful + { + [LibraryImport(NativeExportsNE_Binary, EntryPoint = "subtract_return_int")] + public static partial IntWrapperWithNotification SubtractInts(IntWrapperWithNotification x, IntWrapperWithNotification y); + [LibraryImport(NativeExportsNE_Binary, EntryPoint = "subtract_out_int")] + public static partial void SubtractInts(IntWrapperWithNotification x, IntWrapperWithNotification y, out IntWrapperWithNotification result); + + [LibraryImport(NativeExportsNE_Binary, EntryPoint = "negate_bools")] + public static partial void NegateBools( + [MarshalUsing(typeof(BoolStructMarshallerStateful))] BoolStruct boolStruct, + [MarshalUsing(typeof(BoolStructMarshallerStateful))] out BoolStruct pBoolStructOut); + + [LibraryImport(NativeExportsNE_Binary, EntryPoint = "and_bools_ref")] + [return: MarshalAs(UnmanagedType.U1)] + public static partial bool AndBoolsRef([MarshalUsing(typeof(BoolStructMarshallerStateful))] in BoolStruct boolStruct); + + + [LibraryImport(NativeExportsNE_Binary, EntryPoint = "double_int_ref")] + [return: MarshalUsing(typeof(IntWrapperMarshallerStateful))] + public static partial IntWrapper DoubleIntRef([MarshalUsing(typeof(IntWrapperMarshallerStateful))] IntWrapper pInt); + } + [LibraryImport(NativeExportsNE_Binary, EntryPoint = "reverse_replace_ref_ushort")] public static partial void ReverseReplaceString([MarshalUsing(typeof(Utf16StringMarshaller))] ref string s); @@ -196,6 +219,97 @@ public void NonBlittableStructPinnableMarshalerPassByRef() Assert.Equal(expected, str); } + [Fact] + public void NotifyInvokeSucceededInNoReturn() + { + bool xNotified = false; + bool yNotified = false; + IntWrapperWithNotification x = new() { Value = 23 }; + x.InvokeSucceeded += (sender, args) => xNotified = true; + IntWrapperWithNotification y = new() { Value = 897 }; + y.InvokeSucceeded += (sender, args) => yNotified = true; + + int oldNumInvokeSucceededOnUninitialized = IntWrapperWithNotification.NumInvokeSucceededOnUninitialized; + + int result = NativeExportsNE.Stateful.SubtractInts(x, y).Value; + + Assert.Equal(x.Value - y.Value, result); + Assert.True(xNotified); + Assert.True(yNotified); + Assert.Equal(oldNumInvokeSucceededOnUninitialized, IntWrapperWithNotification.NumInvokeSucceededOnUninitialized); + } + + [Fact] + public void NotifyInvokeSucceededInNoOut() + { + bool xNotified = false; + bool yNotified = false; + IntWrapperWithNotification x = new() { Value = 23 }; + x.InvokeSucceeded += (sender, args) => xNotified = true; + IntWrapperWithNotification y = new() { Value = 897 }; + y.InvokeSucceeded += (sender, args) => yNotified = true; + + int oldNumInvokeSucceededOnUninitialized = IntWrapperWithNotification.NumInvokeSucceededOnUninitialized; + + NativeExportsNE.Stateful.SubtractInts(x, y, out IntWrapperWithNotification result); + + Assert.Equal(x.Value - y.Value, result.Value); + Assert.True(xNotified); + Assert.True(yNotified); + Assert.Equal(oldNumInvokeSucceededOnUninitialized, IntWrapperWithNotification.NumInvokeSucceededOnUninitialized); + } + + [Fact] + public void NonBlittableStructWithoutAllocation_Stateful() + { + var boolStruct = new BoolStruct + { + b1 = true, + b2 = false, + b3 = true + }; + + NativeExportsNE.Stateful.NegateBools(boolStruct, out BoolStruct boolStructNegated); + + Assert.Equal(!boolStruct.b1, boolStructNegated.b1); + Assert.Equal(!boolStruct.b2, boolStructNegated.b2); + Assert.Equal(!boolStruct.b3, boolStructNegated.b3); + } + + [Theory] + [InlineData(true, true, true)] + [InlineData(true, true, false)] + [InlineData(true, false, true)] + [InlineData(true, false, false)] + [InlineData(false, true, true)] + [InlineData(false, true, false)] + [InlineData(false, false, true)] + [InlineData(false, false, false)] + public void NonBlittableStructIn_Stateful(bool b1, bool b2, bool b3) + { + var container = new BoolStruct + { + b1 = b1, + b2 = b2, + b3 = b3 + }; + + Assert.Equal(b1 && b2 && b3, NativeExportsNE.Stateful.AndBoolsRef(container)); + } + + [Fact] + public void NonBlittableType_Stateful_Marshalling_Free() + { + int originalValue = 42; + var wrapper = new IntWrapper { i = originalValue }; + + var retVal = NativeExportsNE.Stateful.DoubleIntRef(wrapper); + + // We don't pin the managed value, so it shouldn't update. + Assert.Equal(originalValue, wrapper.i); + Assert.Equal(originalValue * 2, retVal.i); + } + private static string ReverseChars(string value) { if (value == null) diff --git a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/CodeSnippets.cs b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/CodeSnippets.cs index 940cd5ab2b02e..f24f0c4ed2727 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/CodeSnippets.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/CodeSnippets.cs @@ -698,7 +698,13 @@ public struct Native { } public static Native ConvertToUnmanaged(S s) => default; } "; - private static string StatelessIn = @" + public static string NonStaticMarshallerEntryPoint => BasicParameterByValue("S") + + NonBlittableUserDefinedType() + + NonStatic; + + public static class Stateless + { + private static string In = @" [CustomMarshaller(typeof(S), Scenario.ManagedToUnmanagedIn, typeof(Marshaller))] public static class Marshaller { @@ -707,7 +713,7 @@ public struct Native { } public static Native ConvertToUnmanaged(S s) => default; } "; - private static string StatelessInBuffer = @" + private static string InBuffer = @" [CustomMarshaller(typeof(S), Scenario.ManagedToUnmanagedIn, typeof(Marshaller))] public static class Marshaller { @@ -717,7 +723,7 @@ public struct Native { } public static Native ConvertToUnmanaged(S s, System.Span buffer) => default; } "; - private static string StatelessOut = @" + private static string Out = @" [CustomMarshaller(typeof(S), Scenario.ManagedToUnmanagedOut, typeof(Marshaller))] public static class Marshaller { @@ -726,7 +732,7 @@ public struct Native { } public static S ConvertToManaged(Native n) => default; } "; - private static string StatelessOutGuaranteed = @" + private static string OutGuaranteed = @" [CustomMarshaller(typeof(S), Scenario.ManagedToUnmanagedOut, typeof(Marshaller))] public static class Marshaller { @@ -735,7 +741,7 @@ public struct Native { } public static S ConvertToManagedGuaranteed(Native n) => default; } "; - public static string StatelessRef = @" + public static string Ref = @" [CustomMarshaller(typeof(S), Scenario.Default, typeof(Marshaller))] public static class Marshaller { @@ -745,7 +751,7 @@ public struct Native { } public static S ConvertToManaged(Native n) => default; } "; - public static string StatelessRefBuffer = @" + public static string RefBuffer = @" [CustomMarshaller(typeof(S), Scenario.ManagedToUnmanagedIn, typeof(Marshaller))] [CustomMarshaller(typeof(S), Scenario.ManagedToUnmanagedOut, typeof(Marshaller))] public static class Marshaller @@ -757,7 +763,7 @@ public struct Native { } public static S ConvertToManaged(Native n) => default; } "; - public static string StatelessRefOptionalBuffer = @" + public static string RefOptionalBuffer = @" [CustomMarshaller(typeof(S), Scenario.Default, typeof(Marshaller))] public static class Marshaller { @@ -770,61 +776,257 @@ public struct Native { } } "; - public static string ManagedToNativeOnlyOutParameter => BasicParameterWithByRefModifier("out", "S") - + NonBlittableUserDefinedType() - + StatelessIn; + public static string ManagedToNativeOnlyOutParameter => BasicParameterWithByRefModifier("out", "S") + + NonBlittableUserDefinedType() + + In; - public static string NativeToManagedOnlyOutParameter => BasicParameterWithByRefModifier("out", "S") - + NonBlittableUserDefinedType() - + StatelessOut; + public static string NativeToManagedOnlyOutParameter => BasicParameterWithByRefModifier("out", "S") + + NonBlittableUserDefinedType() + + Out; - public static string NativeToManagedGuaranteedOnlyOutParameter => BasicParameterWithByRefModifier("out", "S") - + NonBlittableUserDefinedType() - + StatelessOutGuaranteed; + public static string NativeToManagedGuaranteedOnlyOutParameter => BasicParameterWithByRefModifier("out", "S") + + NonBlittableUserDefinedType() + + OutGuaranteed; - public static string ManagedToNativeOnlyReturnValue => BasicReturnType("S") - + NonBlittableUserDefinedType() - + StatelessIn; + public static string ManagedToNativeOnlyReturnValue => BasicReturnType("S") + + NonBlittableUserDefinedType() + + In; - public static string NativeToManagedOnlyReturnValue => BasicReturnType("S") - + NonBlittableUserDefinedType() - + StatelessOut; + public static string NativeToManagedOnlyReturnValue => BasicReturnType("S") + + NonBlittableUserDefinedType() + + Out; - public static string NativeToManagedGuaranteedOnlyReturnValue => BasicReturnType("S") - + NonBlittableUserDefinedType() - + StatelessOut; + public static string NativeToManagedGuaranteedOnlyReturnValue => BasicReturnType("S") + + NonBlittableUserDefinedType() + + Out; - public static string NativeToManagedOnlyInParameter => BasicParameterWithByRefModifier("in", "S") - + NonBlittableUserDefinedType() - + StatelessOut; + public static string NativeToManagedOnlyInParameter => BasicParameterWithByRefModifier("in", "S") + + NonBlittableUserDefinedType() + + Out; - public static string ParametersAndModifiers = BasicParametersAndModifiers("S", UsingSystemRuntimeInteropServicesMarshalling) - + NonBlittableUserDefinedType(defineNativeMarshalling: true) - + StatelessRef; + public static string ParametersAndModifiers = BasicParametersAndModifiers("S", UsingSystemRuntimeInteropServicesMarshalling) + + NonBlittableUserDefinedType(defineNativeMarshalling: true) + + Ref; - public static string MarshalUsingParametersAndModifiers = MarshalUsingParametersAndModifiers("S", "Marshaller") - + NonBlittableUserDefinedType(defineNativeMarshalling: false) - + StatelessRef; + public static string MarshalUsingParametersAndModifiers = MarshalUsingParametersAndModifiers("S", "Marshaller") + + NonBlittableUserDefinedType(defineNativeMarshalling: false) + + Ref; - public static string NonStaticMarshallerEntryPoint => BasicParameterByValue("S") - + NonBlittableUserDefinedType() - + NonStatic; + public static string StackallocByValueInParameter => BasicParameterByValue("S") + + NonBlittableUserDefinedType() + + InBuffer; - public static string StackallocByValueInParameter => BasicParameterByValue("S") - + NonBlittableUserDefinedType() - + StatelessInBuffer; + public static string StackallocParametersAndModifiersNoRef = BasicParametersAndModifiersNoRef("S") + + NonBlittableUserDefinedType() + + RefBuffer; - public static string StackallocParametersAndModifiersNoRef = BasicParametersAndModifiersNoRef("S") - + NonBlittableUserDefinedType() - + StatelessRefBuffer; + public static string StackallocOnlyRefParameter = BasicParameterWithByRefModifier("ref", "S") + + NonBlittableUserDefinedType() + + RefBuffer; - public static string StackallocOnlyRefParameter = BasicParameterWithByRefModifier("ref", "S") - + NonBlittableUserDefinedType() - + StatelessRefBuffer; + public static string OptionalStackallocParametersAndModifiers = BasicParametersAndModifiers("S", UsingSystemRuntimeInteropServicesMarshalling) + + NonBlittableUserDefinedType() + + RefOptionalBuffer; + } - public static string OptionalStackallocParametersAndModifiers = BasicParametersAndModifiers("S", UsingSystemRuntimeInteropServicesMarshalling) - + NonBlittableUserDefinedType() - + StatelessRefOptionalBuffer; + public static class Stateful + { + private static string In = @" +[CustomMarshaller(typeof(S), Scenario.ManagedToUnmanagedIn, typeof(M))] +public static class Marshaller +{ + public struct Native { } + + public struct M + { + public void FromManaged(S s) {} + public Native ToUnmanaged() => default; + } +} +"; + + private static string InBuffer = @" +[CustomMarshaller(typeof(S), Scenario.ManagedToUnmanagedIn, typeof(M))] +public static class Marshaller +{ + public struct Native { } + + public struct M + { + public const int BufferSize = 0x100; + public void FromManaged(S s, System.Span buffer) {} + public Native ToUnmanaged() => default; + } +} +"; + private static string Out = @" +[CustomMarshaller(typeof(S), Scenario.ManagedToUnmanagedOut, typeof(M))] +public static class Marshaller +{ + public struct Native { } + + public struct M + { + public void FromUnmanaged(Native n) {} + public S ToManaged() => default; + } +} +"; + private static string OutGuaranteed = @" +[CustomMarshaller(typeof(S), Scenario.ManagedToUnmanagedOut, typeof(M))] +public static class Marshaller +{ + public struct Native { } + + public struct M + { + public void FromUnmanaged(Native n) {} + public S ToManagedGuaranteed() => default; + } +} +"; + public static string Ref = @" +[CustomMarshaller(typeof(S), Scenario.Default, typeof(M))] +public static class Marshaller +{ + public struct Native { } + + public struct M + { + public void FromManaged(S s) {} + public Native ToUnmanaged() => default; + public void FromUnmanaged(Native n) {} + public S ToManaged() => default; + } +} +"; + public static string RefWithFree = @" +[CustomMarshaller(typeof(S), Scenario.Default, typeof(M))] +public static class Marshaller +{ + public struct Native { } + + public struct M + { + public void FromManaged(S s) {} + public Native ToUnmanaged() => default; + public void FromUnmanaged(Native n) {} + public S ToManaged() => default; + public void Free() {} + } +} +"; + public static string RefWithNotifyInvokeSucceeded = @" +[CustomMarshaller(typeof(S), Scenario.Default, typeof(M))] +public static class Marshaller +{ + public struct Native { } + + public struct M + { + public void FromManaged(S s) {} + public Native ToUnmanaged() => default; + public void FromUnmanaged(Native n) {} + public S ToManaged() => default; + public void NotifyInvokeSucceeded() {} + } +} +"; + public static string RefBuffer = @" +[CustomMarshaller(typeof(S), Scenario.ManagedToUnmanagedIn, typeof(M))] +[CustomMarshaller(typeof(S), Scenario.ManagedToUnmanagedOut, typeof(M))] +public static class Marshaller +{ + public struct Native { } + + public struct M + { + public const int BufferSize = 0x100; + public void FromManaged(S s, System.Span buffer) {} + public Native ToUnmanaged() => default; + public void FromUnmanaged(Native n) {} + public S ToManaged() => default; + } +} +"; + public static string RefOptionalBuffer = @" +[CustomMarshaller(typeof(S), Scenario.Default, typeof(M))] +public static class Marshaller +{ + public struct Native { } + + + public struct M + { + public const int BufferSize = 0x100; + public void FromManaged(S s) {} + public void FromManaged(S s, System.Span buffer) {} + public Native ToUnmanaged() => default; + public void FromUnmanaged(Native n) {} + public S ToManaged() => default; + } +} +"; + public static string ManagedToNativeOnlyOutParameter => BasicParameterWithByRefModifier("out", "S") + + NonBlittableUserDefinedType() + + In; + + public static string NativeToManagedOnlyOutParameter => BasicParameterWithByRefModifier("out", "S") + + NonBlittableUserDefinedType() + + Out; + + public static string NativeToManagedGuaranteedOnlyOutParameter => BasicParameterWithByRefModifier("out", "S") + + NonBlittableUserDefinedType() + + OutGuaranteed; + + public static string ManagedToNativeOnlyReturnValue => BasicReturnType("S") + + NonBlittableUserDefinedType() + + In; + + public static string NativeToManagedOnlyReturnValue => BasicReturnType("S") + + NonBlittableUserDefinedType() + + Out; + + public static string NativeToManagedGuaranteedOnlyReturnValue => BasicReturnType("S") + + NonBlittableUserDefinedType() + + Out; + + public static string NativeToManagedOnlyInParameter => BasicParameterWithByRefModifier("in", "S") + + NonBlittableUserDefinedType() + + Out; + + public static string ParametersAndModifiers = BasicParametersAndModifiers("S", UsingSystemRuntimeInteropServicesMarshalling) + + NonBlittableUserDefinedType(defineNativeMarshalling: true) + + Ref; + + public static string ParametersAndModifiersWithFree = BasicParametersAndModifiers("S", UsingSystemRuntimeInteropServicesMarshalling) + + NonBlittableUserDefinedType(defineNativeMarshalling: true) + + RefWithFree; + + public static string ParametersAndModifiersWithNotifyInvokeSucceeded = BasicParametersAndModifiers("S", UsingSystemRuntimeInteropServicesMarshalling) + + NonBlittableUserDefinedType(defineNativeMarshalling: true) + + RefWithNotifyInvokeSucceeded; + + public static string MarshalUsingParametersAndModifiers = MarshalUsingParametersAndModifiers("S", "Marshaller") + + NonBlittableUserDefinedType(defineNativeMarshalling: false) + + Ref; + + public static string StackallocByValueInParameter => BasicParameterByValue("S") + + NonBlittableUserDefinedType() + + InBuffer; + + public static string StackallocParametersAndModifiersNoRef = BasicParametersAndModifiersNoRef("S") + + NonBlittableUserDefinedType() + + RefBuffer; + + public static string StackallocOnlyRefParameter = BasicParameterWithByRefModifier("ref", "S") + + NonBlittableUserDefinedType() + + RefBuffer; + + public static string OptionalStackallocParametersAndModifiers = BasicParametersAndModifiers("S", UsingSystemRuntimeInteropServicesMarshalling) + + NonBlittableUserDefinedType() + + RefOptionalBuffer; + } } public static class CustomStructMarshalling_V1 diff --git a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/CompileFails.cs b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/CompileFails.cs index 62de36e5448aa..935980dbc7395 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/CompileFails.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/CompileFails.cs @@ -92,11 +92,15 @@ public static IEnumerable CodeSnippetsToCompile() yield return new object[] { CodeSnippets.MarshalUsingArrayParameterWithSizeParam(isByRef: false), 2, 0 }; // Custom type marshalling with invalid members - yield return new object[] { CodeSnippets.CustomStructMarshalling.ManagedToNativeOnlyOutParameter, 1, 0 }; - yield return new object[] { CodeSnippets.CustomStructMarshalling.ManagedToNativeOnlyReturnValue, 1, 0 }; - yield return new object[] { CodeSnippets.CustomStructMarshalling.NativeToManagedOnlyInParameter, 1, 0 }; yield return new object[] { CodeSnippets.CustomStructMarshalling.NonStaticMarshallerEntryPoint, 2, 0 }; - yield return new object[] { CodeSnippets.CustomStructMarshalling.StackallocOnlyRefParameter, 1, 0 }; + yield return new object[] { CodeSnippets.CustomStructMarshalling.Stateless.ManagedToNativeOnlyOutParameter, 1, 0 }; + yield return new object[] { CodeSnippets.CustomStructMarshalling.Stateless.ManagedToNativeOnlyReturnValue, 1, 0 }; + yield return new object[] { CodeSnippets.CustomStructMarshalling.Stateless.NativeToManagedOnlyInParameter, 1, 0 }; + yield return new object[] { CodeSnippets.CustomStructMarshalling.Stateless.StackallocOnlyRefParameter, 1, 0 }; + yield return new object[] { CodeSnippets.CustomStructMarshalling.Stateful.ManagedToNativeOnlyOutParameter, 1, 0 }; + yield return new object[] { CodeSnippets.CustomStructMarshalling.Stateful.ManagedToNativeOnlyReturnValue, 1, 0 }; + yield return new object[] { CodeSnippets.CustomStructMarshalling.Stateful.NativeToManagedOnlyInParameter, 1, 0 }; + yield return new object[] { CodeSnippets.CustomStructMarshalling.Stateful.StackallocOnlyRefParameter, 1, 0 }; yield return new object[] { CodeSnippets.CustomStructMarshalling_V1.TwoStageRefReturn, 3, 0 }; yield return new object[] { CodeSnippets.CustomStructMarshalling_V1.ManagedToNativeOnlyOutParameter, 1, 0 }; yield return new object[] { CodeSnippets.CustomStructMarshalling_V1.ManagedToNativeOnlyReturnValue, 1, 0 }; diff --git a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/Compiles.cs b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/Compiles.cs index 202d8a782d008..725aba057ce43 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/Compiles.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/Compiles.cs @@ -170,15 +170,26 @@ public static IEnumerable CodeSnippetsToCompile() yield return new[] { CodeSnippets.SafeHandleWithCustomDefaultConstructorAccessibility(privateCtor: true) }; // Custom type marshalling - yield return new[] { CodeSnippets.CustomStructMarshalling.ParametersAndModifiers }; - yield return new[] { CodeSnippets.CustomStructMarshalling.MarshalUsingParametersAndModifiers }; - yield return new[] { CodeSnippets.CustomStructMarshalling.NativeToManagedOnlyOutParameter }; - yield return new[] { CodeSnippets.CustomStructMarshalling.NativeToManagedGuaranteedOnlyOutParameter }; - yield return new[] { CodeSnippets.CustomStructMarshalling.NativeToManagedOnlyReturnValue }; - yield return new[] { CodeSnippets.CustomStructMarshalling.NativeToManagedGuaranteedOnlyReturnValue }; - yield return new[] { CodeSnippets.CustomStructMarshalling.StackallocByValueInParameter }; - yield return new[] { CodeSnippets.CustomStructMarshalling.StackallocParametersAndModifiersNoRef }; - yield return new[] { CodeSnippets.CustomStructMarshalling.OptionalStackallocParametersAndModifiers }; + yield return new[] { CodeSnippets.CustomStructMarshalling.Stateless.ParametersAndModifiers }; + yield return new[] { CodeSnippets.CustomStructMarshalling.Stateless.MarshalUsingParametersAndModifiers }; + yield return new[] { CodeSnippets.CustomStructMarshalling.Stateless.NativeToManagedOnlyOutParameter }; + yield return new[] { CodeSnippets.CustomStructMarshalling.Stateless.NativeToManagedGuaranteedOnlyOutParameter }; + yield return new[] { CodeSnippets.CustomStructMarshalling.Stateless.NativeToManagedOnlyReturnValue }; + yield return new[] { CodeSnippets.CustomStructMarshalling.Stateless.NativeToManagedGuaranteedOnlyReturnValue }; + yield return new[] { CodeSnippets.CustomStructMarshalling.Stateless.StackallocByValueInParameter }; + yield return new[] { CodeSnippets.CustomStructMarshalling.Stateless.StackallocParametersAndModifiersNoRef }; + yield return new[] { CodeSnippets.CustomStructMarshalling.Stateless.OptionalStackallocParametersAndModifiers }; + yield return new[] { CodeSnippets.CustomStructMarshalling.Stateful.ParametersAndModifiers }; + yield return new[] { CodeSnippets.CustomStructMarshalling.Stateful.ParametersAndModifiersWithFree }; + yield return new[] { CodeSnippets.CustomStructMarshalling.Stateful.ParametersAndModifiersWithNotifyInvokeSucceeded }; + yield return new[] { CodeSnippets.CustomStructMarshalling.Stateful.MarshalUsingParametersAndModifiers }; + yield return new[] { CodeSnippets.CustomStructMarshalling.Stateful.NativeToManagedOnlyOutParameter }; + yield return new[] { CodeSnippets.CustomStructMarshalling.Stateful.NativeToManagedGuaranteedOnlyOutParameter }; + yield return new[] { CodeSnippets.CustomStructMarshalling.Stateful.NativeToManagedOnlyReturnValue }; + yield return new[] { CodeSnippets.CustomStructMarshalling.Stateful.NativeToManagedGuaranteedOnlyReturnValue }; + yield return new[] { CodeSnippets.CustomStructMarshalling.Stateful.StackallocByValueInParameter }; + yield return new[] { CodeSnippets.CustomStructMarshalling.Stateful.StackallocParametersAndModifiersNoRef }; + yield return new[] { CodeSnippets.CustomStructMarshalling.Stateful.OptionalStackallocParametersAndModifiers }; yield return new[] { CodeSnippets.CustomStructMarshalling_V1.ParametersAndModifiers }; yield return new[] { CodeSnippets.CustomStructMarshalling_V1.StackallocParametersAndModifiersNoRef }; yield return new[] { CodeSnippets.CustomStructMarshalling_V1.StackallocTwoStageParametersAndModifiersNoRef }; diff --git a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/NonBlittable.cs b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/NonBlittable.cs index 609ef8c8ab5d2..f9f34a6e4dec5 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/NonBlittable.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/NonBlittable.cs @@ -9,6 +9,7 @@ using System.Runtime.InteropServices; using System.Runtime.InteropServices.Marshalling; using System.Text; +using static SharedTypes.IntWrapperWithNotificationMarshaller; namespace SharedTypes { @@ -164,4 +165,121 @@ public static void Free(int* unmanaged) Marshal.FreeCoTaskMem((IntPtr)unmanaged); } } + + + [CustomMarshaller(typeof(IntWrapper), Scenario.Default, typeof(Marshaller))] + public static unsafe class IntWrapperMarshallerStateful + { + public struct Marshaller + { + private IntWrapper managed; + private int* native; + public void FromManaged(IntWrapper wrapper) + { + managed = wrapper; + } + + public int* ToUnmanaged() + { + native = (int*)Marshal.AllocCoTaskMem(sizeof(int)); + *native = managed.i; + return native; + } + + public void FromUnmanaged(int* value) + { + native = value; + } + + public IntWrapper ToManaged() => managed = new IntWrapper() { i = *native }; + + public void Free() + { + Marshal.FreeCoTaskMem((IntPtr)native); + } + } + } + + [NativeMarshalling(typeof(IntWrapperWithNotificationMarshaller))] + public struct IntWrapperWithNotification + { + [ThreadStatic] + public static int NumInvokeSucceededOnUninitialized = 0; + + private bool initialized; + public int Value; + public event EventHandler InvokeSucceeded; + + public IntWrapperWithNotification() + { + initialized = true; + } + + public void RaiseInvokeSucceeded() + { + if (!initialized) + { + NumInvokeSucceededOnUninitialized++; + } + InvokeSucceeded?.Invoke(this, EventArgs.Empty); + } + } + + [CustomMarshaller(typeof(IntWrapperWithNotification), Scenario.Default, typeof(Marshaller))] + public static class IntWrapperWithNotificationMarshaller + { + public struct Marshaller + { + private IntWrapperWithNotification _managed; + + public void FromManaged(IntWrapperWithNotification managed) =>_managed = managed; + + public int ToUnmanaged() => _managed.Value; + + public void FromUnmanaged(int i) => _managed.Value = i; + + public IntWrapperWithNotification ToManaged() => _managed; + + public void NotifyInvokeSucceeded() => _managed.RaiseInvokeSucceeded(); + } + } + + [CustomMarshaller(typeof(BoolStruct), Scenario.Default, typeof(Marshaller))] + public static class BoolStructMarshallerStateful + { + public struct BoolStructNative + { + public byte b1; + public byte b2; + public byte b3; + } + + public struct Marshaller + { + private BoolStructNative _boolStructNative; + public void FromManaged(BoolStruct managed) + { + _boolStructNative = new BoolStructNative + { + b1 = (byte)(managed.b1 ? 1 : 0), + b2 = (byte)(managed.b2 ? 1 : 0), + b3 = (byte)(managed.b3 ? 1 : 0) + }; + } + + public BoolStructNative ToUnmanaged() => _boolStructNative; + + public void FromUnmanaged(BoolStructNative value) => _boolStructNative = value; + + public BoolStruct ToManaged() + { + return new BoolStruct + { + b1 = _boolStructNative.b1 != 0, + b2 = _boolStructNative.b2 != 0, + b3 = _boolStructNative.b3 != 0 + }; + } + } + } }