From d8a68ba2fe2bfc739953cf818fb30d18e0c61067 Mon Sep 17 00:00:00 2001 From: rharding6373 Date: Tue, 5 Dec 2023 22:15:49 -0800 Subject: [PATCH] sql: support sequence and udt name rewriting in plpgsql CRDB rewrites sequence and UDT names as IDs in views and functions so that if the sequence or UDT is renamed the views and functions using them don't break. This PR adds support for this in PLpgSQL. Epic: None Fixes: #115627 Release note (sql change): Fixes a bug in PLpgSQL where altering the name of a sequence or UDT that was used in a PLpgSQL function or procedure could break them. This is only present in 23.2 alpha and beta releases. --- .../testdata/logic_test/udf_rewrite | 295 +++++++++++- pkg/sql/BUILD.bazel | 2 + pkg/sql/crdb_internal.go | 31 +- pkg/sql/create_function.go | 15 +- pkg/sql/create_view.go | 201 ++++++--- pkg/sql/opt/optbuilder/plpgsql.go | 7 +- pkg/sql/plpgsql/parser/testdata/stmt_case | 1 + pkg/sql/sem/plpgsqltree/BUILD.bazel | 1 + pkg/sql/sem/plpgsqltree/exception.go | 23 +- pkg/sql/sem/plpgsqltree/statements.go | 427 +++++++++++++----- pkg/sql/sem/plpgsqltree/utils/plpg_visitor.go | 270 ++++++++++- pkg/sql/sem/plpgsqltree/visitor.go | 7 +- pkg/sql/show_create_clauses.go | 167 +++++-- 13 files changed, 1181 insertions(+), 266 deletions(-) diff --git a/pkg/ccl/logictestccl/testdata/logic_test/udf_rewrite b/pkg/ccl/logictestccl/testdata/logic_test/udf_rewrite index d6579e3df3be..59832480faca 100644 --- a/pkg/ccl/logictestccl/testdata/logic_test/udf_rewrite +++ b/pkg/ccl/logictestccl/testdata/logic_test/udf_rewrite @@ -37,34 +37,153 @@ SET use_declarative_schema_changer = 'on' subtest rewrite_plpgsql -statement ok -DROP FUNCTION IF EXISTS f_rewrite - statement ok CREATE OR REPLACE FUNCTION f_rewrite() RETURNS INT AS $$ + DECLARE + i INT := nextval('seq'); + j INT := nextval('seq'); + curs REFCURSOR := nextval('seq')::STRING; + curs2 CURSOR FOR SELECT nextval('seq'); BEGIN - SELECT nextval('seq'); + RAISE NOTICE USING MESSAGE = format('next val: %d',nextval('seq')); + RAISE NOTICE 'val1: %, val2: %', nextval('seq'), nextval('seq'); + WHILE nextval('seq') < 10 LOOP + i = nextval('seq'); + SELECT nextval('seq'); + IF nextval('seq') = 1 THEN + SELECT nextval('seq'); + SELECT nextval('seq'); + CONTINUE; + ELSIF nextval('seq') = 2 THEN + SELECT v INTO i FROM nextval('seq') AS v(INT); + ELSIF nextval('seq') = 3 THEN + SELECT nextval('seq'); + SELECT nextval('seq'); + END IF; + END LOOP; + OPEN curs FOR SELECT nextval('seq'); + RETURN nextval('seq'); + EXCEPTION + WHEN division_by_zero THEN + RAISE NOTICE USING MESSAGE = format('next val: %d',nextval('seq')); + WHEN not_null_violation THEN + SELECT nextval('seq'); + SELECT nextval('seq'); + RAISE NOTICE USING MESSAGE = format('next val: %d',nextval('seq')); END -$$ LANGUAGE PLPGSQL +$$ LANGUAGE PLPGSQL; query T SELECT get_body_str('f_rewrite'); ---- -"BEGIN\nSELECT nextval('seq':::STRING);\nEND\n;" +"DECLARE\ni INT8 := nextval(106:::REGCLASS);\nj INT8 := nextval(106:::REGCLASS);\ncurs REFCURSOR := nextval(106:::REGCLASS)::STRING;\ncurs2 CURSOR FOR SELECT nextval(106:::REGCLASS);\nBEGIN\nRAISE notice\nUSING MESSAGE = format('next val: %d':::STRING, nextval(106:::REGCLASS));\nRAISE notice 'val1: %, val2: %', nextval(106:::REGCLASS), nextval(106:::REGCLASS);\nWHILE nextval(106:::REGCLASS) < 10:::INT8 LOOP\ni := nextval(106:::REGCLASS);\nSELECT nextval(106:::REGCLASS);\nIF nextval(106:::REGCLASS) = 1:::INT8 THEN\n\tSELECT nextval(106:::REGCLASS);\n\tSELECT nextval(106:::REGCLASS);\n\tCONTINUE;\nELSIF nextval(106:::REGCLASS) = 2:::INT8 THEN\n\tSELECT v FROM ROWS FROM (nextval(106:::REGCLASS)) AS v (\"int\") INTO i;\nELSIF nextval(106:::REGCLASS) = 3:::INT8 THEN\n\tSELECT nextval(106:::REGCLASS);\n\tSELECT nextval(106:::REGCLASS);\nEND IF;\nEND LOOP;\nOPEN curs FOR SELECT nextval(106:::REGCLASS);\nRETURN nextval(106:::REGCLASS);\nEXCEPTION\nWHEN division_by_zero THEN\nRAISE notice\nUSING MESSAGE = format('next val: %d':::STRING, nextval(106:::REGCLASS));\nWHEN not_null_violation THEN\nSELECT nextval(106:::REGCLASS);\nSELECT nextval(106:::REGCLASS);\nRAISE notice\nUSING MESSAGE = format('next val: %d':::STRING, nextval(106:::REGCLASS));\nEND\n;" + +query TT +SHOW CREATE FUNCTION f_rewrite; +---- +f_rewrite CREATE FUNCTION public.f_rewrite() + RETURNS INT8 + VOLATILE + NOT LEAKPROOF + CALLED ON NULL INPUT + LANGUAGE plpgsql + AS $$ + DECLARE + i INT8 := nextval('public.seq'::REGCLASS); + j INT8 := nextval('public.seq'::REGCLASS); + curs REFCURSOR := nextval('public.seq'::REGCLASS)::STRING; + curs2 CURSOR FOR SELECT nextval('public.seq'::REGCLASS); + BEGIN + RAISE notice + USING MESSAGE = format('next val: %d':::STRING, nextval('public.seq'::REGCLASS)); + RAISE notice 'val1: %, val2: %', nextval('public.seq'::REGCLASS), nextval('public.seq'::REGCLASS); + WHILE nextval('public.seq'::REGCLASS) < 10:::INT8 LOOP + i := nextval('public.seq'::REGCLASS); + SELECT nextval('public.seq'::REGCLASS); + IF nextval('public.seq'::REGCLASS) = 1:::INT8 THEN + SELECT nextval('public.seq'::REGCLASS); + SELECT nextval('public.seq'::REGCLASS); + CONTINUE; + ELSIF nextval('public.seq'::REGCLASS) = 2:::INT8 THEN + SELECT v FROM ROWS FROM (nextval('public.seq'::REGCLASS)) AS v ("int") INTO i; + ELSIF nextval('public.seq'::REGCLASS) = 3:::INT8 THEN + SELECT nextval('public.seq'::REGCLASS); + SELECT nextval('public.seq'::REGCLASS); + END IF; + END LOOP; + OPEN curs FOR SELECT nextval('public.seq'::REGCLASS); + RETURN nextval('public.seq'::REGCLASS); + EXCEPTION + WHEN division_by_zero THEN + RAISE notice + USING MESSAGE = format('next val: %d':::STRING, nextval('public.seq'::REGCLASS)); + WHEN not_null_violation THEN + SELECT nextval('public.seq'::REGCLASS); + SELECT nextval('public.seq'::REGCLASS); + RAISE notice + USING MESSAGE = format('next val: %d':::STRING, nextval('public.seq'::REGCLASS)); + END + $$ statement ok -CREATE OR REPLACE FUNCTION f_rewrite() RETURNS INT AS -$$ - BEGIN - INSERT INTO t_rewrite(v) VALUES (nextval('seq')) RETURNING v; - END -$$ LANGUAGE PLPGSQL +ALTER SEQUENCE seq RENAME TO renamed; query T SELECT get_body_str('f_rewrite'); ---- -"BEGIN\nINSERT INTO test.public.t_rewrite(v) VALUES (nextval('seq':::STRING)) RETURNING v;\nEND\n;" +"DECLARE\ni INT8 := nextval(106:::REGCLASS);\nj INT8 := nextval(106:::REGCLASS);\ncurs REFCURSOR := nextval(106:::REGCLASS)::STRING;\ncurs2 CURSOR FOR SELECT nextval(106:::REGCLASS);\nBEGIN\nRAISE notice\nUSING MESSAGE = format('next val: %d':::STRING, nextval(106:::REGCLASS));\nRAISE notice 'val1: %, val2: %', nextval(106:::REGCLASS), nextval(106:::REGCLASS);\nWHILE nextval(106:::REGCLASS) < 10:::INT8 LOOP\ni := nextval(106:::REGCLASS);\nSELECT nextval(106:::REGCLASS);\nIF nextval(106:::REGCLASS) = 1:::INT8 THEN\n\tSELECT nextval(106:::REGCLASS);\n\tSELECT nextval(106:::REGCLASS);\n\tCONTINUE;\nELSIF nextval(106:::REGCLASS) = 2:::INT8 THEN\n\tSELECT v FROM ROWS FROM (nextval(106:::REGCLASS)) AS v (\"int\") INTO i;\nELSIF nextval(106:::REGCLASS) = 3:::INT8 THEN\n\tSELECT nextval(106:::REGCLASS);\n\tSELECT nextval(106:::REGCLASS);\nEND IF;\nEND LOOP;\nOPEN curs FOR SELECT nextval(106:::REGCLASS);\nRETURN nextval(106:::REGCLASS);\nEXCEPTION\nWHEN division_by_zero THEN\nRAISE notice\nUSING MESSAGE = format('next val: %d':::STRING, nextval(106:::REGCLASS));\nWHEN not_null_violation THEN\nSELECT nextval(106:::REGCLASS);\nSELECT nextval(106:::REGCLASS);\nRAISE notice\nUSING MESSAGE = format('next val: %d':::STRING, nextval(106:::REGCLASS));\nEND\n;" + +query TT +SHOW CREATE FUNCTION f_rewrite; +---- +f_rewrite CREATE FUNCTION public.f_rewrite() + RETURNS INT8 + VOLATILE + NOT LEAKPROOF + CALLED ON NULL INPUT + LANGUAGE plpgsql + AS $$ + DECLARE + i INT8 := nextval('public.renamed'::REGCLASS); + j INT8 := nextval('public.renamed'::REGCLASS); + curs REFCURSOR := nextval('public.renamed'::REGCLASS)::STRING; + curs2 CURSOR FOR SELECT nextval('public.renamed'::REGCLASS); + BEGIN + RAISE notice + USING MESSAGE = format('next val: %d':::STRING, nextval('public.renamed'::REGCLASS)); + RAISE notice 'val1: %, val2: %', nextval('public.renamed'::REGCLASS), nextval('public.renamed'::REGCLASS); + WHILE nextval('public.renamed'::REGCLASS) < 10:::INT8 LOOP + i := nextval('public.renamed'::REGCLASS); + SELECT nextval('public.renamed'::REGCLASS); + IF nextval('public.renamed'::REGCLASS) = 1:::INT8 THEN + SELECT nextval('public.renamed'::REGCLASS); + SELECT nextval('public.renamed'::REGCLASS); + CONTINUE; + ELSIF nextval('public.renamed'::REGCLASS) = 2:::INT8 THEN + SELECT v FROM ROWS FROM (nextval('public.renamed'::REGCLASS)) AS v ("int") INTO i; + ELSIF nextval('public.renamed'::REGCLASS) = 3:::INT8 THEN + SELECT nextval('public.renamed'::REGCLASS); + SELECT nextval('public.renamed'::REGCLASS); + END IF; + END LOOP; + OPEN curs FOR SELECT nextval('public.renamed'::REGCLASS); + RETURN nextval('public.renamed'::REGCLASS); + EXCEPTION + WHEN division_by_zero THEN + RAISE notice + USING MESSAGE = format('next val: %d':::STRING, nextval('public.renamed'::REGCLASS)); + WHEN not_null_violation THEN + SELECT nextval('public.renamed'::REGCLASS); + SELECT nextval('public.renamed'::REGCLASS); + RAISE notice + USING MESSAGE = format('next val: %d':::STRING, nextval('public.renamed'::REGCLASS)); + END + $$ + +# Reset sequence name for subtest. +statement ok +ALTER SEQUENCE renamed RENAME TO seq; statement ok DROP FUNCTION f_rewrite(); @@ -72,28 +191,154 @@ DROP FUNCTION f_rewrite(); statement ok CREATE OR REPLACE FUNCTION f_rewrite() RETURNS weekday AS $$ + DECLARE + day weekday := 'wednesday'::weekday; + today weekday := 'thursday'::weekday; + curs REFCURSOR := 'monday'::weekday::STRING; + curs2 CURSOR FOR SELECT 'tuesday'::weekday; BEGIN - SELECT 'wednesday'::weekday; + RAISE NOTICE USING MESSAGE = format('val: %d','wednesday'::weekday); + RAISE NOTICE 'val1: %, val2: %', 'wednesday'::weekday, 'thursday'::weekday; + WHILE day != 'wednesday'::weekday LOOP + day = 'friday'::weekday; + SELECT 'wednesday'::weekday; + IF day = 'wednesday'::weekday THEN + day = 'thursday'::weekday; + SELECT 'tuesday'::weekday; + CONTINUE; + ELSIF day = 'monday'::weekday THEN + SELECT 'tuesday'::weekday INTO day; + ELSIF day = 'tuesday'::weekday THEN + SELECT 'wednesday'::weekday INTO day; + SELECT 'wednesday'::weekday; + END IF; + END LOOP; + OPEN curs FOR SELECT 'wednesday'::weekday; + RETURN 'wednesday'::weekday; + EXCEPTION + WHEN division_by_zero THEN + RAISE NOTICE USING MESSAGE = format('val: %d','wednesday'::weekday); + WHEN not_null_violation THEN + SELECT 'wednesday'::weekday; + RAISE NOTICE 'val: %', 'wednesday'::weekday; END -$$ LANGUAGE PLPGSQL +$$ LANGUAGE PLPGSQL; query T SELECT get_body_str('f_rewrite'); ---- -"BEGIN\nSELECT 'wednesday'::@100107;\nEND\n;" +"DECLARE\nday @100107 := b'\\x80':::@100107;\ntoday @100107 := b'\\xa0':::@100107;\ncurs REFCURSOR := b' ':::@100107::STRING;\ncurs2 CURSOR FOR SELECT b'@':::@100107;\nBEGIN\nRAISE notice\nUSING MESSAGE = format('val: %d':::STRING, b'\\x80':::@100107);\nRAISE notice 'val1: %, val2: %', b'\\x80':::@100107, b'\\xa0':::@100107;\nWHILE day != b'\\x80':::@100107 LOOP\nday := b'\\xc0':::@100107;\nSELECT b'\\x80':::@100107;\nIF day = b'\\x80':::@100107 THEN\n\tday := b'\\xa0':::@100107;\n\tSELECT b'@':::@100107;\n\tCONTINUE;\nELSIF day = b' ':::@100107 THEN\n\tSELECT b'@':::@100107 INTO day;\nELSIF day = b'@':::@100107 THEN\n\tSELECT b'\\x80':::@100107 INTO day;\n\tSELECT b'\\x80':::@100107;\nEND IF;\nEND LOOP;\nOPEN curs FOR SELECT b'\\x80':::@100107;\nRETURN b'\\x80':::@100107;\nEXCEPTION\nWHEN division_by_zero THEN\nRAISE notice\nUSING MESSAGE = format('val: %d':::STRING, b'\\x80':::@100107);\nWHEN not_null_violation THEN\nSELECT b'\\x80':::@100107;\nRAISE notice 'val: %', b'\\x80':::@100107;\nEND\n;" + +query TT +SHOW CREATE FUNCTION f_rewrite; +---- +f_rewrite CREATE FUNCTION public.f_rewrite() + RETURNS test.public.weekday + VOLATILE + NOT LEAKPROOF + CALLED ON NULL INPUT + LANGUAGE plpgsql + AS $$ + DECLARE + day test.public.weekday := 'wednesday':::test.public.weekday; + today test.public.weekday := 'thursday':::test.public.weekday; + curs REFCURSOR := 'monday':::test.public.weekday::STRING; + curs2 CURSOR FOR SELECT 'tuesday':::test.public.weekday; + BEGIN + RAISE notice + USING MESSAGE = format('val: %d':::STRING, 'wednesday':::test.public.weekday); + RAISE notice 'val1: %, val2: %', 'wednesday':::test.public.weekday, 'thursday':::test.public.weekday; + WHILE day != 'wednesday':::test.public.weekday LOOP + day := 'friday':::test.public.weekday; + SELECT 'wednesday':::test.public.weekday; + IF day = 'wednesday':::test.public.weekday THEN + day := 'thursday':::test.public.weekday; + SELECT 'tuesday':::test.public.weekday; + CONTINUE; + ELSIF day = 'monday':::test.public.weekday THEN + SELECT 'tuesday':::test.public.weekday INTO day; + ELSIF day = 'tuesday':::test.public.weekday THEN + SELECT 'wednesday':::test.public.weekday INTO day; + SELECT 'wednesday':::test.public.weekday; + END IF; + END LOOP; + OPEN curs FOR SELECT 'wednesday':::test.public.weekday; + RETURN 'wednesday':::test.public.weekday; + EXCEPTION + WHEN division_by_zero THEN + RAISE notice + USING MESSAGE = format('val: %d':::STRING, 'wednesday':::test.public.weekday); + WHEN not_null_violation THEN + SELECT 'wednesday':::test.public.weekday; + RAISE notice 'val: %', 'wednesday':::test.public.weekday; + END + $$ statement ok -CREATE OR REPLACE FUNCTION f_rewrite() RETURNS weekday AS -$$ - BEGIN - UPDATE t_rewrite SET w = 'thursday'::weekday WHERE w = 'wednesday'::weekday RETURNING w; - END -$$ LANGUAGE PLPGSQL +ALTER TYPE weekday RENAME VALUE 'wednesday' TO 'humpday'; + +statement ok +ALTER TYPE weekday RENAME TO workday; query T SELECT get_body_str('f_rewrite'); ---- -"BEGIN\nUPDATE test.public.t_rewrite SET w = 'thursday'::@100107 WHERE w = 'wednesday'::@100107 RETURNING w;\nEND\n;" +"DECLARE\nday @100107 := b'\\x80':::@100107;\ntoday @100107 := b'\\xa0':::@100107;\ncurs REFCURSOR := b' ':::@100107::STRING;\ncurs2 CURSOR FOR SELECT b'@':::@100107;\nBEGIN\nRAISE notice\nUSING MESSAGE = format('val: %d':::STRING, b'\\x80':::@100107);\nRAISE notice 'val1: %, val2: %', b'\\x80':::@100107, b'\\xa0':::@100107;\nWHILE day != b'\\x80':::@100107 LOOP\nday := b'\\xc0':::@100107;\nSELECT b'\\x80':::@100107;\nIF day = b'\\x80':::@100107 THEN\n\tday := b'\\xa0':::@100107;\n\tSELECT b'@':::@100107;\n\tCONTINUE;\nELSIF day = b' ':::@100107 THEN\n\tSELECT b'@':::@100107 INTO day;\nELSIF day = b'@':::@100107 THEN\n\tSELECT b'\\x80':::@100107 INTO day;\n\tSELECT b'\\x80':::@100107;\nEND IF;\nEND LOOP;\nOPEN curs FOR SELECT b'\\x80':::@100107;\nRETURN b'\\x80':::@100107;\nEXCEPTION\nWHEN division_by_zero THEN\nRAISE notice\nUSING MESSAGE = format('val: %d':::STRING, b'\\x80':::@100107);\nWHEN not_null_violation THEN\nSELECT b'\\x80':::@100107;\nRAISE notice 'val: %', b'\\x80':::@100107;\nEND\n;" + +query TT +SHOW CREATE FUNCTION f_rewrite; +---- +f_rewrite CREATE FUNCTION public.f_rewrite() + RETURNS test.public.workday + VOLATILE + NOT LEAKPROOF + CALLED ON NULL INPUT + LANGUAGE plpgsql + AS $$ + DECLARE + day test.public.workday := 'humpday':::test.public.workday; + today test.public.workday := 'thursday':::test.public.workday; + curs REFCURSOR := 'monday':::test.public.workday::STRING; + curs2 CURSOR FOR SELECT 'tuesday':::test.public.workday; + BEGIN + RAISE notice + USING MESSAGE = format('val: %d':::STRING, 'humpday':::test.public.workday); + RAISE notice 'val1: %, val2: %', 'humpday':::test.public.workday, 'thursday':::test.public.workday; + WHILE day != 'humpday':::test.public.workday LOOP + day := 'friday':::test.public.workday; + SELECT 'humpday':::test.public.workday; + IF day = 'humpday':::test.public.workday THEN + day := 'thursday':::test.public.workday; + SELECT 'tuesday':::test.public.workday; + CONTINUE; + ELSIF day = 'monday':::test.public.workday THEN + SELECT 'tuesday':::test.public.workday INTO day; + ELSIF day = 'tuesday':::test.public.workday THEN + SELECT 'humpday':::test.public.workday INTO day; + SELECT 'humpday':::test.public.workday; + END IF; + END LOOP; + OPEN curs FOR SELECT 'humpday':::test.public.workday; + RETURN 'humpday':::test.public.workday; + EXCEPTION + WHEN division_by_zero THEN + RAISE notice + USING MESSAGE = format('val: %d':::STRING, 'humpday':::test.public.workday); + WHEN not_null_violation THEN + SELECT 'humpday':::test.public.workday; + RAISE notice 'val: %', 'humpday':::test.public.workday; + END + $$ + +# Reset types for subtest. +statement ok +ALTER TYPE workday RENAME TO weekday; + +statement ok +ALTER TYPE weekday RENAME VALUE 'humpday' TO 'wednesday'; + +statement ok +DROP FUNCTION f_rewrite(); subtest end @@ -110,7 +355,7 @@ $$ LANGUAGE PLPGSQL query T SELECT get_body_str('p_rewrite'); ---- -"BEGIN\nINSERT INTO test.public.t_rewrite(v) VALUES (nextval('seq':::STRING)) RETURNING v;\nEND\n;" +"BEGIN\nINSERT INTO test.public.t_rewrite(v) VALUES (nextval(106:::REGCLASS)) RETURNING v;\nEND\n;" statement ok DROP PROCEDURE p_rewrite(); @@ -126,6 +371,6 @@ $$ LANGUAGE PLPGSQL query T SELECT get_body_str('p_rewrite'); ---- -"BEGIN\nUPDATE test.public.t_rewrite SET w = 'thursday'::@100107 WHERE w = 'wednesday'::@100107 RETURNING w;\nEND\n;" +"BEGIN\nUPDATE test.public.t_rewrite SET w = b'\\xa0':::@100107 WHERE w = b'\\x80':::@100107 RETURNING w;\nEND\n;" subtest end diff --git a/pkg/sql/BUILD.bazel b/pkg/sql/BUILD.bazel index 6bd121b0dd03..9f9090a57bfe 100644 --- a/pkg/sql/BUILD.bazel +++ b/pkg/sql/BUILD.bazel @@ -490,6 +490,8 @@ go_library( "//pkg/sql/sem/catconstants", "//pkg/sql/sem/catid", "//pkg/sql/sem/eval", + "//pkg/sql/sem/plpgsqltree", + "//pkg/sql/sem/plpgsqltree/utils", "//pkg/sql/sem/semenumpb", "//pkg/sql/sem/transform", "//pkg/sql/sem/tree", diff --git a/pkg/sql/crdb_internal.go b/pkg/sql/crdb_internal.go index de5e8ac0944a..44785a8aaa4a 100644 --- a/pkg/sql/crdb_internal.go +++ b/pkg/sql/crdb_internal.go @@ -70,7 +70,6 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgnotice" - plpgsql "github.com/cockroachdb/cockroach/pkg/sql/plpgsql/parser" "github.com/cockroachdb/cockroach/pkg/sql/privilege" "github.com/cockroachdb/cockroach/pkg/sql/protoreflect" "github.com/cockroachdb/cockroach/pkg/sql/roleoption" @@ -3585,29 +3584,15 @@ func createRoutinePopulate( for i := range treeNode.Options { if body, ok := treeNode.Options[i].(tree.RoutineBodyStr); ok { bodyStr := string(body) - switch fnDesc.GetLanguage() { - case catpb.Function_SQL: - bodyStr, err = formatFunctionQueryTypesForDisplay(ctx, &p.semaCtx, p.SessionData(), bodyStr) - if err != nil { - return err - } - bodyStr, err = formatQuerySequencesForDisplay(ctx, &p.semaCtx, bodyStr, true /* multiStmt */) - if err != nil { - return err - } - bodyStr = "\n" + bodyStr + "\n" - case catpb.Function_PLPGSQL: - // TODO(drewk): integrate this with the SQL case above. - plpgsqlStmt, err := plpgsql.Parse(bodyStr) - if err != nil { - return err - } - fmtCtx := tree.NewFmtCtx(tree.FmtParsable) - fmtCtx.FormatNode(plpgsqlStmt.AST) - bodyStr = "\n" + fmtCtx.CloseAndGetString() - default: - return errors.AssertionFailedf("unexpected function language: %s", fnDesc.GetLanguage()) + bodyStr, err = formatFunctionQueryTypesForDisplay(ctx, &p.semaCtx, p.SessionData(), bodyStr, fnDesc.GetLanguage()) + if err != nil { + return err + } + bodyStr, err = formatQuerySequencesForDisplay(ctx, &p.semaCtx, bodyStr, true /* multiStmt */, fnDesc.GetLanguage()) + if err != nil { + return err } + bodyStr = "\n" + bodyStr + "\n" stmtStrs := strings.Split(bodyStr, "\n") for i := range stmtStrs { if stmtStrs[i] != "" { diff --git a/pkg/sql/create_function.go b/pkg/sql/create_function.go index 4d2fe98b81c6..193cc7add44e 100644 --- a/pkg/sql/create_function.go +++ b/pkg/sql/create_function.go @@ -455,26 +455,19 @@ func setFuncOptions( } } - switch lang { - case catpb.Function_SQL: + if lang != catpb.Function_UNKNOWN_LANGUAGE && body != "" { // Replace any sequence names in the function body with IDs. - seqReplacedFuncBody, err := replaceSeqNamesWithIDs(params.ctx, params.p, body, true) + seqReplacedFuncBody, err := replaceSeqNamesWithIDsLang(params.ctx, params.p, body, true, lang) if err != nil { return err } - typeReplacedFuncBody, err := serializeUserDefinedTypes( - params.ctx, params.p.SemaCtx(), seqReplacedFuncBody, true /* multiStmt */, "UDFs", - ) + typeReplacedFuncBody, err := serializeUserDefinedTypesLang( + params.ctx, params.p.SemaCtx(), seqReplacedFuncBody, true /* multiStmt */, "UDFs", lang) if err != nil { return err } udfDesc.SetFuncBody(typeReplacedFuncBody) - case catpb.Function_PLPGSQL: - // TODO(#115627): make replaceSeqNamesWithIDs and serializeUserDefinedTypes - // play nice with PL/pgSQL. - udfDesc.SetFuncBody(body) } - return nil } diff --git a/pkg/sql/create_view.go b/pkg/sql/create_view.go index 888939e475b5..f8b87be715fc 100644 --- a/pkg/sql/create_view.go +++ b/pkg/sql/create_view.go @@ -30,8 +30,11 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgnotice" + plpgsql "github.com/cockroachdb/cockroach/pkg/sql/plpgsql/parser" "github.com/cockroachdb/cockroach/pkg/sql/privilege" "github.com/cockroachdb/cockroach/pkg/sql/sem/eval" + "github.com/cockroachdb/cockroach/pkg/sql/sem/plpgsqltree" + "github.com/cockroachdb/cockroach/pkg/sql/sem/plpgsqltree/utils" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/sql/sqlerrors" "github.com/cockroachdb/cockroach/pkg/sql/sqltelemetry" @@ -425,14 +428,31 @@ func makeViewTableDesc( return desc, nil } -// replaceSeqNamesWithIDs prepares to walk the given viewQuery by defining the -// function used to replace sequence names with IDs, and parsing the -// viewQuery into a statement. +// replaceSeqNamesWithIDsLang walks the query in queryStr, replacing any +// sequence names with their IDs and returning a new query string with the names +// replaced. It assumes that the query is in the SQL language. // TODO (Chengxiong): move this to a better place. func replaceSeqNamesWithIDs( ctx context.Context, sc resolver.SchemaResolver, queryStr string, multiStmt bool, +) (string, error) { + return replaceSeqNamesWithIDsLang(ctx, sc, queryStr, multiStmt, catpb.Function_SQL) +} + +// replaceSeqNamesWithIDsLang walks the query in queryStr, replacing any +// sequence names with their IDs and returning a new query string with the names +// replaced. Queries may be in either the SQL or PLpgSQL language, indicated by +// lang. +func replaceSeqNamesWithIDsLang( + ctx context.Context, + sc resolver.SchemaResolver, + queryStr string, + multiStmt bool, + lang catpb.Function_Language, ) (string, error) { replaceSeqFunc := func(expr tree.Expr) (recurse bool, newExpr tree.Expr, err error) { + if expr == nil { + return false, expr, nil + } seqIdentifiers, err := seqexpr.GetUsedSequences(expr) if err != nil { return false, expr, err @@ -452,48 +472,83 @@ func replaceSeqNamesWithIDs( return false, newExpr, nil } - var stmts tree.Statements - if multiStmt { - parsedStmtd, err := parser.Parse(queryStr) - if err != nil { - return "", errors.Wrap(err, "failed to parse query string") + fmtCtx := tree.NewFmtCtx(tree.FmtSimple) + switch lang { + case catpb.Function_SQL: + var stmts tree.Statements + if multiStmt { + parsedStmtd, err := parser.Parse(queryStr) + if err != nil { + return "", errors.Wrap(err, "failed to parse query string") + } + for _, s := range parsedStmtd { + stmts = append(stmts, s.AST) + } + } else { + stmt, err := parser.ParseOne(queryStr) + if err != nil { + return "", errors.Wrap(err, "failed to parse query string") + } + stmts = tree.Statements{stmt.AST} } - for _, s := range parsedStmtd { - stmts = append(stmts, s.AST) + + for i, stmt := range stmts { + newStmt, err := tree.SimpleStmtVisit(stmt, replaceSeqFunc) + if err != nil { + return "", err + } + if i > 0 { + fmtCtx.WriteString("\n") + } + fmtCtx.FormatNode(newStmt) + if multiStmt { + fmtCtx.WriteString(";") + } } - } else { - stmt, err := parser.ParseOne(queryStr) + case catpb.Function_PLPGSQL: + var stmts plpgsqltree.Statement + plstmt, err := plpgsql.Parse(queryStr) if err != nil { return "", errors.Wrap(err, "failed to parse query string") } - stmts = tree.Statements{stmt.AST} - } + stmts = plstmt.AST - fmtCtx := tree.NewFmtCtx(tree.FmtSimple) - for i, stmt := range stmts { - newStmt, err := tree.SimpleStmtVisit(stmt, replaceSeqFunc) - if err != nil { - return "", err - } - if i > 0 { - fmtCtx.WriteString("\n") - } + v := utils.SQLStmtVisitor{Fn: replaceSeqFunc} + newStmt := plpgsqltree.Walk(&v, stmts) fmtCtx.FormatNode(newStmt) - if multiStmt { - fmtCtx.WriteString(";") - } } return fmtCtx.String(), nil } -// serializeUserDefinedTypes will walk the given view query -// and serialize any user defined types, so that renaming the type -// does not corrupt the view. +// serializeUserDefinedTypes walks the given query and serializes any +// user defined types as IDs, so that renaming the type does not cause +// corruption, and returns a new query string containing the replacement IDs. +// It assumes that the query language is SQL. func serializeUserDefinedTypes( ctx context.Context, semaCtx *tree.SemaContext, queries string, multiStmt bool, parentType string, ) (string, error) { + return serializeUserDefinedTypesLang(ctx, semaCtx, queries, multiStmt, parentType, catpb.Function_SQL) +} + +// serializeUserDefinedTypesLang walks the given query and serializes any +// user defined types as IDs, so that renaming the type does not cause +// corruption, and returns a new query string containing the replacement IDs. +// The query may be in either the SQL or PLpgSQL language, indicated by lang. +func serializeUserDefinedTypesLang( + ctx context.Context, + semaCtx *tree.SemaContext, + queries string, + multiStmt bool, + parentType string, + lang catpb.Function_Language, +) (string, error) { + // replaceFunc is a visitor function that replaces user defined types in SQL + // expressions with their IDs. replaceFunc := func(expr tree.Expr) (recurse bool, newExpr tree.Expr, err error) { + if expr == nil { + return false, expr, nil + } var innerExpr tree.Expr var typRef tree.ResolvableTypeReference switch n := expr.(type) { @@ -539,39 +594,83 @@ func serializeUserDefinedTypes( } return false, parsedExpr, nil } - - var stmts tree.Statements - if multiStmt { - parsedStmts, err := parser.Parse(queries) - if err != nil { - return "", errors.Wrap(err, "failed to parse query") + // replaceTypeFunc is a visitor function that replaces type annotations + // containing user defined types with their IDs. This is currently only + // necessary for some kinds of PLpgSQL statements. + replaceTypeFunc := func(typ tree.ResolvableTypeReference) (newTyp tree.ResolvableTypeReference, err error) { + if typ == nil { + return typ, nil } - stmts = make(tree.Statements, len(parsedStmts)) - for i, stmt := range parsedStmts { - stmts[i] = stmt.AST + // semaCtx may be nil if this is a virtual view being created at + // init time. + var typeResolver tree.TypeReferenceResolver + if semaCtx != nil { + typeResolver = semaCtx.TypeResolver } - } else { - stmt, err := parser.ParseOne(queries) + var t *types.T + t, err = tree.ResolveType(ctx, typ, typeResolver) if err != nil { - return "", errors.Wrap(err, "failed to parse query") + return typ, err + } + if !t.UserDefined() { + return typ, nil } - stmts = tree.Statements{stmt.AST} + return &tree.OIDTypeReference{OID: t.Oid()}, nil } fmtCtx := tree.NewFmtCtx(tree.FmtSimple) - for i, stmt := range stmts { - newStmt, err := tree.SimpleStmtVisit(stmt, replaceFunc) - if err != nil { - return "", err + switch lang { + case catpb.Function_SQL: + var stmts tree.Statements + if multiStmt { + parsedStmts, err := parser.Parse(queries) + if err != nil { + return "", errors.Wrap(err, "failed to parse query") + } + stmts = make(tree.Statements, len(parsedStmts)) + for i, stmt := range parsedStmts { + stmts[i] = stmt.AST + } + } else { + stmt, err := parser.ParseOne(queries) + if err != nil { + return "", errors.Wrap(err, "failed to parse query") + } + stmts = tree.Statements{stmt.AST} } - if i > 0 { - fmtCtx.WriteString("\n") + + for i, stmt := range stmts { + newStmt, err := tree.SimpleStmtVisit(stmt, replaceFunc) + if err != nil { + return "", err + } + if i > 0 { + fmtCtx.WriteString("\n") + } + fmtCtx.FormatNode(newStmt) + if multiStmt { + fmtCtx.WriteString(";") + } } - fmtCtx.FormatNode(newStmt) - if multiStmt { - fmtCtx.WriteString(";") + case catpb.Function_PLPGSQL: + var stmts plpgsqltree.Statement + plstmt, err := plpgsql.Parse(queries) + if err != nil { + return "", errors.Wrap(err, "failed to parse query string") } + stmts = plstmt.AST + + v := utils.SQLStmtVisitor{Fn: replaceFunc} + newStmt := plpgsqltree.Walk(&v, stmts) + // Some PLpgSQL statements (i.e., declarations), may contain type + // annotations containing the UDT. We need to walk the AST to replace them, + // too. + v2 := utils.TypeRefVisitor{Fn: replaceTypeFunc} + newStmt = plpgsqltree.Walk(&v2, newStmt) + fmtCtx.FormatNode(newStmt) + fmtCtx.WriteString(";") } + return fmtCtx.CloseAndGetString(), nil } diff --git a/pkg/sql/opt/optbuilder/plpgsql.go b/pkg/sql/opt/optbuilder/plpgsql.go index 002033abbef0..ffda0efedd37 100644 --- a/pkg/sql/opt/optbuilder/plpgsql.go +++ b/pkg/sql/opt/optbuilder/plpgsql.go @@ -1532,7 +1532,7 @@ func newRecordTypeVisitor( var _ ast.StatementVisitor = &recordTypeVisitor{} -func (r *recordTypeVisitor) Visit(stmt ast.Statement) { +func (r *recordTypeVisitor) Visit(stmt ast.Statement) (newStmt ast.Statement, changed bool) { if retStmt, ok := stmt.(*ast.Return); ok { desired := types.Any if r.typ != types.Unknown { @@ -1545,7 +1545,7 @@ func (r *recordTypeVisitor) Visit(stmt ast.Statement) { } typ := typedExpr.ResolvedType() if typ == types.Unknown { - return + return stmt, false } if typ.Family() != types.TupleFamily { panic(pgerror.New(pgcode.DatatypeMismatch, @@ -1554,7 +1554,7 @@ func (r *recordTypeVisitor) Visit(stmt ast.Statement) { } if r.typ == types.Unknown { r.typ = typ - return + return stmt, false } if !typ.Identical(r.typ) { panic(errors.WithHint( @@ -1565,4 +1565,5 @@ func (r *recordTypeVisitor) Visit(stmt ast.Statement) { )) } } + return stmt, false } diff --git a/pkg/sql/plpgsql/parser/testdata/stmt_case b/pkg/sql/plpgsql/parser/testdata/stmt_case index e8da2026e9d9..adfd3d304de1 100644 --- a/pkg/sql/plpgsql/parser/testdata/stmt_case +++ b/pkg/sql/plpgsql/parser/testdata/stmt_case @@ -133,6 +133,7 @@ BEGIN END CASE; END ---- +decl_stmt: 1 stmt_block: 1 stmt_call: 3 stmt_case: 1 diff --git a/pkg/sql/sem/plpgsqltree/BUILD.bazel b/pkg/sql/sem/plpgsqltree/BUILD.bazel index d51ee48b3cc0..4548391ad4aa 100644 --- a/pkg/sql/sem/plpgsqltree/BUILD.bazel +++ b/pkg/sql/sem/plpgsqltree/BUILD.bazel @@ -13,6 +13,7 @@ go_library( visibility = ["//visibility:public"], deps = [ "//pkg/sql/sem/tree", + "//pkg/util/errorutil/unimplemented", "@com_github_cockroachdb_errors//:errors", ], ) diff --git a/pkg/sql/sem/plpgsqltree/exception.go b/pkg/sql/sem/plpgsqltree/exception.go index 7b77216598c2..9dbc83df8a8f 100644 --- a/pkg/sql/sem/plpgsqltree/exception.go +++ b/pkg/sql/sem/plpgsqltree/exception.go @@ -22,6 +22,13 @@ type Exception struct { Action []Statement } +func (s *Exception) CopyNode() *Exception { + copyNode := *s + copyNode.Conditions = append([]Condition(nil), copyNode.Conditions...) + copyNode.Action = append([]Statement(nil), copyNode.Action...) + return ©Node +} + func (s *Exception) Format(ctx *tree.FmtCtx) { ctx.WriteString("WHEN ") for i, cond := range s.Conditions { @@ -44,11 +51,19 @@ func (s *Exception) PlpgSQLStatementTag() string { return "proc_exception" } -func (s *Exception) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) - for _, stmt := range s.Action { - stmt.WalkStmt(visitor) +func (s *Exception) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + newStmt, changed = visitor.Visit(s) + for i, stmt := range s.Action { + ns, ch := stmt.WalkStmt(visitor) + if ch { + changed = true + if newStmt == s { + newStmt = s.CopyNode() + } + newStmt.(*Exception).Action[i] = ns + } } + return newStmt, changed } type Condition struct { diff --git a/pkg/sql/sem/plpgsqltree/statements.go b/pkg/sql/sem/plpgsqltree/statements.go index 02dc14c36787..c953313311af 100644 --- a/pkg/sql/sem/plpgsqltree/statements.go +++ b/pkg/sql/sem/plpgsqltree/statements.go @@ -16,6 +16,7 @@ import ( "strings" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" + "github.com/cockroachdb/cockroach/pkg/util/errorutil/unimplemented" ) type Expr = tree.Expr @@ -25,7 +26,7 @@ type Statement interface { GetLineNo() int GetStmtID() uint plpgsqlStmt() - WalkStmt(StatementVisitor) + WalkStmt(StatementVisitor) (newStmt Statement, changed bool) } type TaggedStatement interface { @@ -63,6 +64,14 @@ type Block struct { Exceptions []Exception } +func (s *Block) CopyNode() *Block { + copyNode := *s + copyNode.Decls = append([]Statement(nil), copyNode.Decls...) + copyNode.Body = append([]Statement(nil), copyNode.Body...) + copyNode.Exceptions = append([]Exception(nil), copyNode.Exceptions...) + return ©Node +} + // TODO(drewk): format Label and Exceptions fields. func (s *Block) Format(ctx *tree.FmtCtx) { if s.Decls != nil { @@ -90,11 +99,39 @@ func (s *Block) PlpgSQLStatementTag() string { return "stmt_block" } -func (s *Block) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) - for _, stmt := range s.Body { - stmt.WalkStmt(visitor) +func (s *Block) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + newStmt, changed = visitor.Visit(s) + for i, stmt := range s.Decls { + ns, ch := stmt.WalkStmt(visitor) + if ch { + changed = true + if newStmt == s { + newStmt = s.CopyNode() + } + newStmt.(*Block).Decls[i] = ns + } + } + for i, stmt := range s.Body { + ns, ch := stmt.WalkStmt(visitor) + if ch { + changed = true + if newStmt == s { + newStmt = s.CopyNode() + } + newStmt.(*Block).Body[i] = ns + } + } + for i, stmt := range s.Exceptions { + ns, ch := stmt.WalkStmt(visitor) + if ch { + changed = true + if newStmt == s { + newStmt = s.CopyNode() + } + newStmt.(*Block).Exceptions[i] = *(ns.(*Exception)) + } } + return newStmt, changed } // decl_stmt @@ -108,6 +145,11 @@ type Declaration struct { Expr Expr } +func (s *Declaration) CopyNode() *Declaration { + copyNode := *s + return ©Node +} + func (s *Declaration) Format(ctx *tree.FmtCtx) { ctx.WriteString(string(s.Var)) if s.Constant { @@ -131,8 +173,9 @@ func (s *Declaration) PlpgSQLStatementTag() string { return "decl_stmt" } -func (s *Declaration) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) +func (s *Declaration) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + newStmt, changed = visitor.Visit(s) + return newStmt, changed } type CursorDeclaration struct { @@ -142,6 +185,11 @@ type CursorDeclaration struct { Query tree.Statement } +func (s *CursorDeclaration) CopyNode() *CursorDeclaration { + copyNode := *s + return ©Node +} + func (s *CursorDeclaration) Format(ctx *tree.FmtCtx) { ctx.WriteString(string(s.Name)) switch s.Scroll { @@ -159,8 +207,9 @@ func (s *CursorDeclaration) PlpgSQLStatementTag() string { return "decl_cursor_stmt" } -func (s *CursorDeclaration) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) +func (s *CursorDeclaration) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + newStmt, changed = visitor.Visit(s) + return newStmt, changed } // stmt_assign @@ -170,6 +219,11 @@ type Assignment struct { Value Expr } +func (s *Assignment) CopyNode() *Assignment { + copyNode := *s + return ©Node +} + func (s *Assignment) PlpgSQLStatementTag() string { return "stmt_assign" } @@ -178,8 +232,9 @@ func (s *Assignment) Format(ctx *tree.FmtCtx) { ctx.WriteString(fmt.Sprintf("%s := %s;\n", s.Var, s.Value)) } -func (s *Assignment) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) +func (s *Assignment) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + newStmt, changed = visitor.Visit(s) + return newStmt, changed } // stmt_if @@ -191,6 +246,18 @@ type If struct { ElseBody []Statement } +func (s *If) CopyNode() *If { + copyNode := *s + copyNode.ThenBody = append([]Statement(nil), copyNode.ThenBody...) + copyNode.ElseBody = append([]Statement(nil), copyNode.ElseBody...) + copyNode.ElseIfList = make([]ElseIf, len(s.ElseIfList)) + for i, ei := range s.ElseIfList { + copyNode.ElseIfList[i] = ei + copyNode.ElseIfList[i].Stmts = append([]Statement(nil), copyNode.ElseIfList[i].Stmts...) + } + return ©Node +} + func (s *If) Format(ctx *tree.FmtCtx) { ctx.WriteString("IF ") s.Condition.Format(ctx) @@ -217,21 +284,43 @@ func (s *If) PlpgSQLStatementTag() string { return "stmt_if" } -func (s *If) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) +func (s *If) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + newStmt, changed = visitor.Visit(s) - for _, thenStmt := range s.ThenBody { - thenStmt.WalkStmt(visitor) + for i, thenStmt := range s.ThenBody { + ns, ch := thenStmt.WalkStmt(visitor) + if ch { + changed = true + if newStmt == s { + newStmt = s.CopyNode() + } + newStmt.(*If).ThenBody[i] = ns + } } - for _, elseIf := range s.ElseIfList { - elseIf.WalkStmt(visitor) + for i, elseIf := range s.ElseIfList { + ns, ch := elseIf.WalkStmt(visitor) + if ch { + changed = true + if newStmt == s { + newStmt = s.CopyNode() + } + newStmt.(*If).ElseIfList[i] = *ns.(*ElseIf) + } } - for _, elseStmt := range s.ElseBody { - elseStmt.WalkStmt(visitor) + for i, elseStmt := range s.ElseBody { + ns, ch := elseStmt.WalkStmt(visitor) + if ch { + changed = true + if newStmt == s { + newStmt = s.CopyNode() + } + newStmt.(*If).ElseBody[i] = ns + } } + return newStmt, changed } type ElseIf struct { @@ -240,6 +329,12 @@ type ElseIf struct { Stmts []Statement } +func (s *ElseIf) CopyNode() *ElseIf { + copyNode := *s + copyNode.Stmts = append([]Statement(nil), copyNode.Stmts...) + return ©Node +} + func (s *ElseIf) Format(ctx *tree.FmtCtx) { ctx.WriteString("ELSIF ") s.Condition.Format(ctx) @@ -254,12 +349,20 @@ func (s *ElseIf) PlpgSQLStatementTag() string { return "stmt_if_else_if" } -func (s *ElseIf) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) +func (s *ElseIf) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + newStmt, changed = visitor.Visit(s) - for _, stmt := range s.Stmts { - stmt.WalkStmt(visitor) + for i, stmt := range s.Stmts { + ns, ch := stmt.WalkStmt(visitor) + if ch { + changed = true + if newStmt == s { + newStmt = s.CopyNode() + } + newStmt.(*ElseIf).Stmts[i] = ns + } } + return newStmt, changed } // stmt_case @@ -273,6 +376,18 @@ type Case struct { ElseStmts []Statement } +func (s *Case) CopyNode() *Case { + copyNode := *s + copyNode.ElseStmts = append([]Statement(nil), copyNode.ElseStmts...) + copyNode.CaseWhenList = make([]*CaseWhen, len(s.CaseWhenList)) + caseWhens := make([]CaseWhen, len(s.CaseWhenList)) + for i, cw := range s.CaseWhenList { + caseWhens[i] = *cw + copyNode.CaseWhenList[i] = &caseWhens[i] + } + return ©Node +} + // TODO(drewk): fix the whitespace/newline formatting for CASE (see the // stmt_case test file). func (s *Case) Format(ctx *tree.FmtCtx) { @@ -298,18 +413,33 @@ func (s *Case) PlpgSQLStatementTag() string { return "stmt_case" } -func (s *Case) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) +func (s *Case) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + newStmt, changed = visitor.Visit(s) - for _, when := range s.CaseWhenList { - when.WalkStmt(visitor) + for i, when := range s.CaseWhenList { + ns, ch := when.WalkStmt(visitor) + if ch { + changed = true + if newStmt == s { + newStmt = s.CopyNode() + } + newStmt.(*Case).CaseWhenList[i] = ns.(*CaseWhen) + } } if s.HaveElse { - for _, stmt := range s.ElseStmts { - stmt.WalkStmt(visitor) + for i, stmt := range s.ElseStmts { + ns, ch := stmt.WalkStmt(visitor) + if ch { + changed = true + if newStmt == s { + newStmt = s.CopyNode() + } + newStmt.(*Case).ElseStmts[i] = ns + } } } + return newStmt, changed } type CaseWhen struct { @@ -319,6 +449,12 @@ type CaseWhen struct { Stmts []Statement } +func (s *CaseWhen) CopyNode() *CaseWhen { + copyNode := *s + copyNode.Stmts = append([]Statement(nil), copyNode.Stmts...) + return ©Node +} + func (s *CaseWhen) Format(ctx *tree.FmtCtx) { ctx.WriteString(fmt.Sprintf("WHEN %s THEN\n", s.Expr)) for i, stmt := range s.Stmts { @@ -334,12 +470,20 @@ func (s *CaseWhen) PlpgSQLStatementTag() string { return "stmt_when" } -func (s *CaseWhen) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) +func (s *CaseWhen) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + newStmt, changed = visitor.Visit(s) - for _, stmt := range s.Stmts { - stmt.WalkStmt(visitor) + for i, stmt := range s.Stmts { + ns, ch := stmt.WalkStmt(visitor) + if ch { + changed = true + if newStmt == s { + newStmt = s.CopyNode() + } + newStmt.(*CaseWhen).Stmts[i] = ns + } } + return newStmt, changed } // stmt_loop @@ -349,6 +493,12 @@ type Loop struct { Body []Statement } +func (s *Loop) CopyNode() *Loop { + copyNode := *s + copyNode.Body = append([]Statement(nil), copyNode.Body...) + return ©Node +} + func (s *Loop) PlpgSQLStatementTag() string { return "stmt_simple_loop" } @@ -365,11 +515,19 @@ func (s *Loop) Format(ctx *tree.FmtCtx) { ctx.WriteString(";\n") } -func (s *Loop) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) - for _, stmt := range s.Body { - stmt.WalkStmt(visitor) +func (s *Loop) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + newStmt, changed = visitor.Visit(s) + for i, stmt := range s.Body { + ns, ch := stmt.WalkStmt(visitor) + if ch { + changed = true + if newStmt == s { + newStmt = s.CopyNode() + } + newStmt.(*Loop).Body[i] = ns + } } + return newStmt, changed } // stmt_while @@ -380,6 +538,12 @@ type While struct { Body []Statement } +func (s *While) CopyNode() *While { + copyNode := *s + copyNode.Body = append([]Statement(nil), copyNode.Body...) + return ©Node +} + func (s *While) Format(ctx *tree.FmtCtx) { ctx.WriteString("WHILE ") s.Condition.Format(ctx) @@ -398,11 +562,19 @@ func (s *While) PlpgSQLStatementTag() string { return "stmt_while" } -func (s *While) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) - for _, stmt := range s.Body { - stmt.WalkStmt(visitor) +func (s *While) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + newStmt, changed = visitor.Visit(s) + for i, stmt := range s.Body { + ns, ch := stmt.WalkStmt(visitor) + if ch { + changed = true + if newStmt == s { + newStmt = s.CopyNode() + } + newStmt.(*While).Body[i] = ns + } } + return newStmt, changed } // stmt_for @@ -424,11 +596,8 @@ func (s *ForInt) PlpgSQLStatementTag() string { return "stmt_for_int_loop" } -func (s *ForInt) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) - for _, stmt := range s.Body { - stmt.WalkStmt(visitor) - } +func (s *ForInt) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + panic(unimplemented.New("plpgsql visitor", "Unimplemented PLpgSQL visitor pattern")) } type ForQuery struct { @@ -445,11 +614,8 @@ func (s *ForQuery) PlpgSQLStatementTag() string { return "stmt_for_query_loop" } -func (s *ForQuery) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) - for _, stmt := range s.Body { - stmt.WalkStmt(visitor) - } +func (s *ForQuery) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + panic(unimplemented.New("plpgsql visitor", "Unimplemented PLpgSQL visitor pattern")) } type ForSelect struct { @@ -464,9 +630,8 @@ func (s *ForSelect) PlpgSQLStatementTag() string { return "stmt_query_select_loop" } -func (s *ForSelect) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) - s.ForQuery.WalkStmt(visitor) +func (s *ForSelect) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + panic(unimplemented.New("plpgsql visitor", "Unimplemented PLpgSQL visitor pattern")) } type ForCursor struct { @@ -482,9 +647,8 @@ func (s *ForCursor) PlpgSQLStatementTag() string { return "stmt_for_query_cursor_loop" } -func (s *ForCursor) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) - s.ForQuery.WalkStmt(visitor) +func (s *ForCursor) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + panic(unimplemented.New("plpgsql visitor", "Unimplemented PLpgSQL visitor pattern")) } type ForDynamic struct { @@ -500,9 +664,8 @@ func (s *ForDynamic) PlpgSQLStatementTag() string { return "stmt_for_dyn_loop" } -func (s *ForDynamic) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) - s.ForQuery.WalkStmt(visitor) +func (s *ForDynamic) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + panic(unimplemented.New("plpgsql visitor", "Unimplemented PLpgSQL visitor pattern")) } // stmt_foreach_a @@ -522,12 +685,8 @@ func (s *ForEachArray) PlpgSQLStatementTag() string { return "stmt_for_each_a" } -func (s *ForEachArray) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) - - for _, stmt := range s.Body { - stmt.WalkStmt(visitor) - } +func (s *ForEachArray) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + panic(unimplemented.New("plpgsql visitor", "Unimplemented PLpgSQL visitor pattern")) } // stmt_exit @@ -537,6 +696,11 @@ type Exit struct { Condition Expr } +func (s *Exit) CopyNode() *Exit { + copyNode := *s + return ©Node +} + func (s *Exit) Format(ctx *tree.FmtCtx) { ctx.WriteString("EXIT") if s.Label != "" { @@ -554,8 +718,9 @@ func (s *Exit) PlpgSQLStatementTag() string { return "stmt_exit" } -func (s *Exit) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) +func (s *Exit) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + newStmt, changed = visitor.Visit(s) + return newStmt, changed } // stmt_continue @@ -565,6 +730,11 @@ type Continue struct { Condition Expr } +func (s *Continue) CopyNode() *Continue { + copyNode := *s + return ©Node +} + func (s *Continue) Format(ctx *tree.FmtCtx) { ctx.WriteString("CONTINUE") if s.Label != "" { @@ -581,8 +751,9 @@ func (s *Continue) PlpgSQLStatementTag() string { return "stmt_continue" } -func (s *Continue) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) +func (s *Continue) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + newStmt, changed = visitor.Visit(s) + return newStmt, changed } // stmt_return @@ -592,6 +763,11 @@ type Return struct { RetVar Variable } +func (s *Return) CopyNode() *Return { + copyNode := *s + return ©Node +} + func (s *Return) Format(ctx *tree.FmtCtx) { ctx.WriteString("RETURN ") if s.Expr == nil { @@ -606,8 +782,9 @@ func (s *Return) PlpgSQLStatementTag() string { return "stmt_return" } -func (s *Return) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) +func (s *Return) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + newStmt, changed = visitor.Visit(s) + return newStmt, changed } type ReturnNext struct { @@ -623,8 +800,8 @@ func (s *ReturnNext) PlpgSQLStatementTag() string { return "stmt_return_next" } -func (s *ReturnNext) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) +func (s *ReturnNext) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + panic(unimplemented.New("plpgsql visitor", "Unimplemented PLpgSQL visitor pattern")) } type ReturnQuery struct { @@ -641,8 +818,8 @@ func (s *ReturnQuery) PlpgSQLStatementTag() string { return "stmt_return_query" } -func (s *ReturnQuery) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) +func (s *ReturnQuery) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + panic(unimplemented.New("plpgsql visitor", "Unimplemented PLpgSQL visitor pattern")) } // stmt_raise @@ -656,6 +833,13 @@ type Raise struct { Options []RaiseOption } +func (s *Raise) CopyNode() *Raise { + copyNode := *s + copyNode.Params = append([]Expr(nil), s.Params...) + copyNode.Options = append([]RaiseOption(nil), s.Options...) + return ©Node +} + func (s *Raise) Format(ctx *tree.FmtCtx) { ctx.WriteString("RAISE") if s.LogLevel != "" { @@ -700,8 +884,9 @@ func (s *Raise) PlpgSQLStatementTag() string { return "stmt_raise" } -func (s *Raise) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) +func (s *Raise) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + newStmt, changed = visitor.Visit(s) + return newStmt, changed } // stmt_assert @@ -711,6 +896,11 @@ type Assert struct { Message Expr } +func (s *Assert) CopyNode() *Assert { + copyNode := *s + return ©Node +} + func (s *Assert) Format(ctx *tree.FmtCtx) { // TODO(drewk): Pretty print the assert condition and message ctx.WriteString("ASSERT\n") @@ -720,8 +910,9 @@ func (s *Assert) PlpgSQLStatementTag() string { return "stmt_assert" } -func (s *Assert) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) +func (s *Assert) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + newStmt, changed = visitor.Visit(s) + return newStmt, changed } // stmt_execsql @@ -732,6 +923,12 @@ type Execute struct { Target []Variable } +func (s *Execute) CopyNode() *Execute { + copyNode := *s + copyNode.Target = append([]Variable(nil), copyNode.Target...) + return ©Node +} + func (s *Execute) Format(ctx *tree.FmtCtx) { s.SqlStmt.Format(ctx) if s.Target != nil { @@ -753,8 +950,9 @@ func (s *Execute) PlpgSQLStatementTag() string { return "stmt_exec_sql" } -func (s *Execute) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) +func (s *Execute) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + newStmt, changed = visitor.Visit(s) + return newStmt, changed } // stmt_dynexecute @@ -768,6 +966,12 @@ type DynamicExecute struct { Params []Expr } +func (s *DynamicExecute) CopyNode() *DynamicExecute { + copyNode := *s + copyNode.Params = append([]Expr(nil), s.Params...) + return ©Node +} + func (s *DynamicExecute) Format(ctx *tree.FmtCtx) { // TODO(drewk): Pretty print the original command ctx.WriteString("EXECUTE a dynamic command") @@ -787,8 +991,9 @@ func (s *DynamicExecute) PlpgSQLStatementTag() string { return "stmt_dyn_exec" } -func (s *DynamicExecute) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) +func (s *DynamicExecute) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + newStmt, changed = visitor.Visit(s) + return newStmt, changed } // stmt_perform @@ -804,8 +1009,8 @@ func (s *Perform) PlpgSQLStatementTag() string { return "stmt_perform" } -func (s *Perform) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) +func (s *Perform) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + panic(unimplemented.New("plpgsql visitor", "Unimplemented PLpgSQL visitor pattern")) } // stmt_call @@ -816,6 +1021,11 @@ type Call struct { Target Variable } +func (s *Call) CopyNode() *Call { + copyNode := *s + return ©Node +} + func (s *Call) Format(ctx *tree.FmtCtx) { // TODO(drewk): Correct the Call field and print the Expr and Target. if s.IsCall { @@ -829,8 +1039,9 @@ func (s *Call) PlpgSQLStatementTag() string { return "stmt_call" } -func (s *Call) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) +func (s *Call) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + newStmt, changed = visitor.Visit(s) + return newStmt, changed } // stmt_getdiag @@ -872,8 +1083,9 @@ func (s *GetDiagnostics) PlpgSQLStatementTag() string { return "stmt_get_diag" } -func (s *GetDiagnostics) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) +func (s *GetDiagnostics) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + newStmt, changed = visitor.Visit(s) + return newStmt, changed } // stmt_open @@ -884,6 +1096,11 @@ type Open struct { Query tree.Statement } +func (s *Open) CopyNode() *Open { + copyNode := *s + return ©Node +} + func (s *Open) Format(ctx *tree.FmtCtx) { ctx.WriteString("OPEN ") s.CurVar.Format(ctx) @@ -904,8 +1121,9 @@ func (s *Open) PlpgSQLStatementTag() string { return "stmt_open" } -func (s *Open) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) +func (s *Open) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + newStmt, changed = visitor.Visit(s) + return newStmt, changed } // stmt_fetch @@ -952,8 +1170,9 @@ func (s *Fetch) PlpgSQLStatementTag() string { return "stmt_fetch" } -func (s *Fetch) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) +func (s *Fetch) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + newStmt, changed = visitor.Visit(s) + return newStmt, changed } // stmt_close @@ -972,8 +1191,9 @@ func (s *Close) PlpgSQLStatementTag() string { return "stmt_close" } -func (s *Close) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) +func (s *Close) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + newStmt, changed = visitor.Visit(s) + return newStmt, changed } // stmt_commit @@ -989,8 +1209,9 @@ func (s *Commit) PlpgSQLStatementTag() string { return "stmt_commit" } -func (s *Commit) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) +func (s *Commit) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + newStmt, changed = visitor.Visit(s) + return newStmt, changed } // stmt_rollback @@ -1006,8 +1227,9 @@ func (s *Rollback) PlpgSQLStatementTag() string { return "stmt_rollback" } -func (s *Rollback) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) +func (s *Rollback) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + newStmt, changed = visitor.Visit(s) + return newStmt, changed } // stmt_null @@ -1023,6 +1245,7 @@ func (s *Null) PlpgSQLStatementTag() string { return "stmt_null" } -func (s *Null) WalkStmt(visitor StatementVisitor) { - visitor.Visit(s) +func (s *Null) WalkStmt(visitor StatementVisitor) (newStmt Statement, changed bool) { + newStmt, changed = visitor.Visit(s) + return newStmt, changed } diff --git a/pkg/sql/sem/plpgsqltree/utils/plpg_visitor.go b/pkg/sql/sem/plpgsqltree/utils/plpg_visitor.go index 0e775e864efb..437419dd8611 100644 --- a/pkg/sql/sem/plpgsqltree/utils/plpg_visitor.go +++ b/pkg/sql/sem/plpgsqltree/utils/plpg_visitor.go @@ -58,7 +58,9 @@ type telemetryVisitor struct { var _ plpgsqltree.StatementVisitor = &telemetryVisitor{} // Visit implements the StatementVisitor interface -func (v *telemetryVisitor) Visit(stmt plpgsqltree.Statement) { +func (v *telemetryVisitor) Visit( + stmt plpgsqltree.Statement, +) (newStmt plpgsqltree.Statement, changed bool) { taggedStmt, ok := stmt.(plpgsqltree.TaggedStatement) if !ok { v.Err = errors.AssertionFailedf("no tag found for stmt %q", stmt) @@ -75,6 +77,7 @@ func (v *telemetryVisitor) Visit(stmt plpgsqltree.Statement) { } v.Err = nil + return stmt, false } // MakePLpgSQLTelemetryVisitor makes a plpgsql telemetry visitor, for capturing @@ -114,3 +117,268 @@ func ParseAndCollectTelemetryForPLpgSQLFunc(stmt *tree.CreateRoutine) error { } return unimp.New("plpgsql", "plpgsql not supported in user-defined functions") } + +// SQLStmtVisitor calls Fn for every SQL statement and expression found while +// walking the PLpgSQL AST. Since PLpgSQL nodes may have statement and +// expression fields that are nil, Fn should handle the nil case. +type SQLStmtVisitor struct { + Fn tree.SimpleVisitFn + Err error +} + +var _ plpgsqltree.StatementVisitor = &SQLStmtVisitor{} + +func (v *SQLStmtVisitor) Visit( + stmt plpgsqltree.Statement, +) (newStmt plpgsqltree.Statement, changed bool) { + if v.Err != nil { + return stmt, false + } + newStmt = stmt + var s tree.Statement + var e tree.Expr + switch t := stmt.(type) { + case *plpgsqltree.CursorDeclaration: + s, v.Err = tree.SimpleStmtVisit(t.Query, v.Fn) + if v.Err != nil { + return stmt, false + } + changed = t.Query != s + if changed { + cpy := t.CopyNode() + cpy.Query = s + newStmt = cpy + } + case *plpgsqltree.Execute: + s, v.Err = tree.SimpleStmtVisit(t.SqlStmt, v.Fn) + if v.Err != nil { + return stmt, false + } + changed = t.SqlStmt != s + if changed { + cpy := t.CopyNode() + cpy.SqlStmt = s + newStmt = cpy + } + case *plpgsqltree.Open: + s, v.Err = tree.SimpleStmtVisit(t.Query, v.Fn) + if v.Err != nil { + return stmt, false + } + changed = t.Query != s + if changed { + cpy := t.CopyNode() + cpy.Query = s + newStmt = cpy + } + case *plpgsqltree.Declaration: + e, v.Err = tree.SimpleVisit(t.Expr, v.Fn) + if v.Err != nil { + return stmt, false + } + changed = t.Expr != e + if changed { + cpy := t.CopyNode() + cpy.Expr = e + newStmt = cpy + } + + case *plpgsqltree.Assignment: + e, v.Err = tree.SimpleVisit(t.Value, v.Fn) + if v.Err != nil { + return stmt, false + } + changed = t.Value != e + if changed { + cpy := t.CopyNode() + cpy.Value = e + newStmt = cpy + } + case *plpgsqltree.If: + e, v.Err = tree.SimpleVisit(t.Condition, v.Fn) + if v.Err != nil { + return stmt, false + } + changed = t.Condition != e + if changed { + cpy := t.CopyNode() + cpy.Condition = e + newStmt = cpy + } + case *plpgsqltree.ElseIf: + e, v.Err = tree.SimpleVisit(t.Condition, v.Fn) + if v.Err != nil { + return stmt, false + } + changed = t.Condition != e + if changed { + cpy := t.CopyNode() + cpy.Condition = e + newStmt = cpy + } + case *plpgsqltree.While: + e, v.Err = tree.SimpleVisit(t.Condition, v.Fn) + if v.Err != nil { + return stmt, false + } + changed = t.Condition != e + if changed { + cpy := t.CopyNode() + cpy.Condition = e + newStmt = cpy + } + e, v.Err = tree.SimpleVisit(t.Condition, v.Fn) + if v.Err != nil { + return stmt, false + } + changed = t.Condition != e + if changed { + cpy := t.CopyNode() + cpy.Condition = e + newStmt = cpy + } + case *plpgsqltree.Exit: + e, v.Err = tree.SimpleVisit(t.Condition, v.Fn) + if v.Err != nil { + return stmt, false + } + changed = t.Condition != e + if changed { + cpy := t.CopyNode() + cpy.Condition = e + newStmt = cpy + } + case *plpgsqltree.Continue: + e, v.Err = tree.SimpleVisit(t.Condition, v.Fn) + if v.Err != nil { + return stmt, false + } + changed = t.Condition != e + if changed { + cpy := t.CopyNode() + cpy.Condition = e + newStmt = cpy + } + case *plpgsqltree.Return: + e, v.Err = tree.SimpleVisit(t.Expr, v.Fn) + if v.Err != nil { + return stmt, false + } + changed = t.Expr != e + if changed { + cpy := t.CopyNode() + cpy.Expr = e + newStmt = cpy + } + case *plpgsqltree.Raise: + for i, p := range t.Params { + e, v.Err = tree.SimpleVisit(p, v.Fn) + if v.Err != nil { + return stmt, false + } + changed = changed || (t.Params[i] != e) + if changed { + if newStmt != stmt { + cpy := t.CopyNode() + newStmt = cpy + } + newStmt.(*plpgsqltree.Raise).Params[i] = e + } + } + for i, p := range t.Options { + e, v.Err = tree.SimpleVisit(p.Expr, v.Fn) + if v.Err != nil { + return stmt, false + } + changed = changed || (t.Options[i].Expr != e) + if changed { + if newStmt != stmt { + cpy := t.CopyNode() + newStmt = cpy + } + newStmt.(*plpgsqltree.Raise).Options[i].Expr = e + } + } + case *plpgsqltree.Assert: + e, v.Err = tree.SimpleVisit(t.Condition, v.Fn) + if v.Err != nil { + return stmt, false + } + changed = t.Condition != e + if changed { + cpy := t.CopyNode() + cpy.Condition = e + newStmt = cpy + } + + case *plpgsqltree.DynamicExecute: + for i, p := range t.Params { + e, v.Err = tree.SimpleVisit(p, v.Fn) + if v.Err != nil { + return stmt, false + } + changed = changed || (t.Params[i] != e) + if changed { + if newStmt != stmt { + cpy := t.CopyNode() + newStmt = cpy + } + newStmt.(*plpgsqltree.DynamicExecute).Params[i] = e + } + } + case *plpgsqltree.Call: + e, v.Err = tree.SimpleVisit(t.Expr, v.Fn) + if v.Err != nil { + return stmt, false + } + changed = t.Expr != e + if changed { + cpy := t.CopyNode() + cpy.Expr = e + newStmt = cpy + } + + case *plpgsqltree.ForInt, *plpgsqltree.ForSelect, *plpgsqltree.ForCursor, + *plpgsqltree.ForDynamic, *plpgsqltree.ForEachArray, *plpgsqltree.ReturnNext, + *plpgsqltree.ReturnQuery, *plpgsqltree.Perform: + panic(unimp.New("plpgsql visitor", "Unimplemented PLpgSQL visitor")) + } + if v.Err != nil { + return stmt, false + } + return newStmt, changed +} + +// TypeRefVisitor calls the given replace function on each type reference +// contained in the visited PLpgSQL statements. Note that this currently only +// includes `Declaration`. SQL statements and expressions are not visited. +type TypeRefVisitor struct { + Fn func(typ tree.ResolvableTypeReference) (newTyp tree.ResolvableTypeReference, err error) + Err error +} + +var _ plpgsqltree.StatementVisitor = &TypeRefVisitor{} + +func (v *TypeRefVisitor) Visit( + stmt plpgsqltree.Statement, +) (newStmt plpgsqltree.Statement, changed bool) { + if v.Err != nil { + return stmt, false + } + newStmt = stmt + if t, ok := stmt.(*plpgsqltree.Declaration); ok { + var newTyp tree.ResolvableTypeReference + newTyp, v.Err = v.Fn(t.Typ) + if v.Err != nil { + return stmt, false + } + changed = t.Typ != newTyp + if changed { + if newStmt == stmt { + newStmt = t.CopyNode() + newStmt.(*plpgsqltree.Declaration).Typ = newTyp + } + } + } + return newStmt, changed +} diff --git a/pkg/sql/sem/plpgsqltree/visitor.go b/pkg/sql/sem/plpgsqltree/visitor.go index cbad8e13fe51..9526d480f27c 100644 --- a/pkg/sql/sem/plpgsqltree/visitor.go +++ b/pkg/sql/sem/plpgsqltree/visitor.go @@ -14,10 +14,11 @@ package plpgsqltree // a statement walk. type StatementVisitor interface { // Visit is called during a statement walk. - Visit(stmt Statement) + Visit(stmt Statement) (newStmt Statement, changed bool) } // Walk traverses the plpgsql statement. -func Walk(v StatementVisitor, stmt Statement) { - stmt.WalkStmt(v) +func Walk(v StatementVisitor, stmt Statement) Statement { + newStmt, _ := stmt.WalkStmt(v) + return newStmt } diff --git a/pkg/sql/show_create_clauses.go b/pkg/sql/show_create_clauses.go index 06f46e79e521..9acec2e15ac8 100644 --- a/pkg/sql/show_create_clauses.go +++ b/pkg/sql/show_create_clauses.go @@ -25,8 +25,11 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/catalog/schemaexpr" "github.com/cockroachdb/cockroach/pkg/sql/catalog/tabledesc" "github.com/cockroachdb/cockroach/pkg/sql/parser" + plpgsql "github.com/cockroachdb/cockroach/pkg/sql/plpgsql/parser" "github.com/cockroachdb/cockroach/pkg/sql/rowenc" "github.com/cockroachdb/cockroach/pkg/sql/sem/catconstants" + "github.com/cockroachdb/cockroach/pkg/sql/sem/plpgsqltree" + "github.com/cockroachdb/cockroach/pkg/sql/sem/plpgsqltree/utils" "github.com/cockroachdb/cockroach/pkg/sql/sem/semenumpb" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/sql/sessiondata" @@ -193,7 +196,7 @@ func formatViewQueryForDisplay( } // Convert sequences referenced by ID in the view back to their names. - sequenceReplacedViewQuery, err := formatQuerySequencesForDisplay(ctx, semaCtx, typeReplacedViewQuery, false /* multiStmt */) + sequenceReplacedViewQuery, err := formatQuerySequencesForDisplay(ctx, semaCtx, typeReplacedViewQuery, false /* multiStmt */, catpb.Function_SQL) if err != nil { log.Warningf(ctx, "error converting sequence IDs to names for view %s (%v): %+v", desc.GetName(), desc.GetID(), err) @@ -207,9 +210,16 @@ func formatViewQueryForDisplay( // looks for sequence IDs in the statement. If it finds any, // it will replace the IDs with the descriptor's fully qualified name. func formatQuerySequencesForDisplay( - ctx context.Context, semaCtx *tree.SemaContext, queries string, multiStmt bool, + ctx context.Context, + semaCtx *tree.SemaContext, + queries string, + multiStmt bool, + lang catpb.Function_Language, ) (string, error) { replaceFunc := func(expr tree.Expr) (recurse bool, newExpr tree.Expr, err error) { + if expr == nil { + return false, expr, nil + } newExpr, err = schemaexpr.ReplaceSequenceIDsWithFQNames(ctx, expr, semaCtx) if err != nil { return false, expr, err @@ -217,37 +227,51 @@ func formatQuerySequencesForDisplay( return false, newExpr, nil } - var stmts tree.Statements - if multiStmt { - parsedStmts, err := parser.Parse(queries) - if err != nil { - return "", err + fmtCtx := tree.NewFmtCtx(tree.FmtSimple) + switch lang { + case catpb.Function_SQL: + var stmts tree.Statements + if multiStmt { + parsedStmts, err := parser.Parse(queries) + if err != nil { + return "", err + } + stmts = make(tree.Statements, len(parsedStmts)) + for i, stmt := range parsedStmts { + stmts[i] = stmt.AST + } + } else { + stmt, err := parser.ParseOne(queries) + if err != nil { + return "", err + } + stmts = tree.Statements{stmt.AST} } - stmts = make(tree.Statements, len(parsedStmts)) - for i, stmt := range parsedStmts { - stmts[i] = stmt.AST + + for i, stmt := range stmts { + newStmt, err := tree.SimpleStmtVisit(stmt, replaceFunc) + if err != nil { + return "", err + } + if i > 0 { + fmtCtx.WriteString("\n") + } + fmtCtx.FormatNode(newStmt) + if multiStmt { + fmtCtx.WriteString(";") + } } - } else { - stmt, err := parser.ParseOne(queries) + case catpb.Function_PLPGSQL: + var stmts plpgsqltree.Statement + plstmt, err := plpgsql.Parse(queries) if err != nil { - return "", err + return "", errors.Wrap(err, "failed to parse query string") } - stmts = tree.Statements{stmt.AST} - } + stmts = plstmt.AST - fmtCtx := tree.NewFmtCtx(tree.FmtSimple) - for i, stmt := range stmts { - newStmt, err := tree.SimpleStmtVisit(stmt, replaceFunc) - if err != nil { - return "", err - } - if i > 0 { - fmtCtx.WriteString("\n") - } + v := utils.SQLStmtVisitor{Fn: replaceFunc} + newStmt := plpgsqltree.Walk(&v, stmts) fmtCtx.FormatNode(newStmt) - if multiStmt { - fmtCtx.WriteString(";") - } } return fmtCtx.CloseAndGetString(), nil } @@ -308,7 +332,7 @@ func formatViewQueryTypesForDisplay( } // formatFunctionQueryTypesForDisplay is similar to -// formatViewQueryTypesForDisplay but can only be used for function. +// formatViewQueryTypesForDisplay but can only be used for functions. // nil is used as the table descriptor for schemaexpr.FormatExprForDisplay call. // This is fine assuming that UDFs cannot be created with expression casting a // column/var to an enum in function body. This is super rare case for now, and @@ -319,8 +343,14 @@ func formatFunctionQueryTypesForDisplay( semaCtx *tree.SemaContext, sessionData *sessiondata.SessionData, queries string, + lang catpb.Function_Language, ) (string, error) { + // replaceFunc is a visitor function that replaces user defined type IDs in + // SQL expressions with their names. replaceFunc := func(expr tree.Expr) (recurse bool, newExpr tree.Expr, err error) { + if expr == nil { + return false, expr, nil + } // We need to resolve the type to check if it's user-defined. If not, // no other work is needed. var typRef tree.ResolvableTypeReference @@ -352,29 +382,80 @@ func formatFunctionQueryTypesForDisplay( } return false, newExpr, nil } - - var stmts tree.Statements - parsedStmts, err := parser.Parse(queries) - if err != nil { - return "", errors.Wrap(err, "failed to parse query") - } - stmts = make(tree.Statements, len(parsedStmts)) - for i, stmt := range parsedStmts { - stmts[i] = stmt.AST + // replaceTypeFunc is a visitor function that replaces type annotations + // containing user defined types IDs with their name. This is currently only + // necessary for some kinds of PLpgSQL statements. + replaceTypeFunc := func(typ tree.ResolvableTypeReference) (newTyp tree.ResolvableTypeReference, err error) { + if typ == nil { + return typ, nil + } + // semaCtx may be nil if this is a virtual view being created at + // init time. + var typeResolver tree.TypeReferenceResolver + if semaCtx != nil { + typeResolver = semaCtx.TypeResolver + } + var t *types.T + t, err = tree.ResolveType(ctx, typ, typeResolver) + if err != nil { + return typ, err + } + if !t.UserDefined() { + return typ, nil + } + name := t.TypeMeta.Name + typname := tree.MakeTypeNameWithPrefix(tree.ObjectNamePrefix{ + CatalogName: tree.Name(name.Catalog), + SchemaName: tree.Name(name.Schema), + ExplicitCatalog: name.Catalog != "", + ExplicitSchema: name.ExplicitSchema, + }, name.Name) + ref := typname.ToUnresolvedObjectName() + return ref, nil } fmtCtx := tree.NewFmtCtx(tree.FmtSimple) - for i, stmt := range stmts { - newStmt, err := tree.SimpleStmtVisit(stmt, replaceFunc) + switch lang { + case catpb.Function_SQL: + var stmts tree.Statements + parsedStmts, err := parser.Parse(queries) if err != nil { - return "", err + return "", errors.Wrap(err, "failed to parse query") } - if i > 0 { - fmtCtx.WriteString("\n") + stmts = make(tree.Statements, len(parsedStmts)) + for i, stmt := range parsedStmts { + stmts[i] = stmt.AST } + + for i, stmt := range stmts { + newStmt, err := tree.SimpleStmtVisit(stmt, replaceFunc) + if err != nil { + return "", err + } + if i > 0 { + fmtCtx.WriteString("\n") + } + fmtCtx.FormatNode(newStmt) + fmtCtx.WriteString(";") + } + case catpb.Function_PLPGSQL: + var stmts plpgsqltree.Statement + plstmt, err := plpgsql.Parse(queries) + if err != nil { + return "", errors.Wrap(err, "failed to parse query string") + } + stmts = plstmt.AST + + v := utils.SQLStmtVisitor{Fn: replaceFunc} + newStmt := plpgsqltree.Walk(&v, stmts) + // Some PLpgSQL statements (i.e., declarations), may contain type + // annotations containing the UDT. We need to walk the AST to replace them, + // too. + v2 := utils.TypeRefVisitor{Fn: replaceTypeFunc} + newStmt = plpgsqltree.Walk(&v2, newStmt) fmtCtx.FormatNode(newStmt) - fmtCtx.WriteString(";") } + return fmtCtx.CloseAndGetString(), nil }