Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Many more CCW and other fixes #829

Merged
merged 9 commits into from
Dec 10, 2022
2 changes: 1 addition & 1 deletion Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
<PackageVersion Include="System.Runtime.CompilerServices.Unsafe" Version="6.0.0" />
<PackageVersion Include="System.Text.Encodings.Web" Version="4.7.2" />
<PackageVersion Include="System.Text.Json" Version="4.7.2" />
<PackageVersion Include="Xunit.Combinatorial" Version="1.5.25" />
<PackageVersion Include="Xunit.Combinatorial" Version="1.6.12-alpha" />
<PackageVersion Include="xunit.runner.visualstudio" Version="2.4.5" />
<PackageVersion Include="xunit" Version="2.4.2" />
</ItemGroup>
Expand Down
92 changes: 60 additions & 32 deletions src/Microsoft.Windows.CsWin32/Generator.Com.cs
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,9 @@ private TypeDeclarationSyntax DeclareInterfaceAsStruct(TypeDefinitionHandle type
var vtblMembers = new List<MemberDeclarationSyntax>();
TypeSyntaxSettings typeSettings = this.comSignatureTypeSettings;
IdentifierNameSyntax pThisLocal = IdentifierName("pThis");
ParameterSyntax? ccwThisParameter = this.canUseUnmanagedCallersOnlyAttribute && !this.options.AllowMarshaling && originalIfaceName != "IUnknown" && originalIfaceName != "IDispatch" ? Parameter(pThisLocal.Identifier).WithType(PointerType(ifaceName).WithTrailingTrivia(Space)) : null;
ParameterSyntax? ccwThisParameter = this.canUseUnmanagedCallersOnlyAttribute && !this.options.AllowMarshaling && originalIfaceName != "IUnknown" && originalIfaceName != "IDispatch" && !this.IsNonCOMInterface(typeDef) ? Parameter(pThisLocal.Identifier).WithType(PointerType(ifaceName).WithTrailingTrivia(Space)) : null;
List<QualifiedMethodDefinitionHandle> ccwMethodsToSkip = new();
List<MemberDeclarationSyntax> ccwEntrypointMethods = new();
IdentifierNameSyntax vtblParamName = IdentifierName("vtable");
BlockSyntax populateVTableBody = Block();
IdentifierNameSyntax objectLocal = IdentifierName("__object");
Expand All @@ -119,7 +120,7 @@ private TypeDeclarationSyntax DeclareInterfaceAsStruct(TypeDefinitionHandle type

// We do *not* emit CCW methods for IUnknown, because those are provided by ComWrappers.
if (ccwThisParameter is not null &&
(qualifiedBaseType.Reader.StringComparer.Equals(baseType.Name, "IUnknown") || qualifiedBaseType.Reader.StringComparer.Equals(baseType.Name, "IDispatch")))
(qualifiedBaseType.Reader.StringComparer.Equals(baseType.Name, "IUnknown") || qualifiedBaseType.Reader.StringComparer.Equals(baseType.Name, "IDispatch") || qualifiedBaseType.Reader.StringComparer.Equals(baseType.Name, "IInspectable")))
{
ccwMethodsToSkip.AddRange(methodsThisType);
}
Expand All @@ -132,6 +133,8 @@ private TypeDeclarationSyntax DeclareInterfaceAsStruct(TypeDefinitionHandle type
allMethods.Select(qh => qh.Reader.GetMethodDefinition(qh.MethodHandle)),
originalIfaceName,
allowNonConsecutiveAccessors: true);
ISet<string>? ifaceDeclaredProperties = ccwThisParameter is not null ? this.GetDeclarableProperties(allMethods.Select(qh => qh.Reader.GetMethodDefinition(qh.MethodHandle)), originalIfaceName, allowNonConsecutiveAccessors: false) : null;

foreach (QualifiedMethodDefinitionHandle methodDefHandle in allMethods)
{
methodCounter++;
Expand All @@ -147,6 +150,7 @@ private TypeDeclarationSyntax DeclareInterfaceAsStruct(TypeDefinitionHandle type

ParameterListSyntax parameterList = methodDefinition.Generator.CreateParameterList(methodDefinition.Method, signature, typeSettings);
ParameterListSyntax parameterListPreserveSig = parameterList; // preserve a copy that has no mutations.
bool requiresMarshaling = parameterList.Parameters.Any(p => p.AttributeLists.SelectMany(al => al.Attributes).Any(a => a.Name is IdentifierNameSyntax { Identifier.ValueText: "MarshalAs" }) || p.Modifiers.Any(SyntaxKind.RefKeyword) || p.Modifiers.Any(SyntaxKind.OutKeyword) || p.Modifiers.Any(SyntaxKind.InKeyword));
FunctionPointerParameterListSyntax funcPtrParameters = FunctionPointerParameterList()
.AddParameters(FunctionPointerParameter(PointerType(ifaceName)))
.AddParameters(parameterList.Parameters.Select(p => FunctionPointerParameter(p.Type!).WithModifiers(p.Modifiers)).ToArray())
Expand Down Expand Up @@ -174,7 +178,7 @@ private TypeDeclarationSyntax DeclareInterfaceAsStruct(TypeDefinitionHandle type
.AddArguments(Argument(pThisLocal))
.AddArguments(parameterList.Parameters.Select(p => Argument(IdentifierName(p.Identifier.ValueText)).WithRefKindKeyword(p.Modifiers.Count > 0 ? p.Modifiers[0] : default)).ToArray())));

MemberDeclarationSyntax propertyOrMethod;
MemberDeclarationSyntax? propertyOrMethod;
MethodDeclarationSyntax? methodDeclaration = null;

// We can declare this method as a property accessor if it represents a property.
Expand Down Expand Up @@ -212,18 +216,6 @@ private TypeDeclarationSyntax DeclareInterfaceAsStruct(TypeDefinitionHandle type
resultLocalDeclaration,
vtblInvocationStatement,
returnStatement)).WithFixedKeyword(TokenWithSpace(SyntaxKind.FixedKeyword)));

