Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add WaitGroup synchronization primitive (#14167)
Co-authored-by: Johannes Müller <straightshoota@gmail.com> Co-authored-by: Jason Frey <fryguy9@gmail.com> Co-authored-by: Sijawusz Pur Rahnama <sija@sija.pl>
- Loading branch information
1 parent
0efbf53
commit c14fc89
Showing
3 changed files
with
311 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,185 @@ | ||
require "spec" | ||
require "wait_group" | ||
|
||
private def block_until_pending_waiter(wg) | ||
while wg.@waiting.empty? | ||
Fiber.yield | ||
end | ||
end | ||
|
||
private def forge_counter(wg, value) | ||
wg.@counter.set(value) | ||
end | ||
|
||
describe WaitGroup do | ||
describe "#add" do | ||
it "can't decrement to a negative counter" do | ||
wg = WaitGroup.new | ||
wg.add(5) | ||
wg.add(-3) | ||
expect_raises(RuntimeError, "Negative WaitGroup counter") { wg.add(-5) } | ||
end | ||
|
||
it "resumes waiters when reaching negative counter" do | ||
wg = WaitGroup.new(1) | ||
spawn do | ||
block_until_pending_waiter(wg) | ||
wg.add(-2) | ||
rescue RuntimeError | ||
end | ||
expect_raises(RuntimeError, "Negative WaitGroup counter") { wg.wait } | ||
end | ||
|
||
it "can't increment after reaching negative counter" do | ||
wg = WaitGroup.new | ||
forge_counter(wg, -1) | ||
|
||
# check twice, to make sure the waitgroup counter wasn't incremented back | ||
# to a positive value! | ||
expect_raises(RuntimeError, "Negative WaitGroup counter") { wg.add(5) } | ||
expect_raises(RuntimeError, "Negative WaitGroup counter") { wg.add(3) } | ||
end | ||
end | ||
|
||
describe "#done" do | ||
it "can't decrement to a negative counter" do | ||
wg = WaitGroup.new | ||
wg.add(1) | ||
wg.done | ||
expect_raises(RuntimeError, "Negative WaitGroup counter") { wg.done } | ||
end | ||
|
||
it "resumes waiters when reaching negative counter" do | ||
wg = WaitGroup.new(1) | ||
spawn do | ||
block_until_pending_waiter(wg) | ||
forge_counter(wg, 0) | ||
wg.done | ||
rescue RuntimeError | ||
end | ||
expect_raises(RuntimeError, "Negative WaitGroup counter") { wg.wait } | ||
end | ||
end | ||
|
||
describe "#wait" do | ||
it "immediately returns when counter is zero" do | ||
channel = Channel(Nil).new(1) | ||
|
||
spawn do | ||
wg = WaitGroup.new(0) | ||
wg.wait | ||
channel.send(nil) | ||
end | ||
|
||
select | ||
when channel.receive | ||
# success | ||
when timeout(1.second) | ||
fail "expected #wait to not block the fiber" | ||
end | ||
end | ||
|
||
it "immediately raises when counter is negative" do | ||
wg = WaitGroup.new(0) | ||
expect_raises(RuntimeError) { wg.done } | ||
expect_raises(RuntimeError, "Negative WaitGroup counter") { wg.wait } | ||
end | ||
|
||
it "raises when counter is positive after wake up" do | ||
wg = WaitGroup.new(1) | ||
waiter = Fiber.current | ||
|
||
spawn do | ||
block_until_pending_waiter(wg) | ||
waiter.enqueue | ||
end | ||
|
||
expect_raises(RuntimeError, "Positive WaitGroup counter (early wake up?)") { wg.wait } | ||
end | ||
end | ||
|
||
it "waits until concurrent executions are finished" do | ||
wg1 = WaitGroup.new | ||
wg2 = WaitGroup.new | ||
|
||
8.times do | ||
wg1.add(16) | ||
wg2.add(16) | ||
exited = Channel(Bool).new(16) | ||
|
||
16.times do | ||
spawn do | ||
wg1.done | ||
wg2.wait | ||
exited.send(true) | ||
end | ||
end | ||
|
||
wg1.wait | ||
|
||
16.times do | ||
select | ||
when exited.receive | ||
fail "WaitGroup released group too soon" | ||
else | ||
end | ||
wg2.done | ||
end | ||
|
||
16.times do | ||
select | ||
when x = exited.receive | ||
x.should eq(true) | ||
when timeout(1.millisecond) | ||
fail "Expected channel to receive value" | ||
end | ||
end | ||
end | ||
end | ||
|
||
it "increments the counter from executing fibers" do | ||
wg = WaitGroup.new(16) | ||
extra = Atomic(Int32).new(0) | ||
|
||
16.times do | ||
spawn do | ||
wg.add(2) | ||
|
||
2.times do | ||
spawn do | ||
extra.add(1) | ||
wg.done | ||
end | ||
end | ||
|
||
wg.done | ||
end | ||
end | ||
|
||
wg.wait | ||
extra.get.should eq(32) | ||
end | ||
|
||
# the test takes far too much time for the interpreter to complete | ||
{% unless flag?(:interpreted) %} | ||
it "stress add/done/wait" do | ||
wg = WaitGroup.new | ||
|
||
1000.times do | ||
counter = Atomic(Int32).new(0) | ||
|
||
2.times do | ||
wg.add(1) | ||
|
||
spawn do | ||
counter.add(1) | ||
wg.done | ||
end | ||
end | ||
|
||
wg.wait | ||
counter.get.should eq(2) | ||
end | ||
end | ||
{% end %} | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
require "fiber" | ||
require "crystal/spin_lock" | ||
require "crystal/pointer_linked_list" | ||
|
||
# Suspend execution until a collection of fibers are finished. | ||
# | ||
# The wait group is a declarative counter of how many concurrent fibers have | ||
# been started. Each such fiber is expected to call `#done` to report that they | ||
# are finished doing their work. Whenever the counter reaches zero the waiters | ||
# will be resumed. | ||
# | ||
# This is a simpler and more efficient alternative to using a `Channel(Nil)` | ||
# then looping a number of times until we received N messages to resume | ||
# execution. | ||
# | ||
# Basic example: | ||
# | ||
# ``` | ||
# require "wait_group" | ||
# wg = WaitGroup.new(5) | ||
# | ||
# 5.times do | ||
# spawn do | ||
# do_something | ||
# ensure | ||
# wg.done # the fiber has finished | ||
# end | ||
# end | ||
# | ||
# # suspend the current fiber until the 5 fibers are done | ||
# wg.wait | ||
# ``` | ||
class WaitGroup | ||
private struct Waiting | ||
include Crystal::PointerLinkedList::Node | ||
|
||
def initialize(@fiber : Fiber) | ||
end | ||
|
||
def enqueue : Nil | ||
@fiber.enqueue | ||
end | ||
end | ||
|
||
def initialize(n : Int32 = 0) | ||
@waiting = Crystal::PointerLinkedList(Waiting).new | ||
@lock = Crystal::SpinLock.new | ||
@counter = Atomic(Int32).new(n) | ||
end | ||
|
||
# Increments the counter by how many fibers we want to wait for. | ||
# | ||
# A negative value decrements the counter. When the counter reaches zero, | ||
# all waiting fibers will be resumed. | ||
# Raises `RuntimeError` if the counter reaches a negative value. | ||
# | ||
# Can be called at any time, allowing concurrent fibers to add more fibers to | ||
# wait for, but they must always do so before calling `#done` that would | ||
# decrement the counter, to make sure that the counter may never inadvertently | ||
# reach zero before all fibers are done. | ||
def add(n : Int32 = 1) : Nil | ||
counter = @counter.get(:acquire) | ||
|
||
loop do | ||
raise RuntimeError.new("Negative WaitGroup counter") if counter < 0 | ||
|
||
counter, success = @counter.compare_and_set(counter, counter + n, :acquire_release, :acquire) | ||
break if success | ||
end | ||
|
||
new_counter = counter + n | ||
return if new_counter > 0 | ||
|
||
@lock.sync do | ||
@waiting.consume_each do |node| | ||
node.value.enqueue | ||
end | ||
end | ||
|
||
raise RuntimeError.new("Negative WaitGroup counter") if new_counter < 0 | ||
end | ||
|
||
# Decrements the counter by one. Must be called by concurrent fibers once they | ||
# have finished processing. When the counter reaches zero, all waiting fibers | ||
# will be resumed. | ||
def done : Nil | ||
add(-1) | ||
end | ||
|
||
# Suspends the current fiber until the counter reaches zero, at which point | ||
# the fiber will be resumed. | ||
# | ||
# Can be called from different fibers. | ||
def wait : Nil | ||
return if done? | ||
|
||
waiting = Waiting.new(Fiber.current) | ||
|
||
@lock.sync do | ||
# must check again to avoid a race condition where #done may have | ||
# decremented the counter to zero between the above check and #wait | ||
# acquiring the lock; we'd push the current fiber to the wait list that | ||
# would never be resumed (oops) | ||
return if done? | ||
|
||
@waiting.push(pointerof(waiting)) | ||
end | ||
|
||
Crystal::Scheduler.reschedule | ||
|
||
return if done? | ||
raise RuntimeError.new("Positive WaitGroup counter (early wake up?)") | ||
end | ||
|
||
private def done? | ||
counter = @counter.get(:acquire) | ||
raise RuntimeError.new("Negative WaitGroup counter") if counter < 0 | ||
counter == 0 | ||
end | ||
end |