diff --git a/src/Analyzers/MSTest.Analyzers.CodeFixes/FlowTestContextCancellationTokenFixer.cs b/src/Analyzers/MSTest.Analyzers.CodeFixes/FlowTestContextCancellationTokenFixer.cs index 0941d05e6c..83ec8e124d 100644 --- a/src/Analyzers/MSTest.Analyzers.CodeFixes/FlowTestContextCancellationTokenFixer.cs +++ b/src/Analyzers/MSTest.Analyzers.CodeFixes/FlowTestContextCancellationTokenFixer.cs @@ -16,6 +16,8 @@ using MSTest.Analyzers.Helpers; +using Polyfills; + namespace MSTest.Analyzers; /// @@ -31,8 +33,8 @@ public sealed class FlowTestContextCancellationTokenFixer : CodeFixProvider /// public override FixAllProvider GetFixAllProvider() - // See https://github.com/dotnet/roslyn/blob/main/docs/analyzers/FixAllProvider.md for more information on Fix All Providers - => WellKnownFixAllProviders.BatchFixer; + // Use custom FixAllProvider to handle adding TestContext property when needed + => FlowTestContextCancellationTokenFixAllProvider.Instance; /// public sealed override async Task RegisterCodeFixesAsync(CodeFixContext context) @@ -49,37 +51,213 @@ public sealed override async Task RegisterCodeFixesAsync(CodeFixContext context) } diagnostic.Properties.TryGetValue(FlowTestContextCancellationTokenAnalyzer.TestContextMemberNamePropertyKey, out string? testContextMemberName); + diagnostic.Properties.TryGetValue(nameof(FlowTestContextCancellationTokenAnalyzer.TestContextState), out string? testContextState); // Register a code action that will invoke the fix context.RegisterCodeFix( CodeAction.Create( title: CodeFixResources.PassCancellationTokenFix, - createChangedDocument: c => AddCancellationTokenParameterAsync(context.Document, invocationExpression, testContextMemberName, c), - equivalenceKey: "AddTestContextCancellationToken"), + createChangedDocument: async c => + { + DocumentEditor editor = await DocumentEditor.CreateAsync(context.Document, context.CancellationToken).ConfigureAwait(false); + return ApplyFix(editor, invocationExpression, testContextMemberName, testContextState, adjustedSymbols: null, c); + }, + equivalenceKey: nameof(FlowTestContextCancellationTokenFixer)), diagnostic); } - private static async Task AddCancellationTokenParameterAsync( - Document document, + internal static Document ApplyFix( + DocumentEditor editor, InvocationExpressionSyntax invocationExpression, string? testContextMemberName, + string? testContextState, + HashSet? adjustedSymbols, CancellationToken cancellationToken) { - DocumentEditor editor = await DocumentEditor.CreateAsync(document, cancellationToken).ConfigureAwait(false); + if (testContextState == nameof(FlowTestContextCancellationTokenAnalyzer.TestContextState.CouldBeInScopeAsProperty)) + { + Debug.Assert(testContextMemberName is null, "TestContext member name should be null when state is CouldBeInScopeAsProperty"); + AddCancellationTokenArgument(editor, invocationExpression, "TestContext"); + TypeDeclarationSyntax? containingTypeDeclaration = invocationExpression.FirstAncestorOrSelf(); + if (containingTypeDeclaration is not null) + { + // adjustedSymbols is null meaning we are only applying a single fix (in that case we add the property). + // If we are in fix all, we then verify if a previous fix has already added the property. + // We only add the property if it wasn't added by a previous fix. + // NOTE: We don't expect GetDeclaredSymbol to return null, but if it did (e.g, error scenario), we add the property. + if (adjustedSymbols is null || + editor.SemanticModel.GetDeclaredSymbol(containingTypeDeclaration, cancellationToken) is not { } symbol || + adjustedSymbols.Add(symbol)) + { + editor.ReplaceNode(containingTypeDeclaration, (containingTypeDeclaration, _) => AddTestContextProperty((TypeDeclarationSyntax)containingTypeDeclaration)); + } + } + } + else if (testContextState == nameof(FlowTestContextCancellationTokenAnalyzer.TestContextState.CouldBeInScopeAsParameter)) + { + Debug.Assert(testContextMemberName is null, "TestContext member name should be null when state is CouldBeInScopeAsParameter"); + AddCancellationTokenArgument(editor, invocationExpression, "testContext"); + MethodDeclarationSyntax? containingMethodDeclaration = invocationExpression.FirstAncestorOrSelf(); + + if (containingMethodDeclaration is not null) + { + // adjustedSymbols is null meaning we are only applying a single fix (in that case we add the parameter). + // If we are in fix all, we then verify if a previous fix has already added the parameter. + // We only add the parameter if it wasn't added by a previous fix. + // NOTE: We don't expect GetDeclaredSymbol to return null, but if it did (e.g, error scenario), we add the property. + if (adjustedSymbols is null || + editor.SemanticModel.GetDeclaredSymbol(containingMethodDeclaration, cancellationToken) is not { } symbol || + adjustedSymbols.Add(symbol)) + { + editor.ReplaceNode(containingMethodDeclaration, (containingMethodDeclaration, _) => AddTestContextParameterToMethod((MethodDeclarationSyntax)containingMethodDeclaration)); + } + } + } + else + { + Guard.NotNull(testContextMemberName); + AddCancellationTokenArgument(editor, invocationExpression, testContextMemberName); + } + + return editor.GetChangedDocument(); + } + + internal static void AddCancellationTokenArgument( + DocumentEditor editor, + InvocationExpressionSyntax invocationExpression, + string testContextMemberName) + { + // Find the containing method to determine the context + MethodDeclarationSyntax? containingMethod = invocationExpression.FirstAncestorOrSelf(); // Create the TestContext.CancellationTokenSource.Token expression MemberAccessExpressionSyntax testContextExpression = SyntaxFactory.MemberAccessExpression( SyntaxKind.SimpleMemberAccessExpression, SyntaxFactory.MemberAccessExpression( SyntaxKind.SimpleMemberAccessExpression, - SyntaxFactory.IdentifierName(testContextMemberName ?? "testContext"), + SyntaxFactory.IdentifierName(testContextMemberName), SyntaxFactory.IdentifierName("CancellationTokenSource")), SyntaxFactory.IdentifierName("Token")); - ArgumentListSyntax currentArguments = invocationExpression.ArgumentList; - SeparatedSyntaxList newArguments = currentArguments.Arguments.Add(SyntaxFactory.Argument(testContextExpression)); - InvocationExpressionSyntax newInvocation = invocationExpression.WithArgumentList(currentArguments.WithArguments(newArguments)); - editor.ReplaceNode(invocationExpression, newInvocation); - return editor.GetChangedDocument(); + editor.ReplaceNode(invocationExpression, (node, _) => + { + var invocationExpression = (InvocationExpressionSyntax)node; + ArgumentListSyntax currentArguments = invocationExpression.ArgumentList; + SeparatedSyntaxList newArguments = currentArguments.Arguments.Add(SyntaxFactory.Argument(testContextExpression)); + return invocationExpression.WithArgumentList(currentArguments.WithArguments(newArguments)); + }); + } + + internal static MethodDeclarationSyntax AddTestContextParameterToMethod(MethodDeclarationSyntax method) + { + // Create TestContext parameter + ParameterSyntax testContextParameter = SyntaxFactory.Parameter(SyntaxFactory.Identifier("testContext")) + .WithType(SyntaxFactory.IdentifierName("TestContext")); + + // Add the parameter to the method + SeparatedSyntaxList updatedParameterList = method.ParameterList.Parameters.Count == 0 + ? SyntaxFactory.SingletonSeparatedList(testContextParameter) + : method.ParameterList.Parameters.Add(testContextParameter); + + return method.WithParameterList(method.ParameterList.WithParameters(updatedParameterList)); + } + + internal static TypeDeclarationSyntax AddTestContextProperty(TypeDeclarationSyntax typeDeclaration) + { + PropertyDeclarationSyntax testContextProperty = SyntaxFactory.PropertyDeclaration( + SyntaxFactory.IdentifierName("TestContext"), + "TestContext") + .WithModifiers(SyntaxFactory.TokenList(SyntaxFactory.Token(SyntaxKind.PublicKeyword))) + .WithAccessorList(SyntaxFactory.AccessorList( + SyntaxFactory.List(new[] + { + SyntaxFactory.AccessorDeclaration(SyntaxKind.GetAccessorDeclaration) + .WithSemicolonToken(SyntaxFactory.Token(SyntaxKind.SemicolonToken)), + SyntaxFactory.AccessorDeclaration(SyntaxKind.SetAccessorDeclaration) + .WithSemicolonToken(SyntaxFactory.Token(SyntaxKind.SemicolonToken)), + }))); + + return typeDeclaration.AddMembers(testContextProperty); + } +} + +/// +/// Custom FixAllProvider for that can add TestContext property when needed. +/// This ensures that when multiple fixes are applied to the same class, the TestContext property is added only once. +/// +internal sealed class FlowTestContextCancellationTokenFixAllProvider : FixAllProvider +{ + public static readonly FlowTestContextCancellationTokenFixAllProvider Instance = new(); + + private FlowTestContextCancellationTokenFixAllProvider() + { + } + + public override Task GetFixAsync(FixAllContext fixAllContext) + => Task.FromResult(new FixAllCodeAction(fixAllContext)); + + private sealed class FixAllCodeAction : CodeAction + { + private readonly FixAllContext _fixAllContext; + + public FixAllCodeAction(FixAllContext fixAllContext) + => _fixAllContext = fixAllContext; + + public override string Title => CodeFixResources.PassCancellationTokenFix; + + public override string? EquivalenceKey => nameof(FlowTestContextCancellationTokenFixer); + + protected override async Task GetChangedSolutionAsync(CancellationToken cancellationToken) + { + FixAllContext fixAllContext = _fixAllContext; + var editor = new SolutionEditor(fixAllContext.Solution); + var fixedSymbols = new HashSet(SymbolEqualityComparer.Default); + + if (fixAllContext.Scope == FixAllScope.Document) + { + DocumentEditor documentEditor = await editor.GetDocumentEditorAsync(fixAllContext.Document!.Id, cancellationToken).ConfigureAwait(false); + foreach (Diagnostic diagnostic in await fixAllContext.GetDocumentDiagnosticsAsync(fixAllContext.Document!).ConfigureAwait(false)) + { + FixOneDiagnostic(documentEditor, diagnostic, fixedSymbols, cancellationToken); + } + } + else if (fixAllContext.Scope == FixAllScope.Project) + { + await FixAllInProjectAsync(fixAllContext, fixAllContext.Project, editor, fixedSymbols, cancellationToken).ConfigureAwait(false); + } + else if (fixAllContext.Scope == FixAllScope.Solution) + { + foreach (Project project in fixAllContext.Solution.Projects) + { + await FixAllInProjectAsync(fixAllContext, project, editor, fixedSymbols, cancellationToken).ConfigureAwait(false); + } + } + + return editor.GetChangedSolution(); + } + + private static async Task FixAllInProjectAsync(FixAllContext fixAllContext, Project project, SolutionEditor editor, HashSet fixedSymbols, CancellationToken cancellationToken) + { + foreach (Diagnostic diagnostic in await fixAllContext.GetAllDiagnosticsAsync(project).ConfigureAwait(false)) + { + DocumentId documentId = editor.OriginalSolution.GetDocumentId(diagnostic.Location.SourceTree)!; + DocumentEditor documentEditor = await editor.GetDocumentEditorAsync(documentId, cancellationToken).ConfigureAwait(false); + FixOneDiagnostic(documentEditor, diagnostic, fixedSymbols, cancellationToken); + } + } + + private static void FixOneDiagnostic(DocumentEditor documentEditor, Diagnostic diagnostic, HashSet fixedSymbols, CancellationToken cancellationToken) + { + SyntaxNode node = documentEditor.OriginalRoot.FindNode(diagnostic.Location.SourceSpan, getInnermostNodeForTie: true); + if (node is not InvocationExpressionSyntax invocationExpression) + { + return; + } + + diagnostic.Properties.TryGetValue(FlowTestContextCancellationTokenAnalyzer.TestContextMemberNamePropertyKey, out string? testContextMemberName); + diagnostic.Properties.TryGetValue(nameof(FlowTestContextCancellationTokenAnalyzer.TestContextState), out string? testContextState); + + FlowTestContextCancellationTokenFixer.ApplyFix(documentEditor, invocationExpression, testContextMemberName, testContextState, fixedSymbols, cancellationToken); + } } } diff --git a/src/Analyzers/MSTest.Analyzers/FlowTestContextCancellationTokenAnalyzer.cs b/src/Analyzers/MSTest.Analyzers/FlowTestContextCancellationTokenAnalyzer.cs index cf38e10478..19eb3b5cde 100644 --- a/src/Analyzers/MSTest.Analyzers/FlowTestContextCancellationTokenAnalyzer.cs +++ b/src/Analyzers/MSTest.Analyzers/FlowTestContextCancellationTokenAnalyzer.cs @@ -74,7 +74,7 @@ private static void AnalyzeInvocation( IMethodSymbol method = invocationOperation.TargetMethod; // Check if we're in a context where a TestContext is already available or could be made available. - if (!HasOrCouldHaveTestContextInScope(context.ContainingSymbol, testContextSymbol, classCleanupAttributeSymbol, assemblyCleanupAttributeSymbol, testMethodAttributeSymbol, out string? testContextMemberNameInScope)) + if (!HasOrCouldHaveTestContextInScope(context.ContainingSymbol, testContextSymbol, classCleanupAttributeSymbol, assemblyCleanupAttributeSymbol, testMethodAttributeSymbol, out string? testContextMemberNameInScope, out TestContextState? testContextState)) { return; } @@ -93,13 +93,7 @@ private static void AnalyzeInvocation( invocationOperation.Arguments.FirstOrDefault(arg => SymbolEqualityComparer.Default.Equals(arg.Parameter, cancellationTokenParameter))?.ArgumentKind != ArgumentKind.Explicit) { // The called method has an optional CancellationToken parameter, but it was not explicitly provided. - ImmutableDictionary properties = ImmutableDictionary.Empty; - if (testContextMemberNameInScope is not null) - { - properties = properties.Add(TestContextMemberNamePropertyKey, testContextMemberNameInScope); - } - - context.ReportDiagnostic(invocationOperation.Syntax.CreateDiagnostic(FlowTestContextCancellationTokenRule, properties: GetPropertiesBag(testContextMemberNameInScope))); + context.ReportDiagnostic(invocationOperation.Syntax.CreateDiagnostic(FlowTestContextCancellationTokenRule, properties: GetPropertiesBag(testContextMemberNameInScope, testContextState))); return; } @@ -108,16 +102,15 @@ private static void AnalyzeInvocation( if (cancellationTokenParameter is null && HasOverloadWithCancellationToken(method, cancellationTokenSymbol)) { - context.ReportDiagnostic(invocationOperation.Syntax.CreateDiagnostic(FlowTestContextCancellationTokenRule, properties: GetPropertiesBag(testContextMemberNameInScope))); + context.ReportDiagnostic(invocationOperation.Syntax.CreateDiagnostic(FlowTestContextCancellationTokenRule, properties: GetPropertiesBag(testContextMemberNameInScope, testContextState))); } - static ImmutableDictionary GetPropertiesBag(string? testContextMemberNameInScope) + static ImmutableDictionary GetPropertiesBag(string? testContextMemberNameInScope, TestContextState? testContextState) { ImmutableDictionary properties = ImmutableDictionary.Empty; - if (testContextMemberNameInScope is not null) - { - properties = properties.Add(TestContextMemberNamePropertyKey, testContextMemberNameInScope); - } + properties = testContextMemberNameInScope is not null + ? properties.Add(TestContextMemberNamePropertyKey, testContextMemberNameInScope) + : properties.Add(nameof(TestContextState), testContextState.ToString()); return properties; } @@ -151,9 +144,11 @@ private static bool HasOrCouldHaveTestContextInScope( INamedTypeSymbol classCleanupAttributeSymbol, INamedTypeSymbol assemblyCleanupAttributeSymbol, INamedTypeSymbol testMethodAttributeSymbol, - out string? testContextMemberNameInScope) + out string? testContextMemberNameInScope, + [NotNullWhen(true)] out TestContextState? testContextState) { testContextMemberNameInScope = null; + testContextState = null; if (containingSymbol is not IMethodSymbol method) { @@ -164,6 +159,7 @@ private static bool HasOrCouldHaveTestContextInScope( if (method.Parameters.FirstOrDefault(p => testContextSymbol.Equals(p.Type, SymbolEqualityComparer.Default)) is { } testContextParameter) { testContextMemberNameInScope = testContextParameter.Name; + testContextState = TestContextState.InScope; return true; } @@ -178,6 +174,7 @@ private static bool HasOrCouldHaveTestContextInScope( testContextMemberNameInScope = testContextMember.Name.StartsWith('<') && testContextMember.Name.EndsWith(">P", StringComparison.Ordinal) ? testContextMember.Name.Substring(1, testContextMember.Name.Length - 3) : testContextMember.Name; + testContextState = TestContextState.InScope; return true; } @@ -191,11 +188,13 @@ private static bool HasOrCouldHaveTestContextInScope( (classCleanupAttributeSymbol.Equals(attribute.AttributeClass, SymbolEqualityComparer.Default) || assemblyCleanupAttributeSymbol.Equals(attribute.AttributeClass, SymbolEqualityComparer.Default))) { + testContextState = TestContextState.CouldBeInScopeAsParameter; return true; } if (attribute.AttributeClass?.Inherits(testMethodAttributeSymbol) == true) { + testContextState = TestContextState.CouldBeInScopeAsProperty; return true; } } @@ -228,4 +227,11 @@ private static bool IsCompatibleOverloadWithCancellationToken(IMethodSymbol orig IParameterSymbol lastParam = candidateParams[candidateParams.Length - 1]; return SymbolEqualityComparer.Default.Equals(lastParam.Type, cancellationTokenSymbol); } + + internal enum TestContextState + { + InScope, + CouldBeInScopeAsParameter, + CouldBeInScopeAsProperty, + } } diff --git a/test/UnitTests/MSTest.Analyzers.UnitTests/FlowTestContextCancellationTokenAnalyzerTests.cs b/test/UnitTests/MSTest.Analyzers.UnitTests/FlowTestContextCancellationTokenAnalyzerTests.cs index 95022f5cb7..9ba0988c3a 100644 --- a/test/UnitTests/MSTest.Analyzers.UnitTests/FlowTestContextCancellationTokenAnalyzerTests.cs +++ b/test/UnitTests/MSTest.Analyzers.UnitTests/FlowTestContextCancellationTokenAnalyzerTests.cs @@ -296,11 +296,11 @@ public class MyTestClass public static async Task ClassCleanup() { await [|Task.Delay(1000)|]; + await [|Task.Delay(1000)|]; } } """; - // Codefix doesn't yet handle the addition of TestContext parameter. string fixedCode = """ using Microsoft.VisualStudio.TestTools.UnitTesting; using System.Threading; @@ -310,9 +310,10 @@ public static async Task ClassCleanup() public class MyTestClass { [ClassCleanup] - public static async Task ClassCleanup() + public static async Task ClassCleanup(TestContext testContext) { - await Task.Delay(1000, {|CS0103:testContext|}.CancellationTokenSource.Token); + await Task.Delay(1000, testContext.CancellationTokenSource.Token); + await Task.Delay(1000, testContext.CancellationTokenSource.Token); } } """; @@ -343,6 +344,7 @@ public async Task Test1() public async Task Test2(int _) { await [|Task.Delay(1000)|]; + await [|Task.Delay(1000)|]; } } """; @@ -358,7 +360,7 @@ public class MyTestClass [TestMethod] public async Task Test1() { - await Task.Delay(1000, {|CS0103:testContext|}.CancellationTokenSource.Token); + await Task.Delay(1000, TestContext.CancellationTokenSource.Token); } [TestMethod] @@ -366,8 +368,11 @@ public async Task Test1() [DataRow(1)] public async Task Test2(int _) { - await Task.Delay(1000, {|CS0103:testContext|}.CancellationTokenSource.Token); + await Task.Delay(1000, TestContext.CancellationTokenSource.Token); + await Task.Delay(1000, TestContext.CancellationTokenSource.Token); } + + public TestContext TestContext { get; set; } } """; @@ -531,4 +536,206 @@ public async Task MyTestMethod() await VerifyCS.VerifyCodeFixAsync(code, fixedCode); } + + [TestMethod] + public async Task WhenMultipleFixesInSameClassWithoutTestContext_ShouldAddPropertyOnce() + { + string code = """ + using Microsoft.VisualStudio.TestTools.UnitTesting; + using System.Threading; + using System.Threading.Tasks; + + [TestClass] + public class MyTestClass + { + [TestMethod] + public async Task Test1() + { + await [|Task.Delay(1000)|]; + await [|Task.Delay(2000)|]; + } + + [TestMethod] + public async Task Test2() + { + await [|Task.Delay(3000)|]; + } + } + """; + + string fixedCode = """ + using Microsoft.VisualStudio.TestTools.UnitTesting; + using System.Threading; + using System.Threading.Tasks; + + [TestClass] + public class MyTestClass + { + [TestMethod] + public async Task Test1() + { + await Task.Delay(1000, TestContext.CancellationTokenSource.Token); + await Task.Delay(2000, TestContext.CancellationTokenSource.Token); + } + + [TestMethod] + public async Task Test2() + { + await Task.Delay(3000, TestContext.CancellationTokenSource.Token); + } + + public TestContext TestContext { get; set; } + } + """; + + await VerifyCS.VerifyCodeFixAsync(code, fixedCode); + } + + [TestMethod] + public async Task WhenMultipleFixesInSameClassMultiplePartialsWithoutTestContext_ShouldAddPropertyOnce() + { + var test = new VerifyCS.Test + { + TestState = + { + Sources = + { + """ + using Microsoft.VisualStudio.TestTools.UnitTesting; + using System.Threading; + using System.Threading.Tasks; + + [TestClass] + public partial class MyTestClass + { + [TestMethod] + public async Task Test1() + { + await [|Task.Delay(1000)|]; + await [|Task.Delay(2000)|]; + } + + [TestMethod] + public async Task Test2() + { + await [|Task.Delay(3000)|]; + } + } + """, + """ + using Microsoft.VisualStudio.TestTools.UnitTesting; + using System.Threading; + using System.Threading.Tasks; + + public partial class MyTestClass + { + [TestMethod] + public async Task Test3() + { + await [|Task.Delay(1000)|]; + await [|Task.Delay(2000)|]; + } + + [TestMethod] + public async Task Test4() + { + await [|Task.Delay(3000)|]; + } + } + """, + }, + }, + FixedState = + { + Sources = + { + """ + using Microsoft.VisualStudio.TestTools.UnitTesting; + using System.Threading; + using System.Threading.Tasks; + + [TestClass] + public partial class MyTestClass + { + [TestMethod] + public async Task Test1() + { + await Task.Delay(1000, TestContext.CancellationTokenSource.Token); + await Task.Delay(2000, TestContext.CancellationTokenSource.Token); + } + + [TestMethod] + public async Task Test2() + { + await Task.Delay(3000, TestContext.CancellationTokenSource.Token); + } + + public TestContext TestContext { get; set; } + } + """, + """ + using Microsoft.VisualStudio.TestTools.UnitTesting; + using System.Threading; + using System.Threading.Tasks; + + public partial class MyTestClass + { + [TestMethod] + public async Task Test3() + { + await Task.Delay(1000, TestContext.CancellationTokenSource.Token); + await Task.Delay(2000, TestContext.CancellationTokenSource.Token); + } + + [TestMethod] + public async Task Test4() + { + await Task.Delay(3000, TestContext.CancellationTokenSource.Token); + } + } + """, + }, + }, + }; + + await test.RunAsync(); + } + + [TestMethod] + public async Task WhenInAssemblyCleanupWithoutTestContextParameter_Diagnostic() + { + string code = """ + using Microsoft.VisualStudio.TestTools.UnitTesting; + using System.Threading; + using System.Threading.Tasks; + + [TestClass] + public class MyTestClass + { + [AssemblyCleanup] + public static async Task AssemblyCleanup() + { + await [|Task.Delay(1000)|]; + } + } + """; + + string fixedCode = """ + using Microsoft.VisualStudio.TestTools.UnitTesting; + using System.Threading; + using System.Threading.Tasks; + + [TestClass] + public class MyTestClass + { + [AssemblyCleanup] + public static async Task AssemblyCleanup(TestContext testContext) + { + await Task.Delay(1000, testContext.CancellationTokenSource.Token); + } + } + """; + + await VerifyCS.VerifyCodeFixAsync(code, fixedCode); + } }