diff --git a/NullabilityInference.Tests/FlowTests.cs b/NullabilityInference.Tests/FlowTests.cs index f6c62d7..7cc60be 100644 --- a/NullabilityInference.Tests/FlowTests.cs +++ b/NullabilityInference.Tests/FlowTests.cs @@ -259,5 +259,25 @@ public bool TryGet(int i, [Attr] out string? name) expectedProgram: program.Replace("[Attr]", "[NotNullWhen(true)]").Replace("[using]", "using System.Diagnostics.CodeAnalysis;"), inputProgram: program.Replace("[Attr] ", "").Replace("[using]", "")); } + + [Fact] + public void UseNotNullIfNotNull() + { + string program = @" +using System.Diagnostics.CodeAnalysis; +class Program +{ + public void Test() + { + string a = Identitity(string.Empty); + string? b = Identitity(null); + } + +#nullable enable + [return: NotNullIfNotNull(""input"")] + public static string? Identitity(string? input) => input; +}"; + AssertNullabilityInference(program, program); + } } } diff --git a/NullabilityInference/EdgeBuildingOperationVisitor.cs b/NullabilityInference/EdgeBuildingOperationVisitor.cs index e35e782..c92ff37 100644 --- a/NullabilityInference/EdgeBuildingOperationVisitor.cs +++ b/NullabilityInference/EdgeBuildingOperationVisitor.cs @@ -889,9 +889,14 @@ public override TypeWithNode VisitInvocation(IInvocationOperation operation, Edg } methodTypeArgNodes = ExtendMethodTypeArguments(targetMethod, methodTypeArgNodes); var substitution = new TypeSubstitution(classTypeArgNodes, methodTypeArgNodes); - HandleArguments(substitution, operation.Arguments, invocationContext: argument); + var argumentTypes = HandleArguments(substitution, operation.Arguments, invocationContext: argument); var returnType = typeSystem.GetSymbolType(targetMethod.OriginalDefinition); returnType = returnType.WithSubstitution(targetMethod.ReturnType, substitution); + + if (typeSystem.GetNotNullIfNotNullParam(operation.TargetMethod.OriginalDefinition) is { } notNullParam) { + returnType = returnType.WithNode(argumentTypes[notNullParam.Ordinal].Node); + } + return returnType; } @@ -967,8 +972,9 @@ private void HandleMethodGroup(IMethodReferenceOperation operation, TypeWithNode _ => null, }; - private void HandleArguments(TypeSubstitution substitution, ImmutableArray arguments, EdgeBuildingContext invocationContext) + private List HandleArguments(TypeSubstitution substitution, ImmutableArray arguments, EdgeBuildingContext invocationContext) { + List argumentTypes = new List(); Action? afterCall = null; FlowState? flowStateOnTrue = null; FlowState? flowStateOnFalse = null; @@ -977,6 +983,7 @@ private void HandleArguments(TypeSubstitution substitution, ImmutableArray