Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix handling of signed bit fields #1126

Merged
merged 1 commit into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/Microsoft.Windows.CsWin32/FastSyntaxFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,10 @@ internal static SeparatedSyntaxList<TNode> SeparatedList<TNode>()

internal static IsPatternExpressionSyntax IsPatternExpression(ExpressionSyntax expression, PatternSyntax pattern) => SyntaxFactory.IsPatternExpression(expression, Token(TriviaList(Space), SyntaxKind.IsKeyword, TriviaList(Space)), pattern);

internal static BinaryPatternSyntax BinaryPattern(SyntaxKind kind, PatternSyntax left, PatternSyntax right) => SyntaxFactory.BinaryPattern(kind, left, TokenWithSpaces(GetBinaryPatternOperatorTokenKind(kind)), right);

internal static RelationalPatternSyntax RelationalPattern(SyntaxToken operatorToken, ExpressionSyntax expression) => SyntaxFactory.RelationalPattern(operatorToken, expression);

internal static ConditionalExpressionSyntax ConditionalExpression(ExpressionSyntax condition, ExpressionSyntax whenTrue, ExpressionSyntax whenFalse) => SyntaxFactory.ConditionalExpression(condition, Token(TriviaList(Space), SyntaxKind.QuestionToken, TriviaList(Space)), whenTrue, Token(TriviaList(Space), SyntaxKind.ColonToken, TriviaList(Space)), whenFalse);

internal static IfStatementSyntax IfStatement(ExpressionSyntax condition, StatementSyntax whenTrue) => IfStatement(condition, whenTrue, null);
Expand Down Expand Up @@ -595,6 +599,14 @@ private static SyntaxKind GetLiteralExpressionTokenKind(SyntaxKind kind)
};
}

private static SyntaxKind GetBinaryPatternOperatorTokenKind(SyntaxKind kind)
=> kind switch
{
SyntaxKind.OrPattern => SyntaxKind.OrKeyword,
SyntaxKind.AndPattern => SyntaxKind.AndKeyword,
_ => throw new ArgumentOutOfRangeException(),
};

