Skip to content

Commit

Permalink
Declare accessor properties for bitfields
Browse files Browse the repository at this point in the history
Closes #987
  • Loading branch information
AArnott committed Jan 12, 2024
1 parent 7f4d00f commit 555bab9
Show file tree
Hide file tree
Showing 6 changed files with 242 additions and 3 deletions.
1 change: 1 addition & 0 deletions src/Microsoft.Windows.CsWin32/Generator.Invariants.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ public partial class Generator
{
internal const string InteropDecorationNamespace = "Windows.Win32.Foundation.Metadata";
internal const string NativeArrayInfoAttribute = "NativeArrayInfoAttribute";
internal const string NativeBitfieldAttribute = "NativeBitfieldAttribute";
internal const string MemorySizeAttribute = "MemorySizeAttribute";
internal const string RAIIFreeAttribute = "RAIIFreeAttribute";
internal const string DoNotReleaseAttribute = "DoNotReleaseAttribute";
Expand Down
124 changes: 124 additions & 0 deletions src/Microsoft.Windows.CsWin32/Generator.Struct.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,16 @@ namespace Microsoft.Windows.CsWin32;

public partial class Generator
{
private static byte GetLengthInBytes(PrimitiveTypeCode code) => code switch
{
PrimitiveTypeCode.SByte or PrimitiveTypeCode.Byte => 1,
PrimitiveTypeCode.Int16 or PrimitiveTypeCode.UInt16 => 2,
PrimitiveTypeCode.Int32 or PrimitiveTypeCode.UInt32 => 4,
PrimitiveTypeCode.Int64 or PrimitiveTypeCode.UInt64 => 8,
PrimitiveTypeCode.IntPtr or PrimitiveTypeCode.UIntPtr => 8, // Assume this -- guessing high isn't a problem for our use case.
_ => throw new NotSupportedException($"Unsupported primitive type code: {code}"),
};

private StructDeclarationSyntax DeclareStruct(TypeDefinitionHandle typeDefHandle, Context context)
{
TypeDefinition typeDef = this.Reader.GetTypeDefinition(typeDefHandle);
Expand Down Expand Up @@ -130,6 +140,120 @@ private StructDeclarationSyntax DeclareStruct(TypeDefinitionHandle typeDefHandle
}

members.Add(field);

foreach (CustomAttribute bitfieldAttribute in MetadataUtilities.FindAttributes(this.Reader, fieldDef.GetCustomAttributes(), InteropDecorationNamespace, NativeBitfieldAttribute))
{
var fieldTypeInfo = (PrimitiveTypeHandleInfo)fieldDef.DecodeSignature(SignatureHandleProvider.Instance, null);

CustomAttributeValue<TypeSyntax> decodedAttribute = bitfieldAttribute.DecodeValue(CustomAttributeTypeProvider.Instance);
string propName = (string)decodedAttribute.FixedArguments[0].Value!;
byte propOffset = (byte)(long)decodedAttribute.FixedArguments[1].Value!;
byte propLength = (byte)(long)decodedAttribute.FixedArguments[2].Value!;
if (propLength == 0)
{
// D3DKMDT_DISPLAYMODE_FLAGS has an "Anonymous" 0-length bitfield,
// but that's totally useless and breaks our math later on, so skip it.
continue;
}

TypeSyntax propertyType = propLength switch
{
1 => PredefinedType(Token(SyntaxKind.BoolKeyword)),
<= 8 => PredefinedType(Token(SyntaxKind.ByteKeyword)),
<= 16 => PredefinedType(Token(SyntaxKind.UShortKeyword)),
<= 32 => PredefinedType(Token(SyntaxKind.UIntKeyword)),
<= 64 => PredefinedType(Token(SyntaxKind.ULongKeyword)),
_ => throw new NotSupportedException(),
};

AccessorDeclarationSyntax getter = AccessorDeclaration(SyntaxKind.GetAccessorDeclaration);
AccessorDeclarationSyntax setter = AccessorDeclaration(SyntaxKind.SetAccessorDeclaration);

ulong maskNoOffset = (1UL << propLength) - 1;
ulong mask = maskNoOffset << propOffset;
int fieldLengthInHexChars = GetLengthInBytes(fieldTypeInfo.PrimitiveTypeCode) * 2;
LiteralExpressionSyntax maskExpr = LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(ToHex(mask, fieldLengthInHexChars), mask));

ExpressionSyntax fieldAccess = MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, ThisExpression(), IdentifierName(fieldName));
TypeSyntax fieldType = field.Declaration.Type.WithoutTrailingTrivia();

//// unchecked((int)~mask)
ExpressionSyntax notMask = UncheckedExpression(CastExpression(fieldType, PrefixUnaryExpression(SyntaxKind.BitwiseNotExpression, maskExpr)));
//// (field & unchecked((int)~mask))
ExpressionSyntax fieldAndNotMask = ParenthesizedExpression(BinaryExpression(SyntaxKind.BitwiseAndExpression, fieldAccess, notMask));

if (propLength > 1)
{
LiteralExpressionSyntax maskNoOffsetExpr = LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(ToHex(maskNoOffset, fieldLengthInHexChars), maskNoOffset));
ExpressionSyntax notMaskNoOffset = UncheckedExpression(CastExpression(propertyType, PrefixUnaryExpression(SyntaxKind.BitwiseNotExpression, maskNoOffsetExpr)));
LiteralExpressionSyntax propOffsetExpr = LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(propOffset));

// get => (byte)((field & unchecked((FIELDTYPE)getterMask)) >> propOffset);
ExpressionSyntax getterExpression =
CastExpression(propertyType, ParenthesizedExpression(BinaryExpression(
SyntaxKind.RightShiftExpression,
ParenthesizedExpression(BinaryExpression(
SyntaxKind.BitwiseAndExpression,
fieldAccess,
UncheckedExpression(CastExpression(fieldType, maskExpr)))),
propOffsetExpr)));
getter = getter
.WithExpressionBody(ArrowExpressionClause(getterExpression))
.WithSemicolonToken(SemicolonWithLineFeed);

// if ((value & ~maskNoOffset) != 0) throw new ArgumentOutOfRangeException(nameof(value));
// field = (int)((field & unchecked((int)~mask)) | ((int)value << propOffset)));
IdentifierNameSyntax valueName = IdentifierName("value");
setter = setter.WithBody(Block().AddStatements(
IfStatement(
BinaryExpression(SyntaxKind.NotEqualsExpression, ParenthesizedExpression(BinaryExpression(SyntaxKind.BitwiseAndExpression, valueName, notMaskNoOffset)), LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(0))),
ThrowStatement(ObjectCreationExpression(IdentifierName(nameof(ArgumentOutOfRangeException))).AddArgumentListArguments(Argument(InvocationExpression(IdentifierName("nameof")).WithArgumentList(ArgumentList().AddArguments(Argument(valueName))))))),
ExpressionStatement(AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
fieldAccess,
CastExpression(fieldType, ParenthesizedExpression(
BinaryExpression(
SyntaxKind.BitwiseOrExpression,
//// (field & unchecked((int)~mask))
fieldAndNotMask,
//// ((int)value << propOffset)
ParenthesizedExpression(BinaryExpression(SyntaxKind.LeftShiftExpression, CastExpression(fieldType, valueName), propOffsetExpr)))))))));
}
else
{
// get => (field & getterMask) != 0;
getter = getter
.WithExpressionBody(ArrowExpressionClause(BinaryExpression(
SyntaxKind.NotEqualsExpression,
ParenthesizedExpression(BinaryExpression(SyntaxKind.BitwiseAndExpression, fieldAccess, maskExpr)),
LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(0)))))
.WithSemicolonToken(SemicolonWithLineFeed);

