Skip to content

Commit

Permalink
Added CREATE PROCEDURE + supporting statements
Browse files Browse the repository at this point in the history
  • Loading branch information
Hydrocharged committed Feb 16, 2021
1 parent 91341d5 commit f72720c
Show file tree
Hide file tree
Showing 5 changed files with 5,326 additions and 4,465 deletions.
148 changes: 147 additions & 1 deletion go/vt/sqlparser/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ func (*BeginEndBlock) iStatement() {}
func (*CaseStatement) iStatement() {}
func (*IfStatement) iStatement() {}
func (*Signal) iStatement() {}
func (*Call) iStatement() {}

// ParenSelect can actually not be a top level statement,
// but we have to allow it because it's a requirement
Expand Down Expand Up @@ -689,6 +690,38 @@ func (s *Signal) walkSubtree(visit Visit) error {
return nil
}

// Call represents the CALL statement
type Call struct {
FuncName string
Params []Expr
}

func (c *Call) Format(buf *TrackedBuffer) {
buf.Myprintf("call %s", c.FuncName)
if len(c.Params) > 0 {
buf.Myprintf("(")
for i, param := range c.Params {
if i > 0 {
buf.Myprintf(", ")
}
buf.Myprintf("%v", param)
}
buf.Myprintf(")")
}
}

func (c *Call) walkSubtree(visit Visit) error {
if c == nil {
return nil
}
for _, expr := range c.Params {
if err := Walk(visit, expr); err != nil {
return err
}
}
return nil
}