private static SyntaxToken XmlReplaceBracketTokens(SyntaxToken originalToken, SyntaxToken rewrittenToken)
{
if (rewrittenToken.IsKind(SyntaxKind.LessThanToken) && string.Equals("<", rewrittenToken.Text, StringComparison.Ordinal))
Expand Down
113 changes: 80 additions & 33 deletions src/Microsoft.Windows.CsWin32/Generator.Struct.cs
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,20 @@ private StructDeclarationSyntax DeclareStruct(TypeDefinitionHandle typeDefHandle
var fieldTypeInfo = (PrimitiveTypeHandleInfo)fieldDef.DecodeSignature(SignatureHandleProvider.Instance, null);

CustomAttributeValue<TypeSyntax> decodedAttribute = bitfieldAttribute.DecodeValue(CustomAttributeTypeProvider.Instance);
(int? fieldBitLength, bool signed) = fieldTypeInfo.PrimitiveTypeCode switch
{
PrimitiveTypeCode.Byte => (8, false),
PrimitiveTypeCode.SByte => (8, true),
PrimitiveTypeCode.UInt16 => (16, false),
PrimitiveTypeCode.Int16 => (16, true),
PrimitiveTypeCode.UInt32 => (32, false),
PrimitiveTypeCode.Int32 => (32, true),
PrimitiveTypeCode.UInt64 => (64, false),
PrimitiveTypeCode.Int64 => (64, true),
PrimitiveTypeCode.UIntPtr => (null, false),
PrimitiveTypeCode.IntPtr => ((int?)null, true),
_ => throw new NotImplementedException(),
};
string propName = (string)decodedAttribute.FixedArguments[0].Value!;
byte propOffset = (byte)(long)decodedAttribute.FixedArguments[1].Value!;
byte propLength = (byte)(long)decodedAttribute.FixedArguments[2].Value!;
Expand All @@ -171,18 +185,25 @@ private StructDeclarationSyntax DeclareStruct(TypeDefinitionHandle typeDefHandle
continue;
}

TypeSyntax propertyType = propLength switch
long minValue = signed ? -(1L << (propLength - 1)) : 0;
long maxValue = (1L << (propLength - (signed ? 1 : 0))) - 1;
int? leftPad = fieldBitLength.HasValue ? fieldBitLength - (propOffset + propLength) : null;
int rightPad = propOffset;
(TypeSyntax propertyType, int propertyBitLength) = propLength switch
{
1 => PredefinedType(Token(SyntaxKind.BoolKeyword)),
<= 8 => PredefinedType(Token(SyntaxKind.ByteKeyword)),
<= 16 => PredefinedType(Token(SyntaxKind.UShortKeyword)),
<= 32 => PredefinedType(Token(SyntaxKind.UIntKeyword)),
<= 64 => PredefinedType(Token(SyntaxKind.ULongKeyword)),
1 => (PredefinedType(Token(SyntaxKind.BoolKeyword)), 1),
<= 8 => (PredefinedType(Token(signed ? SyntaxKind.SByteKeyword : SyntaxKind.ByteKeyword)), 8),
<= 16 => (PredefinedType(Token(signed ? SyntaxKind.ShortKeyword : SyntaxKind.UShortKeyword)), 16),
<= 32 => (PredefinedType(Token(signed ? SyntaxKind.IntKeyword : SyntaxKind.UIntKeyword)), 32),
<= 64 => (PredefinedType(Token(signed ? SyntaxKind.LongKeyword : SyntaxKind.ULongKeyword)), 64),
_ => throw new NotSupportedException(),
};

AccessorDeclarationSyntax getter = AccessorDeclaration(SyntaxKind.GetAccessorDeclaration);
AccessorDeclarationSyntax setter = AccessorDeclaration(SyntaxKind.SetAccessorDeclaration);
AccessorDeclarationSyntax getter = AccessorDeclaration(SyntaxKind.GetAccessorDeclaration)
.AddModifiers(TokenWithSpace(SyntaxKind.ReadOnlyKeyword))
.AddAttributeLists(AttributeList().AddAttributes(MethodImpl(MethodImplOptions.AggressiveInlining)));
AccessorDeclarationSyntax setter = AccessorDeclaration(SyntaxKind.SetAccessorDeclaration)
.AddAttributeLists(AttributeList().AddAttributes(MethodImpl(MethodImplOptions.AggressiveInlining)));

ulong maskNoOffset = (1UL << propLength) - 1;
ulong mask = maskNoOffset << propOffset;
Expand All @@ -203,36 +224,61 @@ private StructDeclarationSyntax DeclareStruct(TypeDefinitionHandle typeDefHandle
ExpressionSyntax notMaskNoOffset = UncheckedExpression(CastExpression(propertyType, PrefixUnaryExpression(SyntaxKind.BitwiseNotExpression, maskNoOffsetExpr)));
LiteralExpressionSyntax propOffsetExpr = LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(propOffset));

// get => (byte)((field & unchecked((FIELDTYPE)getterMask)) >> propOffset);
// signed:
// get => (byte)((field << leftPad) >> (leftPad + rightPad)));
// unsigned:
// get => (byte)((field >> rightPad) & maskNoOffset);
ExpressionSyntax getterExpression =
CastExpression(propertyType, ParenthesizedExpression(BinaryExpression(
SyntaxKind.RightShiftExpression,
ParenthesizedExpression(BinaryExpression(
SyntaxKind.BitwiseAndExpression,
fieldAccess,
UncheckedExpression(CastExpression(fieldType, maskExpr)))),
propOffsetExpr)));
CastExpression(propertyType, ParenthesizedExpression(
signed ?
BinaryExpression(
SyntaxKind.RightShiftExpression,
ParenthesizedExpression(BinaryExpression(
SyntaxKind.LeftShiftExpression,
fieldAccess,
LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(leftPad!.Value)))),
LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(leftPad.Value + rightPad)))
: BinaryExpression(
SyntaxKind.BitwiseAndExpression,
ParenthesizedExpression(BinaryExpression(SyntaxKind.RightShiftExpression, fieldAccess, LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(rightPad)))),
maskNoOffsetExpr)));
getter = getter
.WithExpressionBody(ArrowExpressionClause(getterExpression))
.WithSemicolonToken(SemicolonWithLineFeed);

