Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add helper APIs for variable-length inline arrays #1130

Merged
merged 1 commit into from
Jan 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/Microsoft.Windows.CsWin32/FastSyntaxFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ internal static ForStatementSyntax ForStatement(VariableDeclarationSyntax? decla

internal static VariableDeclarationSyntax VariableDeclaration(TypeSyntax type) => SyntaxFactory.VariableDeclaration(type.WithTrailingTrivia(TriviaList(Space)));

internal static SizeOfExpressionSyntax SizeOfExpression(TypeSyntax type) => SyntaxFactory.SizeOfExpression(Token(SyntaxKind.SizeOfKeyword), Token(SyntaxKind.OpenParenToken), type, Token(SyntaxKind.CloseParenToken));

internal static MemberAccessExpressionSyntax MemberAccessExpression(SyntaxKind kind, ExpressionSyntax expression, SimpleNameSyntax name) => SyntaxFactory.MemberAccessExpression(kind, expression, Token(GetMemberAccessExpressionOperatorTokenKind(kind)), name);

internal static ConditionalAccessExpressionSyntax ConditionalAccessExpression(ExpressionSyntax expression, SimpleNameSyntax name) => SyntaxFactory.ConditionalAccessExpression(expression, Token(SyntaxKind.QuestionToken), MemberBindingExpression(name));
Expand Down
1 change: 1 addition & 0 deletions src/Microsoft.Windows.CsWin32/Generator.Features.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ public partial class Generator
private readonly bool canUseSpan;
private readonly bool canCallCreateSpan;
private readonly bool canUseUnsafeAsRef;
private readonly bool canUseUnsafeAdd;
private readonly bool canUseUnsafeNullRef;
private readonly bool canUseUnmanagedCallersOnlyAttribute;
private readonly bool canUseSetLastPInvokeError;
Expand Down
148 changes: 148 additions & 0 deletions src/Microsoft.Windows.CsWin32/Generator.Struct.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,14 @@ private StructDeclarationSyntax DeclareStruct(TypeDefinitionHandle typeDefHandle

// If the last field has the [FlexibleArray] attribute, we must disable marshaling since the struct
// is only ever valid when accessed via a pointer since the struct acts as a header of an arbitrarily-sized array.
FieldDefinitionHandle flexibleArrayFieldHandle = default;
MethodDeclarationSyntax? sizeOfMethod = null;
if (typeDef.GetFields().LastOrDefault() is FieldDefinitionHandle { IsNil: false } lastFieldHandle)
{
FieldDefinition lastField = this.Reader.GetFieldDefinition(lastFieldHandle);
if (MetadataUtilities.FindAttribute(this.Reader, lastField.GetCustomAttributes(), InteropDecorationNamespace, FlexibleArrayAttribute) is not null)
{
flexibleArrayFieldHandle = lastFieldHandle;
context = context with { AllowMarshaling = false };
}
}
Expand Down Expand Up @@ -80,6 +83,37 @@ private StructDeclarationSyntax DeclareStruct(TypeDefinitionHandle typeDefHandle
.WithArgumentList(BracketedArgumentList(SingletonSeparatedList(Argument(size)))))
.AddModifiers(TokenWithSpace(this.Visibility), TokenWithSpace(SyntaxKind.UnsafeKeyword), Token(SyntaxKind.FixedKeyword));
}
else if (fieldDefHandle == flexibleArrayFieldHandle)
{
CustomAttributeHandleCollection fieldAttributes = fieldDef.GetCustomAttributes();
var fieldTypeInfo = (ArrayTypeHandleInfo)fieldDef.DecodeSignature(SignatureHandleProvider.Instance, null);
TypeSyntax fieldType = fieldTypeInfo.ElementType.ToTypeSyntax(typeSettings, GeneratingElement.StructMember, fieldAttributes).Type;

if (fieldType is PointerTypeSyntax or FunctionPointerTypeSyntax)
{
// These types are not allowed as generic type arguments (https://github.com/dotnet/runtime/issues/13627)
// so we have to generate a special nested struct dedicated to this type instead of using the generic type.
StructDeclarationSyntax helperStruct = this.DeclareVariableLengthInlineArrayHelper(context, fieldType);
additionalMembers = additionalMembers.Add(helperStruct);

field = FieldDeclaration(
VariableDeclaration(IdentifierName(helperStruct.Identifier.ValueText)))
.AddDeclarationVariables(fieldDeclarator)
.AddModifiers(TokenWithSpace(this.Visibility));
}
else
{
this.RequestVariableLengthInlineArrayHelper(context);
field = FieldDeclaration(
VariableDeclaration(
GenericName($"global::Windows.Win32.VariableLengthInlineArray")
.WithTypeArgumentList(TypeArgumentList().AddArguments(fieldType))))
.AddDeclarationVariables(fieldDeclarator)
.AddModifiers(TokenWithSpace(this.Visibility));
}

sizeOfMethod = this.DeclareSizeOfMethod(name, fieldType, typeSettings);
}
else
{
CustomAttributeHandleCollection fieldAttributes = fieldDef.GetCustomAttributes();
Expand Down Expand Up @@ -334,6 +368,12 @@ private StructDeclarationSyntax DeclareStruct(TypeDefinitionHandle typeDefHandle
}
}

// Add a SizeOf method, if there is a FlexibleArray field.
if (sizeOfMethod is not null)
{
members.Add(sizeOfMethod);
}

// Add the additional members, taking care to not introduce redundant declarations.
members.AddRange(additionalMembers.Where(c => c is not StructDeclarationSyntax cs || !members.OfType<StructDeclarationSyntax>().Any(m => m.Identifier.ValueText == cs.Identifier.ValueText)));

Expand Down Expand Up @@ -370,6 +410,95 @@ private StructDeclarationSyntax DeclareStruct(TypeDefinitionHandle typeDefHandle
return result;
}

private StructDeclarationSyntax DeclareVariableLengthInlineArrayHelper(Context context, TypeSyntax fieldType)
{
IdentifierNameSyntax firstElementFieldName = IdentifierName("e0");
List<MemberDeclarationSyntax> members = new();

// internal unsafe T e0;
members.Add(FieldDeclaration(VariableDeclaration(fieldType).AddVariables(VariableDeclarator(firstElementFieldName.Identifier)))
.AddModifiers(TokenWithSpace(this.Visibility), TokenWithSpace(SyntaxKind.UnsafeKeyword)));

if (this.canUseUnsafeAdd)
{
////[MethodImpl(MethodImplOptions.AggressiveInlining)]
////get { fixed (int** p = &e0) return *(p + index); }
IdentifierNameSyntax pLocal = IdentifierName("p");
AccessorDeclarationSyntax getter = AccessorDeclaration(SyntaxKind.GetAccessorDeclaration)
.WithBody(Block().AddStatements(
FixedStatement(
VariableDeclaration(PointerType(fieldType)).AddVariables(
VariableDeclarator(pLocal.Identifier).WithInitializer(EqualsValueClause(PrefixUnaryExpression(SyntaxKind.AddressOfExpression, firstElementFieldName)))),
ReturnStatement(PrefixUnaryExpression(SyntaxKind.PointerIndirectionExpression, ParenthesizedExpression(BinaryExpression(SyntaxKind.AddExpression, pLocal, IdentifierName("index"))))))))
.AddAttributeLists(AttributeList().AddAttributes(MethodImpl(MethodImplOptions.AggressiveInlining)));

////[MethodImpl(MethodImplOptions.AggressiveInlining)]
////set { fixed (int** p = &e0) *(p + index) = value; }
AccessorDeclarationSyntax setter = AccessorDeclaration(SyntaxKind.SetAccessorDeclaration)
.WithBody(Block().AddStatements(
FixedStatement(
VariableDeclaration(PointerType(fieldType)).AddVariables(
VariableDeclarator(pLocal.Identifier).WithInitializer(EqualsValueClause(PrefixUnaryExpression(SyntaxKind.AddressOfExpression, firstElementFieldName)))),
ExpressionStatement(AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
PrefixUnaryExpression(SyntaxKind.PointerIndirectionExpression, ParenthesizedExpression(BinaryExpression(SyntaxKind.AddExpression, pLocal, IdentifierName("index")))),
IdentifierName("value"))))))
.AddAttributeLists(AttributeList().AddAttributes(MethodImpl(MethodImplOptions.AggressiveInlining)));

////internal unsafe T this[int index]
members.Add(IndexerDeclaration(fieldType.WithTrailingTrivia(Space))
.AddModifiers(TokenWithSpace(this.Visibility), TokenWithSpace(SyntaxKind.UnsafeKeyword))
.AddParameterListParameters(Parameter(Identifier("index")).WithType(PredefinedType(TokenWithSpace(SyntaxKind.IntKeyword))))
.AddAccessorListAccessors(getter, setter));
}

// internal partial struct VariableLengthInlineArrayHelper
return StructDeclaration(Identifier("VariableLengthInlineArrayHelper"))
.AddModifiers(TokenWithSpace(this.Visibility), TokenWithSpace(SyntaxKind.PartialKeyword))
.AddMembers(members.ToArray());
}

