Skip to content

Commit

Permalink
Fix to #5454 - Error when using Where with reference
Browse files Browse the repository at this point in the history
Problem was for happening for queries with optional navigations like so:

context.Orders.Where(o => o.Customer.IsVip)

In this case Customer can be nullable, so IsVip can also be nullable. We compensate for this by introducing the following:

context.Orders.Where(o => o.Customer != null ? (bool?)o.Customer.IsVip : (bool?)null)

However, we didn't convert it back to the original type (users had to introduce those casts themselves). Without the cast, those queries were throwing compile-time exceptions that were not very informative.

Fix is to cast back to the original type requested by user:

context.Orders.Where(o => (bool)(o.Customer != null ? (bool?)o.Customer.IsVip : (bool?)null))

This still may cause runtime errors if o.Customer is actually null, but those are much more understandable now.

Also fixed compilation errors for complex Skip/Take arguments. Those are evaluated on the client for the time being.

CR: Andrew, Smit
  • Loading branch information
maumar authored and AndriySvyryd committed Jul 1, 2016
1 parent 5a2372e commit c499000
Show file tree
Hide file tree
Showing 6 changed files with 539 additions and 117 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ public class RelationalResultOperatorHandler : IResultOperatorHandler
private sealed class HandlerContext
{
private readonly IResultOperatorHandler _resultOperatorHandler;
private readonly ISqlTranslatingExpressionVisitorFactory _sqlTranslatingExpressionVisitorFactory;

public HandlerContext(
IResultOperatorHandler resultOperatorHandler,
Expand All @@ -43,10 +44,10 @@ private sealed class HandlerContext
SelectExpression selectExpression)
{
_resultOperatorHandler = resultOperatorHandler;
_sqlTranslatingExpressionVisitorFactory = sqlTranslatingExpressionVisitorFactory;

Model = model;
RelationalAnnotationProvider = relationalAnnotationProvider;
SqlTranslatingExpressionVisitorFactory = sqlTranslatingExpressionVisitorFactory;
SelectExpressionFactory = selectExpressionFactory;
QueryModelVisitor = queryModelVisitor;
ResultOperator = resultOperator;
Expand All @@ -56,7 +57,6 @@ private sealed class HandlerContext

public IModel Model { get; }
public IRelationalAnnotationProvider RelationalAnnotationProvider { get; }
public ISqlTranslatingExpressionVisitorFactory SqlTranslatingExpressionVisitorFactory { get; }
public ISelectExpressionFactory SelectExpressionFactory { get; }
public ResultOperatorBase ResultOperator { get; }
public SelectExpression SelectExpression { get; }
Expand All @@ -71,6 +71,15 @@ public Expression EvalOnClient(bool requiresClientResultOperator = true)
return _resultOperatorHandler
.HandleResultOperator(QueryModelVisitor, ResultOperator, QueryModel);
}

public SqlTranslatingExpressionVisitor CreateSqlTranslatingVisitor(bool bindParentQueries = false)
{
return _sqlTranslatingExpressionVisitorFactory
.Create(
QueryModelVisitor,
SelectExpression,
bindParentQueries: bindParentQueries);
}
}

private static readonly Dictionary<Type, Func<HandlerContext, Expression>>
Expand Down Expand Up @@ -166,11 +175,7 @@ var handlerContext

