Skip to content

Commit

Permalink
sql: support sequence and udt name rewriting in plpgsql
Browse files Browse the repository at this point in the history
CRDB rewrites sequence and UDT names as IDs in views and functions so
that if the sequence or UDT is renamed the views and functions using
them don't break. This PR adds support for this in PLpgSQL.

Epic: None
Fixes: #115627

Release note (sql change): Fixes a bug in PLpgSQL where altering the
name of a sequence or UDT that was used in a PLpgSQL function or
procedure could break them. This is only present in 23.2 alpha and beta
releases.
  • Loading branch information
rharding6373 committed Dec 12, 2023
1 parent 08c0e5e commit e0f9c28
Show file tree
Hide file tree
Showing 13 changed files with 1,045 additions and 273 deletions.
189 changes: 156 additions & 33 deletions pkg/ccl/logictestccl/testdata/logic_test/udf_rewrite
Original file line number Diff line number Diff line change
Expand Up @@ -37,63 +37,186 @@ SET use_declarative_schema_changer = 'on'

subtest rewrite_plpgsql

statement ok
DROP FUNCTION IF EXISTS f_rewrite

statement ok
CREATE OR REPLACE FUNCTION f_rewrite() RETURNS INT AS
$$
DECLARE
i INT := nextval('seq');
j INT := nextval('seq');
curs REFCURSOR := nextval('seq')::STRING;
curs2 CURSOR FOR SELECT nextval('seq');
BEGIN
SELECT nextval('seq');
RAISE NOTICE USING MESSAGE = format('next val: %d',nextval('seq'));
RAISE NOTICE 'val1: %, val2: %', nextval('seq'), nextval('seq');
WHILE nextval('seq') < 10 LOOP
i = nextval('seq');
SELECT nextval('seq');
IF nextval('seq') = 1 THEN
SELECT nextval('seq');
SELECT nextval('seq');
CONTINUE;
ELSIF nextval('seq') = 2 THEN
SELECT v INTO i FROM nextval('seq') AS v(INT);
ELSIF nextval('seq') = 3 THEN
SELECT nextval('seq');
SELECT nextval('seq');
END IF;
END LOOP;
OPEN curs FOR SELECT nextval('seq');
RETURN nextval('seq');
EXCEPTION
WHEN division_by_zero THEN
RAISE NOTICE USING MESSAGE = format('next val: %d',nextval('seq'));
WHEN not_null_violation THEN
SELECT nextval('seq');
SELECT nextval('seq');
RAISE NOTICE USING MESSAGE = format('next val: %d',nextval('seq'));
END
$$ LANGUAGE PLPGSQL
$$ LANGUAGE PLPGSQL;

query T
SELECT get_body_str('f_rewrite');
----
"BEGIN\nSELECT nextval('seq':::STRING);\nEND\n;"

statement ok
CREATE OR REPLACE FUNCTION f_rewrite() RETURNS INT AS
$$
BEGIN
INSERT INTO t_rewrite(v) VALUES (nextval('seq')) RETURNING v;
END
$$ LANGUAGE PLPGSQL
"DECLARE\ni INT8 := nextval(106:::REGCLASS);\nj INT8 := nextval(106:::REGCLASS);\ncurs REFCURSOR := nextval(106:::REGCLASS)::STRING;\ncurs2 CURSOR FOR SELECT nextval(106:::REGCLASS);\nBEGIN\nRAISE notice\nUSING MESSAGE = format('next val: %d':::STRING, nextval(106:::REGCLASS));\nRAISE notice 'val1: %, val2: %', nextval(106:::REGCLASS), nextval(106:::REGCLASS);\nWHILE nextval(106:::REGCLASS) < 10:::INT8 LOOP\ni := nextval(106:::REGCLASS);\nSELECT nextval(106:::REGCLASS);\nIF nextval(106:::REGCLASS) = 1:::INT8 THEN\n\tSELECT nextval(106:::REGCLASS);\n\tSELECT nextval(106:::REGCLASS);\n\tCONTINUE;\nELSIF nextval(106:::REGCLASS) = 2:::INT8 THEN\n\tSELECT v FROM ROWS FROM (nextval(106:::REGCLASS)) AS v (\"int\") INTO i;\nELSIF nextval(106:::REGCLASS) = 3:::INT8 THEN\n\tSELECT nextval(106:::REGCLASS);\n\tSELECT nextval(106:::REGCLASS);\nEND IF;\nEND LOOP;\nOPEN curs FOR SELECT nextval(106:::REGCLASS);\nRETURN nextval(106:::REGCLASS);\nEXCEPTION\nWHEN division_by_zero THEN\nRAISE notice\nUSING MESSAGE = format('next val: %d':::STRING, nextval(106:::REGCLASS));\nWHEN not_null_violation THEN\nSELECT nextval(106:::REGCLASS);\nSELECT nextval(106:::REGCLASS);\nRAISE notice\nUSING MESSAGE = format('next val: %d':::STRING, nextval(106:::REGCLASS));\nEND\n;"

