Skip to content

Commit

Permalink
Update result type for code gen
Browse files Browse the repository at this point in the history
  • Loading branch information
MichalBrylka committed Jan 4, 2024
1 parent 891d67d commit 45a7204
Show file tree
Hide file tree
Showing 9 changed files with 99 additions and 66 deletions.
Expand Up @@ -33,7 +33,7 @@ private static void RunCase(string index)

var actual = ScrubGeneratorComments(generatedTrees.Single());

actual = IgnoreNewLinesComparer.NormalizeNewLines(actual);
actual = NormalizeNewLines(actual);

Approvals.Verify(WriterFactory.CreateTextWriter(actual, "cs"));
}
Expand Down
Expand Up @@ -39,7 +39,7 @@ private static void RunCase(string name)

var actual = ScrubGeneratorComments(sources.Single());

actual = IgnoreNewLinesComparer.NormalizeNewLines(actual);
actual = NormalizeNewLines(actual);

Approvals.Verify(WriterFactory.CreateTextWriter(actual, "cs"));
}
Expand Down
10 changes: 6 additions & 4 deletions Nemesis.TextParsers.CodeGen.Tests/CodeGenUtils.cs
Expand Up @@ -106,6 +106,8 @@ public static IReadOnlyList<string> GetGeneratedTreesOnly(Compilation compilatio
((CompilationUnitSyntax)tree.GetRoot())
.ToFullString()).ToList();
}

public static string NormalizeNewLines(string text) => text.Replace("\r\n", "\n").Replace("\r", "\n");
}

internal class IgnoreNewLinesComparer : IComparer<string>, IEqualityComparer<string>
Expand All @@ -114,14 +116,14 @@ internal class IgnoreNewLinesComparer : IComparer<string>, IEqualityComparer<str

public static readonly IEqualityComparer<string> EqualityComparer = new IgnoreNewLinesComparer();

public int Compare(string? x, string? y) => string.CompareOrdinal(NormalizeNewLines(x), NormalizeNewLines(y));
public int Compare(string? x, string? y) => string.CompareOrdinal(RemoveNewLines(x), RemoveNewLines(y));

public bool Equals(string? x, string? y) => NormalizeNewLines(x) == NormalizeNewLines(y);
public bool Equals(string? x, string? y) => RemoveNewLines(x) == RemoveNewLines(y);

public int GetHashCode(string s) => NormalizeNewLines(s)?.GetHashCode() ?? 0;
public int GetHashCode(string s) => RemoveNewLines(s)?.GetHashCode() ?? 0;

//for NET 6+ use string.ReplaceLineEndings()
public static string? NormalizeNewLines(string? s) => s?
private static string? RemoveNewLines(string? s) => s?
.Replace(Environment.NewLine, "")
.Replace("\n", "")
.Replace("\r", "");
Expand Down
Expand Up @@ -14,21 +14,21 @@ internal class EnumTransformerGeneratorTests
.SetName($"EnumCodeGen_{i + 1:00}_{t.name}"));

[TestCaseSource(nameof(EnumCodeGenTestCases))]
public void Generate_ShouldReturnValid_MetaInput_And_Output(string source, TransformerMeta expectedMeta, string expectedCodeGen)
public void Generate_ShouldReturnValid_MetaInput_And_Output(string source, EnumTransformerInput expectedMeta, string expectedCodeGen)
{
var compilation = CreateValidCompilation(source);

var (sources, metas) = new EnumTransformerGenerator().RunIncrementalGeneratorAndCaptureInputs<TransformerMeta>(compilation);
var (sources, inputs) = new EnumTransformerGenerator().RunIncrementalGeneratorAndCaptureInputs<EnumTransformerInput>(compilation);

Assert.Multiple(() =>
{
Assert.That(metas, Has.Count.EqualTo(1));
Assert.That(inputs, Has.Count.EqualTo(1));
Assert.That(sources, Has.Count.EqualTo(1));
var meta = metas[0];
var meta = inputs[0];
meta.Should().BeEquivalentTo(expectedMeta);
var source = ScrubGeneratorComments(sources.First());
var source = ScrubGeneratorComments(sources[0]);
Assert.That(source, Is.EqualTo(expectedCodeGen).Using(IgnoreNewLinesComparer.EqualityComparer));
});
}
Expand Down Expand Up @@ -57,7 +57,7 @@ internal enum Casing { A, a, B, b, C, c, Good }
});
}

