Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

using MSTest.Analyzers.Helpers;

using Polyfills;

namespace MSTest.Analyzers;

/// <summary>
Expand All @@ -31,8 +33,8 @@ public sealed class FlowTestContextCancellationTokenFixer : CodeFixProvider

/// <inheritdoc />
public override FixAllProvider GetFixAllProvider()
// See https://github.com/dotnet/roslyn/blob/main/docs/analyzers/FixAllProvider.md for more information on Fix All Providers
=> WellKnownFixAllProviders.BatchFixer;
// Use custom FixAllProvider to handle adding TestContext property when needed
=> FlowTestContextCancellationTokenFixAllProvider.Instance;

/// <inheritdoc />
public sealed override async Task RegisterCodeFixesAsync(CodeFixContext context)
Expand All @@ -49,37 +51,213 @@ public sealed override async Task RegisterCodeFixesAsync(CodeFixContext context)
}

diagnostic.Properties.TryGetValue(FlowTestContextCancellationTokenAnalyzer.TestContextMemberNamePropertyKey, out string? testContextMemberName);
diagnostic.Properties.TryGetValue(nameof(FlowTestContextCancellationTokenAnalyzer.TestContextState), out string? testContextState);

// Register a code action that will invoke the fix
context.RegisterCodeFix(
CodeAction.Create(
title: CodeFixResources.PassCancellationTokenFix,
createChangedDocument: c => AddCancellationTokenParameterAsync(context.Document, invocationExpression, testContextMemberName, c),
equivalenceKey: "AddTestContextCancellationToken"),
createChangedDocument: async c =>
{
DocumentEditor editor = await DocumentEditor.CreateAsync(context.Document, context.CancellationToken).ConfigureAwait(false);
return ApplyFix(editor, invocationExpression, testContextMemberName, testContextState, adjustedSymbols: null, c);
},
equivalenceKey: nameof(FlowTestContextCancellationTokenFixer)),
diagnostic);
}

private static async Task<Document> AddCancellationTokenParameterAsync(
Document document,
internal static Document ApplyFix(
DocumentEditor editor,
InvocationExpressionSyntax invocationExpression,
string? testContextMemberName,
string? testContextState,
HashSet<ISymbol>? adjustedSymbols,
CancellationToken cancellationToken)
{
DocumentEditor editor = await DocumentEditor.CreateAsync(document, cancellationToken).ConfigureAwait(false);
if (testContextState == nameof(FlowTestContextCancellationTokenAnalyzer.TestContextState.CouldBeInScopeAsProperty))
{
Debug.Assert(testContextMemberName is null, "TestContext member name should be null when state is CouldBeInScopeAsProperty");
AddCancellationTokenArgument(editor, invocationExpression, "TestContext");
TypeDeclarationSyntax? containingTypeDeclaration = invocationExpression.FirstAncestorOrSelf<TypeDeclarationSyntax>();
if (containingTypeDeclaration is not null)
{
// adjustedSymbols is null meaning we are only applying a single fix (in that case we add the property).
// If we are in fix all, we then verify if a previous fix has already added the property.
// We only add the property if it wasn't added by a previous fix.
// NOTE: We don't expect GetDeclaredSymbol to return null, but if it did (e.g, error scenario), we add the property.
if (adjustedSymbols is null ||
editor.SemanticModel.GetDeclaredSymbol(containingTypeDeclaration, cancellationToken) is not { } symbol ||
adjustedSymbols.Add(symbol))
{
editor.ReplaceNode(containingTypeDeclaration, (containingTypeDeclaration, _) => AddTestContextProperty((TypeDeclarationSyntax)containingTypeDeclaration));
}
}
}
else if (testContextState == nameof(FlowTestContextCancellationTokenAnalyzer.TestContextState.CouldBeInScopeAsParameter))
{
Debug.Assert(testContextMemberName is null, "TestContext member name should be null when state is CouldBeInScopeAsParameter");
AddCancellationTokenArgument(editor, invocationExpression, "testContext");
MethodDeclarationSyntax? containingMethodDeclaration = invocationExpression.FirstAncestorOrSelf<MethodDeclarationSyntax>();

if (containingMethodDeclaration is not null)
{
// adjustedSymbols is null meaning we are only applying a single fix (in that case we add the parameter).
// If we are in fix all, we then verify if a previous fix has already added the parameter.
// We only add the parameter if it wasn't added by a previous fix.
// NOTE: We don't expect GetDeclaredSymbol to return null, but if it did (e.g, error scenario), we add the property.
if (adjustedSymbols is null ||
editor.SemanticModel.GetDeclaredSymbol(containingMethodDeclaration, cancellationToken) is not { } symbol ||
adjustedSymbols.Add(symbol))
{
editor.ReplaceNode(containingMethodDeclaration, (containingMethodDeclaration, _) => AddTestContextParameterToMethod((MethodDeclarationSyntax)containingMethodDeclaration));
}
}
}
else
{
Guard.NotNull(testContextMemberName);
AddCancellationTokenArgument(editor, invocationExpression, testContextMemberName);
}

return editor.GetChangedDocument();
}

internal static void AddCancellationTokenArgument(
DocumentEditor editor,
InvocationExpressionSyntax invocationExpression,
string testContextMemberName)
{
// Find the containing method to determine the context
MethodDeclarationSyntax? containingMethod = invocationExpression.FirstAncestorOrSelf<MethodDeclarationSyntax>();

// Create the TestContext.CancellationTokenSource.Token expression
MemberAccessExpressionSyntax testContextExpression = SyntaxFactory.MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
SyntaxFactory.MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
SyntaxFactory.IdentifierName(testContextMemberName ?? "testContext"),
SyntaxFactory.IdentifierName(testContextMemberName),
SyntaxFactory.IdentifierName("CancellationTokenSource")),
SyntaxFactory.IdentifierName("Token"));

