Skip to content

Commit

Permalink
Merge pull request #829 from microsoft/fix751
Browse files Browse the repository at this point in the history
Many more CCW and other fixes
  • Loading branch information
AArnott committed Dec 10, 2022
2 parents d1c2fd7 + b0d8b7f commit 76e706e
Show file tree
Hide file tree
Showing 10 changed files with 197 additions and 55 deletions.
2 changes: 1 addition & 1 deletion Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
<PackageVersion Include="System.Runtime.CompilerServices.Unsafe" Version="6.0.0" />
<PackageVersion Include="System.Text.Encodings.Web" Version="4.7.2" />
<PackageVersion Include="System.Text.Json" Version="4.7.2" />
<PackageVersion Include="Xunit.Combinatorial" Version="1.5.25" />
<PackageVersion Include="Xunit.Combinatorial" Version="1.6.12-alpha" />
<PackageVersion Include="xunit.runner.visualstudio" Version="2.4.5" />
<PackageVersion Include="xunit" Version="2.4.2" />
</ItemGroup>
Expand Down
92 changes: 60 additions & 32 deletions src/Microsoft.Windows.CsWin32/Generator.Com.cs
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,9 @@ private TypeDeclarationSyntax DeclareInterfaceAsStruct(TypeDefinitionHandle type
var vtblMembers = new List<MemberDeclarationSyntax>();
TypeSyntaxSettings typeSettings = this.comSignatureTypeSettings;
IdentifierNameSyntax pThisLocal = IdentifierName("pThis");
ParameterSyntax? ccwThisParameter = this.canUseUnmanagedCallersOnlyAttribute && !this.options.AllowMarshaling && originalIfaceName != "IUnknown" && originalIfaceName != "IDispatch" ? Parameter(pThisLocal.Identifier).WithType(PointerType(ifaceName).WithTrailingTrivia(Space)) : null;
ParameterSyntax? ccwThisParameter = this.canUseUnmanagedCallersOnlyAttribute && !this.options.AllowMarshaling && originalIfaceName != "IUnknown" && originalIfaceName != "IDispatch" && !this.IsNonCOMInterface(typeDef) ? Parameter(pThisLocal.Identifier).WithType(PointerType(ifaceName).WithTrailingTrivia(Space)) : null;
List<QualifiedMethodDefinitionHandle> ccwMethodsToSkip = new();
List<MemberDeclarationSyntax> ccwEntrypointMethods = new();
IdentifierNameSyntax vtblParamName = IdentifierName("vtable");
BlockSyntax populateVTableBody = Block();
IdentifierNameSyntax objectLocal = IdentifierName("__object");
Expand All @@ -119,7 +120,7 @@ private TypeDeclarationSyntax DeclareInterfaceAsStruct(TypeDefinitionHandle type

// We do *not* emit CCW methods for IUnknown, because those are provided by ComWrappers.
if (ccwThisParameter is not null &&
(qualifiedBaseType.Reader.StringComparer.Equals(baseType.Name, "IUnknown") || qualifiedBaseType.Reader.StringComparer.Equals(baseType.Name, "IDispatch")))
(qualifiedBaseType.Reader.StringComparer.Equals(baseType.Name, "IUnknown") || qualifiedBaseType.Reader.StringComparer.Equals(baseType.Name, "IDispatch") || qualifiedBaseType.Reader.StringComparer.Equals(baseType.Name, "IInspectable")))
{
ccwMethodsToSkip.AddRange(methodsThisType);
}
Expand All @@ -132,6 +133,8 @@ private TypeDeclarationSyntax DeclareInterfaceAsStruct(TypeDefinitionHandle type
allMethods.Select(qh => qh.Reader.GetMethodDefinition(qh.MethodHandle)),
originalIfaceName,
allowNonConsecutiveAccessors: true);
ISet<string>? ifaceDeclaredProperties = ccwThisParameter is not null ? this.GetDeclarableProperties(allMethods.Select(qh => qh.Reader.GetMethodDefinition(qh.MethodHandle)), originalIfaceName, allowNonConsecutiveAccessors: false) : null;

foreach (QualifiedMethodDefinitionHandle methodDefHandle in allMethods)
{
methodCounter++;
Expand All @@ -147,6 +150,7 @@ private TypeDeclarationSyntax DeclareInterfaceAsStruct(TypeDefinitionHandle type

ParameterListSyntax parameterList = methodDefinition.Generator.CreateParameterList(methodDefinition.Method, signature, typeSettings);
ParameterListSyntax parameterListPreserveSig = parameterList; // preserve a copy that has no mutations.
bool requiresMarshaling = parameterList.Parameters.Any(p => p.AttributeLists.SelectMany(al => al.Attributes).Any(a => a.Name is IdentifierNameSyntax { Identifier.ValueText: "MarshalAs" }) || p.Modifiers.Any(SyntaxKind.RefKeyword) || p.Modifiers.Any(SyntaxKind.OutKeyword) || p.Modifiers.Any(SyntaxKind.InKeyword));
FunctionPointerParameterListSyntax funcPtrParameters = FunctionPointerParameterList()
.AddParameters(FunctionPointerParameter(PointerType(ifaceName)))
.AddParameters(parameterList.Parameters.Select(p => FunctionPointerParameter(p.Type!).WithModifiers(p.Modifiers)).ToArray())
Expand Down Expand Up @@ -174,7 +178,7 @@ private TypeDeclarationSyntax DeclareInterfaceAsStruct(TypeDefinitionHandle type
.AddArguments(Argument(pThisLocal))
.AddArguments(parameterList.Parameters.Select(p => Argument(IdentifierName(p.Identifier.ValueText)).WithRefKindKeyword(p.Modifiers.Count > 0 ? p.Modifiers[0] : default)).ToArray())));

MemberDeclarationSyntax propertyOrMethod;
MemberDeclarationSyntax? propertyOrMethod;
MethodDeclarationSyntax? methodDeclaration = null;

// We can declare this method as a property accessor if it represents a property.
Expand Down Expand Up @@ -212,18 +216,6 @@ private TypeDeclarationSyntax DeclareInterfaceAsStruct(TypeDefinitionHandle type
resultLocalDeclaration,
vtblInvocationStatement,
returnStatement)).WithFixedKeyword(TokenWithSpace(SyntaxKind.FixedKeyword)));

