Skip to content

Commit

Permalink
Simplify SqlExpressionSimplifyingExpressionVisitor
Browse files Browse the repository at this point in the history
  • Loading branch information
roji committed Feb 10, 2023
1 parent 8664081 commit f082129
Showing 1 changed file with 115 additions and 139 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ protected override Expression VisitExtension(Expression extensionExpression)
}

// Only applies to 'CASE WHEN condition...' not 'CASE operand WHEN...'
if (extensionExpression is CaseExpression caseExpression
&& caseExpression.Operand == null
&& caseExpression.ElseResult is CaseExpression nestedCaseExpression
&& nestedCaseExpression.Operand == null)
if (extensionExpression is CaseExpression
{
Operand: null, ElseResult: CaseExpression { Operand: null } nestedCaseExpression
} caseExpression)
{
return VisitExtension(
_sqlExpressionFactory.Case(
Expand Down Expand Up @@ -93,37 +93,29 @@ protected override Expression VisitExtension(Expression extensionExpression)
return base.VisitExtension(extensionExpression);

static bool IsCoalesce(SqlExpression sqlExpression)
=> sqlExpression is SqlFunctionExpression sqlFunctionExpression
&& sqlFunctionExpression.IsBuiltIn
&& sqlFunctionExpression.Instance == null
=> sqlExpression is SqlFunctionExpression { IsBuiltIn: true, Instance: null } sqlFunctionExpression
&& string.Equals(sqlFunctionExpression.Name, "COALESCE", StringComparison.OrdinalIgnoreCase)
&& sqlFunctionExpression.Arguments?.Count > 1;
}

private static bool IsCompareTo([NotNullWhen(true)] CaseExpression? caseExpression)
{
if (caseExpression != null
&& caseExpression.Operand == null
&& caseExpression.ElseResult == null
&& caseExpression.WhenClauses.Count == 3
&& caseExpression.WhenClauses.All(
c => c.Test is SqlBinaryExpression
&& c.Result is SqlConstantExpression constant
&& constant.Value is int))
if (caseExpression is { Operand: null, ElseResult: null, WhenClauses.Count: 3 }
&& caseExpression.WhenClauses.All(c => c is { Test: SqlBinaryExpression, Result: SqlConstantExpression { Value: int } }))
{
var whenClauses = caseExpression.WhenClauses.Select(
c => new { test = (SqlBinaryExpression)c.Test, resultValue = (int)((SqlConstantExpression)c.Result).Value! }).ToList();

if (whenClauses[0].test.Left.Equals(whenClauses[1].test.Left)
&& whenClauses[1].test.Left.Equals(whenClauses[2].test.Left)
&& whenClauses[0].test.Right.Equals(whenClauses[1].test.Right)
&& whenClauses[1].test.Right.Equals(whenClauses[2].test.Right)
&& whenClauses[0].test.OperatorType == ExpressionType.Equal
&& whenClauses[1].test.OperatorType == ExpressionType.GreaterThan
&& whenClauses[2].test.OperatorType == ExpressionType.LessThan
&& whenClauses[0].resultValue == 0
&& whenClauses[1].resultValue == 1
&& whenClauses[2].resultValue == -1)
c => new { Test = (SqlBinaryExpression)c.Test, ResultValue = (int)((SqlConstantExpression)c.Result).Value! }).ToList();

if (whenClauses[0].Test.Left.Equals(whenClauses[1].Test.Left)
&& whenClauses[1].Test.Left.Equals(whenClauses[2].Test.Left)
&& whenClauses[0].Test.Right.Equals(whenClauses[1].Test.Right)
&& whenClauses[1].Test.Right.Equals(whenClauses[2].Test.Right)
&& whenClauses[0].Test.OperatorType == ExpressionType.Equal
&& whenClauses[1].Test.OperatorType == ExpressionType.GreaterThan
&& whenClauses[2].Test.OperatorType == ExpressionType.LessThan
&& whenClauses[0].ResultValue == 0
&& whenClauses[1].ResultValue == 1
&& whenClauses[2].ResultValue == -1)
{
return true;
}
Expand All @@ -150,69 +142,57 @@ private static bool IsCompareTo([NotNullWhen(true)] CaseExpression? caseExpressi
_ => sqlBinaryExpression.OperatorType
};

switch (operatorType)
return operatorType switch
{
// CompareTo(a, b) != 0 -> a != b
// CompareTo(a, b) != 1 -> a <= b
// CompareTo(a, b) != -1 -> a >= b
case ExpressionType.NotEqual:
return (SqlExpression)Visit(
intValue switch
{
0 => _sqlExpressionFactory.NotEqual(testLeft, testRight),
1 => _sqlExpressionFactory.LessThanOrEqual(testLeft, testRight),
_ => _sqlExpressionFactory.GreaterThanOrEqual(testLeft, testRight)
});

ExpressionType.NotEqual => (SqlExpression)Visit(
intValue switch
{
0 => _sqlExpressionFactory.NotEqual(testLeft, testRight),
1 => _sqlExpressionFactory.LessThanOrEqual(testLeft, testRight),
_ => _sqlExpressionFactory.GreaterThanOrEqual(testLeft, testRight)
}),
// CompareTo(a, b) > 0 -> a > b
// CompareTo(a, b) > 1 -> false
// CompareTo(a, b) > -1 -> a >= b
case ExpressionType.GreaterThan:
return (SqlExpression)Visit(
intValue switch
{
0 => _sqlExpressionFactory.GreaterThan(testLeft, testRight),
1 => _sqlExpressionFactory.Constant(false, sqlBinaryExpression.TypeMapping),
_ => _sqlExpressionFactory.GreaterThanOrEqual(testLeft, testRight)
});

ExpressionType.GreaterThan => (SqlExpression)Visit(
intValue switch
{
0 => _sqlExpressionFactory.GreaterThan(testLeft, testRight),
1 => _sqlExpressionFactory.Constant(false, sqlBinaryExpression.TypeMapping),
_ => _sqlExpressionFactory.GreaterThanOrEqual(testLeft, testRight)
}),
// CompareTo(a, b) >= 0 -> a >= b
// CompareTo(a, b) >= 1 -> a > b
// CompareTo(a, b) >= -1 -> true
case ExpressionType.GreaterThanOrEqual:
return (SqlExpression)Visit(
intValue switch
{
0 => _sqlExpressionFactory.GreaterThanOrEqual(testLeft, testRight),
1 => _sqlExpressionFactory.GreaterThan(testLeft, testRight),
_ => _sqlExpressionFactory.Constant(true, sqlBinaryExpression.TypeMapping)
});

ExpressionType.GreaterThanOrEqual => (SqlExpression)Visit(
intValue switch
{
0 => _sqlExpressionFactory.GreaterThanOrEqual(testLeft, testRight),
1 => _sqlExpressionFactory.GreaterThan(testLeft, testRight),
_ => _sqlExpressionFactory.Constant(true, sqlBinaryExpression.TypeMapping)
}),
// CompareTo(a, b) < 0 -> a < b
// CompareTo(a, b) < 1 -> a <= b
// CompareTo(a, b) < -1 -> false
case ExpressionType.LessThan:
return (SqlExpression)Visit(
intValue switch
{
0 => _sqlExpressionFactory.LessThan(testLeft, testRight),
1 => _sqlExpressionFactory.LessThanOrEqual(testLeft, testRight),
_ => _sqlExpressionFactory.Constant(false, sqlBinaryExpression.TypeMapping)
});

// operatorType == ExpressionType.LessThanOrEqual
// CompareTo(a, b) <= 0 -> a <= b
// CompareTo(a, b) <= 1 -> true
// CompareTo(a, b) <= -1 -> a < b
default:
return (SqlExpression)Visit(
intValue switch
{
0 => _sqlExpressionFactory.LessThanOrEqual(testLeft, testRight),
1 => _sqlExpressionFactory.Constant(true, sqlBinaryExpression.TypeMapping),
_ => _sqlExpressionFactory.LessThan(testLeft, testRight)
});
}
ExpressionType.LessThan => (SqlExpression)Visit(
intValue switch
{
0 => _sqlExpressionFactory.LessThan(testLeft, testRight),
1 => _sqlExpressionFactory.LessThanOrEqual(testLeft, testRight),
_ => _sqlExpressionFactory.Constant(false, sqlBinaryExpression.TypeMapping)
}),

_ => (SqlExpression)Visit(
intValue switch
{
0 => _sqlExpressionFactory.LessThanOrEqual(testLeft, testRight),
1 => _sqlExpressionFactory.Constant(true, sqlBinaryExpression.TypeMapping),
_ => _sqlExpressionFactory.LessThan(testLeft, testRight)
})
};
}

private Expression SimplifySqlBinary(SqlBinaryExpression sqlBinaryExpression)
Expand All @@ -228,10 +208,8 @@ private Expression SimplifySqlBinary(SqlBinaryExpression sqlBinaryExpression)
// WHEN ...
// WHEN conditionN THEN resultN) == result1 -> condition1
if (sqlBinaryExpression.OperatorType == ExpressionType.Equal
&& sqlConstantComponent != null
&& sqlConstantComponent.Value != null
&& caseComponent != null
&& caseComponent.Operand == null
&& sqlConstantComponent?.Value is not null
&& caseComponent?.Operand is not null
&& caseComponent.ElseResult == null)
{
var matchingCaseBlock = caseComponent.WhenClauses.FirstOrDefault(wc => sqlConstantComponent.Equals(wc.Result));
Expand All @@ -244,13 +222,13 @@ private Expression SimplifySqlBinary(SqlBinaryExpression sqlBinaryExpression)
// CompareTo specific optimizations
if (sqlConstantComponent != null
&& IsCompareTo(caseComponent)
&& sqlConstantComponent.Value is int intValue
&& (intValue > -2 && intValue < 2)
&& (sqlBinaryExpression.OperatorType == ExpressionType.NotEqual
|| sqlBinaryExpression.OperatorType == ExpressionType.GreaterThan
|| sqlBinaryExpression.OperatorType == ExpressionType.GreaterThanOrEqual
|| sqlBinaryExpression.OperatorType == ExpressionType.LessThan
|| sqlBinaryExpression.OperatorType == ExpressionType.LessThanOrEqual))
&& sqlConstantComponent.Value is int intValue and > -2 and < 2
&& sqlBinaryExpression.OperatorType
is ExpressionType.NotEqual
or ExpressionType.GreaterThan
or ExpressionType.GreaterThanOrEqual
or ExpressionType.LessThan
or ExpressionType.LessThanOrEqual)
{
return OptimizeCompareTo(
sqlBinaryExpression,
Expand All @@ -261,8 +239,7 @@ private Expression SimplifySqlBinary(SqlBinaryExpression sqlBinaryExpression)
var left = (SqlExpression)Visit(sqlBinaryExpression.Left);
var right = (SqlExpression)Visit(sqlBinaryExpression.Right);

if (sqlBinaryExpression.OperatorType == ExpressionType.AndAlso
|| sqlBinaryExpression.OperatorType == ExpressionType.OrElse)
if (sqlBinaryExpression.OperatorType is ExpressionType.AndAlso or ExpressionType.OrElse)
{
if (TryGetInExpressionCandidateInfo(left, out var leftCandidateInfo)
&& TryGetInExpressionCandidateInfo(right, out var rightCandidateInfo)
Expand All @@ -286,47 +263,49 @@ private Expression SimplifySqlBinary(SqlBinaryExpression sqlBinaryExpression)
object rightValue;
List<object> resultArray;

if (!leftConstantIsEnumerable && !rightConstantIsEnumerable)
{
// comparison + comparison
leftValue = leftCandidateInfo.ConstantValue;
rightValue = rightCandidateInfo.ConstantValue;

// for relational nulls we can't combine comparisons that contain null
// a != 1 && a != null would be converted to a NOT IN (1, null), which never returns any results
// we need to keep it in the original form so that a != null gets converted to a IS NOT NULL instead
// for c# null semantics it's fine because null semantics visitor extracts null back into proper null checks
if (_useRelationalNulls && (leftValue == null || rightValue == null))
{
return sqlBinaryExpression.Update(left, right);
}

resultArray = ConstructCollection(leftValue, rightValue);
}
else if (leftConstantIsEnumerable && rightConstantIsEnumerable)
switch ((leftConstantIsEnumerable, rightConstantIsEnumerable))
{
// in + in
leftValue = leftCandidateInfo.ConstantValue;
rightValue = rightCandidateInfo.ConstantValue;
resultArray = UnionCollections((IEnumerable)leftValue, (IEnumerable)rightValue);
}
else
{
// in + comparison
leftValue = leftConstantIsEnumerable
? leftCandidateInfo.ConstantValue
: rightCandidateInfo.ConstantValue;

rightValue = leftConstantIsEnumerable
? rightCandidateInfo.ConstantValue
: leftCandidateInfo.ConstantValue;

if (_useRelationalNulls && rightValue == null)
{
return sqlBinaryExpression.Update(left, right);
}

resultArray = AddToCollection((IEnumerable)leftValue, rightValue);
case (false, false):
// comparison + comparison
leftValue = leftCandidateInfo.ConstantValue;
rightValue = rightCandidateInfo.ConstantValue;

// for relational nulls we can't combine comparisons that contain null
// a != 1 && a != null would be converted to a NOT IN (1, null), which never returns any results
// we need to keep it in the original form so that a != null gets converted to a IS NOT NULL instead
// for c# null semantics it's fine because null semantics visitor extracts null back into proper null checks
if (_useRelationalNulls && (leftValue == null || rightValue == null))
{
return sqlBinaryExpression.Update(left, right);
}

resultArray = ConstructCollection(leftValue, rightValue);
break;

case (true, true):
// in + in
leftValue = leftCandidateInfo.ConstantValue;
rightValue = rightCandidateInfo.ConstantValue;
resultArray = UnionCollections((IEnumerable)leftValue, (IEnumerable)rightValue);
break;

default:
// in + comparison
leftValue = leftConstantIsEnumerable
? leftCandidateInfo.ConstantValue
: rightCandidateInfo.ConstantValue;

rightValue = leftConstantIsEnumerable
? rightCandidateInfo.ConstantValue
: leftCandidateInfo.ConstantValue;

if (_useRelationalNulls && rightValue == null)
{
return sqlBinaryExpression.Update(left, right);
}

resultArray = AddToCollection((IEnumerable)leftValue, rightValue);
break;
}

return _sqlExpressionFactory.In(
Expand Down Expand Up @@ -422,8 +401,7 @@ private static List<object> BuildListFromEnumerable(IEnumerable collection)
out (ColumnExpression ColumnExpression, object ConstantValue, RelationalTypeMapping TypeMapping, ExpressionType OperationType)
candidateInfo)
{
if (sqlExpression is SqlUnaryExpression sqlUnaryExpression
&& sqlUnaryExpression.OperatorType == ExpressionType.Not)
if (sqlExpression is SqlUnaryExpression { OperatorType: ExpressionType.Not } sqlUnaryExpression)
{
if (TryGetInExpressionCandidateInfo(sqlUnaryExpression.Operand, out var inner))
{
Expand All @@ -433,9 +411,7 @@ private static List<object> BuildListFromEnumerable(IEnumerable collection)
return true;
}
}
else if (sqlExpression is SqlBinaryExpression sqlBinaryExpression
&& (sqlBinaryExpression.OperatorType == ExpressionType.Equal
|| sqlBinaryExpression.OperatorType == ExpressionType.NotEqual))
else if (sqlExpression is SqlBinaryExpression { OperatorType: ExpressionType.Equal or ExpressionType.NotEqual } sqlBinaryExpression)
{
var column = (sqlBinaryExpression.Left as ColumnExpression ?? sqlBinaryExpression.Right as ColumnExpression);
var constant = (sqlBinaryExpression.Left as SqlConstantExpression ?? sqlBinaryExpression.Right as SqlConstantExpression);
Expand All @@ -446,10 +422,10 @@ private static List<object> BuildListFromEnumerable(IEnumerable collection)
return true;
}
}
else if (sqlExpression is InExpression inExpression
&& inExpression.Item is ColumnExpression column
&& inExpression.Subquery == null
&& inExpression.Values is SqlConstantExpression valuesConstant)
else if (sqlExpression is InExpression
{
Item: ColumnExpression column, Subquery: null, Values: SqlConstantExpression valuesConstant
} inExpression)
{
candidateInfo = (column, valuesConstant.Value!, valuesConstant.TypeMapping!,
inExpression.IsNegated ? ExpressionType.NotEqual : ExpressionType.Equal);
Expand Down

0 comments on commit f082129

Please sign in to comment.