diff --git a/src/EFCore.Relational/Query/Internal/SqlExpressionSimplifyingExpressionVisitor.cs b/src/EFCore.Relational/Query/Internal/SqlExpressionSimplifyingExpressionVisitor.cs index da3dafd0e2c..35f566a35e9 100644 --- a/src/EFCore.Relational/Query/Internal/SqlExpressionSimplifyingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/Internal/SqlExpressionSimplifyingExpressionVisitor.cs @@ -44,10 +44,10 @@ protected override Expression VisitExtension(Expression extensionExpression) } // Only applies to 'CASE WHEN condition...' not 'CASE operand WHEN...' - if (extensionExpression is CaseExpression caseExpression - && caseExpression.Operand == null - && caseExpression.ElseResult is CaseExpression nestedCaseExpression - && nestedCaseExpression.Operand == null) + if (extensionExpression is CaseExpression + { + Operand: null, ElseResult: CaseExpression { Operand: null } nestedCaseExpression + } caseExpression) { return VisitExtension( _sqlExpressionFactory.Case( @@ -93,37 +93,29 @@ protected override Expression VisitExtension(Expression extensionExpression) return base.VisitExtension(extensionExpression); static bool IsCoalesce(SqlExpression sqlExpression) - => sqlExpression is SqlFunctionExpression sqlFunctionExpression - && sqlFunctionExpression.IsBuiltIn - && sqlFunctionExpression.Instance == null + => sqlExpression is SqlFunctionExpression { IsBuiltIn: true, Instance: null } sqlFunctionExpression && string.Equals(sqlFunctionExpression.Name, "COALESCE", StringComparison.OrdinalIgnoreCase) && sqlFunctionExpression.Arguments?.Count > 1; } private static bool IsCompareTo([NotNullWhen(true)] CaseExpression? caseExpression) { - if (caseExpression != null - && caseExpression.Operand == null - && caseExpression.ElseResult == null - && caseExpression.WhenClauses.Count == 3 - && caseExpression.WhenClauses.All( - c => c.Test is SqlBinaryExpression - && c.Result is SqlConstantExpression constant - && constant.Value is int)) + if (caseExpression is { Operand: null, ElseResult: null, WhenClauses.Count: 3 } + && caseExpression.WhenClauses.All(c => c is { Test: SqlBinaryExpression, Result: SqlConstantExpression { Value: int } })) { var whenClauses = caseExpression.WhenClauses.Select( - c => new { test = (SqlBinaryExpression)c.Test, resultValue = (int)((SqlConstantExpression)c.Result).Value! }).ToList(); - - if (whenClauses[0].test.Left.Equals(whenClauses[1].test.Left) - && whenClauses[1].test.Left.Equals(whenClauses[2].test.Left) - && whenClauses[0].test.Right.Equals(whenClauses[1].test.Right) - && whenClauses[1].test.Right.Equals(whenClauses[2].test.Right) - && whenClauses[0].test.OperatorType == ExpressionType.Equal - && whenClauses[1].test.OperatorType == ExpressionType.GreaterThan - && whenClauses[2].test.OperatorType == ExpressionType.LessThan - && whenClauses[0].resultValue == 0 - && whenClauses[1].resultValue == 1 - && whenClauses[2].resultValue == -1) + c => new { Test = (SqlBinaryExpression)c.Test, ResultValue = (int)((SqlConstantExpression)c.Result).Value! }).ToList(); + + if (whenClauses[0].Test.Left.Equals(whenClauses[1].Test.Left) + && whenClauses[1].Test.Left.Equals(whenClauses[2].Test.Left) + && whenClauses[0].Test.Right.Equals(whenClauses[1].Test.Right) + && whenClauses[1].Test.Right.Equals(whenClauses[2].Test.Right) + && whenClauses[0].Test.OperatorType == ExpressionType.Equal + && whenClauses[1].Test.OperatorType == ExpressionType.GreaterThan + && whenClauses[2].Test.OperatorType == ExpressionType.LessThan + && whenClauses[0].ResultValue == 0 + && whenClauses[1].ResultValue == 1 + && whenClauses[2].ResultValue == -1) { return true; } @@ -150,69 +142,57 @@ private static bool IsCompareTo([NotNullWhen(true)] CaseExpression? caseExpressi _ => sqlBinaryExpression.OperatorType }; - switch (operatorType) + return operatorType switch { // CompareTo(a, b) != 0 -> a != b // CompareTo(a, b) != 1 -> a <= b // CompareTo(a, b) != -1 -> a >= b - case ExpressionType.NotEqual: - return (SqlExpression)Visit( - intValue switch - { - 0 => _sqlExpressionFactory.NotEqual(testLeft, testRight), - 1 => _sqlExpressionFactory.LessThanOrEqual(testLeft, testRight), - _ => _sqlExpressionFactory.GreaterThanOrEqual(testLeft, testRight) - }); - + ExpressionType.NotEqual => (SqlExpression)Visit( + intValue switch + { + 0 => _sqlExpressionFactory.NotEqual(testLeft, testRight), + 1 => _sqlExpressionFactory.LessThanOrEqual(testLeft, testRight), + _ => _sqlExpressionFactory.GreaterThanOrEqual(testLeft, testRight) + }), // CompareTo(a, b) > 0 -> a > b // CompareTo(a, b) > 1 -> false // CompareTo(a, b) > -1 -> a >= b - case ExpressionType.GreaterThan: - return (SqlExpression)Visit( - intValue switch - { - 0 => _sqlExpressionFactory.GreaterThan(testLeft, testRight), - 1 => _sqlExpressionFactory.Constant(false, sqlBinaryExpression.TypeMapping), - _ => _sqlExpressionFactory.GreaterThanOrEqual(testLeft, testRight) - }); - + ExpressionType.GreaterThan => (SqlExpression)Visit( + intValue switch + { + 0 => _sqlExpressionFactory.GreaterThan(testLeft, testRight), + 1 => _sqlExpressionFactory.Constant(false, sqlBinaryExpression.TypeMapping), + _ => _sqlExpressionFactory.GreaterThanOrEqual(testLeft, testRight) + }), // CompareTo(a, b) >= 0 -> a >= b // CompareTo(a, b) >= 1 -> a > b // CompareTo(a, b) >= -1 -> true - case ExpressionType.GreaterThanOrEqual: - return (SqlExpression)Visit( - intValue switch - { - 0 => _sqlExpressionFactory.GreaterThanOrEqual(testLeft, testRight), - 1 => _sqlExpressionFactory.GreaterThan(testLeft, testRight), - _ => _sqlExpressionFactory.Constant(true, sqlBinaryExpression.TypeMapping) - }); - + ExpressionType.GreaterThanOrEqual => (SqlExpression)Visit( + intValue switch + { + 0 => _sqlExpressionFactory.GreaterThanOrEqual(testLeft, testRight), + 1 => _sqlExpressionFactory.GreaterThan(testLeft, testRight), + _ => _sqlExpressionFactory.Constant(true, sqlBinaryExpression.TypeMapping) + }), // CompareTo(a, b) < 0 -> a < b // CompareTo(a, b) < 1 -> a <= b // CompareTo(a, b) < -1 -> false - case ExpressionType.LessThan: - return (SqlExpression)Visit( - intValue switch - { - 0 => _sqlExpressionFactory.LessThan(testLeft, testRight), - 1 => _sqlExpressionFactory.LessThanOrEqual(testLeft, testRight), - _ => _sqlExpressionFactory.Constant(false, sqlBinaryExpression.TypeMapping) - }); - - // operatorType == ExpressionType.LessThanOrEqual - // CompareTo(a, b) <= 0 -> a <= b - // CompareTo(a, b) <= 1 -> true - // CompareTo(a, b) <= -1 -> a < b - default: - return (SqlExpression)Visit( - intValue switch - { - 0 => _sqlExpressionFactory.LessThanOrEqual(testLeft, testRight), - 1 => _sqlExpressionFactory.Constant(true, sqlBinaryExpression.TypeMapping), - _ => _sqlExpressionFactory.LessThan(testLeft, testRight) - }); - } + ExpressionType.LessThan => (SqlExpression)Visit( + intValue switch + { + 0 => _sqlExpressionFactory.LessThan(testLeft, testRight), + 1 => _sqlExpressionFactory.LessThanOrEqual(testLeft, testRight), + _ => _sqlExpressionFactory.Constant(false, sqlBinaryExpression.TypeMapping) + }), + + _ => (SqlExpression)Visit( + intValue switch + { + 0 => _sqlExpressionFactory.LessThanOrEqual(testLeft, testRight), + 1 => _sqlExpressionFactory.Constant(true, sqlBinaryExpression.TypeMapping), + _ => _sqlExpressionFactory.LessThan(testLeft, testRight) + }) + }; } private Expression SimplifySqlBinary(SqlBinaryExpression sqlBinaryExpression) @@ -228,11 +208,8 @@ private Expression SimplifySqlBinary(SqlBinaryExpression sqlBinaryExpression) // WHEN ... // WHEN conditionN THEN resultN) == result1 -> condition1 if (sqlBinaryExpression.OperatorType == ExpressionType.Equal - && sqlConstantComponent != null - && sqlConstantComponent.Value != null - && caseComponent != null - && caseComponent.Operand == null - && caseComponent.ElseResult == null) + && sqlConstantComponent?.Value is not null + && caseComponent is { Operand: null, ElseResult: null }) { var matchingCaseBlock = caseComponent.WhenClauses.FirstOrDefault(wc => sqlConstantComponent.Equals(wc.Result)); if (matchingCaseBlock != null) @@ -244,13 +221,13 @@ private Expression SimplifySqlBinary(SqlBinaryExpression sqlBinaryExpression) // CompareTo specific optimizations if (sqlConstantComponent != null && IsCompareTo(caseComponent) - && sqlConstantComponent.Value is int intValue - && (intValue > -2 && intValue < 2) - && (sqlBinaryExpression.OperatorType == ExpressionType.NotEqual - || sqlBinaryExpression.OperatorType == ExpressionType.GreaterThan - || sqlBinaryExpression.OperatorType == ExpressionType.GreaterThanOrEqual - || sqlBinaryExpression.OperatorType == ExpressionType.LessThan - || sqlBinaryExpression.OperatorType == ExpressionType.LessThanOrEqual)) + && sqlConstantComponent.Value is int intValue and > -2 and < 2 + && sqlBinaryExpression.OperatorType + is ExpressionType.NotEqual + or ExpressionType.GreaterThan + or ExpressionType.GreaterThanOrEqual + or ExpressionType.LessThan + or ExpressionType.LessThanOrEqual) { return OptimizeCompareTo( sqlBinaryExpression, @@ -261,8 +238,7 @@ private Expression SimplifySqlBinary(SqlBinaryExpression sqlBinaryExpression) var left = (SqlExpression)Visit(sqlBinaryExpression.Left); var right = (SqlExpression)Visit(sqlBinaryExpression.Right); - if (sqlBinaryExpression.OperatorType == ExpressionType.AndAlso - || sqlBinaryExpression.OperatorType == ExpressionType.OrElse) + if (sqlBinaryExpression.OperatorType is ExpressionType.AndAlso or ExpressionType.OrElse) { if (TryGetInExpressionCandidateInfo(left, out var leftCandidateInfo) && TryGetInExpressionCandidateInfo(right, out var rightCandidateInfo) @@ -286,47 +262,49 @@ private Expression SimplifySqlBinary(SqlBinaryExpression sqlBinaryExpression) object rightValue; List resultArray; - if (!leftConstantIsEnumerable && !rightConstantIsEnumerable) - { - // comparison + comparison - leftValue = leftCandidateInfo.ConstantValue; - rightValue = rightCandidateInfo.ConstantValue; - - // for relational nulls we can't combine comparisons that contain null - // a != 1 && a != null would be converted to a NOT IN (1, null), which never returns any results - // we need to keep it in the original form so that a != null gets converted to a IS NOT NULL instead - // for c# null semantics it's fine because null semantics visitor extracts null back into proper null checks - if (_useRelationalNulls && (leftValue == null || rightValue == null)) - { - return sqlBinaryExpression.Update(left, right); - } - - resultArray = ConstructCollection(leftValue, rightValue); - } - else if (leftConstantIsEnumerable && rightConstantIsEnumerable) + switch ((leftConstantIsEnumerable, rightConstantIsEnumerable)) { - // in + in - leftValue = leftCandidateInfo.ConstantValue; - rightValue = rightCandidateInfo.ConstantValue; - resultArray = UnionCollections((IEnumerable)leftValue, (IEnumerable)rightValue); - } - else - { - // in + comparison - leftValue = leftConstantIsEnumerable - ? leftCandidateInfo.ConstantValue - : rightCandidateInfo.ConstantValue; - - rightValue = leftConstantIsEnumerable - ? rightCandidateInfo.ConstantValue - : leftCandidateInfo.ConstantValue; - - if (_useRelationalNulls && rightValue == null) - { - return sqlBinaryExpression.Update(left, right); - } - - resultArray = AddToCollection((IEnumerable)leftValue, rightValue); + case (false, false): + // comparison + comparison + leftValue = leftCandidateInfo.ConstantValue; + rightValue = rightCandidateInfo.ConstantValue; + + // for relational nulls we can't combine comparisons that contain null + // a != 1 && a != null would be converted to a NOT IN (1, null), which never returns any results + // we need to keep it in the original form so that a != null gets converted to a IS NOT NULL instead + // for c# null semantics it's fine because null semantics visitor extracts null back into proper null checks + if (_useRelationalNulls && (leftValue == null || rightValue == null)) + { + return sqlBinaryExpression.Update(left, right); + } + + resultArray = ConstructCollection(leftValue, rightValue); + break; + + case (true, true): + // in + in + leftValue = leftCandidateInfo.ConstantValue; + rightValue = rightCandidateInfo.ConstantValue; + resultArray = UnionCollections((IEnumerable)leftValue, (IEnumerable)rightValue); + break; + + default: + // in + comparison + leftValue = leftConstantIsEnumerable + ? leftCandidateInfo.ConstantValue + : rightCandidateInfo.ConstantValue; + + rightValue = leftConstantIsEnumerable + ? rightCandidateInfo.ConstantValue + : leftCandidateInfo.ConstantValue; + + if (_useRelationalNulls && rightValue == null) + { + return sqlBinaryExpression.Update(left, right); + } + + resultArray = AddToCollection((IEnumerable)leftValue, rightValue); + break; } return _sqlExpressionFactory.In( @@ -422,8 +400,7 @@ private static List BuildListFromEnumerable(IEnumerable collection) out (ColumnExpression ColumnExpression, object ConstantValue, RelationalTypeMapping TypeMapping, ExpressionType OperationType) candidateInfo) { - if (sqlExpression is SqlUnaryExpression sqlUnaryExpression - && sqlUnaryExpression.OperatorType == ExpressionType.Not) + if (sqlExpression is SqlUnaryExpression { OperatorType: ExpressionType.Not } sqlUnaryExpression) { if (TryGetInExpressionCandidateInfo(sqlUnaryExpression.Operand, out var inner)) { @@ -433,9 +410,7 @@ private static List BuildListFromEnumerable(IEnumerable collection) return true; } } - else if (sqlExpression is SqlBinaryExpression sqlBinaryExpression - && (sqlBinaryExpression.OperatorType == ExpressionType.Equal - || sqlBinaryExpression.OperatorType == ExpressionType.NotEqual)) + else if (sqlExpression is SqlBinaryExpression { OperatorType: ExpressionType.Equal or ExpressionType.NotEqual } sqlBinaryExpression) { var column = (sqlBinaryExpression.Left as ColumnExpression ?? sqlBinaryExpression.Right as ColumnExpression); var constant = (sqlBinaryExpression.Left as SqlConstantExpression ?? sqlBinaryExpression.Right as SqlConstantExpression); @@ -446,10 +421,10 @@ private static List BuildListFromEnumerable(IEnumerable collection) return true; } } - else if (sqlExpression is InExpression inExpression - && inExpression.Item is ColumnExpression column - && inExpression.Subquery == null - && inExpression.Values is SqlConstantExpression valuesConstant) + else if (sqlExpression is InExpression + { + Item: ColumnExpression column, Subquery: null, Values: SqlConstantExpression valuesConstant + } inExpression) { candidateInfo = (column, valuesConstant.Value!, valuesConstant.TypeMapping!, inExpression.IsNegated ? ExpressionType.NotEqual : ExpressionType.Equal);