Skip to content

Commit

Permalink
sql: support sequence and udt name rewriting in plpgsql
Browse files Browse the repository at this point in the history
CRDB rewrites sequence and UDT names as IDs in views and functions so
that if the sequence or UDT is renamed the views and functions using
them don't break. This PR adds support for this in PLpgSQL.

Epic: None
Fixes: #115627

Release note: Fixes a bug in PLpgSQL where altering the name of a
sequence or UDT that was used in a PLpgSQL function or procedure could
break them. This is only present in 23.2 alpha and beta releases.
  • Loading branch information
rharding6373 committed Dec 8, 2023
1 parent 37ad01a commit 015c973
Show file tree
Hide file tree
Showing 11 changed files with 1,179 additions and 276 deletions.
139 changes: 106 additions & 33 deletions pkg/ccl/logictestccl/testdata/logic_test/udf_rewrite
Original file line number Diff line number Diff line change
Expand Up @@ -37,63 +37,136 @@ SET use_declarative_schema_changer = 'on'

subtest rewrite_plpgsql

statement ok
DROP FUNCTION IF EXISTS f_rewrite

statement ok
CREATE OR REPLACE FUNCTION f_rewrite() RETURNS INT AS
$$
DECLARE
i INT := nextval('seq');
curs REFCURSOR := nextval('seq')::STRING;
curs2 CURSOR FOR SELECT nextval('seq');
BEGIN
SELECT nextval('seq');
RAISE NOTICE USING MESSAGE = format('next val: %d',nextval('seq'));
RAISE NOTICE 'val1: %, val2: %', nextval('seq'), nextval('seq');
WHILE nextval('seq') < 10 LOOP
i = nextval('seq');
SELECT nextval('seq');
IF nextval('seq') = 1 THEN
CONTINUE;
ELSIF nextval('seq') = 2 THEN
SELECT v INTO i FROM nextval('seq') AS v(INT);
END IF;
END LOOP;
OPEN curs FOR SELECT nextval('seq');
RETURN nextval('seq');
END
$$ LANGUAGE PLPGSQL
$$ LANGUAGE PLPGSQL;

query T
SELECT get_body_str('f_rewrite');
----
"BEGIN\nSELECT nextval('seq':::STRING);\nEND\n;"

statement ok
CREATE OR REPLACE FUNCTION f_rewrite() RETURNS INT AS
$$
BEGIN
INSERT INTO t_rewrite(v) VALUES (nextval('seq')) RETURNING v;
END
$$ LANGUAGE PLPGSQL
"DECLARE\ni INT8 := nextval(106:::REGCLASS);\ncurs REFCURSOR := nextval(106:::REGCLASS)::STRING;\ncurs2 CURSOR FOR SELECT nextval(106:::REGCLASS);\nBEGIN\nRAISE notice\nUSING MESSAGE = format('next val: %d':::STRING, nextval(106:::REGCLASS));\nRAISE notice 'val1: %, val2: %', nextval(106:::REGCLASS), nextval(106:::REGCLASS);\nWHILE nextval(106:::REGCLASS) < 10:::INT8 LOOP\ni := nextval(106:::REGCLASS);\nSELECT nextval(106:::REGCLASS);\nIF nextval(106:::REGCLASS) = 1:::INT8 THEN\n\tCONTINUE;\nELSIF nextval(106:::REGCLASS) = 2:::INT8 THEN\n\tSELECT v FROM ROWS FROM (nextval(106:::REGCLASS)) AS v (\"int\") INTO i;\nEND IF;\nEND LOOP;\nOPEN curs FOR SELECT nextval(106:::REGCLASS);\nRETURN nextval(106:::REGCLASS);\nEND\n;"

query T
SELECT get_body_str('f_rewrite');
query TT
SHOW CREATE FUNCTION f_rewrite;
----
"BEGIN\nINSERT INTO test.public.t_rewrite(v) VALUES (nextval('seq':::STRING)) RETURNING v;\nEND\n;"
f_rewrite CREATE FUNCTION public.f_rewrite()
RETURNS INT8
VOLATILE
NOT LEAKPROOF
CALLED ON NULL INPUT
LANGUAGE plpgsql
AS $$
DECLARE
i INT8 := nextval('public.seq'::REGCLASS);
curs REFCURSOR := nextval('public.seq'::REGCLASS)::STRING;
curs2 CURSOR FOR SELECT nextval('public.seq'::REGCLASS);
BEGIN
RAISE notice
USING MESSAGE = format('next val: %d':::STRING, nextval('public.seq'::REGCLASS));
RAISE notice 'val1: %, val2: %', nextval('public.seq'::REGCLASS), nextval('public.seq'::REGCLASS);
WHILE nextval('public.seq'::REGCLASS) < 10:::INT8 LOOP
i := nextval('public.seq'::REGCLASS);
SELECT nextval('public.seq'::REGCLASS);
IF nextval('public.seq'::REGCLASS) = 1:::INT8 THEN
CONTINUE;
ELSIF nextval('public.seq'::REGCLASS) = 2:::INT8 THEN
SELECT v FROM ROWS FROM (nextval('public.seq'::REGCLASS)) AS v ("int") INTO i;
END IF;
END LOOP;
OPEN curs FOR SELECT nextval('public.seq'::REGCLASS);
RETURN nextval('public.seq'::REGCLASS);
END
;
$$


