Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move the 'make method synchronous' fixer down to shared layer #66862

Merged
merged 8 commits into from
Feb 15, 2023
1 change: 1 addition & 0 deletions src/Analyzers/CSharp/CodeFixes/CSharpCodeFixes.projitems
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
<Compile Include="$(MSBuildThisFileDirectory)MakeLocalFunctionStatic\MakeLocalFunctionStaticCodeFixProvider.cs" />
<Compile Include="$(MSBuildThisFileDirectory)MakeLocalFunctionStatic\PassInCapturedVariablesAsArgumentsCodeFixProvider.cs" />
<Compile Include="$(MSBuildThisFileDirectory)MakeMemberStatic\CSharpMakeMemberStaticCodeFixProvider.cs" />
<Compile Include="$(MSBuildThisFileDirectory)MakeMethodSynchronous\CSharpMakeMethodSynchronousCodeFixProvider.cs" />
<Compile Include="$(MSBuildThisFileDirectory)MakeRefStruct\MakeRefStructCodeFixProvider.cs" />
<Compile Include="$(MSBuildThisFileDirectory)MakeStatementAsynchronous\CSharpMakeStatementAsynchronousCodeFixProvider.cs" />
<Compile Include="$(MSBuildThisFileDirectory)MakeStructReadOnly\CSharpMakeStructReadOnlyCodeFixProvider.cs" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

#nullable disable

using System.Collections.Immutable;
using System.Composition;
using System.Diagnostics.CodeAnalysis;
Expand All @@ -13,7 +11,6 @@
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.MakeMethodSynchronous;
using Microsoft.CodeAnalysis.Shared.Extensions;
using static Microsoft.CodeAnalysis.MakeMethodAsynchronous.AbstractMakeMethodAsynchronousCodeFixProvider;

