diff --git a/src/Microsoft.VisualStudio.Threading.Analyzers.CSharp/VSTHRD003UseJtfRunAsyncAnalyzer.cs b/src/Microsoft.VisualStudio.Threading.Analyzers.CSharp/VSTHRD003UseJtfRunAsyncAnalyzer.cs index 884881890..897fde8fa 100644 --- a/src/Microsoft.VisualStudio.Threading.Analyzers.CSharp/VSTHRD003UseJtfRunAsyncAnalyzer.cs +++ b/src/Microsoft.VisualStudio.Threading.Analyzers.CSharp/VSTHRD003UseJtfRunAsyncAnalyzer.cs @@ -195,15 +195,30 @@ private void AnalyzeAwaitExpression(SyntaxNodeAnalysisContext context) } // Do not report a warning if the task is a member of an object that was created in this method. - if (memberAccessExpression.Expression is IdentifierNameSyntax identifier && - semanticModel.GetSymbolInfo(identifier, cancellationToken).Symbol is ILocalSymbol local) + if (memberAccessExpression.Expression is IdentifierNameSyntax identifier) { - // Search for assignments to the local and see if it was to a new object. - containingFunc ??= CSharpUtils.GetContainingFunction(focusedExpression); - if (containingFunc.Value.BlockOrExpression is not null && - CSharpUtils.FindAssignedValuesWithin(containingFunc.Value.BlockOrExpression, semanticModel, local, cancellationToken).Any(v => v is ObjectCreationExpressionSyntax)) + ISymbol? symbol = semanticModel.GetSymbolInfo(identifier, cancellationToken).Symbol; + switch (symbol) { - return null; + case ILocalSymbol local: + // Search for assignments to the local and see if it was to a new object or the result of an invocation. + containingFunc ??= CSharpUtils.GetContainingFunction(focusedExpression); + if (containingFunc.Value.BlockOrExpression is not null && + CSharpUtils.FindAssignedValuesWithin(containingFunc.Value.BlockOrExpression, semanticModel, local, cancellationToken).Any( + v => v is ObjectCreationExpressionSyntax or ImplicitObjectCreationExpressionSyntax or InvocationExpressionSyntax)) + { + return null; + } + + break; + case IParameterSymbol parameter: + // We allow returning members of a parameter in a lambda, to support `.Select(x => x.Completion)` syntax. + if (parameter.ContainingSymbol is IMethodSymbol method && method.MethodKind == MethodKind.AnonymousFunction) + { + return null; + } + + break; } } } diff --git a/test/Microsoft.VisualStudio.Threading.Analyzers.Tests/Helpers/CSharpCodeFixVerifier`2+Test.cs b/test/Microsoft.VisualStudio.Threading.Analyzers.Tests/Helpers/CSharpCodeFixVerifier`2+Test.cs index 298425a59..18ae7d018 100644 --- a/test/Microsoft.VisualStudio.Threading.Analyzers.Tests/Helpers/CSharpCodeFixVerifier`2+Test.cs +++ b/test/Microsoft.VisualStudio.Threading.Analyzers.Tests/Helpers/CSharpCodeFixVerifier`2+Test.cs @@ -73,7 +73,7 @@ where resourceName.StartsWith(additionalFilePrefix, StringComparison.Ordinal) protected override ParseOptions CreateParseOptions() { - return ((CSharpParseOptions)base.CreateParseOptions()).WithLanguageVersion(LanguageVersion.CSharp8); + return ((CSharpParseOptions)base.CreateParseOptions()).WithLanguageVersion(LanguageVersion.CSharp11); } private static string ReadManifestResource(Assembly assembly, string resourceName) diff --git a/test/Microsoft.VisualStudio.Threading.Analyzers.Tests/VSTHRD003UseJtfRunAsyncAnalyzerTests.cs b/test/Microsoft.VisualStudio.Threading.Analyzers.Tests/VSTHRD003UseJtfRunAsyncAnalyzerTests.cs index 1003c7c24..4a410ae02 100644 --- a/test/Microsoft.VisualStudio.Threading.Analyzers.Tests/VSTHRD003UseJtfRunAsyncAnalyzerTests.cs +++ b/test/Microsoft.VisualStudio.Threading.Analyzers.Tests/VSTHRD003UseJtfRunAsyncAnalyzerTests.cs @@ -1310,6 +1310,29 @@ static async Task GetTask() await CSVerify.VerifyAnalyzerAsync(test); } + [Fact] + public async Task DoNotReportWarningWhenAwaitingTaskPropertyOfObjectCreatedInContext_TargetTypeCreation() + { + string test = """ + using System.Threading.Tasks; + + class Test + { + static Task Exec2Async(string executable, params string[] args) + { + Process p = new(); + return p.Task; + } + } + + class Process + { + public Task Task { get; } + } + """; + await CSVerify.VerifyAnalyzerAsync(test); + } + /// /// This is important to allow folks to return jtf.RunAsync(...).Task from a method. /// @@ -1334,6 +1357,31 @@ static async Task GetTask() await CSVerify.VerifyAnalyzerAsync(test); } + [Fact] + public async Task DoNotReportWarningWhenAwaitingTaskPropertyOfObjectReturnedFromMethodViaLocal() + { + var test = """ + using System.Threading.Tasks; + + class JsonRpc + { + internal static JsonRpc Attach() => throw new System.NotImplementedException(); + + internal Task Completion { get; } + } + + class Tests + { + static async Task ListenAndWait() + { + var jsonRpc = JsonRpc.Attach(); + await jsonRpc.Completion; + } + } + """; + await CSVerify.VerifyAnalyzerAsync(test); + } + [Fact] public async Task ReportWarningWhenAwaitingTaskPropertyThatWasNotSetInContext() { @@ -1354,6 +1402,32 @@ async Task GetTask() await CSVerify.VerifyAnalyzerAsync(test); } + [Fact] + public async Task DoNotReportWarningWhenReturningTaskFromLambdaArgument() + { + var test = """ + using System.Linq; + using System.Threading.Tasks; + + class JsonRpc + { + internal static JsonRpc Attach() => throw new System.NotImplementedException(); + + internal Task Completion { get; } + } + + class Tests + { + static async Task ListenAndWait() + { + JsonRpc[] rpcs = new [] { JsonRpc.Attach(), JsonRpc.Attach() }; + await Task.WhenAll(rpcs.Select(r => r.Completion)); + } + } + """; + await CSVerify.VerifyAnalyzerAsync(test); + } + private DiagnosticResult CreateDiagnostic(int line, int column, int length) => CSVerify.Diagnostic().WithSpan(line, column, line, column + length); }