Skip to content

Commit

Permalink
Merge pull request #66862 from CyrusNajmabadi/makeSynchronousDown
Browse files Browse the repository at this point in the history
Move the 'make method synchronous' fixer down to shared layer
  • Loading branch information
CyrusNajmabadi committed Feb 15, 2023
2 parents fced3be + 7976ea9 commit 1cd0dcc
Show file tree
Hide file tree
Showing 38 changed files with 141 additions and 128 deletions.
1 change: 1 addition & 0 deletions src/Analyzers/CSharp/CodeFixes/CSharpCodeFixes.projitems
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
<Compile Include="$(MSBuildThisFileDirectory)MakeLocalFunctionStatic\MakeLocalFunctionStaticCodeFixProvider.cs" />
<Compile Include="$(MSBuildThisFileDirectory)MakeLocalFunctionStatic\PassInCapturedVariablesAsArgumentsCodeFixProvider.cs" />
<Compile Include="$(MSBuildThisFileDirectory)MakeMemberStatic\CSharpMakeMemberStaticCodeFixProvider.cs" />
<Compile Include="$(MSBuildThisFileDirectory)MakeMethodSynchronous\CSharpMakeMethodSynchronousCodeFixProvider.cs" />
<Compile Include="$(MSBuildThisFileDirectory)MakeRefStruct\MakeRefStructCodeFixProvider.cs" />
<Compile Include="$(MSBuildThisFileDirectory)MakeStatementAsynchronous\CSharpMakeStatementAsynchronousCodeFixProvider.cs" />
<Compile Include="$(MSBuildThisFileDirectory)MakeStructReadOnly\CSharpMakeStructReadOnlyCodeFixProvider.cs" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

#nullable disable

using System.Collections.Immutable;
using System.Composition;
using System.Diagnostics.CodeAnalysis;
Expand All @@ -13,7 +11,6 @@
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.MakeMethodSynchronous;
using Microsoft.CodeAnalysis.Shared.Extensions;
using static Microsoft.CodeAnalysis.MakeMethodAsynchronous.AbstractMakeMethodAsynchronousCodeFixProvider;

