diff --git a/parser/simple_parser.go b/parser/simple_parser.go index aa15af82..c34081f2 100644 --- a/parser/simple_parser.go +++ b/parser/simple_parser.go @@ -109,10 +109,10 @@ func (p *simpleParser) eatDollarQuotedString(tag string) (string, bool) { if potentialTag == tag { return string(p.sql[startPos:posBeforeTag]), true } - // This is a nested dollar-tag. Nested dollar-quoted strings are allowed. - if _, ok = p.eatDollarQuotedString(potentialTag); !ok { - return "", false - } + // This looks like a nested dollar-tag. Nested dollar-quoted strings are allowed. + // These should just be ignored, as we don't know if it is actually a nested + // dollar-quoted string, or just a random dollar sign inside the current string. + p.pos = posBeforeTag } } return "", false diff --git a/parser/statement_parser.go b/parser/statement_parser.go index 81603442..37f554dd 100644 --- a/parser/statement_parser.go +++ b/parser/statement_parser.go @@ -788,3 +788,75 @@ func detectDmlKeyword(keyword string) DmlType { } return DmlTypeUnknown } + +// Split splits a SQL string that potentially contains multiple statements separated by +// semicolons into individual statements. +// +// Returns false if the SQL string only contains a single statement (including when the SQL string +// contains a single statement that is terminated by a semicolon). Whitespaces and comments after +// the last semicolon in the SQL string are ignored. +// +// Returns true and a slice containing the individual statements if the SQL string contains more +// than one statement. The slice always contains more than one element. +func (p *StatementParser) Split(sql string) (bool, []string, error) { + return p.split(sql, ';') +} + +func (p *StatementParser) split(sql string, sep byte) (bool, []string, error) { + // Return early if the string does not contain the separator. + firstIndex := strings.IndexByte(sql, sep) + if firstIndex == -1 { + return false, nil, nil + } + tokens := []byte(sql) + // Also return early if it is a single statement that is just terminated by a semicolon. + if firstIndex == len(tokens)-1 { + return false, nil, nil + } + + res := make([]string, 0) + parser := &simpleParser{sql: tokens, statementParser: p} + startPos := 0 + for parser.pos < len(parser.sql) { + parser.skipWhitespacesAndComments() + if parser.pos >= len(parser.sql) { + break + } + if parser.isMultibyte() { + parser.nextChar() + continue + } + c := parser.sql[parser.pos] + if c == sep { + res = append(res, sql[startPos:parser.pos]) + parser.nextChar() + startPos = parser.pos + // Skip whitespaces / comments etc. and check if we are at the end of the SQL string. + // This prevents that an empty statement is added to the end of the slice if there are only whitespaces + // after the last semicolon. + parser.skipWhitespacesAndComments() + if parser.pos >= len(parser.sql) { + if len(res) == 1 { + return false, nil, nil + } + return true, res, nil + } + continue + } + newPos, err := p.skip(parser.sql, parser.pos) + if err != nil { + return false, nil, err + } + parser.pos = newPos + } + if len(res) == 0 { + // This means that the SQL string contains one or more semicolons, but that all of them + // are inside quoted literals, quoted identifiers or comments. This again means that there + // is only one statement in the string. + return false, nil, nil + } + + // The last statement does not need to be terminated by a semicolon, so we add it here. + res = append(res, sql[startPos:]) + return true, res, nil +} diff --git a/parser/statement_parser_test.go b/parser/statement_parser_test.go index 2e5ba147..59fa39da 100644 --- a/parser/statement_parser_test.go +++ b/parser/statement_parser_test.go @@ -16,6 +16,7 @@ package parser import ( "fmt" + "reflect" "testing" "cloud.google.com/go/spanner" @@ -2306,8 +2307,8 @@ func TestEatDollarQuotedString(t *testing.T) { }, { input: "$outer$ outer string $inner$ mismatched tag $outer$ second part of outer string $inner$", - want: "", - wantErr: true, + want: " outer string $inner$ mismatched tag ", + wantErr: false, }, { input: "$outer$ outer string $outer $outer$", @@ -2329,6 +2330,10 @@ func TestEatDollarQuotedString(t *testing.T) { want: "value $tag/*not a comment*/", wantErr: false, }, + { + input: "$bar$Hello$$;World!$bar$", + want: "Hello$$;World!", + }, } statementParser, err := NewStatementParser(databasepb.DatabaseDialect_POSTGRESQL, 1000) if err != nil { @@ -2695,6 +2700,257 @@ func TestEatIdentifier(t *testing.T) { } } +func TestSplit(t *testing.T) { + t.Parallel() + + tests := []struct { + input string + dialect databasepb.DatabaseDialect + want bool + wantRes []string + wantErr bool + }{ + { + input: "", + }, + { + input: ";", + }, + { + input: ";;", + want: true, + wantRes: []string{"", ""}, + }, + { + input: "select * from my_table", + }, + { + input: "select * from my_table;", + }, + { + input: "select * from my_table\n;\n", + }, + { + input: "begin; select * from my_table", + want: true, + wantRes: []string{"begin", " select * from my_table"}, + }, + { + input: "begin; select * from my_table;", + want: true, + wantRes: []string{"begin", " select * from my_table"}, + }, + { + input: "begin;\n" + + "insert into my_table (id, value) values (1, 'One');\n" + + "select * from my_table;\n" + + "commit;", + want: true, + wantRes: []string{ + "begin", + "\ninsert into my_table (id, value) values (1, 'One')", + "\nselect * from my_table", + "\ncommit", + }, + }, + { + input: "select 'begin;' from my_table;", + }, + { + input: "select * from my_table where value = 'test;'", + }, + { + input: "@{hint=';value;'} select * from my_table where value = 'test;'", + }, + { + input: "-- Comment;\nselect * from my_table", + }, + { + input: "-- Comment;\nbegin; select * from my_table", + want: true, + wantRes: []string{"-- Comment;\nbegin", " select * from my_table"}, + }, + { + input: "-- Comment1;\nbegin--Comment2;\n; select * from my_table;--Comment3", + want: true, + wantRes: []string{"-- Comment1;\nbegin--Comment2;\n", " select * from my_table"}, + }, + { + input: "; begin; commit;", + want: true, + wantRes: []string{"", " begin", " commit"}, + }, + { + input: "begin; commit;;", + want: true, + wantRes: []string{"begin", " commit", ""}, + }, + { + input: "select 1 --; select 2\n;select 3", + want: true, + wantRes: []string{"select 1 --; select 2\n", "select 3"}, + }, + { + input: "-- select 1;\nselect 2;select 3", + want: true, + wantRes: []string{"-- select 1;\nselect 2", "select 3"}, + }, + { + input: "select 1 /* Comment */", + }, + { + input: " /* Comment */ select 1", + }, + { + input: "select 1 /*; select 2 */;select 3", + want: true, + wantRes: []string{"select 1 /*; select 2 */", "select 3"}, + }, + { + input: "/* select 1; */select 2;select 3", + want: true, + wantRes: []string{"/* select 1; */select 2", "select 3"}, + }, + { + input: "select 'Hello World!'", + }, + { + input: "select 'Hello;World!'", + }, + { + input: "select 'Hello;World!';", + }, + { + input: "select 'Hello;World!'; select 1", + want: true, + wantRes: []string{"select 'Hello;World!'", " select 1"}, + }, + { + // Semicolons inside brackets are not treated in any special way. + input: "select (select 1;); select 2;", + want: true, + wantRes: []string{"select (select 1", ")", " select 2"}, + }, + { + input: "select 'Hello''World!'", + dialect: databasepb.DatabaseDialect_POSTGRESQL, + }, + { + input: "select 'Hello'';World!'", + dialect: databasepb.DatabaseDialect_POSTGRESQL, + }, + { + input: "select 'Hello\"World!'", + dialect: databasepb.DatabaseDialect_POSTGRESQL, + }, + { + input: "select 'Hello\";World!'", + dialect: databasepb.DatabaseDialect_POSTGRESQL, + }, + { + input: `select * from "my""table"`, + dialect: databasepb.DatabaseDialect_POSTGRESQL, + }, + { + input: `select * from "my"";table"`, + dialect: databasepb.DatabaseDialect_POSTGRESQL, + }, + { + input: `select * from "my';table"; select 1`, + dialect: databasepb.DatabaseDialect_POSTGRESQL, + want: true, + wantRes: []string{`select * from "my';table"`, ` select 1`}, + }, + { + input: "/* This block comment surrounds a query which itself has a block comment...\n" + + "SELECT /* embedded single line */ 'embedded' AS x2;\n" + + "*/\n" + + "SELECT 1;", + dialect: databasepb.DatabaseDialect_POSTGRESQL, + }, + { + input: "/* This block comment surrounds a query which itself has a block comment...\n" + + "SELECT /* embedded single line */ 'embedded' AS x2;\n" + + "*/\n" + + "SELECT 1; SELECT 2;", + dialect: databasepb.DatabaseDialect_POSTGRESQL, + want: true, + wantRes: []string{"/* This block comment surrounds a query which itself has a block comment...\n" + + "SELECT /* embedded single line */ 'embedded' AS x2;\n" + + "*/\n" + + "SELECT 1", " SELECT 2"}, + }, + { + input: "select $$Hello World!$$", + dialect: databasepb.DatabaseDialect_POSTGRESQL, + }, + { + input: "select $$Hello;World!$$", + dialect: databasepb.DatabaseDialect_POSTGRESQL, + }, + { + input: "select $$Hello;World!$$;", + dialect: databasepb.DatabaseDialect_POSTGRESQL, + }, + { + input: "select $$Hello;World!$$; select 1", + dialect: databasepb.DatabaseDialect_POSTGRESQL, + want: true, + wantRes: []string{"select $$Hello;World!$$", " select 1"}, + }, + { + input: "select $$Hello$;World!$$; select 1", + dialect: databasepb.DatabaseDialect_POSTGRESQL, + want: true, + wantRes: []string{"select $$Hello$;World!$$", " select 1"}, + }, + { + input: "select $bar$Hello$;World!$bar$; select 1", + dialect: databasepb.DatabaseDialect_POSTGRESQL, + want: true, + wantRes: []string{"select $bar$Hello$;World!$bar$", " select 1"}, + }, + { + input: "select $bar$Hello$$;World!$bar$; select 1", + dialect: databasepb.DatabaseDialect_POSTGRESQL, + want: true, + wantRes: []string{"select $bar$Hello$$;World!$bar$", " select 1"}, + }, + { + input: "select $bar$Hello$baz$;World!$bar$; select 1", + dialect: databasepb.DatabaseDialect_POSTGRESQL, + want: true, + wantRes: []string{"select $bar$Hello$baz$;World!$bar$", " select 1"}, + }, + } + for _, dialect := range []databasepb.DatabaseDialect{ + databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL, + databasepb.DatabaseDialect_POSTGRESQL, + } { + parser, err := NewStatementParser(dialect, 1000) + if err != nil { + t.Fatal(err) + } + for _, test := range tests { + if test.dialect == databasepb.DatabaseDialect_DATABASE_DIALECT_UNSPECIFIED || test.dialect == dialect { + t.Run(fmt.Sprintf("%v: %v", dialect, test.input), func(t *testing.T) { + got, res, err := parser.Split(test.input) + if g, w := got, test.want; g != w { + t.Errorf("result mismatch\n Got: %v\nWant: %v", g, w) + } + if test.want { + if g, w := res, test.wantRes; !reflect.DeepEqual(g, w) { + t.Errorf("result mismatch\n Got: %v\nWant: %v", g, w) + } + } else if test.wantErr && err == nil { + t.Error("missing expected error") + } + }) + } + } + } +} + func BenchmarkDetectStatementTypeWithoutCache(b *testing.B) { parser, err := NewStatementParser(databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL, 0) if err != nil {