Skip to content

Commit

Permalink
Match with state - fixes #121 (#133)
Browse files Browse the repository at this point in the history
* Add: Match methods with state object.

Add `Match<TState, TMatchOutput>` and `Match<TState>` for variant and specific methods which accept a `state` object and pass it to the provided functions.
  • Loading branch information
panoukos41 committed Apr 1, 2023
1 parent 34d58fb commit c2e60c2
Show file tree
Hide file tree
Showing 3 changed files with 407 additions and 1 deletion.
189 changes: 188 additions & 1 deletion src/UnionGeneration/UnionSourceBuilder.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using System.Text;
using System.Text;

namespace Dunet.UnionGeneration;

Expand Down Expand Up @@ -122,6 +122,53 @@ UnionDeclaration union
builder.AppendLine(" );");
builder.AppendLine();

// public abstract TMatchOutput Match<TState, TMatchOutput>(
// TState state,
// System.Func<TState, UnionVariant1<T1, T2, ...>, TMatchOutput> @unionVariant1,
// System.Func<TState, UnionVariant2<T1, T2, ...>, TMatchOutput> @unionVariant2,
// ...
// );
builder.AppendLine(" public abstract TMatchOutput Match<TState, TMatchOutput>(");
builder.Append($" TState state");
builder.AppendLine(union.Variants.Count > 0 ? "," : string.Empty);
for (int i = 0; i < union.Variants.Count; ++i)
{
var variant = union.Variants[i];
builder.Append($" System.Func<TState, {variant.Identifier}");
builder.AppendTypeParams(variant.TypeParameters);
builder.Append($", TMatchOutput> {variant.Identifier.ToMethodParameterCase()}");
if (i < union.Variants.Count - 1)
{
builder.Append(",");
}
builder.AppendLine();
}
builder.AppendLine(" );");

// public abstract void Match<TState>(
// TState state,
// System.Action<TState, UnionVariant1<T1, T2, ...>> @unionVariant1,
// System.Action<TState, UnionVariant2<T1, T2, ...>> @unionVariant2,
// ...
// );
builder.AppendLine(" public abstract void Match<TState>(");
builder.Append($" TState state");
builder.AppendLine(union.Variants.Count > 0 ? "," : string.Empty);
for (int i = 0; i < union.Variants.Count; ++i)
{
var variant = union.Variants[i];
builder.Append($" System.Action<TState, {variant.Identifier}");
builder.AppendTypeParams(variant.TypeParameters);
builder.Append($"> {variant.Identifier.ToMethodParameterCase()}");
if (i < union.Variants.Count - 1)
{
builder.Append(",");
}
builder.AppendLine();
}
builder.AppendLine(" );");
builder.AppendLine();

return builder;
}

Expand Down Expand Up @@ -164,6 +211,44 @@ UnionDeclaration union

builder.AppendLine();

foreach (var variant in union.Variants)
{
// public abstract TMatchOutput MatchSpecific<TState, TMatchOutput>(
// TState state,
// System.Func<TState, Specific<T1, T2, ...>, TMatchOutput> @specific,
// System.Func<TState, TMatchOutput> @else
// );
builder.AppendLine($" public abstract TMatchOutput Match{variant.Identifier}<TState, TMatchOutput>(");
builder.Append($" TState state");
builder.AppendLine(union.Variants.Count > 0 ? "," : string.Empty);
builder.Append($" System.Func<TState, {variant.Identifier}");
builder.AppendTypeParams(variant.TypeParameters);
builder.AppendLine($", TMatchOutput> {variant.Identifier.ToMethodParameterCase()},");
builder.AppendLine($" System.Func<TState, TMatchOutput> @else");
builder.AppendLine(" );");
}

builder.AppendLine();

foreach (var variant in union.Variants)
{
// public abstract void MatchSpecific<TState>(
// TState state,
// System.Action<TState, Specific<T1, T2, ...>> @specific,
// System.Action<TState> @else
// );
builder.AppendLine($" public abstract void Match{variant.Identifier}<TState>(");
builder.Append($" TState state");
builder.AppendLine(union.Variants.Count > 0 ? "," : string.Empty);
builder.Append($" System.Action<TState, {variant.Identifier}");
builder.AppendTypeParams(variant.TypeParameters);
builder.AppendLine($"> {variant.Identifier.ToMethodParameterCase()},");
builder.AppendLine($" System.Action<TState> @else");
builder.AppendLine(" );");
}