private MethodDeclarationSyntax DeclareSizeOfMethod(TypeSyntax structType, TypeSyntax elementType, TypeSyntaxSettings typeSettings)
{
PredefinedTypeSyntax intType = PredefinedType(TokenWithSpace(SyntaxKind.IntKeyword));
IdentifierNameSyntax countName = IdentifierName("count");
IdentifierNameSyntax localName = IdentifierName("v");
List<StatementSyntax> statements = new();

// int v = sizeof(OUTER_STRUCT);
statements.Add(LocalDeclarationStatement(VariableDeclaration(intType).AddVariables(
VariableDeclarator(localName.Identifier).WithInitializer(EqualsValueClause(SizeOfExpression(structType))))));

// if (count > 1)
// v += checked((count - 1) * sizeof(ELEMENT_TYPE));
// else if (count < 0)
// throw new ArgumentOutOfRangeException(nameof(count));
statements.Add(IfStatement(
BinaryExpression(SyntaxKind.GreaterThanExpression, countName, LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(1))),
ExpressionStatement(AssignmentExpression(
SyntaxKind.AddAssignmentExpression,
localName,
CheckedExpression(BinaryExpression(
SyntaxKind.MultiplyExpression,
ParenthesizedExpression(BinaryExpression(SyntaxKind.SubtractExpression, countName, LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(1)))),
SizeOfExpression(elementType))))),
ElseClause(IfStatement(
BinaryExpression(SyntaxKind.LessThanExpression, countName, LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(0))),
ThrowStatement(ObjectCreationExpression(IdentifierName(nameof(ArgumentOutOfRangeException))))).WithCloseParenToken(TokenWithLineFeed(SyntaxKind.CloseParenToken)))).WithCloseParenToken(TokenWithLineFeed(SyntaxKind.CloseParenToken)));

