diff --git a/pkg/BUILD.bazel b/pkg/BUILD.bazel index 77b910fb7458..6cde76cbbb93 100644 --- a/pkg/BUILD.bazel +++ b/pkg/BUILD.bazel @@ -825,6 +825,7 @@ ALL_TESTS = [ "//pkg/util/syncutil:syncutil_test", "//pkg/util/sysutil:sysutil_test", "//pkg/util/taskpacer:taskpacer_test", + "//pkg/util/taskset:taskset_test", "//pkg/util/timeofday:timeofday_test", "//pkg/util/timetz:timetz_test", "//pkg/util/timeutil/pgdate:pgdate_test", @@ -2820,6 +2821,8 @@ GO_TARGETS = [ "//pkg/util/sysutil:sysutil_test", "//pkg/util/taskpacer:taskpacer", "//pkg/util/taskpacer:taskpacer_test", + "//pkg/util/taskset:taskset", + "//pkg/util/taskset:taskset_test", "//pkg/util/timeofday:timeofday", "//pkg/util/timeofday:timeofday_test", "//pkg/util/timetz:timetz", diff --git a/pkg/util/taskset/BUILD.bazel b/pkg/util/taskset/BUILD.bazel new file mode 100644 index 000000000000..2bca4d7e367d --- /dev/null +++ b/pkg/util/taskset/BUILD.bazel @@ -0,0 +1,22 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "taskset", + srcs = [ + "task_set.go", + "task_span.go", + ], + importpath = "github.com/cockroachdb/cockroach/pkg/util/taskset", + visibility = ["//visibility:public"], +) + +go_test( + name = "taskset_test", + srcs = ["task_set_test.go"], + embed = [":taskset"], + deps = [ + "//pkg/util/leaktest", + "//pkg/util/log", + "@com_github_stretchr_testify//require", + ], +) diff --git a/pkg/util/taskset/task_set.go b/pkg/util/taskset/task_set.go new file mode 100644 index 000000000000..4cd2d24bb457 --- /dev/null +++ b/pkg/util/taskset/task_set.go @@ -0,0 +1,178 @@ +// Copyright 2025 The Cockroach Authors. +// +// Use of this software is governed by the CockroachDB Software License +// included in the /LICENSE file. + +// Package taskset provides a generic work distribution mechanism for +// coordinating parallel workers. TaskSet hands out integer identifiers +// (TaskIDs) that workers can claim and process. The TaskIDs themselves have no +// inherent meaning - it's up to the caller to map each TaskID to actual work +// (e.g., file indices, key ranges, batch numbers, etc.). +// +// Example usage: +// +// tasks := taskset.MakeTaskSet(100, 4) // 100 work items, 4 workers +// +// // Worker goroutine +// for taskID := tasks.ClaimFirst(); !taskID.IsDone(); taskID = tasks.ClaimNext(taskID) { +// // Map taskID to actual work +// processFile(files[taskID]) +// // or: processKeyRange(splits[taskID], splits[taskID+1]) +// // or: processBatch(taskID*batchSize, (taskID+1)*batchSize) +// } +package taskset + +// TaskID is an abstract integer identifier for a unit of work. The TaskID +// itself has no inherent meaning - callers decide what each TaskID represents +// (e.g., which file to process, which key range to handle, etc.). +type TaskID int64 + +// taskIDDone is an internal sentinel value indicating no more tasks are available. +// Use TaskID.IsDone() to check if a task is done. +const taskIDDone = TaskID(-1) + +func (t TaskID) IsDone() bool { + return t == taskIDDone +} + +// MakeTaskSet creates a new TaskSet with taskCount work items numbered 0 +// through taskCount-1, pre-split for the expected number of workers. +// +// The TaskIDs are abstract identifiers with no inherent meaning - the caller +// decides what each TaskID represents. For example: +// - File processing: MakeTaskSet(100, 4) with TaskID N → files[N] +// - Key ranges: MakeTaskSet(100, 4) with TaskID N → range [splits[N-1], splits[N]) +// - Row batches: MakeTaskSet(100, 4) with TaskID N → rows [N*1000, (N+1)*1000) +// +// The numWorkers parameter enables better initial load balancing by dividing the +// task range into numWorkers equal spans upfront. For example, with 100 tasks +// and 4 workers: +// - Worker 1: starts with task 0 from range [0, 25) +// - Worker 2: starts with task 25 from range [25, 50) +// - Worker 3: starts with task 50 from range [50, 75) +// - Worker 4: starts with task 75 from range [75, 100) +// +// Each worker continues claiming sequential tasks from their region (maintaining +// locality), and can steal from other regions if they finish early. +// +// If the number of workers is unknown, use numWorkers=1 for a single span. +func MakeTaskSet(taskCount, numWorkers int64) TaskSet { + if numWorkers <= 0 { + numWorkers = 1 + } + if taskCount <= 0 { + return TaskSet{unassigned: nil} + } + + // Pre-split the task range into numWorkers equal spans + spans := make([]taskSpan, 0, numWorkers) + tasksPerWorker := taskCount / numWorkers + remainder := taskCount % numWorkers + + start := TaskID(0) + for i := int64(0); i < numWorkers; i++ { + // Distribute remainder evenly by giving first 'remainder' workers one extra task + spanSize := tasksPerWorker + if i < remainder { + spanSize++ + } + if spanSize > 0 { + end := start + TaskID(spanSize) + spans = append(spans, taskSpan{start: start, end: end}) + start = end + } + } + + return TaskSet{unassigned: spans} +} + +// TaskSet is a generic work distribution coordinator that manages a collection +// of abstract task identifiers (TaskIDs) that can be claimed by workers. +// +// TaskSet implements a work-stealing algorithm optimized for task locality: +// - When a worker completes task N, it tries to claim task N+1 (sequential locality) +// - If task N+1 is unavailable, it falls back to round-robin claiming from the first span +// - This balances load across workers while maintaining locality within each worker +// +// The TaskIDs themselves are just integers (0 through taskCount-1) with no +// inherent meaning. Callers map these identifiers to actual work units such as: +// - File indices (TaskID 5 → process files[5]) +// - Key ranges (TaskID 5 → process range [splits[4], splits[5])) +// - Batch numbers (TaskID 5 → process rows [5000, 6000)) +// +// TaskSet is NOT safe for concurrent use. Callers must ensure external +// synchronization if the TaskSet is accessed from multiple goroutines. +type TaskSet struct { + unassigned []taskSpan +} + +// ClaimFirst should be called when a worker claims its first task. It returns +// an abstract TaskID to process. The caller decides what this TaskID represents +// (e.g., which file to process, which key range to handle). Returns a TaskID +// where .IsDone() is true if no tasks are available. +// +// ClaimFirst is distinct from ClaimNext because ClaimFirst will always take +// from the first span and rotate it to the end (round-robin), whereas ClaimNext +// tries to claim the next sequential task for locality. +func (t *TaskSet) ClaimFirst() TaskID { + if len(t.unassigned) == 0 { + return taskIDDone + } + + // Take the first task from the first span, then rotate that span to the end. + // This provides round-robin distribution, ensuring each worker gets tasks + // from different regions initially for better load balancing. + span := t.unassigned[0] + if span.size() == 0 { + return taskIDDone + } + + task := span.start + span.start += 1 + + if span.size() == 0 { + // Span is exhausted, remove it + t.removeSpan(0) + } else { + // Move the span to the end for round-robin distribution + t.unassigned = append(t.unassigned[1:], span) + } + + return task +} + +// ClaimNext should be called when a worker has completed its current task. It +// returns the next abstract TaskID to process. The caller decides what this +// TaskID represents. Returns a TaskID where .IsDone() is true if no tasks are +// available. +// +// ClaimNext optimizes for locality by attempting to claim lastTask+1 first. If +// that task is unavailable, it falls back to ClaimFirst behavior (round-robin +// from the first span). +func (t *TaskSet) ClaimNext(lastTask TaskID) TaskID { + next := lastTask + 1 + + for i, span := range t.unassigned { + if span.start != next { + continue + } + + span.start += 1 + + if span.size() == 0 { + t.removeSpan(i) + return next + } + + t.unassigned[i] = span + return next + } + + // If we didn't find the next task in the unassigned set, then we've + // exhausted the span and need to claim from a different span. + return t.ClaimFirst() +} + +func (t *TaskSet) removeSpan(index int) { + t.unassigned = append(t.unassigned[:index], t.unassigned[index+1:]...) +} diff --git a/pkg/util/taskset/task_set_test.go b/pkg/util/taskset/task_set_test.go new file mode 100644 index 000000000000..dab64cf26bb3 --- /dev/null +++ b/pkg/util/taskset/task_set_test.go @@ -0,0 +1,254 @@ +// Copyright 2025 The Cockroach Authors. +// +// Use of this software is governed by the CockroachDB Software License +// included in the /LICENSE file. + +package taskset + +import ( + "math/rand" + "testing" + + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/stretchr/testify/require" +) + +func TestTaskSetSingleWorker(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + // When a single worker claims tasks from a taskSet, it should claim all + // tasks in sequential order now that we use FIFO claiming. + tasks := MakeTaskSet(10, 1) + var found []TaskID + + for next := tasks.ClaimFirst(); !next.IsDone(); next = tasks.ClaimNext(next) { + found = append(found, next) + } + + // Verify that tasks are claimed sequentially. + require.Equal(t, []TaskID{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, found) +} + +func TestTaskSetParallel(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + taskCount := min(rand.Int63n(10000), 16) + tasks := MakeTaskSet(taskCount, 16) + workers := make([]TaskID, 16) + var found []TaskID + + for i := range workers { + workers[i] = tasks.ClaimFirst() + if !workers[i].IsDone() { + found = append(found, workers[i]) + } + } + + for { + // Check if all workers are done. + allDone := true + for _, w := range workers { + if !w.IsDone() { + allDone = false + break + } + } + if allDone { + break + } + + // Pick a random worker to claim the next task. + workerIndex := rand.Intn(len(workers)) + prevTask := workers[workerIndex] + // Skip workers that have no tasks. + if prevTask.IsDone() { + continue + } + next := tasks.ClaimNext(prevTask) + workers[workerIndex] = next + if !next.IsDone() { + found = append(found, next) + } + } + + // Build a map of the found tasks to ensure they are unique. + taskMap := make(map[TaskID]struct{}) + for _, task := range found { + taskMap[task] = struct{}{} + } + require.Len(t, taskMap, int(taskCount)) +} + +func TestMakeTaskSet(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + // Test with evenly divisible tasks - simulate 4 workers each calling ClaimFirst once. + tasks := MakeTaskSet(100, 4) + var claimed []TaskID + for i := 0; i < 4; i++ { + task := tasks.ClaimFirst() + require.False(t, task.IsDone()) + claimed = append(claimed, task) + } + // Each worker should get the first task from their region (round-robin). + require.Equal(t, []TaskID{0, 25, 50, 75}, claimed) + + // Test with tasks that don't divide evenly - simulate 3 workers. + tasks = MakeTaskSet(100, 3) + claimed = nil + for i := 0; i < 3; i++ { + task := tasks.ClaimFirst() + require.False(t, task.IsDone()) + claimed = append(claimed, task) + } + // First span gets 34 tasks [0,34), second gets 33 [34,67), third gets 33 [67,100). + require.Equal(t, []TaskID{0, 34, 67}, claimed) + + // Test with more workers than tasks - simulate 5 workers (only 5 tasks available). + tasks = MakeTaskSet(5, 10) + claimed = nil + for i := 0; i < 5; i++ { + task := tasks.ClaimFirst() + require.False(t, task.IsDone()) + claimed = append(claimed, task) + } + require.Equal(t, []TaskID{0, 1, 2, 3, 4}, claimed) + // 6th worker should get nothing. + require.True(t, tasks.ClaimFirst().IsDone()) + + // Test edge cases. + tasks = MakeTaskSet(0, 4) + require.True(t, tasks.ClaimFirst().IsDone()) + + tasks = MakeTaskSet(10, 0) + require.False(t, tasks.ClaimFirst().IsDone()) // Should default to 1 worker + + tasks = MakeTaskSet(10, -1) + require.False(t, tasks.ClaimFirst().IsDone()) // Should default to 1 worker +} + +func TestTaskSetLoadBalancing(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + // Simulate 4 workers processing 100 tasks. + tasks := MakeTaskSet(100, 4) + + type worker struct { + id int + tasks []TaskID + } + workers := make([]worker, 4) + + // Each worker claims their first task. + for i := range workers { + workers[i].id = i + task := tasks.ClaimFirst() + require.False(t, task.IsDone()) + workers[i].tasks = append(workers[i].tasks, task) + } + + // Verify initial distribution is balanced across regions + require.Equal(t, TaskID(0), workers[0].tasks[0]) // Region [0, 25) + require.Equal(t, TaskID(25), workers[1].tasks[0]) // Region [25, 50) + require.Equal(t, TaskID(50), workers[2].tasks[0]) // Region [50, 75) + require.Equal(t, TaskID(75), workers[3].tasks[0]) // Region [75, 100) + + // Simulate concurrent-like processing: round-robin through workers + // This prevents one worker from stealing all the work. + for { + claimed := false + for i := range workers { + lastTask := workers[i].tasks[len(workers[i].tasks)-1] + next := tasks.ClaimNext(lastTask) + if !next.IsDone() { + workers[i].tasks = append(workers[i].tasks, next) + claimed = true + } + } + if !claimed { + break + } + } + + // Verify all tasks were claimed exactly once. + allTasks := make(map[TaskID]bool) + for _, w := range workers { + for _, task := range w.tasks { + require.False(t, allTasks[task], "task %d claimed multiple times", task) + allTasks[task] = true + } + } + require.Len(t, allTasks, 100) + + // With round-robin processing, each worker should get approximately equal work. + for i, w := range workers { + require.InDelta(t, 25, len(w.tasks), 2, "worker %d got %d tasks", i, len(w.tasks)) + } +} + +func TestTaskSetMoreWorkersThanTasks(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + // Simulate scenario with more workers than tasks: 10 tasks, 64 workers. + tasks := MakeTaskSet(10, 64) + + type worker struct { + id int + tasks []TaskID + } + workers := make([]worker, 64) + + // Each worker tries to claim their first task + workersWithTasks := 0 + workersWithoutTasks := 0 + for i := range workers { + workers[i].id = i + task := tasks.ClaimFirst() + if !task.IsDone() { + workers[i].tasks = append(workers[i].tasks, task) + workersWithTasks++ + } else { + workersWithoutTasks++ + } + } + + // Only 10 workers should get tasks (one per task). + require.Equal(t, 10, workersWithTasks, "expected 10 workers to get tasks") + require.Equal(t, 54, workersWithoutTasks, "expected 54 workers to get no tasks") + + // Verify the workers that got tasks received unique tasks. + seenTasks := make(map[TaskID]bool) + for _, w := range workers { + if len(w.tasks) > 0 { + require.Len(t, w.tasks, 1, "worker %d should have exactly 1 task initially", w.id) + task := w.tasks[0] + require.False(t, seenTasks[task], "task %d assigned to multiple workers", task) + seenTasks[task] = true + } + } + require.Len(t, seenTasks, 10, "all 10 tasks should be assigned") + + // Verify the tasks are distributed round-robin (0-9). + expectedTasks := []TaskID{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} + actualTasks := make([]TaskID, 0, 10) + for _, w := range workers { + if len(w.tasks) > 0 { + actualTasks = append(actualTasks, w.tasks[0]) + } + } + require.Equal(t, expectedTasks, actualTasks, "tasks should be assigned round-robin") + + // Simulate workers trying to claim more tasks (all should fail). + for i := range workers { + if len(workers[i].tasks) > 0 { + next := tasks.ClaimNext(workers[i].tasks[0]) + require.True(t, next.IsDone(), "no more tasks should be available") + } + } +} diff --git a/pkg/util/taskset/task_span.go b/pkg/util/taskset/task_span.go new file mode 100644 index 000000000000..4d06a6c9baa0 --- /dev/null +++ b/pkg/util/taskset/task_span.go @@ -0,0 +1,15 @@ +// Copyright 2025 The Cockroach Authors. +// +// Use of this software is governed by the CockroachDB Software License +// included in the /LICENSE file. + +package taskset + +type taskSpan struct { + start TaskID + end TaskID +} + +func (t *taskSpan) size() int64 { + return int64(t.end) - int64(t.start) +}