diff --git a/mc2mc/internal/query/builder.go b/mc2mc/internal/query/builder.go index a2a9ac0..8678a0e 100644 --- a/mc2mc/internal/query/builder.go +++ b/mc2mc/internal/query/builder.go @@ -200,7 +200,7 @@ func (b *Builder) constructMergeQuery(hrs, vars, queries []string) string { if headers != "" { builder.WriteString(fmt.Sprintf("%s\n", headers)) } - if variables != "" { + if variables != "" && !IsDDL(q) { // skip variables if it's ddl builder.WriteString(fmt.Sprintf("%s\n", variables)) } builder.WriteString(fmt.Sprintf("%s\n;", q)) diff --git a/mc2mc/internal/query/builder_test.go b/mc2mc/internal/query/builder_test.go index 28862d2..8120e42 100644 --- a/mc2mc/internal/query/builder_test.go +++ b/mc2mc/internal/query/builder_test.go @@ -766,6 +766,62 @@ USING (SELECT * FROM @src2) source on append_test.id = source.id WHEN MATCHED THEN UPDATE SET append_test.id = 2 +;`, query) + }) + t.Run("returns query for merge load method with correct ddl ordering", func(t *testing.T) { + queryToExecute := `SET odps.table.append2.enable=true; +@src := SELECT 1 id; + +@src2 := SELECT id FROM append_tmp; +DROP TABLE IF EXISTS append_tmp; + +CREATE TABLE append_tmp AS SELECT * FROM @src; + +CREATE TABLE append_tmp2(id bigint); + +MERGE INTO append_test +USING (SELECT * FROM @src2) source +on append_test.id = source.id +WHEN MATCHED THEN UPDATE +SET append_test.id = 2;` + odspClient := &mockOdpsClient{} + query, err := query.NewBuilder( + logger.NewDefaultLogger(), + odspClient, + query.WithQuery(queryToExecute), + query.WithMethod(query.MERGE), + ).Build() + assert.NoError(t, err) + assert.Equal(t, `SET odps.table.append2.enable=true +; +DROP TABLE IF EXISTS append_tmp +; +--*--optimus-break-marker--*-- +SET odps.table.append2.enable=true +; +@src := SELECT 1 id +; +@src2 := SELECT id FROM append_tmp +; +CREATE TABLE append_tmp AS SELECT * FROM @src +; +--*--optimus-break-marker--*-- +SET odps.table.append2.enable=true +; +CREATE TABLE append_tmp2(id bigint) +; +--*--optimus-break-marker--*-- +SET odps.table.append2.enable=true +; +@src := SELECT 1 id +; +@src2 := SELECT id FROM append_tmp +; +MERGE INTO append_test +USING (SELECT * FROM @src2) source +on append_test.id = source.id +WHEN MATCHED THEN UPDATE +SET append_test.id = 2 ;`, query) }) } diff --git a/mc2mc/internal/query/helper.go b/mc2mc/internal/query/helper.go index db03183..050ecc2 100644 --- a/mc2mc/internal/query/helper.go +++ b/mc2mc/internal/query/helper.go @@ -11,15 +11,16 @@ const ( ) var ( - semicolonPattern = regexp.MustCompile(`;\s*(\n+|$)`) // regex to match semicolons - commentPattern = regexp.MustCompile(`--[^\n]*`) // regex to match comments - multiCommentPattern = regexp.MustCompile(`(?s)/\*.*?\*/`) // regex to match multi-line comments - headerPattern = regexp.MustCompile(`(?i)^set`) // regex to match header statements - variablePattern = regexp.MustCompile(`(?i)^@`) // regex to match variable statements - dropPattern = regexp.MustCompile(`(?i)^DROP\s+`) // regex to match DROP statements - udfPattern = regexp.MustCompile(`(?i)^function\s+`) // regex to match UDF statements - ddlPattern = regexp.MustCompile(`(?i)^CREATE\s+`) // regex to match DDL statements - stringPattern = regexp.MustCompile(`'[^']*'`) // regex to match SQL strings (anything inside single quotes) + semicolonPattern = regexp.MustCompile(`;\s*(\n+|$)`) // regex to match semicolons + commentPattern = regexp.MustCompile(`--[^\n]*`) // regex to match comments + multiCommentPattern = regexp.MustCompile(`(?s)/\*.*?\*/`) // regex to match multi-line comments + headerPattern = regexp.MustCompile(`(?i)^set`) // regex to match header statements + variablePattern = regexp.MustCompile(`(?i)^@`) // regex to match variable statements + dropPattern = regexp.MustCompile(`(?i)^DROP\s+`) // regex to match DROP statements + udfPattern = regexp.MustCompile(`(?i)^function\s+`) // regex to match UDF statements + ddlPattern = regexp.MustCompile(`(?i)^(ALTER|DROP|TRUNCATE)\s+`) // regex to match DDL statements + ddlCreatePattern = regexp.MustCompile(`(?i)^(CREATE\s+TABLE\s+[^\s]+\s*\()`) // regex to match CREATE DDL statements + stringPattern = regexp.MustCompile(`'[^']*'`) // regex to match SQL strings (anything inside single quotes) ) func SplitQueryComponents(query string) (headers []string, varsUDFs []string, queries []string) { @@ -227,6 +228,7 @@ func RestoreStringLiteral(query string, placeholders map[string]string) string { return query } -func IsDDL(query string) bool { - return ddlPattern.MatchString(query) +func IsDDL(stmt string) bool { + stmtWithoutComment := RemoveComments(stmt) + return ddlPattern.MatchString(strings.TrimSpace(stmtWithoutComment)) || ddlCreatePattern.MatchString(strings.TrimSpace(stmtWithoutComment)) }