-
Notifications
You must be signed in to change notification settings - Fork 271
/
function.go
478 lines (399 loc) · 14 KB
/
function.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
// Copyright 2021 - 2022 Matrix Origin
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package function
import (
"context"
"fmt"
"github.com/matrixorigin/matrixone/pkg/sql/colexec/agg"
"github.com/matrixorigin/matrixone/pkg/common/moerr"
"github.com/matrixorigin/matrixone/pkg/container/types"
"github.com/matrixorigin/matrixone/pkg/container/vector"
"github.com/matrixorigin/matrixone/pkg/pb/plan"
"github.com/matrixorigin/matrixone/pkg/vm/process"
)
var allSupportedFunctions [1000]FuncNew
// register all supported functions.
func initAllSupportedFunctions() {
for _, fn := range supportedOperators {
allSupportedFunctions[fn.functionId] = fn
}
for _, fn := range supportedStringBuiltIns {
allSupportedFunctions[fn.functionId] = fn
}
for _, fn := range supportedDateAndTimeBuiltIns {
allSupportedFunctions[fn.functionId] = fn
}
for _, fn := range supportedMathBuiltIns {
allSupportedFunctions[fn.functionId] = fn
}
for _, fn := range supportedArrayOperations {
allSupportedFunctions[fn.functionId] = fn
}
for _, fn := range supportedControlBuiltIns {
allSupportedFunctions[fn.functionId] = fn
}
for _, fn := range supportedOthersBuiltIns {
allSupportedFunctions[fn.functionId] = fn
}
for _, fn := range supportedAggregateFunctions {
allSupportedFunctions[fn.functionId] = fn
}
for _, fn := range supportedWindowFunctions {
allSupportedFunctions[fn.functionId] = fn
}
agg.InitAggFramework(generateAggExecutorWithoutConfig, generateAggExecutor, GetFunctionIsWinOrderFunById)
}
func GetFunctionIsAggregateByName(name string) bool {
fid, exists := getFunctionIdByNameWithoutErr(name)
if !exists {
return false
}
f := allSupportedFunctions[fid]
return f.isAggregate()
}
func GetFunctionIsWinFunByName(name string) bool {
fid, exists := getFunctionIdByNameWithoutErr(name)
if !exists {
return false
}
f := allSupportedFunctions[fid]
return f.isWindow()
}
func GetFunctionIsWinOrderFunByName(name string) bool {
fid, exists := getFunctionIdByNameWithoutErr(name)
if !exists {
return false
}
f := allSupportedFunctions[fid]
return f.isWindowOrder()
}
func GetFunctionIsWinOrderFunById(overloadID int64) bool {
fid, _ := DecodeOverloadID(overloadID)
return allSupportedFunctions[fid].isWindowOrder()
}
func GetFunctionIsMonotonicById(ctx context.Context, overloadID int64) (bool, error) {
fid, oIndex := DecodeOverloadID(overloadID)
if int(fid) >= len(allSupportedFunctions) || int(fid) != allSupportedFunctions[fid].functionId {
return false, moerr.NewInvalidInput(ctx, "function overload id not found")
}
f := allSupportedFunctions[fid]
if f.Overloads[oIndex].volatile {
return false, nil
}
return f.testFlag(plan.Function_MONOTONIC), nil
}
func GetFunctionById(ctx context.Context, overloadID int64) (f overload, err error) {
fid, oIndex := DecodeOverloadID(overloadID)
if int(fid) >= len(allSupportedFunctions) || int(fid) != allSupportedFunctions[fid].functionId {
return overload{}, moerr.NewInvalidInput(ctx, "function overload id not found")
}
return allSupportedFunctions[fid].Overloads[oIndex], nil
}
func GetLayoutById(ctx context.Context, overloadID int64) (FuncExplainLayout, error) {
fid, _ := DecodeOverloadID(overloadID)
if int(fid) >= len(allSupportedFunctions) || int(fid) != allSupportedFunctions[fid].functionId {
return 0, moerr.NewInvalidInput(ctx, "function overload id not found")
}
return allSupportedFunctions[fid].layout, nil
}
func GetFunctionByIdWithoutError(overloadID int64) (f overload, exists bool) {
fid, oIndex := DecodeOverloadID(overloadID)
if int(fid) >= len(allSupportedFunctions) || int(fid) != allSupportedFunctions[fid].functionId {
return overload{}, false
}
return allSupportedFunctions[fid].Overloads[oIndex], true
}
func GetFunctionByName(ctx context.Context, name string, args []types.Type) (r FuncGetResult, err error) {
r.fid, err = getFunctionIdByName(ctx, name)
if err != nil {
return r, err
}
f := allSupportedFunctions[r.fid]
if len(f.Overloads) == 0 || f.checkFn == nil {
return r, moerr.NewNYI(ctx, "should implement the function %s", name)
}
check := f.checkFn(f.Overloads, args)
switch check.status {
case succeedMatched:
r.overloadId = int32(check.idx)
r.retType = f.Overloads[r.overloadId].retType(args)
r.cannotRunInParallel = f.Overloads[r.overloadId].cannotParallel
case succeedWithCast:
r.overloadId = int32(check.idx)
r.needCast = true
r.targetTypes = check.finalType
r.retType = f.Overloads[r.overloadId].retType(r.targetTypes)
r.cannotRunInParallel = f.Overloads[r.overloadId].cannotParallel
case failedFunctionParametersWrong:
if f.isFunction() {
err = moerr.NewInvalidArg(ctx, fmt.Sprintf("function %s", name), args)
} else {
err = moerr.NewInvalidArg(ctx, fmt.Sprintf("operator %s", name), args)
}
case failedAggParametersWrong:
err = moerr.NewInvalidArg(ctx, fmt.Sprintf("aggregate function %s", name), args)
case failedTooManyFunctionMatched:
err = moerr.NewInvalidArg(ctx, fmt.Sprintf("too many overloads matched %s", name), args)
}
return r, err
}
// RunFunctionDirectly runs a function directly without any protections.
// It is dangerous and should be used only when you are sure that the overloadID is correct and the inputs are valid.
func RunFunctionDirectly(proc *process.Process, overloadID int64, inputs []*vector.Vector, length int) (*vector.Vector, error) {
f, err := GetFunctionById(proc.Ctx, overloadID)
if err != nil {
return nil, err
}
mp := proc.Mp()
inputTypes := make([]types.Type, len(inputs))
for i := range inputTypes {
inputTypes[i] = *inputs[i].GetType()
}
result := vector.NewFunctionResultWrapper(proc.GetVector, proc.PutVector, f.retType(inputTypes), mp)
fold := true
evaluateLength := length
if !f.CannotFold() && !f.IsRealTimeRelated() {
for _, param := range inputs {
if !param.IsConst() {
fold = false
}
}
if fold {
evaluateLength = 1
}
}
if err = result.PreExtendAndReset(evaluateLength); err != nil {
result.Free()
return nil, err
}
exec := f.GetExecuteMethod()
if err = exec(inputs, result, proc, evaluateLength); err != nil {
result.Free()
return nil, err
}
vec := result.GetResultVector()
if fold {
// ToConst is a confused method. it just returns a new pointer to the same memory.
// so we need to duplicate it.
cvec, er := vec.ToConst(0, length, mp).Dup(mp)
result.Free()
if er != nil {
return nil, er
}
return cvec, nil
}
return vec, nil
}
func generateAggExecutor(
overloadID int64, isDistinct bool, inputTypes []types.Type, config any, partialresult any) (agg.Agg[any], error) {
f, exist := GetFunctionByIdWithoutError(overloadID)
if !exist {
return nil, moerr.NewInvalidInputNoCtx("function id '%d' not found", overloadID)
}
outputTyp := f.retType(inputTypes)
return f.aggFramework.aggNew(overloadID, isDistinct, inputTypes, outputTyp, config, partialresult)
}
func generateAggExecutorWithoutConfig(
overloadID int64, isDistinct bool, inputTypes []types.Type) (agg.Agg[any], error) {
return generateAggExecutor(overloadID, isDistinct, inputTypes, nil, nil)
}
func GetAggFunctionNameByID(overloadID int64) string {
f, exist := GetFunctionByIdWithoutError(overloadID)
if !exist {
return "unknown function"
}
return f.aggFramework.str
}
// DeduceNotNullable helps optimization sometimes.
// deduce notNullable for function
// for example, create table t1(c1 int not null, c2 int, c3 int not null ,c4 int);
// sql select c1+1, abs(c2), cast(c3 as varchar(10)) from t1 where c1=c3;
// we can deduce that c1+1, cast c3 and c1=c3 is notNullable, abs(c2) is nullable.
func DeduceNotNullable(overloadID int64, args []*plan.Expr) bool {
fid, _ := DecodeOverloadID(overloadID)
if allSupportedFunctions[fid].testFlag(plan.Function_PRODUCE_NO_NULL) {
return true
}
for _, arg := range args {
if !arg.Typ.NotNullable {
return false
}
}
return true
}
type FuncGetResult struct {
fid int32
overloadId int32
retType types.Type
cannotRunInParallel bool
needCast bool
targetTypes []types.Type
}
func (fr *FuncGetResult) GetEncodedOverloadID() (overloadID int64) {
return encodeOverloadID(fr.fid, fr.overloadId)
}
func (fr *FuncGetResult) ShouldDoImplicitTypeCast() (typs []types.Type, should bool) {
return fr.targetTypes, fr.needCast
}
func (fr *FuncGetResult) GetReturnType() types.Type {
return fr.retType
}
func (fr *FuncGetResult) CannotRunInParallel() bool {
return fr.cannotRunInParallel
}
func encodeOverloadID(fid, overloadId int32) (overloadID int64) {
overloadID = int64(fid)
overloadID = overloadID << 32
overloadID |= int64(overloadId)
return
}
func DecodeOverloadID(overloadID int64) (fid int32, oIndex int32) {
base := overloadID
oIndex = int32(overloadID)
fid = int32(base >> 32)
return fid, oIndex
}
func getFunctionIdByName(ctx context.Context, name string) (int32, error) {
if fid, ok := functionIdRegister[name]; ok {
return fid, nil
}
return -1, moerr.NewNotSupported(ctx, "function or operator '%s'", name)
}
func getFunctionIdByNameWithoutErr(name string) (int32, bool) {
fid, exist := functionIdRegister[name]
return fid, exist
}
// FuncNew stores all information about a function.
// including the unique id that marks the function, the class which the function belongs to,
// and all overloads of the function.
type FuncNew struct {
// unique id of function.
functionId int
// function type.
class plan.Function_FuncFlag
// All overloads of the function.
Overloads []overload
// checkFn was used to check whether the input type can match the requirement of the function.
// if matched, return the corresponding id of overload. If type conversion was required,
// the required type should be returned at the same time.
checkFn func(overloads []overload, inputs []types.Type) checkResult
// layout was used for `explain SQL`.
layout FuncExplainLayout
}
type executeLogicOfOverload func(parameters []*vector.Vector,
result vector.FunctionResultWrapper,
proc *process.Process, length int) error
type aggregationLogicOfOverload struct {
// agg related string for error message.
str string
// newAgg is used to create a new aggregation structure for agg framework.
aggNew func(overloadID int64, dist bool, inputTypes []types.Type, outputType types.Type, config any, partialresult any) (agg.Agg[any], error)
}
// an overload of a function.
// stores all information about execution logic.
type overload struct {
overloadId int
// args records some type information about this overload.
// in most case, it records, in order, which parameter types the overload required.
// For example,
// args can be `{int64, int64}` of one overload for the `pow` function.
// this means the overload can accept {int64, int64} as its input.
// but it was not necessarily the type directly required by the overload.
// what it is depends on the logic of function's checkFn.
args []types.T
// return type of the overload.
// parameters are the params actually received when the overload is executed.
retType func(parameters []types.Type) types.Type
// the execution logic.
newOp func() executeLogicOfOverload
// in fact, the function framework does not directly run aggregate functions and window functions.
// we use two flags to mark whether function is one of them.
isAgg bool
isWin bool
aggFramework aggregationLogicOfOverload
// if true, overload was unable to run in parallel.
// For example,
// rand(1) cannot run in parallel because it should use the same rand seed.
//
// TODO: there is not a good place to use that in plan now. the attribute is not effective.
cannotParallel bool
// if true, overload cannot be folded
volatile bool
// if realTimeRelated, overload cannot be folded when `Prepare`.
realTimeRelated bool
}
func (ov *overload) CannotFold() bool {
return ov.volatile
}
func (ov *overload) IsRealTimeRelated() bool {
return ov.realTimeRelated
}
func (ov *overload) IsAgg() bool {
return ov.isAgg
}
func (ov *overload) CannotExecuteInParallel() bool {
return ov.cannotParallel
}
func (ov *overload) GetExecuteMethod() executeLogicOfOverload {
f := ov.newOp
return f()
}
func (ov *overload) GetReturnTypeMethod() func(parameters []types.Type) types.Type {
return ov.retType
}
func (ov *overload) IsWin() bool {
return ov.isWin
}
func (fn *FuncNew) isFunction() bool {
return fn.layout == STANDARD_FUNCTION || fn.layout >= NOPARAMETER_FUNCTION
}
func (fn *FuncNew) isAggregate() bool {
return fn.testFlag(plan.Function_AGG)
}
func (fn *FuncNew) isWindow() bool {
return fn.testFlag(plan.Function_WIN_ORDER) || fn.testFlag(plan.Function_WIN_VALUE) || fn.testFlag(plan.Function_AGG)
}
func (fn *FuncNew) isWindowOrder() bool {
return fn.testFlag(plan.Function_WIN_ORDER)
}
func (fn *FuncNew) testFlag(funcFlag plan.Function_FuncFlag) bool {
return fn.class&funcFlag != 0
}
type overloadCheckSituation int
const (
succeedMatched overloadCheckSituation = 0
succeedWithCast overloadCheckSituation = -1
failedFunctionParametersWrong overloadCheckSituation = -2
failedAggParametersWrong overloadCheckSituation = -3
failedTooManyFunctionMatched overloadCheckSituation = -4
)
type checkResult struct {
status overloadCheckSituation
// if matched
idx int
finalType []types.Type
}
func newCheckResultWithSuccess(overloadId int) checkResult {
return checkResult{status: succeedMatched, idx: overloadId}
}
func newCheckResultWithFailure(status overloadCheckSituation) checkResult {
return checkResult{status: status}
}
func newCheckResultWithCast(overloadId int, castType []types.Type) checkResult {
return checkResult{
status: succeedWithCast,
idx: overloadId,
finalType: castType,
}
}