From 8ec3156bd37c984f2a08d94019046c50dab2aae3 Mon Sep 17 00:00:00 2001 From: Andrew Arnott Date: Fri, 18 Nov 2022 11:38:30 -0700 Subject: [PATCH] Structs that represent COM interfaces prefer non-PreserveSig Fixes #787 --- .../Generator.Com.cs | 93 ++++++++++++++----- .../GeneratedForm.cs | 14 +++ .../NativeMethods.json | 7 +- .../NativeMethods.txt | 1 + .../GeneratorTests.cs | 1 + 5 files changed, 94 insertions(+), 22 deletions(-) diff --git a/src/Microsoft.Windows.CsWin32/Generator.Com.cs b/src/Microsoft.Windows.CsWin32/Generator.Com.cs index 503c65e7..dbd504d2 100644 --- a/src/Microsoft.Windows.CsWin32/Generator.Com.cs +++ b/src/Microsoft.Windows.CsWin32/Generator.Com.cs @@ -5,6 +5,7 @@ namespace Microsoft.Windows.CsWin32; public partial class Generator { + private static readonly IdentifierNameSyntax HRThrowOnFailureMethodName = IdentifierName("ThrowOnFailure"); private readonly HashSet injectedPInvokeHelperMethodsToFriendlyOverloadsExtensions = new(); private static Guid DecodeGuidFromAttribute(CustomAttribute guidAttribute) @@ -117,13 +118,13 @@ private TypeDeclarationSyntax DeclareInterfaceAsStruct(TypeDefinitionHandle type MethodSignature signature = methodDefinition.Method.DecodeSignature(SignatureHandleProvider.Instance, null); CustomAttributeHandleCollection? returnTypeAttributes = methodDefinition.Generator.GetReturnTypeCustomAttributes(methodDefinition.Method); - TypeSyntaxAndMarshaling returnType = signature.ReturnType.ToTypeSyntax(typeSettings, returnTypeAttributes); + TypeSyntax returnType = signature.ReturnType.ToTypeSyntax(typeSettings, returnTypeAttributes).Type; ParameterListSyntax parameterList = methodDefinition.Generator.CreateParameterList(methodDefinition.Method, signature, typeSettings); FunctionPointerParameterListSyntax funcPtrParameters = FunctionPointerParameterList() .AddParameters(FunctionPointerParameter(PointerType(ifaceName))) .AddParameters(parameterList.Parameters.Select(p => FunctionPointerParameter(p.Type!).WithModifiers(p.Modifiers)).ToArray()) - .AddParameters(FunctionPointerParameter(returnType.Type)); + .AddParameters(FunctionPointerParameter(returnType)); TypeSyntax unmanagedDelegateType = FunctionPointerType().WithCallingConvention( FunctionPointerCallingConvention(TokenWithSpace(SyntaxKind.UnmanagedKeyword)) @@ -139,6 +140,7 @@ private TypeDeclarationSyntax DeclareInterfaceAsStruct(TypeDefinitionHandle type // Build up an unmanaged delegate cast directly from the vtbl pointer and invoke it. // By doing this, we make the emitted code more trimmable by not referencing the full virtual method table and its full set of types // when the app may only invoke a subset of the methods. + //// ((delegate *unmanaged [Stdcall])lpVtbl[3])(pThis, pClassID) IdentifierNameSyntax pThisLocal = IdentifierName("pThis"); ExpressionSyntax vtblIndexingExpression = ParenthesizedExpression( CastExpression(unmanagedDelegateType, ElementAccessExpression(vtblFieldName).AddArgumentListArguments(Argument(methodOffset)))); @@ -156,7 +158,7 @@ private TypeDeclarationSyntax DeclareInterfaceAsStruct(TypeDefinitionHandle type declaredProperties.Contains(propertyName.Identifier.ValueText)) { StatementSyntax ThrowOnHRFailure(ExpressionSyntax hrExpression) => ExpressionStatement(InvocationExpression( - MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, hrExpression, IdentifierName("ThrowOnFailure")), + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, hrExpression, HRThrowOnFailureMethodName), ArgumentList())); BlockSyntax? body; @@ -227,19 +229,66 @@ private TypeDeclarationSyntax DeclareInterfaceAsStruct(TypeDefinitionHandle type } else { - StatementSyntax vtblInvocationStatement = IsVoid(returnType.Type) - ? ExpressionStatement(vtblInvocation) - : ReturnStatement(vtblInvocation); - BlockSyntax? body = Block().AddStatements( - FixedStatement( - VariableDeclaration(PointerType(ifaceName)).AddVariables( - VariableDeclarator(pThisLocal.Identifier).WithInitializer(EqualsValueClause(PrefixUnaryExpression(SyntaxKind.AddressOfExpression, ThisExpression())))), - vtblInvocationStatement).WithFixedKeyword(TokenWithSpace(SyntaxKind.FixedKeyword))); + StatementSyntax fixedBody; + bool preserveSig = this.UsePreserveSigForComMethod(methodDefinition.Method, signature, ifaceName.Identifier.ValueText, methodName); + if (preserveSig) + { + // return ... + fixedBody = IsVoid(returnType) + ? ExpressionStatement(vtblInvocation) + : ReturnStatement(vtblInvocation); + } + else + { + // hrReturningInvocation().ThrowOnFailure(); + StatementSyntax InvokeVtblAndThrow() => ExpressionStatement(InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, vtblInvocation, HRThrowOnFailureMethodName), + ArgumentList())); + + ParameterSyntax? lastParameter = parameterList.Parameters.Count > 0 ? parameterList.Parameters[parameterList.Parameters.Count - 1] : null; + if (lastParameter?.HasAnnotation(IsRetValAnnotation) is true) + { + // Move the retval parameter to the return value position. + parameterList = parameterList.WithParameters(parameterList.Parameters.RemoveAt(parameterList.Parameters.Count - 1)); + returnType = lastParameter.Modifiers.Any(SyntaxKind.OutKeyword) ? lastParameter.Type! : ((PointerTypeSyntax)lastParameter.Type!).ElementType; + + // Guid __retVal = default(Guid); + IdentifierNameSyntax retValLocalName = IdentifierName("__retVal"); + LocalDeclarationStatementSyntax localRetValDecl = LocalDeclarationStatement(VariableDeclaration(returnType).AddVariables( + VariableDeclarator(retValLocalName.Identifier).WithInitializer(EqualsValueClause(DefaultExpression(returnType))))); + + // Modify the vtbl invocation's last argument to point to our own local variable. + ArgumentSyntax lastArgument = lastParameter.Modifiers.Any(SyntaxKind.OutKeyword) + ? Argument(retValLocalName).WithRefKindKeyword(TokenWithSpace(SyntaxKind.OutKeyword)) + : Argument(PrefixUnaryExpression(SyntaxKind.AddressOfExpression, retValLocalName)); + vtblInvocation = vtblInvocation.WithArgumentList( + vtblInvocation.ArgumentList.WithArguments(vtblInvocation.ArgumentList.Arguments.Replace(vtblInvocation.ArgumentList.Arguments.Last(), lastArgument))); + + // return __retVal; + ReturnStatementSyntax returnStatement = ReturnStatement(retValLocalName); + + fixedBody = Block().AddStatements(localRetValDecl, InvokeVtblAndThrow(), returnStatement); + } + else + { + // Remove the return type + returnType = PredefinedType(Token(SyntaxKind.VoidKeyword)); + + fixedBody = InvokeVtblAndThrow(); + } + } + + // fixed (IPersist* pThis = &this) + FixedStatementSyntax fixedStatement = FixedStatement( + VariableDeclaration(PointerType(ifaceName)).AddVariables( + VariableDeclarator(pThisLocal.Identifier).WithInitializer(EqualsValueClause(PrefixUnaryExpression(SyntaxKind.AddressOfExpression, ThisExpression())))), + fixedBody).WithFixedKeyword(TokenWithSpace(SyntaxKind.FixedKeyword)); + BlockSyntax body = Block().AddStatements(fixedStatement); methodDeclaration = MethodDeclaration( List(), modifiers: TokenList(TokenWithSpace(SyntaxKind.PublicKeyword)), // always use public so struct can implement the COM interface - returnType.Type.WithTrailingTrivia(TriviaList(Space)), + returnType.WithTrailingTrivia(TriviaList(Space)), explicitInterfaceSpecifier: null!, SafeIdentifier(methodName), null!, @@ -247,7 +296,6 @@ private TypeDeclarationSyntax DeclareInterfaceAsStruct(TypeDefinitionHandle type List(), body: body, semicolonToken: default); - methodDeclaration = returnType.AddReturnMarshalAs(methodDeclaration); if (methodName == nameof(object.GetType) && parameterList.Parameters.Count == 0) { @@ -428,14 +476,7 @@ private TypeDeclarationSyntax DeclareInterfaceAsStruct(TypeDefinitionHandle type ParameterListSyntax? parameterList = this.CreateParameterList(methodDefinition, signature, this.comSignatureTypeSettings); - bool preserveSig = interfaceAsSubtype - || !IsHresult(signature.ReturnType) - || (methodDefinition.ImplAttributes & MethodImplAttributes.PreserveSig) == MethodImplAttributes.PreserveSig - || this.FindInteropDecorativeAttribute(methodDefinition.GetCustomAttributes(), CanReturnMultipleSuccessValuesAttribute) is not null - || this.FindInteropDecorativeAttribute(methodDefinition.GetCustomAttributes(), CanReturnErrorsAsSuccessAttribute) is not null - || this.options.ComInterop.PreserveSigMethods.Contains($"{ifaceName}.{methodName}") - || this.options.ComInterop.PreserveSigMethods.Contains(ifaceName.ToString()); - + bool preserveSig = interfaceAsSubtype || this.UsePreserveSigForComMethod(methodDefinition, signature, actualIfaceName, methodName); if (!preserveSig) { ParameterSyntax? lastParameter = parameterList.Parameters.Count > 0 ? parameterList.Parameters[parameterList.Parameters.Count - 1] : null; @@ -614,6 +655,16 @@ private unsafe (List Members, List Base return (members, baseTypes); } + private bool UsePreserveSigForComMethod(MethodDefinition methodDefinition, MethodSignature signature, string ifaceName, string methodName) + { + return !IsHresult(signature.ReturnType) + || (methodDefinition.ImplAttributes & MethodImplAttributes.PreserveSig) == MethodImplAttributes.PreserveSig + || this.FindInteropDecorativeAttribute(methodDefinition.GetCustomAttributes(), CanReturnMultipleSuccessValuesAttribute) is not null + || this.FindInteropDecorativeAttribute(methodDefinition.GetCustomAttributes(), CanReturnErrorsAsSuccessAttribute) is not null + || this.options.ComInterop.PreserveSigMethods.Contains($"{ifaceName}.{methodName}") + || this.options.ComInterop.PreserveSigMethods.Contains(ifaceName.ToString()); + } + private ISet GetDeclarableProperties(IEnumerable methods, bool allowNonConsecutiveAccessors) { Dictionary goodProperties = new(StringComparer.Ordinal); diff --git a/test/GenerationSandbox.Unmarshalled.Tests/GeneratedForm.cs b/test/GenerationSandbox.Unmarshalled.Tests/GeneratedForm.cs index 03dbd16c..ba58436a 100644 --- a/test/GenerationSandbox.Unmarshalled.Tests/GeneratedForm.cs +++ b/test/GenerationSandbox.Unmarshalled.Tests/GeneratedForm.cs @@ -3,9 +3,23 @@ #pragma warning disable CA1812 // dead code +using Windows.Win32.Foundation; +using Windows.Win32.System.Com; +using Windows.Win32.System.Com.Events; + /// /// Contains "tests" that never run. Merely compiling is enough to verify the generated code has the right API shape. /// internal static unsafe class GeneratedForm { + private static unsafe void COMStructsPreserveSig() + { + IEventSubscription o = default; + + // Default is non-preservesig + VARIANT v = o.GetPublisherProperty(null); + + // NativeMethods.json opts into PreserveSig for this particular method. + HRESULT hr = o.GetSubscriberProperty(null, out v); + } } diff --git a/test/GenerationSandbox.Unmarshalled.Tests/NativeMethods.json b/test/GenerationSandbox.Unmarshalled.Tests/NativeMethods.json index 11e28dd2..d22036ed 100644 --- a/test/GenerationSandbox.Unmarshalled.Tests/NativeMethods.json +++ b/test/GenerationSandbox.Unmarshalled.Tests/NativeMethods.json @@ -2,5 +2,10 @@ "$schema": "..\\..\\src\\Microsoft.Windows.CsWin32\\settings.schema.json", "emitSingleFile": true, "multiTargetingFriendlyAPIs": true, - "allowMarshaling": false + "allowMarshaling": false, + "comInterop": { + "preserveSigMethods": [ + "IEventSubscription.GetSubscriberProperty" + ] + } } diff --git a/test/GenerationSandbox.Unmarshalled.Tests/NativeMethods.txt b/test/GenerationSandbox.Unmarshalled.Tests/NativeMethods.txt index c110008a..0e0b9e49 100644 --- a/test/GenerationSandbox.Unmarshalled.Tests/NativeMethods.txt +++ b/test/GenerationSandbox.Unmarshalled.Tests/NativeMethods.txt @@ -1 +1,2 @@ IPersistFile +IEventSubscription diff --git a/test/Microsoft.Windows.CsWin32.Tests/GeneratorTests.cs b/test/Microsoft.Windows.CsWin32.Tests/GeneratorTests.cs index fba074ee..7d1d5e8a 100644 --- a/test/Microsoft.Windows.CsWin32.Tests/GeneratorTests.cs +++ b/test/Microsoft.Windows.CsWin32.Tests/GeneratorTests.cs @@ -152,6 +152,7 @@ public void SupportedOSPlatform_AppearsOnFriendlyOverloads() "PZZWSTR", "PCZZSTR", "PCZZWSTR", + "IEventSubscription", "IRealTimeStylusSynchronization", // uses the `lock` C# keyword. "IHTMLInputElement", // has a field named `checked`, a C# keyword. "NCryptImportKey", // friendly overload takes SafeHandle backed by a UIntPtr instead of IntPtr