Skip to content

Commit

Permalink
Merge pull request #59404 from CyrusNajmabadi/recordNullCheck
Browse files Browse the repository at this point in the history
Add support for adding null checks to records.
  • Loading branch information
CyrusNajmabadi committed Feb 9, 2022
2 parents 4489bfe + 6a169b2 commit 82094cf
Show file tree
Hide file tree
Showing 11 changed files with 148 additions and 87 deletions.
Expand Up @@ -2,23 +2,23 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

#nullable disable

using System.Threading.Tasks;
using Microsoft.CodeAnalysis.CodeStyle;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.CodeStyle;
using Microsoft.CodeAnalysis.CSharp.InitializeParameter;
using Microsoft.CodeAnalysis.CSharp.Shared.Extensions;
using Microsoft.CodeAnalysis.Editor.UnitTests.CodeActions;
using Microsoft.CodeAnalysis.Test.Utilities;
using Microsoft.CodeAnalysis.Testing;
using Roslyn.Test.Utilities;
using Xunit;

using VerifyCS = Microsoft.CodeAnalysis.Editor.UnitTests.CodeActions.CSharpCodeRefactoringVerifier<
Microsoft.CodeAnalysis.CSharp.InitializeParameter.CSharpAddParameterCheckCodeRefactoringProvider>;

namespace Microsoft.CodeAnalysis.Editor.CSharp.UnitTests.InitializeParameter
{
using VerifyCS = CSharpCodeRefactoringVerifier<
CSharpAddParameterCheckCodeRefactoringProvider>;

public class AddParameterCheckTests
{
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsInitializeParameter)]
Expand Down Expand Up @@ -76,27 +76,6 @@ public C([||]string s!!)
}.RunAsync();
}

[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsInitializeParameter)]
public async Task TestRecordPrimaryConstructor()
{
// https://github.com/dotnet/roslyn/issues/58779
// Note: we declare a field within the record to work around missing IsExternalInit errors
await new VerifyCS.Test
{
LanguageVersion = LanguageVersionExtensions.CSharpNext,
TestCode = @"
using System;
record Rec([||]string s) { public string s = s; }
",
FixedCode = @"
using System;
record Rec(string s) { public string s = s; }
"
}.RunAsync();
}

[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsInitializeParameter)]
public async Task TestSimpleReferenceType_AlreadyNullChecked2()
{
Expand Down Expand Up @@ -2859,5 +2838,48 @@ class C
}";
await VerifyCS.VerifyRefactoringAsync(source, source);
}

[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsInitializeParameter)]
[WorkItem(58779, "https://github.com/dotnet/roslyn/issues/58779")]
public async Task TestNotInRecordBeforeCSharp11()
{
var code = @"
record C([||]string s) { public string s; }";
await new VerifyCS.Test
{
LanguageVersion = LanguageVersion.CSharp10,
TestCode = code,
FixedCode = code,
}.RunAsync();
}

[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsInitializeParameter)]
[WorkItem(58779, "https://github.com/dotnet/roslyn/issues/58779")]
public async Task TestInRecordAfterCSharp11()
{
await new VerifyCS.Test
{
LanguageVersion = LanguageVersionExtensions.CSharpNext,
TestCode = @"
record C([||]string s) { public string s; }",
FixedCode = @"
record C(string s!!) { public string s; }",
}.RunAsync();
}

[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsInitializeParameter)]
[WorkItem(58779, "https://github.com/dotnet/roslyn/issues/58779")]
public async Task TestInRecordWithMultipleParametersAfterCSharp11()
{
await new VerifyCS.Test
{
LanguageVersion = LanguageVersionExtensions.CSharpNext,
TestCode = @"
record C([||]string s, string t) { public string s, t; }",
FixedCode = @"
record C(string s!!, string t!!) { public string s, t; }",
CodeActionIndex = 1,
}.RunAsync();
}
}
}
Expand Up @@ -4,17 +4,18 @@

