Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions parser/simple_parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,10 @@ func (p *simpleParser) eatDollarQuotedString(tag string) (string, bool) {
if potentialTag == tag {
return string(p.sql[startPos:posBeforeTag]), true
}
// This is a nested dollar-tag. Nested dollar-quoted strings are allowed.
if _, ok = p.eatDollarQuotedString(potentialTag); !ok {
return "", false
}
// This looks like a nested dollar-tag. Nested dollar-quoted strings are allowed.
// These should just be ignored, as we don't know if it is actually a nested
// dollar-quoted string, or just a random dollar sign inside the current string.
p.pos = posBeforeTag
}
}
return "", false
Expand Down
72 changes: 72 additions & 0 deletions parser/statement_parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -788,3 +788,75 @@ func detectDmlKeyword(keyword string) DmlType {
}
return DmlTypeUnknown
}

// Split splits a SQL string that potentially contains multiple statements separated by
// semicolons into individual statements.
//
// Returns false if the SQL string only contains a single statement (including when the SQL string
// contains a single statement that is terminated by a semicolon). Whitespaces and comments after
// the last semicolon in the SQL string are ignored.
//
// Returns true and a slice containing the individual statements if the SQL string contains more
// than one statement. The slice always contains more than one element.
func (p *StatementParser) Split(sql string) (bool, []string, error) {
return p.split(sql, ';')
}

func (p *StatementParser) split(sql string, sep byte) (bool, []string, error) {
// Return early if the string does not contain the separator.
firstIndex := strings.IndexByte(sql, sep)
if firstIndex == -1 {
return false, nil, nil
}
tokens := []byte(sql)
// Also return early if it is a single statement that is just terminated by a semicolon.
if firstIndex == len(tokens)-1 {
return false, nil, nil
}

res := make([]string, 0)
parser := &simpleParser{sql: tokens, statementParser: p}
startPos := 0
for parser.pos < len(parser.sql) {
parser.skipWhitespacesAndComments()
if parser.pos >= len(parser.sql) {
break
}
if parser.isMultibyte() {
parser.nextChar()
continue
}
c := parser.sql[parser.pos]
if c == sep {
res = append(res, sql[startPos:parser.pos])
parser.nextChar()
startPos = parser.pos
// Skip whitespaces / comments etc. and check if we are at the end of the SQL string.
// This prevents that an empty statement is added to the end of the slice if there are only whitespaces
// after the last semicolon.
parser.skipWhitespacesAndComments()
if parser.pos >= len(parser.sql) {
if len(res) == 1 {
return false, nil, nil
}
return true, res, nil
}
continue
}
newPos, err := p.skip(parser.sql, parser.pos)
if err != nil {
return false, nil, err
}
parser.pos = newPos
}
if len(res) == 0 {
// This means that the SQL string contains one or more semicolons, but that all of them
// are inside quoted literals, quoted identifiers or comments. This again means that there
// is only one statement in the string.
return false, nil, nil
}

// The last statement does not need to be terminated by a semicolon, so we add it here.
res = append(res, sql[startPos:])
return true, res, nil
}
260 changes: 258 additions & 2 deletions parser/statement_parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package parser

