Skip to content

Commit

Permalink
Merge pull request #116419 from rharding6373/backport23.2-115809
Browse files Browse the repository at this point in the history
release-23.2: sql: support sequence and udt name rewriting in plpgsql
  • Loading branch information
rharding6373 committed Dec 15, 2023
2 parents 4361347 + 33d6c2d commit 5150b93
Show file tree
Hide file tree
Showing 13 changed files with 1,181 additions and 266 deletions.
295 changes: 270 additions & 25 deletions pkg/ccl/logictestccl/testdata/logic_test/udf_rewrite

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions pkg/sql/BUILD.bazel
Expand Up @@ -490,6 +490,8 @@ go_library(
"//pkg/sql/sem/catconstants",
"//pkg/sql/sem/catid",
"//pkg/sql/sem/eval",
"//pkg/sql/sem/plpgsqltree",
"//pkg/sql/sem/plpgsqltree/utils",
"//pkg/sql/sem/semenumpb",
"//pkg/sql/sem/transform",
"//pkg/sql/sem/tree",
Expand Down
31 changes: 8 additions & 23 deletions pkg/sql/crdb_internal.go
Expand Up @@ -70,7 +70,6 @@ import (
"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode"
"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror"
"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgnotice"
plpgsql "github.com/cockroachdb/cockroach/pkg/sql/plpgsql/parser"
"github.com/cockroachdb/cockroach/pkg/sql/privilege"
"github.com/cockroachdb/cockroach/pkg/sql/protoreflect"
"github.com/cockroachdb/cockroach/pkg/sql/roleoption"
Expand Down Expand Up @@ -3585,29 +3584,15 @@ func createRoutinePopulate(
for i := range treeNode.Options {
if body, ok := treeNode.Options[i].(tree.RoutineBodyStr); ok {
bodyStr := string(body)
switch fnDesc.GetLanguage() {
case catpb.Function_SQL:
bodyStr, err = formatFunctionQueryTypesForDisplay(ctx, &p.semaCtx, p.SessionData(), bodyStr)
if err != nil {
return err
}
bodyStr, err = formatQuerySequencesForDisplay(ctx, &p.semaCtx, bodyStr, true /* multiStmt */)
if err != nil {
return err
}
bodyStr = "\n" + bodyStr + "\n"
case catpb.Function_PLPGSQL:
// TODO(drewk): integrate this with the SQL case above.
plpgsqlStmt, err := plpgsql.Parse(bodyStr)
if err != nil {
return err
}
fmtCtx := tree.NewFmtCtx(tree.FmtParsable)
fmtCtx.FormatNode(plpgsqlStmt.AST)
bodyStr = "\n" + fmtCtx.CloseAndGetString()
default:
return errors.AssertionFailedf("unexpected function language: %s", fnDesc.GetLanguage())
bodyStr, err = formatFunctionQueryTypesForDisplay(ctx, &p.semaCtx, p.SessionData(), bodyStr, fnDesc.GetLanguage())
if err != nil {
return err
}
bodyStr, err = formatQuerySequencesForDisplay(ctx, &p.semaCtx, bodyStr, true /* multiStmt */, fnDesc.GetLanguage())
if err != nil {
return err
}
bodyStr = "\n" + bodyStr + "\n"
stmtStrs := strings.Split(bodyStr, "\n")
for i := range stmtStrs {
if stmtStrs[i] != "" {
Expand Down
15 changes: 4 additions & 11 deletions pkg/sql/create_function.go
Expand Up @@ -455,26 +455,19 @@ func setFuncOptions(
}
}

switch lang {
case catpb.Function_SQL:
if lang != catpb.Function_UNKNOWN_LANGUAGE && body != "" {
// Replace any sequence names in the function body with IDs.
seqReplacedFuncBody, err := replaceSeqNamesWithIDs(params.ctx, params.p, body, true)
seqReplacedFuncBody, err := replaceSeqNamesWithIDsLang(params.ctx, params.p, body, true, lang)
if err != nil {
return err
}
typeReplacedFuncBody, err := serializeUserDefinedTypes(
params.ctx, params.p.SemaCtx(), seqReplacedFuncBody, true /* multiStmt */, "UDFs",
)
typeReplacedFuncBody, err := serializeUserDefinedTypesLang(
params.ctx, params.p.SemaCtx(), seqReplacedFuncBody, true /* multiStmt */, "UDFs", lang)
if err != nil {
return err
}
udfDesc.SetFuncBody(typeReplacedFuncBody)
case catpb.Function_PLPGSQL:
// TODO(#115627): make replaceSeqNamesWithIDs and serializeUserDefinedTypes
// play nice with PL/pgSQL.
udfDesc.SetFuncBody(body)
}

return nil
}

Expand Down
201 changes: 150 additions & 51 deletions pkg/sql/create_view.go
Expand Up @@ -30,8 +30,11 @@ import (
"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode"
"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror"
"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgnotice"
plpgsql "github.com/cockroachdb/cockroach/pkg/sql/plpgsql/parser"
"github.com/cockroachdb/cockroach/pkg/sql/privilege"
"github.com/cockroachdb/cockroach/pkg/sql/sem/eval"
"github.com/cockroachdb/cockroach/pkg/sql/sem/plpgsqltree"
"github.com/cockroachdb/cockroach/pkg/sql/sem/plpgsqltree/utils"
"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
"github.com/cockroachdb/cockroach/pkg/sql/sqlerrors"
"github.com/cockroachdb/cockroach/pkg/sql/sqltelemetry"
Expand Down Expand Up @@ -425,14 +428,31 @@ func makeViewTableDesc(
return desc, nil
}

// replaceSeqNamesWithIDs prepares to walk the given viewQuery by defining the
// function used to replace sequence names with IDs, and parsing the
// viewQuery into a statement.
// replaceSeqNamesWithIDsLang walks the query in queryStr, replacing any
// sequence names with their IDs and returning a new query string with the names
// replaced. It assumes that the query is in the SQL language.
// TODO (Chengxiong): move this to a better place.
func replaceSeqNamesWithIDs(
ctx context.Context, sc resolver.SchemaResolver, queryStr string, multiStmt bool,
) (string, error) {
return replaceSeqNamesWithIDsLang(ctx, sc, queryStr, multiStmt, catpb.Function_SQL)
}

// replaceSeqNamesWithIDsLang walks the query in queryStr, replacing any
// sequence names with their IDs and returning a new query string with the names
// replaced. Queries may be in either the SQL or PLpgSQL language, indicated by
// lang.
func replaceSeqNamesWithIDsLang(
ctx context.Context,
sc resolver.SchemaResolver,
queryStr string,
multiStmt bool,
lang catpb.Function_Language,
) (string, error) {
replaceSeqFunc := func(expr tree.Expr) (recurse bool, newExpr tree.Expr, err error) {
if expr == nil {
return false, expr, nil
}
seqIdentifiers, err := seqexpr.GetUsedSequences(expr)
if err != nil {
return false, expr, err
Expand All @@ -452,48 +472,83 @@ func replaceSeqNamesWithIDs(
return false, newExpr, nil
}

var stmts tree.Statements
if multiStmt {
parsedStmtd, err := parser.Parse(queryStr)
if err != nil {
return "", errors.Wrap(err, "failed to parse query string")
fmtCtx := tree.NewFmtCtx(tree.FmtSimple)
switch lang {
case catpb.Function_SQL:
var stmts tree.Statements
if multiStmt {
parsedStmtd, err := parser.Parse(queryStr)
if err != nil {
return "", errors.Wrap(err, "failed to parse query string")
}
for _, s := range parsedStmtd {
stmts = append(stmts, s.AST)
}
} else {
stmt, err := parser.ParseOne(queryStr)
if err != nil {
return "", errors.Wrap(err, "failed to parse query string")
}
stmts = tree.Statements{stmt.AST}
}
for _, s := range parsedStmtd {
stmts = append(stmts, s.AST)

for i, stmt := range stmts {
newStmt, err := tree.SimpleStmtVisit(stmt, replaceSeqFunc)
if err != nil {
return "", err
}
if i > 0 {
fmtCtx.WriteString("\n")
}
fmtCtx.FormatNode(newStmt)
if multiStmt {
fmtCtx.WriteString(";")
}
}
} else {
stmt, err := parser.ParseOne(queryStr)
case catpb.Function_PLPGSQL:
var stmts plpgsqltree.Statement
plstmt, err := plpgsql.Parse(queryStr)
if err != nil {
return "", errors.Wrap(err, "failed to parse query string")
}
stmts = tree.Statements{stmt.AST}
}
stmts = plstmt.AST

fmtCtx := tree.NewFmtCtx(tree.FmtSimple)
for i, stmt := range stmts {
newStmt, err := tree.SimpleStmtVisit(stmt, replaceSeqFunc)
if err != nil {
return "", err
}
if i > 0 {
fmtCtx.WriteString("\n")
}
v := utils.SQLStmtVisitor{Fn: replaceSeqFunc}
newStmt := plpgsqltree.Walk(&v, stmts)
fmtCtx.FormatNode(newStmt)
if multiStmt {
fmtCtx.WriteString(";")
}
}

return fmtCtx.String(), nil
}

// serializeUserDefinedTypes will walk the given view query
// and serialize any user defined types, so that renaming the type
// does not corrupt the view.
// serializeUserDefinedTypes walks the given query and serializes any
// user defined types as IDs, so that renaming the type does not cause
// corruption, and returns a new query string containing the replacement IDs.
// It assumes that the query language is SQL.
func serializeUserDefinedTypes(
ctx context.Context, semaCtx *tree.SemaContext, queries string, multiStmt bool, parentType string,
) (string, error) {
return serializeUserDefinedTypesLang(ctx, semaCtx, queries, multiStmt, parentType, catpb.Function_SQL)
}

// serializeUserDefinedTypesLang walks the given query and serializes any
// user defined types as IDs, so that renaming the type does not cause
// corruption, and returns a new query string containing the replacement IDs.
// The query may be in either the SQL or PLpgSQL language, indicated by lang.
func serializeUserDefinedTypesLang(
ctx context.Context,
semaCtx *tree.SemaContext,
queries string,
multiStmt bool,
parentType string,
lang catpb.Function_Language,
) (string, error) {
// replaceFunc is a visitor function that replaces user defined types in SQL
// expressions with their IDs.
replaceFunc := func(expr tree.Expr) (recurse bool, newExpr tree.Expr, err error) {
if expr == nil {
return false, expr, nil
}
var innerExpr tree.Expr
var typRef tree.ResolvableTypeReference
switch n := expr.(type) {
Expand Down Expand Up @@ -539,39 +594,83 @@ func serializeUserDefinedTypes(
}
return false, parsedExpr, nil
}

var stmts tree.Statements
if multiStmt {
parsedStmts, err := parser.Parse(queries)
if err != nil {
return "", errors.Wrap(err, "failed to parse query")
// replaceTypeFunc is a visitor function that replaces type annotations
// containing user defined types with their IDs. This is currently only
// necessary for some kinds of PLpgSQL statements.
replaceTypeFunc := func(typ tree.ResolvableTypeReference) (newTyp tree.ResolvableTypeReference, err error) {
if typ == nil {
return typ, nil
}
stmts = make(tree.Statements, len(parsedStmts))
for i, stmt := range parsedStmts {
stmts[i] = stmt.AST
// semaCtx may be nil if this is a virtual view being created at
// init time.
var typeResolver tree.TypeReferenceResolver
if semaCtx != nil {
typeResolver = semaCtx.TypeResolver
}
} else {
stmt, err := parser.ParseOne(queries)
var t *types.T
t, err = tree.ResolveType(ctx, typ, typeResolver)
if err != nil {
return "", errors.Wrap(err, "failed to parse query")
return typ, err
}
if !t.UserDefined() {
return typ, nil
}
stmts = tree.Statements{stmt.AST}
return &tree.OIDTypeReference{OID: t.Oid()}, nil
}

fmtCtx := tree.NewFmtCtx(tree.FmtSimple)
for i, stmt := range stmts {
newStmt, err := tree.SimpleStmtVisit(stmt, replaceFunc)
if err != nil {
return "", err
switch lang {
case catpb.Function_SQL:
var stmts tree.Statements
if multiStmt {
parsedStmts, err := parser.Parse(queries)
if err != nil {
return "", errors.Wrap(err, "failed to parse query")
}
stmts = make(tree.Statements, len(parsedStmts))
for i, stmt := range parsedStmts {
stmts[i] = stmt.AST
}
} else {
stmt, err := parser.ParseOne(queries)
if err != nil {
return "", errors.Wrap(err, "failed to parse query")
}
stmts = tree.Statements{stmt.AST}
}
if i > 0 {
fmtCtx.WriteString("\n")

for i, stmt := range stmts {
newStmt, err := tree.SimpleStmtVisit(stmt, replaceFunc)
if err != nil {
return "", err
}
if i > 0 {
fmtCtx.WriteString("\n")
}
fmtCtx.FormatNode(newStmt)
if multiStmt {
fmtCtx.WriteString(";")
}
}
fmtCtx.FormatNode(newStmt)
if multiStmt {
fmtCtx.WriteString(";")
case catpb.Function_PLPGSQL:
var stmts plpgsqltree.Statement
plstmt, err := plpgsql.Parse(queries)
if err != nil {
return "", errors.Wrap(err, "failed to parse query string")
}
stmts = plstmt.AST

v := utils.SQLStmtVisitor{Fn: replaceFunc}
newStmt := plpgsqltree.Walk(&v, stmts)
// Some PLpgSQL statements (i.e., declarations), may contain type
// annotations containing the UDT. We need to walk the AST to replace them,
// too.
v2 := utils.TypeRefVisitor{Fn: replaceTypeFunc}
newStmt = plpgsqltree.Walk(&v2, newStmt)
fmtCtx.FormatNode(newStmt)
fmtCtx.WriteString(";")
}

return fmtCtx.CloseAndGetString(), nil
}

Expand Down
7 changes: 4 additions & 3 deletions pkg/sql/opt/optbuilder/plpgsql.go
Expand Up @@ -1531,7 +1531,7 @@ func newRecordTypeVisitor(

var _ ast.StatementVisitor = &recordTypeVisitor{}

func (r *recordTypeVisitor) Visit(stmt ast.Statement) {
func (r *recordTypeVisitor) Visit(stmt ast.Statement) (newStmt ast.Statement, changed bool) {
if retStmt, ok := stmt.(*ast.Return); ok {
desired := types.Any
if r.typ != types.Unknown {
Expand All @@ -1544,7 +1544,7 @@ func (r *recordTypeVisitor) Visit(stmt ast.Statement) {
}
typ := typedExpr.ResolvedType()
if typ == types.Unknown {
return
return stmt, false
}
if typ.Family() != types.TupleFamily {
panic(pgerror.New(pgcode.DatatypeMismatch,
Expand All @@ -1553,7 +1553,7 @@ func (r *recordTypeVisitor) Visit(stmt ast.Statement) {
}
if r.typ == types.Unknown {
r.typ = typ
return
return stmt, false
}
if !typ.Identical(r.typ) {
panic(errors.WithHint(
Expand All @@ -1564,4 +1564,5 @@ func (r *recordTypeVisitor) Visit(stmt ast.Statement) {
))
}
}
return stmt, false
}
1 change: 1 addition & 0 deletions pkg/sql/plpgsql/parser/testdata/stmt_case
Expand Up @@ -133,6 +133,7 @@ BEGIN
END CASE;
END
----
decl_stmt: 1
stmt_block: 1
stmt_call: 3
stmt_case: 1
Expand Down
1 change: 1 addition & 0 deletions pkg/sql/sem/plpgsqltree/BUILD.bazel
Expand Up @@ -13,6 +13,7 @@ go_library(
visibility = ["//visibility:public"],
deps = [
"//pkg/sql/sem/tree",
"//pkg/util/errorutil/unimplemented",
"@com_github_cockroachdb_errors//:errors",
],
)

0 comments on commit 5150b93

Please sign in to comment.