diff --git a/mc2mc/internal/client/client.go b/mc2mc/internal/client/client.go index 6756e8f..e71ea4b 100644 --- a/mc2mc/internal/client/client.go +++ b/mc2mc/internal/client/client.go @@ -9,11 +9,14 @@ import ( "github.com/pkg/errors" ) +const ( + SqlScriptSequenceHint = "goto.sql.script.sequence" +) + type OdpsClient interface { - ExecSQL(ctx context.Context, query string) error + ExecSQL(ctx context.Context, query string, hints map[string]string) error SetDefaultProject(project string) SetLogViewRetentionInDays(days int) - SetAdditionalHints(hints map[string]string) SetDryRun(dryRun bool) } @@ -47,13 +50,21 @@ func (c *Client) Close() error { return errors.WithStack(err) } -func (c *Client) Execute(ctx context.Context, query string) error { - // execute query with odps client - c.logger.Info(fmt.Sprintf("query to execute:\n%s", query)) - if err := c.OdpsClient.ExecSQL(ctx, query); err != nil { - return errors.WithStack(err) - } +func (c *Client) ExecuteFn(id int) func(context.Context, string, map[string]string) error { + return func(ctx context.Context, query string, additionalHints map[string]string) error { + // execute query with odps client + c.logger.Info(fmt.Sprintf("[sequence: %d] query to execute:\n%s", id, query)) + // Merge additionalHints with the id + if additionalHints == nil { + additionalHints = make(map[string]string) + } + additionalHints[SqlScriptSequenceHint] = fmt.Sprintf("%d", id) - c.logger.Info("execution done") - return nil + if err := c.OdpsClient.ExecSQL(ctx, query, additionalHints); err != nil { + return errors.WithStack(err) + } + + c.logger.Info(fmt.Sprintf("[sequence: %d] execution done", id)) + return nil + } } diff --git a/mc2mc/internal/client/odps.go b/mc2mc/internal/client/odps.go index 387431a..a483f1b 100644 --- a/mc2mc/internal/client/odps.go +++ b/mc2mc/internal/client/odps.go @@ -17,7 +17,6 @@ type odpsClient struct { client *odps.Odps logViewRetentionInDays int - additionalHints map[string]string isDryRun bool } @@ -33,12 +32,14 @@ func NewODPSClient(logger *slog.Logger, client *odps.Odps) *odpsClient { // ExecSQL executes the given query in syncronous mode (blocking) // with capability to do graceful shutdown by terminating task instance // when context is cancelled. -func (c *odpsClient) ExecSQL(ctx context.Context, query string) error { +func (c *odpsClient) ExecSQL(ctx context.Context, query string, additionalHints map[string]string) error { if c.isDryRun { c.logger.Info("dry run mode, skipping execution") return nil } - hints := addHints(c.additionalHints, query) + + hints := addHints(additionalHints, query) + taskIns, err := c.client.ExecSQlWithHints(query, hints) if err != nil { return errors.WithStack(err) @@ -50,24 +51,19 @@ func (c *odpsClient) ExecSQL(ctx context.Context, query string) error { err = e.Join(err, taskIns.Terminate()) return errors.WithStack(err) } - c.logger.Info(fmt.Sprintf("taskId: %s, log view: %s", taskIns.Id(), url)) + c.logger.Info(fmt.Sprintf("taskId: %s, log view: %s, hints: (%s)", taskIns.Id(), url, getHintsString(hints))) // wait execution success select { case <-ctx.Done(): c.logger.Info("context cancelled, terminating task instance") - err := taskIns.Terminate() + err := c.terminate(taskIns) return e.Join(ctx.Err(), err) case err := <-c.wait(taskIns): return errors.WithStack(err) } } -// SetAdditionalHints sets the additional hints for the odps client -func (c *odpsClient) SetAdditionalHints(hints map[string]string) { - c.additionalHints = hints -} - // SetLogViewRetentionInDays sets the log view retention in days func (c *odpsClient) SetLogViewRetentionInDays(days int) { c.logViewRetentionInDays = days @@ -217,3 +213,14 @@ func retry(l *slog.Logger, retryMax int, retryBackoffMs int64, f func() error) e return err } + +func getHintsString(hints map[string]string) string { + if hints == nil { + return "" + } + var hintsStr []string + for k, v := range hints { + hintsStr = append(hintsStr, fmt.Sprintf("%s: %s", k, v)) + } + return strings.Join(hintsStr, ", ") +} diff --git a/mc2mc/internal/client/setup.go b/mc2mc/internal/client/setup.go index b60e841..0c0dfaa 100644 --- a/mc2mc/internal/client/setup.go +++ b/mc2mc/internal/client/setup.go @@ -20,18 +20,6 @@ func SetupDryRun(dryRun bool) SetupFn { } } -func SetupAdditionalHints(hints map[string]string) SetupFn { - return func(c *Client) error { - if c.OdpsClient == nil { - return errors.New("odps client is required") - } - if hints != nil { - c.OdpsClient.SetAdditionalHints(hints) - } - return nil - } -} - func SetUpLogViewRetentionInDays(days int) SetupFn { return func(c *Client) error { if c.OdpsClient == nil { diff --git a/mc2mc/mc2mc.go b/mc2mc/mc2mc.go index b4b9ae0..6e41877 100644 --- a/mc2mc/mc2mc.go +++ b/mc2mc/mc2mc.go @@ -44,7 +44,6 @@ func mc2mc(envs []string) error { client.SetupODPSClient(cfg.GenOdps()), client.SetupDefaultProject(cfg.ExecutionProject), client.SetUpLogViewRetentionInDays(cfg.LogViewRetentionInDays), - client.SetupAdditionalHints(cfg.AdditionalHints), client.SetupDryRun(cfg.DryRun), ) if err != nil { @@ -162,23 +161,24 @@ func mc2mc(envs []string) error { // only support concurrent execution for REPLACE method if cfg.LoadMethod == "REPLACE" { - return executeConcurrently(ctx, c, cfg.Concurrency, queriesToExecute) + return executeConcurrently(ctx, c, cfg.Concurrency, queriesToExecute, cfg.AdditionalHints) } // otherwise execute sequentially - return execute(ctx, c, queriesToExecute) + return execute(ctx, c, queriesToExecute, cfg.AdditionalHints) } -func executeConcurrently(ctx context.Context, c *client.Client, concurrency int, queriesToExecute []string) error { +func executeConcurrently(ctx context.Context, c *client.Client, concurrency int, queriesToExecute []string, additionalHints map[string]string) error { // execute query concurrently sem := make(chan uint8, concurrency) wg := sync.WaitGroup{} wg.Add(len(queriesToExecute)) errChan := make(chan error, len(queriesToExecute)) - for _, queryToExecute := range queriesToExecute { + for i, queryToExecute := range queriesToExecute { sem <- 0 + executeFn := c.ExecuteFn(i + 1) go func(queryToExecute string, errChan chan error) { - err := c.Execute(ctx, queryToExecute) + err := executeFn(ctx, queryToExecute, additionalHints) if err != nil { errChan <- errors.WithStack(err) } @@ -200,9 +200,10 @@ func executeConcurrently(ctx context.Context, c *client.Client, concurrency int, return errs } -func execute(ctx context.Context, c *client.Client, queriesToExecute []string) error { - for _, queryToExecute := range queriesToExecute { - err := c.Execute(ctx, queryToExecute) +func execute(ctx context.Context, c *client.Client, queriesToExecute []string, additionalHints map[string]string) error { + for i, queryToExecute := range queriesToExecute { + executeFn := c.ExecuteFn(i + 1) + err := executeFn(ctx, queryToExecute, additionalHints) if err != nil { return errors.WithStack(err) }