Skip to content

Commit

Permalink
sql: support sequence and udt name rewriting in plpgsql
Browse files Browse the repository at this point in the history
CRDB rewrites sequence and UDT names as IDs in views and functions so
that if the sequence or UDT is renamed the views and functions using
them don't break. This PR adds support for this in PLpgSQL.

Epic: None
Fixes: #115627

Release note (sql change): Fixes a bug in PLpgSQL where altering the
name of a sequence or UDT that was used in a PLpgSQL function or
procedure could break them. This is only present in 23.2 alpha and beta
releases.
  • Loading branch information
rharding6373 committed Dec 14, 2023
1 parent 5459be7 commit 33d6c2d
Show file tree
Hide file tree
Showing 13 changed files with 1,181 additions and 266 deletions.
295 changes: 270 additions & 25 deletions pkg/ccl/logictestccl/testdata/logic_test/udf_rewrite

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions pkg/sql/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,8 @@ go_library(
"//pkg/sql/sem/catconstants",
"//pkg/sql/sem/catid",
"//pkg/sql/sem/eval",
"//pkg/sql/sem/plpgsqltree",
"//pkg/sql/sem/plpgsqltree/utils",
"//pkg/sql/sem/semenumpb",
"//pkg/sql/sem/transform",
"//pkg/sql/sem/tree",
Expand Down
31 changes: 8 additions & 23 deletions pkg/sql/crdb_internal.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ import (
"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode"
"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror"
"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgnotice"
plpgsql "github.com/cockroachdb/cockroach/pkg/sql/plpgsql/parser"
"github.com/cockroachdb/cockroach/pkg/sql/privilege"
"github.com/cockroachdb/cockroach/pkg/sql/protoreflect"
"github.com/cockroachdb/cockroach/pkg/sql/roleoption"
Expand Down Expand Up @@ -3585,29 +3584,15 @@ func createRoutinePopulate(
for i := range treeNode.Options {
if body, ok := treeNode.Options[i].(tree.RoutineBodyStr); ok {
bodyStr := string(body)
switch fnDesc.GetLanguage() {
case catpb.Function_SQL:
bodyStr, err = formatFunctionQueryTypesForDisplay(ctx, &p.semaCtx, p.SessionData(), bodyStr)
if err != nil {
return err
}
bodyStr, err = formatQuerySequencesForDisplay(ctx, &p.semaCtx, bodyStr, true /* multiStmt */)
if err != nil {
return err
}
bodyStr = "\n" + bodyStr + "\n"
case catpb.Function_PLPGSQL:
// TODO(drewk): integrate this with the SQL case above.
plpgsqlStmt, err := plpgsql.Parse(bodyStr)
if err != nil {
return err
}
fmtCtx := tree.NewFmtCtx(tree.FmtParsable)
fmtCtx.FormatNode(plpgsqlStmt.AST)
bodyStr = "\n" + fmtCtx.CloseAndGetString()
default:
return errors.AssertionFailedf("unexpected function language: %s", fnDesc.GetLanguage())
bodyStr, err = formatFunctionQueryTypesForDisplay(ctx, &p.semaCtx, p.SessionData(), bodyStr, fnDesc.GetLanguage())
if err != nil {
return err
}
bodyStr, err = formatQuerySequencesForDisplay(ctx, &p.semaCtx, bodyStr, true /* multiStmt */, fnDesc.GetLanguage())
if err != nil {
return err
}
bodyStr = "\n" + bodyStr + "\n"
stmtStrs := strings.Split(bodyStr, "\n")
for i := range stmtStrs {
if stmtStrs[i] != "" {
Expand Down
15 changes: 4 additions & 11 deletions pkg/sql/create_function.go
Original file line number Diff line number Diff line change
Expand Up @@ -455,26 +455,19 @@ func setFuncOptions(
}
}

switch lang {
case catpb.Function_SQL:
if lang != catpb.Function_UNKNOWN_LANGUAGE && body != "" {
// Replace any sequence names in the function body with IDs.
seqReplacedFuncBody, err := replaceSeqNamesWithIDs(params.ctx, params.p, body, true)
seqReplacedFuncBody, err := replaceSeqNamesWithIDsLang(params.ctx, params.p, body, true, lang)
if err != nil {
return err
}
typeReplacedFuncBody, err := serializeUserDefinedTypes(
params.ctx, params.p.SemaCtx(), seqReplacedFuncBody, true /* multiStmt */, "UDFs",
)
typeReplacedFuncBody, err := serializeUserDefinedTypesLang(
params.ctx, params.p.SemaCtx(), seqReplacedFuncBody, true /* multiStmt */, "UDFs", lang)
if err != nil {
return err
}
udfDesc.SetFuncBody(typeReplacedFuncBody)
case catpb.Function_PLPGSQL:
// TODO(#115627): make replaceSeqNamesWithIDs and serializeUserDefinedTypes
// play nice with PL/pgSQL.
udfDesc.SetFuncBody(body)
}

return nil
}

Expand Down
201 changes: 150 additions & 51 deletions pkg/sql/create_view.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,11 @@ import (
"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode"
"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror"
"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgnotice"
plpgsql "github.com/cockroachdb/cockroach/pkg/sql/plpgsql/parser"
"github.com/cockroachdb/cockroach/pkg/sql/privilege"
"github.com/cockroachdb/cockroach/pkg/sql/sem/eval"
"github.com/cockroachdb/cockroach/pkg/sql/sem/plpgsqltree"
"github.com/cockroachdb/cockroach/pkg/sql/sem/plpgsqltree/utils"
"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
"github.com/cockroachdb/cockroach/pkg/sql/sqlerrors"
"github.com/cockroachdb/cockroach/pkg/sql/sqltelemetry"
Expand Down Expand Up @@ -425,14 +428,31 @@ func makeViewTableDesc(
return desc, nil
}

// replaceSeqNamesWithIDs prepares to walk the given viewQuery by defining the
// function used to replace sequence names with IDs, and parsing the
// viewQuery into a statement.
// replaceSeqNamesWithIDsLang walks the query in queryStr, replacing any
// sequence names with their IDs and returning a new query string with the names
// replaced. It assumes that the query is in the SQL language.
// TODO (Chengxiong): move this to a better place.
func replaceSeqNamesWithIDs(
ctx context.Context, sc resolver.SchemaResolver, queryStr string, multiStmt bool,
) (string, error) {
return replaceSeqNamesWithIDsLang(ctx, sc, queryStr, multiStmt, catpb.Function_SQL)
}

// replaceSeqNamesWithIDsLang walks the query in queryStr, replacing any
// sequence names with their IDs and returning a new query string with the names
// replaced. Queries may be in either the SQL or PLpgSQL language, indicated by
// lang.
func replaceSeqNamesWithIDsLang(
ctx context.Context,
sc resolver.SchemaResolver,
queryStr string,
multiStmt bool,
lang catpb.Function_Language,
) (string, error) {
replaceSeqFunc := func(expr tree.Expr) (recurse bool, newExpr tree.Expr, err error) {
if expr == nil {
return false, expr, nil
}
seqIdentifiers, err := seqexpr.GetUsedSequences(expr)
if err != nil {
return false, expr, err
Expand All @@ -452,48 +472,83 @@ func replaceSeqNamesWithIDs(
return false, newExpr, nil
}

var stmts tree.Statements
if multiStmt {
parsedStmtd, err := parser.Parse(queryStr)
if err != nil {
return "", errors.Wrap(err, "failed to parse query string")
fmtCtx := tree.NewFmtCtx(tree.FmtSimple)
switch lang {
case catpb.Function_SQL:
var stmts tree.Statements
if multiStmt {
parsedStmtd, err := parser.Parse(queryStr)
if err != nil {
return "", errors.Wrap(err, "failed to parse query string")
}
for _, s := range parsedStmtd {
stmts = append(stmts, s.AST)
}
} else {
stmt, err := parser.ParseOne(queryStr)
if err != nil {
return "", errors.Wrap(err, "failed to parse query string")
}
stmts = tree.Statements{stmt.AST}
}
for _, s := range parsedStmtd {
stmts = append(stmts, s.AST)

for i, stmt := range stmts {
newStmt, err := tree.SimpleStmtVisit(stmt, replaceSeqFunc)
if err != nil {
return "", err
}
if i > 0 {
fmtCtx.WriteString("\n")
}
fmtCtx.FormatNode(newStmt)
if multiStmt {
fmtCtx.WriteString(";")
}
}
} else {
stmt, err := parser.ParseOne(queryStr)
case catpb.Function_PLPGSQL:
var stmts plpgsqltree.Statement
plstmt, err := plpgsql.Parse(queryStr)
if err != nil {
return "", errors.Wrap(err, "failed to parse query string")
}
stmts = tree.Statements{stmt.AST}
}
stmts = plstmt.AST

fmtCtx := tree.NewFmtCtx(tree.FmtSimple)
for i, stmt := range stmts {
newStmt, err := tree.SimpleStmtVisit(stmt, replaceSeqFunc)
if err != nil {
return "", err
}
if i > 0 {
fmtCtx.WriteString("\n")
}
v := utils.SQLStmtVisitor{Fn: replaceSeqFunc}
newStmt := plpgsqltree.Walk(&v, stmts)
fmtCtx.FormatNode(newStmt)
if multiStmt {
fmtCtx.WriteString(";")
}
}

return fmtCtx.String(), nil
}

// serializeUserDefinedTypes will walk the given view query
// and serialize any user defined types, so that renaming the type
// does not corrupt the view.
// serializeUserDefinedTypes walks the given query and serializes any
// user defined types as IDs, so that renaming the type does not cause
// corruption, and returns a new query string containing the replacement IDs.
// It assumes that the query language is SQL.
func serializeUserDefinedTypes(
ctx context.Context, semaCtx *tree.SemaContext, queries string, multiStmt bool, parentType string,
) (string, error) {
return serializeUserDefinedTypesLang(ctx, semaCtx, queries, multiStmt, parentType, catpb.Function_SQL)
}

// serializeUserDefinedTypesLang walks the given query and serializes any
// user defined types as IDs, so that renaming the type does not cause
// corruption, and returns a new query string containing the replacement IDs.
// The query may be in either the SQL or PLpgSQL language, indicated by lang.
func serializeUserDefinedTypesLang(
ctx context.Context,
semaCtx *tree.SemaContext,
queries string,
multiStmt bool,
parentType string,
lang catpb.Function_Language,
) (string, error) {
// replaceFunc is a visitor function that replaces user defined types in SQL
// expressions with their IDs.
replaceFunc := func(expr tree.Expr) (recurse bool, newExpr tree.Expr, err error) {
if expr == nil {
return false, expr, nil
}
var innerExpr tree.Expr
var typRef tree.ResolvableTypeReference
switch n := expr.(type) {
Expand Down Expand Up @@ -539,39 +594,83 @@ func serializeUserDefinedTypes(
}
return false, parsedExpr, nil
}

var stmts tree.Statements
if multiStmt {
parsedStmts, err := parser.Parse(queries)
if err != nil {
return "", errors.Wrap(err, "failed to parse query")
// replaceTypeFunc is a visitor function that replaces type annotations
// containing user defined types with their IDs. This is currently only
// necessary for some kinds of PLpgSQL statements.
replaceTypeFunc := func(typ tree.ResolvableTypeReference) (newTyp tree.ResolvableTypeReference, err error) {
if typ == nil {
return typ, nil
}
stmts = make(tree.Statements, len(parsedStmts))
for i, stmt := range parsedStmts {
stmts[i] = stmt.AST
// semaCtx may be nil if this is a virtual view being created at
// init time.
var typeResolver tree.TypeReferenceResolver
if semaCtx != nil {
typeResolver = semaCtx.TypeResolver
}
} else {
stmt, err := parser.ParseOne(queries)
var t *types.T
t, err = tree.ResolveType(ctx, typ, typeResolver)
if err != nil {
return "", errors.Wrap(err, "failed to parse query")
return typ, err
}
if !t.UserDefined() {
return typ, nil
}
stmts = tree.Statements{stmt.AST}
return &tree.OIDTypeReference{OID: t.Oid()}, nil
}

fmtCtx := tree.NewFmtCtx(tree.FmtSimple)
for i, stmt := range stmts {
newStmt, err := tree.SimpleStmtVisit(stmt, replaceFunc)
if err != nil {
return "", err
switch lang {
case catpb.Function_SQL:
var stmts tree.Statements
if multiStmt {
parsedStmts, err := parser.Parse(queries)
if err != nil {
return "", errors.Wrap(err, "failed to parse query")
}
stmts = make(tree.Statements, len(parsedStmts))
for i, stmt := range parsedStmts {
stmts[i] = stmt.AST
}
} else {
stmt, err := parser.ParseOne(queries)
if err != nil {
return "", errors.Wrap(err, "failed to parse query")
}
stmts = tree.Statements{stmt.AST}
}
if i > 0 {
fmtCtx.WriteString("\n")

for i, stmt := range stmts {
newStmt, err := tree.SimpleStmtVisit(stmt, replaceFunc)
if err != nil {
return "", err
}
if i > 0 {
fmtCtx.WriteString("\n")
}
fmtCtx.FormatNode(newStmt)
if multiStmt {
fmtCtx.WriteString(";")
}
}
fmtCtx.FormatNode(newStmt)
if multiStmt {
fmtCtx.WriteString(";")
case catpb.Function_PLPGSQL:
var stmts plpgsqltree.Statement
plstmt, err := plpgsql.Parse(queries)
if err != nil {
return "", errors.Wrap(err, "failed to parse query string")
}
stmts = plstmt.AST

v := utils.SQLStmtVisitor{Fn: replaceFunc}
newStmt := plpgsqltree.Walk(&v, stmts)
// Some PLpgSQL statements (i.e., declarations), may contain type
// annotations containing the UDT. We need to walk the AST to replace them,
// too.
v2 := utils.TypeRefVisitor{Fn: replaceTypeFunc}
newStmt = plpgsqltree.Walk(&v2, newStmt)
fmtCtx.FormatNode(newStmt)
fmtCtx.WriteString(";")
}

return fmtCtx.CloseAndGetString(), nil
}

Expand Down
7 changes: 4 additions & 3 deletions pkg/sql/opt/optbuilder/plpgsql.go
Original file line number Diff line number Diff line change
Expand Up @@ -1531,7 +1531,7 @@ func newRecordTypeVisitor(

var _ ast.StatementVisitor = &recordTypeVisitor{}

func (r *recordTypeVisitor) Visit(stmt ast.Statement) {
func (r *recordTypeVisitor) Visit(stmt ast.Statement) (newStmt ast.Statement, changed bool) {
if retStmt, ok := stmt.(*ast.Return); ok {
desired := types.Any
if r.typ != types.Unknown {
Expand All @@ -1544,7 +1544,7 @@ func (r *recordTypeVisitor) Visit(stmt ast.Statement) {
}
typ := typedExpr.ResolvedType()
if typ == types.Unknown {
return
return stmt, false
}
if typ.Family() != types.TupleFamily {
panic(pgerror.New(pgcode.DatatypeMismatch,
Expand All @@ -1553,7 +1553,7 @@ func (r *recordTypeVisitor) Visit(stmt ast.Statement) {
}
if r.typ == types.Unknown {
r.typ = typ
return
return stmt, false
}
if !typ.Identical(r.typ) {
panic(errors.WithHint(
Expand All @@ -1564,4 +1564,5 @@ func (r *recordTypeVisitor) Visit(stmt ast.Statement) {
))
}
}
return stmt, false
}
1 change: 1 addition & 0 deletions pkg/sql/plpgsql/parser/testdata/stmt_case
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ BEGIN
END CASE;
END
----
decl_stmt: 1
stmt_block: 1
stmt_call: 3
stmt_case: 1
Expand Down
1 change: 1 addition & 0 deletions pkg/sql/sem/plpgsqltree/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ go_library(
visibility = ["//visibility:public"],
deps = [
"//pkg/sql/sem/tree",
"//pkg/util/errorutil/unimplemented",
"@com_github_cockroachdb_errors//:errors",
],
)
Loading

0 comments on commit 33d6c2d

Please sign in to comment.