diff --git a/mc2mc/mc2mc.go b/mc2mc/mc2mc.go index d42c4f8..98ebb69 100644 --- a/mc2mc/mc2mc.go +++ b/mc2mc/mc2mc.go @@ -4,6 +4,7 @@ import ( "context" e "errors" "fmt" + "log/slog" "os" "os/signal" "strings" @@ -162,35 +163,53 @@ func mc2mc(envs []string) error { // only support concurrent execution for REPLACE method if cfg.LoadMethod == "REPLACE" { - return executeConcurrently(ctx, c, cfg.Concurrency, queriesToExecute, cfg.AdditionalHints) + return executeConcurrently(ctx, l, c, cfg.Concurrency, queriesToExecute, cfg.AdditionalHints) } // otherwise execute sequentially - return execute(ctx, c, queriesToExecute, cfg.AdditionalHints) + return execute(ctx, l, c, queriesToExecute, cfg.AdditionalHints) } -func executeConcurrently(ctx context.Context, c *client.Client, concurrency int, queriesToExecute []string, additionalHints map[string]string) error { +func executeConcurrently(ctx context.Context, l *slog.Logger, c *client.Client, concurrency int, queriesToExecute []string, additionalHints map[string]string) error { // execute query concurrently sem := make(chan uint8, concurrency) wg := sync.WaitGroup{} wg.Add(len(queriesToExecute)) errChan := make(chan error, len(queriesToExecute)) + ids := sync.Map{} // id to boolean map to track running ids for i, queryToExecute := range queriesToExecute { sem <- 0 - executeFn := c.ExecuteFn(i + 1) - go func(queryToExecute string, errChan chan error) { + id := i + 1 + ids.Store(id, false) + executeFn := c.ExecuteFn(id) + go func(id int, queryToExecute string, errChan chan error) { + defer func() { + wg.Done() + <-sem + ids.Delete(id) + // logs the remaining running ids + remainingIds := []int{} + ids.Range(func(key, value any) bool { + remainingIds = append(remainingIds, key.(int)) + return true + }) + if len(remainingIds) > 0 { + l.Info(fmt.Sprintf("remaining running ids: %v", remainingIds)) + l.Info(fmt.Sprintf("waiting for %d other queries to finish...", len(remainingIds))) + } + }() err := executeFn(ctx, queryToExecute, additionalHints) if err != nil { errChan <- errors.WithStack(err) } - wg.Done() - <-sem - }(queryToExecute, errChan) + }(id, queryToExecute, errChan) } wg.Wait() close(errChan) + l.Info("all queries have been processed") + // check error var errs error for err := range errChan { @@ -201,14 +220,17 @@ func executeConcurrently(ctx context.Context, c *client.Client, concurrency int, return errs } -func execute(ctx context.Context, c *client.Client, queriesToExecute []string, additionalHints map[string]string) error { +func execute(ctx context.Context, l *slog.Logger, c *client.Client, queriesToExecute []string, additionalHints map[string]string) error { for i, queryToExecute := range queriesToExecute { + l.Info(fmt.Sprintf("processing query %d of %d", i+1, len(queriesToExecute))) executeFn := c.ExecuteFn(i + 1) err := executeFn(ctx, queryToExecute, additionalHints) if err != nil { return errors.WithStack(err) } } + + l.Info("all queries have been processed") return nil }