Skip to content

Commit

Permalink
Add support for nested navigations in filters
Browse files Browse the repository at this point in the history
Part of #12086
  • Loading branch information
AndriySvyryd committed Jun 28, 2019
1 parent a2ce460 commit 84860c8
Show file tree
Hide file tree
Showing 14 changed files with 302 additions and 55 deletions.
3 changes: 2 additions & 1 deletion src/EFCore.Cosmos/Infrastructure/CosmosModelValidator.cs
Expand Up @@ -143,7 +143,8 @@ public override void Validate(IModel model, IDiagnosticsLogger<DbLoggerCategory.
firstEntityType = entityType;
}

if (entityType.ClrType?.IsInstantiable() == true)
if (entityType.ClrType?.IsInstantiable() == true
&& entityType.GetCosmosContainingPropertyName() == null)
{
if (entityType.GetDiscriminatorProperty() == null)
{
Expand Down
Expand Up @@ -465,9 +465,7 @@ protected override ShapedQueryExpression TranslateWhere(ShapedQueryExpression so
}

private SqlExpression TranslateExpression(Expression expression)
{
return _sqlTranslator.Translate(expression);
}
=> _sqlTranslator.Translate(expression);

private SqlExpression TranslateLambdaExpression(
ShapedQueryExpression shapedQueryExpression, LambdaExpression lambdaExpression)
Expand Down
Expand Up @@ -2,13 +2,12 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using Microsoft.EntityFrameworkCore.Internal;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Metadata.Internal;
using Microsoft.EntityFrameworkCore.Query.Expressions.Internal;
using Microsoft.EntityFrameworkCore.Query.Pipeline;

Expand Down Expand Up @@ -55,7 +54,8 @@ private class SqlTypeMappingVerifyingExpressionVisitor : ExpressionVisitor
{
protected override Expression VisitExtension(Expression node)
{
if (node is SqlExpression sqlExpression)
if (node is SqlExpression sqlExpression
&& !(node is ObjectAccessExpression))
{
if (sqlExpression.TypeMapping == null)
{
Expand All @@ -69,42 +69,73 @@ protected override Expression VisitExtension(Expression node)

protected override Expression VisitMember(MemberExpression memberExpression)
{
if (memberExpression.Expression is EntityShaperExpression)
var innerExpression = Visit(memberExpression.Expression);

if (TryBindProperty(innerExpression, MemberIdentity.Create(memberExpression.Member), out var result))
{
return BindProperty(memberExpression.Expression, memberExpression.Member.GetSimpleMemberName());
return result;
}

var innerExpression = Visit(memberExpression.Expression);

return TranslationFailed(memberExpression.Expression, innerExpression)
? null
: _memberTranslatorProvider.Translate((SqlExpression)innerExpression, memberExpression.Member, memberExpression.Type);
}

private SqlExpression BindProperty(Expression source, string propertyName)
private bool TryBindProperty(Expression source, MemberIdentity member, out SqlExpression expression)
{
if (source is EntityShaperExpression entityShaper)
if (source is EntityShaperExpression entityShaperExpression)
{
var entityType = entityShaper.EntityType;
var projectionBindingExpression = (ProjectionBindingExpression)entityShaperExpression.ValueBufferExpression;
var selectExpression = ((SelectExpression)projectionBindingExpression.QueryExpression);

var entityType = entityShaperExpression.EntityType;
var property = member.MemberInfo != null
? entityType.FindProperty(member.MemberInfo)
: entityType.FindProperty(member.Name);
if (property != null)
{
expression = selectExpression.BindProperty(property, projectionBindingExpression);
return true;
}

return BindProperty(entityShaper, entityType.FindProperty(propertyName));
var navigation = member.MemberInfo != null
? entityType.FindNavigation(member.MemberInfo)
: entityType.FindNavigation(member.Name);
expression = selectExpression.BindNavigation(navigation, projectionBindingExpression);
return true;
}
else if (source is ObjectAccessExpression objectAccessExpression)
{
var entityType = objectAccessExpression.Navigation.GetTargetType();
var property = member.MemberInfo != null
? entityType.FindProperty(member.MemberInfo)
: entityType.FindProperty(member.Name);
if (property != null)
{
expression = new KeyAccessExpression(property, objectAccessExpression);
return true;
}

throw new InvalidOperationException();
}
var navigation = member.MemberInfo != null
? entityType.FindNavigation(member.MemberInfo)
: entityType.FindNavigation(member.Name);
expression = new ObjectAccessExpression(navigation, objectAccessExpression);
return true;
}

private SqlExpression BindProperty(EntityShaperExpression entityShaperExpression, IProperty property)
{
var projectionBindingExpression = (ProjectionBindingExpression)entityShaperExpression.ValueBufferExpression;
return ((SelectExpression)projectionBindingExpression.QueryExpression)
.BindProperty(projectionBindingExpression, property);
expression = null;
return false;
}

protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression)
{
if (methodCallExpression.TryGetEFPropertyArguments(out var source, out var propertyName))
{
return BindProperty(source, propertyName);
if (!TryBindProperty(source, MemberIdentity.Create(propertyName), out var result))
{
throw new InvalidOperationException();
}
return result;
}

//if (methodCallExpression.Method.DeclaringType == typeof(Queryable))
Expand Down
19 changes: 19 additions & 0 deletions src/EFCore.Cosmos/Query/Pipeline/EntityProjectionExpression.cs
Expand Up @@ -14,6 +14,8 @@ public class EntityProjectionExpression : Expression
{
private readonly IDictionary<IProperty, KeyAccessExpression> _propertyExpressionsCache
= new Dictionary<IProperty, KeyAccessExpression>();
private readonly IDictionary<INavigation, ObjectAccessExpression> _navigationExpressionsCache
= new Dictionary<INavigation, ObjectAccessExpression>();
private readonly IEntityType _entityType;

public EntityProjectionExpression(IEntityType entityType, RootReferenceExpression accessExpression, string alias)
Expand Down Expand Up @@ -55,5 +57,22 @@ public KeyAccessExpression GetProperty(IProperty property)

return expression;
}

public ObjectAccessExpression GetNavigation(INavigation navigation)
{
if (!_entityType.GetTypesInHierarchy().Contains(navigation.DeclaringEntityType))
{
throw new InvalidOperationException(
$"Called EntityProjectionExpression.GetNavigation() with incorrect INavigation. EntityType:{_entityType.DisplayName()}, Navigation:{navigation.Name}");
}

if (!_navigationExpressionsCache.TryGetValue(navigation, out var expression))
{
expression = new ObjectAccessExpression(navigation, AccessExpression);
_navigationExpressionsCache[navigation] = expression;
}

return expression;
}
}
}
21 changes: 7 additions & 14 deletions src/EFCore.Cosmos/Query/Pipeline/KeyAccessExpression.cs
Expand Up @@ -11,9 +11,9 @@ namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Pipeline
public class KeyAccessExpression : SqlExpression
{
private readonly IProperty _property;
private readonly RootReferenceExpression _outerExpression;
private readonly Expression _outerExpression;

public KeyAccessExpression(IProperty property, RootReferenceExpression outerExpression)
public KeyAccessExpression(IProperty property, Expression outerExpression)
: base(property.ClrType, property.GetTypeMapping())
{
Name = property.GetCosmosPropertyName();
Expand All @@ -25,27 +25,20 @@ public KeyAccessExpression(IProperty property, RootReferenceExpression outerExpr

protected override Expression VisitChildren(ExpressionVisitor visitor)
{
var outerExpression = (RootReferenceExpression)visitor.Visit(_outerExpression);
var outerExpression = visitor.Visit(_outerExpression);

return Update(outerExpression);
}

public KeyAccessExpression Update(RootReferenceExpression outerExpression)
{
return outerExpression != _outerExpression
public KeyAccessExpression Update(Expression outerExpression)
=> outerExpression != _outerExpression
? new KeyAccessExpression(_property, outerExpression)
: this;
}

public override void Print(ExpressionPrinter expressionPrinter)
{
expressionPrinter.StringBuilder.Append(ToString());
}
=> expressionPrinter.StringBuilder.Append(ToString());

public override string ToString()
{
return $"{_outerExpression}[\"{Name}\"]";
}
public override string ToString() => $"{_outerExpression}[\"{Name}\"]";

public override bool Equals(object obj)
=> obj != null
Expand Down
62 changes: 62 additions & 0 deletions src/EFCore.Cosmos/Query/Pipeline/ObjectAccessExpression.cs
@@ -0,0 +1,62 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Linq.Expressions;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Query.Internal;

namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Pipeline
{
public class ObjectAccessExpression : SqlExpression
{
private readonly Expression _outerExpression;

public ObjectAccessExpression(INavigation navigation, Expression outerExpression)
: base(navigation.ClrType, null)
{
Name = navigation.GetTargetType().GetCosmosContainingPropertyName();
if (Name == null)
{
throw new InvalidOperationException(
$"Navigation '{navigation.DeclaringEntityType.DisplayName()}.{navigation.Name}' doesn't point to a nested entity.");
}

Navigation = navigation;
_outerExpression = outerExpression;
}

public string Name { get; }

public INavigation Navigation { get; }

protected override Expression VisitChildren(ExpressionVisitor visitor)
{
var outerExpression = visitor.Visit(_outerExpression);

return Update(outerExpression);
}

public ObjectAccessExpression Update(Expression outerExpression)
=> outerExpression != _outerExpression
? new ObjectAccessExpression(Navigation, outerExpression)
: this;

public override void Print(ExpressionPrinter expressionPrinter) => expressionPrinter.StringBuilder.Append(ToString());

public override string ToString() => $"{_outerExpression}[\"{Name}\"]";

public override bool Equals(object obj)
=> obj != null
&& (ReferenceEquals(this, obj)
|| obj is ObjectAccessExpression objectAccessExpression
&& Equals(objectAccessExpression));

private bool Equals(ObjectAccessExpression objectAccessExpression)
=> base.Equals(objectAccessExpression)
&& string.Equals(Name, objectAccessExpression.Name)
&& _outerExpression.Equals(objectAccessExpression._outerExpression);

public override int GetHashCode() => HashCode.Combine(base.GetHashCode(), Name, _outerExpression);
}
}
5 changes: 2 additions & 3 deletions src/EFCore.Cosmos/Query/Pipeline/ProjectionExpression.cs
Expand Up @@ -42,10 +42,9 @@ public void Print(ExpressionPrinter expressionPrinter)
}

private string GetName()
{
return (Expression as KeyAccessExpression)?.Name
=> (Expression as KeyAccessExpression)?.Name
?? (Expression as ObjectAccessExpression)?.Name
?? (Expression as EntityProjectionExpression)?.Alias;
}

public override bool Equals(object obj)
=> obj != null
Expand Down
7 changes: 7 additions & 0 deletions src/EFCore.Cosmos/Query/Pipeline/QuerySqlGenerator.cs
Expand Up @@ -82,6 +82,13 @@ protected override Expression VisitKeyAccess(KeyAccessExpression keyAccessExpres
return keyAccessExpression;
}

protected override Expression VisitObjectAccess(ObjectAccessExpression objectAccessExpression)
{
_sqlBuilder.Append(objectAccessExpression);

return objectAccessExpression;
}

protected override Expression VisitProjection(ProjectionExpression projectionExpression)
{
Visit(projectionExpression.Expression);
Expand Down
10 changes: 6 additions & 4 deletions src/EFCore.Cosmos/Query/Pipeline/SelectExpression.cs
Expand Up @@ -208,11 +208,13 @@ public void ReverseOrderings()
}
}

public SqlExpression BindProperty(ProjectionBindingExpression projectionBindingExpression, IProperty property)
{
return ((EntityProjectionExpression)_projectionMapping[projectionBindingExpression.ProjectionMember])
public SqlExpression BindProperty(IProperty property, ProjectionBindingExpression projectionBindingExpression)
=> ((EntityProjectionExpression)_projectionMapping[projectionBindingExpression.ProjectionMember])
.GetProperty(property);
}

public SqlExpression BindNavigation(INavigation navigation, ProjectionBindingExpression projectionBindingExpression)
=> ((EntityProjectionExpression)_projectionMapping[projectionBindingExpression.ProjectionMember])
.GetNavigation(navigation);

public override Type Type => typeof(JObject);
public override ExpressionType NodeType => ExpressionType.Extension;
Expand Down
1 change: 1 addition & 0 deletions src/EFCore.Cosmos/Query/Pipeline/SqlExpressionVisitor.cs
Expand Up @@ -63,6 +63,7 @@ protected override Expression VisitExtension(Expression extensionExpression)
protected abstract Expression VisitSqlConstant(SqlConstantExpression sqlConstantExpression);
protected abstract Expression VisitSqlBinary(SqlBinaryExpression sqlBinaryExpression);
protected abstract Expression VisitKeyAccess(KeyAccessExpression keyAccessExpression);
protected abstract Expression VisitObjectAccess(ObjectAccessExpression objectAccessExpression);
protected abstract Expression VisitRootReference(RootReferenceExpression rootReferenceExpression);
protected abstract Expression VisitEntityProjection(EntityProjectionExpression entityProjectionExpression);
protected abstract Expression VisitProjection(ProjectionExpression projectionExpression);
Expand Down
13 changes: 7 additions & 6 deletions src/EFCore/Query/Pipeline/ShapedQueryExpressionVisitor.cs
Expand Up @@ -2,7 +2,6 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
Expand Down Expand Up @@ -412,28 +411,30 @@ var discriminatorValue
{
var navigation = nestedShaper.ParentNavigation;
var memberInfo = navigation.GetMemberInfo(forConstruction: true, forSet: true);
var convertedInstanceVariable = memberInfo.DeclaringType.IsAssignableFrom(instanceVariable.Type)
? (Expression)instanceVariable
: Expression.Convert(instanceVariable, memberInfo.DeclaringType);

Expression navigationExpression;
if (navigation.IsCollection())
{
var accessorExpression = Expression.Constant(new ClrCollectionAccessorFactory().Create(navigation));
navigationExpression = Expression.Call(accessorExpression, _accessorAddRangeMethodInfo,
instanceVariable, new CollectionShaperExpression(null, nestedShaper, navigation));
convertedInstanceVariable, new CollectionShaperExpression(null, nestedShaper, navigation));
}
else
{
navigationExpression = Expression.Assign(Expression.MakeMemberAccess(
instanceVariable,
convertedInstanceVariable,
memberInfo),
nestedShaper);
}

var nestedMaterializer = Expression.Condition(
var nestedMaterializer = Expression.IfThen(
Expression.Call(_isAssignableFromMethodInfo,
Expression.Constant(navigation.DeclaringEntityType),
concreteEntityTypeVariable),
navigationExpression,
Expression.Constant(null, navigationExpression.Type));
navigationExpression);

expressions.Add(nestedMaterializer);
}
Expand Down

0 comments on commit 84860c8

Please sign in to comment.