Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 92 additions & 27 deletions src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ private readonly IDictionary<EntityProjectionExpression, IDictionary<IProperty,
private readonly List<TableExpressionBase> _tables = new List<TableExpressionBase>();
private readonly List<SqlExpression> _groupBy = new List<SqlExpression>();
private readonly List<OrderingExpression> _orderings = new List<OrderingExpression>();

private readonly List<SqlExpression> _identifier = new List<SqlExpression>();
private readonly List<SqlExpression> _childIdentifiers = new List<SqlExpression>();
private readonly List<SelectExpression> _pendingCollections = new List<SelectExpression>();
Expand Down Expand Up @@ -1244,10 +1243,33 @@ when _subquery.ContainsTableReference(columnExpression.Table):

protected override Expression VisitChildren(ExpressionVisitor visitor)
{
// If we're visiting this expression a second time, the result of the previous visitation will be stored
// here - simply return it to avoid full double visitation.
if (VisitedExpression != null)
{
return VisitedExpression;
}

// TODO: Only needed with nested visitors
var previousTableVisitedExpressions = new TableExpressionBase[_tables.Count];

// We have to do in-place mutation till we have applied pending collections because of shaper references
// This is pseudo finalization phase for select expression.
if (_pendingCollections.Any(e => e != null))
{
var tables = _tables.ToList();
_tables.Clear();
for (var i = 0; i < tables.Count; i++)
{
var table = tables[i];
var newTable = VisitTable(table, i);
_tables.Add(newTable);
if (table is SelectExpression subqueryExpression)
{
subqueryExpression.VisitedExpression = newTable;
}
}

if (Projection.Any())
{
var projections = _projection.ToList();
Expand All @@ -1267,10 +1289,6 @@ protected override Expression VisitChildren(ExpressionVisitor visitor)
_projectionMapping = projectionMapping;
}

var tables = _tables.ToList();
_tables.Clear();
_tables.AddRange(tables.Select(e => (TableExpressionBase)visitor.Visit(e)));

Predicate = (SqlExpression)visitor.Visit(Predicate);

var groupBy = _groupBy.ToList();
Expand All @@ -1288,12 +1306,54 @@ protected override Expression VisitChildren(ExpressionVisitor visitor)
Offset = (SqlExpression)visitor.Visit(Offset);
Limit = (SqlExpression)visitor.Visit(Limit);

// Restore the previous VisitedExpression references for nested visitors
for (var i = 0; i < _tables.Count; i++)
{
UnwrapJoin(_tables[i]).VisitedExpression = previousTableVisitedExpressions[i];
}

return this;
}
else
{
var changed = false;

var newTables = _tables;
for (var i = 0; i < _tables.Count; i++)
{
var table = _tables[i];
var newTable = VisitTable(table, i);

// // Set the table's VisitedExpression to null so that we can actually visit it. Unwrap joins since the contained
// // table is the one actually referenced from other SelectExpression components. We restore the previous visitation
// // before leaving the method in case more than one visitor is active.
// var referencedTable = UnwrapJoin(table);
// previousTableVisitedExpressions[i] = referencedTable.VisitedExpression;
// referencedTable.VisitedExpression = null;
//
// var newTable = (TableExpressionBase)visitor.Visit(table);
//
// // Reference the visited table from the un-visited one so that subsequent visitations (e.g. from projections) will
// // rewire to it and avoid multiple deep visitations.
// referencedTable.VisitedExpression = UnwrapJoin(newTable);

if (newTable != table
&& newTables == _tables)
{
newTables = new List<TableExpressionBase>(_tables.Count);
for (var j = 0; j < i; j++)
{
newTables.Add(_tables[j]);
}
changed = true;
}

if (newTables != _tables)
{
newTables.Add(newTable);
}
}

var newProjections = _projection;
var newProjectionMapping = _projectionMapping;
if (_projection.Any())
Expand Down Expand Up @@ -1338,28 +1398,6 @@ protected override Expression VisitChildren(ExpressionVisitor visitor)
}
}

var newTables = _tables;
for (var i = 0; i < _tables.Count; i++)
{
var table = _tables[i];
var newTable = (TableExpressionBase)visitor.Visit(table);
if (newTable != table
&& newTables == _tables)
{
newTables = new List<TableExpressionBase>(_tables.Count);
for (var j = 0; j < i; j++)
{
newTables.Add(_tables[j]);
}
changed = true;
}

if (newTables != _tables)
{
newTables.Add(newTable);
}
}

var predicate = (SqlExpression)visitor.Visit(Predicate);
changed |= predicate != Predicate;

