Skip to content

Commit

Permalink
Add WaitGroup synchronization primitive (#14167)
Browse files Browse the repository at this point in the history
Co-authored-by: Johannes Müller <straightshoota@gmail.com>
Co-authored-by: Jason Frey <fryguy9@gmail.com>
Co-authored-by: Sijawusz Pur Rahnama <sija@sija.pl>
  • Loading branch information
4 people committed Apr 13, 2024
1 parent 0efbf53 commit c14fc89
Show file tree
Hide file tree
Showing 3 changed files with 311 additions and 0 deletions.
185 changes: 185 additions & 0 deletions spec/std/wait_group_spec.cr
@@ -0,0 +1,185 @@
require "spec"
require "wait_group"

private def block_until_pending_waiter(wg)
while wg.@waiting.empty?
Fiber.yield
end
end

private def forge_counter(wg, value)
wg.@counter.set(value)
end

describe WaitGroup do
describe "#add" do
it "can't decrement to a negative counter" do
wg = WaitGroup.new
wg.add(5)
wg.add(-3)
expect_raises(RuntimeError, "Negative WaitGroup counter") { wg.add(-5) }
end

it "resumes waiters when reaching negative counter" do
wg = WaitGroup.new(1)
spawn do
block_until_pending_waiter(wg)
wg.add(-2)
rescue RuntimeError
end
expect_raises(RuntimeError, "Negative WaitGroup counter") { wg.wait }
end

it "can't increment after reaching negative counter" do
wg = WaitGroup.new
forge_counter(wg, -1)

# check twice, to make sure the waitgroup counter wasn't incremented back
# to a positive value!
expect_raises(RuntimeError, "Negative WaitGroup counter") { wg.add(5) }
expect_raises(RuntimeError, "Negative WaitGroup counter") { wg.add(3) }
end
end

describe "#done" do
it "can't decrement to a negative counter" do
wg = WaitGroup.new
wg.add(1)
wg.done
expect_raises(RuntimeError, "Negative WaitGroup counter") { wg.done }
end

it "resumes waiters when reaching negative counter" do
wg = WaitGroup.new(1)
spawn do
block_until_pending_waiter(wg)
forge_counter(wg, 0)
wg.done
rescue RuntimeError
end
expect_raises(RuntimeError, "Negative WaitGroup counter") { wg.wait }
end
end

describe "#wait" do
it "immediately returns when counter is zero" do
channel = Channel(Nil).new(1)

spawn do
wg = WaitGroup.new(0)
wg.wait
channel.send(nil)
end

select
when channel.receive
# success
when timeout(1.second)
fail "expected #wait to not block the fiber"
end
end

it "immediately raises when counter is negative" do
wg = WaitGroup.new(0)
expect_raises(RuntimeError) { wg.done }
expect_raises(RuntimeError, "Negative WaitGroup counter") { wg.wait }
end

it "raises when counter is positive after wake up" do
wg = WaitGroup.new(1)
waiter = Fiber.current

spawn do
block_until_pending_waiter(wg)
waiter.enqueue
end

expect_raises(RuntimeError, "Positive WaitGroup counter (early wake up?)") { wg.wait }
end
end

it "waits until concurrent executions are finished" do
wg1 = WaitGroup.new
wg2 = WaitGroup.new

8.times do
wg1.add(16)
wg2.add(16)
exited = Channel(Bool).new(16)

16.times do
spawn do
wg1.done
wg2.wait
exited.send(true)
end
end

wg1.wait

16.times do
select
when exited.receive
fail "WaitGroup released group too soon"
else
end
wg2.done
end

16.times do
select
when x = exited.receive
x.should eq(true)
when timeout(1.millisecond)
fail "Expected channel to receive value"
end
end
end
end

it "increments the counter from executing fibers" do
wg = WaitGroup.new(16)
extra = Atomic(Int32).new(0)

16.times do
spawn do
wg.add(2)

2.times do
spawn do
extra.add(1)
wg.done
end
end

wg.done
end
end

wg.wait
extra.get.should eq(32)
end

# the test takes far too much time for the interpreter to complete
{% unless flag?(:interpreted) %}
it "stress add/done/wait" do
wg = WaitGroup.new

1000.times do
counter = Atomic(Int32).new(0)

2.times do
wg.add(1)

spawn do
counter.add(1)
wg.done
end
end

wg.wait
counter.get.should eq(2)
end
end
{% end %}
end
6 changes: 6 additions & 0 deletions src/crystal/pointer_linked_list.cr
Expand Up @@ -80,4 +80,10 @@ struct Crystal::PointerLinkedList(T)
node = _next
end
end

# Iterates the list before clearing it.
def consume_each(&) : Nil
each { |node| yield node }
@head = Pointer(T).null
end
end
120 changes: 120 additions & 0 deletions src/wait_group.cr
@@ -0,0 +1,120 @@
require "fiber"
require "crystal/spin_lock"
require "crystal/pointer_linked_list"

# Suspend execution until a collection of fibers are finished.
#
# The wait group is a declarative counter of how many concurrent fibers have
# been started. Each such fiber is expected to call `#done` to report that they
# are finished doing their work. Whenever the counter reaches zero the waiters
# will be resumed.
#
# This is a simpler and more efficient alternative to using a `Channel(Nil)`
# then looping a number of times until we received N messages to resume
# execution.
#
# Basic example:
#
# ```
# require "wait_group"
# wg = WaitGroup.new(5)
#
# 5.times do
# spawn do
# do_something
# ensure
# wg.done # the fiber has finished
# end
# end
#
# # suspend the current fiber until the 5 fibers are done
# wg.wait
# ```
class WaitGroup
private struct Waiting
include Crystal::PointerLinkedList::Node

def initialize(@fiber : Fiber)
end

def enqueue : Nil
@fiber.enqueue
end
end

def initialize(n : Int32 = 0)
@waiting = Crystal::PointerLinkedList(Waiting).new
@lock = Crystal::SpinLock.new
@counter = Atomic(Int32).new(n)
end

# Increments the counter by how many fibers we want to wait for.
#
# A negative value decrements the counter. When the counter reaches zero,
# all waiting fibers will be resumed.
# Raises `RuntimeError` if the counter reaches a negative value.
#
# Can be called at any time, allowing concurrent fibers to add more fibers to
# wait for, but they must always do so before calling `#done` that would
# decrement the counter, to make sure that the counter may never inadvertently
# reach zero before all fibers are done.
def add(n : Int32 = 1) : Nil
counter = @counter.get(:acquire)

loop do
raise RuntimeError.new("Negative WaitGroup counter") if counter < 0

counter, success = @counter.compare_and_set(counter, counter + n, :acquire_release, :acquire)
break if success
end

new_counter = counter + n
return if new_counter > 0

@lock.sync do
@waiting.consume_each do |node|
node.value.enqueue
end
end

raise RuntimeError.new("Negative WaitGroup counter") if new_counter < 0
end

# Decrements the counter by one. Must be called by concurrent fibers once they
# have finished processing. When the counter reaches zero, all waiting fibers
# will be resumed.
def done : Nil
add(-1)
end

# Suspends the current fiber until the counter reaches zero, at which point
# the fiber will be resumed.
#
# Can be called from different fibers.
def wait : Nil
return if done?

waiting = Waiting.new(Fiber.current)

@lock.sync do
# must check again to avoid a race condition where #done may have
# decremented the counter to zero between the above check and #wait
# acquiring the lock; we'd push the current fiber to the wait list that
# would never be resumed (oops)
return if done?

@waiting.push(pointerof(waiting))
end

Crystal::Scheduler.reschedule

return if done?
raise RuntimeError.new("Positive WaitGroup counter (early wake up?)")
end

private def done?
counter = @counter.get(:acquire)
raise RuntimeError.new("Negative WaitGroup counter") if counter < 0
counter == 0
end
end

0 comments on commit c14fc89

Please sign in to comment.