Skip to content

Commit

Permalink
ExecuteUpdate: Allow using other tables in the query to generate resu…
Browse files Browse the repository at this point in the history
…lt set

Part of #795
  • Loading branch information
smitpatel committed Aug 15, 2022
1 parent adf8eb4 commit 26d4fe5
Show file tree
Hide file tree
Showing 8 changed files with 524 additions and 26 deletions.
Expand Up @@ -3,7 +3,6 @@

using System.Collections;
using System.Collections.Concurrent;
using Microsoft.EntityFrameworkCore.Query.SqlExpressions;
using Microsoft.Extensions.Caching.Memory;

namespace Microsoft.EntityFrameworkCore.Query.Internal;
Expand Down
53 changes: 48 additions & 5 deletions src/EFCore.Relational/Query/QuerySqlGenerator.cs
Expand Up @@ -1237,8 +1237,6 @@ protected override Expression VisitUpdate(UpdateExpression updateExpression)
&& selectExpression.Having == null
&& selectExpression.Orderings.Count == 0
&& selectExpression.GroupBy.Count == 0
&& selectExpression.Tables.Count == 1
&& selectExpression.Tables[0] == updateExpression.Table
&& selectExpression.Projection.Count == 0)
{
_relationalCommandBuilder.Append("UPDATE ");
Expand All @@ -1255,13 +1253,58 @@ protected override Expression VisitUpdate(UpdateExpression updateExpression)
},
joinAction: e => e.AppendLine(","));
_relationalCommandBuilder.AppendLine();
}

if (selectExpression.Predicate != null)
var predicate = selectExpression.Predicate;
var firstTablePrinted = false;
if (selectExpression.Tables.Count > 1)
{
_relationalCommandBuilder.AppendLine().Append("FROM ");
for (var i = 0; i < selectExpression.Tables.Count; i++)
{
var table = selectExpression.Tables[i];
var joinExpression = table as JoinExpressionBase;

if (ReferenceEquals(updateExpression.Table, joinExpression?.Table ?? table))
{
LiftPredicate(table);
continue;
}

if (firstTablePrinted)
{
_relationalCommandBuilder.AppendLine();
}
else
{
firstTablePrinted = true;
LiftPredicate(table);
table = joinExpression?.Table ?? table;
}

Visit(table);

void LiftPredicate(TableExpressionBase joinTable)
{
if (joinTable is PredicateJoinExpressionBase predicateJoinExpression)
{
predicate = predicate == null
? predicateJoinExpression.JoinPredicate
: new SqlBinaryExpression(
ExpressionType.AndAlso,
predicateJoinExpression.JoinPredicate,
predicate,
typeof(bool),
predicate.TypeMapping);
}
}
}
}

if (predicate != null)
{
_relationalCommandBuilder.AppendLine().Append("WHERE ");
Visit(selectExpression.Predicate);
Visit(predicate);
}

