Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 16 additions & 16 deletions mc2mc/internal/query/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,20 +60,21 @@ func (b *Builder) Build() (string, error) {
return "", errors.New("query is required")
}

// separate headers, variables and udfs from the query
hr, query := SeparateHeadersAndQuery(b.query)
varsAndUDFs, query := SeparateVariablesUDFsAndQuery(query)
drops, query := SeparateDropsAndQuery(query)

if b.method == MERGE {
queries := semicolonPattern.Split(query, -1)
// split query components
hrs, vars, queries := SplitQueryComponents(b.query)
if len(queries) <= 1 {
return b.query, nil
}
query = b.constructMergeQuery(hr, drops, varsAndUDFs, queries)
query := b.constructMergeQuery(hrs, vars, queries)
return query, nil
}

// separate headers, variables and udfs from the query
hr, query := SeparateHeadersAndQuery(b.query)
varsAndUDFs, query := SeparateVariablesUDFsAndQuery(query)
drops, query := SeparateDropsAndQuery(query)

// destination table is required for append and replace method
if b.destinationTableID == "" {
return "", errors.New("destination table is required")
Expand Down Expand Up @@ -187,21 +188,20 @@ func (b *Builder) constructOverridedValues(query string) (string, error) {
}

// constructMergeQueries constructs merge queries with headers and variables
func (b *Builder) constructMergeQuery(hr, drops, varsAndUDFs string, queries []string) string {
func (b *Builder) constructMergeQuery(hrs, vars, queries []string) string {
builder := strings.Builder{}
if drops != "" {
builder.WriteString(fmt.Sprintf("%s\n", hr))
builder.WriteString(fmt.Sprintf("%s\n", drops))
builder.WriteString(fmt.Sprintf("%s\n", BREAK_MARKER))
}
for i, q := range queries {
q = strings.TrimSpace(q)
if q == "" || strings.TrimSpace(RemoveComments(q)) == "" {
continue
}
builder.WriteString(fmt.Sprintf("%s\n", hr))
if varsAndUDFs != "" {
builder.WriteString(fmt.Sprintf("%s\n", varsAndUDFs))
headers := JoinSliceString(hrs[:i+1], "\n")
variables := JoinSliceString(vars[:i+1], "\n")
if headers != "" {
builder.WriteString(fmt.Sprintf("%s\n", headers))
}
if variables != "" {
builder.WriteString(fmt.Sprintf("%s\n", variables))
}
builder.WriteString(fmt.Sprintf("%s\n;", q))
if i < len(queries)-1 {
Expand Down
82 changes: 48 additions & 34 deletions mc2mc/internal/query/builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -540,16 +540,12 @@ SET append_test.id = 2;`
assert.NoError(t, err)
assert.Equal(t, `SET odps.table.append2.enable=true
;
@src := SELECT 1 id
;
CREATE TABLE IF NOT EXISTS append_test (id bigint)
TBLPROPERTIES('table.format.version'='2')
;
--*--optimus-break-marker--*--
SET odps.table.append2.enable=true
;
@src := SELECT 1 id
;
INSERT OVERWRITE TABLE append_test VALUES(0),(1)
;
--*--optimus-break-marker--*--
Expand Down Expand Up @@ -601,19 +597,6 @@ SET append_test.id = 2;`
assert.NoError(t, err)
assert.Equal(t, `SET odps.table.append2.enable=true
;
FUNCTION castStringToBoolean (@field STRING) AS CASE
WHEN TOLOWER(@field) = '1.0' THEN true
WHEN TOLOWER(@field) = '0.0' THEN false
WHEN TOLOWER(@field) = '1' THEN true
WHEN TOLOWER(@field) = '0' THEN false
WHEN TOLOWER(@field) = 'true' THEN true
WHEN TOLOWER(@field) = 'false' THEN false
END
;
function my_add(@a BIGINT) as @a + 1
;
@src := SELECT my_add(1) id
;
CREATE TABLE IF NOT EXISTS append_test (id bigint)
TBLPROPERTIES('table.format.version'='2')
;
Expand All @@ -631,8 +614,6 @@ END
;
function my_add(@a BIGINT) as @a + 1
;
@src := SELECT my_add(1) id
;
INSERT OVERWRITE TABLE append_test VALUES(0),(1)
;
--*--optimus-break-marker--*--
Expand Down Expand Up @@ -697,19 +678,6 @@ SET append_test.id = 2;`
assert.NoError(t, err)
assert.Equal(t, `SET odps.table.append2.enable=true
;
FUNCTION castStringToBoolean (@field STRING) AS CASE
WHEN TOLOWER(@field) = '1.0' THEN true
WHEN TOLOWER(@field) = '0.0' THEN false
WHEN TOLOWER(@field) = '1' THEN true
WHEN TOLOWER(@field) = '0' THEN false
WHEN TOLOWER(@field) = 'true' THEN true
WHEN TOLOWER(@field) = 'false' THEN false
END
;
function my_add(@a BIGINT) as @a + 1
;
@src := SELECT my_add(1) id
;
CREATE TABLE IF NOT EXISTS append_test (id bigint)
TBLPROPERTIES('table.format.version'='2')
;
Expand All @@ -727,8 +695,6 @@ END
;
function my_add(@a BIGINT) as @a + 1
;
@src := SELECT my_add(1) id
;
INSERT OVERWRITE TABLE append_test VALUES(0),(1)
;
--*--optimus-break-marker--*--
Expand All @@ -752,6 +718,54 @@ USING (SELECT castStringToBoolean(id) FROM @src) source
on append_test.id = source.id
WHEN MATCHED THEN UPDATE
SET append_test.id = 2
;`, query)
})

t.Run("returns query for merge load method with proper variable ordering", func(t *testing.T) {
queryToExecute := `SET odps.table.append2.enable=true;
DROP TABLE IF EXISTS append_tmp;
@src := SELECT 1 id;

CREATE TABLE append_tmp AS SELECT * FROM @src;

@src2 := SELECT id FROM append_tmp;

MERGE INTO append_test
USING (SELECT * FROM @src2) source
on append_test.id = source.id
WHEN MATCHED THEN UPDATE
SET append_test.id = 2;`
odspClient := &mockOdpsClient{}
query, err := query.NewBuilder(
logger.NewDefaultLogger(),
odspClient,
query.WithQuery(queryToExecute),
query.WithMethod(query.MERGE),
).Build()
assert.NoError(t, err)
assert.Equal(t, `SET odps.table.append2.enable=true
;
DROP TABLE IF EXISTS append_tmp
;
--*--optimus-break-marker--*--
SET odps.table.append2.enable=true
;
@src := SELECT 1 id
;
CREATE TABLE append_tmp AS SELECT * FROM @src
;
--*--optimus-break-marker--*--
SET odps.table.append2.enable=true
;
@src := SELECT 1 id
;
@src2 := SELECT id FROM append_tmp
;
MERGE INTO append_test
USING (SELECT * FROM @src2) source
on append_test.id = source.id
WHEN MATCHED THEN UPDATE
SET append_test.id = 2
;`, query)
})
}
Expand Down
64 changes: 64 additions & 0 deletions mc2mc/internal/query/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,70 @@ var (
stringPattern = regexp.MustCompile(`'[^']*'`) // regex to match SQL strings (anything inside single quotes)
)

func SplitQueryComponents(query string) (headers []string, varsUDFs []string, queries []string) {
query = strings.TrimSpace(query)

// extract all header, variable and query lines
stmts := semicolonPattern.Split(query, -1)
queryIndex := 0
for _, stmt := range stmts {
stmt = strings.TrimSpace(stmt)
if stmt == "" {
continue
}
stmtWithoutComment := RemoveComments(stmt)
if headerPattern.MatchString(strings.TrimSpace(stmtWithoutComment)) {
for len(headers) <= queryIndex {
headers = append(headers, "")
}
headers[queryIndex] += strings.TrimSpace(stmt) + "\n;\n"
} else if variablePattern.MatchString(strings.TrimSpace(stmtWithoutComment)) ||
udfPattern.MatchString(strings.TrimSpace(stmtWithoutComment)) {
for len(varsUDFs) <= queryIndex {
varsUDFs = append(varsUDFs, "")
}
varsUDFs[queryIndex] += strings.TrimSpace(stmt) + "\n;\n"
} else if strings.TrimSpace(stmtWithoutComment) == "" {
// if the statement is empty, it's a comment, then omit it
// since it doesn't make sense to execute this statement
} else {
queries = append(queries, stmt)
queryIndex++
}
}

// fill in empty headers and varsUDFs + clear whitespace
for i := range queries {
if len(headers) == i {
headers = append(headers, "")
}
if len(varsUDFs) == i {
varsUDFs = append(varsUDFs, "")
}
headers[i] = strings.TrimSpace(headers[i])
varsUDFs[i] = strings.TrimSpace(varsUDFs[i])
queries[i] = strings.TrimSpace(queries[i])
}

return headers, varsUDFs, queries
}

// JoinSliceString joins a slice of strings with a delimiter
// and skips empty strings
func JoinSliceString(slice []string, delimiter string) string {
builder := strings.Builder{}
for i, s := range slice {
if s == "" {
continue
}
if i > 0 {
builder.WriteString(delimiter)
}
builder.WriteString(s)
}
return strings.TrimSpace(builder.String())
}

func SeparateHeadersAndQuery(query string) (string, string) {
headers := []string{}
query = strings.TrimSpace(query)
Expand Down
114 changes: 114 additions & 0 deletions mc2mc/internal/query/helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,120 @@ import (
"github.com/goto/transformers/mc2mc/internal/query"
)

func TestSplitQueryComponents(t *testing.T) {
t.Run("returns query without headers and variables", func(t *testing.T) {
q1 := `select * from playground`
headers, varsUDFs, queries := query.SplitQueryComponents(q1)
assert.Len(t, headers, 1)
assert.Len(t, varsUDFs, 1)
assert.Len(t, queries, 1)
assert.Empty(t, headers[0])
assert.Empty(t, varsUDFs[0])
assert.Equal(t, q1, queries[0])
})
t.Run("returns headers, vars, and queries with proper order", func(t *testing.T) {
q1 := `set odps.sql.allow.fullscan=true;
set odps.sql.python.version=cp37;
DROP TABLE IF EXISTS append_test_tmp;

@src := SELECT 1 id;
@src2 := SELECT id
FROM @src
WHERE id = 1;
CREATE TABLE append_test_tmp AS SELECT * FROM @src2;

MERGE INTO append_test_tmp USING (SELECT * FROM @src) source
on append_test_tmp.id = source.id
WHEN MATCHED THEN UPDATE
SET append_test_tmp.id = 2;

@src3 := SELECT id FROM append_test_tmp WHERE id = 2;
MERGE INTO append_test USING (SELECT * FROM @src3) source
on append_test.id = source.id
WHEN MATCHED THEN UPDATE
SET append_test.id = 3;

MERGE INTO append_test USING (SELECT * FROM @src3) source
on append_test.id = source.id
WHEN MATCHED THEN UPDATE
SET append_test.id = 3;
`
headers, varsUDFs, queries := query.SplitQueryComponents(q1)
assert.Len(t, headers, 5)
assert.Len(t, varsUDFs, 5)
assert.Len(t, queries, 5)

// headers asserts
headersExpected := make([]string, 5)
headersExpected[0] = `set odps.sql.allow.fullscan=true
;
set odps.sql.python.version=cp37
;`
headersExpected[1] = ""
headersExpected[2] = ""
headersExpected[3] = ""
headersExpected[4] = ""

// vars asserts
varsExpected := make([]string, 5)
varsExpected[0] = ""
varsExpected[1] = `@src := SELECT 1 id
;
@src2 := SELECT id
FROM @src
WHERE id = 1
;`
varsExpected[2] = ""
varsExpected[3] = `@src3 := SELECT id FROM append_test_tmp WHERE id = 2
;`
varsExpected[4] = ""

// queries asserts
queriesExpected := make([]string, 5)
queriesExpected[0] = "DROP TABLE IF EXISTS append_test_tmp"
queriesExpected[1] = "CREATE TABLE append_test_tmp AS SELECT * FROM @src2"
queriesExpected[2] = `MERGE INTO append_test_tmp USING (SELECT * FROM @src) source
on append_test_tmp.id = source.id
WHEN MATCHED THEN UPDATE
SET append_test_tmp.id = 2`
queriesExpected[3] = `MERGE INTO append_test USING (SELECT * FROM @src3) source
on append_test.id = source.id
WHEN MATCHED THEN UPDATE
SET append_test.id = 3`
queriesExpected[4] = `MERGE INTO append_test USING (SELECT * FROM @src3) source
on append_test.id = source.id
WHEN MATCHED THEN UPDATE
SET append_test.id = 3`

for i := range queries {
assert.Equal(t, headersExpected[i], headers[i])
assert.Equal(t, varsExpected[i], varsUDFs[i])
assert.Equal(t, queriesExpected[i], queries[i])
}
})
}

func TestJoinSliceString(t *testing.T) {
t.Run("returns empty string for empty slice", func(t *testing.T) {
slice := []string{}
delimiter := ";"
result := query.JoinSliceString(slice, delimiter)
assert.Empty(t, result)
})
t.Run("returns joined string with delimiter", func(t *testing.T) {
slice := []string{"set odps.sql.allow.fullscan=true", "set odps.sql.python.version=cp37"}
delimiter := ";"
result := query.JoinSliceString(slice, delimiter)
assert.Equal(t, "set odps.sql.allow.fullscan=true;set odps.sql.python.version=cp37", result)
})
t.Run("returns joined string with delimiter and skips empty strings", func(t *testing.T) {
slice := []string{"set odps.sql.allow.fullscan=true", "", "set odps.sql.python.version=cp37"}
delimiter := ";"
result := query.JoinSliceString(slice, delimiter)
assert.Equal(t, "set odps.sql.allow.fullscan=true;set odps.sql.python.version=cp37", result)
})
}

func TestSeparateHeadersAndQuery(t *testing.T) {
t.Run("returns query without macros", func(t *testing.T) {
q1 := `select * from playground`
Expand Down
Loading