-
Notifications
You must be signed in to change notification settings - Fork 51
/
emit_match.py
274 lines (213 loc) · 10.3 KB
/
emit_match.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
# Copyright 2017-present Kensho Technologies, LLC.
"""Convert lowered IR basic blocks to MATCH query strings."""
from collections import deque
import six
from .blocks import Filter, QueryRoot, Recurse, Traverse
from .expressions import TrueLiteral
from .helpers import validate_safe_string
def _get_vertex_location_name(location):
"""Get the location name from a location that is expected to point to a vertex."""
mark_name, field_name = location.get_location_name()
if field_name is not None:
raise AssertionError(u'Location unexpectedly pointed to a field: {}'.format(location))
return mark_name
def _first_step_to_match(match_step):
"""Transform the very first MATCH step into a MATCH query string."""
if not isinstance(match_step.root_block, QueryRoot):
raise AssertionError(u'Expected QueryRoot root block, received: '
u'{} {}'.format(match_step.root_block, match_step))
match_step.root_block.validate()
start_class_set = match_step.root_block.start_class
if len(start_class_set) != 1:
raise AssertionError(u'Attempted to emit MATCH but did not have exactly one start class: '
u'{} {}'.format(start_class_set, match_step))
start_class = list(start_class_set)[0]
# MATCH steps with a QueryRoot root block shouldn't have a 'coerce_type_block'.
if match_step.coerce_type_block is not None:
raise AssertionError(u'Invalid MATCH step: {}'.format(match_step))
parts = [
u'class: %s' % (start_class,),
]
if match_step.where_block:
match_step.where_block.validate()
parts.append(u'where: (%s)' % (match_step.where_block.predicate.to_match(),))
if match_step.as_block:
match_step.as_block.validate()
parts.append(u'as: %s' % (_get_vertex_location_name(match_step.as_block.location),))
return u'{{ %s }}' % (u', '.join(parts),)
def _subsequent_step_to_match(match_step):
"""Transform any subsequent (non-first) MATCH step into a MATCH query string."""
if not isinstance(match_step.root_block, (Traverse, Recurse)):
raise AssertionError(u'Expected Traverse root block, received: '
u'{} {}'.format(match_step.root_block, match_step))
is_recursing = isinstance(match_step.root_block, Recurse)
match_step.root_block.validate()
traversal_command = u'.%s(\'%s\')' % (match_step.root_block.direction,
match_step.root_block.edge_name)
parts = []
if match_step.coerce_type_block:
coerce_type_set = match_step.coerce_type_block.target_class
if len(coerce_type_set) != 1:
raise AssertionError(u'Found MATCH type coercion block with more than one target class:'
u' {} {}'.format(coerce_type_set, match_step))
coerce_type_target = list(coerce_type_set)[0]
parts.append(u'class: %s' % (coerce_type_target,))
if is_recursing:
# In MATCH, "$depth < 1" means "include the source vertex and its immediate neighbors."
# Yes, the "<" is intentional -- it's not supposed to be a "<=".
parts.append(u'while: ($depth < %d)' % (match_step.root_block.depth,))
if match_step.where_block:
match_step.where_block.validate()
parts.append(u'where: (%s)' % (match_step.where_block.predicate.to_match(),))
if not is_recursing and match_step.root_block.optional:
parts.append(u'optional: true')
if match_step.as_block:
match_step.as_block.validate()
parts.append(u'as: %s' % (_get_vertex_location_name(match_step.as_block.location),))
return u'%s {{ %s }}' % (traversal_command, u', '.join(parts))
def _represent_match_traversal(match_traversal):
"""Emit MATCH query code for an entire MATCH traversal sequence."""
output = []
output.append(_first_step_to_match(match_traversal[0]))
for step in match_traversal[1:]:
output.append(_subsequent_step_to_match(step))
return u''.join(output)
def _represent_fold(fold_location, fold_ir_blocks):
"""Emit a LET clause corresponding to the IR blocks for a @fold scope."""
start_let_template = u'$%(mark_name)s = %(base_location)s'
traverse_edge_template = u'.%(direction)s("%(edge_name)s")'
base_template = start_let_template + traverse_edge_template
edge_direction, edge_name = fold_location.relative_position
mark_name = fold_location.get_location_name()
base_location_name, _ = fold_location.base_location.get_location_name()
validate_safe_string(mark_name)
validate_safe_string(base_location_name)
validate_safe_string(edge_direction)
validate_safe_string(edge_name)
template_data = {
'mark_name': mark_name,
'base_location': base_location_name,
'direction': edge_direction,
'edge_name': edge_name,
}
final_string = base_template % template_data
for block in fold_ir_blocks:
if isinstance(block, Filter):
final_string += u'[' + block.predicate.to_match() + u']'
elif isinstance(block, Traverse):
template_data = {
'direction': block.direction,
'edge_name': block.edge_name,
}
final_string += traverse_edge_template % template_data
else:
raise AssertionError(u'Found a non-Filter/Traverse IR block in the folded IR blocks: '
u'{} {} {}'.format(type(block), block, fold_ir_blocks))
# Workaround for OrientDB's inconsistent return type when filtering a list.
# https://github.com/orientechnologies/orientdb/issues/7811
final_string += '.asList()'
return final_string
def _construct_output_to_match(output_block):
"""Transform a ConstructResult block into a MATCH query string."""
output_block.validate()
selections = (
u'%s AS `%s`' % (output_block.fields[key].to_match(), key)
for key in sorted(output_block.fields.keys()) # Sort keys for deterministic output order.
)
return u'SELECT %s FROM' % (u', '.join(selections),)
def _construct_where_to_match(where_block):
"""Transform a Filter block into a MATCH query string."""
if where_block.predicate == TrueLiteral:
raise AssertionError(u'Received WHERE block with TrueLiteral predicate: {}'
.format(where_block))
return u'WHERE ' + where_block.predicate.to_match()
##############
# Public API #
##############
def emit_code_from_single_match_query(match_query):
"""Return a MATCH query string from a list of IR blocks."""
query_data = deque([u'MATCH '])
if not match_query.match_traversals:
raise AssertionError(u'Unexpected falsy value for match_query.match_traversals received: '
u'{} {}'.format(match_query.match_traversals, match_query))
# Represent and add the MATCH traversal steps.
match_traversal_data = [
_represent_match_traversal(x)
for x in match_query.match_traversals
]
query_data.append(match_traversal_data[0])
for traversal_data in match_traversal_data[1:]:
query_data.append(u', ')
query_data.append(traversal_data)
query_data.appendleft(u' (') # Prepare to wrap the MATCH in a SELECT.
query_data.append(u'RETURN $matches)') # Finish the MATCH query and the wrapping ().
# Represent and add the LET clauses for any @fold scopes that might be part of the query.
# Sort for deterministic order of clauses.
fold_data = sorted([
_represent_fold(fold_location, fold_ir_blocks)
for fold_location, fold_ir_blocks in six.iteritems(match_query.folds)
])
if fold_data:
query_data.append(u' LET ')
query_data.append(fold_data[0])
for fold_clause in fold_data[1:]:
query_data.append(u', ')
query_data.append(fold_clause)
# Represent and add the SELECT clauses with the proper output data.
query_data.appendleft(_construct_output_to_match(match_query.output_block))
# Represent and add the WHERE clause with the proper filters.
if match_query.where_block is not None:
query_data.append(_construct_where_to_match(match_query.where_block))
return u' '.join(query_data)
def emit_code_from_multiple_match_queries(match_queries):
"""Return a MATCH query string from a list of MatchQuery namedtuples."""
optional_variable_base_name = '$optional__'
union_variable_name = '$result'
query_data = deque([u'SELECT EXPAND(', union_variable_name, u')', u' LET '])
optional_variables = []
sub_queries = [emit_code_from_single_match_query(match_query)
for match_query in match_queries]
for (i, sub_query) in enumerate(sub_queries):
variable_name = optional_variable_base_name + str(i)
variable_assignment = variable_name + u' = ('
sub_query_end = u'),'
query_data.append(variable_assignment)
query_data.append(sub_query)
query_data.append(sub_query_end)
optional_variables.append(variable_name)
query_data.append(union_variable_name)
query_data.append(u' = UNIONALL(')
query_data.append(u', '.join(optional_variables))
query_data.append(u')')
return u' '.join(query_data)
def emit_code_from_ir(compound_match_query):
"""Return a MATCH query string from a CompoundMatchQuery."""
# If the compound match query contains only one match query,
# just call `emit_code_from_single_match_query`
# If there are multiple match queries, construct the query string for each
# individual query and combine them as follows.
#
# SELECT EXPAND($result)
# LET
# $optional__0 = (
# <query_string_0>
# ),
# $optional__1 = (
# <query_string_1>
# ),
# $optional__2 = (
# <query_string_2>
# ),
#
# . . .
#
# $result = UNIONALL($optional__0, $optional__1, . . . )
match_queries = compound_match_query.match_queries
if len(match_queries) == 1:
query_string = emit_code_from_single_match_query(match_queries[0])
elif len(match_queries) > 1:
query_string = emit_code_from_multiple_match_queries(match_queries)
else:
raise AssertionError(u'Received CompoundMatchQuery with an empty list of MatchQueries: '
u'{}'.format(match_queries))
return query_string