internal static readonly IEnumerable<(string name, string source, TransformerMeta expectedMeta, string expectedCodeGen)> EnumCodeGenCases =
internal static readonly IEnumerable<(string name, string source, EnumTransformerInput expectedMeta, string expectedCodeGen)> EnumCodeGenCases =
[
("Month", """
[Auto.AutoEnumTransformer(CaseInsensitive = true, AllowParsingNumerics = true, TransformerClassName = "MonthCodeGenTransformer")]
Expand Down
Expand Up @@ -63,15 +63,15 @@ internal sealed class {{ATTRIBUTE_NAME}} : global::System.Attribute
}
""";

private static Result<TransformerMeta?, Diagnostic> GetTypeToGenerate(GeneratorAttributeSyntaxContext context, CancellationToken ct)
private static Result<EnumTransformerInput, Diagnostic> GetTypeToGenerate(GeneratorAttributeSyntaxContext context, CancellationToken ct)
{
if (context.TargetSymbol is not INamedTypeSymbol enumSymbol ||
enumSymbol.TypeKind != TypeKind.Enum ||
enumSymbol is IErrorTypeSymbol ||
(enumSymbol.GetAttributes() is var attributes && attributes.Length == 0)
)
{
return (TransformerMeta?)null;
return Result<EnumTransformerInput, Diagnostic>.None();
}

ct.ThrowIfCancellationRequested();
Expand Down Expand Up @@ -126,7 +126,7 @@ internal sealed class {{ATTRIBUTE_NAME}} : global::System.Attribute
}
}

if (!autoAttributeFound) return (TransformerMeta?)null;
if (!autoAttributeFound) return Result<EnumTransformerInput, Diagnostic>.None();

ct.ThrowIfCancellationRequested();

Expand All @@ -139,7 +139,7 @@ internal sealed class {{ATTRIBUTE_NAME}} : global::System.Attribute
.Select(static symbol => symbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat))
.ToArray();

var meta = new TransformerMeta(
var input = new EnumTransformerInput(
transformerName ?? $"{enumSymbol.Name}Transformer",
transformerNamespace ?? (enumSymbol.ContainingNamespace.IsGlobalNamespace ? string.Empty : enumSymbol.ContainingNamespace.ToString()),
caseInsensitive ?? true, allowParsingNumerics ?? true,
Expand All @@ -148,20 +148,20 @@ internal sealed class {{ATTRIBUTE_NAME}} : global::System.Attribute
enumSymbol.DeclaredAccessibility == Accessibility.Public,
hasFlags, underlyingType);

if (meta.CaseInsensitive &&
if (input.CaseInsensitive &&
memberNames.Length != new HashSet<string>(memberNames, StringComparer.OrdinalIgnoreCase).Count)
{
return CreateDiagnostics(CaseInsensitiveIncompatibleMemberNames, enumSymbol);
}
else
return meta;
return input;
static Diagnostic CreateDiagnostics(DiagnosticDescriptor rule, ISymbol? symbol) =>
Diagnostic.Create(rule, symbol?.Locations[0] ?? Location.None, symbol?.ContainingNamespace?.ToString(), symbol?.Name);
}
}