private static Expression HandleAll(HandlerContext handlerContext)
{
var filteringVisitor
= handlerContext.SqlTranslatingExpressionVisitorFactory
.Create(
handlerContext.QueryModelVisitor,
handlerContext.SelectExpression);
var filteringVisitor = handlerContext.CreateSqlTranslatingVisitor();

var predicate
= filteringVisitor.Visit(
Expand Down Expand Up @@ -229,17 +234,11 @@ private static Expression HandleCast(HandlerContext handlerContext)

private static Expression HandleContains(HandlerContext handlerContext)
{
var filteringVisitor
= handlerContext.SqlTranslatingExpressionVisitorFactory
.Create(
handlerContext.QueryModelVisitor,
handlerContext.SelectExpression,
bindParentQueries: true);
var filteringVisitor = handlerContext.CreateSqlTranslatingVisitor(bindParentQueries: true);

var itemResultOperator = (ContainsResultOperator)handlerContext.ResultOperator;

var item = filteringVisitor.Visit(itemResultOperator.Item);

if (item != null)
{
var itemSelectExpression = item as SelectExpression;
Expand Down Expand Up @@ -431,11 +430,7 @@ private static Expression HandleFirst(HandlerContext handlerContext)

private static Expression HandleGroup(HandlerContext handlerContext)
{
var sqlTranslatingExpressionVisitor
= handlerContext.SqlTranslatingExpressionVisitorFactory
.Create(
handlerContext.QueryModelVisitor,
handlerContext.SelectExpression);
var sqlTranslatingExpressionVisitor = handlerContext.CreateSqlTranslatingVisitor();

var groupResultOperator = (GroupResultOperator)handlerContext.ResultOperator;

Expand Down Expand Up @@ -650,9 +645,17 @@ private static Expression HandleSkip(HandlerContext handlerContext)
{
var skipResultOperator = (SkipResultOperator)handlerContext.ResultOperator;

handlerContext.SelectExpression.Offset = skipResultOperator.Count;
var sqlTranslatingExpressionVisitor = handlerContext.CreateSqlTranslatingVisitor(bindParentQueries: true);

return handlerContext.EvalOnServer;
var offset = sqlTranslatingExpressionVisitor.Visit(skipResultOperator.Count);
if (offset != null)
{
handlerContext.SelectExpression.Offset = offset;

return handlerContext.EvalOnServer;
}

return handlerContext.EvalOnClient();
}

private static Expression HandleSum(HandlerContext handlerContext)
Expand All @@ -676,9 +679,17 @@ private static Expression HandleTake(HandlerContext handlerContext)
{
var takeResultOperator = (TakeResultOperator)handlerContext.ResultOperator;

handlerContext.SelectExpression.Limit = takeResultOperator.Count;
var sqlTranslatingExpressionVisitor = handlerContext.CreateSqlTranslatingVisitor(bindParentQueries: true);

return handlerContext.EvalOnServer;
var limit = sqlTranslatingExpressionVisitor.Visit(takeResultOperator.Count);
if (limit != null)
{
handlerContext.SelectExpression.Limit = takeResultOperator.Count;

return handlerContext.EvalOnServer;
}

return handlerContext.EvalOnClient();
}

private static void SetProjectionConditionalExpression(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1500,6 +1500,125 @@ public virtual void Coalesce_operator_in_projection_with_other_conditions()
}
}

[ConditionalFact]
public virtual void Optional_navigation_type_compensation_works_with_predicate()
{
using (var context = CreateContext())
{
var query = context.Tags.Where(t => t.Note != "K.I.A." && t.Gear.HasSoulPatch);
var result = query.ToList();

Assert.Equal(2, result.Count);
}
}

[ConditionalFact]
public virtual void Optional_navigation_type_compensation_works_with_projection()
{
using (var context = CreateContext())
{
var query = context.Tags.Where(t => t.Note != "K.I.A.").Select(t => t.Gear.SquadId);
var result = query.ToList();

Assert.Equal(5, result.Count);
}
}

[ConditionalFact]
public virtual void Optional_navigation_type_compensation_works_with_projection_into_anonymous_type()
{
using (var context = CreateContext())
{
var query = context.Tags.Where(t => t.Note != "K.I.A.").Select(t => new { t.Gear.SquadId });
var result = query.ToList();

Assert.Equal(5, result.Count);
}
}

[ConditionalFact]
public virtual void Optional_navigation_type_compensation_works_with_orderby()
{
using (var context = CreateContext())
{
var query = context.Tags.Where(t => t.Note != "K.I.A.").OrderBy(t => t.Gear.SquadId);
var result = query.ToList();

Assert.Equal(5, result.Count);
}
}

[ConditionalFact]
public virtual void Optional_navigation_type_compensation_works_with_groupby()
{
using (var context = CreateContext())
{
var query = context.Tags.Where(t => t.Note != "K.I.A.").GroupBy(t => t.Gear.SquadId);
var result = query.ToList();

Assert.Equal(2, result.Count);
}
}

[ConditionalFact]
public virtual void Optional_navigation_type_compensation_works_with_all()
{
using (var context = CreateContext())
{
var query = context.Tags.Where(t => t.Note != "K.I.A.").All(t => t.Gear.HasSoulPatch);

Assert.False(query);
}
}

[ConditionalFact]
public virtual void Optional_navigation_type_compensation_works_with_contains()
{
using (var context = CreateContext())
{
var query = context.Tags.Where(t => t.Note != "K.I.A." && context.Gears.Select(g => g.SquadId).Contains(t.Gear.SquadId));
var result = query.ToList();

Assert.Equal(5, result.Count);
}
}

[ConditionalFact]
public virtual void Optional_navigation_type_compensation_works_with_skip()
{
using (var context = CreateContext())
{
var query = context.Tags.Where(t => t.Note != "K.I.A.").Select(t => context.Gears.Skip(t.Gear.SquadId));
var result = query.ToList();

Assert.Equal(5, result.Count);
}
}

[ConditionalFact]
public virtual void Optional_navigation_type_compensation_works_with_take()
{
using (var context = CreateContext())
{
var query = context.Tags.Where(t => t.Note != "K.I.A.").Select(t => context.Gears.Take(t.Gear.SquadId));
var result = query.ToList();

Assert.Equal(5, result.Count);
}
}

[ConditionalFact]
public virtual void Optional_navigation_type_compensation_throws_rasonable_exception_for_nullable_values()
{
using (var context = CreateContext())
{
var query = context.Tags.Where(t => t.Gear.HasSoulPatch);

//Nullable object must have a value
Assert.Throws<InvalidOperationException>(() => query.ToList());
}
}

protected GearsOfWarContext CreateContext() => Fixture.CreateContext(TestStore);

protected GearsOfWarQueryTestBase(TFixture fixture)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ public Gear()
public string LeaderNickname { get; set; }
public int LeaderSquadId { get; set; }

public bool HasSoulPatch { get; set; }

[NotMapped]
public bool IsMarcus => Nickname == "Marcus";
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ public static void Seed(GearsOfWarContext context)
{
Nickname = "Dom",
FullName = "Dominic Santiago",
HasSoulPatch = false,
SquadId = deltaSquad.Id,
Rank = MilitaryRank.Corporal,
AssignedCity = ephyra,
Expand All @@ -201,6 +202,7 @@ public static void Seed(GearsOfWarContext context)
{
Nickname = "Cole Train",
FullName = "Augustus Cole",
HasSoulPatch = false,
SquadId = deltaSquad.Id,
Rank = MilitaryRank.Private,
CityOrBirthName = hanover.Name,
Expand All @@ -213,6 +215,7 @@ public static void Seed(GearsOfWarContext context)
{
Nickname = "Paduk",
FullName = "Garron Paduk",
HasSoulPatch = false,
SquadId = kiloSquad.Id,
Rank = MilitaryRank.Private,
CityOrBirthName = unknown.Name,
Expand All @@ -224,6 +227,7 @@ public static void Seed(GearsOfWarContext context)
{
Nickname = "Baird",
FullName = "Damon Baird",
HasSoulPatch = true,
SquadId = deltaSquad.Id,
Rank = MilitaryRank.Corporal,
CityOrBirthName = unknown.Name,
Expand All @@ -237,6 +241,7 @@ public static void Seed(GearsOfWarContext context)
{
Nickname = "Marcus",
FullName = "Marcus Fenix",
HasSoulPatch = true,
SquadId = deltaSquad.Id,
Rank = MilitaryRank.Sergeant,
CityOrBirthName = jacinto.Name,
Expand Down
Loading

0 comments on commit c499000

Please sign in to comment.