if (ccwThisParameter is not null && !ccwMethodsToSkip.Contains(methodDefHandle))
{
//// *inputArg = @object.Property;
StatementSyntax propertyGet = ExpressionStatement(AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
PrefixUnaryExpression(SyntaxKind.PointerIndirectionExpression, IdentifierName(parameterListPreserveSig.Parameters.Last().Identifier.ValueText)),
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, objectLocal, propertyName)));
this.TryGenerateConstantOrThrow("S_OK");
AddCcwThunk(propertyGet, returnSOK);
}

break;
case SyntaxKind.SetAccessorDeclaration:
// vtblInvoke(pThis, value).ThrowOnFailure();
Expand All @@ -233,18 +225,6 @@ private TypeDeclarationSyntax DeclareInterfaceAsStruct(TypeDefinitionHandle type
VariableDeclaration(PointerType(ifaceName)).AddVariables(
VariableDeclarator(pThisLocal.Identifier).WithInitializer(EqualsValueClause(PrefixUnaryExpression(SyntaxKind.AddressOfExpression, ThisExpression())))),
vtblInvocationStatement).WithFixedKeyword(TokenWithSpace(SyntaxKind.FixedKeyword)));

if (ccwThisParameter is not null && !ccwMethodsToSkip.Contains(methodDefHandle))
{
//// @object.Property = inputArg;
StatementSyntax propertySet = ExpressionStatement(AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, objectLocal, propertyName),
IdentifierName(parameterListPreserveSig.Parameters.Last().Identifier.ValueText)));
this.TryGenerateConstantOrThrow("S_OK");
AddCcwThunk(propertySet, returnSOK);
}