internal readonly record struct TransformerMeta(
internal readonly record struct EnumTransformerInput(
string TransformerName, string TransformerNamespace,
bool CaseInsensitive, bool AllowParsingNumerics,

Expand Down
42 changes: 21 additions & 21 deletions Nemesis.TextParsers.CodeGen/Enums/EnumTransformerGenerator.cs
Expand Up @@ -16,45 +16,45 @@ public override void Initialize(IncrementalGeneratorInitializationContext contex
ATTRIBUTE_FULL_NAME,
predicate: static (node, _) => node is EnumDeclarationSyntax,
transform: GetTypeToGenerate)
.Where(static result => result.IsError || (result.IsSuccess && result.Value is not null))
.Where(static result => !result.IsNone)
.WithTrackingName(INPUTS);

context.RegisterSourceOutput(transformersToGenerate,
static (spc, result) => Execute(result, spc));
}

private static void Execute(Result<TransformerMeta?, Diagnostic> result, SourceProductionContext context)
private static void Execute(Result<EnumTransformerInput, Diagnostic> result, SourceProductionContext context)
{
if (result.IsError && result.Error is { } error)
context.ReportDiagnostic(error);
else if (result.IsSuccess && result.Value is { } meta)
{
var source = Render(in meta);
context.AddSource($"{meta.TransformerName}.g.cs", SourceText.From(source, Encoding.UTF8));
}
result.Invoke(
input =>
{
var source = Render(in input);
context.AddSource($"{input.TransformerName}.g.cs", SourceText.From(source, Encoding.UTF8));
},
context.ReportDiagnostic);
}

private static string Render(in TransformerMeta meta)
private static string Render(in EnumTransformerInput input)
{
var sb = new StringBuilder(HEADER, 1024).AppendLine();

var enumName = meta.EnumFullyQualifiedName;
var numberType = meta.UnderlyingType;
var memberNames = meta.MemberNames;
var enumName = input.EnumFullyQualifiedName;
var numberType = input.UnderlyingType;
var memberNames = input.MemberNames;

sb.AppendLine($$"""
using System;
using Nemesis.TextParsers;
{{(
string.IsNullOrEmpty(meta.TransformerNamespace)
string.IsNullOrEmpty(input.TransformerNamespace)
? ""
: $"""
namespace {meta.TransformerNamespace};
namespace {input.TransformerNamespace};

"""
)}}
{{CODE_GEN_ATTRIBUTES}}
{{(meta.IsPublic ? "public" : "internal")}} sealed class {{meta.TransformerName}} : TransformerBase<{{enumName}}>
{{(input.IsPublic ? "public" : "internal")}} sealed class {{input.TransformerName}} : TransformerBase<{{enumName}}>
{
""");

Expand All @@ -76,7 +76,7 @@ namespace {meta.TransformerNamespace};
""").AppendLine();

//ParseCore
if (meta.IsFlagEnum)
if (input.IsFlagEnum)
{
sb.AppendLine($$"""
protected override {{enumName}} ParseCore(in ReadOnlySpan<char> input)
Expand Down Expand Up @@ -117,7 +117,7 @@ namespace {meta.TransformerNamespace};
input = input.Trim();
""");

if (meta.AllowParsingNumerics)
if (input.AllowParsingNumerics)
{
sb.AppendLine($$"""
if (IsNumeric(input) && {{numberType}}.TryParse(input
Expand Down Expand Up @@ -166,8 +166,8 @@ namespace {meta.TransformerNamespace};
""").AppendLine();
}

var numberParsingText = meta.AllowParsingNumerics ? $" or number within {numberType} range. " : ". ";
var caseInsensitiveText = meta.CaseInsensitive ? "Ignore case option on." : "Case sensitive option on.";
var numberParsingText = input.AllowParsingNumerics ? $" or number within {numberType} range. " : ". ";
var caseInsensitiveText = input.CaseInsensitive ? "Ignore case option on." : "Case sensitive option on.";
var exceptionMessage = $$"""
Enum of type '{{enumName}}' cannot be parsed from '{input.ToString()}'.
Valid values are: [{{string.Join(" or ", memberNames)}}]{{numberParsingText}}
Expand All @@ -178,7 +178,7 @@ namespace {meta.TransformerNamespace};
"""");


var stringComparison = meta.CaseInsensitive ? nameof(StringComparison.OrdinalIgnoreCase) : nameof(StringComparison.Ordinal);
var stringComparison = input.CaseInsensitive ? nameof(StringComparison.OrdinalIgnoreCase) : nameof(StringComparison.Ordinal);
if (hasAnyMembers)
sb.AppendLine().AppendLine($$"""
static bool IsEqual(ReadOnlySpan<char> input, string label) =>
Expand Down
16 changes: 9 additions & 7 deletions Nemesis.TextParsers.CodeGen/IncrementalGenerator.cs
Expand Up @@ -26,32 +26,34 @@ public IReadOnlyList<string> RunIncrementalGeneratorAndGetGeneratedSources(Compi
return GetGeneratedOutput(result, requiredCardinality);
}

public (IReadOnlyList<string> Sources, IReadOnlyList<TMeta> Meta) RunIncrementalGeneratorAndCaptureInputs<TMeta>(Compilation compilation, int requiredCardinality = 1)
where TMeta : struct
public (IReadOnlyList<string> Sources, IReadOnlyList<TInput> Inputs) RunIncrementalGeneratorAndCaptureInputs<TInput>(Compilation compilation, int requiredCardinality = 1)
{
var result = RunIncrementalGenerator(compilation);
var generatedSources = GetGeneratedOutput(result, requiredCardinality);

IReadOnlyList<TMeta> meta = [];
if (result.TrackedSteps.TryGetValue(INPUTS, out var metaValue))
{
var stepResults = metaValue.Single().Outputs
.Select(o => (
Result: (Result<TMeta?, Diagnostic>)o.Value,
Result: (Result<TInput, Diagnostic>)o.Value,
o.Reason
))
.ToList();


if (stepResults.Any(r => r.Reason != IncrementalStepRunReason.New))
throw new NotSupportedException($"All generation steps are expected to be new");
if (stepResults.Any(r => r.Result.IsSuccess == false || r.Result.Value is null))

if (stepResults.Any(r => !r.Result.IsSuccess))
throw new NotSupportedException($"All generation steps are expected to be succesful");

meta = stepResults.Select(s => s.Result.Value!.Value!).ToList();
var meta = new List<TInput>(stepResults.Count);
stepResults.ForEach(r => r.Result.Invoke(meta.Add));

return (generatedSources, meta);
}

return (generatedSources, meta);
return (generatedSources, []);
}

private IReadOnlyList<string> GetGeneratedOutput(GeneratorRunResult result, int requiredCardinality)
Expand Down
55 changes: 42 additions & 13 deletions Nemesis.TextParsers.CodeGen/Utils/Result.cs
Expand Up @@ -3,32 +3,61 @@ namespace Nemesis.TextParsers.CodeGen.Utils;

internal readonly struct Result<TValue, TError>
{
public TValue? Value { get; }
public TError? Error { get; }
private readonly State _state;
private readonly TValue? _value;
private readonly TError? _error;

public bool IsError { get; }
public TValue Value => IsSuccess ? _value! : throw new InvalidOperationException("Value can only be retrieved in 'Success' state");

public bool IsSuccess => !IsError;
public bool IsSuccess => _state == State.Success;
public bool IsError => _state == State.Error;
public bool IsNone => _state == State.None;

private Result(TValue value)
{
Value = value;
Error = default;
IsError = false;
_value = value;
_error = default;

_state = State.Success;
}

private Result(TError error)
{
Value = default;
Error = error;
IsError = true;
_value = default;
_error = error;

_state = State.Error;
}

public Result()
{
_value = default;
_error = default;

_state = State.None;
}

public static Result<TValue, TError> None() => new();

public static implicit operator Result<TValue, TError>(TValue result) => new(result);
public static implicit operator Result<TValue, TError>(TError error) => new(error);

public TResult Match<TResult>(Func<TValue, TResult> success, Func<TError, TResult> failure) =>
IsError ? failure(Error!) : success(Value!);
/*public TResult Match<TResult>(Func<TValue, TResult> success, Func<TError, TResult> failure) =>
IsError ? failure(Error!) : success(Value!);*/

public void Invoke(Action<TValue> success, Action<TError>? failure = null)
{
if (IsSuccess) success(_value!);
else if (IsError) failure?.Invoke(_error!);
}

public override string? ToString() => _state switch
{
State.Success => _value?.ToString(),
State.Error => _error?.ToString(),
State.None => "<None>",
_ => throw new NotSupportedException($"State {_state} is not supported")
};

public override string? ToString() => IsError ? Error?.ToString() : Value?.ToString();
enum State : byte { None, Success, Error }
}
12 changes: 6 additions & 6 deletions Nemesis.TextParsers.Tests/Utils/TestHelper.cs
Expand Up @@ -16,8 +16,8 @@ public static string AssertException(Exception actual, Type expectedException, s

if (expectedErrorMessagePart != null)
Assert.That(
IgnoreNewLinesComparer.NormalizeNewLines(actual?.Message),
Does.Contain(IgnoreNewLinesComparer.NormalizeNewLines(expectedErrorMessagePart))
IgnoreNewLinesComparer.RemoveNewLines(actual?.Message),
Does.Contain(IgnoreNewLinesComparer.RemoveNewLines(expectedErrorMessagePart))
);

if (!logMessage) return "";
Expand Down Expand Up @@ -183,14 +183,14 @@ internal class IgnoreNewLinesComparer : IComparer<string>, IEqualityComparer<str

public static readonly IEqualityComparer<string> EqualityComparer = new IgnoreNewLinesComparer();

public int Compare(string? x, string? y) => string.CompareOrdinal(NormalizeNewLines(x), NormalizeNewLines(y));
public int Compare(string? x, string? y) => string.CompareOrdinal(RemoveNewLines(x), RemoveNewLines(y));

public bool Equals(string? x, string? y) => NormalizeNewLines(x) == NormalizeNewLines(y);
public bool Equals(string? x, string? y) => RemoveNewLines(x) == RemoveNewLines(y);

public int GetHashCode(string s) => NormalizeNewLines(s)?.GetHashCode() ?? 0;
public int GetHashCode(string s) => RemoveNewLines(s)?.GetHashCode() ?? 0;

//for NET 6+ use string.ReplaceLineEndings()
public static string? NormalizeNewLines(string? s) => s?
public static string? RemoveNewLines(string? s) => s?
.Replace(Environment.NewLine, "")
.Replace("\n", "")
.Replace("\r", "");
Expand Down

0 comments on commit 45a7204

Please sign in to comment.