From d80cf0f05ecc73fa98a27d904a8887be099ceceb Mon Sep 17 00:00:00 2001 From: Andrew Arnott Date: Thu, 25 Jan 2024 13:10:07 -0700 Subject: [PATCH] Add helper APIs for variable-length inline arrays Closes #387 --- .../FastSyntaxFactory.cs | 2 + .../Generator.Features.cs | 1 + .../Generator.Struct.cs | 148 ++++++++++++++++++ src/Microsoft.Windows.CsWin32/Generator.cs | 6 +- .../templates/VariableLengthInlineArray`1.cs | 33 ++++ .../FlexibleArrayTests.cs | 37 +++++ .../GenerationSandbox.Tests/NativeMethods.txt | 1 + .../StructTests.cs | 16 ++ 8 files changed, 243 insertions(+), 1 deletion(-) create mode 100644 src/Microsoft.Windows.CsWin32/templates/VariableLengthInlineArray`1.cs create mode 100644 test/GenerationSandbox.Tests/FlexibleArrayTests.cs diff --git a/src/Microsoft.Windows.CsWin32/FastSyntaxFactory.cs b/src/Microsoft.Windows.CsWin32/FastSyntaxFactory.cs index 117772a7..a1919a02 100644 --- a/src/Microsoft.Windows.CsWin32/FastSyntaxFactory.cs +++ b/src/Microsoft.Windows.CsWin32/FastSyntaxFactory.cs @@ -102,6 +102,8 @@ internal static ForStatementSyntax ForStatement(VariableDeclarationSyntax? decla internal static VariableDeclarationSyntax VariableDeclaration(TypeSyntax type) => SyntaxFactory.VariableDeclaration(type.WithTrailingTrivia(TriviaList(Space))); + internal static SizeOfExpressionSyntax SizeOfExpression(TypeSyntax type) => SyntaxFactory.SizeOfExpression(Token(SyntaxKind.SizeOfKeyword), Token(SyntaxKind.OpenParenToken), type, Token(SyntaxKind.CloseParenToken)); + internal static MemberAccessExpressionSyntax MemberAccessExpression(SyntaxKind kind, ExpressionSyntax expression, SimpleNameSyntax name) => SyntaxFactory.MemberAccessExpression(kind, expression, Token(GetMemberAccessExpressionOperatorTokenKind(kind)), name); internal static ConditionalAccessExpressionSyntax ConditionalAccessExpression(ExpressionSyntax expression, SimpleNameSyntax name) => SyntaxFactory.ConditionalAccessExpression(expression, Token(SyntaxKind.QuestionToken), MemberBindingExpression(name)); diff --git a/src/Microsoft.Windows.CsWin32/Generator.Features.cs b/src/Microsoft.Windows.CsWin32/Generator.Features.cs index be224f6c..edb7bcd7 100644 --- a/src/Microsoft.Windows.CsWin32/Generator.Features.cs +++ b/src/Microsoft.Windows.CsWin32/Generator.Features.cs @@ -8,6 +8,7 @@ public partial class Generator private readonly bool canUseSpan; private readonly bool canCallCreateSpan; private readonly bool canUseUnsafeAsRef; + private readonly bool canUseUnsafeAdd; private readonly bool canUseUnsafeNullRef; private readonly bool canUseUnmanagedCallersOnlyAttribute; private readonly bool canUseSetLastPInvokeError; diff --git a/src/Microsoft.Windows.CsWin32/Generator.Struct.cs b/src/Microsoft.Windows.CsWin32/Generator.Struct.cs index 5b64bb2b..f9bc112d 100644 --- a/src/Microsoft.Windows.CsWin32/Generator.Struct.cs +++ b/src/Microsoft.Windows.CsWin32/Generator.Struct.cs @@ -28,11 +28,14 @@ private StructDeclarationSyntax DeclareStruct(TypeDefinitionHandle typeDefHandle // If the last field has the [FlexibleArray] attribute, we must disable marshaling since the struct // is only ever valid when accessed via a pointer since the struct acts as a header of an arbitrarily-sized array. + FieldDefinitionHandle flexibleArrayFieldHandle = default; + MethodDeclarationSyntax? sizeOfMethod = null; if (typeDef.GetFields().LastOrDefault() is FieldDefinitionHandle { IsNil: false } lastFieldHandle) { FieldDefinition lastField = this.Reader.GetFieldDefinition(lastFieldHandle); if (MetadataUtilities.FindAttribute(this.Reader, lastField.GetCustomAttributes(), InteropDecorationNamespace, FlexibleArrayAttribute) is not null) { + flexibleArrayFieldHandle = lastFieldHandle; context = context with { AllowMarshaling = false }; } } @@ -80,6 +83,37 @@ private StructDeclarationSyntax DeclareStruct(TypeDefinitionHandle typeDefHandle .WithArgumentList(BracketedArgumentList(SingletonSeparatedList(Argument(size))))) .AddModifiers(TokenWithSpace(this.Visibility), TokenWithSpace(SyntaxKind.UnsafeKeyword), Token(SyntaxKind.FixedKeyword)); } + else if (fieldDefHandle == flexibleArrayFieldHandle) + { + CustomAttributeHandleCollection fieldAttributes = fieldDef.GetCustomAttributes(); + var fieldTypeInfo = (ArrayTypeHandleInfo)fieldDef.DecodeSignature(SignatureHandleProvider.Instance, null); + TypeSyntax fieldType = fieldTypeInfo.ElementType.ToTypeSyntax(typeSettings, GeneratingElement.StructMember, fieldAttributes).Type; + + if (fieldType is PointerTypeSyntax or FunctionPointerTypeSyntax) + { + // These types are not allowed as generic type arguments (https://github.com/dotnet/runtime/issues/13627) + // so we have to generate a special nested struct dedicated to this type instead of using the generic type. + StructDeclarationSyntax helperStruct = this.DeclareVariableLengthInlineArrayHelper(context, fieldType); + additionalMembers = additionalMembers.Add(helperStruct); + + field = FieldDeclaration( + VariableDeclaration(IdentifierName(helperStruct.Identifier.ValueText))) + .AddDeclarationVariables(fieldDeclarator) + .AddModifiers(TokenWithSpace(this.Visibility)); + } + else + { + this.RequestVariableLengthInlineArrayHelper(context); + field = FieldDeclaration( + VariableDeclaration( + GenericName($"global::Windows.Win32.VariableLengthInlineArray") + .WithTypeArgumentList(TypeArgumentList().AddArguments(fieldType)))) + .AddDeclarationVariables(fieldDeclarator) + .AddModifiers(TokenWithSpace(this.Visibility)); + } + + sizeOfMethod = this.DeclareSizeOfMethod(name, fieldType, typeSettings); + } else { CustomAttributeHandleCollection fieldAttributes = fieldDef.GetCustomAttributes(); @@ -334,6 +368,12 @@ private StructDeclarationSyntax DeclareStruct(TypeDefinitionHandle typeDefHandle } } + // Add a SizeOf method, if there is a FlexibleArray field. + if (sizeOfMethod is not null) + { + members.Add(sizeOfMethod); + } + // Add the additional members, taking care to not introduce redundant declarations. members.AddRange(additionalMembers.Where(c => c is not StructDeclarationSyntax cs || !members.OfType().Any(m => m.Identifier.ValueText == cs.Identifier.ValueText))); @@ -370,6 +410,95 @@ private StructDeclarationSyntax DeclareStruct(TypeDefinitionHandle typeDefHandle return result; } + private StructDeclarationSyntax DeclareVariableLengthInlineArrayHelper(Context context, TypeSyntax fieldType) + { + IdentifierNameSyntax firstElementFieldName = IdentifierName("e0"); + List members = new(); + + // internal unsafe T e0; + members.Add(FieldDeclaration(VariableDeclaration(fieldType).AddVariables(VariableDeclarator(firstElementFieldName.Identifier))) + .AddModifiers(TokenWithSpace(this.Visibility), TokenWithSpace(SyntaxKind.UnsafeKeyword))); + + if (this.canUseUnsafeAdd) + { + ////[MethodImpl(MethodImplOptions.AggressiveInlining)] + ////get { fixed (int** p = &e0) return *(p + index); } + IdentifierNameSyntax pLocal = IdentifierName("p"); + AccessorDeclarationSyntax getter = AccessorDeclaration(SyntaxKind.GetAccessorDeclaration) + .WithBody(Block().AddStatements( + FixedStatement( + VariableDeclaration(PointerType(fieldType)).AddVariables( + VariableDeclarator(pLocal.Identifier).WithInitializer(EqualsValueClause(PrefixUnaryExpression(SyntaxKind.AddressOfExpression, firstElementFieldName)))), + ReturnStatement(PrefixUnaryExpression(SyntaxKind.PointerIndirectionExpression, ParenthesizedExpression(BinaryExpression(SyntaxKind.AddExpression, pLocal, IdentifierName("index")))))))) + .AddAttributeLists(AttributeList().AddAttributes(MethodImpl(MethodImplOptions.AggressiveInlining))); + + ////[MethodImpl(MethodImplOptions.AggressiveInlining)] + ////set { fixed (int** p = &e0) *(p + index) = value; } + AccessorDeclarationSyntax setter = AccessorDeclaration(SyntaxKind.SetAccessorDeclaration) + .WithBody(Block().AddStatements( + FixedStatement( + VariableDeclaration(PointerType(fieldType)).AddVariables( + VariableDeclarator(pLocal.Identifier).WithInitializer(EqualsValueClause(PrefixUnaryExpression(SyntaxKind.AddressOfExpression, firstElementFieldName)))), + ExpressionStatement(AssignmentExpression( + SyntaxKind.SimpleAssignmentExpression, + PrefixUnaryExpression(SyntaxKind.PointerIndirectionExpression, ParenthesizedExpression(BinaryExpression(SyntaxKind.AddExpression, pLocal, IdentifierName("index")))), + IdentifierName("value")))))) + .AddAttributeLists(AttributeList().AddAttributes(MethodImpl(MethodImplOptions.AggressiveInlining))); + + ////internal unsafe T this[int index] + members.Add(IndexerDeclaration(fieldType.WithTrailingTrivia(Space)) + .AddModifiers(TokenWithSpace(this.Visibility), TokenWithSpace(SyntaxKind.UnsafeKeyword)) + .AddParameterListParameters(Parameter(Identifier("index")).WithType(PredefinedType(TokenWithSpace(SyntaxKind.IntKeyword)))) + .AddAccessorListAccessors(getter, setter)); + } + + // internal partial struct VariableLengthInlineArrayHelper + return StructDeclaration(Identifier("VariableLengthInlineArrayHelper")) + .AddModifiers(TokenWithSpace(this.Visibility), TokenWithSpace(SyntaxKind.PartialKeyword)) + .AddMembers(members.ToArray()); + } + + private MethodDeclarationSyntax DeclareSizeOfMethod(TypeSyntax structType, TypeSyntax elementType, TypeSyntaxSettings typeSettings) + { + PredefinedTypeSyntax intType = PredefinedType(TokenWithSpace(SyntaxKind.IntKeyword)); + IdentifierNameSyntax countName = IdentifierName("count"); + IdentifierNameSyntax localName = IdentifierName("v"); + List statements = new(); + + // int v = sizeof(OUTER_STRUCT); + statements.Add(LocalDeclarationStatement(VariableDeclaration(intType).AddVariables( + VariableDeclarator(localName.Identifier).WithInitializer(EqualsValueClause(SizeOfExpression(structType)))))); + + // if (count > 1) + // v += checked((count - 1) * sizeof(ELEMENT_TYPE)); + // else if (count < 0) + // throw new ArgumentOutOfRangeException(nameof(count)); + statements.Add(IfStatement( + BinaryExpression(SyntaxKind.GreaterThanExpression, countName, LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(1))), + ExpressionStatement(AssignmentExpression( + SyntaxKind.AddAssignmentExpression, + localName, + CheckedExpression(BinaryExpression( + SyntaxKind.MultiplyExpression, + ParenthesizedExpression(BinaryExpression(SyntaxKind.SubtractExpression, countName, LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(1)))), + SizeOfExpression(elementType))))), + ElseClause(IfStatement( + BinaryExpression(SyntaxKind.LessThanExpression, countName, LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(0))), + ThrowStatement(ObjectCreationExpression(IdentifierName(nameof(ArgumentOutOfRangeException))))).WithCloseParenToken(TokenWithLineFeed(SyntaxKind.CloseParenToken)))).WithCloseParenToken(TokenWithLineFeed(SyntaxKind.CloseParenToken))); + + // return v; + statements.Add(ReturnStatement(localName)); + + // internal static unsafe int SizeOf(int count) + MethodDeclarationSyntax sizeOfMethod = MethodDeclaration(intType, Identifier("SizeOf")) + .AddParameterListParameters(Parameter(countName.Identifier).WithType(intType)) + .WithBody(Block().AddStatements(statements.ToArray())) + .AddModifiers(TokenWithSpace(this.Visibility), TokenWithSpace(SyntaxKind.StaticKeyword), TokenWithSpace(SyntaxKind.UnsafeKeyword)) + .WithLeadingTrivia(ParseLeadingTrivia("/// Computes the amount of memory that must be allocated to store this struct, including the specified number of elements in the variable length inline array at the end.\n")); + + return sizeOfMethod; + } + private (TypeSyntax FieldType, SyntaxList AdditionalMembers, AttributeSyntax? MarshalAsAttribute) ReinterpretFieldType(FieldDefinition fieldDef, TypeSyntax originalType, CustomAttributeHandleCollection customAttributes, Context context) { TypeSyntaxSettings typeSettings = context.Filter(this.fieldTypeSettings); @@ -397,4 +526,23 @@ private StructDeclarationSyntax DeclareStruct(TypeDefinitionHandle typeDefHandle return (originalType, default(SyntaxList), marshalAs); } + + private void RequestVariableLengthInlineArrayHelper(Context context) + { + if (this.IsWin32Sdk) + { + if (!this.IsTypeAlreadyFullyDeclared($"{this.Namespace}.{this.variableLengthInlineArrayStruct.Identifier.ValueText}")) + { + this.DeclareUnscopedRefAttributeIfNecessary(); + this.volatileCode.GenerateSpecialType("VariableLengthInlineArray", () => this.volatileCode.AddSpecialType("VariableLengthInlineArray", this.variableLengthInlineArrayStruct)); + } + } + else if (this.SuperGenerator is not null && this.SuperGenerator.TryGetGenerator("Windows.Win32", out Generator? generator)) + { + generator.volatileCode.GenerationTransaction(delegate + { + generator.RequestVariableLengthInlineArrayHelper(context); + }); + } + } } diff --git a/src/Microsoft.Windows.CsWin32/Generator.cs b/src/Microsoft.Windows.CsWin32/Generator.cs index 8131d5cd..38380313 100644 --- a/src/Microsoft.Windows.CsWin32/Generator.cs +++ b/src/Microsoft.Windows.CsWin32/Generator.cs @@ -22,6 +22,7 @@ public partial class Generator : IGenerator, IDisposable private readonly TypeSyntaxSettings errorMessageTypeSettings; private readonly ClassDeclarationSyntax comHelperClass; + private readonly StructDeclarationSyntax variableLengthInlineArrayStruct; private readonly Dictionary> findTypeSymbolIfAlreadyAvailableCache = new(StringComparer.Ordinal); private readonly Rental metadataReader; @@ -86,7 +87,8 @@ public Generator(string metadataLibraryPath, Docs? docs, GeneratorOptions option this.canUseSpan = this.compilation?.GetTypeByMetadataName(typeof(Span<>).FullName) is not null; this.canCallCreateSpan = this.compilation?.GetTypeByMetadataName(typeof(MemoryMarshal).FullName)?.GetMembers("CreateSpan").Any() is true; - this.canUseUnsafeAsRef = this.compilation?.GetTypeByMetadataName(typeof(Unsafe).FullName)?.GetMembers("AsRef").Any() is true; + this.canUseUnsafeAsRef = this.compilation?.GetTypeByMetadataName(typeof(Unsafe).FullName)?.GetMembers("Add").Any() is true; + this.canUseUnsafeAdd = this.compilation?.GetTypeByMetadataName(typeof(Unsafe).FullName)?.GetMembers("AsRef").Any() is true; this.canUseUnsafeNullRef = this.compilation?.GetTypeByMetadataName(typeof(Unsafe).FullName)?.GetMembers("NullRef").Any() is true; this.canUseUnmanagedCallersOnlyAttribute = this.compilation?.GetTypeByMetadataName("System.Runtime.InteropServices.UnmanagedCallersOnlyAttribute") is not null; this.canUseSetLastPInvokeError = this.compilation?.GetTypeByMetadataName("System.Runtime.InteropServices.Marshal")?.GetMembers("GetLastSystemError").IsEmpty is false; @@ -110,6 +112,7 @@ public Generator(string metadataLibraryPath, Docs? docs, GeneratorOptions option AddSymbolIf(this.canUseSpan, "canUseSpan"); AddSymbolIf(this.canCallCreateSpan, "canCallCreateSpan"); AddSymbolIf(this.canUseUnsafeAsRef, "canUseUnsafeAsRef"); + AddSymbolIf(this.canUseUnsafeAdd, "canUseUnsafeAdd"); AddSymbolIf(this.canUseUnsafeNullRef, "canUseUnsafeNullRef"); AddSymbolIf(compilation?.GetTypeByMetadataName("System.Drawing.Point") is not null, "canUseSystemDrawing"); AddSymbolIf(this.IsFeatureAvailable(Feature.InterfaceStaticMembers), "canUseInterfaceStaticMembers"); @@ -149,6 +152,7 @@ void AddSymbolIf(bool condition, string symbol) this.methodsAndConstantsClassName = IdentifierName(options.ClassName); FetchTemplate("ComHelpers", this, out this.comHelperClass); + FetchTemplate("VariableLengthInlineArray`1", this, out this.variableLengthInlineArrayStruct); } internal enum GeneratingElement diff --git a/src/Microsoft.Windows.CsWin32/templates/VariableLengthInlineArray`1.cs b/src/Microsoft.Windows.CsWin32/templates/VariableLengthInlineArray`1.cs new file mode 100644 index 00000000..d29ed33d --- /dev/null +++ b/src/Microsoft.Windows.CsWin32/templates/VariableLengthInlineArray`1.cs @@ -0,0 +1,33 @@ +internal struct VariableLengthInlineArray + where T : unmanaged +{ + internal T e0; + +#if canUseUnsafeAdd + internal ref T this[int index] + { + [UnscopedRef] + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => ref Unsafe.Add(ref this.e0, index); + } +#endif + +#if canUseSpan + [UnscopedRef] + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal Span AsSpan(int length) + { +#if canCallCreateSpan + return MemoryMarshal.CreateSpan(ref this.e0, length); +#else + unsafe + { + fixed (void* p = &this.e0) + { + return new Span(p, length); + } + } +#endif + } +#endif +} diff --git a/test/GenerationSandbox.Tests/FlexibleArrayTests.cs b/test/GenerationSandbox.Tests/FlexibleArrayTests.cs new file mode 100644 index 00000000..6d9cd95a --- /dev/null +++ b/test/GenerationSandbox.Tests/FlexibleArrayTests.cs @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Runtime.InteropServices; +using Windows.Win32.System.Ole; + +public class FlexibleArrayTests +{ + [Fact] + public unsafe void FlexibleArraySizing() + { + const int count = 3; + PAGESET* pPageSet = (PAGESET*)Marshal.AllocHGlobal(PAGESET.SizeOf(count)); + try + { + pPageSet->rgPages[0].nFromPage = 0; + + Span pageRange = pPageSet->rgPages.AsSpan(count); + for (int i = 0; i < count; i++) + { + pageRange[i].nFromPage = i * 2; + pageRange[i].nToPage = (i * 2) + 1; + } + } + finally + { + Marshal.FreeHGlobal((IntPtr)pPageSet); + } + } + + [Fact] + public void SizeOf_Minimum1Element() + { + Assert.Equal(PAGESET.SizeOf(1), PAGESET.SizeOf(0)); + Assert.Equal(Marshal.SizeOf(), PAGESET.SizeOf(2) - PAGESET.SizeOf(1)); + } +} diff --git a/test/GenerationSandbox.Tests/NativeMethods.txt b/test/GenerationSandbox.Tests/NativeMethods.txt index 47b4c279..2f06e9cb 100644 --- a/test/GenerationSandbox.Tests/NativeMethods.txt +++ b/test/GenerationSandbox.Tests/NativeMethods.txt @@ -30,6 +30,7 @@ MAKELRESULT MAKEWPARAM MAX_PATH NTSTATUS +PAGESET PathParseIconLocation PROCESS_BASIC_INFORMATION PZZSTR diff --git a/test/Microsoft.Windows.CsWin32.Tests/StructTests.cs b/test/Microsoft.Windows.CsWin32.Tests/StructTests.cs index 100e727b..0a55db05 100644 --- a/test/Microsoft.Windows.CsWin32.Tests/StructTests.cs +++ b/test/Microsoft.Windows.CsWin32.Tests/StructTests.cs @@ -244,6 +244,22 @@ public void StructConstantsAreGeneratedAsConstants() Assert.NotEmpty(type.Members.OfType().Where(f => f.Modifiers.Any(SyntaxKind.ConstKeyword))); } + [Theory] + [MemberData(nameof(TFMData))] + public void FlexibleArrayMember(string tfm) + { + this.compilation = this.starterCompilations[tfm]; + this.GenerateApi("BITMAPINFO"); + var type = (StructDeclarationSyntax)Assert.Single(this.FindGeneratedType("BITMAPINFO")); + FieldDeclarationSyntax flexArrayField = Assert.Single(type.Members.OfType(), m => m.Declaration.Variables.Any(v => v.Identifier.ValueText == "bmiColors")); + var fieldType = Assert.IsType(Assert.IsType(flexArrayField.Declaration.Type).Right); + Assert.Equal("VariableLengthInlineArray", fieldType.Identifier.ValueText); + Assert.Equal("RGBQUAD", Assert.IsType(Assert.Single(fieldType.TypeArgumentList.Arguments)).Right.Identifier.ValueText); + + // Verify that the SizeOf method was generated. + Assert.Single(this.FindGeneratedMethod("SizeOf")); + } + [Theory] [CombinatorialData] public void InterestingStructs(