Skip to content

Commit

Permalink
Properly implement VisitChildren on PG TVF expressions (#2959)
Browse files Browse the repository at this point in the history
Fixes #2957
  • Loading branch information
roji committed Nov 17, 2023
1 parent 7be1f48 commit 9792445
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 1 deletion.
127 changes: 127 additions & 0 deletions src/EFCore.PG/Extensions/ExpressionVisitorExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
using System.Runtime.CompilerServices;

// ReSharper disable once CheckNamespace
namespace System.Linq.Expressions;

#nullable enable

[DebuggerStepThrough]
internal static class ExpressionVisitorExtensions
{
/// <summary>
/// Dispatches the list of expressions to one of the more specialized visit methods in this class.
/// </summary>
/// <param name="visitor">The expression visitor.</param>
/// <param name="nodes">The expressions to visit.</param>
/// <returns>
/// The modified expression list, if any of the elements were modified; otherwise, returns the original expression list.
/// </returns>
public static IReadOnlyList<Expression> Visit(this ExpressionVisitor visitor, IReadOnlyList<Expression> nodes)
{
Expression[]? newNodes = null;
for (int i = 0, n = nodes.Count; i < n; i++)
{
var node = visitor.Visit(nodes[i]);

if (newNodes is not null)
{
newNodes[i] = node;
}
else if (!ReferenceEquals(node, nodes[i]))
{
newNodes = new Expression[n];
for (var j = 0; j < i; j++)
{
newNodes[j] = nodes[j];
}

newNodes[i] = node;
}
}

return newNodes ?? nodes;
}

/// <summary>
/// Visits an expression, casting the result back to the original expression type.
/// </summary>
/// <typeparam name="T">The type of the expression.</typeparam>
/// <param name="visitor">The expression visitor.</param>
/// <param name="nodes">The expression to visit.</param>
/// <param name="callerName">The name of the calling method; used to report to report a better error message.</param>
/// <returns>
/// The modified expression, if it or any subexpression was modified; otherwise, returns the original expression.
/// </returns>
/// <exception cref="InvalidOperationException">The visit method for this node returned a different type.</exception>
public static IReadOnlyList<T> VisitAndConvert<T>(
this ExpressionVisitor visitor,
IReadOnlyList<T> nodes,
[CallerMemberName] string? callerName = null)
where T : Expression
{
T[]? newNodes = null;
for (int i = 0, n = nodes.Count; i < n; i++)
{
if (visitor.Visit(nodes[i]) is not T node)
{
throw new InvalidOperationException(CoreStrings.MustRewriteToSameNode(callerName, typeof(T).Name));
}

if (newNodes is not null)
{
newNodes[i] = node;
}
else if (!ReferenceEquals(node, nodes[i]))
{
newNodes = new T[n];
for (var j = 0; j < i; j++)
{
newNodes[j] = nodes[j];
}

newNodes[i] = node;
}
}

return newNodes ?? nodes;
}

/// <summary>
/// Visits all nodes in the collection using a specified element visitor.
/// </summary>
/// <typeparam name="T">The type of the nodes.</typeparam>
/// <param name="visitor">The expression visitor.</param>
/// <param name="nodes">The nodes to visit.</param>
/// <param name="elementVisitor">
/// A delegate that visits a single element,
/// optionally replacing it with a new element.
/// </param>
/// <returns>
/// The modified node list, if any of the elements were modified;
/// otherwise, returns the original node list.
/// </returns>
public static IReadOnlyList<T> Visit<T>(this ExpressionVisitor visitor, IReadOnlyList<T> nodes, Func<T, T> elementVisitor)
{
T[]? newNodes = null;
for (int i = 0, n = nodes.Count; i < n; i++)
{
var node = elementVisitor(nodes[i]);
if (newNodes is not null)
{
newNodes[i] = node;
}
else if (!ReferenceEquals(node, nodes[i]))
{
newNodes = new T[n];
for (var j = 0; j < i; j++)
{
newNodes[j] = nodes[j];
}

newNodes[i] = node;
}
}

return newNodes ?? nodes;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,12 @@ public class PgTableValuedFunctionExpression : TableValuedFunctionExpression, IE
WithOrdinality = withOrdinality;
}

/// <inheritdoc />
protected override Expression VisitChildren(ExpressionVisitor visitor)
=> visitor.VisitAndConvert(Arguments) is var visitedArguments && visitedArguments == Arguments
? this
: new PgTableValuedFunctionExpression(Alias, Name, visitedArguments, ColumnInfos, WithOrdinality);

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query.Expressions.Internal;
/// </summary>
/// <remarks>
/// <para>
/// This expression is just a <see cref="TableValuedFunctionExpression" />, adding the ability to provide an explicit column name
/// This expression is just a <see cref="PgTableValuedFunctionExpression" />, adding the ability to provide an explicit column name
/// for its output (<c>SELECT * FROM unnest(array) AS f(foo)</c>). This is necessary since when the column name isn't explicitly
/// specified, it is automatically identical to the table alias (<c>f</c> above); since the table alias may get uniquified by
/// EF, this would break queries.
Expand Down Expand Up @@ -58,6 +58,12 @@ public PgUnnestExpression(string alias, SqlExpression array, string columnName,
{
}

/// <inheritdoc />
protected override Expression VisitChildren(ExpressionVisitor visitor)
=> visitor.Visit(Array) is var visitedArray && visitedArray == Array
? this
: new PgUnnestExpression(Alias, (SqlExpression)visitedArray, ColumnName, WithOrdinality);

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
Expand Down

0 comments on commit 9792445

Please sign in to comment.