namespace Microsoft.CodeAnalysis.CSharp.MakeMethodSynchronous
{
Expand All @@ -34,12 +31,12 @@ public CSharpMakeMethodSynchronousCodeFixProvider()
protected override bool IsAsyncSupportingFunctionSyntax(SyntaxNode node)
=> node.IsAsyncSupportingFunctionSyntax();

protected override SyntaxNode RemoveAsyncTokenAndFixReturnType(IMethodSymbol methodSymbolOpt, SyntaxNode node, KnownTypes knownTypes)
protected override SyntaxNode RemoveAsyncTokenAndFixReturnType(IMethodSymbol methodSymbol, SyntaxNode node, KnownTypes knownTypes)
{
switch (node)
{
case MethodDeclarationSyntax method: return FixMethod(methodSymbolOpt, method, knownTypes);
case LocalFunctionStatementSyntax localFunction: return FixLocalFunction(methodSymbolOpt, localFunction, knownTypes);
case MethodDeclarationSyntax method: return FixMethod(methodSymbol, method, knownTypes);
case LocalFunctionStatementSyntax localFunction: return FixLocalFunction(methodSymbol, localFunction, knownTypes);
case AnonymousMethodExpressionSyntax method: return RemoveAsyncModifierHelpers.WithoutAsyncModifier(method);
case ParenthesizedLambdaExpressionSyntax lambda: return RemoveAsyncModifierHelpers.WithoutAsyncModifier(lambda);
case SimpleLambdaExpressionSyntax lambda: return RemoveAsyncModifierHelpers.WithoutAsyncModifier(lambda);
Expand Down Expand Up @@ -73,12 +70,14 @@ private static TypeSyntax FixMethodReturnType(IMethodSymbol methodSymbol, TypeSy
// If the return type is Task<T>, then make the new return type "T".
newReturnType = returnType.GetTypeArguments()[0].GenerateTypeSyntax().WithTriviaFrom(returnTypeSyntax);
}
else if (returnType.OriginalDefinition.Equals(knownTypes.IAsyncEnumerableOfTTypeOpt))
else if (returnType.OriginalDefinition.Equals(knownTypes.IAsyncEnumerableOfTTypeOpt) &&
knownTypes.IEnumerableOfTType != null)
{
// If the return type is IAsyncEnumerable<T>, then make the new return type IEnumerable<T>.
newReturnType = knownTypes.IEnumerableOfTType.Construct(methodSymbol.ReturnType.GetTypeArguments()[0]).GenerateTypeSyntax();
}
else if (returnType.OriginalDefinition.Equals(knownTypes.IAsyncEnumeratorOfTTypeOpt))
else if (returnType.OriginalDefinition.Equals(knownTypes.IAsyncEnumeratorOfTTypeOpt) &&
knownTypes.IEnumeratorOfTType != null)
{
// If the return type is IAsyncEnumerator<T>, then make the new return type IEnumerator<T>.
newReturnType = knownTypes.IEnumeratorOfTType.Construct(methodSymbol.ReturnType.GetTypeArguments()[0]).GenerateTypeSyntax();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
<Compile Include="$(MSBuildThisFileDirectory)Iterator\AddYieldTests.cs" />
<Compile Include="$(MSBuildThisFileDirectory)Iterator\ChangeToIEnumerableTests.cs" />
<Compile Include="$(MSBuildThisFileDirectory)MakeMemberStatic\MakeMemberStaticTests.cs" />
<Compile Include="$(MSBuildThisFileDirectory)MakeMethodSynchronous\MakeMethodSynchronousTests.cs" />
<Compile Include="$(MSBuildThisFileDirectory)MakeRefStruct\MakeRefStructTests.cs" />
<Compile Include="$(MSBuildThisFileDirectory)MakeStatementAsynchronous\CSharpMakeStatementAsynchronousCodeFixTests.cs" />
<Compile Include="$(MSBuildThisFileDirectory)MakeTypeAbstract\MakeTypeAbstractTests.cs" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,20 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

#nullable disable

using System.Threading.Tasks;
using Microsoft.CodeAnalysis.CSharp.MakeMethodSynchronous;
using Microsoft.CodeAnalysis.Editor.UnitTests.CodeActions;
using Microsoft.CodeAnalysis.Test.Utilities;
using Microsoft.CodeAnalysis.Testing;
using Roslyn.Test.Utilities;
using Xunit;
using VerifyCS = Microsoft.CodeAnalysis.Editor.UnitTests.CodeActions.CSharpCodeFixVerifier<
Microsoft.CodeAnalysis.Testing.EmptyDiagnosticAnalyzer,
Microsoft.CodeAnalysis.CSharp.MakeMethodSynchronous.CSharpMakeMethodSynchronousCodeFixProvider>;

namespace Microsoft.CodeAnalysis.Editor.CSharp.UnitTests.Diagnostics.MakeMethodSynchronous
namespace Microsoft.CodeAnalysis.Editor.CSharp.UnitTests.MakeMethodSynchronous
{
using VerifyCS = CSharpCodeFixVerifier<
EmptyDiagnosticAnalyzer,
CSharpMakeMethodSynchronousCodeFixProvider>;

public class MakeMethodSynchronousTests
{
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsMakeMethodSynchronous)]
Expand Down
1 change: 1 addition & 0 deletions src/Analyzers/Core/CodeFixes/CodeFixes.projitems
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
<Compile Include="$(MSBuildThisFileDirectory)Iterator\AbstractIteratorCodeFixProvider.cs" />
<Compile Include="$(MSBuildThisFileDirectory)MakeFieldReadonly\AbstractMakeFieldReadonlyCodeFixProvider.cs" />
<Compile Include="$(MSBuildThisFileDirectory)MakeMemberStatic\AbstractMakeMemberStaticCodeFixProvider.cs" />
<Compile Include="$(MSBuildThisFileDirectory)MakeMethodSynchronous\AbstractMakeMethodSynchronousCodeFixProvider.cs" />
<Compile Include="$(MSBuildThisFileDirectory)MakeTypeAbstract\AbstractMakeTypeAbstractCodeFixProvider.cs" />
<Compile Include="$(MSBuildThisFileDirectory)MakeTypePartial\AbstractMakeTypePartialCodeFixProvider.cs" />
<Compile Include="$(MSBuildThisFileDirectory)MatchFolderAndNamespace\AbstractChangeNamespaceToMatchFolderCodeFixProvider.cs" />
Expand Down
3 changes: 3 additions & 0 deletions src/Analyzers/Core/CodeFixes/CodeFixesResources.resx
Original file line number Diff line number Diff line change
Expand Up @@ -207,4 +207,7 @@
<data name="Remove_tag" xml:space="preserve">
<value>Remove tag</value>
</data>
<data name="Make_method_synchronous" xml:space="preserve">
<value>Make method synchronous</value>
</data>
</root>
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

#nullable disable

using System;
using System.Collections.Immutable;
using System.Diagnostics;
using System.Linq;
Expand All @@ -19,52 +16,57 @@
using Microsoft.CodeAnalysis.Rename;
using Microsoft.CodeAnalysis.Shared.Extensions;
using Roslyn.Utilities;
using static Microsoft.CodeAnalysis.MakeMethodAsynchronous.AbstractMakeMethodAsynchronousCodeFixProvider;

namespace Microsoft.CodeAnalysis.MakeMethodSynchronous
{
internal abstract class AbstractMakeMethodSynchronousCodeFixProvider : CodeFixProvider
{
protected abstract bool IsAsyncSupportingFunctionSyntax(SyntaxNode node);
protected abstract SyntaxNode RemoveAsyncTokenAndFixReturnType(IMethodSymbol methodSymbolOpt, SyntaxNode node, KnownTypes knownTypes);
protected abstract SyntaxNode RemoveAsyncTokenAndFixReturnType(IMethodSymbol methodSymbol, SyntaxNode node, KnownTypes knownTypes);

public override FixAllProvider GetFixAllProvider() => WellKnownFixAllProviders.BatchFixer;

public override Task RegisterCodeFixesAsync(CodeFixContext context)
{
context.RegisterCodeFix(
CodeAction.Create(
FeaturesResources.Make_method_synchronous,
c => FixNodeAsync(context.Document, context.Diagnostics.First(), c),
nameof(FeaturesResources.Make_method_synchronous)),
context.Diagnostics);
var cancellationToken = context.CancellationToken;
var diagnostic = context.Diagnostics.First();

var token = diagnostic.Location.FindToken(cancellationToken);
var node = token.GetAncestor(IsAsyncSupportingFunctionSyntax);
if (node != null)
{
context.RegisterCodeFix(
CodeAction.Create(
CodeFixesResources.Make_method_synchronous,
cancellationToken => FixNodeAsync(context.Document, node, cancellationToken),
nameof(CodeFixesResources.Make_method_synchronous)),
context.Diagnostics);
}

return Task.CompletedTask;
}

private const string AsyncSuffix = "Async";

private async Task<Solution> FixNodeAsync(
Document document, Diagnostic diagnostic, CancellationToken cancellationToken)
Document document, SyntaxNode node, CancellationToken cancellationToken)
{
var token = diagnostic.Location.FindToken(cancellationToken);
var node = token.GetAncestor(IsAsyncSupportingFunctionSyntax);

// See if we're on an actual method declaration (otherwise we're on a lambda declaration).
// If we're on a method declaration, we'll get an IMethodSymbol back. In that case, check
// if it has the 'Async' suffix, and remove that suffix if so.
var semanticModel = await document.GetSemanticModelAsync(cancellationToken).ConfigureAwait(false);
var methodSymbolOpt = semanticModel.GetDeclaredSymbol(node, cancellationToken) as IMethodSymbol;
var semanticModel = await document.GetRequiredSemanticModelAsync(cancellationToken).ConfigureAwait(false);
var methodSymbol = (IMethodSymbol?)(semanticModel.GetDeclaredSymbol(node, cancellationToken) ?? semanticModel.GetSymbolInfo(node, cancellationToken).GetAnySymbol());
Contract.ThrowIfNull(methodSymbol);

var isOrdinaryOrLocalFunction = methodSymbolOpt.IsOrdinaryMethodOrLocalFunction();
if (isOrdinaryOrLocalFunction &&
methodSymbolOpt.Name.Length > AsyncSuffix.Length &&
methodSymbolOpt.Name.EndsWith(AsyncSuffix))
if (methodSymbol.IsOrdinaryMethodOrLocalFunction() &&
methodSymbol.Name.Length > AsyncSuffix.Length &&
methodSymbol.Name.EndsWith(AsyncSuffix))
{
return await RenameThenRemoveAsyncTokenAsync(document, node, methodSymbolOpt, cancellationToken).ConfigureAwait(false);
return await RenameThenRemoveAsyncTokenAsync(document, node, methodSymbol, cancellationToken).ConfigureAwait(false);
}
else
{
return await RemoveAsyncTokenAsync(document, methodSymbolOpt, node, cancellationToken).ConfigureAwait(false);
return await RemoveAsyncTokenAsync(document, methodSymbol, node, cancellationToken).ConfigureAwait(false);
}
}

Expand All @@ -79,38 +81,36 @@ private async Task<Solution> RenameThenRemoveAsyncTokenAsync(Document document,

// Rename the method to remove the 'Async' suffix, then remove the 'async' keyword.
var newSolution = await Renamer.RenameSymbolAsync(solution, methodSymbol, new SymbolRenameOptions(), newName, cancellationToken).ConfigureAwait(false);
var newDocument = newSolution.GetDocument(document.Id);
var newDocument = newSolution.GetRequiredDocument(document.Id);
var newRoot = await newDocument.GetSyntaxRootAsync(cancellationToken).ConfigureAwait(false);
if (syntaxPath.TryResolve(newRoot, out SyntaxNode newNode))
if (syntaxPath.TryResolve(newRoot, out SyntaxNode? newNode))
{
var semanticModel = await newDocument.GetSemanticModelAsync(cancellationToken).ConfigureAwait(false);
var newMethod = (IMethodSymbol)semanticModel.GetDeclaredSymbol(newNode, cancellationToken);
var semanticModel = await newDocument.GetRequiredSemanticModelAsync(cancellationToken).ConfigureAwait(false);
var newMethod = (IMethodSymbol)semanticModel.GetRequiredDeclaredSymbol(newNode, cancellationToken);
return await RemoveAsyncTokenAsync(newDocument, newMethod, newNode, cancellationToken).ConfigureAwait(false);
}

return newSolution;
}

private async Task<Solution> RemoveAsyncTokenAsync(
Document document, IMethodSymbol methodSymbolOpt, SyntaxNode node, CancellationToken cancellationToken)
Document document, IMethodSymbol methodSymbol, SyntaxNode node, CancellationToken cancellationToken)
{
var compilation = await document.Project.GetCompilationAsync(cancellationToken).ConfigureAwait(false);
var compilation = await document.Project.GetRequiredCompilationAsync(cancellationToken).ConfigureAwait(false);
var knownTypes = new KnownTypes(compilation);

var annotation = new SyntaxAnnotation();
var newNode = RemoveAsyncTokenAndFixReturnType(methodSymbolOpt, node, knownTypes)
var newNode = RemoveAsyncTokenAndFixReturnType(methodSymbol, node, knownTypes)
.WithAdditionalAnnotations(Formatter.Annotation, annotation);

var root = await document.GetSyntaxRootAsync(cancellationToken).ConfigureAwait(false);
var root = await document.GetRequiredSyntaxRootAsync(cancellationToken).ConfigureAwait(false);
var newRoot = root.ReplaceNode(node, newNode);

var newDocument = document.WithSyntaxRoot(newRoot);
var newSolution = newDocument.Project.Solution;

if (methodSymbolOpt == null)
{
if (!methodSymbol.IsOrdinaryMethodOrLocalFunction())
return newSolution;
}

return await RemoveAwaitFromCallersAsync(
newDocument, annotation, cancellationToken).ConfigureAwait(false);
Expand All @@ -119,17 +119,26 @@ private async Task<Solution> RemoveAsyncTokenAsync(
private static async Task<Solution> RemoveAwaitFromCallersAsync(
Document document, SyntaxAnnotation annotation, CancellationToken cancellationToken)
{
var syntaxRoot = await document.GetSyntaxRootAsync(cancellationToken).ConfigureAwait(false);
var syntaxRoot = await document.GetRequiredSyntaxRootAsync(cancellationToken).ConfigureAwait(false);
var methodDeclaration = syntaxRoot.GetAnnotatedNodes(annotation).FirstOrDefault();
if (methodDeclaration != null)
{
var semanticModel = await document.GetSemanticModelAsync(cancellationToken).ConfigureAwait(false);
var semanticModel = await document.GetRequiredSemanticModelAsync(cancellationToken).ConfigureAwait(false);

if (semanticModel.GetDeclaredSymbol(methodDeclaration, cancellationToken) is IMethodSymbol methodSymbol)
{
#if CODE_STYLE

var references = await SymbolFinder.FindReferencesAsync(
methodSymbol, document.Project.Solution, cancellationToken).ConfigureAwait(false);

#else

var references = await SymbolFinder.FindRenamableReferencesAsync(
ImmutableArray.Create<ISymbol>(methodSymbol), document.Project.Solution, cancellationToken).ConfigureAwait(false);

#endif

var referencedSymbol = references.FirstOrDefault(r => Equals(r.Definition, methodSymbol));
if (referencedSymbol != null)
{
Expand Down Expand Up @@ -162,8 +171,8 @@ private static async Task<Solution> RemoveAwaitFromCallersAsync(
Solution currentSolution, IGrouping<Document, ReferenceLocation> group, CancellationToken cancellationToken)
{
var document = group.Key;
var syntaxFactsService = document.GetLanguageService<ISyntaxFactsService>();
var root = await document.GetSyntaxRootAsync(cancellationToken).ConfigureAwait(false);
var syntaxFactsService = document.GetRequiredLanguageService<ISyntaxFactsService>();
var root = await document.GetRequiredSyntaxRootAsync(cancellationToken).ConfigureAwait(false);

var editor = new SyntaxEditor(root, currentSolution.Services);

Expand Down Expand Up @@ -221,30 +230,31 @@ private static void RemoveAwaitFromCallerIfPresent(
if (syntaxFacts.IsExpressionOfAwaitExpression(invocationExpression))
{
// Handle the case where we're directly awaited.
var awaitExpression = invocationExpression.Parent;
var awaitExpression = invocationExpression.GetRequiredParent();
editor.ReplaceNode(awaitExpression, (currentAwaitExpression, generator) =>
syntaxFacts.GetExpressionOfAwaitExpression(currentAwaitExpression)
.WithTriviaFrom(currentAwaitExpression));
}
else if (syntaxFacts.IsExpressionOfMemberAccessExpression(invocationExpression))
{
// Check for the .ConfigureAwait case.
var parentMemberAccessExpression = invocationExpression.Parent;
var parentMemberAccessExpressionNameNode = syntaxFacts.GetNameOfMemberAccessExpression(
parentMemberAccessExpression);
var parentMemberAccessExpression = invocationExpression.GetRequiredParent();
var parentMemberAccessExpressionNameNode = syntaxFacts.GetNameOfMemberAccessExpression(parentMemberAccessExpression);

var parentMemberAccessExpressionName = syntaxFacts.GetIdentifierOfSimpleName(parentMemberAccessExpressionNameNode).ValueText;
if (parentMemberAccessExpressionName == nameof(Task.ConfigureAwait))
{
var parentExpression = parentMemberAccessExpression.Parent;
if (syntaxFacts.IsExpressionOfAwaitExpression(parentExpression))
{
var awaitExpression = parentExpression.Parent;
var awaitExpression = parentExpression.GetRequiredParent();
editor.ReplaceNode(awaitExpression, (currentAwaitExpression, generator) =>
{
var currentConfigureAwaitInvocation = syntaxFacts.GetExpressionOfAwaitExpression(currentAwaitExpression);
var currentMemberAccess = syntaxFacts.GetExpressionOfInvocationExpression(currentConfigureAwaitInvocation);
var currentInvocationExpression = syntaxFacts.GetExpressionOfMemberAccessExpression(currentMemberAccess);
Contract.ThrowIfNull(currentInvocationExpression);

return currentInvocationExpression.WithTriviaFrom(currentAwaitExpression);
});
}
Expand Down
Loading