Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Structs that represent COM interfaces prefer non-PreserveSig #793

Merged
merged 1 commit into from
Nov 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
93 changes: 72 additions & 21 deletions src/Microsoft.Windows.CsWin32/Generator.Com.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ namespace Microsoft.Windows.CsWin32;

public partial class Generator
{
private static readonly IdentifierNameSyntax HRThrowOnFailureMethodName = IdentifierName("ThrowOnFailure");
private readonly HashSet<string> injectedPInvokeHelperMethodsToFriendlyOverloadsExtensions = new();

private static Guid DecodeGuidFromAttribute(CustomAttribute guidAttribute)
Expand Down Expand Up @@ -117,13 +118,13 @@ private TypeDeclarationSyntax DeclareInterfaceAsStruct(TypeDefinitionHandle type

MethodSignature<TypeHandleInfo> signature = methodDefinition.Method.DecodeSignature(SignatureHandleProvider.Instance, null);
CustomAttributeHandleCollection? returnTypeAttributes = methodDefinition.Generator.GetReturnTypeCustomAttributes(methodDefinition.Method);
TypeSyntaxAndMarshaling returnType = signature.ReturnType.ToTypeSyntax(typeSettings, returnTypeAttributes);
TypeSyntax returnType = signature.ReturnType.ToTypeSyntax(typeSettings, returnTypeAttributes).Type;

ParameterListSyntax parameterList = methodDefinition.Generator.CreateParameterList(methodDefinition.Method, signature, typeSettings);
FunctionPointerParameterListSyntax funcPtrParameters = FunctionPointerParameterList()
.AddParameters(FunctionPointerParameter(PointerType(ifaceName)))
.AddParameters(parameterList.Parameters.Select(p => FunctionPointerParameter(p.Type!).WithModifiers(p.Modifiers)).ToArray())
.AddParameters(FunctionPointerParameter(returnType.Type));
.AddParameters(FunctionPointerParameter(returnType));

TypeSyntax unmanagedDelegateType = FunctionPointerType().WithCallingConvention(
FunctionPointerCallingConvention(TokenWithSpace(SyntaxKind.UnmanagedKeyword))
Expand All @@ -139,6 +140,7 @@ private TypeDeclarationSyntax DeclareInterfaceAsStruct(TypeDefinitionHandle type
// Build up an unmanaged delegate cast directly from the vtbl pointer and invoke it.
// By doing this, we make the emitted code more trimmable by not referencing the full virtual method table and its full set of types
// when the app may only invoke a subset of the methods.
//// ((delegate *unmanaged [Stdcall]<IPersist*,global::System.Guid* ,winmdroot.Foundation.HRESULT>)lpVtbl[3])(pThis, pClassID)
IdentifierNameSyntax pThisLocal = IdentifierName("pThis");
ExpressionSyntax vtblIndexingExpression = ParenthesizedExpression(
CastExpression(unmanagedDelegateType, ElementAccessExpression(vtblFieldName).AddArgumentListArguments(Argument(methodOffset))));
Expand All @@ -156,7 +158,7 @@ private TypeDeclarationSyntax DeclareInterfaceAsStruct(TypeDefinitionHandle type
declaredProperties.Contains(propertyName.Identifier.ValueText))
{
StatementSyntax ThrowOnHRFailure(ExpressionSyntax hrExpression) => ExpressionStatement(InvocationExpression(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, hrExpression, IdentifierName("ThrowOnFailure")),
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, hrExpression, HRThrowOnFailureMethodName),
ArgumentList()));

BlockSyntax? body;
Expand Down Expand Up @@ -227,27 +229,73 @@ private TypeDeclarationSyntax DeclareInterfaceAsStruct(TypeDefinitionHandle type
}
else
{
StatementSyntax vtblInvocationStatement = IsVoid(returnType.Type)
? ExpressionStatement(vtblInvocation)
: ReturnStatement(vtblInvocation);
BlockSyntax? body = Block().AddStatements(
FixedStatement(
VariableDeclaration(PointerType(ifaceName)).AddVariables(
VariableDeclarator(pThisLocal.Identifier).WithInitializer(EqualsValueClause(PrefixUnaryExpression(SyntaxKind.AddressOfExpression, ThisExpression())))),
vtblInvocationStatement).WithFixedKeyword(TokenWithSpace(SyntaxKind.FixedKeyword)));
StatementSyntax fixedBody;
bool preserveSig = this.UsePreserveSigForComMethod(methodDefinition.Method, signature, ifaceName.Identifier.ValueText, methodName);
if (preserveSig)
{
// return ...
fixedBody = IsVoid(returnType)
? ExpressionStatement(vtblInvocation)
: ReturnStatement(vtblInvocation);
}
else
{
// hrReturningInvocation().ThrowOnFailure();
StatementSyntax InvokeVtblAndThrow() => ExpressionStatement(InvocationExpression(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, vtblInvocation, HRThrowOnFailureMethodName),
ArgumentList()));

ParameterSyntax? lastParameter = parameterList.Parameters.Count > 0 ? parameterList.Parameters[parameterList.Parameters.Count - 1] : null;
if (lastParameter?.HasAnnotation(IsRetValAnnotation) is true)
{
// Move the retval parameter to the return value position.
parameterList = parameterList.WithParameters(parameterList.Parameters.RemoveAt(parameterList.Parameters.Count - 1));
returnType = lastParameter.Modifiers.Any(SyntaxKind.OutKeyword) ? lastParameter.Type! : ((PointerTypeSyntax)lastParameter.Type!).ElementType;

// Guid __retVal = default(Guid);
IdentifierNameSyntax retValLocalName = IdentifierName("__retVal");
LocalDeclarationStatementSyntax localRetValDecl = LocalDeclarationStatement(VariableDeclaration(returnType).AddVariables(
VariableDeclarator(retValLocalName.Identifier).WithInitializer(EqualsValueClause(DefaultExpression(returnType)))));

// Modify the vtbl invocation's last argument to point to our own local variable.
ArgumentSyntax lastArgument = lastParameter.Modifiers.Any(SyntaxKind.OutKeyword)
? Argument(retValLocalName).WithRefKindKeyword(TokenWithSpace(SyntaxKind.OutKeyword))
: Argument(PrefixUnaryExpression(SyntaxKind.AddressOfExpression, retValLocalName));
vtblInvocation = vtblInvocation.WithArgumentList(
vtblInvocation.ArgumentList.WithArguments(vtblInvocation.ArgumentList.Arguments.Replace(vtblInvocation.ArgumentList.Arguments.Last(), lastArgument)));

// return __retVal;
ReturnStatementSyntax returnStatement = ReturnStatement(retValLocalName);

fixedBody = Block().AddStatements(localRetValDecl, InvokeVtblAndThrow(), returnStatement);
}
else
{
// Remove the return type
returnType = PredefinedType(Token(SyntaxKind.VoidKeyword));

fixedBody = InvokeVtblAndThrow();
}
}

// fixed (IPersist* pThis = &this)
FixedStatementSyntax fixedStatement = FixedStatement(
VariableDeclaration(PointerType(ifaceName)).AddVariables(
VariableDeclarator(pThisLocal.Identifier).WithInitializer(EqualsValueClause(PrefixUnaryExpression(SyntaxKind.AddressOfExpression, ThisExpression())))),
fixedBody).WithFixedKeyword(TokenWithSpace(SyntaxKind.FixedKeyword));
BlockSyntax body = Block().AddStatements(fixedStatement);

methodDeclaration = MethodDeclaration(
List<AttributeListSyntax>(),
modifiers: TokenList(TokenWithSpace(SyntaxKind.PublicKeyword)), // always use public so struct can implement the COM interface
returnType.Type.WithTrailingTrivia(TriviaList(Space)),
returnType.WithTrailingTrivia(TriviaList(Space)),
explicitInterfaceSpecifier: null!,
SafeIdentifier(methodName),
null!,
parameterList,
List<TypeParameterConstraintClauseSyntax>(),
body: body,
semicolonToken: default);
methodDeclaration = returnType.AddReturnMarshalAs(methodDeclaration);

if (methodName == nameof(object.GetType) && parameterList.Parameters.Count == 0)
{
Expand Down Expand Up @@ -428,14 +476,7 @@ private TypeDeclarationSyntax DeclareInterfaceAsStruct(TypeDefinitionHandle type

ParameterListSyntax? parameterList = this.CreateParameterList(methodDefinition, signature, this.comSignatureTypeSettings);

bool preserveSig = interfaceAsSubtype
|| !IsHresult(signature.ReturnType)
|| (methodDefinition.ImplAttributes & MethodImplAttributes.PreserveSig) == MethodImplAttributes.PreserveSig
|| this.FindInteropDecorativeAttribute(methodDefinition.GetCustomAttributes(), CanReturnMultipleSuccessValuesAttribute) is not null
|| this.FindInteropDecorativeAttribute(methodDefinition.GetCustomAttributes(), CanReturnErrorsAsSuccessAttribute) is not null
|| this.options.ComInterop.PreserveSigMethods.Contains($"{ifaceName}.{methodName}")
|| this.options.ComInterop.PreserveSigMethods.Contains(ifaceName.ToString());

bool preserveSig = interfaceAsSubtype || this.UsePreserveSigForComMethod(methodDefinition, signature, actualIfaceName, methodName);
if (!preserveSig)
{
ParameterSyntax? lastParameter = parameterList.Parameters.Count > 0 ? parameterList.Parameters[parameterList.Parameters.Count - 1] : null;
Expand Down Expand Up @@ -614,6 +655,16 @@ private unsafe (List<MemberDeclarationSyntax> Members, List<BaseTypeSyntax> Base
return (members, baseTypes);
}

private bool UsePreserveSigForComMethod(MethodDefinition methodDefinition, MethodSignature<TypeHandleInfo> signature, string ifaceName, string methodName)
{
return !IsHresult(signature.ReturnType)
|| (methodDefinition.ImplAttributes & MethodImplAttributes.PreserveSig) == MethodImplAttributes.PreserveSig
|| this.FindInteropDecorativeAttribute(methodDefinition.GetCustomAttributes(), CanReturnMultipleSuccessValuesAttribute) is not null
|| this.FindInteropDecorativeAttribute(methodDefinition.GetCustomAttributes(), CanReturnErrorsAsSuccessAttribute) is not null
|| this.options.ComInterop.PreserveSigMethods.Contains($"{ifaceName}.{methodName}")
|| this.options.ComInterop.PreserveSigMethods.Contains(ifaceName.ToString());
}

private ISet<string> GetDeclarableProperties(IEnumerable<MethodDefinition> methods, bool allowNonConsecutiveAccessors)
{
Dictionary<string, (TypeSyntax Type, int Index)> goodProperties = new(StringComparer.Ordinal);
Expand Down
14 changes: 14 additions & 0 deletions test/GenerationSandbox.Unmarshalled.Tests/GeneratedForm.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,23 @@

#pragma warning disable CA1812 // dead code

using Windows.Win32.Foundation;
using Windows.Win32.System.Com;
using Windows.Win32.System.Com.Events;

/// <summary>
/// Contains "tests" that never run. Merely compiling is enough to verify the generated code has the right API shape.
/// </summary>
internal static unsafe class GeneratedForm
{
private static unsafe void COMStructsPreserveSig()
{
IEventSubscription o = default;

// Default is non-preservesig
VARIANT v = o.GetPublisherProperty(null);

// NativeMethods.json opts into PreserveSig for this particular method.
HRESULT hr = o.GetSubscriberProperty(null, out v);
}
}
7 changes: 6 additions & 1 deletion test/GenerationSandbox.Unmarshalled.Tests/NativeMethods.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,10 @@
"$schema": "..\\..\\src\\Microsoft.Windows.CsWin32\\settings.schema.json",
"emitSingleFile": true,
"multiTargetingFriendlyAPIs": true,
"allowMarshaling": false
"allowMarshaling": false,
"comInterop": {
"preserveSigMethods": [
"IEventSubscription.GetSubscriberProperty"
]
}
}
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
IPersistFile
IEventSubscription
1 change: 1 addition & 0 deletions test/Microsoft.Windows.CsWin32.Tests/GeneratorTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ public void SupportedOSPlatform_AppearsOnFriendlyOverloads()
"PZZWSTR",
"PCZZSTR",
"PCZZWSTR",
"IEventSubscription",
"IRealTimeStylusSynchronization", // uses the `lock` C# keyword.
"IHTMLInputElement", // has a field named `checked`, a C# keyword.
"NCryptImportKey", // friendly overload takes SafeHandle backed by a UIntPtr instead of IntPtr
Expand Down