Skip to content

Commit

Permalink
Extend async check to sections
Browse files Browse the repository at this point in the history
  • Loading branch information
ltrzesniewski committed Nov 29, 2023
1 parent 5e4a9db commit 7577e38
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 10 deletions.
13 changes: 13 additions & 0 deletions src/RazorBlade.Analyzers.Tests/RazorBladeSourceGeneratorTests.cs
Expand Up @@ -337,6 +337,19 @@ After section
);
}

[Test]
public Task should_detect_async_sections()
{
return Verify(
"""
@using System.Threading.Tasks
@if(42.ToString() == "42") {
@section SectionName { @await Task.FromResult(42) }
}
"""
);
}

private static GeneratorDriverRunResult Generate(string input,
string? csharpCode,
bool embeddedLibrary,
Expand Down
@@ -0,0 +1,54 @@
//HintName: TestNamespace.TestFile.Razor.g.cs
#pragma checksum "./TestFile.cshtml" "{ff1816ec-aa5e-4d10-87f7-6f4963833460}" "ad80bc3dc0df64bf01f11639ab5672ca37eb6742"
// <auto-generated/>
#pragma warning disable 1591
namespace TestNamespace
{
#line hidden
#nullable restore
#line 1 "./TestFile.cshtml"
using System.Threading.Tasks;

#line default
#line hidden
#nullable disable
#nullable restore
internal partial class TestFile : global::RazorBlade.HtmlTemplate
#nullable disable
{
#pragma warning disable 1998
protected async override global::System.Threading.Tasks.Task ExecuteAsync()
{
#nullable restore
#line 2 "./TestFile.cshtml"
if(42.ToString() == "42") {


#line default
#line hidden
#nullable disable
DefineSection("SectionName", async() => {
WriteLiteral(" ");
#nullable restore
#line (3,29)-(3,54) 6 "./TestFile.cshtml"
Write(await Task.FromResult(42));
#line default
#line hidden
#nullable disable
WriteLiteral(" ");
}
);
#nullable restore
#line 3 "./TestFile.cshtml"

}

#line default
#line hidden
#nullable disable
}
#pragma warning restore 1998
}
}
#pragma warning restore 1591
@@ -0,0 +1,22 @@
//HintName: TestNamespace.TestFile.RazorBlade.g.cs
// <auto-generated/>

#nullable restore

namespace TestNamespace
{
partial class TestFile
{
/// <inheritdoc cref="M:RazorBlade.RazorTemplate.Render(System.Threading.CancellationToken)" />
[global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)]
[global::System.Obsolete("The generated template is async. Use RenderAsync instead.", DiagnosticId = "RB0003")]
public new string Render(global::System.Threading.CancellationToken cancellationToken = default(global::System.Threading.CancellationToken))
=> base.Render(cancellationToken);

/// <inheritdoc cref="M:RazorBlade.RazorTemplate.Render(System.IO.TextWriter,System.Threading.CancellationToken)" />
[global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)]
[global::System.Obsolete("The generated template is async. Use RenderAsync instead.", DiagnosticId = "RB0003")]
public new void Render(global::System.IO.TextWriter textWriter, global::System.Threading.CancellationToken cancellationToken = default(global::System.Threading.CancellationToken))
=> base.Render(textWriter, cancellationToken);
}
}
65 changes: 55 additions & 10 deletions src/RazorBlade.Analyzers/LibraryCodeGenerator.cs
Expand Up @@ -52,6 +52,7 @@ internal class LibraryCodeGenerator
private INamedTypeSymbol? _classSymbol;
private ImmutableArray<Diagnostic> _diagnostics;
private Compilation _compilation;
private SemanticModel? _semanticModel;

public LibraryCodeGenerator(RazorCSharpDocument generatedDoc,
Compilation compilation,
Expand Down Expand Up @@ -86,7 +87,7 @@ public string Generate(CancellationToken cancellationToken)
using (_writer.BuildClassDeclaration(["partial"], _classSymbol.Name, null, Array.Empty<string>(), Array.Empty<TypeParameter>(), useNullableContext: false))
{
GenerateConstructors();
GenerateConditionalOnAsync();
GenerateConditionalOnAsync(cancellationToken);
}
}

Expand All @@ -107,17 +108,17 @@ private void Analyze(CancellationToken cancellationToken)
.AddSyntaxTrees(syntaxTree)
.AddSyntaxTrees(_additionalSyntaxTrees);

var semanticModel = _compilation.GetSemanticModel(syntaxTree);
_semanticModel = _compilation.GetSemanticModel(syntaxTree);

var classDeclarationNode = syntaxTree.GetRoot(cancellationToken)
.DescendantNodes()
.FirstOrDefault(static i => i.IsKind(SyntaxKind.ClassDeclaration));

_classSymbol = classDeclarationNode is ClassDeclarationSyntax classDeclarationSyntax
? semanticModel.GetDeclaredSymbol(classDeclarationSyntax, cancellationToken)
? _semanticModel.GetDeclaredSymbol(classDeclarationSyntax, cancellationToken)
: null;

