Skip to content

Commit

Permalink
Add IVTable interfaces to COM structs
Browse files Browse the repository at this point in the history
Closes #831
  • Loading branch information
AArnott committed Jun 27, 2023
1 parent 2c6ad76 commit fc1bb9c
Show file tree
Hide file tree
Showing 13 changed files with 199 additions and 24 deletions.
29 changes: 24 additions & 5 deletions src/Microsoft.Windows.CsWin32/Generator.Com.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ private static Guid DecodeGuidFromAttribute(CustomAttribute guidAttribute)

private static bool IsHresult(TypeHandleInfo? typeHandleInfo) => typeHandleInfo is HandleTypeHandleInfo handleInfo && handleInfo.IsType("HRESULT");

private static bool GenerateCcwFor(string interfaceName) => interfaceName is not ("IUnknown" or "IDispatch" or "IInspectable");

private static bool GenerateCcwFor(MetadataReader reader, StringHandle typeName) => !(reader.StringComparer.Equals(typeName, "IUnknown") || reader.StringComparer.Equals(typeName, "IDispatch") || reader.StringComparer.Equals(typeName, "IInspectable"));

/// <summary>
/// Generates a type to represent a COM interface.
/// </summary>
Expand Down Expand Up @@ -121,8 +125,7 @@ private TypeDeclarationSyntax DeclareInterfaceAsStruct(TypeDefinitionHandle type
allMethods.AddRange(methodsThisType);

// We do *not* emit CCW methods for IUnknown, because those are provided by ComWrappers.
if (ccwThisParameter is not null &&
(qualifiedBaseType.Reader.StringComparer.Equals(baseType.Name, "IUnknown") || qualifiedBaseType.Reader.StringComparer.Equals(baseType.Name, "IDispatch") || qualifiedBaseType.Reader.StringComparer.Equals(baseType.Name, "IInspectable")))
if (ccwThisParameter is not null && !GenerateCcwFor(qualifiedBaseType.Reader, baseType.Name))
{
ccwMethodsToSkip.AddRange(methodsThisType);
}
Expand Down Expand Up @@ -482,7 +485,7 @@ static ExpressionSyntax ThisPointer(PointerTypeSyntax? typedPointer = null)

if (ccwThisParameter is not null)
{
// PopulateVTable must be public in order to (implicitly) implement an interface that WinForms declares.
// PopulateVTable must be public in order to (implicitly) implement the IVTable<TComInterface, TVTable> interface.
// public static void PopulateVTable(Vtbl* vtable)
MethodDeclarationSyntax populateVtblMethodDecl = MethodDeclaration(PredefinedType(Token(SyntaxKind.VoidKeyword)), Identifier("PopulateVTable"))
.AddModifiers(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.StaticKeyword))
Expand All @@ -503,7 +506,7 @@ static ExpressionSyntax ThisPointer(PointerTypeSyntax? typedPointer = null)
BaseListSyntax baseList = BaseList(SeparatedList<BaseTypeSyntax>());

CustomAttribute? guidAttribute = this.FindGuidAttribute(typeDef.GetCustomAttributes());
var staticMembers = this.DeclareStaticCOMInterfaceMembers(guidAttribute);
var staticMembers = this.DeclareStaticCOMInterfaceMembers(originalIfaceName, ifaceName, ccwThisParameter is not null, guidAttribute, context);
members.AddRange(staticMembers.Members);
baseList = baseList.AddTypes(staticMembers.BaseTypes.ToArray());

Expand Down Expand Up @@ -755,11 +758,27 @@ static ExpressionSyntax ThisPointer(PointerTypeSyntax? typedPointer = null)
return ifaceDeclaration;
}

