Skip to content

Commit

Permalink
Add support for adding null checks to records.
Browse files Browse the repository at this point in the history
  • Loading branch information
CyrusNajmabadi committed Feb 9, 2022
1 parent d45083f commit 6a169b2
Show file tree
Hide file tree
Showing 11 changed files with 148 additions and 87 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,23 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

#nullable disable

using System.Threading.Tasks;
using Microsoft.CodeAnalysis.CodeStyle;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.CodeStyle;
using Microsoft.CodeAnalysis.CSharp.InitializeParameter;
using Microsoft.CodeAnalysis.CSharp.Shared.Extensions;
using Microsoft.CodeAnalysis.Editor.UnitTests.CodeActions;
using Microsoft.CodeAnalysis.Test.Utilities;
using Microsoft.CodeAnalysis.Testing;
using Roslyn.Test.Utilities;
using Xunit;

using VerifyCS = Microsoft.CodeAnalysis.Editor.UnitTests.CodeActions.CSharpCodeRefactoringVerifier<
Microsoft.CodeAnalysis.CSharp.InitializeParameter.CSharpAddParameterCheckCodeRefactoringProvider>;

namespace Microsoft.CodeAnalysis.Editor.CSharp.UnitTests.InitializeParameter
{
using VerifyCS = CSharpCodeRefactoringVerifier<
CSharpAddParameterCheckCodeRefactoringProvider>;

public class AddParameterCheckTests
{
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsInitializeParameter)]
Expand Down Expand Up @@ -76,27 +76,6 @@ public C([||]string s!!)
}.RunAsync();
}

[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsInitializeParameter)]
public async Task TestRecordPrimaryConstructor()
{
// https://github.com/dotnet/roslyn/issues/58779
// Note: we declare a field within the record to work around missing IsExternalInit errors
await new VerifyCS.Test
{
LanguageVersion = LanguageVersionExtensions.CSharpNext,
TestCode = @"
using System;
record Rec([||]string s) { public string s = s; }
",
FixedCode = @"
using System;
record Rec(string s) { public string s = s; }
"
}.RunAsync();
}

[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsInitializeParameter)]
public async Task TestSimpleReferenceType_AlreadyNullChecked2()
{
Expand Down Expand Up @@ -2859,5 +2838,48 @@ class C
}";
await VerifyCS.VerifyRefactoringAsync(source, source);
}

[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsInitializeParameter)]
[WorkItem(58779, "https://github.com/dotnet/roslyn/issues/58779")]
public async Task TestNotInRecordBeforeCSharp11()
{
var code = @"
record C([||]string s) { public string s; }";
await new VerifyCS.Test
{
LanguageVersion = LanguageVersion.CSharp10,
TestCode = code,
FixedCode = code,
}.RunAsync();
}

[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsInitializeParameter)]
[WorkItem(58779, "https://github.com/dotnet/roslyn/issues/58779")]
public async Task TestInRecordAfterCSharp11()
{
await new VerifyCS.Test
{
LanguageVersion = LanguageVersionExtensions.CSharpNext,
TestCode = @"
record C([||]string s) { public string s; }",
FixedCode = @"
record C(string s!!) { public string s; }",
}.RunAsync();
}