// return v;
statements.Add(ReturnStatement(localName));

// internal static unsafe int SizeOf(int count)
MethodDeclarationSyntax sizeOfMethod = MethodDeclaration(intType, Identifier("SizeOf"))
.AddParameterListParameters(Parameter(countName.Identifier).WithType(intType))
.WithBody(Block().AddStatements(statements.ToArray()))
.AddModifiers(TokenWithSpace(this.Visibility), TokenWithSpace(SyntaxKind.StaticKeyword), TokenWithSpace(SyntaxKind.UnsafeKeyword))
.WithLeadingTrivia(ParseLeadingTrivia("/// <summary>Computes the amount of memory that must be allocated to store this struct, including the specified number of elements in the variable length inline array at the end.</summary>\n"));

return sizeOfMethod;
}

private (TypeSyntax FieldType, SyntaxList<MemberDeclarationSyntax> AdditionalMembers, AttributeSyntax? MarshalAsAttribute) ReinterpretFieldType(FieldDefinition fieldDef, TypeSyntax originalType, CustomAttributeHandleCollection customAttributes, Context context)
{
TypeSyntaxSettings typeSettings = context.Filter(this.fieldTypeSettings);
Expand Down Expand Up @@ -397,4 +526,23 @@ private StructDeclarationSyntax DeclareStruct(TypeDefinitionHandle typeDefHandle

return (originalType, default(SyntaxList<MemberDeclarationSyntax>), marshalAs);
}

private void RequestVariableLengthInlineArrayHelper(Context context)
{
if (this.IsWin32Sdk)
{
if (!this.IsTypeAlreadyFullyDeclared($"{this.Namespace}.{this.variableLengthInlineArrayStruct.Identifier.ValueText}"))
{
this.DeclareUnscopedRefAttributeIfNecessary();
this.volatileCode.GenerateSpecialType("VariableLengthInlineArray", () => this.volatileCode.AddSpecialType("VariableLengthInlineArray", this.variableLengthInlineArrayStruct));
}
}
else if (this.SuperGenerator is not null && this.SuperGenerator.TryGetGenerator("Windows.Win32", out Generator? generator))
{
generator.volatileCode.GenerationTransaction(delegate
{
generator.RequestVariableLengthInlineArrayHelper(context);
});
}
}
}
6 changes: 5 additions & 1 deletion src/Microsoft.Windows.CsWin32/Generator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ public partial class Generator : IGenerator, IDisposable
private readonly TypeSyntaxSettings errorMessageTypeSettings;