return updateExpression;
Expand Down
Expand Up @@ -1105,6 +1105,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
foreach (var (propertyExpression, valueExpression) in propertyValueLambdaExpressions)
{
var left = RemapLambdaBody(source, propertyExpression);
left = left.UnwrapTypeConversion(out _);
if (!IsValidPropertyAccess(left, out var ese))
{
AddTranslationErrorDetails(RelationalStrings.InvalidPropertyInSetProperty(propertyExpression.Print()));
Expand All @@ -1123,6 +1124,10 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
}

var right = RemapLambdaBody(source, valueExpression);
if (right.Type != left.Type)
{
right = Expression.Convert(right, left.Type);
}
// We generate equality between property = value while translating sothat value infer tye type mapping from property correctly.
// Later we decompose it back into left/right components so that the equality is not in the tree which can get affected by
// null semantics or other visitor.
Expand Down Expand Up @@ -1280,7 +1285,7 @@ static bool IsValidPropertyAccess(Expression expression, [NotNullWhen(true)] out
/// <param name="selectExpression">The select expression to validate.</param>
/// <param name="entityShaperExpression">The entity shaper expression on which the delete operation is being applied.</param>
/// <param name="tableExpression">The table expression from which rows are being deleted.</param>
/// <returns> das </returns>
/// <returns>Returns <see langword="true" /> if the current select expression can be used for delete as-is, <see langword="false" /> otherwise.</returns>
protected virtual bool IsValidSelectExpressionForExecuteDelete(
SelectExpression selectExpression,
EntityShaperExpression entityShaperExpression,
Expand All @@ -1305,13 +1310,12 @@ static bool IsValidPropertyAccess(Expression expression, [NotNullWhen(true)] out
return false;
}

// TODO: Update this documentation.
/// <summary>
/// Validates if the current select expression can be used for execute update operation or it requires to be pushed into a subquery.
/// Validates if the current select expression can be used for execute update operation or it requires to be joined as a subquery.
/// </summary>
/// <remarks>
/// <para>
/// By default, only single-table select expressions are supported, and optionally with a predicate.
/// By default, only muli-table select expressions are supported, and optionally with a predicate.
/// </para>
/// <para>
/// Providers can override this to allow more select expression features to be supported without pushing down into a subquery.
Expand All @@ -1322,7 +1326,7 @@ static bool IsValidPropertyAccess(Expression expression, [NotNullWhen(true)] out
/// <param name="selectExpression">The select expression to validate.</param>
/// <param name="entityShaperExpression">The entity shaper expression on which the update operation is being applied.</param>
/// <param name="tableExpression">The table expression from which rows are being deleted.</param>
/// <returns> das </returns>
/// <returns>Returns <see langword="true" /> if the current select expression can be used for update as-is, <see langword="false" /> otherwise.</returns>
protected virtual bool IsValidSelectExpressionForExecuteUpdate(
SelectExpression selectExpression,
EntityShaperExpression entityShaperExpression,
Expand All @@ -1334,13 +1338,30 @@ static bool IsValidPropertyAccess(Expression expression, [NotNullWhen(true)] out
&& (!selectExpression.IsDistinct || entityShaperExpression.EntityType.FindPrimaryKey() != null)
&& selectExpression.GroupBy.Count == 0
&& selectExpression.Having == null
&& selectExpression.Orderings.Count == 0
&& selectExpression.Tables.Count == 1
&& selectExpression.Tables[0] is TableExpression expression)
&& selectExpression.Orderings.Count == 0)
{
tableExpression = expression;
TableExpressionBase table;
if (selectExpression.Tables.Count == 1)
{
table = selectExpression.Tables[0];
}
else
{
var projectionBindingExpression = (ProjectionBindingExpression)entityShaperExpression.ValueBufferExpression;
var entityProjectionExpression = (EntityProjectionExpression)selectExpression.GetProjection(projectionBindingExpression);
var column = entityProjectionExpression.BindProperty(entityShaperExpression.EntityType.GetProperties().First());
table = column.Table;
if (table is JoinExpressionBase joinExpressionBase)
{
table = joinExpressionBase.Table;
}
}

return true;
if (table is TableExpression te)
{
tableExpression = te;
return true;
}
}

tableExpression = null;
Expand Down
Expand Up @@ -80,15 +80,14 @@ protected override Expression VisitUpdate(UpdateExpression updateExpression)
var selectExpression = updateExpression.SelectExpression;

if (selectExpression.Offset == null
&& selectExpression.Limit == null
&& selectExpression.Having == null
&& selectExpression.Orderings.Count == 0
&& selectExpression.GroupBy.Count == 0
&& selectExpression.Tables.Count == 1
&& selectExpression.Tables[0] == updateExpression.Table
&& selectExpression.Projection.Count == 0)
{
Sql.Append("UPDATE ");
GenerateTop(selectExpression);

Sql.AppendLine($"{Dependencies.SqlGenerationHelper.DelimitIdentifier(updateExpression.Table.Alias)}");
using (Sql.Indent())
{
Expand Down
Expand Up @@ -146,6 +146,52 @@ protected override Expression VisitExtension(Expression extensionExpression)
return false;
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected override bool IsValidSelectExpressionForExecuteUpdate(
SelectExpression selectExpression,
EntityShaperExpression entityShaperExpression,
[NotNullWhen(true)] out TableExpression? tableExpression)
{
if (selectExpression.Offset == null
// If entity type has primary key then Distinct is no-op
&& (!selectExpression.IsDistinct || entityShaperExpression.EntityType.FindPrimaryKey() != null)
&& selectExpression.GroupBy.Count == 0
&& selectExpression.Having == null
&& selectExpression.Orderings.Count == 0)
{
TableExpressionBase table;
if (selectExpression.Tables.Count == 1)
{
table = selectExpression.Tables[0];
}
else
{
var projectionBindingExpression = (ProjectionBindingExpression)entityShaperExpression.ValueBufferExpression;
var entityProjectionExpression = (EntityProjectionExpression)selectExpression.GetProjection(projectionBindingExpression);
var column = entityProjectionExpression.BindProperty(entityShaperExpression.EntityType.GetProperties().First());
table = column.Table;
if (table is JoinExpressionBase joinExpressionBase)
{
table = joinExpressionBase.Table;
}
}

if (table is TableExpression te)
{
tableExpression = te;
return true;
}
}

tableExpression = null;
return false;
}

private sealed class TemporalAnnotationApplyingExpressionVisitor : ExpressionVisitor
{
private readonly Func<TableExpression, TableExpressionBase> _annotationApplyingFunc;
Expand Down
Expand Up @@ -343,6 +343,29 @@ public virtual Task Update_where_constant(bool async)
rowsAffectedCount: 8,
(b, a) => a.ForEach(c => Assert.Equal("Updated", c.ContactName)));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Update_where_parameter_in_predicate(bool async)
{
var customer = "ALFKI";
await AssertUpdate(
async,
ss => ss.Set<Customer>().Where(c => c.CustomerID == customer),
e => e,
s => s.SetProperty(c => c.ContactName, c => "Updated"),
rowsAffectedCount: 1,
(b, a) => a.ForEach(c => Assert.Equal("Updated", c.ContactName)));

customer = null;
await AssertUpdate(
async,
ss => ss.Set<Customer>().Where(c => c.CustomerID == customer),
e => e,
s => s.SetProperty(c => c.ContactName, c => "Updated"),
rowsAffectedCount: 0,
(b, a) => a.ForEach(c => Assert.Equal("Updated", c.ContactName)));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Update_where_parameter(bool async)
Expand All @@ -357,6 +380,113 @@ public virtual Task Update_where_parameter(bool async)
(b, a) => a.ForEach(c => Assert.Equal("Abc", c.ContactName)));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Update_where_take_constant(bool async)
=> AssertUpdate(
async,
ss => ss.Set<Customer>().Where(c => c.CustomerID.StartsWith("F")).Take(4),
e => e,
s => s.SetProperty(c => c.ContactName, c => "Updated"),
rowsAffectedCount: 4,
(b, a) => a.ForEach(c => Assert.Equal("Updated", c.ContactName)));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Update_where_group_by_aggregate_constant(bool async)
=> AssertUpdate(
async,
ss => ss.Set<Customer>()
.Where(c => c.CustomerID == ss.Set<Order>()
.GroupBy(e => e.CustomerID).Where(g => g.Count() > 11).Select(e => e.Key).FirstOrDefault()),
e => e,
s => s.SetProperty(c => c.ContactName, c => "Updated"),
rowsAffectedCount: 1,
(b, a) => a.ForEach(c => Assert.Equal("Updated", c.ContactName)));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Update_where_group_by_first_constant(bool async)
=> AssertUpdate(
async,
ss => ss.Set<Customer>()
.Where(c => c.CustomerID == ss.Set<Order>()
.GroupBy(e => e.CustomerID).Where(g => g.Count() > 11).Select(e => e.First().CustomerID).FirstOrDefault()),
e => e,
s => s.SetProperty(c => c.ContactName, c => "Updated"),
rowsAffectedCount: 1,
(b, a) => a.ForEach(c => Assert.Equal("Updated", c.ContactName)));

[ConditionalTheory(Skip = "Issue#26753")]
[MemberData(nameof(IsAsyncData))]
public virtual Task Update_where_group_by_first_constant_2(bool async)
=> AssertUpdate(
async,
ss => ss.Set<Customer>()
.Where(c => c == ss.Set<Order>()
.GroupBy(e => e.CustomerID).Where(g => g.Count() > 11).Select(e => e.First().Customer).FirstOrDefault()),
e => e,
s => s.SetProperty(c => c.ContactName, c => "Updated"),
rowsAffectedCount: 1,
(b, a) => a.ForEach(c => Assert.Equal("Updated", c.ContactName)));

[ConditionalTheory(Skip = "Issue#28524")]
[MemberData(nameof(IsAsyncData))]
public virtual Task Update_where_group_by_first_constant_3(bool async)
=> AssertUpdate(
async,
ss => ss.Set<Customer>()
.Where(c => ss.Set<Order>()
.GroupBy(e => e.CustomerID).Where(g => g.Count() > 11).Select(e => e.First().Customer).Contains(c)),
e => e,
s => s.SetProperty(c => c.ContactName, c => "Updated"),
rowsAffectedCount: 1,
(b, a) => a.ForEach(c => Assert.Equal("Updated", c.ContactName)));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Update_where_distinct_constant(bool async)
=> AssertUpdate(
async,
ss => ss.Set<Customer>().Where(c => c.CustomerID.StartsWith("F")).Distinct(),
e => e,
s => s.SetProperty(c => c.ContactName, c => "Updated"),
rowsAffectedCount: 8,
(b, a) => a.ForEach(c => Assert.Equal("Updated", c.ContactName)));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Update_where_using_navigation(bool async)
=> AssertUpdate(
async,
ss => ss.Set<Order>().Where(o => o.Customer.City == "Seattle"),
e => e,
s => s.SetProperty(c => c.OrderDate, c => null),
rowsAffectedCount: 14,
(b, a) => a.ForEach(c => Assert.Null(c.OrderDate)));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Update_where_using_navigation_2(bool async)
=> AssertUpdate(
async,
ss => ss.Set<OrderDetail>().Where(od => od.Order.Customer.City == "Seattle"),
e => e,
s => s.SetProperty(c => c.Quantity, c => 1),
rowsAffectedCount: 40,
(b, a) => a.ForEach(c => Assert.Equal(1, c.Quantity)));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Update_where_select_many(bool async)
=> AssertUpdate(
async,
ss => ss.Set<Customer>().Where(c => c.CustomerID.StartsWith("F")).SelectMany(c => c.Orders),
e => e,
s => s.SetProperty(c => c.OrderDate, c => null),
rowsAffectedCount: 63,
(b, a) => a.ForEach(c => Assert.Null(c.OrderDate)));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Update_where_using_property_plus_constant(bool async)
Expand Down

0 comments on commit 26d4fe5

Please sign in to comment.