import (
"fmt"
"reflect"
"testing"

"cloud.google.com/go/spanner"
Expand Down Expand Up @@ -2306,8 +2307,8 @@ func TestEatDollarQuotedString(t *testing.T) {
},
{
input: "$outer$ outer string $inner$ mismatched tag $outer$ second part of outer string $inner$",
want: "",
wantErr: true,
want: " outer string $inner$ mismatched tag ",
wantErr: false,
},
{
input: "$outer$ outer string $outer $outer$",
Expand All @@ -2329,6 +2330,10 @@ func TestEatDollarQuotedString(t *testing.T) {
want: "value $tag/*not a comment*/",
wantErr: false,
},
{
input: "$bar$Hello$$;World!$bar$",
want: "Hello$$;World!",
},
}
statementParser, err := NewStatementParser(databasepb.DatabaseDialect_POSTGRESQL, 1000)
if err != nil {
Expand Down Expand Up @@ -2695,6 +2700,257 @@ func TestEatIdentifier(t *testing.T) {
}
}

func TestSplit(t *testing.T) {
t.Parallel()

tests := []struct {
input string
dialect databasepb.DatabaseDialect
want bool
wantRes []string
wantErr bool
}{
{
input: "",
},
{
input: ";",
},
{
input: ";;",
want: true,
wantRes: []string{"", ""},
},
{
input: "select * from my_table",
},
{
input: "select * from my_table;",
},
{
input: "select * from my_table\n;\n",
},
{
input: "begin; select * from my_table",
want: true,
wantRes: []string{"begin", " select * from my_table"},
},
{
input: "begin; select * from my_table;",
want: true,
wantRes: []string{"begin", " select * from my_table"},
},
{
input: "begin;\n" +
"insert into my_table (id, value) values (1, 'One');\n" +
"select * from my_table;\n" +
"commit;",
want: true,
wantRes: []string{
"begin",
"\ninsert into my_table (id, value) values (1, 'One')",
"\nselect * from my_table",
"\ncommit",
},
},
{
input: "select 'begin;' from my_table;",
},
{
input: "select * from my_table where value = 'test;'",
},
{
input: "@{hint=';value;'} select * from my_table where value = 'test;'",
},
{
input: "-- Comment;\nselect * from my_table",
},
{
input: "-- Comment;\nbegin; select * from my_table",
want: true,
wantRes: []string{"-- Comment;\nbegin", " select * from my_table"},
},
{
input: "-- Comment1;\nbegin--Comment2;\n; select * from my_table;--Comment3",
want: true,
wantRes: []string{"-- Comment1;\nbegin--Comment2;\n", " select * from my_table"},
},
{
input: "; begin; commit;",
want: true,
wantRes: []string{"", " begin", " commit"},
},
{
input: "begin; commit;;",
want: true,
wantRes: []string{"begin", " commit", ""},
},
{
input: "select 1 --; select 2\n;select 3",
want: true,
wantRes: []string{"select 1 --; select 2\n", "select 3"},
},
{
input: "-- select 1;\nselect 2;select 3",
want: true,
wantRes: []string{"-- select 1;\nselect 2", "select 3"},
},
{
input: "select 1 /* Comment */",
},
{
input: " /* Comment */ select 1",
},
{
input: "select 1 /*; select 2 */;select 3",
want: true,
wantRes: []string{"select 1 /*; select 2 */", "select 3"},
},
{
input: "/* select 1; */select 2;select 3",
want: true,
wantRes: []string{"/* select 1; */select 2", "select 3"},
},
{
input: "select 'Hello World!'",
},
{
input: "select 'Hello;World!'",
},
{
input: "select 'Hello;World!';",
},
{
input: "select 'Hello;World!'; select 1",
want: true,
wantRes: []string{"select 'Hello;World!'", " select 1"},
},
{
// Semicolons inside brackets are not treated in any special way.
input: "select (select 1;); select 2;",
want: true,
wantRes: []string{"select (select 1", ")", " select 2"},
},
{
input: "select 'Hello''World!'",
dialect: databasepb.DatabaseDialect_POSTGRESQL,
},
{
input: "select 'Hello'';World!'",
dialect: databasepb.DatabaseDialect_POSTGRESQL,
},
{
input: "select 'Hello\"World!'",
dialect: databasepb.DatabaseDialect_POSTGRESQL,
},
{
input: "select 'Hello\";World!'",
dialect: databasepb.DatabaseDialect_POSTGRESQL,
},
{
input: `select * from "my""table"`,
dialect: databasepb.DatabaseDialect_POSTGRESQL,
},
{
input: `select * from "my"";table"`,
dialect: databasepb.DatabaseDialect_POSTGRESQL,
},
{
input: `select * from "my';table"; select 1`,
dialect: databasepb.DatabaseDialect_POSTGRESQL,
want: true,
wantRes: []string{`select * from "my';table"`, ` select 1`},
},
{
input: "/* This block comment surrounds a query which itself has a block comment...\n" +
"SELECT /* embedded single line */ 'embedded' AS x2;\n" +
"*/\n" +
"SELECT 1;",
dialect: databasepb.DatabaseDialect_POSTGRESQL,
},
{
input: "/* This block comment surrounds a query which itself has a block comment...\n" +
"SELECT /* embedded single line */ 'embedded' AS x2;\n" +
"*/\n" +
"SELECT 1; SELECT 2;",
dialect: databasepb.DatabaseDialect_POSTGRESQL,
want: true,
wantRes: []string{"/* This block comment surrounds a query which itself has a block comment...\n" +
"SELECT /* embedded single line */ 'embedded' AS x2;\n" +
"*/\n" +
"SELECT 1", " SELECT 2"},
},
{
input: "select $$Hello World!$$",
dialect: databasepb.DatabaseDialect_POSTGRESQL,
},
{
input: "select $$Hello;World!$$",
dialect: databasepb.DatabaseDialect_POSTGRESQL,
},
{
input: "select $$Hello;World!$$;",
dialect: databasepb.DatabaseDialect_POSTGRESQL,
},
{
input: "select $$Hello;World!$$; select 1",
dialect: databasepb.DatabaseDialect_POSTGRESQL,
want: true,
wantRes: []string{"select $$Hello;World!$$", " select 1"},
},
{
input: "select $$Hello$;World!$$; select 1",
dialect: databasepb.DatabaseDialect_POSTGRESQL,
want: true,
wantRes: []string{"select $$Hello$;World!$$", " select 1"},
},
{
input: "select $bar$Hello$;World!$bar$; select 1",
dialect: databasepb.DatabaseDialect_POSTGRESQL,
want: true,
wantRes: []string{"select $bar$Hello$;World!$bar$", " select 1"},
},
{
input: "select $bar$Hello$$;World!$bar$; select 1",
dialect: databasepb.DatabaseDialect_POSTGRESQL,
want: true,
wantRes: []string{"select $bar$Hello$$;World!$bar$", " select 1"},
},
{
input: "select $bar$Hello$baz$;World!$bar$; select 1",
dialect: databasepb.DatabaseDialect_POSTGRESQL,
want: true,
wantRes: []string{"select $bar$Hello$baz$;World!$bar$", " select 1"},
},
}
for _, dialect := range []databasepb.DatabaseDialect{
databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL,
databasepb.DatabaseDialect_POSTGRESQL,
} {
parser, err := NewStatementParser(dialect, 1000)
if err != nil {
t.Fatal(err)
}
for _, test := range tests {
if test.dialect == databasepb.DatabaseDialect_DATABASE_DIALECT_UNSPECIFIED || test.dialect == dialect {
t.Run(fmt.Sprintf("%v: %v", dialect, test.input), func(t *testing.T) {
got, res, err := parser.Split(test.input)
if g, w := got, test.want; g != w {
t.Errorf("result mismatch\n Got: %v\nWant: %v", g, w)
}
if test.want {
if g, w := res, test.wantRes; !reflect.DeepEqual(g, w) {
t.Errorf("result mismatch\n Got: %v\nWant: %v", g, w)
}
} else if test.wantErr && err == nil {
t.Error("missing expected error")
}
})
}
}
}
}

func BenchmarkDetectStatementTypeWithoutCache(b *testing.B) {
parser, err := NewStatementParser(databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL, 0)
if err != nil {
Expand Down
Loading