Skip to content

Commit

Permalink
Normalize Any to Contains instead of vice versa
Browse files Browse the repository at this point in the history
E.g. for easier pattern-matching of Contains
  • Loading branch information
roji committed May 22, 2023
1 parent c983103 commit 09f4c85
Show file tree
Hide file tree
Showing 11 changed files with 218 additions and 178 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -419,11 +419,6 @@ private static ShapedQueryExpression CreateShapedQueryExpression(IEntityType ent
? nestedOperand
: _sqlExpressionFactory.Not(translation));

if (TrySimplifyValuesToInExpression(source, isNegated: true, out var simplifiedQuery))
{
return simplifiedQuery;
}

selectExpression.ReplaceProjection(new List<Expression>());
selectExpression.ApplyProjection();
if (selectExpression.Limit == null
Expand Down Expand Up @@ -452,24 +447,19 @@ private static ShapedQueryExpression CreateShapedQueryExpression(IEntityType ent
}

source = translatedSource;

if (TrySimplifyValuesToInExpression(source, isNegated: false, out var simplifiedQuery))
{
return simplifiedQuery;
}
}

var selectExpression = (SelectExpression)source.QueryExpression;
selectExpression.ReplaceProjection(new List<Expression>());
selectExpression.ApplyProjection();
if (selectExpression.Limit == null
&& selectExpression.Offset == null)
var subquery = (SelectExpression)source.QueryExpression;
subquery.ReplaceProjection(new List<Expression>());
subquery.ApplyProjection();
if (subquery.Limit == null
&& subquery.Offset == null)
{
selectExpression.ClearOrdering();
subquery.ClearOrdering();
}

var translation = _sqlExpressionFactory.Exists(selectExpression, false);
selectExpression = _sqlExpressionFactory.Select(translation);
var translation = _sqlExpressionFactory.Exists(subquery, false);
var selectExpression = _sqlExpressionFactory.Select(translation);

