Skip to content

Commit

Permalink
Merge pull request #727 from microsoft/fix719
Browse files Browse the repository at this point in the history
Generate decimal/DECIMAL converters
  • Loading branch information
AArnott committed Oct 13, 2022
2 parents ea3d883 + 1e1abee commit 2cceaa6
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 7 deletions.
2 changes: 1 addition & 1 deletion src/Microsoft.Windows.CsWin32/FastSyntaxFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ internal static SyntaxToken XmlTextNewLine(string text, bool continueXmlDocument

internal static MethodDeclarationSyntax MethodDeclaration(SyntaxList<AttributeListSyntax> attributeLists, SyntaxTokenList modifiers, TypeSyntax returnType, ExplicitInterfaceSpecifierSyntax explicitInterfaceSpecifier, SyntaxToken identifier, TypeParameterListSyntax typeParameterList, ParameterListSyntax parameterList, SyntaxList<TypeParameterConstraintClauseSyntax> constraintClauses, BlockSyntax body, SyntaxToken semicolonToken) => SyntaxFactory.MethodDeclaration(attributeLists, modifiers, returnType.WithTrailingTrivia(TriviaList(Space)), explicitInterfaceSpecifier, identifier, typeParameterList, parameterList, constraintClauses, body, semicolonToken);

internal static MemberDeclarationSyntax? ParseMemberDeclaration(string text) => SyntaxFactory.ParseMemberDeclaration(text);
internal static MemberDeclarationSyntax? ParseMemberDeclaration(string text, ParseOptions? options) => SyntaxFactory.ParseMemberDeclaration(text, options: options);

internal static SingleVariableDesignationSyntax SingleVariableDesignation(SyntaxToken identifier) => SyntaxFactory.SingleVariableDesignation(identifier);

Expand Down
30 changes: 27 additions & 3 deletions src/Microsoft.Windows.CsWin32/Generator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2563,7 +2563,7 @@ private static NativeArrayInfo DecodeNativeArrayInfoAttribute(CustomAttribute na
return sr.ReadToEnd().Replace("\r\n", "\n").Replace("\t", string.Empty);
}

private static bool TryFetchTemplate(string name, Generator? visibilityModifier, [NotNullWhen(true)] out MemberDeclarationSyntax? member)
private static bool TryFetchTemplate(string name, Generator? generator, [NotNullWhen(true)] out MemberDeclarationSyntax? member)
{
string? template = FetchTemplateText(name);
if (template == null)
Expand All @@ -2572,8 +2572,15 @@ private static bool TryFetchTemplate(string name, Generator? visibilityModifier,
return false;
}

member = ParseMemberDeclaration(template) ?? throw new GenerationFailedException($"Unable to parse a type from a template: {name}");
member = visibilityModifier?.ElevateVisibility(member) ?? member;
member = ParseMemberDeclaration(template, generator?.parseOptions) ?? throw new GenerationFailedException($"Unable to parse a type from a template: {name}");

// Strip out #if/#else/#endif trivia, which was already evaluated with the parse options we passed in.
if (generator?.parseOptions is not null)
{
member = (MemberDeclarationSyntax)member.Accept(DirectiveTriviaRemover.Instance)!;
}

member = generator?.ElevateVisibility(member) ?? member;
return true;
}

Expand Down Expand Up @@ -3966,6 +3973,7 @@ private StructDeclarationSyntax DeclareStruct(TypeDefinitionHandle typeDefHandle
case "RECT":
case "SIZE":
case "SYSTEMTIME":
case "DECIMAL":
members.AddRange(this.ExtractMembersFromTemplate(name.Identifier.ValueText));
break;
default:
Expand Down Expand Up @@ -6852,6 +6860,22 @@ internal Grouping(TKey key, IEnumerable<TElement> values)
}
}

private class DirectiveTriviaRemover : CSharpSyntaxRewriter
{
internal static readonly DirectiveTriviaRemover Instance = new();

private DirectiveTriviaRemover()
{
}

public override SyntaxTrivia VisitTrivia(SyntaxTrivia trivia) =>
trivia.IsKind(SyntaxKind.IfDirectiveTrivia) ||
trivia.IsKind(SyntaxKind.ElseDirectiveTrivia) ||
trivia.IsKind(SyntaxKind.EndIfDirectiveTrivia) ||
trivia.IsKind(SyntaxKind.DisabledTextTrivia)
? default : trivia;
}

