Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Always use pointers when referencing structs with variable-length inline arrays #1127

Merged
merged 1 commit into from
Jan 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/Microsoft.Windows.CsWin32/Generator.FriendlyOverloads.cs
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,16 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverloads(MethodDefi

TypeHandleInfo parameterTypeInfo = originalSignature.ParameterTypes[param.SequenceNumber - 1];
bool isManagedParameterType = this.IsManagedType(parameterTypeInfo);
bool mustRemainAsPointer = parameterTypeInfo is PointerTypeHandleInfo { ElementType: HandleTypeHandleInfo pointedElement } && this.IsStructWithFlexibleArray(pointedElement);

IdentifierNameSyntax origName = IdentifierName(externParam.Identifier.ValueText);

if (isReserved && !isOut)
if (mustRemainAsPointer)
{
// This block intentionally left blank, so as to disable further processing that might try to
// replace a pointer with a `ref` or similar modifier.
}
else if (isReserved && !isOut)
{
// Remove the parameter and supply the default value for the type to the extern method.
arguments[param.SequenceNumber - 1] = Argument(LiteralExpression(SyntaxKind.DefaultLiteralExpression));
Expand Down
1 change: 1 addition & 0 deletions src/Microsoft.Windows.CsWin32/Generator.Invariants.cs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ public partial class Generator
private const string AlsoUsableForAttribute = "AlsoUsableForAttribute";
private const string InvalidHandleValueAttribute = "InvalidHandleValueAttribute";
private const string CanReturnMultipleSuccessValuesAttribute = "CanReturnMultipleSuccessValuesAttribute";
private const string FlexibleArrayAttribute = "FlexibleArrayAttribute";
private const string CanReturnErrorsAsSuccessAttribute = "CanReturnErrorsAsSuccessAttribute";
private const string SimpleFileNameAnnotation = "SimpleFileName";
private const string NamespaceContainerAnnotation = "NamespaceContainer";
Expand Down
42 changes: 42 additions & 0 deletions src/Microsoft.Windows.CsWin32/Generator.MetadataHelpers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,27 @@ internal bool TryGetTypeDefHandle(string @namespace, string name, out TypeDefini
return false;
}

internal bool TryGetTypeDefHandle(EntityHandle entityHandle, out QualifiedTypeDefinitionHandle typeDefHandle)
{
if (entityHandle.IsNil)
{
typeDefHandle = default;
return false;
}

switch (entityHandle.Kind)
{
case HandleKind.TypeReference:
return this.TryGetTypeDefHandle((TypeReferenceHandle)entityHandle, out typeDefHandle);
case HandleKind.TypeDefinition:
typeDefHandle = new QualifiedTypeDefinitionHandle(this, (TypeDefinitionHandle)entityHandle);
return true;
default:
typeDefHandle = default;
return false;
}
}

internal bool IsNonCOMInterface(TypeDefinition interfaceTypeDef)
{
if (this.Reader.StringComparer.Equals(interfaceTypeDef.Name, "IUnknown"))
Expand Down Expand Up @@ -164,6 +185,27 @@ internal bool IsInterface(TypeReferenceHandle typeRefHandle)

internal bool IsDelegate(TypeDefinition typeDef) => (typeDef.Attributes & TypeAttributes.Class) == TypeAttributes.Class && typeDef.BaseType.Kind == HandleKind.TypeReference && this.Reader.StringComparer.Equals(this.Reader.GetTypeReference((TypeReferenceHandle)typeDef.BaseType).Name, nameof(MulticastDelegate));

internal bool IsStructWithFlexibleArray(HandleTypeHandleInfo typeInfo)
{
return this.TryGetTypeDefHandle(typeInfo.Handle, out QualifiedTypeDefinitionHandle typeHandle)
&& typeHandle.Generator.IsStructWithFlexibleArray(typeHandle.DefinitionHandle);
}

internal bool IsStructWithFlexibleArray(TypeDefinitionHandle typeDefHandle)
{
TypeDefinition typeDef = this.Reader.GetTypeDefinition(typeDefHandle);
foreach (FieldDefinitionHandle fieldHandle in typeDef.GetFields())
{
FieldDefinition field = this.Reader.GetFieldDefinition(fieldHandle);
if (MetadataUtilities.FindAttribute(this.Reader, field.GetCustomAttributes(), InteropDecorationNamespace, FlexibleArrayAttribute) is not null)
{
return true;
}
}

return false;
}

internal bool IsManagedType(TypeHandleInfo typeHandleInfo)
{
TypeHandleInfo elementType =
Expand Down
11 changes: 11 additions & 0 deletions src/Microsoft.Windows.CsWin32/Generator.Struct.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,17 @@ private StructDeclarationSyntax DeclareStruct(TypeDefinitionHandle typeDefHandle
context = context with { AllowMarshaling = false };
}

// If the last field has the [FlexibleArray] attribute, we must disable marshaling since the struct
// is only ever valid when accessed via a pointer since the struct acts as a header of an arbitrarily-sized array.
if (typeDef.GetFields().LastOrDefault() is FieldDefinitionHandle { IsNil: false } lastFieldHandle)
{
FieldDefinition lastField = this.Reader.GetFieldDefinition(lastFieldHandle);
if (MetadataUtilities.FindAttribute(this.Reader, lastField.GetCustomAttributes(), InteropDecorationNamespace, FlexibleArrayAttribute) is not null)
{
context = context with { AllowMarshaling = false };
}
}

TypeSyntaxSettings typeSettings = context.Filter(this.fieldTypeSettings);

bool hasUtf16CharField = false;
Expand Down
4 changes: 3 additions & 1 deletion src/Microsoft.Windows.CsWin32/PointerTypeHandleInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ internal override TypeSyntaxAndMarshaling ToTypeSyntax(TypeSyntaxSettings inputs
}

bool xOptional = (parameterAttributes & ParameterAttributes.Optional) == ParameterAttributes.Optional;
if (xOptional && forElement == Generator.GeneratingElement.InterfaceMember && nativeArrayInfo is null)
bool mustUsePointers = xOptional && forElement == Generator.GeneratingElement.InterfaceMember && nativeArrayInfo is null;
mustUsePointers |= this.ElementType is HandleTypeHandleInfo handleElementType && inputs.Generator?.IsStructWithFlexibleArray(handleElementType) is true;
if (mustUsePointers)
{
// Disable marshaling because pointers to optional parameters cannot be passed by reference when used as parameters of a COM interface method.
return new TypeSyntaxAndMarshaling(PointerType(this.ElementType.ToTypeSyntax(
Expand Down
11 changes: 11 additions & 0 deletions test/Microsoft.Windows.CsWin32.Tests/COMTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,17 @@ public void ITypeNameBuilder_ToStringOverload(bool allowMarshaling)
this.GenerateApi(typeName);
}

[Fact]
public void ReferencesToStructWithFlexibleArrayAreAlwaysPointers()
{
this.GenerateApi("IAMLine21Decoder");
Assert.All(this.FindGeneratedMethod("SetOutputFormat"), m => Assert.IsType<PointerTypeSyntax>(m.ParameterList.Parameters[0].Type));

// Assert that the 'unmanaged' declaration of the struct is the *only* declaration.
Assert.Single(this.FindGeneratedType("BITMAPINFO"));
Assert.Empty(this.FindGeneratedType("BITMAPINFO_unmanaged"));
}

[Theory]
[CombinatorialData]
public void COMInterfaceIIDInterfaceOnAppropriateTFMs(
Expand Down
11 changes: 11 additions & 0 deletions test/Microsoft.Windows.CsWin32.Tests/ExternMethodTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,17 @@ public void DefaultEntryPointIsNotEmitted()
Assert.DoesNotContain(attribute.ArgumentList!.Arguments, a => a.NameEquals?.Name.Identifier.ValueText == "EntryPoint");
}

[Fact]
public void ReferencesToStructWithFlexibleArrayAreAlwaysPointers()
{
this.GenerateApi("CreateDIBSection");
Assert.All(this.FindGeneratedMethod("CreateDIBSection"), m => Assert.IsType<PointerTypeSyntax>(m.ParameterList.Parameters[1].Type));

// Assert that the 'unmanaged' declaration of the struct is the *only* declaration.
Assert.Single(this.FindGeneratedType("BITMAPINFO"));
Assert.Empty(this.FindGeneratedType("BITMAPINFO_unmanaged"));
}

private static AttributeSyntax? FindDllImportAttribute(SyntaxList<AttributeListSyntax> attributeLists) => attributeLists.SelectMany(al => al.Attributes).FirstOrDefault(a => a.Name.ToString() == "DllImport");

private IEnumerable<MethodDeclarationSyntax> GenerateMethod(string methodName)
Expand Down