-
Notifications
You must be signed in to change notification settings - Fork 21
/
utils.py
369 lines (304 loc) · 13 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
import copy
from typing import List
from mindsdb_sql.exceptions import PlanningException
from mindsdb_sql.parser.ast import (Identifier, Operation, Star, Select, BinaryOperation, Constant,
OrderBy, UnaryOperation, NullConstant, TypeCast, Parameter)
from mindsdb_sql.parser import ast
# def get_integration_path_from_identifier(identifier):
# parts = identifier.parts
# integration_name = parts[0]
# new_parts = parts[1:]
#
# if len(parts) == 1:
# raise PlanningException(f'No integration specified for table: {str(identifier)}')
# elif len(parts) > 4:
# raise PlanningException(f'Too many parts (dots) in table identifier: {str(identifier)}')
#
# new_identifier = copy.deepcopy(identifier)
# new_identifier.parts = new_parts
#
# return integration_name, new_identifier
def get_predictor_name_identifier(identifier):
new_identifier = copy.deepcopy(identifier)
if len(new_identifier.parts) > 1:
new_identifier.parts.pop(0)
return new_identifier
def disambiguate_predictor_column_identifier(identifier, predictor):
"""Removes integration name from column if it's present, adds table path if it's absent"""
table_ref = predictor.alias.parts_to_str() if predictor.alias else predictor.parts_to_str()
parts = list(identifier.parts)
if parts[0] == table_ref:
parts = parts[1:]
new_identifier = Identifier(parts=parts)
return new_identifier
def recursively_extract_column_values(op, row_dict, predictor):
if isinstance(op, BinaryOperation) and op.op == '=':
id = op.args[0]
value = op.args[1]
# if (
# isinstance(value, UnaryOperation)
# and value.op == '-'
# and isinstance(value.args[0], Constant)
# ):
# value = Constant(-value.args[0].value)
if not (
isinstance(id, Identifier)
and
(isinstance(value, Constant) or isinstance(value, Parameter))
):
raise PlanningException(f'The WHERE clause for selecting from a predictor'
f' must contain pairs \'Identifier(...) = Constant(...)\','
f' found instead: {id.to_tree()}, {value.to_tree()}')
id = disambiguate_predictor_column_identifier(id, predictor)
if str(id) in row_dict:
raise PlanningException(f'Multiple values provided for {str(id)}')
if isinstance(value, Constant):
value = value.value
row_dict[str(id)] = value
elif isinstance(op, BinaryOperation) and op.op == 'and':
recursively_extract_column_values(op.args[0], row_dict, predictor)
recursively_extract_column_values(op.args[1], row_dict, predictor)
else:
raise PlanningException(f'Only \'and\' and \'=\' operations allowed in WHERE clause, found: {op.to_tree()}')
def get_deepest_select(select):
if not select.from_table or not isinstance(select.from_table, Select):
return select
return get_deepest_select(select.from_table)
def query_traversal(node, callback, is_table=False, is_target=False, parent_query=None):
'''
:param node: element
:param callback: function applied to every element
:param is_table: it is table in query
:param is_target: it is the target in select
:param parent_query: current query (select/update/create/...) where we are now
:return:
new element if it is needed to be replaced
or None to keep element and traverse over it
'''
# traversal query tree to find and replace nodes
res = callback(node, is_table=is_table, is_target=is_target, parent_query=parent_query)
if res is not None:
# node is going to be replaced
return res
if isinstance(node, ast.Select):
if node.from_table is not None:
node_out = query_traversal(node.from_table, callback, is_table=True, parent_query=node)
if node_out is not None:
node.from_table = node_out
array = []
for node2 in node.targets:
node_out = query_traversal(node2, callback, parent_query=node, is_target=True) or node2
if isinstance(node_out, list):
array.extend(node_out)
else:
array.append(node_out)
node.targets = array
if node.cte is not None:
array = []
for cte in node.cte:
node_out = query_traversal(cte.query, callback, parent_query=node) or cte
array.append(node_out)
node.cte = array
if node.where is not None:
node_out = query_traversal(node.where, callback, parent_query=node)
if node_out is not None:
node.where = node_out
if node.group_by is not None:
array = []
for node2 in node.group_by:
node_out = query_traversal(node2, callback, parent_query=node) or node2
array.append(node_out)
node.group_by = array
if node.having is not None:
node_out = query_traversal(node.having, callback, parent_query=node)
if node_out is not None:
node.having = node_out
if node.order_by is not None:
array = []
for node2 in node.order_by:
node_out = query_traversal(node2, callback, parent_query=node) or node2
array.append(node_out)
node.order_by = array
elif isinstance(node, ast.Union):
node_out = query_traversal(node.left, callback, parent_query=node)
if node_out is not None:
node.left = node_out
node_out = query_traversal(node.right, callback, parent_query=node)
if node_out is not None:
node.right = node_out
elif isinstance(node, ast.Join):
node_out = query_traversal(node.right, callback, is_table=True, parent_query=parent_query)
if node_out is not None:
node.right = node_out
node_out = query_traversal(node.left, callback, is_table=True, parent_query=parent_query)
if node_out is not None:
node.left = node_out
if node.condition is not None:
node_out = query_traversal(node.condition, callback, parent_query=parent_query)
if node_out is not None:
node.condition = node_out
elif isinstance(node, ast.Function) \
or isinstance(node, ast.BinaryOperation)\
or isinstance(node, ast.UnaryOperation) \
or isinstance(node, ast.BetweenOperation):
array = []
for arg in node.args:
node_out = query_traversal(arg, callback, parent_query=parent_query) or arg
array.append(node_out)
node.args = array
elif isinstance(node, ast.WindowFunction):
query_traversal(node.function, callback, parent_query=parent_query)
if node.partition is not None:
array = []
for node2 in node.partition:
node_out = query_traversal(node2, callback, parent_query=parent_query) or node2
array.append(node_out)
node.partition = array
if node.order_by is not None:
array = []
for node2 in node.order_by:
node_out = query_traversal(node2, callback, parent_query=parent_query) or node2
array.append(node_out)
node.order_by = array
elif isinstance(node, ast.TypeCast):
node_out = query_traversal(node.arg, callback, parent_query=parent_query)
if node_out is not None:
node.arg = node_out
elif isinstance(node, ast.Tuple):
array = []
for node2 in node.items:
node_out = query_traversal(node2, callback, parent_query=parent_query) or node2
array.append(node_out)
node.items = array
elif isinstance(node, ast.Insert):
if node.table is not None:
node_out = query_traversal(node.table, callback, is_table=True, parent_query=node)
if node_out is not None:
node.table = node_out
if node.values is not None:
rows = []
for row in node.values:
items = []
for item in row:
item2 = query_traversal(item, callback, parent_query=node) or item
items.append(item2)
rows.append(items)
node.values = rows
if node.from_select is not None:
node_out = query_traversal(node.from_select, callback, parent_query=node)
if node_out is not None:
node.from_select = node_out
elif isinstance(node, ast.Update):
if node.table is not None:
node_out = query_traversal(node.table, callback, is_table=True, parent_query=node)
if node_out is not None:
node.table = node_out
if node.where is not None:
node_out = query_traversal(node.where, callback, parent_query=node)
if node_out is not None:
node.where = node_out
if node.update_columns is not None:
changes = {}
for k, v in node.update_columns.items():
v2 = query_traversal(v, callback, parent_query=node)
if v2 is not None:
changes[k] = v2
if changes:
node.update_columns.update(changes)
if node.from_select is not None:
node_out = query_traversal(node.from_select, callback, parent_query=node)
if node_out is not None:
node.from_select = node_out
elif isinstance(node, ast.CreateTable):
array = []
if node.columns is not None:
for node2 in node.columns:
node_out = query_traversal(node2, callback, parent_query=node) or node2
array.append(node_out)
node.columns = array
if node.name is not None:
node_out = query_traversal(node.name, callback, is_table=True, parent_query=node)
if node_out is not None:
node.name = node_out
if node.from_select is not None:
node_out = query_traversal(node.from_select, callback, parent_query=node)
if node_out is not None:
node.from_select = node_out
elif isinstance(node, ast.Delete):
if node.where is not None:
node_out = query_traversal(node.where, callback, parent_query=node)
if node_out is not None:
node.where = node_out
elif isinstance(node, ast.OrderBy):
if node.field is not None:
node_out = query_traversal(node.field, callback, parent_query=parent_query)
if node_out is not None:
node.field = node_out
elif isinstance(node, ast.Case):
rules = []
for condition, result in node.rules:
condition2 = query_traversal(condition, callback, parent_query=parent_query)
result2 = query_traversal(result, callback, parent_query=parent_query)
condition = condition if condition2 is None else condition2
result = result if result2 is None else result2
rules.append([condition, result])
node.rules = rules
default = query_traversal(node.default, callback, parent_query=parent_query)
if default is not None:
node.default = default
elif isinstance(node, list):
array = []
for node2 in node:
node_out = query_traversal(node2, callback, parent_query=parent_query) or node2
array.append(node_out)
return array
# keep original node
return None
def convert_join_to_list(join):
# join tree to table list
if isinstance(join.right, ast.Join):
raise NotImplementedError('Wrong join AST')
items = []
if isinstance(join.left, ast.Join):
# dive to next level
items.extend(convert_join_to_list(join.left))
else:
# this is first table
items.append(dict(
table=join.left
))
# all properties set to right table
items.append(dict(
table=join.right,
join_type=join.join_type,
is_implicit=join.implicit,
condition=join.condition
))
return items
def get_query_params(query):
# find all parameters
params = []
def params_find(node, **kwargs):
if isinstance(node, ast.Parameter):
params.append(node)
return node
query_traversal(query, params_find)
return params
def fill_query_params(query, params):
params = copy.deepcopy(params)
def params_replace(node, **kwargs):
if isinstance(node, ast.Parameter):
value = params.pop(0)
return ast.Constant(value)
# put parameters into query
query_traversal(query, params_replace)
return query
def filters_to_bin_op(filters: List[BinaryOperation]):
# make a new where clause without params
where = None
for flt in filters:
if where is None:
where = flt
else:
where = BinaryOperation(op='and', args=[where, flt])
return where