builder.AppendLine();

return builder;
}

Expand Down Expand Up @@ -213,6 +298,52 @@ VariantDeclaration variant
}
builder.AppendLine($" ) => {variant.Identifier.ToMethodParameterCase()}(this);");

// public override TMatchOutput Match<TState, TMatchOutput>(
// TState state,
// System.Func<TState, UnionVariant1<T1, T2, ...>, TMatchOutput> @unionVariant1,
// System.Func<TState, UnionVariant2<T1, T2, ...>, TMatchOutput> @unionVariant2,
// ...
// ) => unionVariantX(state, this);
builder.AppendLine(" public override TMatchOutput Match<TState, TMatchOutput>(");
builder.Append($" TState state");
builder.AppendLine(union.Variants.Count > 0 ? "," : string.Empty);
for (int i = 0; i < union.Variants.Count; ++i)
{
var variantParam = union.Variants[i];
builder.Append($" System.Func<TState, {variantParam.Identifier}");
builder.AppendTypeParams(variantParam.TypeParameters);
builder.Append($", TMatchOutput> {variantParam.Identifier.ToMethodParameterCase()}");
if (i < union.Variants.Count - 1)
{
builder.Append(",");
}
builder.AppendLine();
}
builder.AppendLine($" ) => {variant.Identifier.ToMethodParameterCase()}(state, this);");

// public override void Match<TState>(
// TState state,
// System.Action<TState, UnionVariant1<T1, T2, ...>> @unionVariant1,
// System.Action<TState, UnionVariant2<T1, T2, ...>> @unionVariant2,
// ...
// ) => unionVariantX(state, this);
builder.AppendLine(" public override void Match<TState>(");
builder.Append($" TState state");
builder.AppendLine(union.Variants.Count > 0 ? "," : string.Empty);
for (int i = 0; i < union.Variants.Count; ++i)
{
var variantParam = union.Variants[i];
builder.Append($" System.Action<TState, {variantParam.Identifier}");
builder.AppendTypeParams(variantParam.TypeParameters);
builder.Append($"> {variantParam.Identifier.ToMethodParameterCase()}");
if (i < union.Variants.Count - 1)
{
builder.Append(",");
}
builder.AppendLine();
}
builder.AppendLine($" ) => {variant.Identifier.ToMethodParameterCase()}(state, this);");

return builder;
}

Expand Down Expand Up @@ -272,6 +403,62 @@ VariantDeclaration variant
}
}

// public override TMatchOutput MatchVariantX<TState, TMatchOutput>(
// TState state,
// System.Func<TState, UnionVariant1<T1, T2, ...>, TMatchOutput> @unionVariantX,
// System.Func<TState, TMatchOutput> @else,
// ...
// ) => unionVariantX(state, this);
foreach (var specificVariant in union.Variants)
{
builder.AppendLine(
$" public override TMatchOutput Match{specificVariant.Identifier}<TState, TMatchOutput>("
);
builder.Append($" TState state");
builder.AppendLine(union.Variants.Count > 0 ? "," : string.Empty);
builder.Append($" System.Func<TState, {specificVariant.Identifier}");
builder.AppendTypeParams(specificVariant.TypeParameters);
builder.AppendLine(
$", TMatchOutput> {specificVariant.Identifier.ToMethodParameterCase()},"
);
builder.AppendLine($" System.Func<TState, TMatchOutput> @else");
builder.Append(" ) => ");
if (specificVariant.Identifier == variant.Identifier)
{
builder.AppendLine($"{specificVariant.Identifier.ToMethodParameterCase()}(state, this);");
}
else
{
builder.AppendLine("@else(state);");
}
}

