From a56a72cfe1f041504b49b78d63f1af621a67fcb9 Mon Sep 17 00:00:00 2001 From: Dery Rahman Ahaddienata Date: Fri, 21 Feb 2025 15:33:49 +0700 Subject: [PATCH] feat: capability to run query with proper ordering --- mc2mc/internal/query/builder.go | 32 ++++---- mc2mc/internal/query/builder_test.go | 82 +++++++++++-------- mc2mc/internal/query/helper.go | 64 +++++++++++++++ mc2mc/internal/query/helper_test.go | 114 +++++++++++++++++++++++++++ 4 files changed, 242 insertions(+), 50 deletions(-) diff --git a/mc2mc/internal/query/builder.go b/mc2mc/internal/query/builder.go index 7834e73..a2a9ac0 100644 --- a/mc2mc/internal/query/builder.go +++ b/mc2mc/internal/query/builder.go @@ -60,20 +60,21 @@ func (b *Builder) Build() (string, error) { return "", errors.New("query is required") } - // separate headers, variables and udfs from the query - hr, query := SeparateHeadersAndQuery(b.query) - varsAndUDFs, query := SeparateVariablesUDFsAndQuery(query) - drops, query := SeparateDropsAndQuery(query) - if b.method == MERGE { - queries := semicolonPattern.Split(query, -1) + // split query components + hrs, vars, queries := SplitQueryComponents(b.query) if len(queries) <= 1 { return b.query, nil } - query = b.constructMergeQuery(hr, drops, varsAndUDFs, queries) + query := b.constructMergeQuery(hrs, vars, queries) return query, nil } + // separate headers, variables and udfs from the query + hr, query := SeparateHeadersAndQuery(b.query) + varsAndUDFs, query := SeparateVariablesUDFsAndQuery(query) + drops, query := SeparateDropsAndQuery(query) + // destination table is required for append and replace method if b.destinationTableID == "" { return "", errors.New("destination table is required") @@ -187,21 +188,20 @@ func (b *Builder) constructOverridedValues(query string) (string, error) { } // constructMergeQueries constructs merge queries with headers and variables -func (b *Builder) constructMergeQuery(hr, drops, varsAndUDFs string, queries []string) string { +func (b *Builder) constructMergeQuery(hrs, vars, queries []string) string { builder := strings.Builder{} - if drops != "" { - builder.WriteString(fmt.Sprintf("%s\n", hr)) - builder.WriteString(fmt.Sprintf("%s\n", drops)) - builder.WriteString(fmt.Sprintf("%s\n", BREAK_MARKER)) - } for i, q := range queries { q = strings.TrimSpace(q) if q == "" || strings.TrimSpace(RemoveComments(q)) == "" { continue } - builder.WriteString(fmt.Sprintf("%s\n", hr)) - if varsAndUDFs != "" { - builder.WriteString(fmt.Sprintf("%s\n", varsAndUDFs)) + headers := JoinSliceString(hrs[:i+1], "\n") + variables := JoinSliceString(vars[:i+1], "\n") + if headers != "" { + builder.WriteString(fmt.Sprintf("%s\n", headers)) + } + if variables != "" { + builder.WriteString(fmt.Sprintf("%s\n", variables)) } builder.WriteString(fmt.Sprintf("%s\n;", q)) if i < len(queries)-1 { diff --git a/mc2mc/internal/query/builder_test.go b/mc2mc/internal/query/builder_test.go index dfb1f0b..28862d2 100644 --- a/mc2mc/internal/query/builder_test.go +++ b/mc2mc/internal/query/builder_test.go @@ -540,16 +540,12 @@ SET append_test.id = 2;` assert.NoError(t, err) assert.Equal(t, `SET odps.table.append2.enable=true ; -@src := SELECT 1 id -; CREATE TABLE IF NOT EXISTS append_test (id bigint) TBLPROPERTIES('table.format.version'='2') ; --*--optimus-break-marker--*-- SET odps.table.append2.enable=true ; -@src := SELECT 1 id -; INSERT OVERWRITE TABLE append_test VALUES(0),(1) ; --*--optimus-break-marker--*-- @@ -601,19 +597,6 @@ SET append_test.id = 2;` assert.NoError(t, err) assert.Equal(t, `SET odps.table.append2.enable=true ; -FUNCTION castStringToBoolean (@field STRING) AS CASE -WHEN TOLOWER(@field) = '1.0' THEN true -WHEN TOLOWER(@field) = '0.0' THEN false -WHEN TOLOWER(@field) = '1' THEN true -WHEN TOLOWER(@field) = '0' THEN false -WHEN TOLOWER(@field) = 'true' THEN true -WHEN TOLOWER(@field) = 'false' THEN false -END -; -function my_add(@a BIGINT) as @a + 1 -; -@src := SELECT my_add(1) id -; CREATE TABLE IF NOT EXISTS append_test (id bigint) TBLPROPERTIES('table.format.version'='2') ; @@ -631,8 +614,6 @@ END ; function my_add(@a BIGINT) as @a + 1 ; -@src := SELECT my_add(1) id -; INSERT OVERWRITE TABLE append_test VALUES(0),(1) ; --*--optimus-break-marker--*-- @@ -697,19 +678,6 @@ SET append_test.id = 2;` assert.NoError(t, err) assert.Equal(t, `SET odps.table.append2.enable=true ; -FUNCTION castStringToBoolean (@field STRING) AS CASE -WHEN TOLOWER(@field) = '1.0' THEN true -WHEN TOLOWER(@field) = '0.0' THEN false -WHEN TOLOWER(@field) = '1' THEN true -WHEN TOLOWER(@field) = '0' THEN false -WHEN TOLOWER(@field) = 'true' THEN true -WHEN TOLOWER(@field) = 'false' THEN false -END -; -function my_add(@a BIGINT) as @a + 1 -; -@src := SELECT my_add(1) id -; CREATE TABLE IF NOT EXISTS append_test (id bigint) TBLPROPERTIES('table.format.version'='2') ; @@ -727,8 +695,6 @@ END ; function my_add(@a BIGINT) as @a + 1 ; -@src := SELECT my_add(1) id -; INSERT OVERWRITE TABLE append_test VALUES(0),(1) ; --*--optimus-break-marker--*-- @@ -752,6 +718,54 @@ USING (SELECT castStringToBoolean(id) FROM @src) source on append_test.id = source.id WHEN MATCHED THEN UPDATE SET append_test.id = 2 +;`, query) + }) + + t.Run("returns query for merge load method with proper variable ordering", func(t *testing.T) { + queryToExecute := `SET odps.table.append2.enable=true; +DROP TABLE IF EXISTS append_tmp; +@src := SELECT 1 id; + +CREATE TABLE append_tmp AS SELECT * FROM @src; + +@src2 := SELECT id FROM append_tmp; + +MERGE INTO append_test +USING (SELECT * FROM @src2) source +on append_test.id = source.id +WHEN MATCHED THEN UPDATE +SET append_test.id = 2;` + odspClient := &mockOdpsClient{} + query, err := query.NewBuilder( + logger.NewDefaultLogger(), + odspClient, + query.WithQuery(queryToExecute), + query.WithMethod(query.MERGE), + ).Build() + assert.NoError(t, err) + assert.Equal(t, `SET odps.table.append2.enable=true +; +DROP TABLE IF EXISTS append_tmp +; +--*--optimus-break-marker--*-- +SET odps.table.append2.enable=true +; +@src := SELECT 1 id +; +CREATE TABLE append_tmp AS SELECT * FROM @src +; +--*--optimus-break-marker--*-- +SET odps.table.append2.enable=true +; +@src := SELECT 1 id +; +@src2 := SELECT id FROM append_tmp +; +MERGE INTO append_test +USING (SELECT * FROM @src2) source +on append_test.id = source.id +WHEN MATCHED THEN UPDATE +SET append_test.id = 2 ;`, query) }) } diff --git a/mc2mc/internal/query/helper.go b/mc2mc/internal/query/helper.go index 40fd81b..db03183 100644 --- a/mc2mc/internal/query/helper.go +++ b/mc2mc/internal/query/helper.go @@ -22,6 +22,70 @@ var ( stringPattern = regexp.MustCompile(`'[^']*'`) // regex to match SQL strings (anything inside single quotes) ) +func SplitQueryComponents(query string) (headers []string, varsUDFs []string, queries []string) { + query = strings.TrimSpace(query) + + // extract all header, variable and query lines + stmts := semicolonPattern.Split(query, -1) + queryIndex := 0 + for _, stmt := range stmts { + stmt = strings.TrimSpace(stmt) + if stmt == "" { + continue + } + stmtWithoutComment := RemoveComments(stmt) + if headerPattern.MatchString(strings.TrimSpace(stmtWithoutComment)) { + for len(headers) <= queryIndex { + headers = append(headers, "") + } + headers[queryIndex] += strings.TrimSpace(stmt) + "\n;\n" + } else if variablePattern.MatchString(strings.TrimSpace(stmtWithoutComment)) || + udfPattern.MatchString(strings.TrimSpace(stmtWithoutComment)) { + for len(varsUDFs) <= queryIndex { + varsUDFs = append(varsUDFs, "") + } + varsUDFs[queryIndex] += strings.TrimSpace(stmt) + "\n;\n" + } else if strings.TrimSpace(stmtWithoutComment) == "" { + // if the statement is empty, it's a comment, then omit it + // since it doesn't make sense to execute this statement + } else { + queries = append(queries, stmt) + queryIndex++ + } + } + + // fill in empty headers and varsUDFs + clear whitespace + for i := range queries { + if len(headers) == i { + headers = append(headers, "") + } + if len(varsUDFs) == i { + varsUDFs = append(varsUDFs, "") + } + headers[i] = strings.TrimSpace(headers[i]) + varsUDFs[i] = strings.TrimSpace(varsUDFs[i]) + queries[i] = strings.TrimSpace(queries[i]) + } + + return headers, varsUDFs, queries +} + +// JoinSliceString joins a slice of strings with a delimiter +// and skips empty strings +func JoinSliceString(slice []string, delimiter string) string { + builder := strings.Builder{} + for i, s := range slice { + if s == "" { + continue + } + if i > 0 { + builder.WriteString(delimiter) + } + builder.WriteString(s) + } + return strings.TrimSpace(builder.String()) +} + func SeparateHeadersAndQuery(query string) (string, string) { headers := []string{} query = strings.TrimSpace(query) diff --git a/mc2mc/internal/query/helper_test.go b/mc2mc/internal/query/helper_test.go index 34892b2..655bf8c 100644 --- a/mc2mc/internal/query/helper_test.go +++ b/mc2mc/internal/query/helper_test.go @@ -8,6 +8,120 @@ import ( "github.com/goto/transformers/mc2mc/internal/query" ) +func TestSplitQueryComponents(t *testing.T) { + t.Run("returns query without headers and variables", func(t *testing.T) { + q1 := `select * from playground` + headers, varsUDFs, queries := query.SplitQueryComponents(q1) + assert.Len(t, headers, 1) + assert.Len(t, varsUDFs, 1) + assert.Len(t, queries, 1) + assert.Empty(t, headers[0]) + assert.Empty(t, varsUDFs[0]) + assert.Equal(t, q1, queries[0]) + }) + t.Run("returns headers, vars, and queries with proper order", func(t *testing.T) { + q1 := `set odps.sql.allow.fullscan=true; +set odps.sql.python.version=cp37; +DROP TABLE IF EXISTS append_test_tmp; + +@src := SELECT 1 id; +@src2 := SELECT id +FROM @src +WHERE id = 1; +CREATE TABLE append_test_tmp AS SELECT * FROM @src2; + +MERGE INTO append_test_tmp USING (SELECT * FROM @src) source +on append_test_tmp.id = source.id +WHEN MATCHED THEN UPDATE +SET append_test_tmp.id = 2; + +@src3 := SELECT id FROM append_test_tmp WHERE id = 2; +MERGE INTO append_test USING (SELECT * FROM @src3) source +on append_test.id = source.id +WHEN MATCHED THEN UPDATE +SET append_test.id = 3; + +MERGE INTO append_test USING (SELECT * FROM @src3) source +on append_test.id = source.id +WHEN MATCHED THEN UPDATE +SET append_test.id = 3; +` + headers, varsUDFs, queries := query.SplitQueryComponents(q1) + assert.Len(t, headers, 5) + assert.Len(t, varsUDFs, 5) + assert.Len(t, queries, 5) + + // headers asserts + headersExpected := make([]string, 5) + headersExpected[0] = `set odps.sql.allow.fullscan=true +; +set odps.sql.python.version=cp37 +;` + headersExpected[1] = "" + headersExpected[2] = "" + headersExpected[3] = "" + headersExpected[4] = "" + + // vars asserts + varsExpected := make([]string, 5) + varsExpected[0] = "" + varsExpected[1] = `@src := SELECT 1 id +; +@src2 := SELECT id +FROM @src +WHERE id = 1 +;` + varsExpected[2] = "" + varsExpected[3] = `@src3 := SELECT id FROM append_test_tmp WHERE id = 2 +;` + varsExpected[4] = "" + + // queries asserts + queriesExpected := make([]string, 5) + queriesExpected[0] = "DROP TABLE IF EXISTS append_test_tmp" + queriesExpected[1] = "CREATE TABLE append_test_tmp AS SELECT * FROM @src2" + queriesExpected[2] = `MERGE INTO append_test_tmp USING (SELECT * FROM @src) source +on append_test_tmp.id = source.id +WHEN MATCHED THEN UPDATE +SET append_test_tmp.id = 2` + queriesExpected[3] = `MERGE INTO append_test USING (SELECT * FROM @src3) source +on append_test.id = source.id +WHEN MATCHED THEN UPDATE +SET append_test.id = 3` + queriesExpected[4] = `MERGE INTO append_test USING (SELECT * FROM @src3) source +on append_test.id = source.id +WHEN MATCHED THEN UPDATE +SET append_test.id = 3` + + for i := range queries { + assert.Equal(t, headersExpected[i], headers[i]) + assert.Equal(t, varsExpected[i], varsUDFs[i]) + assert.Equal(t, queriesExpected[i], queries[i]) + } + }) +} + +func TestJoinSliceString(t *testing.T) { + t.Run("returns empty string for empty slice", func(t *testing.T) { + slice := []string{} + delimiter := ";" + result := query.JoinSliceString(slice, delimiter) + assert.Empty(t, result) + }) + t.Run("returns joined string with delimiter", func(t *testing.T) { + slice := []string{"set odps.sql.allow.fullscan=true", "set odps.sql.python.version=cp37"} + delimiter := ";" + result := query.JoinSliceString(slice, delimiter) + assert.Equal(t, "set odps.sql.allow.fullscan=true;set odps.sql.python.version=cp37", result) + }) + t.Run("returns joined string with delimiter and skips empty strings", func(t *testing.T) { + slice := []string{"set odps.sql.allow.fullscan=true", "", "set odps.sql.python.version=cp37"} + delimiter := ";" + result := query.JoinSliceString(slice, delimiter) + assert.Equal(t, "set odps.sql.allow.fullscan=true;set odps.sql.python.version=cp37", result) + }) +} + func TestSeparateHeadersAndQuery(t *testing.T) { t.Run("returns query without macros", func(t *testing.T) { q1 := `select * from playground`