// Stream represents a SELECT statement.
type Stream struct {
Comments Comments
Expand Down Expand Up @@ -921,6 +954,52 @@ type TriggerOrder struct {
OtherTriggerName string
}

type ProcedureSpec struct {
Name string
Definer string
Params []ProcedureParam
Characteristics []Characteristic
Body Statement
}

type ProcedureParamDirection string
const (
ProcedureParamDirection_In ProcedureParamDirection = "in"
ProcedureParamDirection_Inout ProcedureParamDirection = "inout"
ProcedureParamDirection_Out ProcedureParamDirection = "out"
)

type ProcedureParam struct {
Direction ProcedureParamDirection
Name string
Type ColumnType
}

type CharacteristicValue string
const (
CharacteristicValue_Comment CharacteristicValue = "comment"
CharacteristicValue_LanguageSql CharacteristicValue = "language sql"
CharacteristicValue_Deterministic CharacteristicValue = "deterministic"
CharacteristicValue_NotDeterministic CharacteristicValue = "not deterministic"
CharacteristicValue_ContainsSql CharacteristicValue = "contains sql"
CharacteristicValue_NoSql CharacteristicValue = "no sql"
CharacteristicValue_ReadsSqlData CharacteristicValue = "reads sql data"
CharacteristicValue_ModifiesSqlData CharacteristicValue = "modifies sql data"
CharacteristicValue_SqlSecurityDefiner CharacteristicValue = "sql security definer"
CharacteristicValue_SqlSecurityInvoker CharacteristicValue = "sql security invoker"
)

type Characteristic struct {
Type CharacteristicValue
Comment string
}
func (c Characteristic) String() string {
if c.Type == CharacteristicValue_Comment {
return fmt.Sprintf("comment '%s'", c.Comment)
}
return string(c.Type)
}

// DDL represents a CREATE, ALTER, DROP, RENAME, TRUNCATE or ANALYZE statement.
type DDL struct {
Action string
Expand Down Expand Up @@ -985,6 +1064,9 @@ type DDL struct {

// TriggerSpec is set for CREATE / ALTER / DROP trigger operations
TriggerSpec *TriggerSpec

// ProcedureSpec is set for CREATE PROCEDURE operations
ProcedureSpec *ProcedureSpec
}

// ColumnOrder is used in some DDL statements to specify or change the order of a column in a schema.
Expand Down Expand Up @@ -1049,6 +1131,26 @@ func (node *DDL) Format(buf *TrackedBuffer) {
}
buf.Myprintf("%s trigger %s %s %s on %v for each row %s%v",
node.Action, trigger.Name, trigger.Time, trigger.Event, node.Table, triggerOrder, trigger.Body)
} else if node.ProcedureSpec != nil {
proc := node.ProcedureSpec
sb := strings.Builder{}
sb.WriteString("create ")
if proc.Definer != "" {
sb.WriteString(fmt.Sprintf("definer = %s ", proc.Definer))
}
sb.WriteString(fmt.Sprintf("procedure %s (", proc.Name))
for i, param := range proc.Params {
if i > 0 {
sb.WriteString(", ")
}
sb.WriteString(string(param.Direction)+" ")
sb.WriteString(fmt.Sprintf("%s %s", param.Name, param.Type.String()))
}
sb.WriteString(")")
for _, characteristic := range proc.Characteristics {
sb.WriteString(" "+characteristic.String())
}
buf.Myprintf("%s %v", sb.String(), proc.Body)
} else {
notExists := ""
if node.IfNotExists {
Expand All @@ -1075,6 +1177,12 @@ func (node *DDL) Format(buf *TrackedBuffer) {
exists = " if exists"
}
buf.Myprintf(fmt.Sprintf("%s trigger%s %v", node.Action, exists, node.TriggerSpec.Name))
} else if node.ProcedureSpec != nil {
exists := ""
if node.IfExists {
exists = " if exists"
}
buf.Myprintf(fmt.Sprintf("%s procedure%s %v", node.Action, exists, node.ProcedureSpec.Name))
} else {
buf.Myprintf("%s table%s %v", node.Action, exists, node.FromTables)
}
Expand Down Expand Up @@ -1508,6 +1616,13 @@ func (ct *ColumnType) Format(buf *TrackedBuffer) {
}
}

// String returns a canonical string representation of the type and all relevant options
func (ct *ColumnType) String() string {
buf := NewTrackedBuffer(nil)
ct.Format(buf)
return buf.String()
}

// DescribeType returns the abbreviated type information as required for
// describe table
func (ct *ColumnType) DescribeType() string {
Expand Down Expand Up @@ -2045,6 +2160,7 @@ type Show struct {
Scope string
ShowCollationFilterOpt *Expr
ShowIndexFilterOpt Expr
ProcFuncFilter *ShowFilter
}

// Format formats the node.
Expand Down Expand Up @@ -2078,6 +2194,20 @@ func (node *Show) Format(buf *TrackedBuffer) {
buf.Myprintf("show create trigger %v", node.Table)
return
}
if node.Type == "procedure status" {
buf.Myprintf("show procedure status")
if node.ProcFuncFilter != nil {
buf.Myprintf("%v", node.ProcFuncFilter)
}
return
}
if node.Type == "function status" {
buf.Myprintf("show function status")
if node.ProcFuncFilter != nil {
buf.Myprintf("%v", node.ProcFuncFilter)
}
return
}
if node.Database != "" {
notExistsOpt := ""
if node.IfNotExists {
Expand Down Expand Up @@ -2115,7 +2245,16 @@ func (node *Show) HasTable() bool {
}

func (node *Show) walkSubtree(visit Visit) error {
return nil
if node == nil {
return nil
}
return Walk(
visit,
node.OnTable,
node.Table,
node.ShowIndexFilterOpt,
node.ProcFuncFilter,
)
}

// ShowTablesOpt is show tables option
Expand Down Expand Up @@ -3186,6 +3325,13 @@ func (node *SQLVal) Format(buf *TrackedBuffer) {
}
}

// String returns the node as a string, similar to Format.
func (node *SQLVal) String() string {
buf := NewTrackedBuffer(nil)
node.Format(buf)
return buf.String()
}

func (node *SQLVal) walkSubtree(visit Visit) error {
return nil
}
Expand Down
77 changes: 71 additions & 6 deletions go/vt/sqlparser/parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1265,10 +1265,13 @@ var (
output: "show events",
}, {
input: "show function code func",
output: "show function",
output: "show function code",
}, {
input: "show function status",
output: "show function",
input: "show function status",
}, {
input: "show function status where Name = 'hi'",
}, {
input: "show function status like 'hi'",
}, {
input: "show grants for 'root@localhost'",
output: "show grants",
Expand Down Expand Up @@ -1322,10 +1325,13 @@ var (
output: "show privileges",
}, {
input: "show procedure code p",
output: "show procedure",
output: "show procedure code",
}, {
input: "show procedure status",
}, {
input: "show procedure status",
output: "show procedure",
input: "show procedure status where Name = 'hi'",
}, {
input: "show procedure status like 'hi'",
}, {
input: "show processlist",
output: "show processlist",
Expand Down Expand Up @@ -1679,11 +1685,69 @@ var (
}, {
input: "delete a.*, b.* from tbl_a a, tbl_b b where a.id = b.id and b.name = 'test'",
output: "delete a, b from tbl_a as a, tbl_b as b where a.id = b.id and b.name = 'test'",
}, {
input: "call f1",
}, {
input: "call f1()",
output: "call f1",
}, {
input: "call f1 ()",
output: "call f1",
}, {
input: "call f1(x)",
}, {
input: "call f1(@x, @y)",
}, {
input: "call f1(now(), rand())",
}, {
input: "drop procedure p1",
}, {
input: "drop procedure if exists p1",
}, {
input: "create procedure p1() select rand()",
output: "create procedure p1 () select rand() from dual",
}, {
input: "create procedure p1() language sql deterministic sql security invoker select 1+1",
output: "create procedure p1 () language sql deterministic sql security invoker select 1 + 1 from dual",
},{
input: "create definer = me procedure p1(v1 int) select now()",
output: "create definer = me procedure p1 (in v1 int) select now() from dual",
}, {
input: "create definer = me procedure p1(v1 int) comment 'some_comment' not deterministic select now()",
output: "create definer = me procedure p1 (in v1 int) comment 'some_comment' not deterministic select now() from dual",
},
}
// Any tests that contain multiple statements within the body (such as BEGIN/END blocks) should go here.
// validSQL is used by TestParseNextValid, which expects a semicolon to mean the end of a full statement.
// Multi-statement bodies do not follow this expectation, hence they are excluded from TestParseNextValid.
validMultiStatementSql = []parseTest{
{
input: "create procedure p1 (in v1 int, inout v2 char(2), out v3 datetime) begin select rand() * 10; end",
output: "create procedure p1 (in v1 int, inout v2 char(2), out v3 datetime) begin\nselect rand() * 10 from dual;\nend",
}, {
input: "create procedure p1(v1 datetime)\nif rand() < 1 then select rand();\nend if",
output: "create procedure p1 (in v1 datetime) if rand() < 1 then select rand() from dual;\nend if",
}, {
input: `create procedure p1(n double, m double)
begin
set @s = '';
if n = m then set @s = 'equals';
else
if n > m then set @s = 'greater';
else set @s = 'less';
end if;
set @s = concat('is ', @s, ' than');
end if;
set @s = concat(n, ' ', @s, ' ', m, '.');
select @s;
end`,
output: "create procedure p1 (in n double, in m double) begin\nset @s = '';\nif n = m then set @s = 'equals';\nelse if n > m then set @s = 'greater';\nelse set @s = 'less';\nend if; set @s = concat('is ', @s, ' than');\nend if;\nset @s = concat(n, ' ', @s, ' ', m, '.');\nselect @s from dual;\nend",
},
}
)

func TestValid(t *testing.T) {
validSQL = append(validSQL, validMultiStatementSql...)
for _, tcase := range validSQL {
t.Run(tcase.input, func(t *testing.T) {
if tcase.output == "" {
Expand Down Expand Up @@ -1858,6 +1922,7 @@ func TestCreateViewSelectPosition(t *testing.T) {

// Ensure there is no corruption from using a pooled yyParserImpl in Parse.
func TestValidParallel(t *testing.T) {
validSQL = append(validSQL, validMultiStatementSql...)
parallelism := 100
numIters := 1000

Expand Down

0 comments on commit f72720c

Please sign in to comment.