Skip to content

Commit

Permalink
Add support for using existing NotNullIfNotNullAttribute
Browse files Browse the repository at this point in the history
  • Loading branch information
dgrunwald committed Jun 13, 2020
1 parent e58a738 commit 07b2d78
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 2 deletions.
20 changes: 20 additions & 0 deletions NullabilityInference.Tests/FlowTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -259,5 +259,25 @@ public bool TryGet(int i, [Attr] out string? name)
expectedProgram: program.Replace("[Attr]", "[NotNullWhen(true)]").Replace("[using]", "using System.Diagnostics.CodeAnalysis;"),
inputProgram: program.Replace("[Attr] ", "").Replace("[using]", ""));
}

[Fact]
public void UseNotNullIfNotNull()
{
string program = @"
using System.Diagnostics.CodeAnalysis;
class Program
{
public void Test()
{
string a = Identitity(string.Empty);
string? b = Identitity(null);
}
#nullable enable
[return: NotNullIfNotNull(""input"")]
public static string? Identitity(string? input) => input;
}";
AssertNullabilityInference(program, program);
}
}
}
12 changes: 10 additions & 2 deletions NullabilityInference/EdgeBuildingOperationVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -889,9 +889,14 @@ public override TypeWithNode VisitInvocation(IInvocationOperation operation, Edg
}
methodTypeArgNodes = ExtendMethodTypeArguments(targetMethod, methodTypeArgNodes);
var substitution = new TypeSubstitution(classTypeArgNodes, methodTypeArgNodes);
HandleArguments(substitution, operation.Arguments, invocationContext: argument);
var argumentTypes = HandleArguments(substitution, operation.Arguments, invocationContext: argument);
var returnType = typeSystem.GetSymbolType(targetMethod.OriginalDefinition);
returnType = returnType.WithSubstitution(targetMethod.ReturnType, substitution);

if (typeSystem.GetNotNullIfNotNullParam(operation.TargetMethod.OriginalDefinition) is { } notNullParam) {
returnType = returnType.WithNode(argumentTypes[notNullParam.Ordinal].Node);
}

return returnType;
}

Expand Down Expand Up @@ -967,8 +972,9 @@ private void HandleMethodGroup(IMethodReferenceOperation operation, TypeWithNode
_ => null,
};

private void HandleArguments(TypeSubstitution substitution, ImmutableArray<IArgumentOperation> arguments, EdgeBuildingContext invocationContext)
private List<TypeWithNode> HandleArguments(TypeSubstitution substitution, ImmutableArray<IArgumentOperation> arguments, EdgeBuildingContext invocationContext)
{
List<TypeWithNode> argumentTypes = new List<TypeWithNode>();
Action? afterCall = null;
FlowState? flowStateOnTrue = null;
FlowState? flowStateOnFalse = null;
Expand All @@ -977,6 +983,7 @@ private void HandleArguments(TypeSubstitution substitution, ImmutableArray<IArgu
var parameterType = typeSystem.GetSymbolType(param);
bool isLValue = param.RefKind == RefKind.Ref || param.RefKind == RefKind.Out;
var argumentType = Visit(arg.Value, isLValue ? EdgeBuildingContext.LValue : EdgeBuildingContext.Normal);
argumentTypes.Add(argumentType);
// Create an assignment edge from argument to parameter.
// We use the parameter's original type + substitution so that a type parameter `T` appearing in
// multiple parameters uses the same nullability nodes for all occurrences.
Expand Down Expand Up @@ -1024,6 +1031,7 @@ private void HandleArguments(TypeSubstitution substitution, ImmutableArray<IArgu
Debug.Assert(flowStateOnTrue == null && flowStateOnFalse == null);
Debug.Assert(flowStateReturnedOnTrue == null && flowStateReturnedOnFalse == null);
}
return argumentTypes;
}

private TypeWithNode currentObjectCreationType;
Expand Down
15 changes: 15 additions & 0 deletions NullabilityInference/TypeSystem.cs
Original file line number Diff line number Diff line change
Expand Up @@ -608,5 +608,20 @@ internal void CreateTypeEdge(TypeWithNode source, TypeWithNode target, TypeSubst
return edge;
}
}

internal IParameterSymbol? GetNotNullIfNotNullParam(IMethodSymbol method)
{
foreach (var attr in method.GetReturnTypeAttributes()) {
if (attr.AttributeClass?.GetFullName() == "System.Diagnostics.CodeAnalysis.NotNullIfNotNullAttribute"
&& attr.ConstructorArguments.Length == 1
&& attr.ConstructorArguments[0].Value is string parameterName) {
foreach (var p in method.Parameters) {
if (p.Name == parameterName)
return p;
}
}
}
return null;
}
}
}

0 comments on commit 07b2d78

Please sign in to comment.