From dd96eaab4fcbf00920cb28e8cb50870497e30331 Mon Sep 17 00:00:00 2001 From: Maurycy Markowski Date: Thu, 16 Jun 2016 16:42:22 -0700 Subject: [PATCH] fd --- .../RelationalResultOperatorHandler.cs | 43 ++++++------------- .../NavigationRewritingExpressionVisitor.cs | 18 ++++---- 2 files changed, 22 insertions(+), 39 deletions(-) diff --git a/src/Microsoft.EntityFrameworkCore.Relational/Query/Internal/RelationalResultOperatorHandler.cs b/src/Microsoft.EntityFrameworkCore.Relational/Query/Internal/RelationalResultOperatorHandler.cs index 120628bf7c6..04a74620bad 100644 --- a/src/Microsoft.EntityFrameworkCore.Relational/Query/Internal/RelationalResultOperatorHandler.cs +++ b/src/Microsoft.EntityFrameworkCore.Relational/Query/Internal/RelationalResultOperatorHandler.cs @@ -71,6 +71,15 @@ public Expression EvalOnClient(bool requiresClientResultOperator = true) return _resultOperatorHandler .HandleResultOperator(QueryModelVisitor, ResultOperator, QueryModel); } + + public SqlTranslatingExpressionVisitor CreateSqlTranslatingVisitor(bool bindParentQueries = false) + { + return SqlTranslatingExpressionVisitorFactory + .Create( + QueryModelVisitor, + SelectExpression, + bindParentQueries: bindParentQueries); + } } private static readonly Dictionary> @@ -166,11 +175,7 @@ var handlerContext private static Expression HandleAll(HandlerContext handlerContext) { - var filteringVisitor - = handlerContext.SqlTranslatingExpressionVisitorFactory - .Create( - handlerContext.QueryModelVisitor, - handlerContext.SelectExpression); + var filteringVisitor = handlerContext.CreateSqlTranslatingVisitor(); var predicate = filteringVisitor.Visit( @@ -229,17 +234,11 @@ private static Expression HandleCast(HandlerContext handlerContext) private static Expression HandleContains(HandlerContext handlerContext) { - var filteringVisitor - = handlerContext.SqlTranslatingExpressionVisitorFactory - .Create( - handlerContext.QueryModelVisitor, - handlerContext.SelectExpression, - bindParentQueries: true); + var filteringVisitor = handlerContext.CreateSqlTranslatingVisitor(bindParentQueries: true); var itemResultOperator = (ContainsResultOperator)handlerContext.ResultOperator; var item = filteringVisitor.Visit(itemResultOperator.Item); - if (item != null) { var itemSelectExpression = item as SelectExpression; @@ -431,11 +430,7 @@ private static Expression HandleFirst(HandlerContext handlerContext) private static Expression HandleGroup(HandlerContext handlerContext) { - var sqlTranslatingExpressionVisitor - = handlerContext.SqlTranslatingExpressionVisitorFactory - .Create( - handlerContext.QueryModelVisitor, - handlerContext.SelectExpression); + var sqlTranslatingExpressionVisitor = handlerContext.CreateSqlTranslatingVisitor(); var groupResultOperator = (GroupResultOperator)handlerContext.ResultOperator; @@ -650,12 +645,7 @@ private static Expression HandleSkip(HandlerContext handlerContext) { var skipResultOperator = (SkipResultOperator)handlerContext.ResultOperator; - var sqlTranslatingExpressionVisitor - = handlerContext.SqlTranslatingExpressionVisitorFactory - .Create( - handlerContext.QueryModelVisitor, - handlerContext.SelectExpression, - bindParentQueries: true); + var sqlTranslatingExpressionVisitor = handlerContext.CreateSqlTranslatingVisitor(bindParentQueries: true); var offset = sqlTranslatingExpressionVisitor.Visit(skipResultOperator.Count); if (offset != null) @@ -689,12 +679,7 @@ private static Expression HandleTake(HandlerContext handlerContext) { var takeResultOperator = (TakeResultOperator)handlerContext.ResultOperator; - var sqlTranslatingExpressionVisitor - = handlerContext.SqlTranslatingExpressionVisitorFactory - .Create( - handlerContext.QueryModelVisitor, - handlerContext.SelectExpression, - bindParentQueries: true); + var sqlTranslatingExpressionVisitor = handlerContext.CreateSqlTranslatingVisitor(bindParentQueries: true); var limit = sqlTranslatingExpressionVisitor.Visit(takeResultOperator.Count); if (limit != null) diff --git a/src/Microsoft.EntityFrameworkCore/Query/ExpressionVisitors/Internal/NavigationRewritingExpressionVisitor.cs b/src/Microsoft.EntityFrameworkCore/Query/ExpressionVisitors/Internal/NavigationRewritingExpressionVisitor.cs index 027e5a04a53..5c500378f33 100644 --- a/src/Microsoft.EntityFrameworkCore/Query/ExpressionVisitors/Internal/NavigationRewritingExpressionVisitor.cs +++ b/src/Microsoft.EntityFrameworkCore/Query/ExpressionVisitors/Internal/NavigationRewritingExpressionVisitor.cs @@ -964,11 +964,9 @@ public override void VisitAdditionalFromClause(AdditionalFromClause fromClause, public override void VisitWhereClause(WhereClause whereClause, QueryModel queryModel, int index) { - var originalType = whereClause.Predicate.Type; - base.VisitWhereClause(whereClause, queryModel, index); - if (originalType == typeof(bool) && whereClause.Predicate.Type == typeof(bool?)) + if (whereClause.Predicate.Type == typeof(bool?)) { whereClause.Predicate = Expression.Convert(whereClause.Predicate, typeof(bool)); } @@ -984,12 +982,12 @@ public override void VisitOrderByClause(OrderByClause orderByClause, QueryModel Debug.Assert(originalTypes.Count == newTypes.Count); - for (int i = 0; i < newTypes.Count; i++) + for (var i = 0; i < newTypes.Count; i++) { - if (originalTypes[i] != newTypes[i] + if ((originalTypes[i] != newTypes[i]) && !originalTypes[i].IsNullableType() && newTypes[i].IsNullableType() - && originalTypes[i].UnwrapNullableType() == newTypes[i].UnwrapNullableType()) + && (originalTypes[i].UnwrapNullableType() == newTypes[i].UnwrapNullableType())) { orderByClause.Orderings[i].Expression = Expression.Convert(orderByClause.Orderings[i].Expression, originalTypes[i]); } @@ -1041,10 +1039,10 @@ public override void VisitSelectClause(SelectClause selectClause, QueryModel que base.VisitSelectClause(selectClause, queryModel); var newType = selectClause.Selector.Type; - if (originalType != newType + if ((originalType != newType) && !originalType.IsNullableType() && newType.IsNullableType() - && originalType.UnwrapNullableType() == newType.UnwrapNullableType()) + && (originalType.UnwrapNullableType() == newType.UnwrapNullableType())) { selectClause.Selector = Expression.Convert(selectClause.Selector, originalType); } @@ -1139,10 +1137,10 @@ public override void VisitResultOperator(ResultOperatorBase resultOperator, Quer var translatedExpression = TransformingVisitor.Visit(originalExpression); var newType = translatedExpression.Type; - if (originalType != newType + if ((originalType != newType) && !originalType.IsNullableType() && newType.IsNullableType() - && originalType.UnwrapNullableType() == newType.UnwrapNullableType()) + && (originalType.UnwrapNullableType() == newType.UnwrapNullableType())) { translatedExpression = Expression.Convert(translatedExpression, originalType); }