Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

release-23.2: sql: support sequence and udt name rewriting in plpgsql #116419

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
295 changes: 270 additions & 25 deletions pkg/ccl/logictestccl/testdata/logic_test/udf_rewrite

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions pkg/sql/BUILD.bazel
Expand Up @@ -490,6 +490,8 @@ go_library(
"//pkg/sql/sem/catconstants",
"//pkg/sql/sem/catid",
"//pkg/sql/sem/eval",
"//pkg/sql/sem/plpgsqltree",
"//pkg/sql/sem/plpgsqltree/utils",
"//pkg/sql/sem/semenumpb",
"//pkg/sql/sem/transform",
"//pkg/sql/sem/tree",
Expand Down
31 changes: 8 additions & 23 deletions pkg/sql/crdb_internal.go
Expand Up @@ -70,7 +70,6 @@ import (
"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode"
"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror"
"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgnotice"
plpgsql "github.com/cockroachdb/cockroach/pkg/sql/plpgsql/parser"
"github.com/cockroachdb/cockroach/pkg/sql/privilege"
"github.com/cockroachdb/cockroach/pkg/sql/protoreflect"
"github.com/cockroachdb/cockroach/pkg/sql/roleoption"
Expand Down Expand Up @@ -3585,29 +3584,15 @@ func createRoutinePopulate(
for i := range treeNode.Options {
if body, ok := treeNode.Options[i].(tree.RoutineBodyStr); ok {
bodyStr := string(body)
switch fnDesc.GetLanguage() {
case catpb.Function_SQL:
bodyStr, err = formatFunctionQueryTypesForDisplay(ctx, &p.semaCtx, p.SessionData(), bodyStr)
if err != nil {
return err
}
bodyStr, err = formatQuerySequencesForDisplay(ctx, &p.semaCtx, bodyStr, true /* multiStmt */)
if err != nil {
return err
}
bodyStr = "\n" + bodyStr + "\n"
case catpb.Function_PLPGSQL:
// TODO(drewk): integrate this with the SQL case above.
plpgsqlStmt, err := plpgsql.Parse(bodyStr)
if err != nil {
return err
}
fmtCtx := tree.NewFmtCtx(tree.FmtParsable)
fmtCtx.FormatNode(plpgsqlStmt.AST)
bodyStr = "\n" + fmtCtx.CloseAndGetString()
default:
return errors.AssertionFailedf("unexpected function language: %s", fnDesc.GetLanguage())
bodyStr, err = formatFunctionQueryTypesForDisplay(ctx, &p.semaCtx, p.SessionData(), bodyStr, fnDesc.GetLanguage())
if err != nil {
return err
}
bodyStr, err = formatQuerySequencesForDisplay(ctx, &p.semaCtx, bodyStr, true /* multiStmt */, fnDesc.GetLanguage())
if err != nil {
return err
}
bodyStr = "\n" + bodyStr + "\n"
stmtStrs := strings.Split(bodyStr, "\n")
for i := range stmtStrs {
if stmtStrs[i] != "" {
Expand Down
15 changes: 4 additions & 11 deletions pkg/sql/create_function.go
Expand Up @@ -455,26 +455,19 @@ func setFuncOptions(
}
}

switch lang {
case catpb.Function_SQL:
if lang != catpb.Function_UNKNOWN_LANGUAGE && body != "" {
// Replace any sequence names in the function body with IDs.
seqReplacedFuncBody, err := replaceSeqNamesWithIDs(params.ctx, params.p, body, true)
seqReplacedFuncBody, err := replaceSeqNamesWithIDsLang(params.ctx, params.p, body, true, lang)
if err != nil {
return err
}
typeReplacedFuncBody, err := serializeUserDefinedTypes(
params.ctx, params.p.SemaCtx(), seqReplacedFuncBody, true /* multiStmt */, "UDFs",
)
typeReplacedFuncBody, err := serializeUserDefinedTypesLang(
params.ctx, params.p.SemaCtx(), seqReplacedFuncBody, true /* multiStmt */, "UDFs", lang)
if err != nil {
return err
}
udfDesc.SetFuncBody(typeReplacedFuncBody)
case catpb.Function_PLPGSQL:
// TODO(#115627): make replaceSeqNamesWithIDs and serializeUserDefinedTypes
// play nice with PL/pgSQL.
udfDesc.SetFuncBody(body)
}

return nil
}

Expand Down
201 changes: 150 additions & 51 deletions pkg/sql/create_view.go
Expand Up @@ -30,8 +30,11 @@ import (
"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode"
"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror"
"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgnotice"
plpgsql "github.com/cockroachdb/cockroach/pkg/sql/plpgsql/parser"
"github.com/cockroachdb/cockroach/pkg/sql/privilege"
"github.com/cockroachdb/cockroach/pkg/sql/sem/eval"
"github.com/cockroachdb/cockroach/pkg/sql/sem/plpgsqltree"
"github.com/cockroachdb/cockroach/pkg/sql/sem/plpgsqltree/utils"
"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
"github.com/cockroachdb/cockroach/pkg/sql/sqlerrors"
"github.com/cockroachdb/cockroach/pkg/sql/sqltelemetry"
Expand Down Expand Up @@ -425,14 +428,31 @@ func makeViewTableDesc(
return desc, nil
}

// replaceSeqNamesWithIDs prepares to walk the given viewQuery by defining the
// function used to replace sequence names with IDs, and parsing the
// viewQuery into a statement.
// replaceSeqNamesWithIDsLang walks the query in queryStr, replacing any
// sequence names with their IDs and returning a new query string with the names
// replaced. It assumes that the query is in the SQL language.
// TODO (Chengxiong): move this to a better place.
func replaceSeqNamesWithIDs(
ctx context.Context, sc resolver.SchemaResolver, queryStr string, multiStmt bool,
) (string, error) {
return replaceSeqNamesWithIDsLang(ctx, sc, queryStr, multiStmt, catpb.Function_SQL)
}

// replaceSeqNamesWithIDsLang walks the query in queryStr, replacing any
// sequence names with their IDs and returning a new query string with the names
// replaced. Queries may be in either the SQL or PLpgSQL language, indicated by
// lang.
func replaceSeqNamesWithIDsLang(
ctx context.Context,
sc resolver.SchemaResolver,
queryStr string,
multiStmt bool,
lang catpb.Function_Language,
) (string, error) {
replaceSeqFunc := func(expr tree.Expr) (recurse bool, newExpr tree.Expr, err error) {
if expr == nil {
return false, expr, nil
}
seqIdentifiers, err := seqexpr.GetUsedSequences(expr)
if err != nil {
return false, expr, err
Expand All @@ -452,48 +472,83 @@ func replaceSeqNamesWithIDs(
return false, newExpr, nil
}

var stmts tree.Statements
if multiStmt {
parsedStmtd, err := parser.Parse(queryStr)
if err != nil {
return "", errors.Wrap(err, "failed to parse query string")
fmtCtx := tree.NewFmtCtx(tree.FmtSimple)
switch lang {
case catpb.Function_SQL:
var stmts tree.Statements
if multiStmt {
parsedStmtd, err := parser.Parse(queryStr)
if err != nil {
return "", errors.Wrap(err, "failed to parse query string")
}
for _, s := range parsedStmtd {
stmts = append(stmts, s.AST)
}
} else {
stmt, err := parser.ParseOne(queryStr)
if err != nil {
return "", errors.Wrap(err, "failed to parse query string")
}
stmts = tree.Statements{stmt.AST}
}
for _, s := range parsedStmtd {
stmts = append(stmts, s.AST)

for i, stmt := range stmts {
newStmt, err := tree.SimpleStmtVisit(stmt, replaceSeqFunc)
if err != nil {
return "", err
}
if i > 0 {
fmtCtx.WriteString("\n")
}
fmtCtx.FormatNode(newStmt)
if multiStmt {
fmtCtx.WriteString(";")
}
}
} else {
stmt, err := parser.ParseOne(queryStr)
case catpb.Function_PLPGSQL:
var stmts plpgsqltree.Statement
plstmt, err := plpgsql.Parse(queryStr)
if err != nil {
return "", errors.Wrap(err, "failed to parse query string")
}
stmts = tree.Statements{stmt.AST}
}
stmts = plstmt.AST

fmtCtx := tree.NewFmtCtx(tree.FmtSimple)
for i, stmt := range stmts {
newStmt, err := tree.SimpleStmtVisit(stmt, replaceSeqFunc)
if err != nil {
return "", err
}
if i > 0 {
fmtCtx.WriteString("\n")
}
v := utils.SQLStmtVisitor{Fn: replaceSeqFunc}
newStmt := plpgsqltree.Walk(&v, stmts)
fmtCtx.FormatNode(newStmt)
if multiStmt {
fmtCtx.WriteString(";")
}
}

return fmtCtx.String(), nil
}

// serializeUserDefinedTypes will walk the given view query
// and serialize any user defined types, so that renaming the type
// does not corrupt the view.
// serializeUserDefinedTypes walks the given query and serializes any
// user defined types as IDs, so that renaming the type does not cause
// corruption, and returns a new query string containing the replacement IDs.
// It assumes that the query language is SQL.
func serializeUserDefinedTypes(
ctx context.Context, semaCtx *tree.SemaContext, queries string, multiStmt bool, parentType string,
) (string, error) {
return serializeUserDefinedTypesLang(ctx, semaCtx, queries, multiStmt, parentType, catpb.Function_SQL)
}

// serializeUserDefinedTypesLang walks the given query and serializes any
// user defined types as IDs, so that renaming the type does not cause
// corruption, and returns a new query string containing the replacement IDs.
// The query may be in either the SQL or PLpgSQL language, indicated by lang.
func serializeUserDefinedTypesLang(
ctx context.Context,
semaCtx *tree.SemaContext,
queries string,
multiStmt bool,
parentType string,
lang catpb.Function_Language,
) (string, error) {
// replaceFunc is a visitor function that replaces user defined types in SQL
// expressions with their IDs.
replaceFunc := func(expr tree.Expr) (recurse bool, newExpr tree.Expr, err error) {
if expr == nil {
return false, expr, nil
}
var innerExpr tree.Expr
var typRef tree.ResolvableTypeReference
switch n := expr.(type) {
Expand Down Expand Up @@ -539,39 +594,83 @@ func serializeUserDefinedTypes(
}
return false, parsedExpr, nil
}

var stmts tree.Statements
if multiStmt {
parsedStmts, err := parser.Parse(queries)
if err != nil {
return "", errors.Wrap(err, "failed to parse query")
// replaceTypeFunc is a visitor function that replaces type annotations
// containing user defined types with their IDs. This is currently only
// necessary for some kinds of PLpgSQL statements.
replaceTypeFunc := func(typ tree.ResolvableTypeReference) (newTyp tree.ResolvableTypeReference, err error) {
if typ == nil {
return typ, nil
}
stmts = make(tree.Statements, len(parsedStmts))
for i, stmt := range parsedStmts {
stmts[i] = stmt.AST
// semaCtx may be nil if this is a virtual view being created at
// init time.
var typeResolver tree.TypeReferenceResolver
if semaCtx != nil {
typeResolver = semaCtx.TypeResolver
}
} else {
stmt, err := parser.ParseOne(queries)
var t *types.T
t, err = tree.ResolveType(ctx, typ, typeResolver)
if err != nil {
return "", errors.Wrap(err, "failed to parse query")
return typ, err
}
if !t.UserDefined() {
return typ, nil
}
stmts = tree.Statements{stmt.AST}
return &tree.OIDTypeReference{OID: t.Oid()}, nil
}

fmtCtx := tree.NewFmtCtx(tree.FmtSimple)
for i, stmt := range stmts {
newStmt, err := tree.SimpleStmtVisit(stmt, replaceFunc)
if err != nil {
return "", err
switch lang {
case catpb.Function_SQL:
var stmts tree.Statements
if multiStmt {
parsedStmts, err := parser.Parse(queries)
if err != nil {
return "", errors.Wrap(err, "failed to parse query")
}
stmts = make(tree.Statements, len(parsedStmts))
for i, stmt := range parsedStmts {
stmts[i] = stmt.AST
}
} else {
stmt, err := parser.ParseOne(queries)
if err != nil {
return "", errors.Wrap(err, "failed to parse query")
}
stmts = tree.Statements{stmt.AST}
}
if i > 0 {
fmtCtx.WriteString("\n")

for i, stmt := range stmts {
newStmt, err := tree.SimpleStmtVisit(stmt, replaceFunc)
if err != nil {
return "", err
}
if i > 0 {
fmtCtx.WriteString("\n")
}
fmtCtx.FormatNode(newStmt)
if multiStmt {
fmtCtx.WriteString(";")
}
}
fmtCtx.FormatNode(newStmt)
if multiStmt {
fmtCtx.WriteString(";")
case catpb.Function_PLPGSQL:
var stmts plpgsqltree.Statement
plstmt, err := plpgsql.Parse(queries)
if err != nil {
return "", errors.Wrap(err, "failed to parse query string")
}
stmts = plstmt.AST

v := utils.SQLStmtVisitor{Fn: replaceFunc}
newStmt := plpgsqltree.Walk(&v, stmts)
// Some PLpgSQL statements (i.e., declarations), may contain type
// annotations containing the UDT. We need to walk the AST to replace them,
// too.
v2 := utils.TypeRefVisitor{Fn: replaceTypeFunc}
newStmt = plpgsqltree.Walk(&v2, newStmt)
fmtCtx.FormatNode(newStmt)
fmtCtx.WriteString(";")
}

return fmtCtx.CloseAndGetString(), nil
}

Expand Down
7 changes: 4 additions & 3 deletions pkg/sql/opt/optbuilder/plpgsql.go
Expand Up @@ -1531,7 +1531,7 @@ func newRecordTypeVisitor(

var _ ast.StatementVisitor = &recordTypeVisitor{}

func (r *recordTypeVisitor) Visit(stmt ast.Statement) {
func (r *recordTypeVisitor) Visit(stmt ast.Statement) (newStmt ast.Statement, changed bool) {
if retStmt, ok := stmt.(*ast.Return); ok {
desired := types.Any
if r.typ != types.Unknown {
Expand All @@ -1544,7 +1544,7 @@ func (r *recordTypeVisitor) Visit(stmt ast.Statement) {
}
typ := typedExpr.ResolvedType()
if typ == types.Unknown {
return
return stmt, false
}
if typ.Family() != types.TupleFamily {
panic(pgerror.New(pgcode.DatatypeMismatch,
Expand All @@ -1553,7 +1553,7 @@ func (r *recordTypeVisitor) Visit(stmt ast.Statement) {
}
if r.typ == types.Unknown {
r.typ = typ
return
return stmt, false
}
if !typ.Identical(r.typ) {
panic(errors.WithHint(
Expand All @@ -1564,4 +1564,5 @@ func (r *recordTypeVisitor) Visit(stmt ast.Statement) {
))
}
}
return stmt, false
}
1 change: 1 addition & 0 deletions pkg/sql/plpgsql/parser/testdata/stmt_case
Expand Up @@ -133,6 +133,7 @@ BEGIN
END CASE;
END
----
decl_stmt: 1
stmt_block: 1
stmt_call: 3
stmt_case: 1
Expand Down
1 change: 1 addition & 0 deletions pkg/sql/sem/plpgsqltree/BUILD.bazel
Expand Up @@ -13,6 +13,7 @@ go_library(
visibility = ["//visibility:public"],
deps = [
"//pkg/sql/sem/tree",
"//pkg/util/errorutil/unimplemented",
"@com_github_cockroachdb_errors//:errors",
],
)