break;
default:
throw new NotSupportedException("Unsupported accessor kind: " + accessorKind);
Expand All @@ -258,7 +238,7 @@ private TypeDeclarationSyntax DeclareInterfaceAsStruct(TypeDefinitionHandle type
// Add the accessor to the existing property declaration.
PropertyDeclarationSyntax priorDeclaration = (PropertyDeclarationSyntax)members[priorPropertyDeclarationIndex];
members[priorPropertyDeclarationIndex] = priorDeclaration.WithAccessorList(priorDeclaration.AccessorList!.AddAccessors(accessor));
continue;
propertyOrMethod = null;
}
else
{
Expand Down Expand Up @@ -359,8 +339,38 @@ private TypeDeclarationSyntax DeclareInterfaceAsStruct(TypeDefinitionHandle type
propertyOrMethod = methodDeclaration;

members.AddRange(methodDefinition.Generator.DeclareFriendlyOverloads(methodDefinition.Method, methodDeclaration, IdentifierName(ifaceName.Identifier.ValueText), FriendlyOverloadOf.StructMethod, helperMethodsInStruct));
}

if (ccwThisParameter is not null && !ccwMethodsToSkip.Contains(methodDefHandle))
if (ccwThisParameter is not null && !ccwMethodsToSkip.Contains(methodDefHandle))
{
if (this.TryGetPropertyAccessorInfo(methodDefinition.Method, originalIfaceName, out propertyName, out accessorKind, out propertyType) &&
ifaceDeclaredProperties!.Contains(propertyName.Identifier.ValueText))
{
switch (accessorKind)
{
case SyntaxKind.GetAccessorDeclaration:
//// *inputArg = @object.Property;
StatementSyntax propertyGet = ExpressionStatement(AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
PrefixUnaryExpression(SyntaxKind.PointerIndirectionExpression, IdentifierName(parameterListPreserveSig.Parameters.Last().Identifier.ValueText)),
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, objectLocal, propertyName)));
this.TryGenerateConstantOrThrow("S_OK");
AddCcwThunk(propertyGet, returnSOK);
break;
case SyntaxKind.SetAccessorDeclaration:
//// @object.Property = inputArg;
StatementSyntax propertySet = ExpressionStatement(AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, objectLocal, propertyName),
IdentifierName(parameterListPreserveSig.Parameters.Last().Identifier.ValueText)));
this.TryGenerateConstantOrThrow("S_OK");
AddCcwThunk(propertySet, returnSOK);
break;
default:
throw new NotSupportedException("Unsupported accessor kind: " + accessorKind);
}
}
else
{
// Prepare the args for the thunk call. The Interface we thunk into *always* uses PreserveSig, which is super convenient for us.
ArgumentListSyntax args = ArgumentList().AddArguments(parameterListPreserveSig.Parameters.Select(p => Argument(IdentifierName(p.Identifier.ValueText))).ToArray());
Expand All @@ -385,6 +395,20 @@ void AddCcwThunk(params StatementSyntax[] thunkInvokeAndReturn)
return;
}

if (requiresMarshaling)
{
// Oops. This method requires marshaling, which isn't supported in a native-callable function.
// Abandon all efforts to add CCW support to this interface.
ccwThisParameter = null;
foreach (MethodDeclarationSyntax ccwEntrypointMethod in ccwEntrypointMethods)
{
members.Remove(ccwEntrypointMethod);
}

ccwEntrypointMethods.Clear();
return;
}

this.RequestComHelpers(context);
bool hrReturnType = returnTypePreserveSig is QualifiedNameSyntax { Right.Identifier.ValueText: "HRESULT" };

Expand Down Expand Up @@ -434,6 +458,7 @@ void AddCcwThunk(params StatementSyntax[] thunkInvokeAndReturn)
ccwBody,
semicolonToken: default);
members.Add(ccwMethod);
ccwEntrypointMethods.Add(ccwMethod);

populateVTableBody = populateVTableBody.AddStatements(
ExpressionStatement(AssignmentExpression(
Expand All @@ -442,9 +467,12 @@ void AddCcwThunk(params StatementSyntax[] thunkInvokeAndReturn)
PrefixUnaryExpression(SyntaxKind.AddressOfExpression, SafeIdentifierName(methodName)))));
}