Expand Down Expand Up @@ -1422,6 +1460,12 @@ protected override Expression VisitChildren(ExpressionVisitor visitor)
var limit = (SqlExpression)visitor.Visit(Limit);
changed |= limit != Limit;

// Restore the previous VisitedExpression references for nested visitors
for (var i = 0; i < _tables.Count; i++)
{
UnwrapJoin(_tables[i]).VisitedExpression = previousTableVisitedExpressions[i];
}

if (changed)
{
var newSelectExpression = new SelectExpression(Alias, newProjections, newTables, newGroupBy, newOrderings)
Expand All @@ -1442,6 +1486,27 @@ protected override Expression VisitChildren(ExpressionVisitor visitor)

return this;
}

static TableExpressionBase UnwrapJoin(TableExpressionBase table) =>
table is JoinExpressionBase join ? join.Table : table;

TableExpressionBase VisitTable(TableExpressionBase table, int tableIndex)
{
// Set the table's VisitedExpression to null so that we can actually visit it. Unwrap joins since the contained
// table is the one actually referenced from other SelectExpression components. We restore the previous visitation
// before leaving the method in case more than one visitor is active.
var referencedTable = UnwrapJoin(table);
previousTableVisitedExpressions[tableIndex] = referencedTable.VisitedExpression;
referencedTable.VisitedExpression = null;

var newTable = (TableExpressionBase)visitor.Visit(table);

// Reference the visited table from the un-visited one so that subsequent visitations (e.g. from projections) will
// rewire to it and avoid multiple deep visitations.
referencedTable.VisitedExpression = UnwrapJoin(newTable);

return newTable;
}
}

public override bool Equals(object obj)
Expand Down
10 changes: 10 additions & 0 deletions src/EFCore.Relational/Query/SqlExpressions/TableExpressionBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,16 @@ protected TableExpressionBase([CanBeNull] string alias)

public virtual string Alias { get; internal set; }

/// <summary>
/// Populated after this expression is first visited from its containing <see cref="SelectExpression"/>, to make sure that
/// subsequent visits return the same instance, and to prevent needless multiple deep visits. Used only by visitors.
/// </summary>
/// <remarks>
/// If you implement an expression visitor which contains specific logic for visiting <see cref="TableExpressionBase"/>,
/// you should properly populate and check this field (<see cref="SelectExpression.VisitChildren"/> for an example).
/// </remarks>
public TableExpressionBase VisitedExpression { get; set; }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No public surface like this. This is just plain wrong.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I considered this - but we need something that will work for completely external visitors, e.g. provider-added postprocessors (such as the one in the test).

Any specific better ideas? Any feedback on the general approach and on the rest?

Copy link
Member Author

@roji roji Sep 5, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PS How was this solved in the old pipeline? Seems like in the old pipeline SelectExpression.VisitChildren could neither mutate the instance nor return a modified copy...


protected override Expression VisitChildren(ExpressionVisitor visitor) => this;

public override Type Type => typeof(object);
Expand Down
61 changes: 61 additions & 0 deletions test/EFCore.Relational.Tests/Query/SelectExpressionTest.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System.Collections.Generic;
using System.Linq.Expressions;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Query.SqlExpressions;
using Microsoft.EntityFrameworkCore.TestUtilities;
using Xunit;

namespace Microsoft.EntityFrameworkCore.Query
{
public class SelectExpressionTest
{
[ConditionalFact]
public void Table_referential_integrity_is_preserved()
{
var model = CreateModel();
var property = model.FindEntityType(typeof(Foo)).FindProperty("Id");
var table = new TableExpression("SomeTable", null, "t");
var ordering = new OrderingExpression(new ColumnExpression(property, table, false), true);

var select = new SelectExpression(
"s",
new List<ProjectionExpression>(),
new List<TableExpressionBase> { table },
new List<SqlExpression>(),
new List<OrderingExpression> { ordering });

var visitor = new TableSwitchingExpressionVisitor();
var visitedSelect = (SelectExpression)visitor.Visit(select);
Assert.Same(visitedSelect.Tables[0], ((ColumnExpression)visitedSelect.Orderings[0].Expression).Table);
}

private class TableSwitchingExpressionVisitor : ExpressionVisitor
{
protected override Expression VisitExtension(Expression node)
{
if (node is TableExpression tableExpression)
{
return tableExpression.VisitedExpression ?? new TableExpression(
tableExpression.Name + "2", tableExpression.Schema, tableExpression.Alias);
}
return base.VisitExtension(node);
}
}

protected IMutableModel CreateModel()
{
var builder = RelationalTestHelpers.Instance.CreateConventionBuilder();
builder.Entity<Foo>();
builder.FinalizeModel();
return builder.Model;
}

private class Foo
{
public int Id { get; set; }
}
}
}