query T
SELECT get_body_str('f_rewrite');
query TT
SHOW CREATE FUNCTION f_rewrite;
----
"BEGIN\nINSERT INTO test.public.t_rewrite(v) VALUES (nextval('seq':::STRING)) RETURNING v;\nEND\n;"
f_rewrite CREATE FUNCTION public.f_rewrite()
RETURNS INT8
VOLATILE
NOT LEAKPROOF
CALLED ON NULL INPUT
LANGUAGE plpgsql
AS $$
DECLARE
i INT8 := nextval('public.seq'::REGCLASS);
j INT8 := nextval('public.seq'::REGCLASS);
curs REFCURSOR := nextval('public.seq'::REGCLASS)::STRING;
curs2 CURSOR FOR SELECT nextval('public.seq'::REGCLASS);
BEGIN
RAISE notice
USING MESSAGE = format('next val: %d':::STRING, nextval('public.seq'::REGCLASS));
RAISE notice 'val1: %, val2: %', nextval('public.seq'::REGCLASS), nextval('public.seq'::REGCLASS);
WHILE nextval('public.seq'::REGCLASS) < 10:::INT8 LOOP
i := nextval('public.seq'::REGCLASS);
SELECT nextval('public.seq'::REGCLASS);
IF nextval('public.seq'::REGCLASS) = 1:::INT8 THEN
SELECT nextval('public.seq'::REGCLASS);
SELECT nextval('public.seq'::REGCLASS);
CONTINUE;
ELSIF nextval('public.seq'::REGCLASS) = 2:::INT8 THEN
SELECT v FROM ROWS FROM (nextval('public.seq'::REGCLASS)) AS v ("int") INTO i;
ELSIF nextval('public.seq'::REGCLASS) = 3:::INT8 THEN
SELECT nextval('public.seq'::REGCLASS);
SELECT nextval('public.seq'::REGCLASS);
END IF;
END LOOP;
OPEN curs FOR SELECT nextval('public.seq'::REGCLASS);
RETURN nextval('public.seq'::REGCLASS);
EXCEPTION
WHEN division_by_zero THEN
RAISE notice
USING MESSAGE = format('next val: %d':::STRING, nextval('public.seq'::REGCLASS));
WHEN not_null_violation THEN
SELECT nextval('public.seq'::REGCLASS);
SELECT nextval('public.seq'::REGCLASS);
RAISE notice
USING MESSAGE = format('next val: %d':::STRING, nextval('public.seq'::REGCLASS));
END
$$

statement ok
DROP FUNCTION f_rewrite();

statement ok
CREATE OR REPLACE FUNCTION f_rewrite() RETURNS weekday AS
$$
DECLARE
day weekday := 'wednesday'::weekday;
today weekday := 'thursday'::weekday;
curs REFCURSOR := 'monday'::weekday::STRING;
curs2 CURSOR FOR SELECT 'tuesday'::weekday;
BEGIN
SELECT 'wednesday'::weekday;
RAISE NOTICE USING MESSAGE = format('val: %d','wednesday'::weekday);
RAISE NOTICE 'val1: %, val2: %', 'wednesday'::weekday, 'thursday'::weekday;
WHILE day != 'wednesday'::weekday LOOP
day = 'friday'::weekday;
SELECT 'wednesday'::weekday;
IF day = 'wednesday'::weekday THEN
day = 'thursday'::weekday;
SELECT 'tuesday'::weekday;
CONTINUE;
ELSIF day = 'monday'::weekday THEN
SELECT 'tuesday'::weekday INTO day;
ELSIF day = 'tuesday'::weekday THEN
SELECT 'wednesday'::weekday INTO day;
SELECT 'wednesday'::weekday;
END IF;
END LOOP;
OPEN curs FOR SELECT 'wednesday'::weekday;
RETURN 'wednesday'::weekday;
EXCEPTION
WHEN division_by_zero THEN
RAISE NOTICE USING MESSAGE = format('val: %d','wednesday'::weekday);
WHEN not_null_violation THEN
SELECT 'wednesday'::weekday;
RAISE NOTICE 'val: %', 'wednesday'::weekday;
END
$$ LANGUAGE PLPGSQL
$$ LANGUAGE PLPGSQL;