// if ((value & ~maskNoOffset) != 0) throw new ArgumentOutOfRangeException(nameof(value));
// field = (int)((field & unchecked((int)~mask)) | ((int)value << propOffset)));
IdentifierNameSyntax valueName = IdentifierName("value");
setter = setter.WithBody(Block().AddStatements(
IfStatement(
BinaryExpression(SyntaxKind.NotEqualsExpression, ParenthesizedExpression(BinaryExpression(SyntaxKind.BitwiseAndExpression, valueName, notMaskNoOffset)), LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(0))),
ThrowStatement(ObjectCreationExpression(IdentifierName(nameof(ArgumentOutOfRangeException))).AddArgumentListArguments(Argument(InvocationExpression(IdentifierName("nameof")).WithArgumentList(ArgumentList().AddArguments(Argument(valueName))))))),
ExpressionStatement(AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
fieldAccess,
CastExpression(fieldType, ParenthesizedExpression(
BinaryExpression(
SyntaxKind.BitwiseOrExpression,
//// (field & unchecked((int)~mask))
fieldAndNotMask,
//// ((int)value << propOffset)
ParenthesizedExpression(BinaryExpression(SyntaxKind.LeftShiftExpression, CastExpression(fieldType, valueName), propOffsetExpr)))))))));

List<StatementSyntax> setterStatements = new();
if (propertyBitLength > propLength)
{
// The allowed range is smaller than the property type, so we need to check that the value fits.
// signed:
// global::System.Debug.Assert(value is >= minValue and <= maxValue);
// unsigned:
// global::System.Debug.Assert(value is <= maxValue);
RelationalPatternSyntax max = RelationalPattern(TokenWithSpace(SyntaxKind.LessThanEqualsToken), CastExpression(propertyType, LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(maxValue))));
RelationalPatternSyntax? min = signed ? RelationalPattern(TokenWithSpace(SyntaxKind.GreaterThanEqualsToken), CastExpression(propertyType, LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(minValue)))) : null;
setterStatements.Add(ExpressionStatement(InvocationExpression(
ParseName("global::System.Diagnostics.Debug.Assert"),
ArgumentList().AddArguments(Argument(
IsPatternExpression(
valueName,
min is null ? max : BinaryPattern(SyntaxKind.AndPattern, min, max)))))));
}

// field = (int)((field & unchecked((int)~mask)) | ((int)(value & mask) << propOffset)));
ExpressionSyntax valueAndMaskNoOffset = ParenthesizedExpression(BinaryExpression(SyntaxKind.BitwiseAndExpression, valueName, maskNoOffsetExpr));
setterStatements.Add(ExpressionStatement(AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
fieldAccess,
CastExpression(fieldType, ParenthesizedExpression(
BinaryExpression(
SyntaxKind.BitwiseOrExpression,
//// (field & unchecked((int)~mask))
fieldAndNotMask,
//// ((int)(value & mask) << propOffset)
ParenthesizedExpression(BinaryExpression(SyntaxKind.LeftShiftExpression, CastExpression(fieldType, valueAndMaskNoOffset), propOffsetExpr))))))));
setter = setter.WithBody(Block().AddStatements(setterStatements.ToArray()));
}
else
{
Expand Down Expand Up @@ -261,11 +307,12 @@ private StructDeclarationSyntax DeclareStruct(TypeDefinitionHandle typeDefHandle
}

string bitDescription = propLength == 1 ? $"bit {propOffset}" : $"bits {propOffset}-{propOffset + propLength - 1}";
string allowedRange = propLength == 1 ? string.Empty : $" Allowed values are [{minValue}..{maxValue}].";

PropertyDeclarationSyntax bitfieldProperty = PropertyDeclaration(propertyType.WithTrailingTrivia(Space), Identifier(propName).WithTrailingTrivia(LineFeed))
.AddModifiers(TokenWithSpace(this.Visibility))
.WithAccessorList(AccessorList().AddAccessors(getter, setter))
.WithLeadingTrivia(ParseLeadingTrivia($"/// <summary>Gets or sets {bitDescription} in the <see cref=\"{fieldName}\" /> field.</summary>\n"));
.WithLeadingTrivia(ParseLeadingTrivia($"/// <summary>Gets or sets {bitDescription} in the <see cref=\"{fieldName}\" /> field.{allowedRange}</summary>\n"));

members.Add(bitfieldProperty);
}
Expand Down
67 changes: 66 additions & 1 deletion test/GenerationSandbox.Tests/BitFieldTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using Windows.Win32.Devices.Usb;
using Windows.Win32.UI.Shell;
using Windows.Win32.UI.TabletPC;