private class WhitespaceRewriter : CSharpSyntaxRewriter
{
private readonly List<SyntaxTrivia> indentationLevels = new List<SyntaxTrivia> { default };
Expand Down
39 changes: 39 additions & 0 deletions src/Microsoft.Windows.CsWin32/templates/DECIMAL.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
internal partial struct DECIMAL
{
public DECIMAL(decimal value)
{
unchecked
{
const int SignMask = (int)0x80000000;
#if NET5_0_OR_GREATER
Span<int> bits = stackalloc int[4];
decimal.GetBits(value, bits);
#else
int[] bits = decimal.GetBits(value);
#endif
uint lo32 = (uint)bits[0];
uint mid32 = (uint)bits[1];
uint hi32 = (uint)bits[2];
byte scale = (byte)(bits[3] >> 16);
byte sign = (bits[3] & SignMask) == SignMask ? (byte)0x80 : (byte)0x00;
this.Anonymous2 = new _Anonymous2_e__Union() { Anonymous = new _Anonymous2_e__Union._Anonymous_e__Struct() { Lo32 = lo32, Mid32 = mid32 } };
this.Hi32 = hi32;
this.Anonymous1 = new _Anonymous1_e__Union() { Anonymous = new _Anonymous1_e__Union._Anonymous_e__Struct() { scale = scale, sign = sign } };
this.wReserved = 0;
}
}

public static implicit operator decimal(DECIMAL value)
{
return new decimal(
(int)value.Anonymous2.Anonymous.Lo32,
(int)value.Anonymous2.Anonymous.Mid32,
(int)value.Hi32,
value.Anonymous1.Anonymous.sign == 0x80,
value.Anonymous1.Anonymous.scale);
}

#if NET5_0_OR_GREATER
public static implicit operator DECIMAL(decimal value) => new DECIMAL(value);
#endif
}
24 changes: 24 additions & 0 deletions test/GenerationSandbox.Tests/BasicTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,15 @@ public BasicTests(ITestOutputHelper logger)

internal delegate uint GetTickCountDelegate();

public static object[][] InterestingDecimalValue => new object[][]
{
new object[] { 0.0m },
new object[] { 1.2m },
new object[] { -1.2m },
new object[] { decimal.MinValue },
new object[] { decimal.MaxValue },
};

[Fact]
public void GetTickCount_Nonzero()
{
Expand Down Expand Up @@ -105,6 +114,21 @@ public unsafe void BSTR_AsSpan()
}
}

[Theory]
[MemberData(nameof(InterestingDecimalValue))]
public void DecimalConversion(decimal value)
{
DECIMAL nativeDecimal = new(value);
decimal valueRoundTripped = nativeDecimal;
Assert.Equal(value, valueRoundTripped);

#if NET5_0_OR_GREATER
nativeDecimal = value;
valueRoundTripped = nativeDecimal;
Assert.Equal(value, valueRoundTripped);
#endif
}

[Fact]
public void HandlesOverrideEquals()
{
Expand Down
30 changes: 27 additions & 3 deletions test/Microsoft.Windows.CsWin32.Tests/GeneratorTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Testing;
using Microsoft.CodeAnalysis.Text;
using VerifyTest = Microsoft.CodeAnalysis.CSharp.Testing.CSharpSourceGeneratorTest<
Microsoft.Windows.CsWin32.SourceGenerator,
Microsoft.CodeAnalysis.Testing.Verifiers.XUnitVerifier>;
using VerifyTest = Microsoft.CodeAnalysis.CSharp.Testing.CSharpSourceGeneratorTest<Microsoft.Windows.CsWin32.SourceGenerator, Microsoft.CodeAnalysis.Testing.Verifiers.XUnitVerifier>;

public class GeneratorTests : IDisposable, IAsyncLifetime
{
Expand Down Expand Up @@ -42,6 +40,7 @@ public class GeneratorTests : IDisposable, IAsyncLifetime

private readonly ITestOutputHelper logger;
private readonly Dictionary<string, CSharpCompilation> starterCompilations = new();
private readonly Dictionary<string, string[]> preprocessorSymbolsByTfm = new();
private CSharpCompilation compilation;
private CSharpParseOptions parseOptions;
private Generator? generator;
Expand Down Expand Up @@ -110,6 +109,20 @@ public async Task InitializeAsync()
this.starterCompilations.Add("net6.0-x86", await this.CreateCompilationAsync(MyReferenceAssemblies.Net.Net60, Platform.X86));
this.starterCompilations.Add("net6.0-x64", await this.CreateCompilationAsync(MyReferenceAssemblies.Net.Net60, Platform.X64));

foreach (string tfm in this.starterCompilations.Keys)
{
if (tfm.StartsWith("net6"))
{
AddSymbols("NET5_0_OR_GREATER", "NET6_0_OR_GREATER", "NET6_0");
}
else
{
AddSymbols();
}

void AddSymbols(params string[] symbols) => this.preprocessorSymbolsByTfm.Add(tfm, symbols);
}

this.compilation = this.starterCompilations["netstandard2.0"];
}

Expand Down Expand Up @@ -557,6 +570,17 @@ public void ComOutPtrTypedAsOutObject()
Assert.Contains(this.FindGeneratedMethod(methodName), m => m.ParameterList.Parameters.Last() is { } last && last.Modifiers.Any(SyntaxKind.OutKeyword) && last.Type is PredefinedTypeSyntax { Keyword: { RawKind: (int)SyntaxKind.ObjectKeyword } });
}

[Theory, CombinatorialData]
public void Decimal([CombinatorialValues("net472", "net6.0")] string tfm)
{
this.compilation = this.starterCompilations[tfm];
this.parseOptions = this.parseOptions.WithPreprocessorSymbols(this.preprocessorSymbolsByTfm[tfm]);
this.generator = this.CreateGenerator();
Assert.True(this.generator.TryGenerate("DECIMAL", CancellationToken.None));
this.CollectGeneratedCode(this.generator);
this.AssertNoDiagnostics();
}

[Fact]
public void ComOutPtrTypedAsIntPtr()
{
Expand Down

0 comments on commit 2cceaa6

Please sign in to comment.