_diagnostics = semanticModel.GetDiagnostics(cancellationToken: cancellationToken);
_diagnostics = _semanticModel.GetDiagnostics(cancellationToken: cancellationToken);
}

private void GenerateConstructors()
Expand Down Expand Up @@ -164,26 +165,27 @@ private void GenerateConstructors()
}
}

private void GenerateConditionalOnAsync()
private void GenerateConditionalOnAsync(CancellationToken cancellationToken)
{
const string executeAsyncMethodName = "ExecuteAsync";
const string defineSectionMethodName = "DefineSection";

var conditionalOnAsyncAttribute = _compilation.GetTypeByMetadataName("RazorBlade.Support.ConditionalOnAsyncAttribute");
if (conditionalOnAsyncAttribute is null)
return;

var executeMethodSymbol = _classSymbol?.GetMembers("ExecuteAsync")
var executeMethodSymbol = _classSymbol?.GetMembers(executeAsyncMethodName)
.OfType<IMethodSymbol>()
.FirstOrDefault(i => i.Parameters.IsEmpty && i.IsAsync);

var methodLocation = executeMethodSymbol?.Locations.FirstOrDefault();
if (methodLocation is null)
return;

// CS1998 = This async method lacks 'await' operators and will run synchronously.
var isTemplateSync = _diagnostics.Any(i => i.Id == "CS1998" && i.Location == methodLocation);

var isTemplateSync = IsTemplateSync();
var hiddenMethodSignatures = new HashSet<string>(StringComparer.Ordinal);

for (var baseClass = _classSymbol?.BaseType; baseClass is not (null or { SpecialType: SpecialType.System_Object }); baseClass = baseClass.BaseType)
foreach (var baseClass in _classSymbol.SelfAndBasesTypes().Skip(1))
{
foreach (var methodSymbol in baseClass.GetMembers().OfType<IMethodSymbol>())
{
Expand Down Expand Up @@ -233,6 +235,49 @@ private void GenerateConditionalOnAsync()
}
}

bool IsTemplateSync()
{
// CS1998 = This async method lacks 'await' operators and will run synchronously.
// The ExecuteAsync and all the DefineSection methods need to have this diagnostic for the template to be considered synchronous.

var diagnosticLocations = _diagnostics.Where(i => i.Id == "CS1998").Select(i => i.Location).ToHashSet();
if (diagnosticLocations.Count == 0 || !diagnosticLocations.Contains(methodLocation))
return false;

if (executeMethodSymbol?.DeclaringSyntaxReferences is not [var syntaxRef]
|| syntaxRef.GetSyntax(cancellationToken) is not MethodDeclarationSyntax { Body: { } executeMethodBody })
return true;

var defineSectionMethod = _classSymbol.SelfAndBasesTypes()
.SelectMany(t => t.GetMembers(defineSectionMethodName))
.OfType<IMethodSymbol>()
.FirstOrDefault(m => m.Parameters.Length == 2);

if (defineSectionMethod is null)
return true;

foreach (var node in executeMethodBody.DescendantNodes())
{
if (node is InvocationExpressionSyntax
{
ArgumentList.Arguments:
[
{ Expression: LiteralExpressionSyntax { RawKind: (int)SyntaxKind.StringLiteralExpression } },
{ Expression: ParenthesizedLambdaExpressionSyntax lambda }
],
Expression: IdentifierNameSyntax { Identifier.ValueText: defineSectionMethodName } expression
}
&& !diagnosticLocations.Contains(lambda.ArrowToken.GetLocation())
&& SymbolEqualityComparer.Default.Equals(_semanticModel.GetSymbolInfo(expression, cancellationToken).Symbol, defineSectionMethod)
)
{
return false;
}
}

return true;
}

static string GetMethodSignatureFootprint(IMethodSymbol methodSymbol)
{
var sb = new StringBuilder();
Expand Down
12 changes: 12 additions & 0 deletions src/RazorBlade.Analyzers/Support/Extensions.cs
Expand Up @@ -7,6 +7,9 @@ namespace RazorBlade.Analyzers.Support;

internal static class Extensions
{
public static HashSet<T> ToHashSet<T>(this IEnumerable<T> items)
=> new(items);

public static IncrementalValuesProvider<T> WhereNotNull<T>(this IncrementalValuesProvider<T?> provider)
where T : class
=> provider.Where(static item => item is not null)!;
Expand All @@ -33,6 +36,15 @@ public static string EscapeCSharpKeyword(this string name)
? "@" + name
: name;

public static IEnumerable<INamedTypeSymbol> SelfAndBasesTypes(this INamedTypeSymbol? symbol)
{
while (symbol is not null)
{
yield return symbol;
symbol = symbol.BaseType;
}
}

private sealed class LambdaComparer<T>(Func<T, T, bool> equals, Func<T, int> getHashCode) : IEqualityComparer<T>
{
public bool Equals(T? x, T? y)
Expand Down

0 comments on commit 7577e38

Please sign in to comment.