diff --git a/src/Microsoft.Windows.CsWin32/FastSyntaxFactory.cs b/src/Microsoft.Windows.CsWin32/FastSyntaxFactory.cs index 63f06e5b..117772a7 100644 --- a/src/Microsoft.Windows.CsWin32/FastSyntaxFactory.cs +++ b/src/Microsoft.Windows.CsWin32/FastSyntaxFactory.cs @@ -390,6 +390,10 @@ internal static SeparatedSyntaxList SeparatedList() internal static IsPatternExpressionSyntax IsPatternExpression(ExpressionSyntax expression, PatternSyntax pattern) => SyntaxFactory.IsPatternExpression(expression, Token(TriviaList(Space), SyntaxKind.IsKeyword, TriviaList(Space)), pattern); + internal static BinaryPatternSyntax BinaryPattern(SyntaxKind kind, PatternSyntax left, PatternSyntax right) => SyntaxFactory.BinaryPattern(kind, left, TokenWithSpaces(GetBinaryPatternOperatorTokenKind(kind)), right); + + internal static RelationalPatternSyntax RelationalPattern(SyntaxToken operatorToken, ExpressionSyntax expression) => SyntaxFactory.RelationalPattern(operatorToken, expression); + internal static ConditionalExpressionSyntax ConditionalExpression(ExpressionSyntax condition, ExpressionSyntax whenTrue, ExpressionSyntax whenFalse) => SyntaxFactory.ConditionalExpression(condition, Token(TriviaList(Space), SyntaxKind.QuestionToken, TriviaList(Space)), whenTrue, Token(TriviaList(Space), SyntaxKind.ColonToken, TriviaList(Space)), whenFalse); internal static IfStatementSyntax IfStatement(ExpressionSyntax condition, StatementSyntax whenTrue) => IfStatement(condition, whenTrue, null); @@ -595,6 +599,14 @@ private static SyntaxKind GetLiteralExpressionTokenKind(SyntaxKind kind) }; } + private static SyntaxKind GetBinaryPatternOperatorTokenKind(SyntaxKind kind) + => kind switch + { + SyntaxKind.OrPattern => SyntaxKind.OrKeyword, + SyntaxKind.AndPattern => SyntaxKind.AndKeyword, + _ => throw new ArgumentOutOfRangeException(), + }; + private static SyntaxToken XmlReplaceBracketTokens(SyntaxToken originalToken, SyntaxToken rewrittenToken) { if (rewrittenToken.IsKind(SyntaxKind.LessThanToken) && string.Equals("<", rewrittenToken.Text, StringComparison.Ordinal)) diff --git a/src/Microsoft.Windows.CsWin32/Generator.Struct.cs b/src/Microsoft.Windows.CsWin32/Generator.Struct.cs index bd6d14eb..3f9ed349 100644 --- a/src/Microsoft.Windows.CsWin32/Generator.Struct.cs +++ b/src/Microsoft.Windows.CsWin32/Generator.Struct.cs @@ -161,6 +161,20 @@ private StructDeclarationSyntax DeclareStruct(TypeDefinitionHandle typeDefHandle var fieldTypeInfo = (PrimitiveTypeHandleInfo)fieldDef.DecodeSignature(SignatureHandleProvider.Instance, null); CustomAttributeValue decodedAttribute = bitfieldAttribute.DecodeValue(CustomAttributeTypeProvider.Instance); + (int? fieldBitLength, bool signed) = fieldTypeInfo.PrimitiveTypeCode switch + { + PrimitiveTypeCode.Byte => (8, false), + PrimitiveTypeCode.SByte => (8, true), + PrimitiveTypeCode.UInt16 => (16, false), + PrimitiveTypeCode.Int16 => (16, true), + PrimitiveTypeCode.UInt32 => (32, false), + PrimitiveTypeCode.Int32 => (32, true), + PrimitiveTypeCode.UInt64 => (64, false), + PrimitiveTypeCode.Int64 => (64, true), + PrimitiveTypeCode.UIntPtr => (null, false), + PrimitiveTypeCode.IntPtr => ((int?)null, true), + _ => throw new NotImplementedException(), + }; string propName = (string)decodedAttribute.FixedArguments[0].Value!; byte propOffset = (byte)(long)decodedAttribute.FixedArguments[1].Value!; byte propLength = (byte)(long)decodedAttribute.FixedArguments[2].Value!; @@ -171,18 +185,25 @@ private StructDeclarationSyntax DeclareStruct(TypeDefinitionHandle typeDefHandle continue; } - TypeSyntax propertyType = propLength switch + long minValue = signed ? -(1L << (propLength - 1)) : 0; + long maxValue = (1L << (propLength - (signed ? 1 : 0))) - 1; + int? leftPad = fieldBitLength.HasValue ? fieldBitLength - (propOffset + propLength) : null; + int rightPad = propOffset; + (TypeSyntax propertyType, int propertyBitLength) = propLength switch { - 1 => PredefinedType(Token(SyntaxKind.BoolKeyword)), - <= 8 => PredefinedType(Token(SyntaxKind.ByteKeyword)), - <= 16 => PredefinedType(Token(SyntaxKind.UShortKeyword)), - <= 32 => PredefinedType(Token(SyntaxKind.UIntKeyword)), - <= 64 => PredefinedType(Token(SyntaxKind.ULongKeyword)), + 1 => (PredefinedType(Token(SyntaxKind.BoolKeyword)), 1), + <= 8 => (PredefinedType(Token(signed ? SyntaxKind.SByteKeyword : SyntaxKind.ByteKeyword)), 8), + <= 16 => (PredefinedType(Token(signed ? SyntaxKind.ShortKeyword : SyntaxKind.UShortKeyword)), 16), + <= 32 => (PredefinedType(Token(signed ? SyntaxKind.IntKeyword : SyntaxKind.UIntKeyword)), 32), + <= 64 => (PredefinedType(Token(signed ? SyntaxKind.LongKeyword : SyntaxKind.ULongKeyword)), 64), _ => throw new NotSupportedException(), }; - AccessorDeclarationSyntax getter = AccessorDeclaration(SyntaxKind.GetAccessorDeclaration); - AccessorDeclarationSyntax setter = AccessorDeclaration(SyntaxKind.SetAccessorDeclaration); + AccessorDeclarationSyntax getter = AccessorDeclaration(SyntaxKind.GetAccessorDeclaration) + .AddModifiers(TokenWithSpace(SyntaxKind.ReadOnlyKeyword)) + .AddAttributeLists(AttributeList().AddAttributes(MethodImpl(MethodImplOptions.AggressiveInlining))); + AccessorDeclarationSyntax setter = AccessorDeclaration(SyntaxKind.SetAccessorDeclaration) + .AddAttributeLists(AttributeList().AddAttributes(MethodImpl(MethodImplOptions.AggressiveInlining))); ulong maskNoOffset = (1UL << propLength) - 1; ulong mask = maskNoOffset << propOffset; @@ -203,36 +224,61 @@ private StructDeclarationSyntax DeclareStruct(TypeDefinitionHandle typeDefHandle ExpressionSyntax notMaskNoOffset = UncheckedExpression(CastExpression(propertyType, PrefixUnaryExpression(SyntaxKind.BitwiseNotExpression, maskNoOffsetExpr))); LiteralExpressionSyntax propOffsetExpr = LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(propOffset)); - // get => (byte)((field & unchecked((FIELDTYPE)getterMask)) >> propOffset); + // signed: + // get => (byte)((field << leftPad) >> (leftPad + rightPad))); + // unsigned: + // get => (byte)((field >> rightPad) & maskNoOffset); ExpressionSyntax getterExpression = - CastExpression(propertyType, ParenthesizedExpression(BinaryExpression( - SyntaxKind.RightShiftExpression, - ParenthesizedExpression(BinaryExpression( - SyntaxKind.BitwiseAndExpression, - fieldAccess, - UncheckedExpression(CastExpression(fieldType, maskExpr)))), - propOffsetExpr))); + CastExpression(propertyType, ParenthesizedExpression( + signed ? + BinaryExpression( + SyntaxKind.RightShiftExpression, + ParenthesizedExpression(BinaryExpression( + SyntaxKind.LeftShiftExpression, + fieldAccess, + LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(leftPad!.Value)))), + LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(leftPad.Value + rightPad))) + : BinaryExpression( + SyntaxKind.BitwiseAndExpression, + ParenthesizedExpression(BinaryExpression(SyntaxKind.RightShiftExpression, fieldAccess, LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(rightPad)))), + maskNoOffsetExpr))); getter = getter .WithExpressionBody(ArrowExpressionClause(getterExpression)) .WithSemicolonToken(SemicolonWithLineFeed); - // if ((value & ~maskNoOffset) != 0) throw new ArgumentOutOfRangeException(nameof(value)); - // field = (int)((field & unchecked((int)~mask)) | ((int)value << propOffset))); IdentifierNameSyntax valueName = IdentifierName("value"); - setter = setter.WithBody(Block().AddStatements( - IfStatement( - BinaryExpression(SyntaxKind.NotEqualsExpression, ParenthesizedExpression(BinaryExpression(SyntaxKind.BitwiseAndExpression, valueName, notMaskNoOffset)), LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(0))), - ThrowStatement(ObjectCreationExpression(IdentifierName(nameof(ArgumentOutOfRangeException))).AddArgumentListArguments(Argument(InvocationExpression(IdentifierName("nameof")).WithArgumentList(ArgumentList().AddArguments(Argument(valueName))))))), - ExpressionStatement(AssignmentExpression( - SyntaxKind.SimpleAssignmentExpression, - fieldAccess, - CastExpression(fieldType, ParenthesizedExpression( - BinaryExpression( - SyntaxKind.BitwiseOrExpression, - //// (field & unchecked((int)~mask)) - fieldAndNotMask, - //// ((int)value << propOffset) - ParenthesizedExpression(BinaryExpression(SyntaxKind.LeftShiftExpression, CastExpression(fieldType, valueName), propOffsetExpr))))))))); + + List setterStatements = new(); + if (propertyBitLength > propLength) + { + // The allowed range is smaller than the property type, so we need to check that the value fits. + // signed: + // global::System.Debug.Assert(value is >= minValue and <= maxValue); + // unsigned: + // global::System.Debug.Assert(value is <= maxValue); + RelationalPatternSyntax max = RelationalPattern(TokenWithSpace(SyntaxKind.LessThanEqualsToken), CastExpression(propertyType, LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(maxValue)))); + RelationalPatternSyntax? min = signed ? RelationalPattern(TokenWithSpace(SyntaxKind.GreaterThanEqualsToken), CastExpression(propertyType, LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(minValue)))) : null; + setterStatements.Add(ExpressionStatement(InvocationExpression( + ParseName("global::System.Diagnostics.Debug.Assert"), + ArgumentList().AddArguments(Argument( + IsPatternExpression( + valueName, + min is null ? max : BinaryPattern(SyntaxKind.AndPattern, min, max))))))); + } + + // field = (int)((field & unchecked((int)~mask)) | ((int)(value & mask) << propOffset))); + ExpressionSyntax valueAndMaskNoOffset = ParenthesizedExpression(BinaryExpression(SyntaxKind.BitwiseAndExpression, valueName, maskNoOffsetExpr)); + setterStatements.Add(ExpressionStatement(AssignmentExpression( + SyntaxKind.SimpleAssignmentExpression, + fieldAccess, + CastExpression(fieldType, ParenthesizedExpression( + BinaryExpression( + SyntaxKind.BitwiseOrExpression, + //// (field & unchecked((int)~mask)) + fieldAndNotMask, + //// ((int)(value & mask) << propOffset) + ParenthesizedExpression(BinaryExpression(SyntaxKind.LeftShiftExpression, CastExpression(fieldType, valueAndMaskNoOffset), propOffsetExpr)))))))); + setter = setter.WithBody(Block().AddStatements(setterStatements.ToArray())); } else { @@ -261,11 +307,12 @@ private StructDeclarationSyntax DeclareStruct(TypeDefinitionHandle typeDefHandle } string bitDescription = propLength == 1 ? $"bit {propOffset}" : $"bits {propOffset}-{propOffset + propLength - 1}"; + string allowedRange = propLength == 1 ? string.Empty : $" Allowed values are [{minValue}..{maxValue}]."; PropertyDeclarationSyntax bitfieldProperty = PropertyDeclaration(propertyType.WithTrailingTrivia(Space), Identifier(propName).WithTrailingTrivia(LineFeed)) .AddModifiers(TokenWithSpace(this.Visibility)) .WithAccessorList(AccessorList().AddAccessors(getter, setter)) - .WithLeadingTrivia(ParseLeadingTrivia($"/// Gets or sets {bitDescription} in the field.\n")); + .WithLeadingTrivia(ParseLeadingTrivia($"/// Gets or sets {bitDescription} in the field.{allowedRange}\n")); members.Add(bitfieldProperty); } diff --git a/test/GenerationSandbox.Tests/BitFieldTests.cs b/test/GenerationSandbox.Tests/BitFieldTests.cs index 59419466..14f3a9ec 100644 --- a/test/GenerationSandbox.Tests/BitFieldTests.cs +++ b/test/GenerationSandbox.Tests/BitFieldTests.cs @@ -3,6 +3,7 @@ using Windows.Win32.Devices.Usb; using Windows.Win32.UI.Shell; +using Windows.Win32.UI.TabletPC; public class BitFieldTests { @@ -21,13 +22,28 @@ public void Bool() Assert.Equal(unchecked((int)0xfffffffb), s._bitfield); } +#if DEBUG [Fact] public void ThrowWhenSetValueIsOutOfBounds() { BM_REQUEST_TYPE._BM s = default; - Assert.Throws(() => s.Type = 0b100); + TestUtils.AssertDebugAssertFailed(() => s.Type = 0b100); } + [Fact] + public void ThrowWhenSetValueIsOutOfBounds_Signed() + { + FLICK_DATA s = default; + + // Assert after each invalid set that what ended up being set did not exceed the bounds of the bitfield. + TestUtils.AssertDebugAssertFailed(() => s.iFlickDirection = -5); + Assert.Equal(0, s._bitfield & ~0xe0); + + TestUtils.AssertDebugAssertFailed(() => s.iFlickDirection = 4); + Assert.Equal(0, s._bitfield & ~0xe0); + } +#endif + [Fact] public void SetValueMultiBit() { @@ -41,4 +57,53 @@ public void SetValueMultiBit() s.Type = 0; Assert.Equal(0b10011111, s._bitfield); } + + [Fact] + public void SignedField() + { + FLICK_DATA s = default; + + // iFlickDirection: 3 bits => range -4..3 + const int mask = 0b111_00000; + s.iFlickDirection = -1; + Assert.Equal(0b111_00000, s._bitfield); + Assert.Equal(-1, s.iFlickDirection); + + s.iFlickDirection = 1; + Assert.Equal(0b001_00000, s._bitfield); + Assert.Equal(1, s.iFlickDirection); + + int oldFieldValue = s._bitfield; + for (sbyte i = -4; i <= 3; i++) + { + // Assert that a valid value is retained via the property. + s.iFlickDirection = i; + Assert.Equal(i, s.iFlickDirection); + + // Assert that no other bits were touched. + Assert.Equal(oldFieldValue & ~mask, s._bitfield & ~mask); + } + + // Repeat the test, but with all 1s in other locations. + s._bitfield = unchecked((int)0xffffffff); + oldFieldValue = s._bitfield; + for (sbyte i = -4; i <= 3; i++) + { + // Assert that a valid value is retained via the property. + s.iFlickDirection = i; + Assert.Equal(i, s.iFlickDirection); + + // Assert that no other bits were touched. + Assert.Equal(oldFieldValue & ~mask, s._bitfield & ~mask); + } + } + + [Fact] + public void SignedField_HasBoolFor1Bit() + { + FLICK_DATA s = default; + Assert.False(s.fMenuModifier); + s.fMenuModifier = true; + Assert.Equal(0b10_0000_0000, s._bitfield); + } } diff --git a/test/GenerationSandbox.Tests/NativeMethods.txt b/test/GenerationSandbox.Tests/NativeMethods.txt index 08e79ead..47b4c279 100644 --- a/test/GenerationSandbox.Tests/NativeMethods.txt +++ b/test/GenerationSandbox.Tests/NativeMethods.txt @@ -9,6 +9,7 @@ CSIDL_DESKTOP DISPLAYCONFIG_VIDEO_SIGNAL_INFO EnumWindows FILE_ACCESS_RIGHTS +FLICK_DATA GetProcAddress GetTickCount GetWindowText diff --git a/test/GenerationSandbox.Tests/TestUtils.cs b/test/GenerationSandbox.Tests/TestUtils.cs new file mode 100644 index 00000000..8498e96c --- /dev/null +++ b/test/GenerationSandbox.Tests/TestUtils.cs @@ -0,0 +1,55 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Diagnostics; + +internal static class TestUtils +{ +#if DEBUG // Only tests that are conditioned for Debug mode can assert this. + internal static void AssertDebugAssertFailed(Action action) + { + // We're mutating a static collection. + // Protect against concurrent tests mutating the collection while we're using it. + lock (Trace.Listeners) + { + TraceListener[] listeners = Trace.Listeners.Cast().ToArray(); + Trace.Listeners.Clear(); + Trace.Listeners.Add(new ThrowingTraceListener()); + + try + { + action(); + Assert.Fail("Expected Debug.Assert to fail."); + } + catch (DebugAssertFailedException) + { + // PASS + } + finally + { + Trace.Listeners.Clear(); + Trace.Listeners.AddRange(listeners); + } + } + } +#endif + + private class DebugAssertFailedException : Exception + { + } + + private class ThrowingTraceListener : TraceListener + { + public override void Fail(string? message) => throw new DebugAssertFailedException(); + + public override void Fail(string? message, string? detailMessage) => throw new DebugAssertFailedException(); + + public override void Write(string? message) + { + } + + public override void WriteLine(string? message) + { + } + } +}