diff --git a/RefactoringEssentials/CSharp/CodeRefactorings/Synced/ConvertInstanceToStaticMethodCodeRefactoringProvider.cs b/RefactoringEssentials/CSharp/CodeRefactorings/Synced/ConvertInstanceToStaticMethodCodeRefactoringProvider.cs new file mode 100644 index 00000000..2a15e4c5 --- /dev/null +++ b/RefactoringEssentials/CSharp/CodeRefactorings/Synced/ConvertInstanceToStaticMethodCodeRefactoringProvider.cs @@ -0,0 +1,317 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.CodeActions; +using Microsoft.CodeAnalysis.CodeRefactorings; +using Microsoft.CodeAnalysis.Formatting; +using Microsoft.CodeAnalysis.Text; +using Microsoft.CodeAnalysis.FindSymbols; + +namespace RefactoringEssentials.CSharp +{ + /// + /// Converts an instance method to a static method adding an additional parameter as "this" replacement. + /// + [ExportCodeRefactoringProvider(LanguageNames.CSharp, Name = "Convert instance to static method")] + public class ConvertInstanceToStaticMethodCodeRefactoringProvider : SpecializedCodeRefactoringProvider + { + protected override IEnumerable GetActions(Document document, SemanticModel semanticModel, SyntaxNode root, TextSpan span, MethodDeclarationSyntax node, CancellationToken cancellationToken) + { + TypeDeclarationSyntax enclosingTypeDeclaration = node.Ancestors().OfType().FirstOrDefault(); + if (enclosingTypeDeclaration == null) + yield break; + if (node.Modifiers.Any(SyntaxKind.StaticKeyword)) + yield break; + + var declaringTypeSymbol = semanticModel.GetDeclaredSymbol(enclosingTypeDeclaration); + var methodSymbol = semanticModel.GetDeclaredSymbol(node); + + yield return CodeActionFactory.Create(span, DiagnosticSeverity.Info, GettextCatalog.GetString("Convert to static method"), t2 => + { + return PerformAction(document, semanticModel, root, enclosingTypeDeclaration, declaringTypeSymbol, node, methodSymbol, cancellationToken); + }); + } + + class MethodReferencesInDocument + { + public readonly Document Document; + public readonly IEnumerable References; + + public MethodReferencesInDocument(Document document, IEnumerable references) + { + this.Document = document; + this.References = references; + } + } + + class ReferencingInvocationExpression + { + public readonly bool IsInChangedMethod; + public readonly InvocationExpressionSyntax InvocationExpression; + + public ReferencingInvocationExpression(bool isInChangedMethod, InvocationExpressionSyntax invocationExpression) + { + this.IsInChangedMethod = isInChangedMethod; + this.InvocationExpression = invocationExpression; + } + } + + async Task PerformAction(Document document, SemanticModel model, SyntaxNode root, TypeDeclarationSyntax enclosingTypeDeclaration, INamedTypeSymbol declaringTypeSymbol, MethodDeclarationSyntax methodDeclaration, IMethodSymbol methodSymbol, CancellationToken cancellationToken) + { + // Collect all invocations of changed method + var methodReferencesVisitor = new MethodReferencesVisitor(document.Project.Solution, methodSymbol, methodDeclaration, cancellationToken); + await methodReferencesVisitor.Collect(); + + // Collect all references to type members and "this" expressions inside of changed method + var memberReferencesVisitor = new MemberReferencesVisitor(model, declaringTypeSymbol.GetMembers().Where(m => m != methodSymbol), cancellationToken); + memberReferencesVisitor.Collect(methodDeclaration.Body); + + Solution solution = document.Project.Solution; + + List trackedNodesInMainDoc = new List(); + trackedNodesInMainDoc.Add(methodDeclaration); + var methodReferencesInMainDocument = methodReferencesVisitor.NodesToChange.FirstOrDefault(n => n.Document.Id == document.Id); + if (methodReferencesInMainDocument != null) + { + trackedNodesInMainDoc.AddRange(methodReferencesInMainDocument.References.Select(r => r.InvocationExpression)); + } + trackedNodesInMainDoc.AddRange(memberReferencesVisitor.NodesToChange); + + var newMainRoot = root.TrackNodes(trackedNodesInMainDoc); + + foreach (var invocationsInDocument in methodReferencesVisitor.NodesToChange) + { + SyntaxNode thisDocRoot = null; + var thisDocumentId = invocationsInDocument.Document.Id; + if (document.Id == thisDocumentId) + { + // We are in same document as changed method declaration, reuse new root from outside + thisDocRoot = newMainRoot; + } + else + { + thisDocRoot = await invocationsInDocument.Document.GetSyntaxRootAsync(); + if (thisDocRoot == null) + continue; + thisDocRoot = thisDocRoot.TrackNodes(invocationsInDocument.References.Select(r => r.InvocationExpression)); + } + + foreach (var referencingInvocation in invocationsInDocument.References) + { + // Change this method invocation to invocation of a static method with instance parameter + var thisInvocation = thisDocRoot.GetCurrentNode(referencingInvocation.InvocationExpression); + + ExpressionSyntax invocationExpressionPart = null; + SimpleNameSyntax methodName = null; + var memberAccessExpr = thisInvocation.Expression as MemberAccessExpressionSyntax; + if (memberAccessExpr != null) + { + invocationExpressionPart = memberAccessExpr.Expression; + methodName = memberAccessExpr.Name; + } + + if (invocationExpressionPart == null) + { + var identifier = thisInvocation.Expression as IdentifierNameSyntax; + if (identifier != null) + { + // If changed method references itself, use "instance" as additional parameter! In other methods of affected class, use "this"! + if (referencingInvocation.IsInChangedMethod) + invocationExpressionPart = SyntaxFactory.IdentifierName("instance").WithLeadingTrivia(identifier.GetLeadingTrivia()); + else + invocationExpressionPart = SyntaxFactory.ThisExpression().WithLeadingTrivia(identifier.GetLeadingTrivia()); + methodName = identifier; + } + } + + if (invocationExpressionPart == null) + continue; + + List invocationArguments = new List(); + invocationArguments.Add(SyntaxFactory.Argument(invocationExpressionPart.WithoutLeadingTrivia())); + invocationArguments.AddRange(referencingInvocation.InvocationExpression.ArgumentList.Arguments); + + thisDocRoot = thisDocRoot.ReplaceNode( + thisInvocation, + SyntaxFactory.InvocationExpression( + SyntaxFactory.MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + SyntaxFactory.IdentifierName(enclosingTypeDeclaration.Identifier.WithoutTrivia()).WithLeadingTrivia(invocationExpressionPart.GetLeadingTrivia()), + methodName.WithoutLeadingTrivia() + ), + SyntaxFactory.ArgumentList(SyntaxFactory.SeparatedList(invocationArguments)).WithAdditionalAnnotations(Formatter.Annotation) + )); + } + + + if (document.Id == thisDocumentId) + { + // Write new root back to outside + newMainRoot = thisDocRoot; + } + else + { + // Another document, replace it with modified version in solution + solution = solution.WithDocumentSyntaxRoot(thisDocumentId, thisDocRoot); + } + } + + foreach (var changedNode in memberReferencesVisitor.NodesToChange) + { + var trackedNode = newMainRoot.GetCurrentNode(changedNode); + + var thisExpression = trackedNode as ThisExpressionSyntax; + if (thisExpression != null) + { + // Replace "this" with instance parameter name + newMainRoot = newMainRoot.ReplaceNode( + thisExpression, + SyntaxFactory.IdentifierName("instance").WithLeadingTrivia(thisExpression.GetLeadingTrivia()) + ); + } + + var memberIdentifier = trackedNode as IdentifierNameSyntax; + if (memberIdentifier != null) + { + newMainRoot = newMainRoot.ReplaceNode( + memberIdentifier, + SyntaxFactory.MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + SyntaxFactory.IdentifierName("instance").WithLeadingTrivia(memberIdentifier.GetLeadingTrivia()), + memberIdentifier.WithoutLeadingTrivia()) + ); + } + } + + List parameters = new List(); + parameters.Add(SyntaxFactory.Parameter( + SyntaxFactory.List(), + SyntaxFactory.TokenList(), + SyntaxFactory.ParseTypeName(enclosingTypeDeclaration.Identifier.ValueText), + SyntaxFactory.Identifier("instance"), null) + .WithAdditionalAnnotations(Formatter.Annotation)); + parameters.AddRange(methodDeclaration.ParameterList.Parameters); + + var staticModifierLeadingTrivia = + methodDeclaration.Modifiers.Any() ? SyntaxFactory.TriviaList() : methodDeclaration.GetLeadingTrivia(); + var methodDeclarationLeadingTrivia = + methodDeclaration.Modifiers.Any() ? methodDeclaration.GetLeadingTrivia() : SyntaxFactory.TriviaList(); + + var trackedMethodDeclaration = newMainRoot.GetCurrentNode(methodDeclaration); + newMainRoot = newMainRoot.ReplaceNode((SyntaxNode)trackedMethodDeclaration, trackedMethodDeclaration + .WithLeadingTrivia(methodDeclarationLeadingTrivia) + .WithModifiers(trackedMethodDeclaration.Modifiers.Add(SyntaxFactory.Token(SyntaxKind.StaticKeyword).WithLeadingTrivia(staticModifierLeadingTrivia).WithTrailingTrivia(SyntaxFactory.TriviaList(SyntaxFactory.Whitespace(" "))))) + .WithParameterList(SyntaxFactory.ParameterList(SyntaxFactory.SeparatedList(parameters)).WithTrailingTrivia(trackedMethodDeclaration.ParameterList.GetTrailingTrivia()))); + return solution.WithDocumentSyntaxRoot(document.Id, newMainRoot); + } + + class MethodReferencesVisitor + { + readonly Solution solution; + readonly MethodDeclarationSyntax changedMethodDeclaration; + readonly ISymbol methodSymbol; + readonly CancellationToken cancellationToken; + + public readonly List NodesToChange = new List(); + + public MethodReferencesVisitor(Solution solution, ISymbol methodSymbol, MethodDeclarationSyntax changedMethodDeclaration, CancellationToken cancellationToken) + { + this.solution = solution; + this.methodSymbol = methodSymbol; + this.changedMethodDeclaration = changedMethodDeclaration; + this.cancellationToken = cancellationToken; + } + + public async Task Collect() + { + var invocations = await SymbolFinder.FindCallersAsync(methodSymbol, solution); + var invocationsPerDocument = from invocation in invocations + from location in invocation.Locations + where location.SourceTree != null + group location by location.SourceTree into locationGroup + select locationGroup; + + foreach (var locationsInDocument in invocationsPerDocument) + { + var document = solution.GetDocument(locationsInDocument.Key); + if (document == null) + continue; + + var root = await document.GetSyntaxRootAsync(cancellationToken); + if (root == null) + continue; + + NodesToChange.Add(new MethodReferencesInDocument( + document, + locationsInDocument.Select(loc => + { + if (!loc.IsInSource) + return null; + + var node = root.FindNode(loc.SourceSpan); + if (node == null) + return null; + + var invocationExpression = node.AncestorsAndSelf().OfType().FirstOrDefault(); + if (invocationExpression == null) + return null; + + return new ReferencingInvocationExpression(invocationExpression.Ancestors().Contains(changedMethodDeclaration), invocationExpression); + }) + .Where(r => r != null))); + } + } + } + + class MemberReferencesVisitor : CSharpSyntaxWalker + { + readonly SemanticModel semanticModel; + readonly IEnumerable referenceSymbols; + readonly CancellationToken cancellationToken; + + public readonly List NodesToChange = new List(); + + public MemberReferencesVisitor(SemanticModel semanticModel, IEnumerable referenceSymbols, CancellationToken cancellationToken) + { + this.semanticModel = semanticModel; + this.referenceSymbols = referenceSymbols; + this.cancellationToken = cancellationToken; + } + + public void Collect(SyntaxNode root) + { + this.Visit(root); + } + + public override void VisitIdentifierName(IdentifierNameSyntax node) + { + cancellationToken.ThrowIfCancellationRequested(); + + if (!(node.Parent is MemberAccessExpressionSyntax)) + { + var thisSymbolInfo = semanticModel.GetSymbolInfo(node); + if (referenceSymbols.Contains(thisSymbolInfo.Symbol)) + { + NodesToChange.Add(node); + } + } + + base.VisitIdentifierName(node); + } + + public override void VisitThisExpression(ThisExpressionSyntax node) + { + cancellationToken.ThrowIfCancellationRequested(); + + NodesToChange.Add(node); + + base.VisitThisExpression(node); + } + } + } +} + diff --git a/RefactoringEssentials/CSharp/Diagnostics/CodeActionFactory.cs b/RefactoringEssentials/CSharp/Diagnostics/CodeActionFactory.cs index 85498e57..a7bac4d4 100644 --- a/RefactoringEssentials/CSharp/Diagnostics/CodeActionFactory.cs +++ b/RefactoringEssentials/CSharp/Diagnostics/CodeActionFactory.cs @@ -12,27 +12,45 @@ static class CodeActionFactory public static CodeAction Create(TextSpan textSpan, DiagnosticSeverity severity, string description, Document changedDocument) { if (description == null) - throw new ArgumentNullException("description"); + throw new ArgumentNullException(nameof(description)); if (changedDocument == null) - throw new ArgumentNullException("changedDocument"); + throw new ArgumentNullException(nameof(changedDocument)); return new DocumentChangeAction(textSpan, severity, description, ct => Task.FromResult(changedDocument)); } public static CodeAction Create(TextSpan textSpan, DiagnosticSeverity severity, string description, Func> createChangedDocument) { if (description == null) - throw new ArgumentNullException("description"); + throw new ArgumentNullException(nameof(description)); if (createChangedDocument == null) - throw new ArgumentNullException("createChangedDocument"); + throw new ArgumentNullException(nameof(createChangedDocument)); return new DocumentChangeAction(textSpan, severity, description, createChangedDocument); } + public static CodeAction Create(TextSpan textSpan, DiagnosticSeverity severity, string description, Solution changedSolution) + { + if (description == null) + throw new ArgumentNullException(nameof(description)); + if (changedSolution == null) + throw new ArgumentNullException(nameof(changedSolution)); + return new DocumentChangeAction(textSpan, severity, description, ct => Task.FromResult(changedSolution)); + } + + public static CodeAction Create(TextSpan textSpan, DiagnosticSeverity severity, string description, Func> createChangedSolution) + { + if (description == null) + throw new ArgumentNullException(nameof(description)); + if (createChangedSolution == null) + throw new ArgumentNullException(nameof(createChangedSolution)); + return new DocumentChangeAction(textSpan, severity, description, createChangedSolution); + } + public static CodeAction CreateInsertion(TextSpan textSpan, DiagnosticSeverity severity, string description, Func> createInsertion) { if (description == null) - throw new ArgumentNullException("description"); + throw new ArgumentNullException(nameof(description)); if (createInsertion == null) - throw new ArgumentNullException("createInsertion"); + throw new ArgumentNullException(nameof(createInsertion)); return new InsertionAction(textSpan, severity, description, createInsertion); } } diff --git a/RefactoringEssentials/CSharp/Diagnostics/DocumentChangeAction.cs b/RefactoringEssentials/CSharp/Diagnostics/DocumentChangeAction.cs index 8905e589..be96264f 100644 --- a/RefactoringEssentials/CSharp/Diagnostics/DocumentChangeAction.cs +++ b/RefactoringEssentials/CSharp/Diagnostics/DocumentChangeAction.cs @@ -13,6 +13,7 @@ public sealed class DocumentChangeAction : NRefactoryCodeAction { readonly string title; readonly Func> createChangedDocument; + readonly Func> createChangedSolution; public override string Title { @@ -28,12 +29,30 @@ public DocumentChangeAction(TextSpan textSpan, DiagnosticSeverity severity, stri this.createChangedDocument = createChangedDocument; } + public DocumentChangeAction(TextSpan textSpan, DiagnosticSeverity severity, string title, Func> createChangedSolution) : base(textSpan, severity) + { + this.title = title; + this.createChangedSolution = createChangedSolution; + } + protected override Task GetChangedDocumentAsync(CancellationToken cancellationToken) { + if (createChangedDocument == null) + return base.GetChangedDocumentAsync(cancellationToken); + var task = createChangedDocument.Invoke(cancellationToken); return task; } + protected override Task GetChangedSolutionAsync(CancellationToken cancellationToken) + { + if (createChangedSolution == null) + return base.GetChangedSolutionAsync(cancellationToken); + + var task = createChangedSolution.Invoke(cancellationToken); + return task; + } + protected override async Task PostProcessChangesAsync(Document document, CancellationToken cancellationToken) { document = await Simplifier.ReduceAsync(document, Simplifier.Annotation, cancellationToken: cancellationToken).ConfigureAwait(false); diff --git a/RefactoringEssentials/CodeRefactorings.CSharp.html b/RefactoringEssentials/CodeRefactorings.CSharp.html index c1324518..5def2e9f 100644 --- a/RefactoringEssentials/CodeRefactorings.CSharp.html +++ b/RefactoringEssentials/CodeRefactorings.CSharp.html @@ -15,7 +15,7 @@ -->

Supported Refactorings

-

101 code refactorings for C#

+

102 code refactorings for C#

  • Adds another accessor (AddAnotherAccessorCodeRefactoringProvider)
  • Add braces (AddBracesCodeRefactoringProvider)
  • @@ -54,6 +54,7 @@

    Supported Refactorings

  • Convert 'if' to 'return' (ConvertIfStatementToReturnStatementAction)
  • Convert 'if' to 'switch' (ConvertIfStatementToSwitchStatementCodeRefactoringProvider)
  • Convert implict to explicit implementation (ConvertImplicitToExplicitImplementationCodeRefactoringProvider)
  • +
  • Convert instance to static method (ConvertInstanceToStaticMethodCodeRefactoringProvider)
  • Convert string interpolation to 'string.Format' (ConvertInterpolatedStringToStringFormatCodeRefactoringProvider)
  • Converts expression of lambda body to statement (ConvertLambdaBodyExpressionToStatementCodeRefactoringProvider)
  • Converts statement of lambda body to expression (ConvertLambdaStatementToExpressionCodeRefactoringProvider)
  • diff --git a/RefactoringEssentials/RefactoringEssentials.csproj b/RefactoringEssentials/RefactoringEssentials.csproj index eaba7dca..7c309cda 100644 --- a/RefactoringEssentials/RefactoringEssentials.csproj +++ b/RefactoringEssentials/RefactoringEssentials.csproj @@ -61,6 +61,7 @@ + diff --git a/Tests/CSharp/CodeRefactorings/ConvertInstanceToStaticMethodCodeRefactoringTests.cs b/Tests/CSharp/CodeRefactorings/ConvertInstanceToStaticMethodCodeRefactoringTests.cs new file mode 100644 index 00000000..552774ac --- /dev/null +++ b/Tests/CSharp/CodeRefactorings/ConvertInstanceToStaticMethodCodeRefactoringTests.cs @@ -0,0 +1,279 @@ +using NUnit.Framework; +using RefactoringEssentials.CSharp; +using RefactoringEssentials.Tests.CSharp.CodeRefactorings; + +namespace RefactoringEssentials.Tests.CSharp +{ + /// + /// Tests for ConvertInstanceToStaticMethodCodeRefactoringProvider. + /// + [TestFixture] + public class ConvertInstanceToStaticMethodCodeRefactoringTests : CSharpCodeRefactoringTestBase + { + [Test] + public void MethodWithoutParameters1() + { + Test(@" +class Foo +{ + void $Test() + { + int a = 0; + } +}", @" +class Foo +{ + static void Test(Foo instance) + { + int a = 0; + } +}"); + } + + [Test] + public void MethodWithoutParameters2() + { + Test(@" +class Foo +{ + public void $Test() + { + int a = 0; + } +}", @" +class Foo +{ + public static void Test(Foo instance) + { + int a = 0; + } +}"); + } + + [Test] + public void MethodWithParameters() + { + Test(@" +class Foo +{ + public void $Test(int b) + { + int a = 0; + } +}", @" +class Foo +{ + public static void Test(Foo instance, int b) + { + int a = 0; + } +}"); + } + + [Test] + public void AlreadyStaticMethod() + { + TestWrongContext(@" +class Foo +{ + public static void $Test(int b) + { + int a = 0; + } +}"); + } + + [Test] + public void MethodUsingInstanceMember() + { + Test(@" +class Foo +{ + int member; + + void AnotherMethod(int a) + { + } + + void $Test() + { + int a = 0; + member = a; + AnotherMethod(a); + } +}", @" +class Foo +{ + int member; + + void AnotherMethod(int a) + { + } + + static void Test(Foo instance) + { + int a = 0; + instance.member = a; + instance.AnotherMethod(a); + } +}"); + } + + [Test] + public void MethodUsingInstanceMemberWithThis() + { + Test(@" +class Foo +{ + int member; + + void AnotherMethod(int a) + { + } + + void $Test() + { + int a = 0; + this.member = a; + this.AnotherMethod(a); + } +}", @" +class Foo +{ + int member; + + void AnotherMethod(int a) + { + } + + static void Test(Foo instance) + { + int a = 0; + instance.member = a; + instance.AnotherMethod(a); + } +}"); + } + + [Test] + public void RecursiveMethodCall() + { + Test(@" +class Foo +{ + int member; + + void $Test() + { + int a = 0; + Test(); + } +}", @" +class Foo +{ + int member; + + static void Test(Foo instance) + { + int a = 0; + Foo.Test(instance); + } +}"); + } + + [Test] + public void RecursiveMethodCallWithThis() + { + Test(@" +class Foo +{ + int member; + + void $Test() + { + int a = 0; + this.Test(); + } +}", @" +class Foo +{ + int member; + + static void Test(Foo instance) + { + int a = 0; + Foo.Test(instance); + } +}"); + } + + [Test] + public void MethodWithInternalReference() + { + Test(@" +class Foo +{ + int member; + + void AnotherMethod(int a) + { + Test(); + } + + void $Test() + { + } +}", @" +class Foo +{ + int member; + + void AnotherMethod(int a) + { + Foo.Test(this); + } + + static void Test(Foo instance) + { + } +}"); + } + + [Test] + public void MethodWithExternalReference() + { + Test(@" +class Foo +{ + public void $Test() + { + int a = 0; + } +} + +class Foo2 +{ + void Test(Foo foo) + { + foo.Test(); + } +}", @" +class Foo +{ + public static void Test(Foo instance) + { + int a = 0; + } +} + +class Foo2 +{ + void Test(Foo foo) + { + Foo.Test(foo); + } +}"); + } + } +} + diff --git a/Tests/Tests.csproj b/Tests/Tests.csproj index 14c29eca..26340638 100644 --- a/Tests/Tests.csproj +++ b/Tests/Tests.csproj @@ -173,6 +173,7 @@ +