ArgumentListSyntax currentArguments = invocationExpression.ArgumentList;
SeparatedSyntaxList<ArgumentSyntax> newArguments = currentArguments.Arguments.Add(SyntaxFactory.Argument(testContextExpression));
InvocationExpressionSyntax newInvocation = invocationExpression.WithArgumentList(currentArguments.WithArguments(newArguments));
editor.ReplaceNode(invocationExpression, newInvocation);
return editor.GetChangedDocument();
editor.ReplaceNode(invocationExpression, (node, _) =>
{
var invocationExpression = (InvocationExpressionSyntax)node;
ArgumentListSyntax currentArguments = invocationExpression.ArgumentList;
SeparatedSyntaxList<ArgumentSyntax> newArguments = currentArguments.Arguments.Add(SyntaxFactory.Argument(testContextExpression));
return invocationExpression.WithArgumentList(currentArguments.WithArguments(newArguments));
});
}

internal static MethodDeclarationSyntax AddTestContextParameterToMethod(MethodDeclarationSyntax method)
{
// Create TestContext parameter
ParameterSyntax testContextParameter = SyntaxFactory.Parameter(SyntaxFactory.Identifier("testContext"))
.WithType(SyntaxFactory.IdentifierName("TestContext"));

// Add the parameter to the method
SeparatedSyntaxList<ParameterSyntax> updatedParameterList = method.ParameterList.Parameters.Count == 0
? SyntaxFactory.SingletonSeparatedList(testContextParameter)
: method.ParameterList.Parameters.Add(testContextParameter);

return method.WithParameterList(method.ParameterList.WithParameters(updatedParameterList));
}

internal static TypeDeclarationSyntax AddTestContextProperty(TypeDeclarationSyntax typeDeclaration)
{
PropertyDeclarationSyntax testContextProperty = SyntaxFactory.PropertyDeclaration(
SyntaxFactory.IdentifierName("TestContext"),
"TestContext")
.WithModifiers(SyntaxFactory.TokenList(SyntaxFactory.Token(SyntaxKind.PublicKeyword)))
.WithAccessorList(SyntaxFactory.AccessorList(
SyntaxFactory.List(new[]
{
SyntaxFactory.AccessorDeclaration(SyntaxKind.GetAccessorDeclaration)
.WithSemicolonToken(SyntaxFactory.Token(SyntaxKind.SemicolonToken)),
SyntaxFactory.AccessorDeclaration(SyntaxKind.SetAccessorDeclaration)
.WithSemicolonToken(SyntaxFactory.Token(SyntaxKind.SemicolonToken)),
})));

return typeDeclaration.AddMembers(testContextProperty);
}
}

