Skip to content

Commit

Permalink
Query: Add condition when accessing property on optional dependent wh…
Browse files Browse the repository at this point in the history
…ich is shared (#25949)

Resolves #23230
  • Loading branch information
smitpatel committed Sep 10, 2021
1 parent 31afd74 commit 2da38e0
Show file tree
Hide file tree
Showing 21 changed files with 397 additions and 225 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -923,8 +923,63 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression)
{
return null;
}
var entityProjectionExpression = (EntityProjectionExpression)valueBufferExpression;
var propertyAccess = entityProjectionExpression.BindProperty(property);

return ((EntityProjectionExpression)valueBufferExpression).BindProperty(property);
var entityType = entityReferenceExpression.EntityType;
var table = entityType.GetViewOrTableMappings().FirstOrDefault()?.Table;
if ((table?.IsOptional(entityType)) != true)
{
return propertyAccess;
}

// this is optional dependent sharing table
var nonPrincipalSharedNonPkProperties = entityType.GetNonPrincipalSharedNonPkProperties(table);
if (nonPrincipalSharedNonPkProperties.Contains(property))
{
// The column is not being shared with principal side so we can always use directly
return propertyAccess;
}

SqlExpression? condition = null;
// Property is being shared with principal side, so we need to make it conditional access
var allRequiredNonPkPropertiesCondition = entityType.GetProperties().Where(p => !p.IsNullable && !p.IsPrimaryKey()).ToList();
if (allRequiredNonPkPropertiesCondition.Count > 0)
{
condition = allRequiredNonPkPropertiesCondition.Select(p => entityProjectionExpression.BindProperty(p))
.Select(c => (SqlExpression)_sqlExpressionFactory.NotEqual(c, _sqlExpressionFactory.Constant(null)))
.Aggregate((a, b) => _sqlExpressionFactory.AndAlso(a, b));
}

if (nonPrincipalSharedNonPkProperties.Count != 0
&& nonPrincipalSharedNonPkProperties.All(p => p.IsNullable))
{
// If all non principal shared properties are nullable then we need additional condition
var atLeastOneNonNullValueInNullableColumnsCondition = nonPrincipalSharedNonPkProperties
.Select(p => entityProjectionExpression.BindProperty(p))
.Select(c => (SqlExpression)_sqlExpressionFactory.NotEqual(c, _sqlExpressionFactory.Constant(null)))
.Aggregate((a, b) => _sqlExpressionFactory.OrElse(a, b));

condition = condition == null
? atLeastOneNonNullValueInNullableColumnsCondition
: _sqlExpressionFactory.AndAlso(condition, atLeastOneNonNullValueInNullableColumnsCondition);
}

if (condition == null)
{
// if we cannot compute condition then we just return property access (and hope for the best)
return propertyAccess;
}

return _sqlExpressionFactory.Case(
new List<CaseWhenClause>
{
new CaseWhenClause(condition, propertyAccess)
},
elseResult: null);

// We don't do above processing for subquery entity since it comes from after subquery which has been
// single result so either it is regular entity or a collection which always have their own table.
}