private unsafe (List<MemberDeclarationSyntax> Members, List<BaseTypeSyntax> BaseTypes) DeclareStaticCOMInterfaceMembers(CustomAttribute? guidAttribute)
private unsafe (IReadOnlyList<MemberDeclarationSyntax> Members, IReadOnlyList<BaseTypeSyntax> BaseTypes) DeclareStaticCOMInterfaceMembers(
string originalIfaceName,
IdentifierNameSyntax ifaceName,
bool populateVtblDeclared,
CustomAttribute? guidAttribute,
Context context)
{
List<MemberDeclarationSyntax> members = new();
List<BaseTypeSyntax> baseTypes = new();

// IVTable<ComStructType, ComStructType.Vtbl>
// Static interface members require C# 11 and .NET 7 at minimum.
if (populateVtblDeclared && this.IsFeatureAvailable(Feature.InterfaceStaticMembers) && !context.AllowMarshaling && GenerateCcwFor(originalIfaceName))
{
this.RequestComHelpers(context);
baseTypes.Add(SimpleBaseType(GenericName("IVTable").AddTypeArgumentListArguments(
ifaceName,
QualifiedName(ifaceName, IdentifierName("Vtbl")))));
}

// IComIID
if (guidAttribute.HasValue)
{
Guid guidAttributeValue = DecodeGuidFromAttribute(guidAttribute.Value);
Expand Down
9 changes: 5 additions & 4 deletions src/Microsoft.Windows.CsWin32/Generator.Invariants.cs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ public partial class Generator
private const string OriginalDelegateAnnotation = "OriginalDelegate";

private static readonly Dictionary<string, MethodDeclarationSyntax> PInvokeHelperMethods;
private static readonly ClassDeclarationSyntax ComHelperClass;
private static readonly InterfaceDeclarationSyntax IVTableInterface;
private static readonly InterfaceDeclarationSyntax IVTableGenericInterface;
private static readonly Dictionary<string, MethodDeclarationSyntax> Win32SdkMacros;

private static readonly string AutoGeneratedHeader = @"// ------------------------------------------------------------------------------
Expand Down Expand Up @@ -324,13 +325,13 @@ public partial class Generator
.Add("ULARGE_INTEGER", "Use the C# ulong keyword instead.")
.Add("OVERLAPPED", "Use System.Threading.NativeOverlapped instead.")
.Add("POINT", "Use System.Drawing.Point instead.")
.Add("POINTF", "Use System.Drawing.PointF instead.")
.Add("IUnknown", "This COM interface is implicit in the runtime. Interfaces that derive from it should apply the [InterfaceType(ComInterfaceType.InterfaceIsIUnknown)] attribute instead.")
.Add("IDispatch", "This COM interface is implicit in the runtime. Interfaces that derive from it should apply the [InterfaceType(ComInterfaceType.InterfaceIsIDispatch)] attribute instead.");
.Add("POINTF", "Use System.Drawing.PointF instead.");

/// <summary>
/// Gets a map of interop APIs that should not be generated when marshaling is allowed, and messages to emit in diagnostics if these APIs are ever directly requested.
/// </summary>
internal static ImmutableDictionary<string, string> BannedAPIsWithMarshaling { get; } = BannedAPIsWithoutMarshaling
.Add("IUnknown", "This COM interface is implicit in the runtime. Interfaces that derive from it should apply the [InterfaceType(ComInterfaceType.InterfaceIsIUnknown)] attribute instead.")
.Add("IDispatch", "This COM interface is implicit in the runtime. Interfaces that derive from it should apply the [InterfaceType(ComInterfaceType.InterfaceIsIDispatch)] attribute instead.")
.Add("VARIANT", "Use `object` instead of VARIANT when in COM interface mode. VARIANT can only be emitted when emitting COM interfaces as structs.");
}
11 changes: 11 additions & 0 deletions src/Microsoft.Windows.CsWin32/Generator.Templates.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,17 @@ private static bool TryFetchTemplate(string name, Generator? generator, [NotNull
return true;
}

private static void FetchTemplate<T>(string name, Generator? generator, out T member)
where T : MemberDeclarationSyntax
{
if (!TryFetchTemplate(name, generator, out MemberDeclarationSyntax? localMember))
{
throw new GenerationFailedException("Missing embedded resource.");
}

member = (T)localMember;
}

private IEnumerable<MemberDeclarationSyntax> ExtractMembersFromTemplate(string name) => ((TypeDeclarationSyntax)this.FetchTemplate($"{name}")).Members;

/// <summary>
Expand Down
36 changes: 25 additions & 11 deletions src/Microsoft.Windows.CsWin32/Generator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ public partial class Generator : IGenerator, IDisposable
private readonly TypeSyntaxSettings functionPointerTypeSettings;
private readonly TypeSyntaxSettings errorMessageTypeSettings;

private readonly ClassDeclarationSyntax comHelperClass;

private readonly Dictionary<string, IReadOnlyList<ISymbol>> findTypeSymbolIfAlreadyAvailableCache = new(StringComparer.Ordinal);
private readonly Rental<MetadataReader> metadataReader;
private readonly GeneratorOptions options;
Expand Down Expand Up @@ -53,12 +55,8 @@ static Generator()

Win32SdkMacros = ((ClassDeclarationSyntax)member).Members.OfType<MethodDeclarationSyntax>().ToDictionary(m => m.Identifier.ValueText, m => m);

if (!TryFetchTemplate("ComHelpers", null, out member))
{
throw new GenerationFailedException("Missing embedded resource.");
}

ComHelperClass = (ClassDeclarationSyntax)member;
FetchTemplate("IVTable", null, out IVTableInterface);
FetchTemplate("IVTable`2", null, out IVTableGenericInterface);
}

/// <summary>
Expand Down Expand Up @@ -113,6 +111,7 @@ public Generator(string metadataLibraryPath, Docs? docs, GeneratorOptions option
AddSymbolIf(this.canUseUnsafeAsRef, "canUseUnsafeAsRef");
AddSymbolIf(this.canUseUnsafeNullRef, "canUseUnsafeNullRef");
AddSymbolIf(compilation?.GetTypeByMetadataName("System.Drawing.Point") is not null, "canUseSystemDrawing");
AddSymbolIf(this.IsFeatureAvailable(Feature.InterfaceStaticMembers), "canUseInterfaceStaticMembers");

if (extraSymbols.Count > 0)
{
Expand Down Expand Up @@ -147,10 +146,15 @@ void AddSymbolIf(bool condition, string symbol)
this.errorMessageTypeSettings = this.generalTypeSettings with { QualifyNames = true, Generator = null }; // Avoid risk of infinite recursion from errors in ToTypeSyntax

this.methodsAndConstantsClassName = IdentifierName(options.ClassName);

FetchTemplate("ComHelpers", this, out this.comHelperClass);
}

private enum Feature
{
/// <summary>
/// Indicates that interfaces can declare static members. This requires at least .NET 7 and C# 11.
/// </summary>
InterfaceStaticMembers,
}

Expand Down Expand Up @@ -794,9 +798,17 @@ internal void RequestComHelpers(Context context)
{
if (this.IsWin32Sdk)
{
const string specialType = "ComHelpers";
this.RequestInteropType("Windows.Win32.Foundation", "HRESULT", context);
this.volatileCode.GenerateSpecialType(specialType, () => this.volatileCode.AddSpecialType(specialType, ComHelperClass));
this.volatileCode.GenerateSpecialType("ComHelpers", () => this.volatileCode.AddSpecialType("ComHelpers", this.comHelperClass));
if (this.IsFeatureAvailable(Feature.InterfaceStaticMembers) && !context.AllowMarshaling)
{
this.volatileCode.GenerateSpecialType("IVTable", () => this.volatileCode.AddSpecialType("IVTable", IVTableInterface));
this.volatileCode.GenerateSpecialType("IVTable`2", () => this.volatileCode.AddSpecialType("IVTable`2", IVTableGenericInterface));
if (!this.TryGenerate("IUnknown", default))
{
throw new GenerationFailedException("Unable to generate IUnknown.");
}
}
}
else if (this.SuperGenerator is not null && this.SuperGenerator.TryGetGenerator("Windows.Win32", out Generator? generator))
{
Expand Down Expand Up @@ -865,7 +877,7 @@ internal void RequestInteropType(TypeDefinitionHandle typeDefHandle, Context con
}
}

bool hasUnmanagedName = this.HasUnmanagedSuffix(context.AllowMarshaling, this.IsManagedType(typeDefHandle));
bool hasUnmanagedName = this.HasUnmanagedSuffix(this.Reader, typeDef.Name, context.AllowMarshaling, this.IsManagedType(typeDefHandle));
this.volatileCode.GenerateType(typeDefHandle, hasUnmanagedName, delegate
{
if (this.RequestInteropTypeHelper(typeDefHandle, context) is MemberDeclarationSyntax typeDeclaration)
Expand Down Expand Up @@ -1027,10 +1039,12 @@ internal void GetBaseTypeInfo(TypeDefinition typeDef, out StringHandle baseTypeN
return specialDeclaration;
}

internal bool HasUnmanagedSuffix(bool allowMarshaling, bool isManagedType) => !allowMarshaling && isManagedType && this.options.AllowMarshaling;
internal bool HasUnmanagedSuffix(string originalName, bool allowMarshaling, bool isManagedType) => !allowMarshaling && isManagedType && this.options.AllowMarshaling && originalName is not "IUnknown";

internal bool HasUnmanagedSuffix(MetadataReader reader, StringHandle typeName, bool allowMarshaling, bool isManagedType) => !allowMarshaling && isManagedType && this.options.AllowMarshaling && !reader.StringComparer.Equals(typeName, "IUnknown");

internal string GetMangledIdentifier(string normalIdentifier, bool allowMarshaling, bool isManagedType) =>
this.HasUnmanagedSuffix(allowMarshaling, isManagedType) ? normalIdentifier + UnmanagedInteropSuffix : normalIdentifier;
this.HasUnmanagedSuffix(normalIdentifier, allowMarshaling, isManagedType) ? normalIdentifier + UnmanagedInteropSuffix : normalIdentifier;

/// <summary>
/// Disposes of managed and unmanaged resources.
Expand Down
6 changes: 4 additions & 2 deletions src/Microsoft.Windows.CsWin32/HandleTypeHandleInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,21 @@ internal override TypeSyntaxAndMarshaling ToTypeSyntax(TypeSyntaxSettings inputs
bool isInterface;
bool isNonCOMConformingInterface;
bool isManagedType = inputs.Generator?.IsManagedType(this) ?? false;
bool hasUnmanagedSuffix = inputs.Generator?.HasUnmanagedSuffix(inputs.AllowMarshaling, isManagedType) ?? false;
string simpleNameSuffix = hasUnmanagedSuffix ? Generator.UnmanagedInteropSuffix : string.Empty;
switch (this.Handle.Kind)
{
case HandleKind.TypeDefinition:
TypeDefinition td = this.reader.GetTypeDefinition((TypeDefinitionHandle)this.Handle);
bool hasUnmanagedSuffix = inputs.Generator?.HasUnmanagedSuffix(this.reader, td.Name, inputs.AllowMarshaling, isManagedType) ?? false;
string simpleNameSuffix = hasUnmanagedSuffix ? Generator.UnmanagedInteropSuffix : string.Empty;
nameSyntax = inputs.QualifyNames ? GetNestingQualifiedName(inputs.Generator, this.reader, td, hasUnmanagedSuffix, isInterfaceNestedInStruct: false) : IdentifierName(this.reader.GetString(td.Name) + simpleNameSuffix);
isInterface = (td.Attributes & TypeAttributes.Interface) == TypeAttributes.Interface;
isNonCOMConformingInterface = isInterface && inputs.Generator?.IsNonCOMInterface(td) is true;
break;
case HandleKind.TypeReference:
var trh = (TypeReferenceHandle)this.Handle;
TypeReference tr = this.reader.GetTypeReference(trh);
hasUnmanagedSuffix = inputs.Generator?.HasUnmanagedSuffix(this.reader, tr.Name, inputs.AllowMarshaling, isManagedType) ?? false;
simpleNameSuffix = hasUnmanagedSuffix ? Generator.UnmanagedInteropSuffix : string.Empty;
nameSyntax = inputs.QualifyNames ? GetNestingQualifiedName(inputs, this.reader, tr, hasUnmanagedSuffix) : IdentifierName(this.reader.GetString(tr.Name) + simpleNameSuffix);
isInterface = inputs.Generator?.IsInterface(trh) is true;
isNonCOMConformingInterface = isInterface && inputs.Generator?.IsNonCOMInterface(trh) is true;
Expand Down
15 changes: 15 additions & 0 deletions src/Microsoft.Windows.CsWin32/templates/ComHelpers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,19 @@
@object = ComWrappers.ComInterfaceDispatch.GetInstance<TInterface>((ComWrappers.ComInterfaceDispatch*)@this);
return @object is null ? COR_E_OBJECTDISPOSED : S_OK;
}

#if canUseInterfaceStaticMembers
internal static void PopulateIUnknown<TComInterface>(System.Com.IUnknown.Vtbl* vtable)
where TComInterface : unmanaged
{
PopulateIUnknownImpl<TComInterface>(vtable);
if (vtable->QueryInterface_1 is null)
{
throw new NotImplementedException("v-tables cannot be accessed unless the Windows.Win32.ComHelpers.PopulateIUnknownImpl partial method is implemented.");
}
}

static partial void PopulateIUnknownImpl<TComInterface>(System.Com.IUnknown.Vtbl* vtable)
where TComInterface : unmanaged;
#endif
}
8 changes: 8 additions & 0 deletions src/Microsoft.Windows.CsWin32/templates/IVTable.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
/// <summary>
/// Non generic interface that allows constraining against a COM wrapper type directly. COM structs should
/// implement <see cref="IVTable{TComInterface, TVTable}"/>.
/// </summary>
internal unsafe interface IVTable
{
static abstract System.Com.IUnknown.Vtbl* VTable { get; }
}
16 changes: 16 additions & 0 deletions src/Microsoft.Windows.CsWin32/templates/IVTable`2.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
internal unsafe interface IVTable<TComInterface, TVTable> : IVTable
where TVTable : unmanaged
where TComInterface : unmanaged, IVTable<TComInterface, TVTable>
{
private protected static abstract void PopulateVTable(TVTable* vtable);

static System.Com.IUnknown.Vtbl* IVTable.VTable { get; } = (System.Com.IUnknown.Vtbl*)CreateVTable();

private static TVTable* CreateVTable()
{
TVTable* vtbl = (TVTable*)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(TVTable), sizeof(TVTable));
ComHelpers.PopulateIUnknown<TComInterface>((System.Com.IUnknown.Vtbl*)vtbl);
TComInterface.PopulateVTable(vtbl);
return vtbl;
}
}
45 changes: 45 additions & 0 deletions test/GenerationSandbox.Unmarshalled.Tests/ComHelpers.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

#if NET7_0_OR_GREATER

using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using Windows.Win32.Foundation;
using Windows.Win32.System.Com;

namespace Windows.Win32;

// The `unsafe` modifier is only allowed to appear on the class declaration -- not the partial method declaration.
// See https://github.com/dotnet/csharplang/discussions/7298 for more.
internal unsafe partial class ComHelpers
{
static partial void PopulateIUnknownImpl<TComInterface>(IUnknown.Vtbl* vtable)
where TComInterface : unmanaged
{
// IUnknown member initialization of the v-table would go here.
vtable->QueryInterface_1 = (delegate* unmanaged[Stdcall]<IUnknown*, Guid*, void**, HRESULT>)&QueryInterface;
vtable->AddRef_2 = (delegate* unmanaged[Stdcall]<IUnknown*, uint>)&AddRef;
vtable->Release_3 = (delegate* unmanaged[Stdcall]<IUnknown*, uint>)&Release;
}

[UnmanagedCallersOnly(CallConvs = new[] { typeof(CallConvStdcall) })]
private static HRESULT QueryInterface(IUnknown* punk, Guid* iid, void** ppvObject)
{
throw new NotImplementedException();
}

[UnmanagedCallersOnly(CallConvs = new[] { typeof(CallConvStdcall) })]
private static uint AddRef(IUnknown* punk)
{
throw new NotImplementedException();
}

[UnmanagedCallersOnly(CallConvs = new[] { typeof(CallConvStdcall) })]
private static uint Release(IUnknown* punk)
{
throw new NotImplementedException();
}
}

#endif
11 changes: 11 additions & 0 deletions test/GenerationSandbox.Unmarshalled.Tests/GeneratedForm.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,17 @@ private static unsafe void IStream_GetsCCW()
}
#endif

#if NET7_0_OR_GREATER
private static unsafe void GetVTable()
{
IUnknown.Vtbl* vtbl = GetVtable<IStream>();

static IUnknown.Vtbl* GetVtable<T>()
where T : IVTable
=> T.VTable;
}
#endif

private static unsafe void IUnknownGetsVtbl()
{
// WinForms needs the v-table to be declared for these base interfaces.
Expand Down
Loading

0 comments on commit fc1bb9c

Please sign in to comment.