statement ok
DROP FUNCTION f_rewrite();

statement ok
CREATE OR REPLACE FUNCTION f_rewrite() RETURNS weekday AS
$$
DECLARE
day weekday := 'wednesday'::weekday;
curs REFCURSOR := 'monday'::weekday::STRING;
curs2 CURSOR FOR SELECT 'tuesday'::weekday;
BEGIN
SELECT 'wednesday'::weekday;
RAISE NOTICE USING MESSAGE = format('val: %d','wednesday'::weekday);
RAISE NOTICE 'val1: %, val2: %', 'wednesday'::weekday, 'thursday'::weekday;
WHILE day != 'wednesday'::weekday LOOP
day = 'friday'::weekday;
SELECT 'wednesday'::weekday;
IF day = 'wednesday'::weekday THEN
CONTINUE;
ELSIF day = 'monday'::weekday THEN
SELECT 'tuesday'::weekday INTO day;
END IF;
END LOOP;
OPEN curs FOR SELECT 'wednesday'::weekday;
RETURN 'wednesday'::weekday;
END
$$ LANGUAGE PLPGSQL
$$ LANGUAGE PLPGSQL;

query T
SELECT get_body_str('f_rewrite');
----
"BEGIN\nSELECT 'wednesday'::@100107;\nEND\n;"

statement ok
CREATE OR REPLACE FUNCTION f_rewrite() RETURNS weekday AS
$$
BEGIN
UPDATE t_rewrite SET w = 'thursday'::weekday WHERE w = 'wednesday'::weekday RETURNING w;
END
$$ LANGUAGE PLPGSQL
"DECLARE\nday @100107 := b'\\x80':::@100107;\ncurs REFCURSOR := b' ':::@100107::STRING;\ncurs2 CURSOR FOR SELECT b'@':::@100107;\nBEGIN\nRAISE notice\nUSING MESSAGE = format('val: %d':::STRING, b'\\x80':::@100107);\nRAISE notice 'val1: %, val2: %', b'\\x80':::@100107, b'\\xa0':::@100107;\nWHILE day != b'\\x80':::@100107 LOOP\nday := b'\\xc0':::@100107;\nSELECT b'\\x80':::@100107;\nIF day = b'\\x80':::@100107 THEN\n\tCONTINUE;\nELSIF day = b' ':::@100107 THEN\n\tSELECT b'@':::@100107 INTO day;\nEND IF;\nEND LOOP;\nOPEN curs FOR SELECT b'\\x80':::@100107;\nRETURN b'\\x80':::@100107;\nEND\n;"

query T
SELECT get_body_str('f_rewrite');
query TT
SHOW CREATE FUNCTION f_rewrite;
----
"BEGIN\nUPDATE test.public.t_rewrite SET w = 'thursday'::@100107 WHERE w = 'wednesday'::@100107 RETURNING w;\nEND\n;"
f_rewrite CREATE FUNCTION public.f_rewrite()
RETURNS test.public.weekday
VOLATILE
NOT LEAKPROOF
CALLED ON NULL INPUT
LANGUAGE plpgsql
AS $$
DECLARE
day weekday := 'wednesday':::test.public.weekday;
curs REFCURSOR := 'monday':::test.public.weekday::STRING;
curs2 CURSOR FOR SELECT 'tuesday':::test.public.weekday;
BEGIN
RAISE notice
USING MESSAGE = format('val: %d':::STRING, 'wednesday':::test.public.weekday);
RAISE notice 'val1: %, val2: %', 'wednesday':::test.public.weekday, 'thursday':::test.public.weekday;
WHILE day != 'wednesday':::test.public.weekday LOOP
day := 'friday':::test.public.weekday;
SELECT 'wednesday':::test.public.weekday;
IF day = 'wednesday':::test.public.weekday THEN
CONTINUE;
ELSIF day = 'monday':::test.public.weekday THEN
SELECT 'tuesday':::test.public.weekday INTO day;
END IF;
END LOOP;
OPEN curs FOR SELECT 'wednesday':::test.public.weekday;
RETURN 'wednesday':::test.public.weekday;
END
;
$$

statement ok
DROP FUNCTION f_rewrite();

subtest end

Expand All @@ -110,7 +183,7 @@ $$ LANGUAGE PLPGSQL
query T
SELECT get_body_str('p_rewrite');
----
"BEGIN\nINSERT INTO test.public.t_rewrite(v) VALUES (nextval('seq':::STRING)) RETURNING v;\nEND\n;"
"BEGIN\nINSERT INTO test.public.t_rewrite(v) VALUES (nextval(106:::REGCLASS)) RETURNING v;\nEND\n;"