// public override void MatchVariantX<TState>(
// TState state,
// System.Action<TState, UnionVariant1<T1, T2, ...>> @unionVariantX,
// System.Action<TState> @else,
// ...
// ) => unionVariantX(state, this);
foreach (var specificVariant in union.Variants)
{
builder.AppendLine($" public override void Match{specificVariant.Identifier}<TState>(");
builder.Append($" TState state");
builder.AppendLine(union.Variants.Count > 0 ? "," : string.Empty);
builder.Append($" System.Action<TState, {specificVariant.Identifier}");
builder.AppendTypeParams(specificVariant.TypeParameters);
builder.AppendLine($"> {specificVariant.Identifier.ToMethodParameterCase()},");
builder.AppendLine($" System.Action<TState> @else");
builder.Append(" ) => ");
if (specificVariant.Identifier == variant.Identifier)
{
builder.AppendLine($"{specificVariant.Identifier.ToMethodParameterCase()}(state, this);");
}
else
{
builder.AppendLine("@else(state);");
}
}

builder.AppendLine(" }");
builder.AppendLine();

Expand Down
127 changes: 127 additions & 0 deletions test/UnionGeneration/MatchMethodWithStateTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
namespace Dunet.Test.UnionGeneration;

public sealed class MatchMethodWithStateTests
{
[Fact]
public void CanUseUnionTypesInDedicatedMatchMethod()
{
// Arrange.
var source = """
using Dunet;

Shape shape = new Shape.Rectangle(3, 4);
double state = 2d;

var area = shape.Match(
state,
static (s, circle) => s + 3.14 * circle.Radius * circle.Radius,
static (s, rectangle) => s + rectangle.Length * rectangle.Width,
static (s, triangle) => s + triangle.Base * triangle.Height / 2
);

[Union]
partial record Shape
{
partial record Circle(double Radius);
partial record Rectangle(double Length, double Width);
partial record Triangle(double Base, double Height);
}
""";
// Act.
var result = Compiler.Compile(source);

// Assert.
using var scope = new AssertionScope();
result.CompilationErrors.Should().BeEmpty();
result.GenerationErrors.Should().BeEmpty();
}

[Theory]
[InlineData("Shape shape = new Shape.Rectangle(3, 4);", 14d)]
[InlineData("Shape shape = new Shape.Circle(1);", 5.14d)]
[InlineData("Shape shape = new Shape.Triangle(4, 2);", 6d)]
public void MatchMethodCallsCorrectFunctionArgument(
string shapeDeclaration,
double expectedArea
)
{
// Arrange.
var source = $$"""
using Dunet;

static double GetArea()
{
{{shapeDeclaration}}
double state = 2d;
return shape.Match(
state,
static (s, circle) => s + 3.14 * circle.Radius * circle.Radius,
static (s, rectangle) => s + rectangle.Length * rectangle.Width,
static (s, triangle) => s + triangle.Base * triangle.Height / 2
);
}

[Union]
partial record Shape
{
partial record Circle(double Radius);
partial record Rectangle(double Length, double Width);
partial record Triangle(double Base, double Height);
}
""";
// Act.
var result = Compiler.Compile(source);
var actualArea = result.Assembly?.ExecuteStaticMethod<double>("GetArea");

// Assert.
using var scope = new AssertionScope();
result.CompilationErrors.Should().BeEmpty();
result.GenerationErrors.Should().BeEmpty();
actualArea.Should().BeApproximately(expectedArea, 0.0000000001d);
}

[Theory]
[InlineData("Keyword keyword = new Keyword.New();" , "string state = \"new\";", "new")]
[InlineData("Keyword keyword = new Keyword.Base();", "string state = \"base\";", "base")]
[InlineData("Keyword keyword = new Keyword.Null();", "string state = \"null\";", "null")]
public void CanMatchOnUnionVariantsNamedAfterKeywords(
string keywordDeclaration,
string stateDeclaration,
string expectedKeyword
)
{
// Arrange.
var source = $$"""
using Dunet;

static string GetKeyword()
{
{{keywordDeclaration}}
{{stateDeclaration}}
return keyword.Match(
state,
static (s, @new) => s,
static (s, @base) => s,
static (s, @null) => s
);
}

[Union]
partial record Keyword
{
partial record New;
partial record Base;
partial record Null;
}
""";
// Act.
var result = Compiler.Compile(source);
var actualKeyword = result.Assembly?.ExecuteStaticMethod<string>("GetKeyword");

// Assert.
using var scope = new AssertionScope();
result.CompilationErrors.Should().BeEmpty();
result.GenerationErrors.Should().BeEmpty();
actualKeyword.Should().Be(expectedKeyword);
}
}
Loading

0 comments on commit c2e60c2

Please sign in to comment.