public class BitFieldTests
{
Expand All @@ -21,13 +22,28 @@ public void Bool()
Assert.Equal(unchecked((int)0xfffffffb), s._bitfield);
}

#if DEBUG
[Fact]
public void ThrowWhenSetValueIsOutOfBounds()
{
BM_REQUEST_TYPE._BM s = default;
Assert.Throws<ArgumentOutOfRangeException>(() => s.Type = 0b100);
TestUtils.AssertDebugAssertFailed(() => s.Type = 0b100);
}

[Fact]
public void ThrowWhenSetValueIsOutOfBounds_Signed()
{
FLICK_DATA s = default;

// Assert after each invalid set that what ended up being set did not exceed the bounds of the bitfield.
TestUtils.AssertDebugAssertFailed(() => s.iFlickDirection = -5);
Assert.Equal(0, s._bitfield & ~0xe0);

TestUtils.AssertDebugAssertFailed(() => s.iFlickDirection = 4);
Assert.Equal(0, s._bitfield & ~0xe0);
}
#endif

[Fact]
public void SetValueMultiBit()
{
Expand All @@ -41,4 +57,53 @@ public void SetValueMultiBit()
s.Type = 0;
Assert.Equal(0b10011111, s._bitfield);
}

[Fact]
public void SignedField()
{
FLICK_DATA s = default;

// iFlickDirection: 3 bits => range -4..3
const int mask = 0b111_00000;
s.iFlickDirection = -1;
Assert.Equal(0b111_00000, s._bitfield);
Assert.Equal(-1, s.iFlickDirection);

s.iFlickDirection = 1;
Assert.Equal(0b001_00000, s._bitfield);
Assert.Equal(1, s.iFlickDirection);

int oldFieldValue = s._bitfield;
for (sbyte i = -4; i <= 3; i++)
{
// Assert that a valid value is retained via the property.
s.iFlickDirection = i;
Assert.Equal(i, s.iFlickDirection);

// Assert that no other bits were touched.
Assert.Equal(oldFieldValue & ~mask, s._bitfield & ~mask);
}

// Repeat the test, but with all 1s in other locations.
s._bitfield = unchecked((int)0xffffffff);
oldFieldValue = s._bitfield;
for (sbyte i = -4; i <= 3; i++)
{
// Assert that a valid value is retained via the property.
s.iFlickDirection = i;
Assert.Equal(i, s.iFlickDirection);

// Assert that no other bits were touched.
Assert.Equal(oldFieldValue & ~mask, s._bitfield & ~mask);
}
}

[Fact]
public void SignedField_HasBoolFor1Bit()
{
FLICK_DATA s = default;
Assert.False(s.fMenuModifier);
s.fMenuModifier = true;
Assert.Equal(0b10_0000_0000, s._bitfield);
}
}
1 change: 1 addition & 0 deletions test/GenerationSandbox.Tests/NativeMethods.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ CSIDL_DESKTOP
DISPLAYCONFIG_VIDEO_SIGNAL_INFO
EnumWindows
FILE_ACCESS_RIGHTS
FLICK_DATA
GetProcAddress
GetTickCount
GetWindowText
Expand Down
55 changes: 55 additions & 0 deletions test/GenerationSandbox.Tests/TestUtils.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System.Diagnostics;

internal static class TestUtils
{
#if DEBUG // Only tests that are conditioned for Debug mode can assert this.
internal static void AssertDebugAssertFailed(Action action)
{
// We're mutating a static collection.
// Protect against concurrent tests mutating the collection while we're using it.
lock (Trace.Listeners)
{
TraceListener[] listeners = Trace.Listeners.Cast<TraceListener>().ToArray();
Trace.Listeners.Clear();
Trace.Listeners.Add(new ThrowingTraceListener());

try
{
action();
Assert.Fail("Expected Debug.Assert to fail.");
}
catch (DebugAssertFailedException)
{
// PASS
}
finally
{
Trace.Listeners.Clear();
Trace.Listeners.AddRange(listeners);
}
}
}
#endif

private class DebugAssertFailedException : Exception
{
}

private class ThrowingTraceListener : TraceListener
{
public override void Fail(string? message) => throw new DebugAssertFailedException();

public override void Fail(string? message, string? detailMessage) => throw new DebugAssertFailedException();

public override void Write(string? message)
{
}

public override void WriteLine(string? message)
{
}
}
}