-
Notifications
You must be signed in to change notification settings - Fork 485
/
ChainFlatteningRewriter.cs
118 lines (101 loc) · 6.01 KB
/
ChainFlatteningRewriter.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
// -------------------------------------------------------------------------------------------------
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
// -------------------------------------------------------------------------------------------------
using System;
using System.Collections.Generic;
using System.Linq;
using EnsureThat;
using Microsoft.Health.Fhir.Core.Features.Search.Expressions;
using Microsoft.Health.Fhir.SqlServer.Features.Search.Expressions.Visitors.QueryGenerators;
namespace Microsoft.Health.Fhir.SqlServer.Features.Search.Expressions.Visitors
{
/// <summary>
/// Flattens chained expressions into <see cref="SqlRootExpression"/>'s <see cref="SqlRootExpression.TableExpressions"/> list.
/// The expression within a chained expression is promoted to a top-level table expression, but we keep track of the height
/// via the <see cref="TableExpression.ChainLevel"/>.
/// </summary>
internal class ChainFlatteningRewriter : SqlExpressionRewriterWithInitialContext<(TableExpression containingTableExpression, int chainLevel)>
{
private readonly NormalizedSearchParameterQueryGeneratorFactory _normalizedSearchParameterQueryGeneratorFactory;
public ChainFlatteningRewriter(NormalizedSearchParameterQueryGeneratorFactory normalizedSearchParameterQueryGeneratorFactory)
{
EnsureArg.IsNotNull(normalizedSearchParameterQueryGeneratorFactory, nameof(normalizedSearchParameterQueryGeneratorFactory));
_normalizedSearchParameterQueryGeneratorFactory = normalizedSearchParameterQueryGeneratorFactory;
}
public override Expression VisitChained(ChainedExpression expression, (TableExpression containingTableExpression, int chainLevel) context)
{
TableExpression thisTableExpression;
if (expression.Expression is ChainedExpression)
{
thisTableExpression = context.containingTableExpression ??
new TableExpression(
ChainAnchorQueryGenerator.Instance,
expression,
null,
TableExpressionKind.Chain,
context.chainLevel);
Expression visitedExpression = expression.Expression.AcceptVisitor(this, (null, context.chainLevel + 1));
switch (visitedExpression)
{
case TableExpression child:
return Expression.And(thisTableExpression, child);
case MultiaryExpression multiary when multiary.MultiaryOperation == MultiaryOperator.And:
var tableExpressions = new List<TableExpression> { thisTableExpression };
tableExpressions.AddRange(multiary.Expressions.Cast<TableExpression>());
return Expression.And(tableExpressions);
default:
throw new InvalidOperationException("Unexpected return type");
}
}
NormalizedSearchParameterQueryGenerator normalizedParameterQueryGenerator = expression.Expression.AcceptVisitor(_normalizedSearchParameterQueryGeneratorFactory);
thisTableExpression = context.containingTableExpression;
if (thisTableExpression == null || normalizedParameterQueryGenerator == null)
{
thisTableExpression = new TableExpression(
ChainAnchorQueryGenerator.Instance,
expression,
denormalizedPredicate: normalizedParameterQueryGenerator == null ? expression.Expression : null,
TableExpressionKind.Chain,
context.chainLevel,
thisTableExpression?.DenormalizedPredicateOnChainRoot);
}
if (normalizedParameterQueryGenerator == null)
{
return thisTableExpression;
}
var childTableExpression = new TableExpression(normalizedParameterQueryGenerator, expression.Expression, null, TableExpressionKind.Normal, context.chainLevel);
return Expression.And(thisTableExpression, childTableExpression);
}
public override Expression VisitSqlRoot(SqlRootExpression expression, (TableExpression containingTableExpression, int chainLevel) context)
{
List<TableExpression> newTableExpressions = null;
for (var i = 0; i < expression.TableExpressions.Count; i++)
{
TableExpression tableExpression = expression.TableExpressions[i];
if (tableExpression.Kind != TableExpressionKind.Chain)
{
newTableExpressions?.Add(tableExpression);
continue;
}
Expression visitedNormalizedPredicate = tableExpression.NormalizedPredicate.AcceptVisitor(this, (tableExpression, tableExpression.ChainLevel));
switch (visitedNormalizedPredicate)
{
case TableExpression convertedExpression:
EnsureAllocatedAndPopulated(ref newTableExpressions, expression.TableExpressions, i);
newTableExpressions.Add(convertedExpression);
break;
case MultiaryExpression multiary when multiary.MultiaryOperation == MultiaryOperator.And:
EnsureAllocatedAndPopulated(ref newTableExpressions, expression.TableExpressions, i);
newTableExpressions.AddRange(multiary.Expressions.Cast<TableExpression>());
break;
}
}
if (newTableExpressions == null)
{
return expression;
}
return new SqlRootExpression(newTableExpressions, expression.DenormalizedExpressions);
}
}
}