diff --git a/src/Microsoft.Windows.CsWin32/Generator.FriendlyOverloads.cs b/src/Microsoft.Windows.CsWin32/Generator.FriendlyOverloads.cs index ed9c86c5..942b6f21 100644 --- a/src/Microsoft.Windows.CsWin32/Generator.FriendlyOverloads.cs +++ b/src/Microsoft.Windows.CsWin32/Generator.FriendlyOverloads.cs @@ -103,9 +103,16 @@ private IEnumerable DeclareFriendlyOverloads(MethodDefi TypeHandleInfo parameterTypeInfo = originalSignature.ParameterTypes[param.SequenceNumber - 1]; bool isManagedParameterType = this.IsManagedType(parameterTypeInfo); + bool mustRemainAsPointer = parameterTypeInfo is PointerTypeHandleInfo { ElementType: HandleTypeHandleInfo pointedElement } && this.IsStructWithFlexibleArray(pointedElement); + IdentifierNameSyntax origName = IdentifierName(externParam.Identifier.ValueText); - if (isReserved && !isOut) + if (mustRemainAsPointer) + { + // This block intentionally left blank, so as to disable further processing that might try to + // replace a pointer with a `ref` or similar modifier. + } + else if (isReserved && !isOut) { // Remove the parameter and supply the default value for the type to the extern method. arguments[param.SequenceNumber - 1] = Argument(LiteralExpression(SyntaxKind.DefaultLiteralExpression)); diff --git a/src/Microsoft.Windows.CsWin32/Generator.Invariants.cs b/src/Microsoft.Windows.CsWin32/Generator.Invariants.cs index 5932be99..c63e5c90 100644 --- a/src/Microsoft.Windows.CsWin32/Generator.Invariants.cs +++ b/src/Microsoft.Windows.CsWin32/Generator.Invariants.cs @@ -77,6 +77,7 @@ public partial class Generator private const string AlsoUsableForAttribute = "AlsoUsableForAttribute"; private const string InvalidHandleValueAttribute = "InvalidHandleValueAttribute"; private const string CanReturnMultipleSuccessValuesAttribute = "CanReturnMultipleSuccessValuesAttribute"; + private const string FlexibleArrayAttribute = "FlexibleArrayAttribute"; private const string CanReturnErrorsAsSuccessAttribute = "CanReturnErrorsAsSuccessAttribute"; private const string SimpleFileNameAnnotation = "SimpleFileName"; private const string NamespaceContainerAnnotation = "NamespaceContainer"; diff --git a/src/Microsoft.Windows.CsWin32/Generator.MetadataHelpers.cs b/src/Microsoft.Windows.CsWin32/Generator.MetadataHelpers.cs index 5f9f5cf4..a993188e 100644 --- a/src/Microsoft.Windows.CsWin32/Generator.MetadataHelpers.cs +++ b/src/Microsoft.Windows.CsWin32/Generator.MetadataHelpers.cs @@ -84,6 +84,27 @@ internal bool TryGetTypeDefHandle(string @namespace, string name, out TypeDefini return false; } + internal bool TryGetTypeDefHandle(EntityHandle entityHandle, out QualifiedTypeDefinitionHandle typeDefHandle) + { + if (entityHandle.IsNil) + { + typeDefHandle = default; + return false; + } + + switch (entityHandle.Kind) + { + case HandleKind.TypeReference: + return this.TryGetTypeDefHandle((TypeReferenceHandle)entityHandle, out typeDefHandle); + case HandleKind.TypeDefinition: + typeDefHandle = new QualifiedTypeDefinitionHandle(this, (TypeDefinitionHandle)entityHandle); + return true; + default: + typeDefHandle = default; + return false; + } + } + internal bool IsNonCOMInterface(TypeDefinition interfaceTypeDef) { if (this.Reader.StringComparer.Equals(interfaceTypeDef.Name, "IUnknown")) @@ -164,6 +185,27 @@ internal bool IsInterface(TypeReferenceHandle typeRefHandle) internal bool IsDelegate(TypeDefinition typeDef) => (typeDef.Attributes & TypeAttributes.Class) == TypeAttributes.Class && typeDef.BaseType.Kind == HandleKind.TypeReference && this.Reader.StringComparer.Equals(this.Reader.GetTypeReference((TypeReferenceHandle)typeDef.BaseType).Name, nameof(MulticastDelegate)); + internal bool IsStructWithFlexibleArray(HandleTypeHandleInfo typeInfo) + { + return this.TryGetTypeDefHandle(typeInfo.Handle, out QualifiedTypeDefinitionHandle typeHandle) + && typeHandle.Generator.IsStructWithFlexibleArray(typeHandle.DefinitionHandle); + } + + internal bool IsStructWithFlexibleArray(TypeDefinitionHandle typeDefHandle) + { + TypeDefinition typeDef = this.Reader.GetTypeDefinition(typeDefHandle); + foreach (FieldDefinitionHandle fieldHandle in typeDef.GetFields()) + { + FieldDefinition field = this.Reader.GetFieldDefinition(fieldHandle); + if (MetadataUtilities.FindAttribute(this.Reader, field.GetCustomAttributes(), InteropDecorationNamespace, FlexibleArrayAttribute) is not null) + { + return true; + } + } + + return false; + } + internal bool IsManagedType(TypeHandleInfo typeHandleInfo) { TypeHandleInfo elementType = diff --git a/src/Microsoft.Windows.CsWin32/Generator.Struct.cs b/src/Microsoft.Windows.CsWin32/Generator.Struct.cs index 3f9ed349..5b64bb2b 100644 --- a/src/Microsoft.Windows.CsWin32/Generator.Struct.cs +++ b/src/Microsoft.Windows.CsWin32/Generator.Struct.cs @@ -26,6 +26,17 @@ private StructDeclarationSyntax DeclareStruct(TypeDefinitionHandle typeDefHandle context = context with { AllowMarshaling = false }; } + // If the last field has the [FlexibleArray] attribute, we must disable marshaling since the struct + // is only ever valid when accessed via a pointer since the struct acts as a header of an arbitrarily-sized array. + if (typeDef.GetFields().LastOrDefault() is FieldDefinitionHandle { IsNil: false } lastFieldHandle) + { + FieldDefinition lastField = this.Reader.GetFieldDefinition(lastFieldHandle); + if (MetadataUtilities.FindAttribute(this.Reader, lastField.GetCustomAttributes(), InteropDecorationNamespace, FlexibleArrayAttribute) is not null) + { + context = context with { AllowMarshaling = false }; + } + } + TypeSyntaxSettings typeSettings = context.Filter(this.fieldTypeSettings); bool hasUtf16CharField = false; diff --git a/src/Microsoft.Windows.CsWin32/PointerTypeHandleInfo.cs b/src/Microsoft.Windows.CsWin32/PointerTypeHandleInfo.cs index 667b611c..b4c00270 100644 --- a/src/Microsoft.Windows.CsWin32/PointerTypeHandleInfo.cs +++ b/src/Microsoft.Windows.CsWin32/PointerTypeHandleInfo.cs @@ -18,7 +18,9 @@ internal override TypeSyntaxAndMarshaling ToTypeSyntax(TypeSyntaxSettings inputs } bool xOptional = (parameterAttributes & ParameterAttributes.Optional) == ParameterAttributes.Optional; - if (xOptional && forElement == Generator.GeneratingElement.InterfaceMember && nativeArrayInfo is null) + bool mustUsePointers = xOptional && forElement == Generator.GeneratingElement.InterfaceMember && nativeArrayInfo is null; + mustUsePointers |= this.ElementType is HandleTypeHandleInfo handleElementType && inputs.Generator?.IsStructWithFlexibleArray(handleElementType) is true; + if (mustUsePointers) { // Disable marshaling because pointers to optional parameters cannot be passed by reference when used as parameters of a COM interface method. return new TypeSyntaxAndMarshaling(PointerType(this.ElementType.ToTypeSyntax( diff --git a/test/Microsoft.Windows.CsWin32.Tests/COMTests.cs b/test/Microsoft.Windows.CsWin32.Tests/COMTests.cs index 88ebee56..bd541a79 100644 --- a/test/Microsoft.Windows.CsWin32.Tests/COMTests.cs +++ b/test/Microsoft.Windows.CsWin32.Tests/COMTests.cs @@ -380,6 +380,17 @@ public void ITypeNameBuilder_ToStringOverload(bool allowMarshaling) this.GenerateApi(typeName); } + [Fact] + public void ReferencesToStructWithFlexibleArrayAreAlwaysPointers() + { + this.GenerateApi("IAMLine21Decoder"); + Assert.All(this.FindGeneratedMethod("SetOutputFormat"), m => Assert.IsType(m.ParameterList.Parameters[0].Type)); + + // Assert that the 'unmanaged' declaration of the struct is the *only* declaration. + Assert.Single(this.FindGeneratedType("BITMAPINFO")); + Assert.Empty(this.FindGeneratedType("BITMAPINFO_unmanaged")); + } + [Theory] [CombinatorialData] public void COMInterfaceIIDInterfaceOnAppropriateTFMs( diff --git a/test/Microsoft.Windows.CsWin32.Tests/ExternMethodTests.cs b/test/Microsoft.Windows.CsWin32.Tests/ExternMethodTests.cs index a56b8de0..61d35d94 100644 --- a/test/Microsoft.Windows.CsWin32.Tests/ExternMethodTests.cs +++ b/test/Microsoft.Windows.CsWin32.Tests/ExternMethodTests.cs @@ -75,6 +75,17 @@ public void DefaultEntryPointIsNotEmitted() Assert.DoesNotContain(attribute.ArgumentList!.Arguments, a => a.NameEquals?.Name.Identifier.ValueText == "EntryPoint"); } + [Fact] + public void ReferencesToStructWithFlexibleArrayAreAlwaysPointers() + { + this.GenerateApi("CreateDIBSection"); + Assert.All(this.FindGeneratedMethod("CreateDIBSection"), m => Assert.IsType(m.ParameterList.Parameters[1].Type)); + + // Assert that the 'unmanaged' declaration of the struct is the *only* declaration. + Assert.Single(this.FindGeneratedType("BITMAPINFO")); + Assert.Empty(this.FindGeneratedType("BITMAPINFO_unmanaged")); + } + private static AttributeSyntax? FindDllImportAttribute(SyntaxList attributeLists) => attributeLists.SelectMany(al => al.Attributes).FirstOrDefault(a => a.Name.ToString() == "DllImport"); private IEnumerable GenerateMethod(string methodName)