return source.Update(
selectExpression,
Expand Down Expand Up @@ -501,47 +491,65 @@ private static ShapedQueryExpression CreateShapedQueryExpression(IEntityType ent
/// <inheritdoc />
protected override ShapedQueryExpression? TranslateContains(ShapedQueryExpression source, Expression item)
{
var selectExpression = (SelectExpression)source.QueryExpression;
var translation = TranslateExpression(item);
if (translation == null)
{
return null;
}

if (selectExpression.Limit == null
&& selectExpression.Offset == null)
{
selectExpression.ClearOrdering();
}

var shaperExpression = source.ShaperExpression;
// No need to check ConvertChecked since this is convert node which we may have added during projection
if (shaperExpression is UnaryExpression { NodeType: ExpressionType.Convert } unaryExpression
&& unaryExpression.Operand.Type.IsNullableType()
&& unaryExpression.Operand.Type.UnwrapNullableType() == unaryExpression.Type)
{
shaperExpression = unaryExpression.Operand;
}

if (shaperExpression is ProjectionBindingExpression projectionBindingExpression)
// Pattern-match Contains over ValuesExpression, translating to simplified 'item IN (1, 2, 3)' with constant elements
if (source.QueryExpression is SelectExpression
{
Tables:
[
ValuesExpression
{
RowValues: [{ Values.Count: 2 }, ..],
ColumnNames: [ValuesOrderingColumnName, ValuesValueColumnName]
} valuesExpression
],
Predicate: null,
GroupBy: [],
Having: null,
IsDistinct: false,
Limit: null,
Offset: null,
// Note that in the context of Contains we don't care about orderings
}
// Make sure that the source projects the column from the ValuesExpression directly, i.e. no projection out with some expression
&& TryGetProjection(source, out var projection)
&& projection is ColumnExpression projectedColumn
&& projectedColumn.Table == valuesExpression)
{
var projection = selectExpression.GetProjection(projectionBindingExpression);
if (projection is SqlExpression sqlExpression)
if (TranslateExpression(item) is not SqlExpression translatedItem)
{
selectExpression.ReplaceProjection(new List<Expression> { sqlExpression });
selectExpression.ApplyProjection();
return null;
}

translation = _sqlExpressionFactory.In(translation, selectExpression, false);
selectExpression = _sqlExpressionFactory.Select(translation);
var values = new object?[valuesExpression.RowValues.Count];
for (var i = 0; i < values.Length; i++)
{
// Skip the first value (_ord), which is irrelevant for Contains
if (valuesExpression.RowValues[i].Values[1] is SqlConstantExpression { Value: var constantValue })
{
values[i] = constantValue;
}
else
{
// We only support constants for now
values = null;
break;
}
}

return source.Update(
selectExpression,
Expression.Convert(
new ProjectionBindingExpression(selectExpression, new ProjectionMember(), typeof(bool?)), typeof(bool)));
if (values is not null)
{
var inExpression = _sqlExpressionFactory.In(translatedItem, _sqlExpressionFactory.Constant(values), negated: false);
return source.Update(_sqlExpressionFactory.Select(inExpression), source.ShaperExpression);
}
}

return null;
// TODO: This generates an EXISTS subquery. Translate to IN instead: #30955
var anyLambdaParameter = Expression.Parameter(item.Type, "p");
var anyLambda = Expression.Lambda(
Infrastructure.ExpressionExtensions.CreateEqualsExpression(anyLambdaParameter, item),
anyLambdaParameter);

return TranslateAny(source, anyLambda);
}

/// <inheritdoc />
Expand Down Expand Up @@ -1812,89 +1820,6 @@ static Expression GetEntitySource(IModel model, Expression propertyAccessExpress
protected virtual bool IsOrdered(SelectExpression selectExpression)
=> selectExpression.Orderings.Count > 0;

/// <summary>
/// Attempts to pattern-match for Contains over <see cref="ValuesExpression" />, which corresponds to
/// <c>Where(b => new[] { 1, 2, 3 }.Contains(b.Id))</c>. Simplifies this to the tighter <c>[b].[Id] IN (1, 2, 3)</c> instead of the
/// full subquery with VALUES.
/// </summary>
private bool TrySimplifyValuesToInExpression(
ShapedQueryExpression source,
bool isNegated,
[NotNullWhen(true)] out ShapedQueryExpression? simplifiedQuery)
{
if (source.QueryExpression is SelectExpression
{
Tables: [ValuesExpression
{
RowValues: [{ Values.Count: 2 }, ..],
ColumnNames: [ ValuesOrderingColumnName, ValuesValueColumnName ]
} valuesExpression],
GroupBy: [],
Having: null,
IsDistinct: false,
Limit: null,
Offset: null,
// Note that we don't care about orderings, they get elided anyway by Any/All
Predicate: SqlBinaryExpression { OperatorType: ExpressionType.Equal, Left: var left, Right: var right },
} selectExpression)
{
// The table is a ValuesExpression, and the predicate is an equality - this is a possible simplifiable Contains.
// Get the projection column pointing to the ValuesExpression, and check that it's compared to on one side of the predicate
// equality.
var shaperExpression = source.ShaperExpression;
if (shaperExpression is UnaryExpression { NodeType: ExpressionType.Convert } unaryExpression
&& unaryExpression.Operand.Type.IsNullableType()
&& unaryExpression.Operand.Type.UnwrapNullableType() == unaryExpression.Type)
{
shaperExpression = unaryExpression.Operand;
}

if (shaperExpression is ProjectionBindingExpression projectionBindingExpression
&& selectExpression.GetProjection(projectionBindingExpression) is ColumnExpression projectionColumn)
{
SqlExpression item;

if (left is ColumnExpression leftColumn
&& (leftColumn.Table, leftColumn.Name) == (projectionColumn.Table, projectionColumn.Name))
{
item = right;
}
else if (right is ColumnExpression rightColumn
&& (rightColumn.Table, rightColumn.Name) == (projectionColumn.Table, projectionColumn.Name))
{
item = left;
}
else
{
simplifiedQuery = null;
return false;
}

var values = new object?[valuesExpression.RowValues.Count];
for (var i = 0; i < values.Length; i++)
{
// Skip the first value (_ord), which is irrelevant for Contains
if (valuesExpression.RowValues[i].Values[1] is SqlConstantExpression { Value: var constantValue })
{
values[i] = constantValue;
}
else
{
simplifiedQuery = null;
return false;
}
}

var inExpression = _sqlExpressionFactory.In(item, _sqlExpressionFactory.Constant(values), isNegated);
simplifiedQuery = source.Update(_sqlExpressionFactory.Select(inExpression), source.ShaperExpression);
return true;
}
}

simplifiedQuery = null;
return false;
}

private Expression RemapLambdaBody(ShapedQueryExpression shapedQueryExpression, LambdaExpression lambdaExpression)
{
var lambdaBody = ReplacingExpressionVisitor.Replace(
Expand Down Expand Up @@ -2569,6 +2494,29 @@ private static Expression MatchShaperNullabilityForSetOperation(Expression shape
return source.UpdateShaperExpression(shaper);
}

private bool TryGetProjection(ShapedQueryExpression shapedQueryExpression, [NotNullWhen(true)] out SqlExpression? projection)
{
var shaperExpression = shapedQueryExpression.ShaperExpression;
// No need to check ConvertChecked since this is convert node which we may have added during projection
if (shaperExpression is UnaryExpression { NodeType: ExpressionType.Convert } unaryExpression
&& unaryExpression.Operand.Type.IsNullableType()
&& unaryExpression.Operand.Type.UnwrapNullableType() == unaryExpression.Type)
{
shaperExpression = unaryExpression.Operand;
}

if (shapedQueryExpression.QueryExpression is SelectExpression selectExpression
&& shaperExpression is ProjectionBindingExpression projectionBindingExpression
&& selectExpression.GetProjection(projectionBindingExpression) is SqlExpression sqlExpression)
{
projection = sqlExpression;
return true;
}

projection = null;
return false;
}

/// <summary>
/// A visitor which scans an expression tree and attempts to find columns for which we were missing type mappings (projected out
/// of queryable constant/parameter), and those type mappings have been inferred.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1133,9 +1133,7 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression)
var operand = Visit(unaryExpression.Operand);

if (operand is EntityReferenceExpression entityReferenceExpression
&& (unaryExpression.NodeType == ExpressionType.Convert
|| unaryExpression.NodeType == ExpressionType.ConvertChecked
|| unaryExpression.NodeType == ExpressionType.TypeAs))
&& unaryExpression.NodeType is ExpressionType.Convert or ExpressionType.ConvertChecked or ExpressionType.TypeAs)
{
return entityReferenceExpression.Convert(unaryExpression.Type);
}
Expand All @@ -1148,7 +1146,13 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression)
switch (unaryExpression.NodeType)
{
case ExpressionType.Not:
return _sqlExpressionFactory.Not(sqlOperand!);
return sqlOperand switch
{
ExistsExpression e => e.Negate(),
InExpression e => e.Negate(),

_ => _sqlExpressionFactory.Not(sqlOperand!)
};

case ExpressionType.Negate:
case ExpressionType.NegateChecked:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@ public class ExistsExpression : SqlExpression
protected override Expression VisitChildren(ExpressionVisitor visitor)
=> Update((SelectExpression)visitor.Visit(Subquery));

/// <summary>
/// Negates this expression by changing presence/absence state indicated by <see cref="IsNegated" />.
/// </summary>
/// <returns>An expression which is negated form of this expression.</returns>
public virtual ExistsExpression Negate()
=> new(Subquery, !IsNegated, TypeMapping);

/// <summary>
/// Creates a new expression that is like this one, but using the supplied children. If all of the children are the same, it will
/// return this expression.
Expand Down
59 changes: 26 additions & 33 deletions src/EFCore/Query/Internal/QueryOptimizingExpressionVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -210,44 +210,38 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
result);
}

// Normalize x.Any(i => i == foo) to x.Contains(foo)
// And x.All(i => i != foo) to !x.Contains(foo)
if (methodCallExpression.Method.IsGenericMethod
&& methodCallExpression.Method.GetGenericMethodDefinition() is MethodInfo methodInfo
&& (methodInfo.Equals(EnumerableMethods.AnyWithPredicate) || methodInfo.Equals(EnumerableMethods.All))
&& methodCallExpression.Arguments[0].NodeType is ExpressionType nodeType
&& (nodeType == ExpressionType.Parameter || nodeType == ExpressionType.Constant)
&& methodCallExpression.Arguments[1] is LambdaExpression lambda
&& TryExtractEqualityOperands(lambda.Body, out var left, out var right, out var negated)
&& (left is ParameterExpression || right is ParameterExpression))
&& (methodInfo == EnumerableMethods.AnyWithPredicate || methodInfo == EnumerableMethods.All || methodInfo == QueryableMethods.AnyWithPredicate || methodInfo == QueryableMethods.All)
&& methodCallExpression.Arguments[1].UnwrapLambdaFromQuote() is var lambda
&& TryExtractEqualityOperands(lambda.Body, out var left, out var right, out var negated))
{
var nonParameterExpression = left is ParameterExpression ? right : left;
var itemExpression = left == lambda.Parameters[0]
? right
: right == lambda.Parameters[0]
? left
: null;

if (methodInfo.Equals(EnumerableMethods.AnyWithPredicate)
&& !negated)
if (itemExpression is not null)
{
var containsMethod = EnumerableMethods.Contains.MakeGenericMethod(methodCallExpression.Method.GetGenericArguments()[0]);
return Expression.Call(null, containsMethod, methodCallExpression.Arguments[0], nonParameterExpression);
}
var containsMethodDefinition = methodInfo.DeclaringType == typeof(Enumerable)
? EnumerableMethods.Contains
: QueryableMethods.Contains;

if (methodInfo.Equals(EnumerableMethods.All) && negated)
{
var containsMethod = EnumerableMethods.Contains.MakeGenericMethod(methodCallExpression.Method.GetGenericArguments()[0]);
return Expression.Not(Expression.Call(null, containsMethod, methodCallExpression.Arguments[0], nonParameterExpression));
}
}

if (methodCallExpression.Method.IsGenericMethod
&& methodCallExpression.Method.GetGenericMethodDefinition() is MethodInfo containsMethodInfo
&& containsMethodInfo.Equals(QueryableMethods.Contains))
{
var typeArgument = methodCallExpression.Method.GetGenericArguments()[0];
var anyMethod = QueryableMethods.AnyWithPredicate.MakeGenericMethod(typeArgument);

var anyLambdaParameter = Expression.Parameter(typeArgument, "p");
var anyLambda = Expression.Lambda(
ExpressionExtensions.CreateEqualsExpression(anyLambdaParameter, methodCallExpression.Arguments[1]),
anyLambdaParameter);
if ((methodInfo == EnumerableMethods.AnyWithPredicate || methodInfo == QueryableMethods.AnyWithPredicate) && !negated)
{
var containsMethod = containsMethodDefinition.MakeGenericMethod(methodCallExpression.Method.GetGenericArguments()[0]);
return Expression.Call(null, containsMethod, methodCallExpression.Arguments[0], itemExpression);
}

return Expression.Call(null, anyMethod, new[] { Visit(methodCallExpression.Arguments[0]), anyLambda });
if ((methodInfo == EnumerableMethods.All || methodInfo == QueryableMethods.All) && negated)
{
var containsMethod = containsMethodDefinition.MakeGenericMethod(methodCallExpression.Method.GetGenericArguments()[0]);
return Expression.Not(Expression.Call(null, containsMethod, methodCallExpression.Arguments[0], itemExpression));
}
}
}

var @object = default(Expression);
Expand Down Expand Up @@ -409,8 +403,7 @@ private static Expression MatchExpressionType(Expression expression, Type typeTo
(left, right) = (binaryExpression.Left, binaryExpression.Right);
return true;

case MethodCallExpression methodCallExpression
when methodCallExpression.Method.Name == nameof(object.Equals):
case MethodCallExpression { Method.Name: nameof(object.Equals) } methodCallExpression:
{
negated = false;
if (methodCallExpression.Arguments.Count == 1
Expand Down
Loading

0 comments on commit 09f4c85

Please sign in to comment.