[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsInitializeParameter)]
[WorkItem(58779, "https://github.com/dotnet/roslyn/issues/58779")]
public async Task TestInRecordWithMultipleParametersAfterCSharp11()
{
await new VerifyCS.Test
{
LanguageVersion = LanguageVersionExtensions.CSharpNext,
TestCode = @"
record C([||]string s, string t) { public string s, t; }",
FixedCode = @"
record C(string s!!, string t!!) { public string s, t; }",
CodeActionIndex = 1,
}.RunAsync();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,18 @@

using System;
using System.Composition;
using System.Threading.Tasks;
using System.Threading;
using Microsoft.CodeAnalysis.CodeRefactorings;
using Microsoft.CodeAnalysis.Host.Mef;
using Microsoft.CodeAnalysis.CSharp.CodeStyle;
using Microsoft.CodeAnalysis.CSharp.Extensions;
using Microsoft.CodeAnalysis.CSharp.LanguageServices;
using Microsoft.CodeAnalysis.CSharp.Shared.Extensions;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Editing;
using Microsoft.CodeAnalysis.Host.Mef;
using Microsoft.CodeAnalysis.InitializeParameter;
using Microsoft.CodeAnalysis.LanguageServices;
using Microsoft.CodeAnalysis.Options;
using Microsoft.CodeAnalysis.Shared.Extensions;
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;

namespace Microsoft.CodeAnalysis.CSharp.InitializeParameter
Expand All @@ -35,6 +36,13 @@ public CSharpAddParameterCheckCodeRefactoringProvider()
{
}

protected override ISyntaxFacts SyntaxFacts
=> CSharpSyntaxFacts.Instance;

// We need to be at least on c# 11 to support using !! with records.
protected override bool SupportsRecords(ParseOptions options)
=> options.LanguageVersion().IsCSharp11OrAbove();

protected override bool IsFunctionDeclaration(SyntaxNode node)
=> InitializeParameterHelpers.IsFunctionDeclaration(node);

Expand Down Expand Up @@ -97,11 +105,8 @@ protected override StatementSyntax CreateParameterCheckIfStatement(DocumentOptio
protected override Document? TryAddNullCheckToParameterDeclaration(Document document, ParameterSyntax parameterSyntax, CancellationToken cancellationToken)
{
var tree = parameterSyntax.SyntaxTree;
var options = (CSharpParseOptions)tree.Options;
if (options.LanguageVersion < LanguageVersionExtensions.CSharpNext)
{
if (!tree.Options.LanguageVersion().IsCSharp11OrAbove())
return null;
}

// We expect the syntax tree to already be in memory since we already have a node from the tree
var syntaxRoot = tree.GetRoot(cancellationToken);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
using System.Composition;
using System.Diagnostics.CodeAnalysis;
using Microsoft.CodeAnalysis.CodeRefactorings;
using Microsoft.CodeAnalysis.CSharp.LanguageServices;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Editing;
using Microsoft.CodeAnalysis.InitializeParameter;
using Microsoft.CodeAnalysis.LanguageServices;
using Microsoft.CodeAnalysis.Operations;

namespace Microsoft.CodeAnalysis.CSharp.InitializeParameter
Expand All @@ -28,6 +30,12 @@ public CSharpInitializeMemberFromParameterCodeRefactoringProvider()
{
}

protected override ISyntaxFacts SyntaxFacts
=> CSharpSyntaxFacts.Instance;

protected override bool SupportsRecords(ParseOptions options)
=> false;

protected override bool IsFunctionDeclaration(SyntaxNode node)
=> InitializeParameterHelpers.IsFunctionDeclaration(node);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
using Microsoft.CodeAnalysis.Operations;
using Microsoft.CodeAnalysis.Options;
using Microsoft.CodeAnalysis.PooledObjects;
using Microsoft.CodeAnalysis.Shared.Collections;
using Microsoft.CodeAnalysis.Shared.Extensions;
using Microsoft.CodeAnalysis.Text;
using Roslyn.Utilities;
Expand All @@ -37,22 +38,20 @@ internal abstract partial class AbstractAddParameterCheckCodeRefactoringProvider
where TExpressionSyntax : SyntaxNode
where TBinaryExpressionSyntax : TExpressionSyntax
{
private readonly Func<SyntaxNode, bool> _isFunctionDeclarationFunc;

protected AbstractAddParameterCheckCodeRefactoringProvider()
{
_isFunctionDeclarationFunc = IsFunctionDeclaration;
}

protected abstract bool CanOffer(SyntaxNode body);
protected abstract bool PrefersThrowExpression(DocumentOptionSet options);
protected abstract string EscapeResourceString(string input);
protected abstract TStatementSyntax CreateParameterCheckIfStatement(DocumentOptionSet options, TExpressionSyntax condition, TStatementSyntax ifTrueStatement);
protected abstract Document? TryAddNullCheckToParameterDeclaration(Document document, TParameterSyntax parameterSyntax, CancellationToken cancellationToken);

protected override async Task<ImmutableArray<CodeAction>> GetRefactoringsForAllParametersAsync(
Document document, SyntaxNode functionDeclaration, IMethodSymbol methodSymbol,
IBlockOperation? blockStatementOpt, ImmutableArray<SyntaxNode> listOfParameterNodes, TextSpan parameterSpan, CancellationToken cancellationToken)
Document document,
SyntaxNode funcOrRecord,
IMethodSymbol methodSymbol,
IBlockOperation? blockStatementOpt,
ImmutableArray<SyntaxNode> listOfParameterNodes,
TextSpan parameterSpan,
CancellationToken cancellationToken)
{
// List to keep track of the valid parameters
var listOfParametersOrdinals = new List<int>();
Expand All @@ -62,16 +61,12 @@ protected AbstractAddParameterCheckCodeRefactoringProvider()
{
var parameter = (IParameterSymbol)semanticModel.GetRequiredDeclaredSymbol(parameterNode, cancellationToken);
if (ParameterValidForNullCheck(document, parameter, semanticModel, blockStatementOpt, cancellationToken))
{
listOfParametersOrdinals.Add(parameter.Ordinal);
}
}

// Min 2 parameters to offer the refactoring
if (listOfParametersOrdinals.Count < 2)
{
return ImmutableArray<CodeAction>.Empty;
}

// Great. The list has parameters that need null checks. Offer to add null checks for all.
return ImmutableArray.Create<CodeAction>(new MyCodeAction(
Expand All @@ -84,7 +79,7 @@ protected AbstractAddParameterCheckCodeRefactoringProvider()
Document document,
TParameterSyntax parameterSyntax,
IParameterSymbol parameter,
SyntaxNode functionDeclaration,
SyntaxNode funcOrRecord,
IMethodSymbol methodSymbol,
IBlockOperation? blockStatementOpt,
CancellationToken cancellationToken)
Expand All @@ -93,33 +88,32 @@ protected AbstractAddParameterCheckCodeRefactoringProvider()

// Only should provide null-checks for reference types and nullable types.
if (!ParameterValidForNullCheck(document, parameter, semanticModel, blockStatementOpt, cancellationToken))
{
return ImmutableArray<CodeAction>.Empty;
}

// Great. There was no null check. Offer to add one.
using var _ = ArrayBuilder<CodeAction>.GetInstance(out var result);
using var result = TemporaryArray<CodeAction>.Empty;
result.Add(new MyCodeAction(
FeaturesResources.Add_null_check,
c => AddNullCheckAsync(document, parameterSyntax, parameter, functionDeclaration, methodSymbol, blockStatementOpt, c),
c => AddNullCheckAsync(document, parameterSyntax, parameter, funcOrRecord, methodSymbol, blockStatementOpt, c),
nameof(FeaturesResources.Add_null_check)));

// Also, if this was a string, offer to add the special checks to
// string.IsNullOrEmpty and string.IsNullOrWhitespace.
if (parameter.Type.SpecialType == SpecialType.System_String)
// Also, if this was a string, offer to add the special checks to string.IsNullOrEmpty and
// string.IsNullOrWhitespace. We cannot do this for records though as they have no location
// to place the checks.
if (parameter.Type.SpecialType == SpecialType.System_String && !IsRecordDeclaration(funcOrRecord))
{
result.Add(new MyCodeAction(
FeaturesResources.Add_string_IsNullOrEmpty_check,
c => AddStringCheckAsync(document, parameter, functionDeclaration, methodSymbol, blockStatementOpt, nameof(string.IsNullOrEmpty), c),
c => AddStringCheckAsync(document, parameter, funcOrRecord, methodSymbol, blockStatementOpt, nameof(string.IsNullOrEmpty), c),
nameof(FeaturesResources.Add_string_IsNullOrEmpty_check)));

result.Add(new MyCodeAction(
FeaturesResources.Add_string_IsNullOrWhiteSpace_check,
c => AddStringCheckAsync(document, parameter, functionDeclaration, methodSymbol, blockStatementOpt, nameof(string.IsNullOrWhiteSpace), c),
c => AddStringCheckAsync(document, parameter, funcOrRecord, methodSymbol, blockStatementOpt, nameof(string.IsNullOrWhiteSpace), c),
nameof(FeaturesResources.Add_string_IsNullOrWhiteSpace_check)));
}

return result.ToImmutable();
return result.ToImmutableAndClear();
}

private async Task<Document> UpdateDocumentForRefactoringAsync(
Expand All @@ -135,12 +129,13 @@ protected AbstractAddParameterCheckCodeRefactoringProvider()
var root = await document.GetRequiredSyntaxRootAsync(cancellationToken).ConfigureAwait(false);

var firstParameterNode = (TParameterSyntax)root.FindNode(parameterSpan);
var functionDeclaration = firstParameterNode.FirstAncestorOrSelf(_isFunctionDeclarationFunc);
if (functionDeclaration == null)
var funcOrRecord = firstParameterNode.FirstAncestorOrSelf(_isFunctionDeclarationFunc) ??
firstParameterNode.FirstAncestorOrSelf(_isRecordDeclarationFunc);
if (funcOrRecord == null)
continue;

var generator = SyntaxGenerator.GetGenerator(document);
var parameterNodes = (IReadOnlyList<TParameterSyntax>)generator.GetParameters(functionDeclaration);
var parameterNodes = (IReadOnlyList<TParameterSyntax>)generator.GetParameters(funcOrRecord);
var semanticModel = await document.GetRequiredSemanticModelAsync(cancellationToken).ConfigureAwait(false);
var (parameterSyntax, parameter) = GetParameterAtOrdinal(index, parameterNodes, semanticModel, cancellationToken);
if (parameter == null)
Expand All @@ -149,20 +144,19 @@ protected AbstractAddParameterCheckCodeRefactoringProvider()

var syntaxFacts = document.GetRequiredLanguageService<ISyntaxFactsService>();

if (!CanOfferRefactoring(functionDeclaration, semanticModel, syntaxFacts, cancellationToken, out blockStatementOpt))
{
if (!CanOfferRefactoring(funcOrRecord, semanticModel, syntaxFacts, cancellationToken, out blockStatementOpt))
continue;
}

// If parameter is a string, default check would be IsNullOrEmpty. This is because IsNullOrEmpty is more commonly used in this regard according to telemetry and UX testing.
if (parameter.Type.SpecialType == SpecialType.System_String)
// If parameter is a string, default check would be IsNullOrEmpty. This is because IsNullOrEmpty is more
// commonly used in this regard according to telemetry and UX testing.
if (parameter.Type.SpecialType == SpecialType.System_String && !IsRecordDeclaration(funcOrRecord))
{
document = await AddStringCheckAsync(document, parameter, functionDeclaration, (IMethodSymbol)parameter.ContainingSymbol, blockStatementOpt, nameof(string.IsNullOrEmpty), cancellationToken).ConfigureAwait(false);
document = await AddStringCheckAsync(document, parameter, funcOrRecord, (IMethodSymbol)parameter.ContainingSymbol, blockStatementOpt, nameof(string.IsNullOrEmpty), cancellationToken).ConfigureAwait(false);
continue;
}

// For all other parameters, add null check - updates document
document = await AddNullCheckAsync(document, parameterSyntax, parameter, functionDeclaration,
document = await AddNullCheckAsync(document, parameterSyntax, parameter, funcOrRecord,
(IMethodSymbol)parameter.ContainingSymbol, blockStatementOpt, cancellationToken).ConfigureAwait(false);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,12 @@ internal abstract partial class AbstractInitializeMemberFromParameterCodeRefacto
where TExpressionSyntax : SyntaxNode
{
protected abstract SyntaxNode? TryGetLastStatement(IBlockOperation? blockStatementOpt);

protected abstract Accessibility DetermineDefaultFieldAccessibility(INamedTypeSymbol containingType);

protected abstract Accessibility DetermineDefaultPropertyAccessibility();

protected override bool SupportsRecords(ParseOptions options)
=> false;

protected override Task<ImmutableArray<CodeAction>> GetRefactoringsForAllParametersAsync(
Document document, SyntaxNode functionDeclaration, IMethodSymbol method, IBlockOperation? blockStatementOpt,
ImmutableArray<SyntaxNode> listOfParameterNodes, TextSpan parameterSpan, CancellationToken cancellationToken)
Expand Down

0 comments on commit 6a169b2

Please sign in to comment.