using System;
using System.Composition;
using System.Threading.Tasks;
using System.Threading;
using Microsoft.CodeAnalysis.CodeRefactorings;
using Microsoft.CodeAnalysis.Host.Mef;
using Microsoft.CodeAnalysis.CSharp.CodeStyle;
using Microsoft.CodeAnalysis.CSharp.Extensions;
using Microsoft.CodeAnalysis.CSharp.LanguageServices;
using Microsoft.CodeAnalysis.CSharp.Shared.Extensions;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Editing;
using Microsoft.CodeAnalysis.Host.Mef;
using Microsoft.CodeAnalysis.InitializeParameter;
using Microsoft.CodeAnalysis.LanguageServices;
using Microsoft.CodeAnalysis.Options;
using Microsoft.CodeAnalysis.Shared.Extensions;
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;

namespace Microsoft.CodeAnalysis.CSharp.InitializeParameter
Expand All @@ -35,6 +36,13 @@ public CSharpAddParameterCheckCodeRefactoringProvider()
{
}

protected override ISyntaxFacts SyntaxFacts
=> CSharpSyntaxFacts.Instance;

// We need to be at least on c# 11 to support using !! with records.
protected override bool SupportsRecords(ParseOptions options)
=> options.LanguageVersion().IsCSharp11OrAbove();

protected override bool IsFunctionDeclaration(SyntaxNode node)
=> InitializeParameterHelpers.IsFunctionDeclaration(node);