namespace Microsoft.CodeAnalysis.CSharp.MakeMethodSynchronous
{
Expand All @@ -34,12 +31,12 @@ public CSharpMakeMethodSynchronousCodeFixProvider()
protected override bool IsAsyncSupportingFunctionSyntax(SyntaxNode node)
=> node.IsAsyncSupportingFunctionSyntax();

protected override SyntaxNode RemoveAsyncTokenAndFixReturnType(IMethodSymbol methodSymbolOpt, SyntaxNode node, KnownTypes knownTypes)
protected override SyntaxNode RemoveAsyncTokenAndFixReturnType(IMethodSymbol methodSymbol, SyntaxNode node, KnownTypes knownTypes)
{
switch (node)
{
case MethodDeclarationSyntax method: return FixMethod(methodSymbolOpt, method, knownTypes);
case LocalFunctionStatementSyntax localFunction: return FixLocalFunction(methodSymbolOpt, localFunction, knownTypes);
case MethodDeclarationSyntax method: return FixMethod(methodSymbol, method, knownTypes);
case LocalFunctionStatementSyntax localFunction: return FixLocalFunction(methodSymbol, localFunction, knownTypes);
case AnonymousMethodExpressionSyntax method: return RemoveAsyncModifierHelpers.WithoutAsyncModifier(method);
case ParenthesizedLambdaExpressionSyntax lambda: return RemoveAsyncModifierHelpers.WithoutAsyncModifier(lambda);
case SimpleLambdaExpressionSyntax lambda: return RemoveAsyncModifierHelpers.WithoutAsyncModifier(lambda);
Expand Down Expand Up @@ -73,12 +70,14 @@ private static TypeSyntax FixMethodReturnType(IMethodSymbol methodSymbol, TypeSy
// If the return type is Task<T>, then make the new return type "T".
newReturnType = returnType.GetTypeArguments()[0].GenerateTypeSyntax().WithTriviaFrom(returnTypeSyntax);
}
else if (returnType.OriginalDefinition.Equals(knownTypes.IAsyncEnumerableOfTTypeOpt))
else if (returnType.OriginalDefinition.Equals(knownTypes.IAsyncEnumerableOfTTypeOpt) &&
knownTypes.IEnumerableOfTType != null)
{
// If the return type is IAsyncEnumerable<T>, then make the new return type IEnumerable<T>.
newReturnType = knownTypes.IEnumerableOfTType.Construct(methodSymbol.ReturnType.GetTypeArguments()[0]).GenerateTypeSyntax();
}
else if (returnType.OriginalDefinition.Equals(knownTypes.IAsyncEnumeratorOfTTypeOpt))
else if (returnType.OriginalDefinition.Equals(knownTypes.IAsyncEnumeratorOfTTypeOpt) &&
knownTypes.IEnumeratorOfTType != null)
{
// If the return type is IAsyncEnumerator<T>, then make the new return type IEnumerator<T>.
newReturnType = knownTypes.IEnumeratorOfTType.Construct(methodSymbol.ReturnType.GetTypeArguments()[0]).GenerateTypeSyntax();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
<Compile Include="$(MSBuildThisFileDirectory)Iterator\AddYieldTests.cs" />
<Compile Include="$(MSBuildThisFileDirectory)Iterator\ChangeToIEnumerableTests.cs" />
<Compile Include="$(MSBuildThisFileDirectory)MakeMemberStatic\MakeMemberStaticTests.cs" />
<Compile Include="$(MSBuildThisFileDirectory)MakeMethodSynchronous\MakeMethodSynchronousTests.cs" />
<Compile Include="$(MSBuildThisFileDirectory)MakeRefStruct\MakeRefStructTests.cs" />
<Compile Include="$(MSBuildThisFileDirectory)MakeStatementAsynchronous\CSharpMakeStatementAsynchronousCodeFixTests.cs" />
<Compile Include="$(MSBuildThisFileDirectory)MakeTypeAbstract\MakeTypeAbstractTests.cs" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,20 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

#nullable disable

using System.Threading.Tasks;
using Microsoft.CodeAnalysis.CSharp.MakeMethodSynchronous;
using Microsoft.CodeAnalysis.Editor.UnitTests.CodeActions;
using Microsoft.CodeAnalysis.Test.Utilities;
using Microsoft.CodeAnalysis.Testing;
using Roslyn.Test.Utilities;
using Xunit;
using VerifyCS = Microsoft.CodeAnalysis.Editor.UnitTests.CodeActions.CSharpCodeFixVerifier<
Microsoft.CodeAnalysis.Testing.EmptyDiagnosticAnalyzer,
Microsoft.CodeAnalysis.CSharp.MakeMethodSynchronous.CSharpMakeMethodSynchronousCodeFixProvider>;

namespace Microsoft.CodeAnalysis.Editor.CSharp.UnitTests.Diagnostics.MakeMethodSynchronous
namespace Microsoft.CodeAnalysis.Editor.CSharp.UnitTests.MakeMethodSynchronous
{
using VerifyCS = CSharpCodeFixVerifier<
EmptyDiagnosticAnalyzer,
CSharpMakeMethodSynchronousCodeFixProvider>;

public class MakeMethodSynchronousTests
{
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsMakeMethodSynchronous)]
Expand Down
1 change: 1 addition & 0 deletions src/Analyzers/Core/CodeFixes/CodeFixes.projitems
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
<Compile Include="$(MSBuildThisFileDirectory)Iterator\AbstractIteratorCodeFixProvider.cs" />
<Compile Include="$(MSBuildThisFileDirectory)MakeFieldReadonly\AbstractMakeFieldReadonlyCodeFixProvider.cs" />
<Compile Include="$(MSBuildThisFileDirectory)MakeMemberStatic\AbstractMakeMemberStaticCodeFixProvider.cs" />
<Compile Include="$(MSBuildThisFileDirectory)MakeMethodSynchronous\AbstractMakeMethodSynchronousCodeFixProvider.cs" />
<Compile Include="$(MSBuildThisFileDirectory)MakeTypeAbstract\AbstractMakeTypeAbstractCodeFixProvider.cs" />
<Compile Include="$(MSBuildThisFileDirectory)MakeTypePartial\AbstractMakeTypePartialCodeFixProvider.cs" />
<Compile Include="$(MSBuildThisFileDirectory)MatchFolderAndNamespace\AbstractChangeNamespaceToMatchFolderCodeFixProvider.cs" />
Expand Down
3 changes: 3 additions & 0 deletions src/Analyzers/Core/CodeFixes/CodeFixesResources.resx
Original file line number Diff line number Diff line change
Expand Up @@ -207,4 +207,7 @@
<data name="Remove_tag" xml:space="preserve">
<value>Remove tag</value>
</data>
<data name="Make_method_synchronous" xml:space="preserve">
<value>Make method synchronous</value>
</data>
</root>
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

#nullable disable

using System;
using System.Collections.Immutable;
using System.Diagnostics;
using System.Linq;
Expand All @@ -19,52 +16,57 @@
using Microsoft.CodeAnalysis.Rename;
using Microsoft.CodeAnalysis.Shared.Extensions;
using Roslyn.Utilities;
using static Microsoft.CodeAnalysis.MakeMethodAsynchronous.AbstractMakeMethodAsynchronousCodeFixProvider;

namespace Microsoft.CodeAnalysis.MakeMethodSynchronous
{
internal abstract class AbstractMakeMethodSynchronousCodeFixProvider : CodeFixProvider
{
protected abstract bool IsAsyncSupportingFunctionSyntax(SyntaxNode node);
protected abstract SyntaxNode RemoveAsyncTokenAndFixReturnType(IMethodSymbol methodSymbolOpt, SyntaxNode node, KnownTypes knownTypes);
protected abstract SyntaxNode RemoveAsyncTokenAndFixReturnType(IMethodSymbol methodSymbol, SyntaxNode node, KnownTypes knownTypes);

public override FixAllProvider GetFixAllProvider() => WellKnownFixAllProviders.BatchFixer;

public override Task RegisterCodeFixesAsync(CodeFixContext context)
{
context.RegisterCodeFix(
CodeAction.Create(
FeaturesResources.Make_method_synchronous,
c => FixNodeAsync(context.Document, context.Diagnostics.First(), c),
nameof(FeaturesResources.Make_method_synchronous)),
context.Diagnostics);
var cancellationToken = context.CancellationToken;
var diagnostic = context.Diagnostics.First();

var token = diagnostic.Location.FindToken(cancellationToken);
var node = token.GetAncestor(IsAsyncSupportingFunctionSyntax);
if (node != null)
{
context.RegisterCodeFix(
CodeAction.Create(
CodeFixesResources.Make_method_synchronous,
cancellationToken => FixNodeAsync(context.Document, node, cancellationToken),
nameof(CodeFixesResources.Make_method_synchronous)),
context.Diagnostics);
}

return Task.CompletedTask;
}

private const string AsyncSuffix = "Async";

private async Task<Solution> FixNodeAsync(
Document document, Diagnostic diagnostic, CancellationToken cancellationToken)
Document document, SyntaxNode node, CancellationToken cancellationToken)
{
var token = diagnostic.Location.FindToken(cancellationToken);
var node = token.GetAncestor(IsAsyncSupportingFunctionSyntax);

// See if we're on an actual method declaration (otherwise we're on a lambda declaration).
// If we're on a method declaration, we'll get an IMethodSymbol back. In that case, check
// if it has the 'Async' suffix, and remove that suffix if so.
var semanticModel = await document.GetSemanticModelAsync(cancellationToken).ConfigureAwait(false);
var methodSymbolOpt = semanticModel.GetDeclaredSymbol(node, cancellationToken) as IMethodSymbol;
var semanticModel = await document.GetRequiredSemanticModelAsync(cancellationToken).ConfigureAwait(false);
var methodSymbol = (IMethodSymbol?)(semanticModel.GetDeclaredSymbol(node, cancellationToken) ?? semanticModel.GetSymbolInfo(node, cancellationToken).GetAnySymbol());
Contract.ThrowIfNull(methodSymbol);

var isOrdinaryOrLocalFunction = methodSymbolOpt.IsOrdinaryMethodOrLocalFunction();
if (isOrdinaryOrLocalFunction &&
methodSymbolOpt.Name.Length > AsyncSuffix.Length &&
methodSymbolOpt.Name.EndsWith(AsyncSuffix))
if (methodSymbol.IsOrdinaryMethodOrLocalFunction() &&
methodSymbol.Name.Length > AsyncSuffix.Length &&
methodSymbol.Name.EndsWith(AsyncSuffix))
{
return await RenameThenRemoveAsyncTokenAsync(document, node, methodSymbolOpt, cancellationToken).ConfigureAwait(false);
return await RenameThenRemoveAsyncTokenAsync(document, node, methodSymbol, cancellationToken).ConfigureAwait(false);
}
else
{
return await RemoveAsyncTokenAsync(document, methodSymbolOpt, node, cancellationToken).ConfigureAwait(false);
return await RemoveAsyncTokenAsync(document, methodSymbol, node, cancellationToken).ConfigureAwait(false);
}
}

Expand All @@ -79,38 +81,36 @@ private async Task<Solution> RenameThenRemoveAsyncTokenAsync(Document document,

// Rename the method to remove the 'Async' suffix, then remove the 'async' keyword.
var newSolution = await Renamer.RenameSymbolAsync(solution, methodSymbol, new SymbolRenameOptions(), newName, cancellationToken).ConfigureAwait(false);
var newDocument = newSolution.GetDocument(document.Id);
var newDocument = newSolution.GetRequiredDocument(document.Id);
var newRoot = await newDocument.GetSyntaxRootAsync(cancellationToken).ConfigureAwait(false);
if (syntaxPath.TryResolve(newRoot, out SyntaxNode newNode))
if (syntaxPath.TryResolve(newRoot, out SyntaxNode? newNode))
{
var semanticModel = await newDocument.GetSemanticModelAsync(cancellationToken).ConfigureAwait(false);
var newMethod = (IMethodSymbol)semanticModel.GetDeclaredSymbol(newNode, cancellationToken);
var semanticModel = await newDocument.GetRequiredSemanticModelAsync(cancellationToken).ConfigureAwait(false);
var newMethod = (IMethodSymbol)semanticModel.GetRequiredDeclaredSymbol(newNode, cancellationToken);
return await RemoveAsyncTokenAsync(newDocument, newMethod, newNode, cancellationToken).ConfigureAwait(false);
}

return newSolution;
}

private async Task<Solution> RemoveAsyncTokenAsync(
Document document, IMethodSymbol methodSymbolOpt, SyntaxNode node, CancellationToken cancellationToken)
Document document, IMethodSymbol methodSymbol, SyntaxNode node, CancellationToken cancellationToken)
{
var compilation = await document.Project.GetCompilationAsync(cancellationToken).ConfigureAwait(false);
var compilation = await document.Project.GetRequiredCompilationAsync(cancellationToken).ConfigureAwait(false);
var knownTypes = new KnownTypes(compilation);

var annotation = new SyntaxAnnotation();
var newNode = RemoveAsyncTokenAndFixReturnType(methodSymbolOpt, node, knownTypes)
var newNode = RemoveAsyncTokenAndFixReturnType(methodSymbol, node, knownTypes)
.WithAdditionalAnnotations(Formatter.Annotation, annotation);

var root = await document.GetSyntaxRootAsync(cancellationToken).ConfigureAwait(false);
var root = await document.GetRequiredSyntaxRootAsync(cancellationToken).ConfigureAwait(false);
var newRoot = root.ReplaceNode(node, newNode);

var newDocument = document.WithSyntaxRoot(newRoot);
var newSolution = newDocument.Project.Solution;

if (methodSymbolOpt == null)
{
if (!methodSymbol.IsOrdinaryMethodOrLocalFunction())
return newSolution;
}

return await RemoveAwaitFromCallersAsync(
newDocument, annotation, cancellationToken).ConfigureAwait(false);
Expand All @@ -119,17 +119,26 @@ private async Task<Solution> RenameThenRemoveAsyncTokenAsync(Document document,
private static async Task<Solution> RemoveAwaitFromCallersAsync(
Document document, SyntaxAnnotation annotation, CancellationToken cancellationToken)
{
var syntaxRoot = await document.GetSyntaxRootAsync(cancellationToken).ConfigureAwait(false);
var syntaxRoot = await document.GetRequiredSyntaxRootAsync(cancellationToken).ConfigureAwait(false);
var methodDeclaration = syntaxRoot.GetAnnotatedNodes(annotation).FirstOrDefault();
if (methodDeclaration != null)
{
var semanticModel = await document.GetSemanticModelAsync(cancellationToken).ConfigureAwait(false);
var semanticModel = await document.GetRequiredSemanticModelAsync(cancellationToken).ConfigureAwait(false);

if (semanticModel.GetDeclaredSymbol(methodDeclaration, cancellationToken) is IMethodSymbol methodSymbol)
{
#if CODE_STYLE

var references = await SymbolFinder.FindReferencesAsync(
methodSymbol, document.Project.Solution, cancellationToken).ConfigureAwait(false);

#else

var references = await SymbolFinder.FindRenamableReferencesAsync(
ImmutableArray.Create<ISymbol>(methodSymbol), document.Project.Solution, cancellationToken).ConfigureAwait(false);

#endif

var referencedSymbol = references.FirstOrDefault(r => Equals(r.Definition, methodSymbol));
if (referencedSymbol != null)
{
Expand Down Expand Up @@ -162,8 +171,8 @@ private async Task<Solution> RenameThenRemoveAsyncTokenAsync(Document document,
Solution currentSolution, IGrouping<Document, ReferenceLocation> group, CancellationToken cancellationToken)
{
var document = group.Key;
var syntaxFactsService = document.GetLanguageService<ISyntaxFactsService>();
var root = await document.GetSyntaxRootAsync(cancellationToken).ConfigureAwait(false);
var syntaxFactsService = document.GetRequiredLanguageService<ISyntaxFactsService>();
var root = await document.GetRequiredSyntaxRootAsync(cancellationToken).ConfigureAwait(false);

var editor = new SyntaxEditor(root, currentSolution.Services);

Expand Down Expand Up @@ -221,30 +230,31 @@ private async Task<Solution> RenameThenRemoveAsyncTokenAsync(Document document,
if (syntaxFacts.IsExpressionOfAwaitExpression(invocationExpression))
{
// Handle the case where we're directly awaited.
var awaitExpression = invocationExpression.Parent;
var awaitExpression = invocationExpression.GetRequiredParent();
editor.ReplaceNode(awaitExpression, (currentAwaitExpression, generator) =>
syntaxFacts.GetExpressionOfAwaitExpression(currentAwaitExpression)
.WithTriviaFrom(currentAwaitExpression));
}
else if (syntaxFacts.IsExpressionOfMemberAccessExpression(invocationExpression))
{
// Check for the .ConfigureAwait case.
var parentMemberAccessExpression = invocationExpression.Parent;
var parentMemberAccessExpressionNameNode = syntaxFacts.GetNameOfMemberAccessExpression(
parentMemberAccessExpression);
var parentMemberAccessExpression = invocationExpression.GetRequiredParent();
var parentMemberAccessExpressionNameNode = syntaxFacts.GetNameOfMemberAccessExpression(parentMemberAccessExpression);

var parentMemberAccessExpressionName = syntaxFacts.GetIdentifierOfSimpleName(parentMemberAccessExpressionNameNode).ValueText;
if (parentMemberAccessExpressionName == nameof(Task.ConfigureAwait))
{
var parentExpression = parentMemberAccessExpression.Parent;
if (syntaxFacts.IsExpressionOfAwaitExpression(parentExpression))
{
var awaitExpression = parentExpression.Parent;
var awaitExpression = parentExpression.GetRequiredParent();
editor.ReplaceNode(awaitExpression, (currentAwaitExpression, generator) =>
{
var currentConfigureAwaitInvocation = syntaxFacts.GetExpressionOfAwaitExpression(currentAwaitExpression);
var currentMemberAccess = syntaxFacts.GetExpressionOfInvocationExpression(currentConfigureAwaitInvocation);
var currentInvocationExpression = syntaxFacts.GetExpressionOfMemberAccessExpression(currentMemberAccess);
Contract.ThrowIfNull(currentInvocationExpression);
return currentInvocationExpression.WithTriviaFrom(currentAwaitExpression);
});
}
Expand Down
Loading

0 comments on commit 1cd0dcc

Please sign in to comment.