/// <summary>
/// Custom FixAllProvider for <see cref="FlowTestContextCancellationTokenFixer"/> that can add TestContext property when needed.
/// This ensures that when multiple fixes are applied to the same class, the TestContext property is added only once.
/// </summary>
internal sealed class FlowTestContextCancellationTokenFixAllProvider : FixAllProvider
{
public static readonly FlowTestContextCancellationTokenFixAllProvider Instance = new();

private FlowTestContextCancellationTokenFixAllProvider()
{
}

public override Task<CodeAction?> GetFixAsync(FixAllContext fixAllContext)
=> Task.FromResult<CodeAction?>(new FixAllCodeAction(fixAllContext));

private sealed class FixAllCodeAction : CodeAction
{
private readonly FixAllContext _fixAllContext;

public FixAllCodeAction(FixAllContext fixAllContext)
=> _fixAllContext = fixAllContext;

public override string Title => CodeFixResources.PassCancellationTokenFix;

public override string? EquivalenceKey => nameof(FlowTestContextCancellationTokenFixer);

protected override async Task<Solution?> GetChangedSolutionAsync(CancellationToken cancellationToken)
{
FixAllContext fixAllContext = _fixAllContext;
var editor = new SolutionEditor(fixAllContext.Solution);
var fixedSymbols = new HashSet<ISymbol>(SymbolEqualityComparer.Default);

if (fixAllContext.Scope == FixAllScope.Document)
{
DocumentEditor documentEditor = await editor.GetDocumentEditorAsync(fixAllContext.Document!.Id, cancellationToken).ConfigureAwait(false);
foreach (Diagnostic diagnostic in await fixAllContext.GetDocumentDiagnosticsAsync(fixAllContext.Document!).ConfigureAwait(false))
{
FixOneDiagnostic(documentEditor, diagnostic, fixedSymbols, cancellationToken);
}
}
else if (fixAllContext.Scope == FixAllScope.Project)
{
await FixAllInProjectAsync(fixAllContext, fixAllContext.Project, editor, fixedSymbols, cancellationToken).ConfigureAwait(false);
}
else if (fixAllContext.Scope == FixAllScope.Solution)
{
foreach (Project project in fixAllContext.Solution.Projects)
{
await FixAllInProjectAsync(fixAllContext, project, editor, fixedSymbols, cancellationToken).ConfigureAwait(false);
}
}

return editor.GetChangedSolution();
}

private static async Task FixAllInProjectAsync(FixAllContext fixAllContext, Project project, SolutionEditor editor, HashSet<ISymbol> fixedSymbols, CancellationToken cancellationToken)
{
foreach (Diagnostic diagnostic in await fixAllContext.GetAllDiagnosticsAsync(project).ConfigureAwait(false))
{
DocumentId documentId = editor.OriginalSolution.GetDocumentId(diagnostic.Location.SourceTree)!;
DocumentEditor documentEditor = await editor.GetDocumentEditorAsync(documentId, cancellationToken).ConfigureAwait(false);
FixOneDiagnostic(documentEditor, diagnostic, fixedSymbols, cancellationToken);
}
}

private static void FixOneDiagnostic(DocumentEditor documentEditor, Diagnostic diagnostic, HashSet<ISymbol> fixedSymbols, CancellationToken cancellationToken)
{
SyntaxNode node = documentEditor.OriginalRoot.FindNode(diagnostic.Location.SourceSpan, getInnermostNodeForTie: true);
if (node is not InvocationExpressionSyntax invocationExpression)
{
return;
}

diagnostic.Properties.TryGetValue(FlowTestContextCancellationTokenAnalyzer.TestContextMemberNamePropertyKey, out string? testContextMemberName);
diagnostic.Properties.TryGetValue(nameof(FlowTestContextCancellationTokenAnalyzer.TestContextState), out string? testContextState);

FlowTestContextCancellationTokenFixer.ApplyFix(documentEditor, invocationExpression, testContextMemberName, testContextState, fixedSymbols, cancellationToken);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ private static void AnalyzeInvocation(
IMethodSymbol method = invocationOperation.TargetMethod;

// Check if we're in a context where a TestContext is already available or could be made available.
if (!HasOrCouldHaveTestContextInScope(context.ContainingSymbol, testContextSymbol, classCleanupAttributeSymbol, assemblyCleanupAttributeSymbol, testMethodAttributeSymbol, out string? testContextMemberNameInScope))
if (!HasOrCouldHaveTestContextInScope(context.ContainingSymbol, testContextSymbol, classCleanupAttributeSymbol, assemblyCleanupAttributeSymbol, testMethodAttributeSymbol, out string? testContextMemberNameInScope, out TestContextState? testContextState))
{
return;
}
Expand All @@ -93,13 +93,7 @@ private static void AnalyzeInvocation(
invocationOperation.Arguments.FirstOrDefault(arg => SymbolEqualityComparer.Default.Equals(arg.Parameter, cancellationTokenParameter))?.ArgumentKind != ArgumentKind.Explicit)
{
// The called method has an optional CancellationToken parameter, but it was not explicitly provided.
ImmutableDictionary<string, string?> properties = ImmutableDictionary<string, string?>.Empty;
if (testContextMemberNameInScope is not null)
{
properties = properties.Add(TestContextMemberNamePropertyKey, testContextMemberNameInScope);
}

context.ReportDiagnostic(invocationOperation.Syntax.CreateDiagnostic(FlowTestContextCancellationTokenRule, properties: GetPropertiesBag(testContextMemberNameInScope)));
context.ReportDiagnostic(invocationOperation.Syntax.CreateDiagnostic(FlowTestContextCancellationTokenRule, properties: GetPropertiesBag(testContextMemberNameInScope, testContextState)));
return;
}

Expand All @@ -108,16 +102,15 @@ private static void AnalyzeInvocation(
if (cancellationTokenParameter is null &&
HasOverloadWithCancellationToken(method, cancellationTokenSymbol))
{
context.ReportDiagnostic(invocationOperation.Syntax.CreateDiagnostic(FlowTestContextCancellationTokenRule, properties: GetPropertiesBag(testContextMemberNameInScope)));
context.ReportDiagnostic(invocationOperation.Syntax.CreateDiagnostic(FlowTestContextCancellationTokenRule, properties: GetPropertiesBag(testContextMemberNameInScope, testContextState)));
}

static ImmutableDictionary<string, string?> GetPropertiesBag(string? testContextMemberNameInScope)
static ImmutableDictionary<string, string?> GetPropertiesBag(string? testContextMemberNameInScope, TestContextState? testContextState)
{
ImmutableDictionary<string, string?> properties = ImmutableDictionary<string, string?>.Empty;
if (testContextMemberNameInScope is not null)
{
properties = properties.Add(TestContextMemberNamePropertyKey, testContextMemberNameInScope);
}
properties = testContextMemberNameInScope is not null
? properties.Add(TestContextMemberNamePropertyKey, testContextMemberNameInScope)
: properties.Add(nameof(TestContextState), testContextState.ToString());

return properties;
}
Expand Down Expand Up @@ -151,9 +144,11 @@ private static bool HasOrCouldHaveTestContextInScope(
INamedTypeSymbol classCleanupAttributeSymbol,
INamedTypeSymbol assemblyCleanupAttributeSymbol,
INamedTypeSymbol testMethodAttributeSymbol,
out string? testContextMemberNameInScope)
out string? testContextMemberNameInScope,
[NotNullWhen(true)] out TestContextState? testContextState)
{
testContextMemberNameInScope = null;
testContextState = null;

if (containingSymbol is not IMethodSymbol method)
{
Expand All @@ -164,6 +159,7 @@ private static bool HasOrCouldHaveTestContextInScope(
if (method.Parameters.FirstOrDefault(p => testContextSymbol.Equals(p.Type, SymbolEqualityComparer.Default)) is { } testContextParameter)
{
testContextMemberNameInScope = testContextParameter.Name;
testContextState = TestContextState.InScope;
return true;
}

Expand All @@ -178,6 +174,7 @@ private static bool HasOrCouldHaveTestContextInScope(
testContextMemberNameInScope = testContextMember.Name.StartsWith('<') && testContextMember.Name.EndsWith(">P", StringComparison.Ordinal)
? testContextMember.Name.Substring(1, testContextMember.Name.Length - 3)
: testContextMember.Name;
testContextState = TestContextState.InScope;
return true;
}

Expand All @@ -191,11 +188,13 @@ private static bool HasOrCouldHaveTestContextInScope(
(classCleanupAttributeSymbol.Equals(attribute.AttributeClass, SymbolEqualityComparer.Default) ||
assemblyCleanupAttributeSymbol.Equals(attribute.AttributeClass, SymbolEqualityComparer.Default)))
{
testContextState = TestContextState.CouldBeInScopeAsParameter;
return true;
}

if (attribute.AttributeClass?.Inherits(testMethodAttributeSymbol) == true)
{
testContextState = TestContextState.CouldBeInScopeAsProperty;
return true;
}
}
Expand Down Expand Up @@ -228,4 +227,11 @@ private static bool IsCompatibleOverloadWithCancellationToken(IMethodSymbol orig
IParameterSymbol lastParam = candidateParams[candidateParams.Length - 1];
return SymbolEqualityComparer.Default.Equals(lastParam.Type, cancellationTokenSymbol);
}

internal enum TestContextState
{
InScope,
CouldBeInScopeAsParameter,
CouldBeInScopeAsProperty,
}
}
Loading