private readonly ClassDeclarationSyntax comHelperClass;
private readonly StructDeclarationSyntax variableLengthInlineArrayStruct;

private readonly Dictionary<string, IReadOnlyList<ISymbol>> findTypeSymbolIfAlreadyAvailableCache = new(StringComparer.Ordinal);
private readonly Rental<MetadataReader> metadataReader;
Expand Down Expand Up @@ -86,7 +87,8 @@ public Generator(string metadataLibraryPath, Docs? docs, GeneratorOptions option

this.canUseSpan = this.compilation?.GetTypeByMetadataName(typeof(Span<>).FullName) is not null;
this.canCallCreateSpan = this.compilation?.GetTypeByMetadataName(typeof(MemoryMarshal).FullName)?.GetMembers("CreateSpan").Any() is true;
this.canUseUnsafeAsRef = this.compilation?.GetTypeByMetadataName(typeof(Unsafe).FullName)?.GetMembers("AsRef").Any() is true;
this.canUseUnsafeAsRef = this.compilation?.GetTypeByMetadataName(typeof(Unsafe).FullName)?.GetMembers("Add").Any() is true;
this.canUseUnsafeAdd = this.compilation?.GetTypeByMetadataName(typeof(Unsafe).FullName)?.GetMembers("AsRef").Any() is true;
this.canUseUnsafeNullRef = this.compilation?.GetTypeByMetadataName(typeof(Unsafe).FullName)?.GetMembers("NullRef").Any() is true;
this.canUseUnmanagedCallersOnlyAttribute = this.compilation?.GetTypeByMetadataName("System.Runtime.InteropServices.UnmanagedCallersOnlyAttribute") is not null;
this.canUseSetLastPInvokeError = this.compilation?.GetTypeByMetadataName("System.Runtime.InteropServices.Marshal")?.GetMembers("GetLastSystemError").IsEmpty is false;
Expand All @@ -110,6 +112,7 @@ public Generator(string metadataLibraryPath, Docs? docs, GeneratorOptions option
AddSymbolIf(this.canUseSpan, "canUseSpan");
AddSymbolIf(this.canCallCreateSpan, "canCallCreateSpan");
AddSymbolIf(this.canUseUnsafeAsRef, "canUseUnsafeAsRef");
AddSymbolIf(this.canUseUnsafeAdd, "canUseUnsafeAdd");
AddSymbolIf(this.canUseUnsafeNullRef, "canUseUnsafeNullRef");
AddSymbolIf(compilation?.GetTypeByMetadataName("System.Drawing.Point") is not null, "canUseSystemDrawing");
AddSymbolIf(this.IsFeatureAvailable(Feature.InterfaceStaticMembers), "canUseInterfaceStaticMembers");
Expand Down Expand Up @@ -149,6 +152,7 @@ void AddSymbolIf(bool condition, string symbol)
this.methodsAndConstantsClassName = IdentifierName(options.ClassName);

FetchTemplate("ComHelpers", this, out this.comHelperClass);
FetchTemplate("VariableLengthInlineArray`1", this, out this.variableLengthInlineArrayStruct);
}

internal enum GeneratingElement
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
internal struct VariableLengthInlineArray<T>
where T : unmanaged
{
internal T e0;

#if canUseUnsafeAdd
internal ref T this[int index]
{
[UnscopedRef]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
get => ref Unsafe.Add(ref this.e0, index);
}
#endif

#if canUseSpan
[UnscopedRef]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal Span<T> AsSpan(int length)
{
#if canCallCreateSpan
return MemoryMarshal.CreateSpan(ref this.e0, length);
#else
unsafe
{
fixed (void* p = &this.e0)
{
return new Span<T>(p, length);
}
}
#endif
}
#endif
}
37 changes: 37 additions & 0 deletions test/GenerationSandbox.Tests/FlexibleArrayTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System.Runtime.InteropServices;
using Windows.Win32.System.Ole;

public class FlexibleArrayTests
{
[Fact]
public unsafe void FlexibleArraySizing()
{
const int count = 3;
PAGESET* pPageSet = (PAGESET*)Marshal.AllocHGlobal(PAGESET.SizeOf(count));
try
{
pPageSet->rgPages[0].nFromPage = 0;

Span<PAGERANGE> pageRange = pPageSet->rgPages.AsSpan(count);
for (int i = 0; i < count; i++)
{
pageRange[i].nFromPage = i * 2;
pageRange[i].nToPage = (i * 2) + 1;
}
}
finally
{
Marshal.FreeHGlobal((IntPtr)pPageSet);
}
}

[Fact]
public void SizeOf_Minimum1Element()
{
Assert.Equal(PAGESET.SizeOf(1), PAGESET.SizeOf(0));
Assert.Equal(Marshal.SizeOf<PAGERANGE>(), PAGESET.SizeOf(2) - PAGESET.SizeOf(1));
}
}
1 change: 1 addition & 0 deletions test/GenerationSandbox.Tests/NativeMethods.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ MAKELRESULT
MAKEWPARAM
MAX_PATH
NTSTATUS
PAGESET
PathParseIconLocation
PROCESS_BASIC_INFORMATION
PZZSTR
Expand Down
16 changes: 16 additions & 0 deletions test/Microsoft.Windows.CsWin32.Tests/StructTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,22 @@ public void StructConstantsAreGeneratedAsConstants()
Assert.NotEmpty(type.Members.OfType<FieldDeclarationSyntax>().Where(f => f.Modifiers.Any(SyntaxKind.ConstKeyword)));
}

[Theory]
[MemberData(nameof(TFMData))]
public void FlexibleArrayMember(string tfm)
{
this.compilation = this.starterCompilations[tfm];
this.GenerateApi("BITMAPINFO");
var type = (StructDeclarationSyntax)Assert.Single(this.FindGeneratedType("BITMAPINFO"));
FieldDeclarationSyntax flexArrayField = Assert.Single(type.Members.OfType<FieldDeclarationSyntax>(), m => m.Declaration.Variables.Any(v => v.Identifier.ValueText == "bmiColors"));
var fieldType = Assert.IsType<GenericNameSyntax>(Assert.IsType<QualifiedNameSyntax>(flexArrayField.Declaration.Type).Right);
Assert.Equal("VariableLengthInlineArray", fieldType.Identifier.ValueText);
Assert.Equal("RGBQUAD", Assert.IsType<QualifiedNameSyntax>(Assert.Single(fieldType.TypeArgumentList.Arguments)).Right.Identifier.ValueText);

// Verify that the SizeOf method was generated.
Assert.Single(this.FindGeneratedMethod("SizeOf"));
}

[Theory]
[CombinatorialData]
public void InterestingStructs(
Expand Down