if (ccwThisParameter is not null && !ccwMethodsToSkip.Contains(methodDefHandle))
{
//// *inputArg = @object.Property;
StatementSyntax propertyGet = ExpressionStatement(AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
PrefixUnaryExpression(SyntaxKind.PointerIndirectionExpression, IdentifierName(parameterListPreserveSig.Parameters.Last().Identifier.ValueText)),
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, objectLocal, propertyName)));
this.TryGenerateConstantOrThrow("S_OK");
AddCcwThunk(propertyGet, returnSOK);
}

break;
case SyntaxKind.SetAccessorDeclaration:
// vtblInvoke(pThis, value).ThrowOnFailure();
Expand All @@ -233,18 +225,6 @@ private TypeDeclarationSyntax DeclareInterfaceAsStruct(TypeDefinitionHandle type
VariableDeclaration(PointerType(ifaceName)).AddVariables(
VariableDeclarator(pThisLocal.Identifier).WithInitializer(EqualsValueClause(PrefixUnaryExpression(SyntaxKind.AddressOfExpression, ThisExpression())))),
vtblInvocationStatement).WithFixedKeyword(TokenWithSpace(SyntaxKind.FixedKeyword)));

if (ccwThisParameter is not null && !ccwMethodsToSkip.Contains(methodDefHandle))
{
//// @object.Property = inputArg;
StatementSyntax propertySet = ExpressionStatement(AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, objectLocal, propertyName),
IdentifierName(parameterListPreserveSig.Parameters.Last().Identifier.ValueText)));
this.TryGenerateConstantOrThrow("S_OK");
AddCcwThunk(propertySet, returnSOK);
}