if (entityReferenceExpression.SubqueryEntity != null)
Expand Down
97 changes: 60 additions & 37 deletions src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2136,18 +2136,18 @@ static void GetPartitions(SelectExpression selectExpression, SqlExpression sqlEx
&& inner.Offset == null
&& inner.Predicate != null)
{
var columnExpressions = new List<ColumnExpression>();
var outerColumnExpressions = new List<SqlExpression>();
var joinPredicate = TryExtractJoinKey(
outer,
inner,
inner.Predicate,
columnExpressions,
outerColumnExpressions,
allowNonEquality,
out var predicate);

if (joinPredicate != null)
{
joinPredicate = RemoveRedundantNullChecks(joinPredicate, columnExpressions);
joinPredicate = RemoveRedundantNullChecks(joinPredicate, outerColumnExpressions);
}
// TODO: verify the case for GroupBy. See issue#24474
// We extract join predicate from Predicate part but GroupBy would have last Having. Changing predicate can change groupings
Expand Down Expand Up @@ -2188,13 +2188,13 @@ static void GetPartitions(SelectExpression selectExpression, SqlExpression sqlEx
SelectExpression outer,
SelectExpression inner,
SqlExpression predicate,
List<ColumnExpression> columnExpressions,
List<SqlExpression> outerColumnExpressions,
bool allowNonEquality,
out SqlExpression? updatedPredicate)
{
if (predicate is SqlBinaryExpression sqlBinaryExpression)
{
var joinPredicate = ValidateKeyComparison(outer, inner, sqlBinaryExpression, columnExpressions, allowNonEquality);
var joinPredicate = ValidateKeyComparison(outer, inner, sqlBinaryExpression, outerColumnExpressions, allowNonEquality);
if (joinPredicate != null)
{
updatedPredicate = null;
Expand All @@ -2205,9 +2205,9 @@ static void GetPartitions(SelectExpression selectExpression, SqlExpression sqlEx
if (sqlBinaryExpression.OperatorType == ExpressionType.AndAlso)
{
var leftJoinKey = TryExtractJoinKey(
outer, inner, sqlBinaryExpression.Left, columnExpressions, allowNonEquality, out var leftPredicate);
outer, inner, sqlBinaryExpression.Left, outerColumnExpressions, allowNonEquality, out var leftPredicate);
var rightJoinKey = TryExtractJoinKey(
outer, inner, sqlBinaryExpression.Right, columnExpressions, allowNonEquality, out var rightPredicate);
outer, inner, sqlBinaryExpression.Right, outerColumnExpressions, allowNonEquality, out var rightPredicate);

updatedPredicate = CombineNonNullExpressions(leftPredicate, rightPredicate);

Expand All @@ -2224,7 +2224,7 @@ static void GetPartitions(SelectExpression selectExpression, SqlExpression sqlEx
SelectExpression outer,
SelectExpression inner,
SqlBinaryExpression sqlBinaryExpression,
List<ColumnExpression> columnExpressions,
List<SqlExpression> outerColumnExpressions,
bool allowNonEquality)
{
if (sqlBinaryExpression.OperatorType == ExpressionType.Equal
Expand All @@ -2235,45 +2235,39 @@ static void GetPartitions(SelectExpression selectExpression, SqlExpression sqlEx
|| sqlBinaryExpression.OperatorType == ExpressionType.LessThan
|| sqlBinaryExpression.OperatorType == ExpressionType.LessThanOrEqual)))
{
if (sqlBinaryExpression.Left is ColumnExpression leftColumn
&& sqlBinaryExpression.Right is ColumnExpression rightColumn)
if (IsContainedColumn(outer, sqlBinaryExpression.Left)
&& IsContainedColumn(inner, sqlBinaryExpression.Right))
{
if (outer.ContainsTableReference(leftColumn)
&& inner.ContainsTableReference(rightColumn))
{
columnExpressions.Add(leftColumn);
outerColumnExpressions.Add(sqlBinaryExpression.Left);

return sqlBinaryExpression;
}
return sqlBinaryExpression;
}

if (outer.ContainsTableReference(rightColumn)
&& inner.ContainsTableReference(leftColumn))
{
columnExpressions.Add(rightColumn);

return new SqlBinaryExpression(
_mirroredOperationMap[sqlBinaryExpression.OperatorType],
sqlBinaryExpression.Right,
sqlBinaryExpression.Left,
sqlBinaryExpression.Type,
sqlBinaryExpression.TypeMapping);
}
if (IsContainedColumn(outer, sqlBinaryExpression.Right)
&& IsContainedColumn(inner, sqlBinaryExpression.Left))
{
outerColumnExpressions.Add(sqlBinaryExpression.Right);

return new SqlBinaryExpression(
_mirroredOperationMap[sqlBinaryExpression.OperatorType],
sqlBinaryExpression.Right,
sqlBinaryExpression.Left,
sqlBinaryExpression.Type,
sqlBinaryExpression.TypeMapping);
}
}

// null checks are considered part of join key
if (sqlBinaryExpression.OperatorType == ExpressionType.NotEqual)
{
if (sqlBinaryExpression.Left is ColumnExpression leftNullCheckColumn
&& outer.ContainsTableReference(leftNullCheckColumn)
if (IsContainedColumn(outer, sqlBinaryExpression.Left)
&& sqlBinaryExpression.Right is SqlConstantExpression rightConstant
&& rightConstant.Value == null)
{
return sqlBinaryExpression;
}

if (sqlBinaryExpression.Right is ColumnExpression rightNullCheckColumn
&& outer.ContainsTableReference(rightNullCheckColumn)
if (IsContainedColumn(outer, sqlBinaryExpression.Right)
&& sqlBinaryExpression.Left is SqlConstantExpression leftConstant
&& leftConstant.Value == null)
{
Expand All @@ -2286,6 +2280,36 @@ static void GetPartitions(SelectExpression selectExpression, SqlExpression sqlEx
return null;
}

static bool IsContainedColumn(SelectExpression selectExpression, SqlExpression sqlExpression)
{
switch (sqlExpression)
{
case ColumnExpression columnExpression:
return selectExpression.ContainsTableReference(columnExpression);

case SqlConstantExpression sqlConstantExpression
when sqlConstantExpression.Value == null:
return true;

case SqlBinaryExpression sqlBinaryExpression
when sqlBinaryExpression.OperatorType == ExpressionType.AndAlso
|| sqlBinaryExpression.OperatorType == ExpressionType.OrElse
|| sqlBinaryExpression.OperatorType == ExpressionType.NotEqual:
return IsContainedColumn(selectExpression, sqlBinaryExpression.Left)
&& IsContainedColumn(selectExpression, sqlBinaryExpression.Right);

case CaseExpression caseExpression
when caseExpression.ElseResult == null
&& caseExpression.Operand == null
&& caseExpression.WhenClauses.Count == 1:
return IsContainedColumn(selectExpression, caseExpression.WhenClauses[0].Test)
&& IsContainedColumn(selectExpression, caseExpression.WhenClauses[0].Result);

default:
return false;
}
}

static void InnerKeyColumns(IEnumerable<TableExpressionBase> tables, SqlExpression joinPredicate, List<ColumnExpression> resultColumns)
{
if (joinPredicate is SqlBinaryExpression sqlBinaryExpression)
Expand Down Expand Up @@ -2334,13 +2358,12 @@ static List<ColumnExpression> ExtractColumnsFromProjectionMapping(IDictionary<Pr
: left
: right;

static SqlExpression? RemoveRedundantNullChecks(SqlExpression predicate, List<ColumnExpression> columnExpressions)
static SqlExpression? RemoveRedundantNullChecks(SqlExpression predicate, List<SqlExpression> outerColumnExpressions)
{
if (predicate is SqlBinaryExpression sqlBinaryExpression)
{
if (sqlBinaryExpression.OperatorType == ExpressionType.NotEqual
&& sqlBinaryExpression.Left is ColumnExpression leftColumn
&& columnExpressions.Contains(leftColumn)
&& outerColumnExpressions.Contains(sqlBinaryExpression.Left)
&& sqlBinaryExpression.Right is SqlConstantExpression sqlConstantExpression
&& sqlConstantExpression.Value == null)
{
Expand All @@ -2349,8 +2372,8 @@ static List<ColumnExpression> ExtractColumnsFromProjectionMapping(IDictionary<Pr

if (sqlBinaryExpression.OperatorType == ExpressionType.AndAlso)
{
var leftPredicate = RemoveRedundantNullChecks(sqlBinaryExpression.Left, columnExpressions);
var rightPredicate = RemoveRedundantNullChecks(sqlBinaryExpression.Right, columnExpressions);
var leftPredicate = RemoveRedundantNullChecks(sqlBinaryExpression.Left, outerColumnExpressions);
var rightPredicate = RemoveRedundantNullChecks(sqlBinaryExpression.Right, outerColumnExpressions);

return CombineNonNullExpressions(leftPredicate, rightPredicate);
}
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Threading.Tasks;
using Microsoft.EntityFrameworkCore.Diagnostics;
using Xunit;

namespace Microsoft.EntityFrameworkCore.Query
{
public abstract class ComplexNavigationsCollectionsSharedTypeQueryRelationalTestBase<TFixture> : ComplexNavigationsCollectionsSharedTypeQueryTestBase<TFixture>
where TFixture : ComplexNavigationsSharedTypeQueryRelationalFixtureBase, new()
{
protected ComplexNavigationsCollectionsSharedTypeQueryRelationalTestBase(TFixture fixture)
: base(fixture)
{
}

public override async Task SelectMany_with_navigation_and_Distinct_projecting_columns_including_join_key(bool async)
{
Assert.Equal(
RelationalStrings.InsufficientInformationToIdentifyOuterElementOfCollectionJoin,
(await Assert.ThrowsAsync<InvalidOperationException>(
() => base.SelectMany_with_navigation_and_Distinct_projecting_columns_including_join_key(async))).Message);
}
}
}
Original file line number Diff line number Diff line change
@@ -1,21 +1,31 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Linq.Expressions;
using System.Reflection;
using System.Threading.Tasks;
using Microsoft.EntityFrameworkCore.Diagnostics;
using Xunit;

namespace Microsoft.EntityFrameworkCore.Query
{
public abstract class ComplexNavigationsCollectionsSplitSharedQueryTypeRelationalTestBase<TFixture> : ComplexNavigationsCollectionsSharedTypeQueryTestBase<TFixture>
public abstract class ComplexNavigationsCollectionsSplitSharedTypeQueryRelationalTestBase<TFixture> : ComplexNavigationsCollectionsSharedTypeQueryTestBase<TFixture>
where TFixture : ComplexNavigationsSharedTypeQueryRelationalFixtureBase, new()
{
protected ComplexNavigationsCollectionsSplitSharedQueryTypeRelationalTestBase(TFixture fixture)
protected ComplexNavigationsCollectionsSplitSharedTypeQueryRelationalTestBase(TFixture fixture)
: base(fixture)
{
}

public override async Task SelectMany_with_navigation_and_Distinct_projecting_columns_including_join_key(bool async)
{
Assert.Equal(
RelationalStrings.InsufficientInformationToIdentifyOuterElementOfCollectionJoin,
(await Assert.ThrowsAsync<InvalidOperationException>(
() => base.SelectMany_with_navigation_and_Distinct_projecting_columns_including_join_key(async))).Message);
}

protected override Expression RewriteServerQueryExpression(Expression serverQueryExpression)
=> new SplitQueryRewritingExpressionVisitor().Visit(serverQueryExpression);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@

namespace Microsoft.EntityFrameworkCore.Query
{
public abstract class ComplexNavigationsSharedQueryTypeRelationalTestBase<TFixture> : ComplexNavigationsSharedTypeQueryTestBase<TFixture>
public abstract class ComplexNavigationsSharedTypeQueryRelationalTestBase<TFixture> : ComplexNavigationsSharedTypeQueryTestBase<TFixture>
where TFixture : ComplexNavigationsSharedTypeQueryRelationalFixtureBase, new()
{
protected ComplexNavigationsSharedQueryTypeRelationalTestBase(TFixture fixture)
protected ComplexNavigationsSharedTypeQueryRelationalTestBase(TFixture fixture)
: base(fixture)
{
}
Expand All @@ -19,6 +19,7 @@ public override Task Complex_query_with_optional_navigations_and_client_side_eva
return AssertTranslationFailed(() => base.Complex_query_with_optional_navigations_and_client_side_evaluation(async));
}


protected virtual bool CanExecuteQueryString
=> false;

Expand Down

0 comments on commit 2da38e0

Please sign in to comment.