query T
SELECT get_body_str('f_rewrite');
----
"BEGIN\nSELECT 'wednesday'::@100107;\nEND\n;"
"DECLARE\nday @100107 := b'\\x80':::@100107;\ntoday @100107 := b'\\xa0':::@100107;\ncurs REFCURSOR := b' ':::@100107::STRING;\ncurs2 CURSOR FOR SELECT b'@':::@100107;\nBEGIN\nRAISE notice\nUSING MESSAGE = format('val: %d':::STRING, b'\\x80':::@100107);\nRAISE notice 'val1: %, val2: %', b'\\x80':::@100107, b'\\xa0':::@100107;\nWHILE day != b'\\x80':::@100107 LOOP\nday := b'\\xc0':::@100107;\nSELECT b'\\x80':::@100107;\nIF day = b'\\x80':::@100107 THEN\n\tday := b'\\xa0':::@100107;\n\tSELECT b'@':::@100107;\n\tCONTINUE;\nELSIF day = b' ':::@100107 THEN\n\tSELECT b'@':::@100107 INTO day;\nELSIF day = b'@':::@100107 THEN\n\tSELECT b'\\x80':::@100107 INTO day;\n\tSELECT b'\\x80':::@100107;\nEND IF;\nEND LOOP;\nOPEN curs FOR SELECT b'\\x80':::@100107;\nRETURN b'\\x80':::@100107;\nEXCEPTION\nWHEN division_by_zero THEN\nRAISE notice\nUSING MESSAGE = format('val: %d':::STRING, b'\\x80':::@100107);\nWHEN not_null_violation THEN\nSELECT b'\\x80':::@100107;\nRAISE notice 'val: %', b'\\x80':::@100107;\nEND\n;"

statement ok
CREATE OR REPLACE FUNCTION f_rewrite() RETURNS weekday AS
$$
BEGIN
UPDATE t_rewrite SET w = 'thursday'::weekday WHERE w = 'wednesday'::weekday RETURNING w;
END
$$ LANGUAGE PLPGSQL

query T
SELECT get_body_str('f_rewrite');
query TT
SHOW CREATE FUNCTION f_rewrite;
----
"BEGIN\nUPDATE test.public.t_rewrite SET w = 'thursday'::@100107 WHERE w = 'wednesday'::@100107 RETURNING w;\nEND\n;"
f_rewrite CREATE FUNCTION public.f_rewrite()
RETURNS test.public.weekday
VOLATILE
NOT LEAKPROOF
CALLED ON NULL INPUT
LANGUAGE plpgsql
AS $$
DECLARE
day test.public.weekday := 'wednesday':::test.public.weekday;
today test.public.weekday := 'thursday':::test.public.weekday;
curs REFCURSOR := 'monday':::test.public.weekday::STRING;
curs2 CURSOR FOR SELECT 'tuesday':::test.public.weekday;
BEGIN
RAISE notice
USING MESSAGE = format('val: %d':::STRING, 'wednesday':::test.public.weekday);
RAISE notice 'val1: %, val2: %', 'wednesday':::test.public.weekday, 'thursday':::test.public.weekday;
WHILE day != 'wednesday':::test.public.weekday LOOP
day := 'friday':::test.public.weekday;
SELECT 'wednesday':::test.public.weekday;
IF day = 'wednesday':::test.public.weekday THEN
day := 'thursday':::test.public.weekday;
SELECT 'tuesday':::test.public.weekday;
CONTINUE;
ELSIF day = 'monday':::test.public.weekday THEN
SELECT 'tuesday':::test.public.weekday INTO day;
ELSIF day = 'tuesday':::test.public.weekday THEN
SELECT 'wednesday':::test.public.weekday INTO day;
SELECT 'wednesday':::test.public.weekday;
END IF;
END LOOP;
OPEN curs FOR SELECT 'wednesday':::test.public.weekday;
RETURN 'wednesday':::test.public.weekday;
EXCEPTION
WHEN division_by_zero THEN
RAISE notice
USING MESSAGE = format('val: %d':::STRING, 'wednesday':::test.public.weekday);
WHEN not_null_violation THEN
SELECT 'wednesday':::test.public.weekday;
RAISE notice 'val: %', 'wednesday':::test.public.weekday;
END
$$

statement ok
DROP FUNCTION f_rewrite();

subtest end

Expand All @@ -110,7 +233,7 @@ $$ LANGUAGE PLPGSQL
query T
SELECT get_body_str('p_rewrite');
----
"BEGIN\nINSERT INTO test.public.t_rewrite(v) VALUES (nextval('seq':::STRING)) RETURNING v;\nEND\n;"
"BEGIN\nINSERT INTO test.public.t_rewrite(v) VALUES (nextval(106:::REGCLASS)) RETURNING v;\nEND\n;"