Expand Down Expand Up @@ -97,11 +105,8 @@ protected override StatementSyntax CreateParameterCheckIfStatement(DocumentOptio
protected override Document? TryAddNullCheckToParameterDeclaration(Document document, ParameterSyntax parameterSyntax, CancellationToken cancellationToken)
{
var tree = parameterSyntax.SyntaxTree;
var options = (CSharpParseOptions)tree.Options;
if (options.LanguageVersion < LanguageVersionExtensions.CSharpNext)
{
if (!tree.Options.LanguageVersion().IsCSharp11OrAbove())
return null;
}

// We expect the syntax tree to already be in memory since we already have a node from the tree
var syntaxRoot = tree.GetRoot(cancellationToken);
Expand Down
Expand Up @@ -5,9 +5,11 @@
using System.Composition;
using System.Diagnostics.CodeAnalysis;
using Microsoft.CodeAnalysis.CodeRefactorings;
using Microsoft.CodeAnalysis.CSharp.LanguageServices;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Editing;
using Microsoft.CodeAnalysis.InitializeParameter;
using Microsoft.CodeAnalysis.LanguageServices;
using Microsoft.CodeAnalysis.Operations;

namespace Microsoft.CodeAnalysis.CSharp.InitializeParameter
Expand All @@ -28,6 +30,12 @@ public CSharpInitializeMemberFromParameterCodeRefactoringProvider()
{
}

protected override ISyntaxFacts SyntaxFacts
=> CSharpSyntaxFacts.Instance;

protected override bool SupportsRecords(ParseOptions options)
=> false;

protected override bool IsFunctionDeclaration(SyntaxNode node)
=> InitializeParameterHelpers.IsFunctionDeclaration(node);

Expand Down
Expand Up @@ -15,6 +15,7 @@
using Microsoft.CodeAnalysis.Operations;
using Microsoft.CodeAnalysis.Options;
using Microsoft.CodeAnalysis.PooledObjects;
using Microsoft.CodeAnalysis.Shared.Collections;
using Microsoft.CodeAnalysis.Shared.Extensions;
using Microsoft.CodeAnalysis.Text;
using Roslyn.Utilities;
Expand All @@ -37,22 +38,20 @@ internal abstract partial class AbstractAddParameterCheckCodeRefactoringProvider
where TExpressionSyntax : SyntaxNode
where TBinaryExpressionSyntax : TExpressionSyntax
{
private readonly Func<SyntaxNode, bool> _isFunctionDeclarationFunc;

protected AbstractAddParameterCheckCodeRefactoringProvider()
{
_isFunctionDeclarationFunc = IsFunctionDeclaration;
}

protected abstract bool CanOffer(SyntaxNode body);
protected abstract bool PrefersThrowExpression(DocumentOptionSet options);
protected abstract string EscapeResourceString(string input);
protected abstract TStatementSyntax CreateParameterCheckIfStatement(DocumentOptionSet options, TExpressionSyntax condition, TStatementSyntax ifTrueStatement);
protected abstract Document? TryAddNullCheckToParameterDeclaration(Document document, TParameterSyntax parameterSyntax, CancellationToken cancellationToken);

protected override async Task<ImmutableArray<CodeAction>> GetRefactoringsForAllParametersAsync(
Document document, SyntaxNode functionDeclaration, IMethodSymbol methodSymbol,
IBlockOperation? blockStatementOpt, ImmutableArray<SyntaxNode> listOfParameterNodes, TextSpan parameterSpan, CancellationToken cancellationToken)
Document document,
SyntaxNode funcOrRecord,
IMethodSymbol methodSymbol,
IBlockOperation? blockStatementOpt,
ImmutableArray<SyntaxNode> listOfParameterNodes,
TextSpan parameterSpan,
CancellationToken cancellationToken)
{
// List to keep track of the valid parameters
var listOfParametersOrdinals = new List<int>();
Expand All @@ -62,16 +61,12 @@ protected AbstractAddParameterCheckCodeRefactoringProvider()
{
var parameter = (IParameterSymbol)semanticModel.GetRequiredDeclaredSymbol(parameterNode, cancellationToken);
if (ParameterValidForNullCheck(document, parameter, semanticModel, blockStatementOpt, cancellationToken))
{
listOfParametersOrdinals.Add(parameter.Ordinal);
}
}

// Min 2 parameters to offer the refactoring
if (listOfParametersOrdinals.Count < 2)
{
return ImmutableArray<CodeAction>.Empty;
}

// Great. The list has parameters that need null checks. Offer to add null checks for all.
return ImmutableArray.Create<CodeAction>(new MyCodeAction(
Expand All @@ -84,7 +79,7 @@ protected AbstractAddParameterCheckCodeRefactoringProvider()
Document document,
TParameterSyntax parameterSyntax,
IParameterSymbol parameter,
SyntaxNode functionDeclaration,
SyntaxNode funcOrRecord,
IMethodSymbol methodSymbol,
IBlockOperation? blockStatementOpt,
CancellationToken cancellationToken)
Expand All @@ -93,33 +88,32 @@ protected AbstractAddParameterCheckCodeRefactoringProvider()

// Only should provide null-checks for reference types and nullable types.
if (!ParameterValidForNullCheck(document, parameter, semanticModel, blockStatementOpt, cancellationToken))
{
return ImmutableArray<CodeAction>.Empty;
}

// Great. There was no null check. Offer to add one.
using var _ = ArrayBuilder<CodeAction>.GetInstance(out var result);
using var result = TemporaryArray<CodeAction>.Empty;
result.Add(new MyCodeAction(
FeaturesResources.Add_null_check,
c => AddNullCheckAsync(document, parameterSyntax, parameter, functionDeclaration, methodSymbol, blockStatementOpt, c),
c => AddNullCheckAsync(document, parameterSyntax, parameter, funcOrRecord, methodSymbol, blockStatementOpt, c),
nameof(FeaturesResources.Add_null_check)));

// Also, if this was a string, offer to add the special checks to
// string.IsNullOrEmpty and string.IsNullOrWhitespace.
if (parameter.Type.SpecialType == SpecialType.System_String)
// Also, if this was a string, offer to add the special checks to string.IsNullOrEmpty and
// string.IsNullOrWhitespace. We cannot do this for records though as they have no location
// to place the checks.
if (parameter.Type.SpecialType == SpecialType.System_String && !IsRecordDeclaration(funcOrRecord))
{
result.Add(new MyCodeAction(
FeaturesResources.Add_string_IsNullOrEmpty_check,
c => AddStringCheckAsync(document, parameter, functionDeclaration, methodSymbol, blockStatementOpt, nameof(string.IsNullOrEmpty), c),
c => AddStringCheckAsync(document, parameter, funcOrRecord, methodSymbol, blockStatementOpt, nameof(string.IsNullOrEmpty), c),
nameof(FeaturesResources.Add_string_IsNullOrEmpty_check)));

result.Add(new MyCodeAction(
FeaturesResources.Add_string_IsNullOrWhiteSpace_check,
c => AddStringCheckAsync(document, parameter, functionDeclaration, methodSymbol, blockStatementOpt, nameof(string.IsNullOrWhiteSpace), c),
c => AddStringCheckAsync(document, parameter, funcOrRecord, methodSymbol, blockStatementOpt, nameof(string.IsNullOrWhiteSpace), c),
nameof(FeaturesResources.Add_string_IsNullOrWhiteSpace_check)));
}

return result.ToImmutable();
return result.ToImmutableAndClear();
}

private async Task<Document> UpdateDocumentForRefactoringAsync(
Expand All @@ -135,12 +129,13 @@ protected AbstractAddParameterCheckCodeRefactoringProvider()
var root = await document.GetRequiredSyntaxRootAsync(cancellationToken).ConfigureAwait(false);

var firstParameterNode = (TParameterSyntax)root.FindNode(parameterSpan);
var functionDeclaration = firstParameterNode.FirstAncestorOrSelf(_isFunctionDeclarationFunc);
if (functionDeclaration == null)
var funcOrRecord = firstParameterNode.FirstAncestorOrSelf(_isFunctionDeclarationFunc) ??
firstParameterNode.FirstAncestorOrSelf(_isRecordDeclarationFunc);
if (funcOrRecord == null)
continue;

var generator = SyntaxGenerator.GetGenerator(document);
var parameterNodes = (IReadOnlyList<TParameterSyntax>)generator.GetParameters(functionDeclaration);
var parameterNodes = (IReadOnlyList<TParameterSyntax>)generator.GetParameters(funcOrRecord);
var semanticModel = await document.GetRequiredSemanticModelAsync(cancellationToken).ConfigureAwait(false);
var (parameterSyntax, parameter) = GetParameterAtOrdinal(index, parameterNodes, semanticModel, cancellationToken);
if (parameter == null)
Expand All @@ -149,20 +144,19 @@ protected AbstractAddParameterCheckCodeRefactoringProvider()

var syntaxFacts = document.GetRequiredLanguageService<ISyntaxFactsService>();

if (!CanOfferRefactoring(functionDeclaration, semanticModel, syntaxFacts, cancellationToken, out blockStatementOpt))
{
if (!CanOfferRefactoring(funcOrRecord, semanticModel, syntaxFacts, cancellationToken, out blockStatementOpt))
continue;
}

// If parameter is a string, default check would be IsNullOrEmpty. This is because IsNullOrEmpty is more commonly used in this regard according to telemetry and UX testing.
if (parameter.Type.SpecialType == SpecialType.System_String)
// If parameter is a string, default check would be IsNullOrEmpty. This is because IsNullOrEmpty is more
// commonly used in this regard according to telemetry and UX testing.
if (parameter.Type.SpecialType == SpecialType.System_String && !IsRecordDeclaration(funcOrRecord))
{
document = await AddStringCheckAsync(document, parameter, functionDeclaration, (IMethodSymbol)parameter.ContainingSymbol, blockStatementOpt, nameof(string.IsNullOrEmpty), cancellationToken).ConfigureAwait(false);
document = await AddStringCheckAsync(document, parameter, funcOrRecord, (IMethodSymbol)parameter.ContainingSymbol, blockStatementOpt, nameof(string.IsNullOrEmpty), cancellationToken).ConfigureAwait(false);
continue;
}

// For all other parameters, add null check - updates document
document = await AddNullCheckAsync(document, parameterSyntax, parameter, functionDeclaration,
document = await AddNullCheckAsync(document, parameterSyntax, parameter, funcOrRecord,
(IMethodSymbol)parameter.ContainingSymbol, blockStatementOpt, cancellationToken).ConfigureAwait(false);
}

Expand Down
Expand Up @@ -40,11 +40,12 @@ internal abstract partial class AbstractInitializeMemberFromParameterCodeRefacto
where TExpressionSyntax : SyntaxNode
{
protected abstract SyntaxNode? TryGetLastStatement(IBlockOperation? blockStatementOpt);

protected abstract Accessibility DetermineDefaultFieldAccessibility(INamedTypeSymbol containingType);

protected abstract Accessibility DetermineDefaultPropertyAccessibility();

protected override bool SupportsRecords(ParseOptions options)
=> false;

protected override Task<ImmutableArray<CodeAction>> GetRefactoringsForAllParametersAsync(
Document document, SyntaxNode functionDeclaration, IMethodSymbol method, IBlockOperation? blockStatementOpt,
ImmutableArray<SyntaxNode> listOfParameterNodes, TextSpan parameterSpan, CancellationToken cancellationToken)
Expand Down

0 comments on commit 82094cf

Please sign in to comment.