break;
default:
throw new NotSupportedException("Unsupported accessor kind: " + accessorKind);
Expand All @@ -258,7 +238,7 @@ private TypeDeclarationSyntax DeclareInterfaceAsStruct(TypeDefinitionHandle type
// Add the accessor to the existing property declaration.
PropertyDeclarationSyntax priorDeclaration = (PropertyDeclarationSyntax)members[priorPropertyDeclarationIndex];
members[priorPropertyDeclarationIndex] = priorDeclaration.WithAccessorList(priorDeclaration.AccessorList!.AddAccessors(accessor));
continue;
propertyOrMethod = null;
}
else
{
Expand Down Expand Up @@ -359,8 +339,38 @@ private TypeDeclarationSyntax DeclareInterfaceAsStruct(TypeDefinitionHandle type
propertyOrMethod = methodDeclaration;

members.AddRange(methodDefinition.Generator.DeclareFriendlyOverloads(methodDefinition.Method, methodDeclaration, IdentifierName(ifaceName.Identifier.ValueText), FriendlyOverloadOf.StructMethod, helperMethodsInStruct));
}

if (ccwThisParameter is not null && !ccwMethodsToSkip.Contains(methodDefHandle))
if (ccwThisParameter is not null && !ccwMethodsToSkip.Contains(methodDefHandle))
{
if (this.TryGetPropertyAccessorInfo(methodDefinition.Method, originalIfaceName, out propertyName, out accessorKind, out propertyType) &&
ifaceDeclaredProperties!.Contains(propertyName.Identifier.ValueText))
{
switch (accessorKind)
{
case SyntaxKind.GetAccessorDeclaration:
//// *inputArg = @object.Property;
StatementSyntax propertyGet = ExpressionStatement(AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
PrefixUnaryExpression(SyntaxKind.PointerIndirectionExpression, IdentifierName(parameterListPreserveSig.Parameters.Last().Identifier.ValueText)),
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, objectLocal, propertyName)));
this.TryGenerateConstantOrThrow("S_OK");
AddCcwThunk(propertyGet, returnSOK);
break;
case SyntaxKind.SetAccessorDeclaration:
//// @object.Property = inputArg;
StatementSyntax propertySet = ExpressionStatement(AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, objectLocal, propertyName),
IdentifierName(parameterListPreserveSig.Parameters.Last().Identifier.ValueText)));
this.TryGenerateConstantOrThrow("S_OK");
AddCcwThunk(propertySet, returnSOK);
break;
default:
throw new NotSupportedException("Unsupported accessor kind: " + accessorKind);
}
}
else
{
// Prepare the args for the thunk call. The Interface we thunk into *always* uses PreserveSig, which is super convenient for us.
ArgumentListSyntax args = ArgumentList().AddArguments(parameterListPreserveSig.Parameters.Select(p => Argument(IdentifierName(p.Identifier.ValueText))).ToArray());
Expand All @@ -385,6 +395,20 @@ void AddCcwThunk(params StatementSyntax[] thunkInvokeAndReturn)
return;
}

if (requiresMarshaling)
{
// Oops. This method requires marshaling, which isn't supported in a native-callable function.
// Abandon all efforts to add CCW support to this interface.
ccwThisParameter = null;
foreach (MethodDeclarationSyntax ccwEntrypointMethod in ccwEntrypointMethods)
{
members.Remove(ccwEntrypointMethod);
}

ccwEntrypointMethods.Clear();
return;
}

this.RequestComHelpers(context);
bool hrReturnType = returnTypePreserveSig is QualifiedNameSyntax { Right.Identifier.ValueText: "HRESULT" };

Expand Down Expand Up @@ -434,6 +458,7 @@ void AddCcwThunk(params StatementSyntax[] thunkInvokeAndReturn)
ccwBody,
semicolonToken: default);
members.Add(ccwMethod);
ccwEntrypointMethods.Add(ccwMethod);

populateVTableBody = populateVTableBody.AddStatements(
ExpressionStatement(AssignmentExpression(
Expand All @@ -442,9 +467,12 @@ void AddCcwThunk(params StatementSyntax[] thunkInvokeAndReturn)
PrefixUnaryExpression(SyntaxKind.AddressOfExpression, SafeIdentifierName(methodName)))));
}