// Add documentation if we can find it.
propertyOrMethod = this.AddApiDocumentation($"{ifaceName}.{methodName}", propertyOrMethod);
members.Add(propertyOrMethod);
if (propertyOrMethod is not null)
{
// Add documentation if we can find it.
propertyOrMethod = this.AddApiDocumentation($"{ifaceName}.{methodName}", propertyOrMethod);
members.Add(propertyOrMethod);
}
}

// We expose the vtbl struct to support CCWs.
Expand Down
66 changes: 65 additions & 1 deletion src/Microsoft.Windows.CsWin32/SimpleSyntaxFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,91 @@ namespace Microsoft.Windows.CsWin32;

internal static class SimpleSyntaxFactory
{
/// <summary>
/// C# keywords that must be escaped or changed when they appear as identifiers from metadata.
/// </summary>
/// <remarks>
/// This list comes from <see href="https://learn.microsoft.com/en-us/dotnet/csharp/language-reference/keywords/">this documentation</see>.
/// </remarks>
internal static readonly HashSet<string> CSharpKeywords = new HashSet<string>(StringComparer.Ordinal)
{
"abstract",
"as",
"base",
"bool",
"break",
"byte",
"case",
"catch",
"char",
"checked",
"class",
"const",
"continue",
"decimal",
"default",
"delegate",
"do",
"double",
"else",
"enum",
"event",
"explicit",
"extern",
"false",
"finally",
"fixed",
"float",
"for",
"foreach",
"goto",
"if",
"implicit",
"in",
"is",
"int",
"interface",
"internal",
"is",
"lock",
"long",
"namespace",
"new",
"null",
"object",
"operator",
"out",
"override",
"params",
"private",
"protected",
"public",
"readonly",
"ref",
"return",
"sbyte",
"sealed",
"short",
"sizeof",
"stackalloc",
"static",
"string",
"struct",
"switch",
"this",
"throw",
"true",
"try",
"typeof",
"uint",
"ulong",
"unchecked",
"unsafe",
"ushort",
"using",
"virtual",
"void",
"volatile",
"while",
};

internal static readonly XmlTextSyntax DocCommentStart = XmlText(" ").WithLeadingTrivia(DocumentationCommentExterior("///"));
Expand Down
8 changes: 8 additions & 0 deletions test/GenerationSandbox.Unmarshalled.Tests/GeneratedForm.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@ private static unsafe void COMStructsPreserveSig()
o.MachineName = bstr;
}

#if NET5_0_OR_GREATER
private static unsafe void IStream_GetsCCW()
{
IStream.Vtbl vtbl;
IStream.PopulateVTable(&vtbl);
}
#endif

private static unsafe void IUnknownGetsVtbl()
{
// WinForms needs the v-table to be declared for these base interfaces.
Expand Down
5 changes: 3 additions & 2 deletions test/GenerationSandbox.Unmarshalled.Tests/NativeMethods.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
IPersistFile
IEventSubscription
IEventSubscription
IPersistFile
IStream
14 changes: 14 additions & 0 deletions test/Microsoft.Windows.CsWin32.Tests/COMTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,20 @@ public void MethodWithHRParameter()
this.AssertNoDiagnostics();
}

[Theory]
[InlineData("IVssCreateWriterMetadata")] // A non-COM compliant interface (since it doesn't derive from IUnknown).
[InlineData("IProtectionPolicyManagerInterop3")] // An IInspectable-derived interface.
[InlineData("ICompositionCapabilitiesInteropFactory")] // An interface with managed types.
[InlineData("IPicture")] // An interface with properties that cannot be represented as properties.
public void InterestingComInterfaces(string api)
{
this.compilation = this.starterCompilations["net6.0"];
this.generator = this.CreateGenerator(DefaultTestGeneratorOptions with { AllowMarshaling = false });
Assert.True(this.generator.TryGenerate(api, CancellationToken.None));
this.CollectGeneratedCode(this.generator);
this.AssertNoDiagnostics();
}

[Fact]
public void ComOutPtrTypedAsOutObject()
{
Expand Down
Loading