diff --git a/mc2mc/internal/client/client.go b/mc2mc/internal/client/client.go index ff21ce4..25a70ff 100644 --- a/mc2mc/internal/client/client.go +++ b/mc2mc/internal/client/client.go @@ -19,6 +19,7 @@ type Loader interface { } type OdpsClient interface { + GetOrderedColumns(tableID string) ([]string, error) GetPartitionNames(ctx context.Context, tableID string) ([]string, error) ExecSQL(ctx context.Context, query string) error } @@ -65,6 +66,17 @@ func (c *Client) Execute(ctx context.Context, tableID, queryFilePath string) err if err != nil { return errors.WithStack(err) } + + // get column names + if tableID != "" { + columnNames, err := c.OdpsClient.GetOrderedColumns(tableID) + if err != nil { + return errors.WithStack(err) + } + // construct query with ordered columns + queryRaw = constructQueryWithOrderedColumns(queryRaw, columnNames) + } + if c.enablePartitionValue && !c.enableAutoPartition { queryRaw = addPartitionValueColumn(queryRaw) } @@ -98,3 +110,8 @@ func addPartitionValueColumn(rawQuery []byte) []byte { header, qr := loader.SeparateHeadersAndQuery(string(rawQuery)) return []byte(fmt.Sprintf("%s SELECT *, STRING(CURRENT_DATE()) as __partitionvalue FROM (%s)", header, qr)) } + +func constructQueryWithOrderedColumns(query []byte, orderedColumns []string) []byte { + header, qr := loader.SeparateHeadersAndQuery(string(query)) + return []byte(fmt.Sprintf("%s %s", header, loader.ConstructQueryWithOrderedColumns(qr, orderedColumns))) +} diff --git a/mc2mc/internal/client/client_test.go b/mc2mc/internal/client/client_test.go index 7d64252..de64f99 100644 --- a/mc2mc/internal/client/client_test.go +++ b/mc2mc/internal/client/client_test.go @@ -23,11 +23,30 @@ func TestExecute(t *testing.T) { // assert assert.Error(t, err) }) + t.Run("should return error when getting ordered columns fails", func(t *testing.T) { + // arrange + client, err := client.NewClient(context.TODO(), client.SetupLogger("error")) + require.NoError(t, err) + client.OdpsClient = &mockOdpsClient{ + orderedColumns: func() ([]string, error) { + return nil, fmt.Errorf("error get ordered columns") + }, + } + assert.NoError(t, os.WriteFile("/tmp/query.sql", []byte("SELECT * FROM table;"), 0644)) + // act + err = client.Execute(context.TODO(), "project_test.table_test", "/tmp/query.sql") + // assert + assert.Error(t, err) + assert.ErrorContains(t, err, "error get ordered columns") + }) t.Run("should return error when getting partition name fails", func(t *testing.T) { // arrange client, err := client.NewClient(context.TODO(), client.SetupLogger("error")) require.NoError(t, err) client.OdpsClient = &mockOdpsClient{ + orderedColumns: func() ([]string, error) { + return []string{"col1", "col2"}, nil + }, partitionResult: func() ([]string, error) { return nil, fmt.Errorf("error get partition name") }, @@ -44,6 +63,9 @@ func TestExecute(t *testing.T) { client, err := client.NewClient(context.TODO(), client.SetupLogger("error"), client.SetupLoader("APPEND")) require.NoError(t, err) client.OdpsClient = &mockOdpsClient{ + orderedColumns: func() ([]string, error) { + return []string{"col1", "col2"}, nil + }, partitionResult: func() ([]string, error) { return nil, nil }, @@ -63,6 +85,9 @@ func TestExecute(t *testing.T) { client, err := client.NewClient(context.TODO(), client.SetupLogger("error"), client.SetupLoader("REPLACE")) require.NoError(t, err) client.OdpsClient = &mockOdpsClient{ + orderedColumns: func() ([]string, error) { + return []string{"col1", "col2"}, nil + }, partitionResult: func() ([]string, error) { return []string{"event_date"}, nil }, @@ -72,11 +97,11 @@ func TestExecute(t *testing.T) { } client.Loader = &mockLoader{ getQueryFunc: func(tableID, query string) string { - return "INSERT OVERWRITE TABLE project_test.table_test SELECT * FROM table;" + return "INSERT OVERWRITE TABLE project_test.table_test SELECT col1, col2 FROM (SELECT * FROM table);" }, getPartitionedQueryFunc: func(tableID, query string, partitionNames []string) string { assert.True(t, true, "should be called") - return "INSERT OVERWRITE TABLE project_test.table_test PARTITION(event_date) SELECT * FROM table;" + return "INSERT OVERWRITE TABLE project_test.table_test PARTITION(event_date) SELECT col1, col2 FROM (SELECT * FROM table);" }, } require.NoError(t, os.WriteFile("/tmp/query.sql", []byte("SELECT * FROM table;"), 0644)) @@ -90,6 +115,9 @@ func TestExecute(t *testing.T) { client, err := client.NewClient(context.TODO(), client.SetupLogger("error"), client.SetupLoader("REPLACE"), client.EnableAutoPartition(true)) require.NoError(t, err) client.OdpsClient = &mockOdpsClient{ + orderedColumns: func() ([]string, error) { + return []string{"col1", "col2"}, nil + }, partitionResult: func() ([]string, error) { return []string{"_partition_value"}, nil }, @@ -99,11 +127,11 @@ func TestExecute(t *testing.T) { } client.Loader = &mockLoader{ getQueryFunc: func(tableID, query string) string { - return "INSERT OVERWRITE TABLE project_test.table_test SELECT * FROM table;" + return "INSERT OVERWRITE TABLE project_test.table_test SELECT col1, col2 FROM (SELECT * FROM table);" }, getPartitionedQueryFunc: func(tableID, query string, partitionNames []string) string { assert.False(t, true, "should not be called") - return "INSERT OVERWRITE TABLE project_test.table_test PARTITION(event_date) SELECT * FROM table;" + return "INSERT OVERWRITE TABLE project_test.table_test PARTITION(_partition_value) SELECT col1, col2 FROM (SELECT * FROM table);" }, } require.NoError(t, os.WriteFile("/tmp/query.sql", []byte("SELECT * FROM table;"), 0644)) @@ -117,6 +145,7 @@ func TestExecute(t *testing.T) { type mockOdpsClient struct { partitionResult func() ([]string, error) execSQLResult func() error + orderedColumns func() ([]string, error) } func (m *mockOdpsClient) GetPartitionNames(ctx context.Context, tableID string) ([]string, error) { @@ -127,6 +156,10 @@ func (m *mockOdpsClient) ExecSQL(ctx context.Context, query string) error { return m.execSQLResult() } +func (m *mockOdpsClient) GetOrderedColumns(tableID string) ([]string, error) { + return m.orderedColumns() +} + type mockLoader struct { getQueryFunc func(tableID, query string) string getPartitionedQueryFunc func(tableID, query string, partitionNames []string) string diff --git a/mc2mc/internal/client/odps.go b/mc2mc/internal/client/odps.go index 320e20c..95d1d68 100644 --- a/mc2mc/internal/client/odps.go +++ b/mc2mc/internal/client/odps.go @@ -70,9 +70,30 @@ func (c *odpsClient) GetPartitionNames(_ context.Context, tableID string) ([]str for _, partition := range table.Schema().PartitionColumns { partitionNames = append(partitionNames, partition.Name) } + return partitionNames, nil } +// GetOrderedColumns returns the ordered column names of the given table +// by querying the table schema. +func (c *odpsClient) GetOrderedColumns(tableID string) ([]string, error) { + splittedTableID := strings.Split(tableID, ".") + if len(splittedTableID) != 3 { + return nil, errors.Errorf("invalid tableID (tableID should be in format project.schema.table): %s", tableID) + } + project, schema, name := splittedTableID[0], splittedTableID[1], splittedTableID[2] + table := odps.NewTable(c.client, project, schema, name) + if err := table.Load(); err != nil { + return nil, errors.WithStack(err) + } + var columnNames []string + for _, column := range table.Schema().Columns { + columnNames = append(columnNames, column.Name) + } + + return columnNames, nil +} + // wait waits for the task instance to finish on a separate goroutine func wait(taskIns *odps.Instance) <-chan error { errChan := make(chan error) diff --git a/mc2mc/internal/loader/helper.go b/mc2mc/internal/loader/helper.go index 7d574e2..bc69d64 100644 --- a/mc2mc/internal/loader/helper.go +++ b/mc2mc/internal/loader/helper.go @@ -1,6 +1,7 @@ package loader import ( + "fmt" "strings" ) @@ -23,3 +24,7 @@ func SeparateHeadersAndQuery(query string) (string, string) { } return headers, last } + +func ConstructQueryWithOrderedColumns(query string, orderedColumns []string) string { + return fmt.Sprintf("SELECT %s FROM (%s)", strings.Join(orderedColumns, ", "), query) +} diff --git a/mc2mc/internal/loader/helper_test.go b/mc2mc/internal/loader/helper_test.go index a54aaf0..294d50d 100644 --- a/mc2mc/internal/loader/helper_test.go +++ b/mc2mc/internal/loader/helper_test.go @@ -56,3 +56,13 @@ where CAST(event_timestamp as DATE) = '{{ .DSTART | Date }}' assert.Contains(t, query, expectedQuery) }) } + +func TestConstructQueryWithOrderedColumns(t *testing.T) { + t.Run("returns query with ordered columns", func(t *testing.T) { + q1 := `select col_2 as col2, col_3 as col3, col_1 as col1 from project.schema.table` + orderedColumns := []string{"col1", "col2", "col3"} + query := loader.ConstructQueryWithOrderedColumns(q1, orderedColumns) + expected := "SELECT col1, col2, col3 FROM (select col_2 as col2, col_3 as col3, col_1 as col1 from project.schema.table)" + assert.Equal(t, expected, query) + }) +}