statement ok
DROP PROCEDURE p_rewrite();
Expand All @@ -126,6 +199,6 @@ $$ LANGUAGE PLPGSQL
query T
SELECT get_body_str('p_rewrite');
----
"BEGIN\nUPDATE test.public.t_rewrite SET w = 'thursday'::@100107 WHERE w = 'wednesday'::@100107 RETURNING w;\nEND\n;"
"BEGIN\nUPDATE test.public.t_rewrite SET w = b'\\xa0':::@100107 WHERE w = b'\\x80':::@100107 RETURNING w;\nEND\n;"

subtest end
2 changes: 2 additions & 0 deletions pkg/sql/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,8 @@ go_library(
"//pkg/sql/sem/catconstants",
"//pkg/sql/sem/catid",
"//pkg/sql/sem/eval",
"//pkg/sql/sem/plpgsqltree",
"//pkg/sql/sem/plpgsqltree/utils",
"//pkg/sql/sem/semenumpb",
"//pkg/sql/sem/transform",
"//pkg/sql/sem/tree",
Expand Down
31 changes: 8 additions & 23 deletions pkg/sql/crdb_internal.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ import (
"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode"
"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror"
"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgnotice"
plpgsql "github.com/cockroachdb/cockroach/pkg/sql/plpgsql/parser"
"github.com/cockroachdb/cockroach/pkg/sql/privilege"
"github.com/cockroachdb/cockroach/pkg/sql/protoreflect"
"github.com/cockroachdb/cockroach/pkg/sql/roleoption"
Expand Down Expand Up @@ -3579,29 +3578,15 @@ func createRoutinePopulate(
for i := range treeNode.Options {
if body, ok := treeNode.Options[i].(tree.RoutineBodyStr); ok {
bodyStr := string(body)
switch fnDesc.GetLanguage() {
case catpb.Function_SQL:
bodyStr, err = formatFunctionQueryTypesForDisplay(ctx, &p.semaCtx, p.SessionData(), bodyStr)
if err != nil {
return err
}
bodyStr, err = formatQuerySequencesForDisplay(ctx, &p.semaCtx, bodyStr, true /* multiStmt */)
if err != nil {
return err
}
bodyStr = "\n" + bodyStr + "\n"
case catpb.Function_PLPGSQL:
// TODO(drewk): integrate this with the SQL case above.
plpgsqlStmt, err := plpgsql.Parse(bodyStr)
if err != nil {
return err
}
fmtCtx := tree.NewFmtCtx(tree.FmtParsable)
fmtCtx.FormatNode(plpgsqlStmt.AST)
bodyStr = "\n" + fmtCtx.CloseAndGetString()
default:
return errors.AssertionFailedf("unexpected function language: %s", fnDesc.GetLanguage())
bodyStr, err = formatFunctionQueryTypesForDisplay(ctx, &p.semaCtx, p.SessionData(), bodyStr, fnDesc.GetLanguage())
if err != nil {
return err
}
bodyStr, err = formatQuerySequencesForDisplay(ctx, &p.semaCtx, bodyStr, true /* multiStmt */, fnDesc.GetLanguage())
if err != nil {
return err
}
bodyStr = "\n" + bodyStr + "\n"
stmtStrs := strings.Split(bodyStr, "\n")
for i := range stmtStrs {
if stmtStrs[i] != "" {
Expand Down
29 changes: 10 additions & 19 deletions pkg/sql/create_function.go
Original file line number Diff line number Diff line change
Expand Up @@ -455,26 +455,17 @@ func setFuncOptions(
}
}

switch lang {
case catpb.Function_SQL:
// Replace any sequence names in the function body with IDs.
seqReplacedFuncBody, err := replaceSeqNamesWithIDs(params.ctx, params.p, body, true)
if err != nil {
return err
}
typeReplacedFuncBody, err := serializeUserDefinedTypes(
params.ctx, params.p.SemaCtx(), seqReplacedFuncBody, true /* multiStmt */, "UDFs",
)
if err != nil {
return err
}
udfDesc.SetFuncBody(typeReplacedFuncBody)
case catpb.Function_PLPGSQL:
// TODO(#115627): make replaceSeqNamesWithIDs and serializeUserDefinedTypes
// play nice with PL/pgSQL.
udfDesc.SetFuncBody(body)
// Replace any sequence names in the function body with IDs.
seqReplacedFuncBody, err := replaceSeqNamesWithIDsLang(params.ctx, params.p, body, true, lang)
if err != nil {
return err
}

typeReplacedFuncBody, err := serializeUserDefinedTypesLang(
params.ctx, params.p.SemaCtx(), seqReplacedFuncBody, true /* multiStmt */, "UDFs", lang)
if err != nil {
return err
}
udfDesc.SetFuncBody(typeReplacedFuncBody)
return nil
}

Expand Down
Loading

0 comments on commit 015c973

Please sign in to comment.