From e0f9c2827889ccd5350ef327365cc6cb86136cb5 Mon Sep 17 00:00:00 2001 From: rharding6373 Date: Tue, 5 Dec 2023 22:15:49 -0800 Subject: [PATCH] sql: support sequence and udt name rewriting in plpgsql CRDB rewrites sequence and UDT names as IDs in views and functions so that if the sequence or UDT is renamed the views and functions using them don't break. This PR adds support for this in PLpgSQL. Epic: None Fixes: #115627 Release note (sql change): Fixes a bug in PLpgSQL where altering the name of a sequence or UDT that was used in a PLpgSQL function or procedure could break them. This is only present in 23.2 alpha and beta releases. --- .../testdata/logic_test/udf_rewrite | 189 ++++++-- pkg/sql/BUILD.bazel | 2 + pkg/sql/crdb_internal.go | 31 +- pkg/sql/create_function.go | 15 +- pkg/sql/create_view.go | 186 ++++++-- pkg/sql/opt/optbuilder/plpgsql.go | 7 +- pkg/sql/plpgsql/parser/testdata/stmt_case | 1 + pkg/sql/sem/plpgsqltree/BUILD.bazel | 1 + pkg/sql/sem/plpgsqltree/exception.go | 23 +- pkg/sql/sem/plpgsqltree/statements.go | 427 +++++++++++++----- pkg/sql/sem/plpgsqltree/utils/plpg_visitor.go | 270 ++++++++++- pkg/sql/sem/plpgsqltree/visitor.go | 7 +- pkg/sql/show_create_clauses.go | 159 +++++-- 13 files changed, 1045 insertions(+), 273 deletions(-) diff --git a/pkg/ccl/logictestccl/testdata/logic_test/udf_rewrite b/pkg/ccl/logictestccl/testdata/logic_test/udf_rewrite index d6579e3df3be..d5ffc7460cd5 100644 --- a/pkg/ccl/logictestccl/testdata/logic_test/udf_rewrite +++ b/pkg/ccl/logictestccl/testdata/logic_test/udf_rewrite @@ -37,34 +37,94 @@ SET use_declarative_schema_changer = 'on' subtest rewrite_plpgsql -statement ok -DROP FUNCTION IF EXISTS f_rewrite - statement ok CREATE OR REPLACE FUNCTION f_rewrite() RETURNS INT AS $$ + DECLARE + i INT := nextval('seq'); + j INT := nextval('seq'); + curs REFCURSOR := nextval('seq')::STRING; + curs2 CURSOR FOR SELECT nextval('seq'); BEGIN - SELECT nextval('seq'); + RAISE NOTICE USING MESSAGE = format('next val: %d',nextval('seq')); + RAISE NOTICE 'val1: %, val2: %', nextval('seq'), nextval('seq'); + WHILE nextval('seq') < 10 LOOP + i = nextval('seq'); + SELECT nextval('seq'); + IF nextval('seq') = 1 THEN + SELECT nextval('seq'); + SELECT nextval('seq'); + CONTINUE; + ELSIF nextval('seq') = 2 THEN + SELECT v INTO i FROM nextval('seq') AS v(INT); + ELSIF nextval('seq') = 3 THEN + SELECT nextval('seq'); + SELECT nextval('seq'); + END IF; + END LOOP; + OPEN curs FOR SELECT nextval('seq'); + RETURN nextval('seq'); + EXCEPTION + WHEN division_by_zero THEN + RAISE NOTICE USING MESSAGE = format('next val: %d',nextval('seq')); + WHEN not_null_violation THEN + SELECT nextval('seq'); + SELECT nextval('seq'); + RAISE NOTICE USING MESSAGE = format('next val: %d',nextval('seq')); END -$$ LANGUAGE PLPGSQL +$$ LANGUAGE PLPGSQL; query T SELECT get_body_str('f_rewrite'); ---- -"BEGIN\nSELECT nextval('seq':::STRING);\nEND\n;" - -statement ok -CREATE OR REPLACE FUNCTION f_rewrite() RETURNS INT AS -$$ - BEGIN - INSERT INTO t_rewrite(v) VALUES (nextval('seq')) RETURNING v; - END -$$ LANGUAGE PLPGSQL +"DECLARE\ni INT8 := nextval(106:::REGCLASS);\nj INT8 := nextval(106:::REGCLASS);\ncurs REFCURSOR := nextval(106:::REGCLASS)::STRING;\ncurs2 CURSOR FOR SELECT nextval(106:::REGCLASS);\nBEGIN\nRAISE notice\nUSING MESSAGE = format('next val: %d':::STRING, nextval(106:::REGCLASS));\nRAISE notice 'val1: %, val2: %', nextval(106:::REGCLASS), nextval(106:::REGCLASS);\nWHILE nextval(106:::REGCLASS) < 10:::INT8 LOOP\ni := nextval(106:::REGCLASS);\nSELECT nextval(106:::REGCLASS);\nIF nextval(106:::REGCLASS) = 1:::INT8 THEN\n\tSELECT nextval(106:::REGCLASS);\n\tSELECT nextval(106:::REGCLASS);\n\tCONTINUE;\nELSIF nextval(106:::REGCLASS) = 2:::INT8 THEN\n\tSELECT v FROM ROWS FROM (nextval(106:::REGCLASS)) AS v (\"int\") INTO i;\nELSIF nextval(106:::REGCLASS) = 3:::INT8 THEN\n\tSELECT nextval(106:::REGCLASS);\n\tSELECT nextval(106:::REGCLASS);\nEND IF;\nEND LOOP;\nOPEN curs FOR SELECT nextval(106:::REGCLASS);\nRETURN nextval(106:::REGCLASS);\nEXCEPTION\nWHEN division_by_zero THEN\nRAISE notice\nUSING MESSAGE = format('next val: %d':::STRING, nextval(106:::REGCLASS));\nWHEN not_null_violation THEN\nSELECT nextval(106:::REGCLASS);\nSELECT nextval(106:::REGCLASS);\nRAISE notice\nUSING MESSAGE = format('next val: %d':::STRING, nextval(106:::REGCLASS));\nEND\n;" -query T -SELECT get_body_str('f_rewrite'); +query TT +SHOW CREATE FUNCTION f_rewrite; ---- -"BEGIN\nINSERT INTO test.public.t_rewrite(v) VALUES (nextval('seq':::STRING)) RETURNING v;\nEND\n;" +f_rewrite CREATE FUNCTION public.f_rewrite() + RETURNS INT8 + VOLATILE + NOT LEAKPROOF + CALLED ON NULL INPUT + LANGUAGE plpgsql + AS $$ + DECLARE + i INT8 := nextval('public.seq'::REGCLASS); + j INT8 := nextval('public.seq'::REGCLASS); + curs REFCURSOR := nextval('public.seq'::REGCLASS)::STRING; + curs2 CURSOR FOR SELECT nextval('public.seq'::REGCLASS); + BEGIN + RAISE notice + USING MESSAGE = format('next val: %d':::STRING, nextval('public.seq'::REGCLASS)); + RAISE notice 'val1: %, val2: %', nextval('public.seq'::REGCLASS), nextval('public.seq'::REGCLASS); + WHILE nextval('public.seq'::REGCLASS) < 10:::INT8 LOOP + i := nextval('public.seq'::REGCLASS); + SELECT nextval('public.seq'::REGCLASS); + IF nextval('public.seq'::REGCLASS) = 1:::INT8 THEN + SELECT nextval('public.seq'::REGCLASS); + SELECT nextval('public.seq'::REGCLASS); + CONTINUE; + ELSIF nextval('public.seq'::REGCLASS) = 2:::INT8 THEN + SELECT v FROM ROWS FROM (nextval('public.seq'::REGCLASS)) AS v ("int") INTO i; + ELSIF nextval('public.seq'::REGCLASS) = 3:::INT8 THEN + SELECT nextval('public.seq'::REGCLASS); + SELECT nextval('public.seq'::REGCLASS); + END IF; + END LOOP; + OPEN curs FOR SELECT nextval('public.seq'::REGCLASS); + RETURN nextval('public.seq'::REGCLASS); + EXCEPTION + WHEN division_by_zero THEN + RAISE notice + USING MESSAGE = format('next val: %d':::STRING, nextval('public.seq'::REGCLASS)); + WHEN not_null_violation THEN + SELECT nextval('public.seq'::REGCLASS); + SELECT nextval('public.seq'::REGCLASS); + RAISE notice + USING MESSAGE = format('next val: %d':::STRING, nextval('public.seq'::REGCLASS)); + END + $$ statement ok DROP FUNCTION f_rewrite(); @@ -72,28 +132,91 @@ DROP FUNCTION f_rewrite(); statement ok CREATE OR REPLACE FUNCTION f_rewrite() RETURNS weekday AS $$ + DECLARE + day weekday := 'wednesday'::weekday; + today weekday := 'thursday'::weekday; + curs REFCURSOR := 'monday'::weekday::STRING; + curs2 CURSOR FOR SELECT 'tuesday'::weekday; BEGIN - SELECT 'wednesday'::weekday; + RAISE NOTICE USING MESSAGE = format('val: %d','wednesday'::weekday); + RAISE NOTICE 'val1: %, val2: %', 'wednesday'::weekday, 'thursday'::weekday; + WHILE day != 'wednesday'::weekday LOOP + day = 'friday'::weekday; + SELECT 'wednesday'::weekday; + IF day = 'wednesday'::weekday THEN + day = 'thursday'::weekday; + SELECT 'tuesday'::weekday; + CONTINUE; + ELSIF day = 'monday'::weekday THEN + SELECT 'tuesday'::weekday INTO day; + ELSIF day = 'tuesday'::weekday THEN + SELECT 'wednesday'::weekday INTO day; + SELECT 'wednesday'::weekday; + END IF; + END LOOP; + OPEN curs FOR SELECT 'wednesday'::weekday; + RETURN 'wednesday'::weekday; + EXCEPTION + WHEN division_by_zero THEN + RAISE NOTICE USING MESSAGE = format('val: %d','wednesday'::weekday); + WHEN not_null_violation THEN + SELECT 'wednesday'::weekday; + RAISE NOTICE 'val: %', 'wednesday'::weekday; END -$$ LANGUAGE PLPGSQL +$$ LANGUAGE PLPGSQL; query T SELECT get_body_str('f_rewrite'); ---- -"BEGIN\nSELECT 'wednesday'::@100107;\nEND\n;" +"DECLARE\nday @100107 := b'\\x80':::@100107;\ntoday @100107 := b'\\xa0':::@100107;\ncurs REFCURSOR := b' ':::@100107::STRING;\ncurs2 CURSOR FOR SELECT b'@':::@100107;\nBEGIN\nRAISE notice\nUSING MESSAGE = format('val: %d':::STRING, b'\\x80':::@100107);\nRAISE notice 'val1: %, val2: %', b'\\x80':::@100107, b'\\xa0':::@100107;\nWHILE day != b'\\x80':::@100107 LOOP\nday := b'\\xc0':::@100107;\nSELECT b'\\x80':::@100107;\nIF day = b'\\x80':::@100107 THEN\n\tday := b'\\xa0':::@100107;\n\tSELECT b'@':::@100107;\n\tCONTINUE;\nELSIF day = b' ':::@100107 THEN\n\tSELECT b'@':::@100107 INTO day;\nELSIF day = b'@':::@100107 THEN\n\tSELECT b'\\x80':::@100107 INTO day;\n\tSELECT b'\\x80':::@100107;\nEND IF;\nEND LOOP;\nOPEN curs FOR SELECT b'\\x80':::@100107;\nRETURN b'\\x80':::@100107;\nEXCEPTION\nWHEN division_by_zero THEN\nRAISE notice\nUSING MESSAGE = format('val: %d':::STRING, b'\\x80':::@100107);\nWHEN not_null_violation THEN\nSELECT b'\\x80':::@100107;\nRAISE notice 'val: %', b'\\x80':::@100107;\nEND\n;" -statement ok -CREATE OR REPLACE FUNCTION f_rewrite() RETURNS weekday AS -$$ - BEGIN - UPDATE t_rewrite SET w = 'thursday'::weekday WHERE w = 'wednesday'::weekday RETURNING w; - END -$$ LANGUAGE PLPGSQL - -query T -SELECT get_body_str('f_rewrite'); +query TT +SHOW CREATE FUNCTION f_rewrite; ---- -"BEGIN\nUPDATE test.public.t_rewrite SET w = 'thursday'::@100107 WHERE w = 'wednesday'::@100107 RETURNING w;\nEND\n;" +f_rewrite CREATE FUNCTION public.f_rewrite() + RETURNS test.public.weekday + VOLATILE + NOT LEAKPROOF + CALLED ON NULL INPUT + LANGUAGE plpgsql + AS $$ + DECLARE + day test.public.weekday := 'wednesday':::test.public.weekday; + today test.public.weekday := 'thursday':::test.public.weekday; + curs REFCURSOR := 'monday':::test.public.weekday::STRING; + curs2 CURSOR FOR SELECT 'tuesday':::test.public.weekday; + BEGIN + RAISE notice + USING MESSAGE = format('val: %d':::STRING, 'wednesday':::test.public.weekday); + RAISE notice 'val1: %, val2: %', 'wednesday':::test.public.weekday, 'thursday':::test.public.weekday; + WHILE day != 'wednesday':::test.public.weekday LOOP + day := 'friday':::test.public.weekday; + SELECT 'wednesday':::test.public.weekday; + IF day = 'wednesday':::test.public.weekday THEN + day := 'thursday':::test.public.weekday; + SELECT 'tuesday':::test.public.weekday; + CONTINUE; + ELSIF day = 'monday':::test.public.weekday THEN + SELECT 'tuesday':::test.public.weekday INTO day; + ELSIF day = 'tuesday':::test.public.weekday THEN + SELECT 'wednesday':::test.public.weekday INTO day; + SELECT 'wednesday':::test.public.weekday; + END IF; + END LOOP; + OPEN curs FOR SELECT 'wednesday':::test.public.weekday; + RETURN 'wednesday':::test.public.weekday; + EXCEPTION + WHEN division_by_zero THEN + RAISE notice + USING MESSAGE = format('val: %d':::STRING, 'wednesday':::test.public.weekday); + WHEN not_null_violation THEN + SELECT 'wednesday':::test.public.weekday; + RAISE notice 'val: %', 'wednesday':::test.public.weekday; + END + $$ + +statement ok +DROP FUNCTION f_rewrite(); subtest end @@ -110,7 +233,7 @@ $$ LANGUAGE PLPGSQL query T SELECT get_body_str('p_rewrite'); ---- -"BEGIN\nINSERT INTO test.public.t_rewrite(v) VALUES (nextval('seq':::STRING)) RETURNING v;\nEND\n;" +"BEGIN\nINSERT INTO test.public.t_rewrite(v) VALUES (nextval(106:::REGCLASS)) RETURNING v;\nEND\n;" statement ok DROP PROCEDURE p_rewrite(); @@ -126,6 +249,6 @@ $$ LANGUAGE PLPGSQL query T SELECT get_body_str('p_rewrite'); ---- -"BEGIN\nUPDATE test.public.t_rewrite SET w = 'thursday'::@100107 WHERE w = 'wednesday'::@100107 RETURNING w;\nEND\n;" +"BEGIN\nUPDATE test.public.t_rewrite SET w = b'\\xa0':::@100107 WHERE w = b'\\x80':::@100107 RETURNING w;\nEND\n;" subtest end diff --git a/pkg/sql/BUILD.bazel b/pkg/sql/BUILD.bazel index 0f9455c4814f..af74183a1a38 100644 --- a/pkg/sql/BUILD.bazel +++ b/pkg/sql/BUILD.bazel @@ -491,6 +491,8 @@ go_library( "//pkg/sql/sem/catconstants", "//pkg/sql/sem/catid", "//pkg/sql/sem/eval", + "//pkg/sql/sem/plpgsqltree", + "//pkg/sql/sem/plpgsqltree/utils", "//pkg/sql/sem/semenumpb", "//pkg/sql/sem/transform", "//pkg/sql/sem/tree", diff --git a/pkg/sql/crdb_internal.go b/pkg/sql/crdb_internal.go index fc8e2213ca0d..b70f90623475 100644 --- a/pkg/sql/crdb_internal.go +++ b/pkg/sql/crdb_internal.go @@ -68,7 +68,6 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgnotice" - plpgsql "github.com/cockroachdb/cockroach/pkg/sql/plpgsql/parser" "github.com/cockroachdb/cockroach/pkg/sql/privilege" "github.com/cockroachdb/cockroach/pkg/sql/protoreflect" "github.com/cockroachdb/cockroach/pkg/sql/roleoption" @@ -3579,29 +3578,15 @@ func createRoutinePopulate( for i := range treeNode.Options { if body, ok := treeNode.Options[i].(tree.RoutineBodyStr); ok { bodyStr := string(body) - switch fnDesc.GetLanguage() { - case catpb.Function_SQL: - bodyStr, err = formatFunctionQueryTypesForDisplay(ctx, &p.semaCtx, p.SessionData(), bodyStr) - if err != nil { - return err - } - bodyStr, err = formatQuerySequencesForDisplay(ctx, &p.semaCtx, bodyStr, true /* multiStmt */) - if err != nil { - return err - } - bodyStr = "\n" + bodyStr + "\n" - case catpb.Function_PLPGSQL: - // TODO(drewk): integrate this with the SQL case above. - plpgsqlStmt, err := plpgsql.Parse(bodyStr) - if err != nil { - return err - } - fmtCtx := tree.NewFmtCtx(tree.FmtParsable) - fmtCtx.FormatNode(plpgsqlStmt.AST) - bodyStr = "\n" + fmtCtx.CloseAndGetString() - default: - return errors.AssertionFailedf("unexpected function language: %s", fnDesc.GetLanguage()) + bodyStr, err = formatFunctionQueryTypesForDisplay(ctx, &p.semaCtx, p.SessionData(), bodyStr, fnDesc.GetLanguage()) + if err != nil { + return err + } + bodyStr, err = formatQuerySequencesForDisplay(ctx, &p.semaCtx, bodyStr, true /* multiStmt */, fnDesc.GetLanguage()) + if err != nil { + return err } + bodyStr = "\n" + bodyStr + "\n" stmtStrs := strings.Split(bodyStr, "\n") for i := range stmtStrs { if stmtStrs[i] != "" { diff --git a/pkg/sql/create_function.go b/pkg/sql/create_function.go index 4d2fe98b81c6..193cc7add44e 100644 --- a/pkg/sql/create_function.go +++ b/pkg/sql/create_function.go @@ -455,26 +455,19 @@ func setFuncOptions( } } - switch lang { - case catpb.Function_SQL: + if lang != catpb.Function_UNKNOWN_LANGUAGE && body != "" { // Replace any sequence names in the function body with IDs. - seqReplacedFuncBody, err := replaceSeqNamesWithIDs(params.ctx, params.p, body, true) + seqReplacedFuncBody, err := replaceSeqNamesWithIDsLang(params.ctx, params.p, body, true, lang) if err != nil { return err } - typeReplacedFuncBody, err := serializeUserDefinedTypes( - params.ctx, params.p.SemaCtx(), seqReplacedFuncBody, true /* multiStmt */, "UDFs", - ) + typeReplacedFuncBody, err := serializeUserDefinedTypesLang( + params.ctx, params.p.SemaCtx(), seqReplacedFuncBody, true /* multiStmt */, "UDFs", lang) if err != nil { return err } udfDesc.SetFuncBody(typeReplacedFuncBody) - case catpb.Function_PLPGSQL: - // TODO(#115627): make replaceSeqNamesWithIDs and serializeUserDefinedTypes - // play nice with PL/pgSQL. - udfDesc.SetFuncBody(body) } - return nil } diff --git a/pkg/sql/create_view.go b/pkg/sql/create_view.go index 888939e475b5..a1765c0c8fc5 100644 --- a/pkg/sql/create_view.go +++ b/pkg/sql/create_view.go @@ -30,8 +30,11 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgnotice" + plpgsql "github.com/cockroachdb/cockroach/pkg/sql/plpgsql/parser" "github.com/cockroachdb/cockroach/pkg/sql/privilege" "github.com/cockroachdb/cockroach/pkg/sql/sem/eval" + "github.com/cockroachdb/cockroach/pkg/sql/sem/plpgsqltree" + "github.com/cockroachdb/cockroach/pkg/sql/sem/plpgsqltree/utils" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/sql/sqlerrors" "github.com/cockroachdb/cockroach/pkg/sql/sqltelemetry" @@ -425,14 +428,30 @@ func makeViewTableDesc( return desc, nil } -// replaceSeqNamesWithIDs prepares to walk the given viewQuery by defining the +// replaceSeqNamesWithIDs prepares to walk the given query by defining the // function used to replace sequence names with IDs, and parsing the -// viewQuery into a statement. +// queryStr into a statement. // TODO (Chengxiong): move this to a better place. func replaceSeqNamesWithIDs( ctx context.Context, sc resolver.SchemaResolver, queryStr string, multiStmt bool, +) (string, error) { + return replaceSeqNamesWithIDsLang(ctx, sc, queryStr, multiStmt, catpb.Function_SQL) +} + +// replaceSeqNamesWithIDsLang prepares to walk the given query by defining the +// function used to replace sequence names with IDs, and parsing the +// queryStr into a statement. +func replaceSeqNamesWithIDsLang( + ctx context.Context, + sc resolver.SchemaResolver, + queryStr string, + multiStmt bool, + lang catpb.Function_Language, ) (string, error) { replaceSeqFunc := func(expr tree.Expr) (recurse bool, newExpr tree.Expr, err error) { + if expr == nil { + return false, expr, nil + } seqIdentifiers, err := seqexpr.GetUsedSequences(expr) if err != nil { return false, expr, err @@ -452,48 +471,77 @@ func replaceSeqNamesWithIDs( return false, newExpr, nil } - var stmts tree.Statements - if multiStmt { - parsedStmtd, err := parser.Parse(queryStr) - if err != nil { - return "", errors.Wrap(err, "failed to parse query string") + fmtCtx := tree.NewFmtCtx(tree.FmtSimple) + switch lang { + case catpb.Function_SQL: + var stmts tree.Statements + if multiStmt { + parsedStmtd, err := parser.Parse(queryStr) + if err != nil { + return "", errors.Wrap(err, "failed to parse query string") + } + for _, s := range parsedStmtd { + stmts = append(stmts, s.AST) + } + } else { + stmt, err := parser.ParseOne(queryStr) + if err != nil { + return "", errors.Wrap(err, "failed to parse query string") + } + stmts = tree.Statements{stmt.AST} } - for _, s := range parsedStmtd { - stmts = append(stmts, s.AST) + + for i, stmt := range stmts { + newStmt, err := tree.SimpleStmtVisit(stmt, replaceSeqFunc) + if err != nil { + return "", err + } + if i > 0 { + fmtCtx.WriteString("\n") + } + fmtCtx.FormatNode(newStmt) + if multiStmt { + fmtCtx.WriteString(";") + } } - } else { - stmt, err := parser.ParseOne(queryStr) + case catpb.Function_PLPGSQL: + var stmts plpgsqltree.Statement + plstmt, err := plpgsql.Parse(queryStr) if err != nil { return "", errors.Wrap(err, "failed to parse query string") } - stmts = tree.Statements{stmt.AST} - } + stmts = plstmt.AST - fmtCtx := tree.NewFmtCtx(tree.FmtSimple) - for i, stmt := range stmts { - newStmt, err := tree.SimpleStmtVisit(stmt, replaceSeqFunc) - if err != nil { - return "", err - } - if i > 0 { - fmtCtx.WriteString("\n") - } + v := utils.SQLStmtVisitor{Fn: replaceSeqFunc} + newStmt := plpgsqltree.Walk(&v, stmts) fmtCtx.FormatNode(newStmt) - if multiStmt { - fmtCtx.WriteString(";") - } } return fmtCtx.String(), nil } -// serializeUserDefinedTypes will walk the given view query -// and serialize any user defined types, so that renaming the type -// does not corrupt the view. +// serializeUserDefinedTypes will walk the given query and serialize any user +// defined types, so that renaming the type does not cause corruption. func serializeUserDefinedTypes( ctx context.Context, semaCtx *tree.SemaContext, queries string, multiStmt bool, parentType string, +) (string, error) { + return serializeUserDefinedTypesLang(ctx, semaCtx, queries, multiStmt, parentType, catpb.Function_SQL) +} + +// serializeUserDefinedTypesLang will walk the given query and serialize any +// user defined types, so that renaming the type does not cause corruption. +func serializeUserDefinedTypesLang( + ctx context.Context, + semaCtx *tree.SemaContext, + queries string, + multiStmt bool, + parentType string, + lang catpb.Function_Language, ) (string, error) { replaceFunc := func(expr tree.Expr) (recurse bool, newExpr tree.Expr, err error) { + if expr == nil { + return false, expr, nil + } var innerExpr tree.Expr var typRef tree.ResolvableTypeReference switch n := expr.(type) { @@ -539,39 +587,77 @@ func serializeUserDefinedTypes( } return false, parsedExpr, nil } - - var stmts tree.Statements - if multiStmt { - parsedStmts, err := parser.Parse(queries) - if err != nil { - return "", errors.Wrap(err, "failed to parse query") + replaceTypeFunc := func(typ tree.ResolvableTypeReference) (newTyp tree.ResolvableTypeReference, err error) { + if typ == nil { + return typ, nil } - stmts = make(tree.Statements, len(parsedStmts)) - for i, stmt := range parsedStmts { - stmts[i] = stmt.AST + // semaCtx may be nil if this is a virtual view being created at + // init time. + var typeResolver tree.TypeReferenceResolver + if semaCtx != nil { + typeResolver = semaCtx.TypeResolver } - } else { - stmt, err := parser.ParseOne(queries) + var t *types.T + t, err = tree.ResolveType(ctx, typ, typeResolver) if err != nil { - return "", errors.Wrap(err, "failed to parse query") + return typ, err + } + if !t.UserDefined() { + return typ, nil } - stmts = tree.Statements{stmt.AST} + return &tree.OIDTypeReference{OID: t.Oid()}, nil } fmtCtx := tree.NewFmtCtx(tree.FmtSimple) - for i, stmt := range stmts { - newStmt, err := tree.SimpleStmtVisit(stmt, replaceFunc) - if err != nil { - return "", err + switch lang { + case catpb.Function_SQL: + var stmts tree.Statements + if multiStmt { + parsedStmts, err := parser.Parse(queries) + if err != nil { + return "", errors.Wrap(err, "failed to parse query") + } + stmts = make(tree.Statements, len(parsedStmts)) + for i, stmt := range parsedStmts { + stmts[i] = stmt.AST + } + } else { + stmt, err := parser.ParseOne(queries) + if err != nil { + return "", errors.Wrap(err, "failed to parse query") + } + stmts = tree.Statements{stmt.AST} } - if i > 0 { - fmtCtx.WriteString("\n") + + for i, stmt := range stmts { + newStmt, err := tree.SimpleStmtVisit(stmt, replaceFunc) + if err != nil { + return "", err + } + if i > 0 { + fmtCtx.WriteString("\n") + } + fmtCtx.FormatNode(newStmt) + if multiStmt { + fmtCtx.WriteString(";") + } } - fmtCtx.FormatNode(newStmt) - if multiStmt { - fmtCtx.WriteString(";") + case catpb.Function_PLPGSQL: + var stmts plpgsqltree.Statement + plstmt, err := plpgsql.Parse(queries) + if err != nil { + return "", errors.Wrap(err, "failed to parse query string") } + stmts = plstmt.AST + + v := utils.SQLStmtVisitor{Fn: replaceFunc} + newStmt := plpgsqltree.Walk(&v, stmts) + v2 := utils.TypeRefVisitor{Fn: replaceTypeFunc} + newStmt = plpgsqltree.Walk(&v2, newStmt) + fmtCtx.FormatNode(newStmt) + fmtCtx.WriteString(";") } + return fmtCtx.CloseAndGetString(), nil } diff --git a/pkg/sql/opt/optbuilder/plpgsql.go b/pkg/sql/opt/optbuilder/plpgsql.go index 002033abbef0..ffda0efedd37 100644 --- a/pkg/sql/opt/optbuilder/plpgsql.go +++ b/pkg/sql/opt/optbuilder/plpgsql.go @@ -1532,7 +1532,7 @@ func newRecordTypeVisitor( var _ ast.StatementVisitor = &recordTypeVisitor{} -func (r *recordTypeVisitor) Visit(stmt ast.Statement) { +func (r *recordTypeVisitor) Visit(stmt ast.Statement) (newStmt ast.Statement, changed bool) { if retStmt, ok := stmt.(*ast.Return); ok { desired := types.Any if r.typ != types.Unknown { @@ -1545,7 +1545,7 @@ func (r *recordTypeVisitor) Visit(stmt ast.Statement) { } typ := typedExpr.ResolvedType() if typ == types.Unknown { - return + return stmt, false } if typ.Family() != types.TupleFamily { panic(pgerror.New(pgcode.DatatypeMismatch, @@ -1554,7 +1554,7 @@ func (r *recordTypeVisitor) Visit(stmt ast.Statement) { } if r.typ == types.Unknown { r.typ = typ - return + return stmt, false } if !typ.Identical(r.typ) { panic(errors.WithHint( @@ -1565,4 +1565,5 @@ func (r *recordTypeVisitor) Visit(stmt ast.Statement) { )) } } + return stmt, false } diff --git a/pkg/sql/plpgsql/parser/testdata/stmt_case b/pkg/sql/plpgsql/parser/testdata/stmt_case index e8da2026e9d9..adfd3d304de1 100644 --- a/pkg/sql/plpgsql/parser/testdata/stmt_case +++ b/pkg/sql/plpgsql/parser/testdata/stmt_case @@ -133,6 +133,7 @@ BEGIN END CASE; END ---- +decl_stmt: 1 stmt_block: 1 stmt_call: 3 stmt_case: 1 diff --git a/pkg/sql/sem/plpgsqltree/BUILD.bazel b/pkg/sql/sem/plpgsqltree/BUILD.bazel index d51ee48b3cc0..4548391ad4aa 100644 --- a/pkg/sql/sem/plpgsqltree/BUILD.bazel +++ b/pkg/sql/sem/plpgsqltree/BUILD.bazel @@ -13,6 +13,7 @@ go_library( visibility = ["//visibility:public"], deps = [ "//pkg/sql/sem/tree", + "//pkg/util/errorutil/unimplemented", "@com_github_cockroachdb_errors//:errors", ], ) diff --git a/pkg/sql/sem/plpgsqltree/exception.go b/pkg/sql/sem/plpgsqltree/exception.go index 7b77216598c2..9dbc83df8a8f 100644 --- a/pkg/sql/sem/plpgsqltree/exception.go +++ b/pkg/sql/sem/plpgsqltree/exception.go @@ -22,6 +22,13 @@ type Exception struct { Action []Statement } +func (s *Exception) CopyNode() *Exception { + copyNode := *s + copyNode.Conditions = append([]Condition(nil), copyNode.Conditions...) + copyNode.Action = append([]Statement(nil), copyNode.Action...) + return ©Node +} + func (s *Exception) Format(ctx *tree.FmtCtx) { ctx.WriteString("WHEN ") for i, cond := range s.Conditions { @@ -44,11 +51,19 @@ func (s *Exception) PlpgSQLStatementTag() string { return "proc_exception" } -func (s *Exception) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) - for _, stmt := range s.Action { - stmt.WalkStmt(visitor) +func (s *Exception) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + newStmt, changed = visitor.Visit(s) + for i, stmt := range s.Action { + ns, ch := stmt.WalkStmt(visitor) + if ch { + changed = true + if newStmt == s { + newStmt = s.CopyNode() + } + newStmt.(*Exception).Action[i] = ns + } } + return newStmt, changed } type Condition struct { diff --git a/pkg/sql/sem/plpgsqltree/statements.go b/pkg/sql/sem/plpgsqltree/statements.go index 9f747fc2a4eb..2f4ad0cff442 100644 --- a/pkg/sql/sem/plpgsqltree/statements.go +++ b/pkg/sql/sem/plpgsqltree/statements.go @@ -16,6 +16,7 @@ import ( "strings" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" + "github.com/cockroachdb/cockroach/pkg/util/errorutil/unimplemented" ) type Expr = tree.Expr @@ -25,7 +26,7 @@ type Statement interface { GetLineNo() int GetStmtID() uint plpgsqlStmt() - WalkStmt(StatementVisitor) + WalkStmt(StatementVisitor) (newStmt Statement, changed bool) } type TaggedStatement interface { @@ -63,6 +64,14 @@ type Block struct { Exceptions []Exception } +func (s *Block) CopyNode() *Block { + copyNode := *s + copyNode.Decls = append([]Statement(nil), copyNode.Decls...) + copyNode.Body = append([]Statement(nil), copyNode.Body...) + copyNode.Exceptions = append([]Exception(nil), copyNode.Exceptions...) + return ©Node +} + // TODO(drewk): format Label and Exceptions fields. func (s *Block) Format(ctx *tree.FmtCtx) { if s.Decls != nil { @@ -90,11 +99,39 @@ func (s *Block) PlpgSQLStatementTag() string { return "stmt_block" } -func (s *Block) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) - for _, stmt := range s.Body { - stmt.WalkStmt(visitor) +func (s *Block) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + newStmt, changed = visitor.Visit(s) + for i, stmt := range s.Decls { + ns, ch := stmt.WalkStmt(visitor) + if ch { + changed = true + if newStmt == s { + newStmt = s.CopyNode() + } + newStmt.(*Block).Decls[i] = ns + } + } + for i, stmt := range s.Body { + ns, ch := stmt.WalkStmt(visitor) + if ch { + changed = true + if newStmt == s { + newStmt = s.CopyNode() + } + newStmt.(*Block).Body[i] = ns + } + } + for i, stmt := range s.Exceptions { + ns, ch := stmt.WalkStmt(visitor) + if ch { + changed = true + if newStmt == s { + newStmt = s.CopyNode() + } + newStmt.(*Block).Exceptions[i] = *(ns.(*Exception)) + } } + return newStmt, changed } // decl_stmt @@ -108,6 +145,11 @@ type Declaration struct { Expr Expr } +func (s *Declaration) CopyNode() *Declaration { + copyNode := *s + return ©Node +} + func (s *Declaration) Format(ctx *tree.FmtCtx) { ctx.WriteString(string(s.Var)) if s.Constant { @@ -131,8 +173,9 @@ func (s *Declaration) PlpgSQLStatementTag() string { return "decl_stmt" } -func (s *Declaration) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) +func (s *Declaration) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + newStmt, changed = visitor.Visit(s) + return newStmt, changed } type CursorDeclaration struct { @@ -142,6 +185,11 @@ type CursorDeclaration struct { Query tree.Statement } +func (s *CursorDeclaration) CopyNode() *CursorDeclaration { + copyNode := *s + return ©Node +} + func (s *CursorDeclaration) Format(ctx *tree.FmtCtx) { ctx.WriteString(string(s.Name)) switch s.Scroll { @@ -159,8 +207,9 @@ func (s *CursorDeclaration) PlpgSQLStatementTag() string { return "decl_cursor_stmt" } -func (s *CursorDeclaration) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) +func (s *CursorDeclaration) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + newStmt, changed = visitor.Visit(s) + return newStmt, changed } // stmt_assign @@ -170,6 +219,11 @@ type Assignment struct { Value Expr } +func (s *Assignment) CopyNode() *Assignment { + copyNode := *s + return ©Node +} + func (s *Assignment) PlpgSQLStatementTag() string { return "stmt_assign" } @@ -178,8 +232,9 @@ func (s *Assignment) Format(ctx *tree.FmtCtx) { ctx.WriteString(fmt.Sprintf("%s := %s;\n", s.Var, s.Value)) } -func (s *Assignment) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) +func (s *Assignment) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + newStmt, changed = visitor.Visit(s) + return newStmt, changed } // stmt_if @@ -191,6 +246,18 @@ type If struct { ElseBody []Statement } +func (s *If) CopyNode() *If { + copyNode := *s + copyNode.ThenBody = append([]Statement(nil), copyNode.ThenBody...) + copyNode.ElseBody = append([]Statement(nil), copyNode.ElseBody...) + copyNode.ElseIfList = make([]ElseIf, len(s.ElseIfList)) + for i, ei := range s.ElseIfList { + copyNode.ElseIfList[i] = ei + copyNode.ElseIfList[i].Stmts = append([]Statement(nil), copyNode.ElseIfList[i].Stmts...) + } + return ©Node +} + func (s *If) Format(ctx *tree.FmtCtx) { ctx.WriteString("IF ") s.Condition.Format(ctx) @@ -217,21 +284,43 @@ func (s *If) PlpgSQLStatementTag() string { return "stmt_if" } -func (s *If) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) +func (s *If) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + newStmt, changed = visitor.Visit(s) - for _, thenStmt := range s.ThenBody { - thenStmt.WalkStmt(visitor) + for i, thenStmt := range s.ThenBody { + ns, ch := thenStmt.WalkStmt(visitor) + if ch { + changed = true + if newStmt == s { + newStmt = s.CopyNode() + } + newStmt.(*If).ThenBody[i] = ns + } } - for _, elseIf := range s.ElseIfList { - elseIf.WalkStmt(visitor) + for i, elseIf := range s.ElseIfList { + ns, ch := elseIf.WalkStmt(visitor) + if ch { + changed = true + if newStmt == s { + newStmt = s.CopyNode() + } + newStmt.(*If).ElseIfList[i] = *ns.(*ElseIf) + } } - for _, elseStmt := range s.ElseBody { - elseStmt.WalkStmt(visitor) + for i, elseStmt := range s.ElseBody { + ns, ch := elseStmt.WalkStmt(visitor) + if ch { + changed = true + if newStmt == s { + newStmt = s.CopyNode() + } + newStmt.(*If).ElseBody[i] = ns + } } + return newStmt, changed } type ElseIf struct { @@ -240,6 +329,12 @@ type ElseIf struct { Stmts []Statement } +func (s *ElseIf) CopyNode() *ElseIf { + copyNode := *s + copyNode.Stmts = append([]Statement(nil), copyNode.Stmts...) + return ©Node +} + func (s *ElseIf) Format(ctx *tree.FmtCtx) { ctx.WriteString("ELSIF ") s.Condition.Format(ctx) @@ -254,12 +349,20 @@ func (s *ElseIf) PlpgSQLStatementTag() string { return "stmt_if_else_if" } -func (s *ElseIf) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) +func (s *ElseIf) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + newStmt, changed = visitor.Visit(s) - for _, stmt := range s.Stmts { - stmt.WalkStmt(visitor) + for i, stmt := range s.Stmts { + ns, ch := stmt.WalkStmt(visitor) + if ch { + changed = true + if newStmt == s { + newStmt = s.CopyNode() + } + newStmt.(*ElseIf).Stmts[i] = ns + } } + return newStmt, changed } // stmt_case @@ -273,6 +376,18 @@ type Case struct { ElseStmts []Statement } +func (s *Case) CopyNode() *Case { + copyNode := *s + copyNode.ElseStmts = append([]Statement(nil), copyNode.ElseStmts...) + copyNode.CaseWhenList = make([]*CaseWhen, len(s.CaseWhenList)) + caseWhens := make([]CaseWhen, len(s.CaseWhenList)) + for i, cw := range s.CaseWhenList { + caseWhens[i] = *cw + copyNode.CaseWhenList[i] = &caseWhens[i] + } + return ©Node +} + // TODO(drewk): fix the whitespace/newline formatting for CASE (see the // stmt_case test file). func (s *Case) Format(ctx *tree.FmtCtx) { @@ -298,18 +413,33 @@ func (s *Case) PlpgSQLStatementTag() string { return "stmt_case" } -func (s *Case) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) +func (s *Case) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + newStmt, changed = visitor.Visit(s) - for _, when := range s.CaseWhenList { - when.WalkStmt(visitor) + for i, when := range s.CaseWhenList { + ns, ch := when.WalkStmt(visitor) + if ch { + changed = true + if newStmt == s { + newStmt = s.CopyNode() + } + newStmt.(*Case).CaseWhenList[i] = ns.(*CaseWhen) + } } if s.HaveElse { - for _, stmt := range s.ElseStmts { - stmt.WalkStmt(visitor) + for i, stmt := range s.ElseStmts { + ns, ch := stmt.WalkStmt(visitor) + if ch { + changed = true + if newStmt == s { + newStmt = s.CopyNode() + } + newStmt.(*Case).ElseStmts[i] = ns + } } } + return newStmt, changed } type CaseWhen struct { @@ -319,6 +449,12 @@ type CaseWhen struct { Stmts []Statement } +func (s *CaseWhen) CopyNode() *CaseWhen { + copyNode := *s + copyNode.Stmts = append([]Statement(nil), copyNode.Stmts...) + return ©Node +} + func (s *CaseWhen) Format(ctx *tree.FmtCtx) { ctx.WriteString(fmt.Sprintf("WHEN %s THEN\n", s.Expr)) for i, stmt := range s.Stmts { @@ -334,12 +470,20 @@ func (s *CaseWhen) PlpgSQLStatementTag() string { return "stmt_when" } -func (s *CaseWhen) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) +func (s *CaseWhen) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + newStmt, changed = visitor.Visit(s) - for _, stmt := range s.Stmts { - stmt.WalkStmt(visitor) + for i, stmt := range s.Stmts { + ns, ch := stmt.WalkStmt(visitor) + if ch { + changed = true + if newStmt == s { + newStmt = s.CopyNode() + } + newStmt.(*CaseWhen).Stmts[i] = ns + } } + return newStmt, changed } // stmt_loop @@ -349,6 +493,12 @@ type Loop struct { Body []Statement } +func (s *Loop) CopyNode() *Loop { + copyNode := *s + copyNode.Body = append([]Statement(nil), copyNode.Body...) + return ©Node +} + func (s *Loop) PlpgSQLStatementTag() string { return "stmt_simple_loop" } @@ -365,11 +515,19 @@ func (s *Loop) Format(ctx *tree.FmtCtx) { ctx.WriteString(";\n") } -func (s *Loop) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) - for _, stmt := range s.Body { - stmt.WalkStmt(visitor) +func (s *Loop) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + newStmt, changed = visitor.Visit(s) + for i, stmt := range s.Body { + ns, ch := stmt.WalkStmt(visitor) + if ch { + changed = true + if newStmt == s { + newStmt = s.CopyNode() + } + newStmt.(*Loop).Body[i] = ns + } } + return newStmt, changed } // stmt_while @@ -380,6 +538,12 @@ type While struct { Body []Statement } +func (s *While) CopyNode() *While { + copyNode := *s + copyNode.Body = append([]Statement(nil), copyNode.Body...) + return ©Node +} + func (s *While) Format(ctx *tree.FmtCtx) { ctx.WriteString("WHILE ") s.Condition.Format(ctx) @@ -398,11 +562,19 @@ func (s *While) PlpgSQLStatementTag() string { return "stmt_while" } -func (s *While) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) - for _, stmt := range s.Body { - stmt.WalkStmt(visitor) +func (s *While) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + newStmt, changed = visitor.Visit(s) + for i, stmt := range s.Body { + ns, ch := stmt.WalkStmt(visitor) + if ch { + changed = true + if newStmt == s { + newStmt = s.CopyNode() + } + newStmt.(*While).Body[i] = ns + } } + return newStmt, changed } // stmt_for @@ -424,11 +596,8 @@ func (s *ForInt) PlpgSQLStatementTag() string { return "stmt_for_int_loop" } -func (s *ForInt) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) - for _, stmt := range s.Body { - stmt.WalkStmt(visitor) - } +func (s *ForInt) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + panic(unimplemented.New("plpgsql visitor", "Unimplemented PLpgSQL visitor pattern")) } type ForQuery struct { @@ -445,11 +614,8 @@ func (s *ForQuery) PlpgSQLStatementTag() string { return "stmt_for_query_loop" } -func (s *ForQuery) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) - for _, stmt := range s.Body { - stmt.WalkStmt(visitor) - } +func (s *ForQuery) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + panic(unimplemented.New("plpgsql visitor", "Unimplemented PLpgSQL visitor pattern")) } type ForSelect struct { @@ -464,9 +630,8 @@ func (s *ForSelect) PlpgSQLStatementTag() string { return "stmt_query_select_loop" } -func (s *ForSelect) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) - s.ForQuery.WalkStmt(visitor) +func (s *ForSelect) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + panic(unimplemented.New("plpgsql visitor", "Unimplemented PLpgSQL visitor pattern")) } type ForCursor struct { @@ -482,9 +647,8 @@ func (s *ForCursor) PlpgSQLStatementTag() string { return "stmt_for_query_cursor_loop" } -func (s *ForCursor) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) - s.ForQuery.WalkStmt(visitor) +func (s *ForCursor) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + panic(unimplemented.New("plpgsql visitor", "Unimplemented PLpgSQL visitor pattern")) } type ForDynamic struct { @@ -500,9 +664,8 @@ func (s *ForDynamic) PlpgSQLStatementTag() string { return "stmt_for_dyn_loop" } -func (s *ForDynamic) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) - s.ForQuery.WalkStmt(visitor) +func (s *ForDynamic) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + panic(unimplemented.New("plpgsql visitor", "Unimplemented PLpgSQL visitor pattern")) } // stmt_foreach_a @@ -522,12 +685,8 @@ func (s *ForEachArray) PlpgSQLStatementTag() string { return "stmt_for_each_a" } -func (s *ForEachArray) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) - - for _, stmt := range s.Body { - stmt.WalkStmt(visitor) - } +func (s *ForEachArray) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + panic(unimplemented.New("plpgsql visitor", "Unimplemented PLpgSQL visitor pattern")) } // stmt_exit @@ -537,6 +696,11 @@ type Exit struct { Condition Expr } +func (s *Exit) CopyNode() *Exit { + copyNode := *s + return ©Node +} + func (s *Exit) Format(ctx *tree.FmtCtx) { ctx.WriteString("EXIT") if s.Label != "" { @@ -554,8 +718,9 @@ func (s *Exit) PlpgSQLStatementTag() string { return "stmt_exit" } -func (s *Exit) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) +func (s *Exit) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + newStmt, changed = visitor.Visit(s) + return newStmt, changed } // stmt_continue @@ -565,6 +730,11 @@ type Continue struct { Condition Expr } +func (s *Continue) CopyNode() *Continue { + copyNode := *s + return ©Node +} + func (s *Continue) Format(ctx *tree.FmtCtx) { ctx.WriteString("CONTINUE") if s.Label != "" { @@ -581,8 +751,9 @@ func (s *Continue) PlpgSQLStatementTag() string { return "stmt_continue" } -func (s *Continue) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) +func (s *Continue) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + newStmt, changed = visitor.Visit(s) + return newStmt, changed } // stmt_return @@ -592,6 +763,11 @@ type Return struct { RetVar Variable } +func (s *Return) CopyNode() *Return { + copyNode := *s + return ©Node +} + func (s *Return) Format(ctx *tree.FmtCtx) { ctx.WriteString("RETURN ") if s.Expr == nil { @@ -606,8 +782,9 @@ func (s *Return) PlpgSQLStatementTag() string { return "stmt_return" } -func (s *Return) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) +func (s *Return) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + newStmt, changed = visitor.Visit(s) + return newStmt, changed } type ReturnNext struct { @@ -623,8 +800,8 @@ func (s *ReturnNext) PlpgSQLStatementTag() string { return "stmt_return_next" } -func (s *ReturnNext) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) +func (s *ReturnNext) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + panic(unimplemented.New("plpgsql visitor", "Unimplemented PLpgSQL visitor pattern")) } type ReturnQuery struct { @@ -641,8 +818,8 @@ func (s *ReturnQuery) PlpgSQLStatementTag() string { return "stmt_return_query" } -func (s *ReturnQuery) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) +func (s *ReturnQuery) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + panic(unimplemented.New("plpgsql visitor", "Unimplemented PLpgSQL visitor pattern")) } // stmt_raise @@ -656,6 +833,13 @@ type Raise struct { Options []RaiseOption } +func (s *Raise) CopyNode() *Raise { + copyNode := *s + copyNode.Params = append([]Expr(nil), s.Params...) + copyNode.Options = append([]RaiseOption(nil), s.Options...) + return ©Node +} + func (s *Raise) Format(ctx *tree.FmtCtx) { ctx.WriteString("RAISE") if s.LogLevel != "" { @@ -700,8 +884,9 @@ func (s *Raise) PlpgSQLStatementTag() string { return "stmt_raise" } -func (s *Raise) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) +func (s *Raise) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + newStmt, changed = visitor.Visit(s) + return newStmt, changed } // stmt_assert @@ -711,6 +896,11 @@ type Assert struct { Message Expr } +func (s *Assert) CopyNode() *Assert { + copyNode := *s + return ©Node +} + func (s *Assert) Format(ctx *tree.FmtCtx) { // TODO(drewk): Pretty print the assert condition and message ctx.WriteString("ASSERT\n") @@ -720,8 +910,9 @@ func (s *Assert) PlpgSQLStatementTag() string { return "stmt_assert" } -func (s *Assert) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) +func (s *Assert) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + newStmt, changed = visitor.Visit(s) + return newStmt, changed } // stmt_execsql @@ -732,6 +923,12 @@ type Execute struct { Target []Variable } +func (s *Execute) CopyNode() *Execute { + copyNode := *s + copyNode.Target = append([]Variable(nil), copyNode.Target...) + return ©Node +} + func (s *Execute) Format(ctx *tree.FmtCtx) { s.SqlStmt.Format(ctx) if s.Target != nil { @@ -753,8 +950,9 @@ func (s *Execute) PlpgSQLStatementTag() string { return "stmt_exec_sql" } -func (s *Execute) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) +func (s *Execute) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + newStmt, changed = visitor.Visit(s) + return newStmt, changed } // stmt_dynexecute @@ -768,6 +966,12 @@ type DynamicExecute struct { Params []Expr } +func (s *DynamicExecute) CopyNode() *DynamicExecute { + copyNode := *s + copyNode.Params = append([]Expr(nil), s.Params...) + return ©Node +} + func (s *DynamicExecute) Format(ctx *tree.FmtCtx) { // TODO(drewk): Pretty print the original command ctx.WriteString("EXECUTE a dynamic command") @@ -787,8 +991,9 @@ func (s *DynamicExecute) PlpgSQLStatementTag() string { return "stmt_dyn_exec" } -func (s *DynamicExecute) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) +func (s *DynamicExecute) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + newStmt, changed = visitor.Visit(s) + return newStmt, changed } // stmt_perform @@ -804,8 +1009,8 @@ func (s *Perform) PlpgSQLStatementTag() string { return "stmt_perform" } -func (s *Perform) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) +func (s *Perform) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + panic(unimplemented.New("plpgsql visitor", "Unimplemented PLpgSQL visitor pattern")) } // stmt_call @@ -816,6 +1021,11 @@ type Call struct { Target Variable } +func (s *Call) CopyNode() *Call { + copyNode := *s + return ©Node +} + func (s *Call) Format(ctx *tree.FmtCtx) { // TODO(drewk): Correct the Call field and print the Expr and Target. if s.IsCall { @@ -829,8 +1039,9 @@ func (s *Call) PlpgSQLStatementTag() string { return "stmt_call" } -func (s *Call) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) +func (s *Call) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + newStmt, changed = visitor.Visit(s) + return newStmt, changed } // stmt_getdiag @@ -872,8 +1083,9 @@ func (s *GetDiagnostics) PlpgSQLStatementTag() string { return "stmt_get_diag" } -func (s *GetDiagnostics) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) +func (s *GetDiagnostics) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + newStmt, changed = visitor.Visit(s) + return newStmt, changed } // stmt_open @@ -884,6 +1096,11 @@ type Open struct { Query tree.Statement } +func (s *Open) CopyNode() *Open { + copyNode := *s + return ©Node +} + func (s *Open) Format(ctx *tree.FmtCtx) { ctx.WriteString("OPEN ") s.CurVar.Format(ctx) @@ -904,8 +1121,9 @@ func (s *Open) PlpgSQLStatementTag() string { return "stmt_open" } -func (s *Open) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) +func (s *Open) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + newStmt, changed = visitor.Visit(s) + return newStmt, changed } // stmt_fetch @@ -952,8 +1170,9 @@ func (s *Fetch) PlpgSQLStatementTag() string { return "stmt_fetch" } -func (s *Fetch) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) +func (s *Fetch) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + newStmt, changed = visitor.Visit(s) + return newStmt, changed } // stmt_close @@ -972,8 +1191,9 @@ func (s *Close) PlpgSQLStatementTag() string { return "stmt_close" } -func (s *Close) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) +func (s *Close) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + newStmt, changed = visitor.Visit(s) + return newStmt, changed } // stmt_commit @@ -989,8 +1209,9 @@ func (s *Commit) PlpgSQLStatementTag() string { return "stmt_commit" } -func (s *Commit) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) +func (s *Commit) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + newStmt, changed = visitor.Visit(s) + return newStmt, changed } // stmt_rollback @@ -1006,8 +1227,9 @@ func (s *Rollback) PlpgSQLStatementTag() string { return "stmt_rollback" } -func (s *Rollback) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) +func (s *Rollback) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + newStmt, changed = visitor.Visit(s) + return newStmt, changed } // stmt_null @@ -1023,6 +1245,7 @@ func (s *Null) PlpgSQLStatementTag() string { return "stmt_null" } -func (s *Null) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) +func (s *Null) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + newStmt, changed = visitor.Visit(s) + return newStmt, changed } diff --git a/pkg/sql/sem/plpgsqltree/utils/plpg_visitor.go b/pkg/sql/sem/plpgsqltree/utils/plpg_visitor.go index 0e775e864efb..fd690ffbebd6 100644 --- a/pkg/sql/sem/plpgsqltree/utils/plpg_visitor.go +++ b/pkg/sql/sem/plpgsqltree/utils/plpg_visitor.go @@ -58,7 +58,9 @@ type telemetryVisitor struct { var _ plpgsqltree.StatementVisitor = &telemetryVisitor{} // Visit implements the StatementVisitor interface -func (v *telemetryVisitor) Visit(stmt plpgsqltree.Statement) { +func (v *telemetryVisitor) Visit( + stmt plpgsqltree.Statement, +) (newStmt plpgsqltree.Statement, changed bool) { taggedStmt, ok := stmt.(plpgsqltree.TaggedStatement) if !ok { v.Err = errors.AssertionFailedf("no tag found for stmt %q", stmt) @@ -75,6 +77,7 @@ func (v *telemetryVisitor) Visit(stmt plpgsqltree.Statement) { } v.Err = nil + return stmt, false } // MakePLpgSQLTelemetryVisitor makes a plpgsql telemetry visitor, for capturing @@ -114,3 +117,268 @@ func ParseAndCollectTelemetryForPLpgSQLFunc(stmt *tree.CreateRoutine) error { } return unimp.New("plpgsql", "plpgsql not supported in user-defined functions") } + +// SQLStmtVisitor calls Fn for every SQL statement and expression found while +// walking the PLpgSQL AST. Since PLpgSQL nodes may have statement and +// expression fields that are nil, Fn should handle the nil case. +type SQLStmtVisitor struct { + Fn tree.SimpleVisitFn + Err error +} + +var _ plpgsqltree.StatementVisitor = &SQLStmtVisitor{} + +func (v *SQLStmtVisitor) Visit( + stmt plpgsqltree.Statement, +) (newStmt plpgsqltree.Statement, changed bool) { + if v.Err != nil { + return stmt, false + } + newStmt = stmt + var s tree.Statement + var e tree.Expr + switch t := stmt.(type) { + case *plpgsqltree.CursorDeclaration: + s, v.Err = tree.SimpleStmtVisit(t.Query, v.Fn) + if v.Err != nil { + return stmt, false + } + changed = t.Query != s + if changed { + cpy := t.CopyNode() + cpy.Query = s + newStmt = cpy + } + case *plpgsqltree.Execute: + s, v.Err = tree.SimpleStmtVisit(t.SqlStmt, v.Fn) + if v.Err != nil { + return stmt, false + } + changed = t.SqlStmt != s + if changed { + cpy := t.CopyNode() + cpy.SqlStmt = s + newStmt = cpy + } + case *plpgsqltree.Open: + s, v.Err = tree.SimpleStmtVisit(t.Query, v.Fn) + if v.Err != nil { + return stmt, false + } + changed = t.Query != s + if changed { + cpy := t.CopyNode() + cpy.Query = s + newStmt = cpy + } + case *plpgsqltree.Declaration: + e, v.Err = tree.SimpleVisit(t.Expr, v.Fn) + if v.Err != nil { + return stmt, false + } + changed = t.Expr != e + if changed { + cpy := t.CopyNode() + cpy.Expr = e + newStmt = cpy + } + + case *plpgsqltree.Assignment: + e, v.Err = tree.SimpleVisit(t.Value, v.Fn) + if v.Err != nil { + return stmt, false + } + changed = t.Value != e + if changed { + cpy := t.CopyNode() + cpy.Value = e + newStmt = cpy + } + case *plpgsqltree.If: + e, v.Err = tree.SimpleVisit(t.Condition, v.Fn) + if v.Err != nil { + return stmt, false + } + changed = t.Condition != e + if changed { + cpy := t.CopyNode() + cpy.Condition = e + newStmt = cpy + } + case *plpgsqltree.ElseIf: + e, v.Err = tree.SimpleVisit(t.Condition, v.Fn) + if v.Err != nil { + return stmt, false + } + changed = t.Condition != e + if changed { + cpy := t.CopyNode() + cpy.Condition = e + newStmt = cpy + } + case *plpgsqltree.While: + e, v.Err = tree.SimpleVisit(t.Condition, v.Fn) + if v.Err != nil { + return stmt, false + } + changed = t.Condition != e + if changed { + cpy := t.CopyNode() + cpy.Condition = e + newStmt = cpy + } + e, v.Err = tree.SimpleVisit(t.Condition, v.Fn) + if v.Err != nil { + return stmt, false + } + changed = t.Condition != e + if changed { + cpy := t.CopyNode() + cpy.Condition = e + newStmt = cpy + } + case *plpgsqltree.Exit: + e, v.Err = tree.SimpleVisit(t.Condition, v.Fn) + if v.Err != nil { + return stmt, false + } + changed = t.Condition != e + if changed { + cpy := t.CopyNode() + cpy.Condition = e + newStmt = cpy + } + case *plpgsqltree.Continue: + e, v.Err = tree.SimpleVisit(t.Condition, v.Fn) + if v.Err != nil { + return stmt, false + } + changed = t.Condition != e + if changed { + cpy := t.CopyNode() + cpy.Condition = e + newStmt = cpy + } + case *plpgsqltree.Return: + e, v.Err = tree.SimpleVisit(t.Expr, v.Fn) + if v.Err != nil { + return stmt, false + } + changed = t.Expr != e + if changed { + cpy := t.CopyNode() + cpy.Expr = e + newStmt = cpy + } + case *plpgsqltree.Raise: + for i, p := range t.Params { + e, v.Err = tree.SimpleVisit(p, v.Fn) + if v.Err != nil { + return stmt, false + } + changed = changed || (t.Params[i] != e) + if changed { + if newStmt != stmt { + cpy := t.CopyNode() + newStmt = cpy + } + newStmt.(*plpgsqltree.Raise).Params[i] = e + } + } + for i, p := range t.Options { + e, v.Err = tree.SimpleVisit(p.Expr, v.Fn) + if v.Err != nil { + return stmt, false + } + changed = changed || (t.Options[i].Expr != e) + if changed { + if newStmt != stmt { + cpy := t.CopyNode() + newStmt = cpy + } + newStmt.(*plpgsqltree.Raise).Options[i].Expr = e + } + } + case *plpgsqltree.Assert: + e, v.Err = tree.SimpleVisit(t.Condition, v.Fn) + if v.Err != nil { + return stmt, false + } + changed = t.Condition != e + if changed { + cpy := t.CopyNode() + cpy.Condition = e + newStmt = cpy + } + + case *plpgsqltree.DynamicExecute: + for i, p := range t.Params { + e, v.Err = tree.SimpleVisit(p, v.Fn) + if v.Err != nil { + return stmt, false + } + changed = changed || (t.Params[i] != e) + if changed { + if newStmt != stmt { + cpy := t.CopyNode() + newStmt = cpy + } + newStmt.(*plpgsqltree.DynamicExecute).Params[i] = e + } + } + case *plpgsqltree.Call: + e, v.Err = tree.SimpleVisit(t.Expr, v.Fn) + if v.Err != nil { + return stmt, false + } + changed = t.Expr != e + if changed { + cpy := t.CopyNode() + cpy.Expr = e + newStmt = cpy + } + + case *plpgsqltree.ForInt, *plpgsqltree.ForSelect, *plpgsqltree.ForCursor, + *plpgsqltree.ForDynamic, *plpgsqltree.ForEachArray, *plpgsqltree.ReturnNext, + *plpgsqltree.ReturnQuery, *plpgsqltree.Perform: + panic(unimp.New("plpgsql visitor", "Unimplemented PLpgSQL visitor")) + } + if v.Err != nil { + return stmt, false + } + return newStmt, changed +} + +// TypeRefVisitor calls the given replace function on each type reference contained in the visited PLpgSQL +// statements. Note that this currently only includes `Declaration`, and SQL statements and expressions +// are not visited. +type TypeRefVisitor struct { + Fn func(typ tree.ResolvableTypeReference) (newTyp tree.ResolvableTypeReference, err error) + Err error +} + +var _ plpgsqltree.StatementVisitor = &TypeRefVisitor{} + +func (v *TypeRefVisitor) Visit( + stmt plpgsqltree.Statement, +) (newStmt plpgsqltree.Statement, changed bool) { + if v.Err != nil { + return stmt, false + } + newStmt = stmt + if t, ok := stmt.(*plpgsqltree.Declaration); ok { + var newTyp tree.ResolvableTypeReference + newTyp, v.Err = v.Fn(t.Typ) + if v.Err != nil { + return stmt, false + } + changed = t.Typ != newTyp + if changed { + if newStmt == stmt { + newStmt = t.CopyNode() + newStmt.(*plpgsqltree.Declaration).Typ = newTyp + } + } + } + return newStmt, changed +} diff --git a/pkg/sql/sem/plpgsqltree/visitor.go b/pkg/sql/sem/plpgsqltree/visitor.go index cbad8e13fe51..9526d480f27c 100644 --- a/pkg/sql/sem/plpgsqltree/visitor.go +++ b/pkg/sql/sem/plpgsqltree/visitor.go @@ -14,10 +14,11 @@ package plpgsqltree // a statement walk. type StatementVisitor interface { // Visit is called during a statement walk. - Visit(stmt Statement) + Visit(stmt Statement) (newStmt Statement, changed bool) } // Walk traverses the plpgsql statement. -func Walk(v StatementVisitor, stmt Statement) { - stmt.WalkStmt(v) +func Walk(v StatementVisitor, stmt Statement) Statement { + newStmt, _ := stmt.WalkStmt(v) + return newStmt } diff --git a/pkg/sql/show_create_clauses.go b/pkg/sql/show_create_clauses.go index ff83d960db72..c810dd080fd3 100644 --- a/pkg/sql/show_create_clauses.go +++ b/pkg/sql/show_create_clauses.go @@ -25,8 +25,11 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/catalog/schemaexpr" "github.com/cockroachdb/cockroach/pkg/sql/catalog/tabledesc" "github.com/cockroachdb/cockroach/pkg/sql/parser" + plpgsql "github.com/cockroachdb/cockroach/pkg/sql/plpgsql/parser" "github.com/cockroachdb/cockroach/pkg/sql/rowenc" "github.com/cockroachdb/cockroach/pkg/sql/sem/catconstants" + "github.com/cockroachdb/cockroach/pkg/sql/sem/plpgsqltree" + "github.com/cockroachdb/cockroach/pkg/sql/sem/plpgsqltree/utils" "github.com/cockroachdb/cockroach/pkg/sql/sem/semenumpb" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/sql/sessiondata" @@ -193,7 +196,7 @@ func formatViewQueryForDisplay( } // Convert sequences referenced by ID in the view back to their names. - sequenceReplacedViewQuery, err := formatQuerySequencesForDisplay(ctx, semaCtx, typeReplacedViewQuery, false /* multiStmt */) + sequenceReplacedViewQuery, err := formatQuerySequencesForDisplay(ctx, semaCtx, typeReplacedViewQuery, false /* multiStmt */, catpb.Function_SQL) if err != nil { log.Warningf(ctx, "error converting sequence IDs to names for view %s (%v): %+v", desc.GetName(), desc.GetID(), err) @@ -207,9 +210,16 @@ func formatViewQueryForDisplay( // looks for sequence IDs in the statement. If it finds any, // it will replace the IDs with the descriptor's fully qualified name. func formatQuerySequencesForDisplay( - ctx context.Context, semaCtx *tree.SemaContext, queries string, multiStmt bool, + ctx context.Context, + semaCtx *tree.SemaContext, + queries string, + multiStmt bool, + lang catpb.Function_Language, ) (string, error) { replaceFunc := func(expr tree.Expr) (recurse bool, newExpr tree.Expr, err error) { + if expr == nil { + return false, expr, nil + } newExpr, err = schemaexpr.ReplaceSequenceIDsWithFQNames(ctx, expr, semaCtx) if err != nil { return false, expr, err @@ -217,37 +227,51 @@ func formatQuerySequencesForDisplay( return false, newExpr, nil } - var stmts tree.Statements - if multiStmt { - parsedStmts, err := parser.Parse(queries) - if err != nil { - return "", err + fmtCtx := tree.NewFmtCtx(tree.FmtSimple) + switch lang { + case catpb.Function_SQL: + var stmts tree.Statements + if multiStmt { + parsedStmts, err := parser.Parse(queries) + if err != nil { + return "", err + } + stmts = make(tree.Statements, len(parsedStmts)) + for i, stmt := range parsedStmts { + stmts[i] = stmt.AST + } + } else { + stmt, err := parser.ParseOne(queries) + if err != nil { + return "", err + } + stmts = tree.Statements{stmt.AST} } - stmts = make(tree.Statements, len(parsedStmts)) - for i, stmt := range parsedStmts { - stmts[i] = stmt.AST + + for i, stmt := range stmts { + newStmt, err := tree.SimpleStmtVisit(stmt, replaceFunc) + if err != nil { + return "", err + } + if i > 0 { + fmtCtx.WriteString("\n") + } + fmtCtx.FormatNode(newStmt) + if multiStmt { + fmtCtx.WriteString(";") + } } - } else { - stmt, err := parser.ParseOne(queries) + case catpb.Function_PLPGSQL: + var stmts plpgsqltree.Statement + plstmt, err := plpgsql.Parse(queries) if err != nil { - return "", err + return "", errors.Wrap(err, "failed to parse query string") } - stmts = tree.Statements{stmt.AST} - } + stmts = plstmt.AST - fmtCtx := tree.NewFmtCtx(tree.FmtSimple) - for i, stmt := range stmts { - newStmt, err := tree.SimpleStmtVisit(stmt, replaceFunc) - if err != nil { - return "", err - } - if i > 0 { - fmtCtx.WriteString("\n") - } + v := utils.SQLStmtVisitor{Fn: replaceFunc} + newStmt := plpgsqltree.Walk(&v, stmts) fmtCtx.FormatNode(newStmt) - if multiStmt { - fmtCtx.WriteString(";") - } } return fmtCtx.CloseAndGetString(), nil } @@ -308,7 +332,7 @@ func formatViewQueryTypesForDisplay( } // formatFunctionQueryTypesForDisplay is similar to -// formatViewQueryTypesForDisplay but can only be used for function. +// formatViewQueryTypesForDisplay but can only be used for functions. // nil is used as the table descriptor for schemaexpr.FormatExprForDisplay call. // This is fine assuming that UDFs cannot be created with expression casting a // column/var to an enum in function body. This is super rare case for now, and @@ -319,8 +343,12 @@ func formatFunctionQueryTypesForDisplay( semaCtx *tree.SemaContext, sessionData *sessiondata.SessionData, queries string, + lang catpb.Function_Language, ) (string, error) { replaceFunc := func(expr tree.Expr) (recurse bool, newExpr tree.Expr, err error) { + if expr == nil { + return false, expr, nil + } // We need to resolve the type to check if it's user-defined. If not, // no other work is needed. var typRef tree.ResolvableTypeReference @@ -352,29 +380,74 @@ func formatFunctionQueryTypesForDisplay( } return false, newExpr, nil } - - var stmts tree.Statements - parsedStmts, err := parser.Parse(queries) - if err != nil { - return "", errors.Wrap(err, "failed to parse query") - } - stmts = make(tree.Statements, len(parsedStmts)) - for i, stmt := range parsedStmts { - stmts[i] = stmt.AST + replaceTypeFunc := func(typ tree.ResolvableTypeReference) (newTyp tree.ResolvableTypeReference, err error) { + if typ == nil { + return typ, nil + } + // semaCtx may be nil if this is a virtual view being created at + // init time. + var typeResolver tree.TypeReferenceResolver + if semaCtx != nil { + typeResolver = semaCtx.TypeResolver + } + var t *types.T + t, err = tree.ResolveType(ctx, typ, typeResolver) + if err != nil { + return typ, err + } + if !t.UserDefined() { + return typ, nil + } + name := t.TypeMeta.Name + typname := tree.MakeTypeNameWithPrefix(tree.ObjectNamePrefix{ + CatalogName: tree.Name(name.Catalog), + SchemaName: tree.Name(name.Schema), + ExplicitCatalog: name.Catalog != "", + ExplicitSchema: name.ExplicitSchema, + }, name.Name) + ref := typname.ToUnresolvedObjectName() + return ref, nil } fmtCtx := tree.NewFmtCtx(tree.FmtSimple) - for i, stmt := range stmts { - newStmt, err := tree.SimpleStmtVisit(stmt, replaceFunc) + switch lang { + case catpb.Function_SQL: + var stmts tree.Statements + parsedStmts, err := parser.Parse(queries) if err != nil { - return "", err + return "", errors.Wrap(err, "failed to parse query") } - if i > 0 { - fmtCtx.WriteString("\n") + stmts = make(tree.Statements, len(parsedStmts)) + for i, stmt := range parsedStmts { + stmts[i] = stmt.AST + } + + for i, stmt := range stmts { + newStmt, err := tree.SimpleStmtVisit(stmt, replaceFunc) + if err != nil { + return "", err + } + if i > 0 { + fmtCtx.WriteString("\n") + } + fmtCtx.FormatNode(newStmt) + fmtCtx.WriteString(";") } + case catpb.Function_PLPGSQL: + var stmts plpgsqltree.Statement + plstmt, err := plpgsql.Parse(queries) + if err != nil { + return "", errors.Wrap(err, "failed to parse query string") + } + stmts = plstmt.AST + + v := utils.SQLStmtVisitor{Fn: replaceFunc} + newStmt := plpgsqltree.Walk(&v, stmts) + v2 := utils.TypeRefVisitor{Fn: replaceTypeFunc} + newStmt = plpgsqltree.Walk(&v2, newStmt) fmtCtx.FormatNode(newStmt) - fmtCtx.WriteString(";") } + return fmtCtx.CloseAndGetString(), nil }