Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 21 additions & 10 deletions mc2mc/internal/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,14 @@ import (
"github.com/pkg/errors"
)

const (
SqlScriptSequenceHint = "goto.sql.script.sequence"
)

type OdpsClient interface {
ExecSQL(ctx context.Context, query string) error
ExecSQL(ctx context.Context, query string, hints map[string]string) error
SetDefaultProject(project string)
SetLogViewRetentionInDays(days int)
SetAdditionalHints(hints map[string]string)
SetDryRun(dryRun bool)
}

Expand Down Expand Up @@ -47,13 +50,21 @@ func (c *Client) Close() error {
return errors.WithStack(err)
}

func (c *Client) Execute(ctx context.Context, query string) error {
// execute query with odps client
c.logger.Info(fmt.Sprintf("query to execute:\n%s", query))
if err := c.OdpsClient.ExecSQL(ctx, query); err != nil {
return errors.WithStack(err)
}
func (c *Client) ExecuteFn(id int) func(context.Context, string, map[string]string) error {
return func(ctx context.Context, query string, additionalHints map[string]string) error {
// execute query with odps client
c.logger.Info(fmt.Sprintf("[sequence: %d] query to execute:\n%s", id, query))
// Merge additionalHints with the id
if additionalHints == nil {
additionalHints = make(map[string]string)
}
additionalHints[SqlScriptSequenceHint] = fmt.Sprintf("%d", id)

c.logger.Info("execution done")
return nil
if err := c.OdpsClient.ExecSQL(ctx, query, additionalHints); err != nil {
return errors.WithStack(err)
}

c.logger.Info(fmt.Sprintf("[sequence: %d] execution done", id))
return nil
}
}
27 changes: 17 additions & 10 deletions mc2mc/internal/client/odps.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ type odpsClient struct {
client *odps.Odps

logViewRetentionInDays int
additionalHints map[string]string
isDryRun bool
}

Expand All @@ -33,12 +32,14 @@ func NewODPSClient(logger *slog.Logger, client *odps.Odps) *odpsClient {
// ExecSQL executes the given query in syncronous mode (blocking)
// with capability to do graceful shutdown by terminating task instance
// when context is cancelled.
func (c *odpsClient) ExecSQL(ctx context.Context, query string) error {
func (c *odpsClient) ExecSQL(ctx context.Context, query string, additionalHints map[string]string) error {
if c.isDryRun {
c.logger.Info("dry run mode, skipping execution")
return nil
}
hints := addHints(c.additionalHints, query)

hints := addHints(additionalHints, query)

taskIns, err := c.client.ExecSQlWithHints(query, hints)
if err != nil {
return errors.WithStack(err)
Expand All @@ -50,24 +51,19 @@ func (c *odpsClient) ExecSQL(ctx context.Context, query string) error {
err = e.Join(err, taskIns.Terminate())
return errors.WithStack(err)
}
c.logger.Info(fmt.Sprintf("taskId: %s, log view: %s", taskIns.Id(), url))
c.logger.Info(fmt.Sprintf("taskId: %s, log view: %s, hints: (%s)", taskIns.Id(), url, getHintsString(hints)))

// wait execution success
select {
case <-ctx.Done():
c.logger.Info("context cancelled, terminating task instance")
err := taskIns.Terminate()
err := c.terminate(taskIns)
return e.Join(ctx.Err(), err)
case err := <-c.wait(taskIns):
return errors.WithStack(err)
}
}

// SetAdditionalHints sets the additional hints for the odps client
func (c *odpsClient) SetAdditionalHints(hints map[string]string) {
c.additionalHints = hints
}

// SetLogViewRetentionInDays sets the log view retention in days
func (c *odpsClient) SetLogViewRetentionInDays(days int) {
c.logViewRetentionInDays = days
Expand Down Expand Up @@ -217,3 +213,14 @@ func retry(l *slog.Logger, retryMax int, retryBackoffMs int64, f func() error) e

return err
}

func getHintsString(hints map[string]string) string {
if hints == nil {
return ""
}
var hintsStr []string
for k, v := range hints {
hintsStr = append(hintsStr, fmt.Sprintf("%s: %s", k, v))
}
return strings.Join(hintsStr, ", ")
}
12 changes: 0 additions & 12 deletions mc2mc/internal/client/setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,6 @@ func SetupDryRun(dryRun bool) SetupFn {
}
}

func SetupAdditionalHints(hints map[string]string) SetupFn {
return func(c *Client) error {
if c.OdpsClient == nil {
return errors.New("odps client is required")
}
if hints != nil {
c.OdpsClient.SetAdditionalHints(hints)
}
return nil
}
}

func SetUpLogViewRetentionInDays(days int) SetupFn {
return func(c *Client) error {
if c.OdpsClient == nil {
Expand Down
19 changes: 10 additions & 9 deletions mc2mc/mc2mc.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ func mc2mc(envs []string) error {
client.SetupODPSClient(cfg.GenOdps()),
client.SetupDefaultProject(cfg.ExecutionProject),
client.SetUpLogViewRetentionInDays(cfg.LogViewRetentionInDays),
client.SetupAdditionalHints(cfg.AdditionalHints),
client.SetupDryRun(cfg.DryRun),
)
if err != nil {
Expand Down Expand Up @@ -162,23 +161,24 @@ func mc2mc(envs []string) error {

// only support concurrent execution for REPLACE method
if cfg.LoadMethod == "REPLACE" {
return executeConcurrently(ctx, c, cfg.Concurrency, queriesToExecute)
return executeConcurrently(ctx, c, cfg.Concurrency, queriesToExecute, cfg.AdditionalHints)
}
// otherwise execute sequentially
return execute(ctx, c, queriesToExecute)
return execute(ctx, c, queriesToExecute, cfg.AdditionalHints)
}

func executeConcurrently(ctx context.Context, c *client.Client, concurrency int, queriesToExecute []string) error {
func executeConcurrently(ctx context.Context, c *client.Client, concurrency int, queriesToExecute []string, additionalHints map[string]string) error {
// execute query concurrently
sem := make(chan uint8, concurrency)
wg := sync.WaitGroup{}
wg.Add(len(queriesToExecute))
errChan := make(chan error, len(queriesToExecute))

for _, queryToExecute := range queriesToExecute {
for i, queryToExecute := range queriesToExecute {
sem <- 0
executeFn := c.ExecuteFn(i + 1)
go func(queryToExecute string, errChan chan error) {
err := c.Execute(ctx, queryToExecute)
err := executeFn(ctx, queryToExecute, additionalHints)
if err != nil {
errChan <- errors.WithStack(err)
}
Expand All @@ -200,9 +200,10 @@ func executeConcurrently(ctx context.Context, c *client.Client, concurrency int,
return errs
}

func execute(ctx context.Context, c *client.Client, queriesToExecute []string) error {
for _, queryToExecute := range queriesToExecute {
err := c.Execute(ctx, queryToExecute)
func execute(ctx context.Context, c *client.Client, queriesToExecute []string, additionalHints map[string]string) error {
for i, queryToExecute := range queriesToExecute {
executeFn := c.ExecuteFn(i + 1)
err := executeFn(ctx, queryToExecute, additionalHints)
if err != nil {
return errors.WithStack(err)
}
Expand Down
Loading