From 71c8e32f3a5e90e047b54b2b9a866a3c9a574a47 Mon Sep 17 00:00:00 2001 From: Marcus Gartner Date: Wed, 20 Jul 2022 12:09:22 -0400 Subject: [PATCH] opt: build UDF expressions This commit adds basic support for building UDFs in optbuilder. Only scalar, nullary (arity of zero) functions with a single statement in the body are supported. Support for more types of UDFs will follow in future commits. Note that this commit does not add support for execution of UDFs, only building them within an optimizer expression. Release note: None --- pkg/sql/opt/memo/expr_format.go | 3 ++ pkg/sql/opt/norm/decorrelate_funcs.go | 4 ++ pkg/sql/opt/norm/testdata/rules/udf | 21 +++++++++ pkg/sql/opt/ops/scalar.opt | 15 ++++++ pkg/sql/opt/optbuilder/scalar.go | 34 ++++++++++++++ pkg/sql/opt/optbuilder/testdata/udf | 66 ++++++++++++++++++++++++++- 6 files changed, 142 insertions(+), 1 deletion(-) create mode 100644 pkg/sql/opt/norm/testdata/rules/udf diff --git a/pkg/sql/opt/memo/expr_format.go b/pkg/sql/opt/memo/expr_format.go index 578df2821779..66b89e3e9665 100644 --- a/pkg/sql/opt/memo/expr_format.go +++ b/pkg/sql/opt/memo/expr_format.go @@ -1512,6 +1512,9 @@ func FormatPrivate(f *ExprFmtCtx, private interface{}, physProps *physical.Requi case *FunctionPrivate: fmt.Fprintf(f.Buffer, " %s", t.Name) + case *UserDefinedFunctionPrivate: + fmt.Fprintf(f.Buffer, " %s", t.Name) + case *WindowsItemPrivate: fmt.Fprintf(f.Buffer, " frame=%q", &t.Frame) diff --git a/pkg/sql/opt/norm/decorrelate_funcs.go b/pkg/sql/opt/norm/decorrelate_funcs.go index 746aa48606c8..3701732275e4 100644 --- a/pkg/sql/opt/norm/decorrelate_funcs.go +++ b/pkg/sql/opt/norm/decorrelate_funcs.go @@ -62,6 +62,10 @@ func (c *CustomFuncs) deriveHasHoistableSubquery(scalar opt.ScalarExpr) bool { // WHERE clause, it will be transformed to an Exists operator, so this case // only occurs when the Any is nested, in a projection, etc. return !t.Input.Relational().OuterCols.Empty() + + case *memo.UserDefinedFunctionExpr: + // Do not attempt to hoist UDFs. + return false } // If HasHoistableSubquery is true for any child, then it's true for this diff --git a/pkg/sql/opt/norm/testdata/rules/udf b/pkg/sql/opt/norm/testdata/rules/udf new file mode 100644 index 000000000000..8548518c233b --- /dev/null +++ b/pkg/sql/opt/norm/testdata/rules/udf @@ -0,0 +1,21 @@ +exec-ddl +CREATE FUNCTION one() RETURNS INT LANGUAGE SQL AS 'SELECT 1'; +---- + +# Do not attempt to hoist UDFs. +norm +SELECT one() +---- +values + ├── columns: one:2 + ├── cardinality: [1 - 1] + ├── key: () + ├── fd: ()-->(2) + └── tuple + └── user-defined-function: one + └── values + ├── columns: "?column?":1!null + ├── cardinality: [1 - 1] + ├── key: () + ├── fd: ()-->(1) + └── (1,) diff --git a/pkg/sql/opt/ops/scalar.opt b/pkg/sql/opt/ops/scalar.opt index dcf85083cc5e..7abbc35d6c81 100644 --- a/pkg/sql/opt/ops/scalar.opt +++ b/pkg/sql/opt/ops/scalar.opt @@ -1214,6 +1214,21 @@ define NthValue { Nth ScalarExpr } +# UserDefinedFunction invokes a user-defined function. The +# UserDefinedFunctionPrivate field contains the name of the function and a +# pointer to its type. +[Scalar] +define UserDefinedFunction { + Body RelExpr + _ UserDefinedFunctionPrivate +} + +[Private] +define UserDefinedFunctionPrivate { + Name string + Typ Type +} + # KVOptions is a set of KVOptionItems that specify arbitrary keys and values # that are used as modifiers for various statements (see tree.KVOptions). The # key is a constant string but the value can be a scalar expression. diff --git a/pkg/sql/opt/optbuilder/scalar.go b/pkg/sql/opt/optbuilder/scalar.go index c6b50b81571f..41b8eb0bd115 100644 --- a/pkg/sql/opt/optbuilder/scalar.go +++ b/pkg/sql/opt/optbuilder/scalar.go @@ -21,6 +21,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/opt/cat" "github.com/cockroachdb/cockroach/pkg/sql/opt/memo" "github.com/cockroachdb/cockroach/pkg/sql/opt/norm" + "github.com/cockroachdb/cockroach/pkg/sql/parser" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" "github.com/cockroachdb/cockroach/pkg/sql/privilege" @@ -529,6 +530,10 @@ func (b *Builder) buildFunction( panic(err) } + if f.ResolvedOverload().Body != "" { + return b.buildUDF(f, def, inScope, outScope, outCol) + } + if isAggregate(def) { panic(errors.AssertionFailedf("aggregate function should have been replaced")) } @@ -583,6 +588,35 @@ func (b *Builder) buildFunction( return b.finishBuildScalar(f, out, inScope, outScope, outCol) } +// buildUDF builds a set of memo groups that represents a user-defined function +// invocation. +// TODO(mgartner): Support multi-statement UDFs. +// TODO(mgartner): Support UDFs with arguments. +func (b *Builder) buildUDF( + f *tree.FuncExpr, def *tree.FunctionDefinition, inScope, outScope *scope, outCol *scopeColumn, +) (out opt.ScalarExpr) { + stmt, err := parser.ParseOne(f.ResolvedOverload().Body) + if err != nil { + panic(err) + } + + // A statement inside a UDF body cannot refer to anything from the outer + // expression calling the function, so we use an empty scope. + // TODO(mgartner): We may need to set bodyScope.atRoot=true to prevent CTEs + // that mutate and are not at the top-level. + bodyScope := b.allocScope() + bodyScope = b.buildStmt(stmt.AST, nil /* desiredTypes */, bodyScope) + + out = b.factory.ConstructUserDefinedFunction( + bodyScope.expr, + &memo.UserDefinedFunctionPrivate{ + Name: def.Name, + Typ: f.ResolvedType(), + }, + ) + return b.finishBuildScalar(f, out, inScope, outScope, outCol) +} + // buildRangeCond builds a RANGE clause as a simpler expression. Examples: // x BETWEEN a AND b -> x >= a AND x <= b // x NOT BETWEEN a AND b -> NOT (x >= a AND x <= b) diff --git a/pkg/sql/opt/optbuilder/testdata/udf b/pkg/sql/opt/optbuilder/testdata/udf index 7d3ca7b03f86..a370e8ee357a 100644 --- a/pkg/sql/opt/optbuilder/testdata/udf +++ b/pkg/sql/opt/optbuilder/testdata/udf @@ -1,4 +1,68 @@ +exec-ddl +CREATE TABLE abc ( + a INT PRIMARY KEY, + b INT, + c INT +) +---- + build SELECT foo() ---- -error (42883): unknown function: foo() +error (42883): unknown function: foo + +exec-ddl +CREATE FUNCTION one() RETURNS INT LANGUAGE SQL AS 'SELECT 1'; +---- + +build +SELECT one() +---- +project + ├── columns: one:2 + ├── values + │ └── () + └── projections + └── user-defined-function: one [as=one:2] + └── project + ├── columns: "?column?":1!null + ├── values + │ └── () + └── projections + └── 1 [as="?column?":1] + +build +SELECT *, one() FROM abc +---- +project + ├── columns: a:1!null b:2 c:3 one:7 + ├── scan abc + │ └── columns: a:1!null b:2 c:3 crdb_internal_mvcc_timestamp:4 tableoid:5 + └── projections + └── user-defined-function: one [as=one:7] + └── project + ├── columns: "?column?":6!null + ├── values + │ └── () + └── projections + └── 1 [as="?column?":6] + +build +SELECT * FROM abc WHERE one() = c +---- +project + ├── columns: a:1!null b:2 c:3 + └── select + ├── columns: a:1!null b:2 c:3 crdb_internal_mvcc_timestamp:4 tableoid:5 + ├── scan abc + │ └── columns: a:1!null b:2 c:3 crdb_internal_mvcc_timestamp:4 tableoid:5 + └── filters + └── eq + ├── user-defined-function: one + │ └── project + │ ├── columns: "?column?":6!null + │ ├── values + │ │ └── () + │ └── projections + │ └── 1 [as="?column?":6] + └── c:3