statement ok
DROP PROCEDURE p_rewrite();
Expand All @@ -126,6 +249,6 @@ $$ LANGUAGE PLPGSQL
query T
SELECT get_body_str('p_rewrite');
----
"BEGIN\nUPDATE test.public.t_rewrite SET w = 'thursday'::@100107 WHERE w = 'wednesday'::@100107 RETURNING w;\nEND\n;"
"BEGIN\nUPDATE test.public.t_rewrite SET w = b'\\xa0':::@100107 WHERE w = b'\\x80':::@100107 RETURNING w;\nEND\n;"

subtest end
2 changes: 2 additions & 0 deletions pkg/sql/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,8 @@ go_library(
"//pkg/sql/sem/catconstants",
"//pkg/sql/sem/catid",
"//pkg/sql/sem/eval",
"//pkg/sql/sem/plpgsqltree",
"//pkg/sql/sem/plpgsqltree/utils",
"//pkg/sql/sem/semenumpb",
"//pkg/sql/sem/transform",
"//pkg/sql/sem/tree",
Expand Down
31 changes: 8 additions & 23 deletions pkg/sql/crdb_internal.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ import (
"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode"
"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror"
"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgnotice"
plpgsql "github.com/cockroachdb/cockroach/pkg/sql/plpgsql/parser"
"github.com/cockroachdb/cockroach/pkg/sql/privilege"
"github.com/cockroachdb/cockroach/pkg/sql/protoreflect"
"github.com/cockroachdb/cockroach/pkg/sql/roleoption"
Expand Down Expand Up @@ -3579,29 +3578,15 @@ func createRoutinePopulate(
for i := range treeNode.Options {
if body, ok := treeNode.Options[i].(tree.RoutineBodyStr); ok {
bodyStr := string(body)
switch fnDesc.GetLanguage() {
case catpb.Function_SQL:
bodyStr, err = formatFunctionQueryTypesForDisplay(ctx, &p.semaCtx, p.SessionData(), bodyStr)
if err != nil {
return err
}
bodyStr, err = formatQuerySequencesForDisplay(ctx, &p.semaCtx, bodyStr, true /* multiStmt */)
if err != nil {
return err
}
bodyStr = "\n" + bodyStr + "\n"
case catpb.Function_PLPGSQL:
// TODO(drewk): integrate this with the SQL case above.
plpgsqlStmt, err := plpgsql.Parse(bodyStr)
if err != nil {
return err
}
fmtCtx := tree.NewFmtCtx(tree.FmtParsable)
fmtCtx.FormatNode(plpgsqlStmt.AST)
bodyStr = "\n" + fmtCtx.CloseAndGetString()
default:
return errors.AssertionFailedf("unexpected function language: %s", fnDesc.GetLanguage())
bodyStr, err = formatFunctionQueryTypesForDisplay(ctx, &p.semaCtx, p.SessionData(), bodyStr, fnDesc.GetLanguage())
if err != nil {
return err
}
bodyStr, err = formatQuerySequencesForDisplay(ctx, &p.semaCtx, bodyStr, true /* multiStmt */, fnDesc.GetLanguage())
if err != nil {
return err
}
bodyStr = "\n" + bodyStr + "\n"
stmtStrs := strings.Split(bodyStr, "\n")
for i := range stmtStrs {
if stmtStrs[i] != "" {
Expand Down
15 changes: 4 additions & 11 deletions pkg/sql/create_function.go
Original file line number Diff line number Diff line change
Expand Up @@ -455,26 +455,19 @@ func setFuncOptions(
}
}

switch lang {
case catpb.Function_SQL:
if lang != catpb.Function_UNKNOWN_LANGUAGE && body != "" {
// Replace any sequence names in the function body with IDs.
seqReplacedFuncBody, err := replaceSeqNamesWithIDs(params.ctx, params.p, body, true)
seqReplacedFuncBody, err := replaceSeqNamesWithIDsLang(params.ctx, params.p, body, true, lang)
if err != nil {
return err
}
typeReplacedFuncBody, err := serializeUserDefinedTypes(
params.ctx, params.p.SemaCtx(), seqReplacedFuncBody, true /* multiStmt */, "UDFs",
)
typeReplacedFuncBody, err := serializeUserDefinedTypesLang(
params.ctx, params.p.SemaCtx(), seqReplacedFuncBody, true /* multiStmt */, "UDFs", lang)
if err != nil {
return err
}
udfDesc.SetFuncBody(typeReplacedFuncBody)
case catpb.Function_PLPGSQL:
// TODO(#115627): make replaceSeqNamesWithIDs and serializeUserDefinedTypes
// play nice with PL/pgSQL.
udfDesc.SetFuncBody(body)
}

return nil
}

Expand Down
Loading

0 comments on commit e0f9c28

Please sign in to comment.