forked from elastic/elasticsearch
/
ScriptWeaver.java
113 lines (98 loc) · 4.62 KB
/
ScriptWeaver.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.sql.expression.gen.script;
import org.elasticsearch.xpack.sql.SqlIllegalArgumentException;
import org.elasticsearch.xpack.sql.expression.Attribute;
import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.expression.Expressions;
import org.elasticsearch.xpack.sql.expression.FieldAttribute;
import org.elasticsearch.xpack.sql.expression.function.aggregate.AggregateFunctionAttribute;
import org.elasticsearch.xpack.sql.expression.function.grouping.GroupingFunctionAttribute;
import org.elasticsearch.xpack.sql.expression.function.scalar.ScalarFunctionAttribute;
import org.elasticsearch.xpack.sql.expression.literal.IntervalDayTime;
import org.elasticsearch.xpack.sql.expression.literal.IntervalYearMonth;
import org.elasticsearch.xpack.sql.type.DataType;
import static org.elasticsearch.xpack.sql.expression.gen.script.ParamsBuilder.paramsBuilder;
/**
* Mixin-like interface for customizing the default script generation.
*/
public interface ScriptWeaver {
default ScriptTemplate asScript(Expression exp) {
if (exp.foldable()) {
return scriptWithFoldable(exp);
}
Attribute attr = Expressions.attribute(exp);
if (attr != null) {
if (attr instanceof ScalarFunctionAttribute) {
return scriptWithScalar((ScalarFunctionAttribute) attr);
}
if (attr instanceof AggregateFunctionAttribute) {
return scriptWithAggregate((AggregateFunctionAttribute) attr);
}
if (attr instanceof GroupingFunctionAttribute) {
return scriptWithGrouping((GroupingFunctionAttribute) attr);
}
if (attr instanceof FieldAttribute) {
return scriptWithField((FieldAttribute) attr);
}
}
throw new SqlIllegalArgumentException("Cannot evaluate script for expression {}", exp);
}
DataType dataType();
default ScriptTemplate scriptWithFoldable(Expression foldable) {
Object fold = foldable.fold();
// wrap intervals with dedicated methods for serialization
if (fold instanceof IntervalYearMonth) {
IntervalYearMonth iym = (IntervalYearMonth) fold;
return new ScriptTemplate(processScript("{sql}.intervalYearMonth({},{})"),
paramsBuilder().variable(iym.interval().toString()).variable(iym.dataType().name()).build(),
dataType());
} else if (fold instanceof IntervalDayTime) {
IntervalDayTime idt = (IntervalDayTime) fold;
return new ScriptTemplate(processScript("{sql}.intervalDayTime({},{})"),
paramsBuilder().variable(idt.interval().toString()).variable(idt.dataType().name()).build(),
dataType());
}
return new ScriptTemplate(processScript("{}"),
paramsBuilder().variable(fold).build(),
dataType());
}
default ScriptTemplate scriptWithScalar(ScalarFunctionAttribute scalar) {
ScriptTemplate nested = scalar.script();
return new ScriptTemplate(processScript(nested.template()),
paramsBuilder().script(nested.params()).build(),
dataType());
}
default ScriptTemplate scriptWithAggregate(AggregateFunctionAttribute aggregate) {
String template = "{}";
if (aggregate.dataType() == DataType.DATE) {
template = "{sql}.asDateTime({})";
}
return new ScriptTemplate(processScript(template),
paramsBuilder().agg(aggregate).build(),
dataType());
}
default ScriptTemplate scriptWithGrouping(GroupingFunctionAttribute grouping) {
String template = "{}";
if (grouping.dataType() == DataType.DATE) {
template = "{sql}.asDateTime({})";
}
return new ScriptTemplate(processScript(template),
paramsBuilder().grouping(grouping).build(),
dataType());
}
default ScriptTemplate scriptWithField(FieldAttribute field) {
return new ScriptTemplate(processScript("doc[{}].value"),
paramsBuilder().variable(field.name()).build(),
dataType());
}
default String processScript(String script) {
return formatTemplate(script);
}
default String formatTemplate(String template) {
return Scripts.formatTemplate(template);
}
}