// Add documentation if we can find it.
propertyOrMethod = this.AddApiDocumentation($"{ifaceName}.{methodName}", propertyOrMethod);
members.Add(propertyOrMethod);
if (propertyOrMethod is not null)
{
// Add documentation if we can find it.
propertyOrMethod = this.AddApiDocumentation($"{ifaceName}.{methodName}", propertyOrMethod);
members.Add(propertyOrMethod);
}
}

// We expose the vtbl struct to support CCWs.
Expand Down
66 changes: 65 additions & 1 deletion src/Microsoft.Windows.CsWin32/SimpleSyntaxFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,91 @@ namespace Microsoft.Windows.CsWin32;

internal static class SimpleSyntaxFactory
{
/// <summary>
/// C# keywords that must be escaped or changed when they appear as identifiers from metadata.
/// </summary>
/// <remarks>
/// This list comes from <see href="https://learn.microsoft.com/en-us/dotnet/csharp/language-reference/keywords/">this documentation</see>.
/// </remarks>
internal static readonly HashSet<string> CSharpKeywords = new HashSet<string>(StringComparer.Ordinal)
{
"abstract",
"as",
"base",
"bool",
"break",
"byte",
"case",
"catch",
"char",
"checked",
"class",
"const",
"continue",
"decimal",
"default",
"delegate",
"do",
"double",
"else",
"enum",
"event",
"explicit",
"extern",
"false",
"finally",
"fixed",
"float",
"for",
"foreach",
"goto",
"if",
"implicit",
"in",
"is",
"int",
"interface",
"internal",
"is",
"lock",
"long",
"namespace",
"new",
"null",
"object",
"operator",
"out",
"override",
"params",
"private",
"protected",
"public",
"readonly",
"ref",
"return",
"sbyte",
"sealed",
"short",
"sizeof",
"stackalloc",
"static",
"string",
"struct",
"switch",
"this",
"throw",
"true",
"try",
"typeof",
"uint",
"ulong",
"unchecked",
"unsafe",
"ushort",
"using",
"virtual",
"void",
"volatile",
"while",
};

internal static readonly XmlTextSyntax DocCommentStart = XmlText(" ").WithLeadingTrivia(DocumentationCommentExterior("///"));
Expand Down
8 changes: 8 additions & 0 deletions test/GenerationSandbox.Unmarshalled.Tests/GeneratedForm.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@ private static unsafe void COMStructsPreserveSig()
o.MachineName = bstr;
}

#if NET5_0_OR_GREATER
private static unsafe void IStream_GetsCCW()
{
IStream.Vtbl vtbl;
IStream.PopulateVTable(&vtbl);
}
#endif

private static unsafe void IUnknownGetsVtbl()
{
// WinForms needs the v-table to be declared for these base interfaces.
Expand Down
5 changes: 3 additions & 2 deletions test/GenerationSandbox.Unmarshalled.Tests/NativeMethods.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
IPersistFile
IEventSubscription
IEventSubscription
IPersistFile
IStream
14 changes: 14 additions & 0 deletions test/Microsoft.Windows.CsWin32.Tests/COMTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,20 @@ public void MethodWithHRParameter()
this.AssertNoDiagnostics();
}

[Theory]
[InlineData("IVssCreateWriterMetadata")] // A non-COM compliant interface (since it doesn't derive from IUnknown).
[InlineData("IProtectionPolicyManagerInterop3")] // An IInspectable-derived interface.
[InlineData("ICompositionCapabilitiesInteropFactory")] // An interface with managed types.
[InlineData("IPicture")] // An interface with properties that cannot be represented as properties.
public void InterestingComInterfaces(string api)
{
this.compilation = this.starterCompilations["net6.0"];
this.generator = this.CreateGenerator(DefaultTestGeneratorOptions with { AllowMarshaling = false });
Assert.True(this.generator.TryGenerate(api, CancellationToken.None));
this.CollectGeneratedCode(this.generator);
this.AssertNoDiagnostics();
}

[Fact]
public void ComOutPtrTypedAsOutObject()
{
Expand Down
Loading

0 comments on commit 76e706e

Please sign in to comment.