// set => field = (byte)(value ? field | getterMask : field & unchecked((int)~getterMask));
setter = setter
.WithExpressionBody(ArrowExpressionClause(
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
fieldAccess,
CastExpression(
fieldType,
ParenthesizedExpression(
ConditionalExpression(
IdentifierName("value"),
BinaryExpression(SyntaxKind.BitwiseOrExpression, fieldAccess, maskExpr),
fieldAndNotMask))))))
.WithSemicolonToken(SemicolonWithLineFeed);
}

string bitDescription = propLength == 1 ? $"bit {propOffset}" : $"bits {propOffset}-{propOffset + propLength - 1}";

PropertyDeclarationSyntax bitfieldProperty = PropertyDeclaration(propertyType.WithTrailingTrivia(Space), Identifier(propName).WithTrailingTrivia(LineFeed))
.AddModifiers(TokenWithSpace(this.Visibility))
.WithAccessorList(AccessorList().AddAccessors(getter, setter))
.WithLeadingTrivia(ParseLeadingTrivia($"/// <summary>Gets or sets {bitDescription} in the <see cref=\"{fieldName}\" /> field.</summary>\n"));

members.Add(bitfieldProperty);
}
}
catch (Exception ex)
{
Expand Down
6 changes: 3 additions & 3 deletions src/Microsoft.Windows.CsWin32/SimpleSyntaxFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -349,11 +349,11 @@ internal static ArrayCreationExpressionSyntax NewByteArray(ReadOnlySpan<byte> by
InitializerExpression(SyntaxKind.ArrayInitializerExpression, SeparatedList(elements)));
}

