From ac82256426af7ada02605f4ad22abd47843580b8 Mon Sep 17 00:00:00 2001 From: Matt Spilchen Date: Tue, 4 Nov 2025 16:23:36 -0400 Subject: [PATCH] util/taskset: add generic work distribution mechanism Adds a new taskset package that provides task coordination for a number of workers. TaskSet hands out integer identifiers (TaskIDs) that workers claim and process, with the caller responsible for mapping TaskIDs to actual work (file indices, key ranges, batches, etc.). Features round-robin initial distribution across workers and locality when getting a task ID. Not thread-safe; callers provide external synchronization. This will be used by the new distributed merge pipeline. Closes #156578 Epic: CRDB-48845 Release note: none Co-authored by: @jeffswenson --- pkg/BUILD.bazel | 3 + pkg/util/taskset/BUILD.bazel | 22 +++ pkg/util/taskset/task_set.go | 178 +++++++++++++++++++++ pkg/util/taskset/task_set_test.go | 254 ++++++++++++++++++++++++++++++ pkg/util/taskset/task_span.go | 15 ++ 5 files changed, 472 insertions(+) create mode 100644 pkg/util/taskset/BUILD.bazel create mode 100644 pkg/util/taskset/task_set.go create mode 100644 pkg/util/taskset/task_set_test.go create mode 100644 pkg/util/taskset/task_span.go diff --git a/pkg/BUILD.bazel b/pkg/BUILD.bazel index 77b910fb7458..6cde76cbbb93 100644 --- a/pkg/BUILD.bazel +++ b/pkg/BUILD.bazel @@ -825,6 +825,7 @@ ALL_TESTS = [ "//pkg/util/syncutil:syncutil_test", "//pkg/util/sysutil:sysutil_test", "//pkg/util/taskpacer:taskpacer_test", + "//pkg/util/taskset:taskset_test", "//pkg/util/timeofday:timeofday_test", "//pkg/util/timetz:timetz_test", "//pkg/util/timeutil/pgdate:pgdate_test", @@ -2820,6 +2821,8 @@ GO_TARGETS = [ "//pkg/util/sysutil:sysutil_test", "//pkg/util/taskpacer:taskpacer", "//pkg/util/taskpacer:taskpacer_test", + "//pkg/util/taskset:taskset", + "//pkg/util/taskset:taskset_test", "//pkg/util/timeofday:timeofday", "//pkg/util/timeofday:timeofday_test", "//pkg/util/timetz:timetz", diff --git a/pkg/util/taskset/BUILD.bazel b/pkg/util/taskset/BUILD.bazel new file mode 100644 index 000000000000..2bca4d7e367d --- /dev/null +++ b/pkg/util/taskset/BUILD.bazel @@ -0,0 +1,22 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "taskset", + srcs = [ + "task_set.go", + "task_span.go", + ], + importpath = "github.com/cockroachdb/cockroach/pkg/util/taskset", + visibility = ["//visibility:public"], +) + +go_test( + name = "taskset_test", + srcs = ["task_set_test.go"], + embed = [":taskset"], + deps = [ + "//pkg/util/leaktest", + "//pkg/util/log", + "@com_github_stretchr_testify//require", + ], +) diff --git a/pkg/util/taskset/task_set.go b/pkg/util/taskset/task_set.go new file mode 100644 index 000000000000..4cd2d24bb457 --- /dev/null +++ b/pkg/util/taskset/task_set.go @@ -0,0 +1,178 @@ +// Copyright 2025 The Cockroach Authors. +// +// Use of this software is governed by the CockroachDB Software License +// included in the /LICENSE file. + +// Package taskset provides a generic work distribution mechanism for +// coordinating parallel workers. TaskSet hands out integer identifiers +// (TaskIDs) that workers can claim and process. The TaskIDs themselves have no +// inherent meaning - it's up to the caller to map each TaskID to actual work +// (e.g., file indices, key ranges, batch numbers, etc.). +// +// Example usage: +// +// tasks := taskset.MakeTaskSet(100, 4) // 100 work items, 4 workers +// +// // Worker goroutine +// for taskID := tasks.ClaimFirst(); !taskID.IsDone(); taskID = tasks.ClaimNext(taskID) { +// // Map taskID to actual work +// processFile(files[taskID]) +// // or: processKeyRange(splits[taskID], splits[taskID+1]) +// // or: processBatch(taskID*batchSize, (taskID+1)*batchSize) +// } +package taskset + +// TaskID is an abstract integer identifier for a unit of work. The TaskID +// itself has no inherent meaning - callers decide what each TaskID represents +// (e.g., which file to process, which key range to handle, etc.). +type TaskID int64 + +// taskIDDone is an internal sentinel value indicating no more tasks are available. +// Use TaskID.IsDone() to check if a task is done. +const taskIDDone = TaskID(-1) + +func (t TaskID) IsDone() bool { + return t == taskIDDone +} + +// MakeTaskSet creates a new TaskSet with taskCount work items numbered 0 +// through taskCount-1, pre-split for the expected number of workers. +// +// The TaskIDs are abstract identifiers with no inherent meaning - the caller +// decides what each TaskID represents. For example: +// - File processing: MakeTaskSet(100, 4) with TaskID N → files[N] +// - Key ranges: MakeTaskSet(100, 4) with TaskID N → range [splits[N-1], splits[N]) +// - Row batches: MakeTaskSet(100, 4) with TaskID N → rows [N*1000, (N+1)*1000) +// +// The numWorkers parameter enables better initial load balancing by dividing the +// task range into numWorkers equal spans upfront. For example, with 100 tasks +// and 4 workers: +// - Worker 1: starts with task 0 from range [0, 25) +// - Worker 2: starts with task 25 from range [25, 50) +// - Worker 3: starts with task 50 from range [50, 75) +// - Worker 4: starts with task 75 from range [75, 100) +// +// Each worker continues claiming sequential tasks from their region (maintaining +// locality), and can steal from other regions if they finish early. +// +// If the number of workers is unknown, use numWorkers=1 for a single span. +func MakeTaskSet(taskCount, numWorkers int64) TaskSet { + if numWorkers <= 0 { + numWorkers = 1 + } + if taskCount <= 0 { + return TaskSet{unassigned: nil} + } + + // Pre-split the task range into numWorkers equal spans + spans := make([]taskSpan, 0, numWorkers) + tasksPerWorker := taskCount / numWorkers + remainder := taskCount % numWorkers + + start := TaskID(0) + for i := int64(0); i < numWorkers; i++ { + // Distribute remainder evenly by giving first 'remainder' workers one extra task + spanSize := tasksPerWorker + if i < remainder { + spanSize++ + } + if spanSize > 0 { + end := start + TaskID(spanSize) + spans = append(spans, taskSpan{start: start, end: end}) + start = end + } + } + + return TaskSet{unassigned: spans} +} + +// TaskSet is a generic work distribution coordinator that manages a collection +// of abstract task identifiers (TaskIDs) that can be claimed by workers. +// +// TaskSet implements a work-stealing algorithm optimized for task locality: +// - When a worker completes task N, it tries to claim task N+1 (sequential locality) +// - If task N+1 is unavailable, it falls back to round-robin claiming from the first span +// - This balances load across workers while maintaining locality within each worker +// +// The TaskIDs themselves are just integers (0 through taskCount-1) with no +// inherent meaning. Callers map these identifiers to actual work units such as: +// - File indices (TaskID 5 → process files[5]) +// - Key ranges (TaskID 5 → process range [splits[4], splits[5])) +// - Batch numbers (TaskID 5 → process rows [5000, 6000)) +// +// TaskSet is NOT safe for concurrent use. Callers must ensure external +// synchronization if the TaskSet is accessed from multiple goroutines. +type TaskSet struct { + unassigned []taskSpan +} + +// ClaimFirst should be called when a worker claims its first task. It returns +// an abstract TaskID to process. The caller decides what this TaskID represents +// (e.g., which file to process, which key range to handle). Returns a TaskID +// where .IsDone() is true if no tasks are available. +// +// ClaimFirst is distinct from ClaimNext because ClaimFirst will always take +// from the first span and rotate it to the end (round-robin), whereas ClaimNext +// tries to claim the next sequential task for locality. +func (t *TaskSet) ClaimFirst() TaskID { + if len(t.unassigned) == 0 { + return taskIDDone + } + + // Take the first task from the first span, then rotate that span to the end. + // This provides round-robin distribution, ensuring each worker gets tasks + // from different regions initially for better load balancing. + span := t.unassigned[0] + if span.size() == 0 { + return taskIDDone + } + + task := span.start + span.start += 1 + + if span.size() == 0 { + // Span is exhausted, remove it + t.removeSpan(0) + } else { + // Move the span to the end for round-robin distribution + t.unassigned = append(t.unassigned[1:], span) + } + + return task +} + +// ClaimNext should be called when a worker has completed its current task. It +// returns the next abstract TaskID to process. The caller decides what this +// TaskID represents. Returns a TaskID where .IsDone() is true if no tasks are +// available. +// +// ClaimNext optimizes for locality by attempting to claim lastTask+1 first. If +// that task is unavailable, it falls back to ClaimFirst behavior (round-robin +// from the first span). +func (t *TaskSet) ClaimNext(lastTask TaskID) TaskID { + next := lastTask + 1 + + for i, span := range t.unassigned { + if span.start != next { + continue + } + + span.start += 1 + + if span.size() == 0 { + t.removeSpan(i) + return next + } + + t.unassigned[i] = span + return next + } + + // If we didn't find the next task in the unassigned set, then we've + // exhausted the span and need to claim from a different span. + return t.ClaimFirst() +} + +func (t *TaskSet) removeSpan(index int) { + t.unassigned = append(t.unassigned[:index], t.unassigned[index+1:]...) +} diff --git a/pkg/util/taskset/task_set_test.go b/pkg/util/taskset/task_set_test.go new file mode 100644 index 000000000000..dab64cf26bb3 --- /dev/null +++ b/pkg/util/taskset/task_set_test.go @@ -0,0 +1,254 @@ +// Copyright 2025 The Cockroach Authors. +// +// Use of this software is governed by the CockroachDB Software License +// included in the /LICENSE file. + +package taskset + +import ( + "math/rand" + "testing" + + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/stretchr/testify/require" +) + +func TestTaskSetSingleWorker(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + // When a single worker claims tasks from a taskSet, it should claim all + // tasks in sequential order now that we use FIFO claiming. + tasks := MakeTaskSet(10, 1) + var found []TaskID + + for next := tasks.ClaimFirst(); !next.IsDone(); next = tasks.ClaimNext(next) { + found = append(found, next) + } + + // Verify that tasks are claimed sequentially. + require.Equal(t, []TaskID{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, found) +} + +func TestTaskSetParallel(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + taskCount := min(rand.Int63n(10000), 16) + tasks := MakeTaskSet(taskCount, 16) + workers := make([]TaskID, 16) + var found []TaskID + + for i := range workers { + workers[i] = tasks.ClaimFirst() + if !workers[i].IsDone() { + found = append(found, workers[i]) + } + } + + for { + // Check if all workers are done. + allDone := true + for _, w := range workers { + if !w.IsDone() { + allDone = false + break + } + } + if allDone { + break + } + + // Pick a random worker to claim the next task. + workerIndex := rand.Intn(len(workers)) + prevTask := workers[workerIndex] + // Skip workers that have no tasks. + if prevTask.IsDone() { + continue + } + next := tasks.ClaimNext(prevTask) + workers[workerIndex] = next + if !next.IsDone() { + found = append(found, next) + } + } + + // Build a map of the found tasks to ensure they are unique. + taskMap := make(map[TaskID]struct{}) + for _, task := range found { + taskMap[task] = struct{}{} + } + require.Len(t, taskMap, int(taskCount)) +} + +func TestMakeTaskSet(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + // Test with evenly divisible tasks - simulate 4 workers each calling ClaimFirst once. + tasks := MakeTaskSet(100, 4) + var claimed []TaskID + for i := 0; i < 4; i++ { + task := tasks.ClaimFirst() + require.False(t, task.IsDone()) + claimed = append(claimed, task) + } + // Each worker should get the first task from their region (round-robin). + require.Equal(t, []TaskID{0, 25, 50, 75}, claimed) + + // Test with tasks that don't divide evenly - simulate 3 workers. + tasks = MakeTaskSet(100, 3) + claimed = nil + for i := 0; i < 3; i++ { + task := tasks.ClaimFirst() + require.False(t, task.IsDone()) + claimed = append(claimed, task) + } + // First span gets 34 tasks [0,34), second gets 33 [34,67), third gets 33 [67,100). + require.Equal(t, []TaskID{0, 34, 67}, claimed) + + // Test with more workers than tasks - simulate 5 workers (only 5 tasks available). + tasks = MakeTaskSet(5, 10) + claimed = nil + for i := 0; i < 5; i++ { + task := tasks.ClaimFirst() + require.False(t, task.IsDone()) + claimed = append(claimed, task) + } + require.Equal(t, []TaskID{0, 1, 2, 3, 4}, claimed) + // 6th worker should get nothing. + require.True(t, tasks.ClaimFirst().IsDone()) + + // Test edge cases. + tasks = MakeTaskSet(0, 4) + require.True(t, tasks.ClaimFirst().IsDone()) + + tasks = MakeTaskSet(10, 0) + require.False(t, tasks.ClaimFirst().IsDone()) // Should default to 1 worker + + tasks = MakeTaskSet(10, -1) + require.False(t, tasks.ClaimFirst().IsDone()) // Should default to 1 worker +} + +func TestTaskSetLoadBalancing(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + // Simulate 4 workers processing 100 tasks. + tasks := MakeTaskSet(100, 4) + + type worker struct { + id int + tasks []TaskID + } + workers := make([]worker, 4) + + // Each worker claims their first task. + for i := range workers { + workers[i].id = i + task := tasks.ClaimFirst() + require.False(t, task.IsDone()) + workers[i].tasks = append(workers[i].tasks, task) + } + + // Verify initial distribution is balanced across regions + require.Equal(t, TaskID(0), workers[0].tasks[0]) // Region [0, 25) + require.Equal(t, TaskID(25), workers[1].tasks[0]) // Region [25, 50) + require.Equal(t, TaskID(50), workers[2].tasks[0]) // Region [50, 75) + require.Equal(t, TaskID(75), workers[3].tasks[0]) // Region [75, 100) + + // Simulate concurrent-like processing: round-robin through workers + // This prevents one worker from stealing all the work. + for { + claimed := false + for i := range workers { + lastTask := workers[i].tasks[len(workers[i].tasks)-1] + next := tasks.ClaimNext(lastTask) + if !next.IsDone() { + workers[i].tasks = append(workers[i].tasks, next) + claimed = true + } + } + if !claimed { + break + } + } + + // Verify all tasks were claimed exactly once. + allTasks := make(map[TaskID]bool) + for _, w := range workers { + for _, task := range w.tasks { + require.False(t, allTasks[task], "task %d claimed multiple times", task) + allTasks[task] = true + } + } + require.Len(t, allTasks, 100) + + // With round-robin processing, each worker should get approximately equal work. + for i, w := range workers { + require.InDelta(t, 25, len(w.tasks), 2, "worker %d got %d tasks", i, len(w.tasks)) + } +} + +func TestTaskSetMoreWorkersThanTasks(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + // Simulate scenario with more workers than tasks: 10 tasks, 64 workers. + tasks := MakeTaskSet(10, 64) + + type worker struct { + id int + tasks []TaskID + } + workers := make([]worker, 64) + + // Each worker tries to claim their first task + workersWithTasks := 0 + workersWithoutTasks := 0 + for i := range workers { + workers[i].id = i + task := tasks.ClaimFirst() + if !task.IsDone() { + workers[i].tasks = append(workers[i].tasks, task) + workersWithTasks++ + } else { + workersWithoutTasks++ + } + } + + // Only 10 workers should get tasks (one per task). + require.Equal(t, 10, workersWithTasks, "expected 10 workers to get tasks") + require.Equal(t, 54, workersWithoutTasks, "expected 54 workers to get no tasks") + + // Verify the workers that got tasks received unique tasks. + seenTasks := make(map[TaskID]bool) + for _, w := range workers { + if len(w.tasks) > 0 { + require.Len(t, w.tasks, 1, "worker %d should have exactly 1 task initially", w.id) + task := w.tasks[0] + require.False(t, seenTasks[task], "task %d assigned to multiple workers", task) + seenTasks[task] = true + } + } + require.Len(t, seenTasks, 10, "all 10 tasks should be assigned") + + // Verify the tasks are distributed round-robin (0-9). + expectedTasks := []TaskID{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} + actualTasks := make([]TaskID, 0, 10) + for _, w := range workers { + if len(w.tasks) > 0 { + actualTasks = append(actualTasks, w.tasks[0]) + } + } + require.Equal(t, expectedTasks, actualTasks, "tasks should be assigned round-robin") + + // Simulate workers trying to claim more tasks (all should fail). + for i := range workers { + if len(workers[i].tasks) > 0 { + next := tasks.ClaimNext(workers[i].tasks[0]) + require.True(t, next.IsDone(), "no more tasks should be available") + } + } +} diff --git a/pkg/util/taskset/task_span.go b/pkg/util/taskset/task_span.go new file mode 100644 index 000000000000..4d06a6c9baa0 --- /dev/null +++ b/pkg/util/taskset/task_span.go @@ -0,0 +1,15 @@ +// Copyright 2025 The Cockroach Authors. +// +// Use of this software is governed by the CockroachDB Software License +// included in the /LICENSE file. + +package taskset + +type taskSpan struct { + start TaskID + end TaskID +} + +func (t *taskSpan) size() int64 { + return int64(t.end) - int64(t.start) +}