internal static unsafe string ToHex<T>(T value)
internal static unsafe string ToHex<T>(T value, int? hexLength = null)
where T : unmanaged
{
int fullHexLength = sizeof(T) * 2;
string hex = string.Format(CultureInfo.InvariantCulture, "0x{0:X" + fullHexLength + "}", value);
hexLength ??= sizeof(T) * 2;
string hex = string.Format(CultureInfo.InvariantCulture, "0x{0:X" + hexLength + "}", value);
return hex;
}

Expand Down
44 changes: 44 additions & 0 deletions test/GenerationSandbox.Tests/BitFieldTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using Windows.Win32.Devices.Usb;
using Windows.Win32.UI.Shell;

public class BitFieldTests
{
[Fact]
public void Bool()
{
SHELLFLAGSTATE s = default;
Assert.False(s.fNoConfirmRecycle);
s.fNoConfirmRecycle = true;
Assert.Equal(0b100, s._bitfield);
Assert.True(s.fNoConfirmRecycle);

s._bitfield = unchecked((int)0xffffffff);
Assert.True(s.fNoConfirmRecycle);
s.fNoConfirmRecycle = false;
Assert.Equal(unchecked((int)0xfffffffb), s._bitfield);
}

[Fact]
public void ThrowWhenSetValueIsOutOfBounds()
{
BM_REQUEST_TYPE._BM s = default;
Assert.Throws<ArgumentOutOfRangeException>(() => s.Type = 0b100);
}

[Fact]
public void SetValueMultiBit()
{
BM_REQUEST_TYPE._BM s = default;
s.Type = 0b11;
Assert.Equal(0b1100000, s._bitfield);
Assert.Equal(0b11, s.Type);

s._bitfield = 0xff;
Assert.Equal(0b11, s.Type);
s.Type = 0;
Assert.Equal(0b10011111, s._bitfield);
}
}
2 changes: 2 additions & 0 deletions test/GenerationSandbox.Tests/NativeMethods.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
BCRYPT_KEY_HANDLE
BM_REQUEST_TYPE
BOOL
BOOLEAN
CHAR
Expand Down Expand Up @@ -37,6 +38,7 @@ RegLoadAppKey
RM_PROCESS_INFO
S_OK
SHDESCRIPTIONID
SHELLFLAGSTATE
ShellLink
ShellWindowFindWindowOptions
ShellWindows
Expand Down
68 changes: 68 additions & 0 deletions test/Microsoft.Windows.CsWin32.Tests/StructTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,71 @@ public void FieldWithAssociatedEnum()
Assert.Equal("SHDID_ID", Assert.IsType<IdentifierNameSyntax>(property.Type).Identifier.ValueText);
}

[Fact]
public void Bitfield_Bool()
{
this.GenerateApi("SHELLFLAGSTATE");
StructDeclarationSyntax structDecl = (StructDeclarationSyntax)Assert.Single(this.FindGeneratedType("SHELLFLAGSTATE"));

// Verify that the struct has a single field of type int.
FieldDeclarationSyntax bitfield = Assert.Single(structDecl.Members.OfType<FieldDeclarationSyntax>());
Assert.True(bitfield.Declaration.Type is PredefinedTypeSyntax { Keyword: { RawKind: (int)SyntaxKind.IntKeyword } });

// Verify that many other *properties* are added that access into the bitfield.
// The actual behavior of the properties is verified in the functional unit tests.
List<PropertyDeclarationSyntax> properties = structDecl.Members.OfType<PropertyDeclarationSyntax>().ToList();
Assert.Contains(properties, p => p.Identifier.ValueText == "fShowAllObjects" && p.Type is PredefinedTypeSyntax { Keyword: { RawKind: (int)SyntaxKind.BoolKeyword } });
Assert.Contains(properties, p => p.Identifier.ValueText == "fShowExtensions" && p.Type is PredefinedTypeSyntax { Keyword: { RawKind: (int)SyntaxKind.BoolKeyword } });
}

[Fact]
public void Bitfield_UIntPtr()
{
this.GenerateApi("PSAPI_WORKING_SET_BLOCK");
StructDeclarationSyntax structDecl = (StructDeclarationSyntax)Assert.Single(this.FindGeneratedType("_Anonymous_e__Struct"));

// Verify that the struct has a single field of type int.
FieldDeclarationSyntax bitfield = Assert.Single(structDecl.Members.OfType<FieldDeclarationSyntax>());
Assert.True(bitfield.Declaration.Type is IdentifierNameSyntax { Identifier.ValueText: "nuint" });

// Verify that many other *properties* are added that access into the bitfield.
// The actual behavior of the properties is verified in the functional unit tests.
List<PropertyDeclarationSyntax> properties = structDecl.Members.OfType<PropertyDeclarationSyntax>().ToList();
Assert.Contains(properties, p => p.Identifier.ValueText == "Protection" && p.Type is PredefinedTypeSyntax { Keyword: { RawKind: (int)SyntaxKind.ByteKeyword } });
Assert.Contains(properties, p => p.Identifier.ValueText == "Shared" && p.Type is PredefinedTypeSyntax { Keyword: { RawKind: (int)SyntaxKind.BoolKeyword } });
}

[Fact]
public void Bitfield_Multiple()
{
this.GenerateApi("AM_COLCON");
StructDeclarationSyntax structDecl = (StructDeclarationSyntax)Assert.Single(this.FindGeneratedType("AM_COLCON"));
Assert.Equal(4, structDecl.Members.OfType<FieldDeclarationSyntax>().Count());

// Verify that each field produced 2 properties of type byte.
List<PropertyDeclarationSyntax> properties = structDecl.Members.OfType<PropertyDeclarationSyntax>().ToList();
Assert.Equal(4 * 2, properties.Count);
Assert.All(properties, p => Assert.True(p.Type is PredefinedTypeSyntax { Keyword: { RawKind: (int)SyntaxKind.ByteKeyword } }));
Assert.Contains(properties, p => p.Identifier.ValueText == "emph1col");
Assert.Contains(properties, p => p.Identifier.ValueText == "patcon");
}

[Fact]
public void Bitfield_MultiplePropertyTypes()
{
this.GenerateApi("BM_REQUEST_TYPE");
StructDeclarationSyntax structDecl = (StructDeclarationSyntax)Assert.Single(this.FindGeneratedType("_BM"));

List<PropertyDeclarationSyntax> fields = structDecl.Members.OfType<PropertyDeclarationSyntax>().ToList();
Assert.Equal(4, fields.Count);
Assert.Equal(SyntaxKind.ByteKeyword, GetPropertyType("Recipient"));
Assert.Equal(SyntaxKind.ByteKeyword, GetPropertyType("Reserved"));
Assert.Equal(SyntaxKind.ByteKeyword, GetPropertyType("Type"));
Assert.Equal(SyntaxKind.BoolKeyword, GetPropertyType("Dir"));

SyntaxKind GetPropertyType(string name) => ((PredefinedTypeSyntax)fields.Single(f => f.Identifier.ValueText == name).Type).Keyword.Kind();
}

[Theory]
[InlineData("PCSTR")]
public void SpecialStruct_ByRequest(string structName)
Expand All @@ -178,6 +243,9 @@ public void SpecialStruct_ByRequest(string structName)
"DEVICE_RELATIONS", // ends with an inline "flexible" array
"D3DHAL_CONTEXTCREATEDATA", // contains a field that is a pointer to a struct that is normally managed
"MIB_TCPTABLE", // a struct that references another struct with a nested anonymous type, that loosely references an enum in the same namespace (by way of an attribute).
"WHEA_XPF_TLB_CHECK", // a struct with a ulong bitfield with one field exceeding 32-bits in length.
"TRANSPORT_PROPERTIES", // a struct with a long bitfield with one subfield expressed as ulong.
"D3DKMDT_DISPLAYMODE_FLAGS", // a struct with an interesting bool/byte conversion.
"WSD_EVENT")] // has a pointer field to a managed struct
string name,
bool allowMarshaling)
Expand Down

0 comments on commit 555bab9

Please sign in to comment.