Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove nested listener locks #40

Merged
merged 4 commits into from
Feb 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 48 additions & 45 deletions shared/src/main/scala/async/Async.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import java.util.concurrent.atomic.AtomicBoolean
import scala.collection.mutable
import java.util.concurrent.locks.ReentrantLock
import java.util.concurrent.atomic.AtomicLong
import gears.async.Listener.{withLock, ListenerLockWrapper}
import gears.async.Listener.withLock
import gears.async.Listener.NumberedLock
import scala.util.boundary

Expand Down Expand Up @@ -136,17 +136,15 @@ object Async:
new Source[T]:
override def poll(k: Listener[T]): Boolean =
if q.isEmpty() then false
else if !k.acquireLock() then true
else
k.lockCompletely(this) match
case Listener.Gone => true
case Listener.Locked =>
val item = q.poll()
if item == null then
k.releaseLock(Listener.Locked)
false
else
k.complete(item, this)
true
val item = q.poll()
if item == null then
k.releaseLock()
false
else
k.complete(item, this)
true

override def onComplete(k: Listener[T]): Unit = poll(k)
override def dropListener(k: Listener[T]): Unit = ()
Expand All @@ -162,7 +160,7 @@ object Async:
selfSrc =>
def transform(k: Listener[U]) =
new Listener.ForwardingListener[T](selfSrc, k):
val lock = withLock(k) { inner => new ListenerLockWrapper(inner, selfSrc) }
val lock = k.lock
def complete(data: T, source: Async.Source[T]) =
k.complete(f(data), selfSrc)

Expand All @@ -185,7 +183,7 @@ object Async:
var found = false

val listener = new Listener.ForwardingListener[U](this, k):
val lock = withLock(k) { inner => new ListenerLockWrapper(inner, selfSrc) }
val lock = k.lock
def complete(data: U, source: Async.Source[U]) =
k.complete(map(data, source), selfSrc)
end listener
Expand All @@ -194,43 +192,48 @@ object Async:
found

def onComplete(k: Listener[T]): Unit =
val listener = new Listener.ForwardingListener[U](this, k)
with NumberedLock
with Listener.ListenerLock
with Listener.PartialLock { self =>
val lock = self
val listener = new Listener.ForwardingListener[U](this, k) { self =>
inline def lockIsOurs = k.lock == null
val lock =
if k.lock != null then
// if the upstream listener holds a lock already, we can utilize it.
new Listener.ListenerLock:
val selfNumber = k.lock.selfNumber
override def acquire() =
if found then false // already completed
else if !k.lock.acquire() then
if !found && !synchronized { // getAndSet alternative, avoid racing only with self here.
val old = found
found = true
old
}
then sources.foreach(_.dropListener(self)) // same as dropListener(k), but avoids an allocation
false
else if found then
k.lock.release()
false
else true
override def release() = k.lock.release()
else
new Listener.ListenerLock with NumberedLock:
val selfNumber: Long = number
def acquire() =
if found then false
else
acquireLock()
if found then
releaseLock()
// no cleanup needed here, since we have done this by an earlier `complete` or `lockNext`
false
else true
def release() =
releaseLock()

var found = false
def heldLock = if k.lock == null then Listener.Locked else this

/* == PartialLock implementation == */
// Note that this is bogus if k.lock is null, but we'll never use it if it is.
val nextNumber = if k.lock == null then -1 else k.lock.selfNumber
def lockNext() =
val res = k.lock.lockSelf(selfSrc)
if res == Listener.Gone then
found = true // This is always false before this, since PartialLock is only returned when found is false
sources.foreach(_.dropListener(this)) // same as dropListener(k), but avoids an allocation
res

/* == ListenerLock implementation == */
val selfNumber = self.number
def lockSelf(src: Async.Source[?]) =
if found then Listener.Gone
else
self.acquireLock()
if found then
self.releaseLock()
// no cleanup needed here, since we have done this by an earlier `complete` or `lockNext`
Listener.Gone
else heldLock
def release(until: Listener.LockMarker) =
self.releaseLock()
if until == heldLock then null else k.lock

def complete(item: U, src: Async.Source[U]) =
found = true
self.releaseLock()
if lockIsOurs then lock.release()
sources.foreach(s => if s != src then s.dropListener(self))
k.complete(map(item, src), selfSrc)
} // end listener
Expand Down
86 changes: 16 additions & 70 deletions shared/src/main/scala/async/Listener.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,24 +31,23 @@ trait Listener[-T]:
*/
val lock: Listener.ListenerLock | Null

/** Attempts to acquire all locks and then calling [[complete]] with the given item and source. If locking fails,
* [[releaseAll]] is automatically called.
/** Attempts to acquire locks and then calling [[complete]] with the given item and source. If locking fails,
* [[release]] is automatically called.
*/
def completeNow(data: T, source: Async.Source[T]): Boolean =
lockCompletely(source) match
case Locked =>
this.complete(data, source)
true
case Gone => false
if acquireLock() then
this.complete(data, source)
true
else false

/** Release the listener lock up to the given [[Listener.LockMarker]], if it exists. */
inline final def releaseLock(to: Listener.LockMarker): Unit = if lock != null then lock.releaseAll(to)
inline final def releaseLock(): Unit = if lock != null then lock.release()

/** Attempts to completely lock the listener, if such a lock exists. Succeeds with [[Listener.Locked]] immediately if
* there is no [[Listener.ListenerLock]]. If locking fails, [[releaseAll]] is automatically called.
/** Attempts to completely lock the listener, if such a lock exists. Succeeds with [[true]] immediately if there is no
* [[Listener.ListenerLock]]. If locking fails, [[release]] is automatically called.
*/
inline final def lockCompletely(source: Async.Source[T]): Locked.type | Gone.type =
if lock != null then lock.lockAll(source) else Locked
inline final def acquireLock(): Boolean =
if lock != null then lock.acquire() else true

object Listener:
/** A simple [[Listener]] that always accepts the item and sends it to the consumer. */
Expand All @@ -72,29 +71,6 @@ object Listener:
val lock = null
override def complete(data: T, source: Async.Source[T]) = ???

/** The result of locking a single listener lock. */
sealed trait LockResult

/** We have completed locking the listener. It can now be `complete`d. */
case object Locked extends LockResult

/** The listener is no longer available. It should be removed from the source, and any acquired locks by the source
* must be manually `release`d by the source itself.
*/
case object Gone extends LockResult

/** Locking is successful; however, there are more locks to be acquired. */
trait PartialLock extends LockResult:
/** The number of the next lock. */
val nextNumber: Long

/** Attempt to lock the next lock. */
def lockNext(): LockResult

/** Points to a position on the lock chain, whose lock up until this point has been acquired, but no further.
*/
type LockMarker = PartialLock | Locked.type

/** A lock required by a listener to be acquired before accepting values. Should there be multiple listeners that
* needs to be locked at the same time, they should be locked by larger-number-first.
*
Expand All @@ -111,47 +87,16 @@ object Listener:
*/
val selfNumber: Long

/** Attempt to lock the current [[ListenerLock]]. To try to lock all possible nesting locks, see
* [[Listener.lockCompletely]]. Locks are guaranteed to be held as short as possible.
/** Attempt to lock the current [[ListenerLock]]. Locks are guaranteed to be held as short as possible.
*/
def lockSelf(source: Async.Source[?]): LockResult
def acquire(): Boolean

/** Release the current lock without resolving the listener with any items, if the current listener lock is before
* or the same as the current [[Listener.LockMarker]]. Returns the next lock to be released, `null` if there are no
* more.
* or the same as the current [[Listener.LockMarker]].
*/
protected def release(to: Listener.LockMarker): ListenerLock | Null

/** Attempt to release all locks up to and including the given [[Listener.LockMarker]]. */
@tailrec
final def releaseAll(to: Listener.LockMarker): Unit =
val rest = release(to)
if rest != null then rest.releaseAll(to)

/** Attempt to lock all layers of this listener lock. */
private[Listener] final def lockAll(source: Async.Source[?]): Locked.type | Gone.type =
lockSelf(source) match
case Locked => Locked
case Gone => Gone
case inner: PartialLock => lockRecursively(inner)

@tailrec
private final def lockRecursively(l: Listener.PartialLock): Locked.type | Gone.type =
l.lockNext() match
case Locked => Locked
case Gone =>
this.releaseAll(l)
Gone
case inner: PartialLock => lockRecursively(inner)
def release(): Unit
end ListenerLock

/** A special wrapper for [[ListenerLock]] that just passes the source through. */
class ListenerLockWrapper(inner: ListenerLock, src: Async.Source[?]) extends ListenerLock:
val selfNumber: Long = inner.selfNumber
def lockSelf(_src: Async.Source[?]) =
inner.lockSelf(src)
def release(to: LockMarker): ListenerLock | Null = inner

/** Maps the lock of a listener, if it exists. */
inline def withLock[T](listener: Listener[?])(inline body: ListenerLock => T): T | Null =
listener.lock match
Expand All @@ -167,5 +112,6 @@ object Listener:

protected def acquireLock() = lock0.lock()
protected def releaseLock() = lock0.unlock()

object NumberedLock:
private val listenerNumber = java.util.concurrent.atomic.AtomicLong()
4 changes: 2 additions & 2 deletions shared/src/main/scala/async/channels.scala
Original file line number Diff line number Diff line change
Expand Up @@ -280,8 +280,8 @@ object Channel:
false

private inline def tryComplete(src: CanSend, s: Sender)(r: Reader): s.type | r.type | Unit =
lockBoth(readSource, src)(r, s) match
case Listener.Locked =>
lockBoth(r, s) match
case true =>
Impl.this.complete(src, r, s)
dequeue() // drop completed reader/sender from queue
()
Expand Down
85 changes: 19 additions & 66 deletions shared/src/main/scala/async/listeners/locking.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,15 @@
package gears.async.listeners

import gears.async._
import Listener.{Locked, ListenerLock, Gone, PartialLock, LockMarker, LockResult}
import Listener.ListenerLock
import scala.annotation.tailrec

/** Two listeners being locked at the same time, while holding the same lock on their listener chains. This happens if
* you attempt to lockBoth two listeners with a common downstream listener, e.g., two derived listeners of the same
* race.
*/
case class ConflictingLocksException(
base: (Listener[?], Listener[?]),
conflict: ((ListenerLock | PartialLock), (ListenerLock | PartialLock))
listeners: (Listener[?], Listener[?])
) extends Exception

/** Attempt to lock both listeners belonging to possibly different sources at the same time. Lock orders are respected
Expand All @@ -22,71 +21,25 @@ case class ConflictingLocksException(
* In the case that two locks sharing the same number is encountered, [[ConflictingLocksException]] is thrown with the
* base listeners and conflicting listeners.
*/
def lockBoth[T, U](st: Async.Source[T], su: Async.Source[U])(
def lockBoth[T, U](
lt: Listener[T],
lu: Listener[U]
): lt.type | lu.type | Locked.type =
/* Step 1: weed out non-locking listeners */
inline def lockedOr[V >: Locked.type](cause: lt.type | lu.type)(inline body: V) =
if body == Locked then Locked else cause
val tlt = lt.lock match
case tl: ListenerLock => tl
case null => return lockedOr(lu) { lu.lockCompletely(su) }
val tlu = lu.lock match
case tl: ListenerLock => tl
case null => return lockedOr(lt) { lt.lockCompletely(st) }
): lt.type | lu.type | true =
val lockT = if lt.lock == null then return (if lu.acquireLock() then true else lu) else lt.lock
val lockU = if lu.lock == null then return (if lt.acquireLock() then true else lt) else lu.lock

/* Attempts to advance locking one by one. */
@tailrec
def loop(mt: LockMarker, mu: LockMarker): lt.type | lu.type | Locked.type =
inline def advanceSu(su: PartialLock): lt.type | lu.type | Locked.type = su.lockNext() match
case Gone => { tlt.releaseAll(mt); tlu.releaseAll(mu); lu }
case v: LockMarker => loop(mt, v)
(mt, mu) match
case (Locked, Locked) => Locked
case (Locked, su: PartialLock) => advanceSu(su)
case (st: PartialLock, su: PartialLock) if st.nextNumber == su.nextNumber =>
tlt.releaseAll(mt); tlu.releaseAll(mu)
throw ConflictingLocksException((lt, lu), (st, su))
case (st: PartialLock, su: PartialLock) if st.nextNumber < su.nextNumber => advanceSu(su)
case (st: PartialLock, _) =>
st.lockNext() match
case Gone => { tlt.releaseAll(mt); tlu.releaseAll(mu); lt }
case v: LockMarker => loop(v, mu)
inline def doLock[T, U](lt: Listener[T], lu: Listener[U])(
lockT: ListenerLock,
lockU: ListenerLock
): lt.type | lu.type | true =
// assert(lockT.number > lockU.number)
if !lockT.acquire() then lt
else if !lockU.acquire() then
lockT.release()
lu
else true

/* Attempt to lock the ListenerLock and advance until we start needing to lock the other one. */
inline def lockUntilLessThan(other: ListenerLock)(src: Async.Source[?], tl: ListenerLock): LockResult =
@tailrec def loop(v: LockMarker): LockResult =
v match
case Locked => Locked
case v: PartialLock if v.nextNumber == other.selfNumber =>
tl.releaseAll(v)
throw ConflictingLocksException((lt, lu), if lt.lock == other then (other, v) else (v, other))
case v: PartialLock if v.nextNumber < other.selfNumber => v
case v: PartialLock =>
v.lockNext() match
case Gone => tl.releaseAll(v); Gone
case m: LockMarker => loop(m)
tl.lockSelf(src) match
case Gone => Gone
case m: LockMarker => loop(m)

/* We have to do the first locking step manually. */
if tlt.selfNumber == tlu.selfNumber then throw ConflictingLocksException((lt, lu), (tlt, tlu))
else if tlt.selfNumber > tlu.selfNumber then
val mt = lockUntilLessThan(tlu)(st, tlt) match
case Gone => return lt
case v: LockMarker => v
val mu = tlu.lockSelf(su) match
case Gone => { tlt.releaseAll(mt); return lu }
case v: LockMarker => v
loop(mt, mu)
else
val mu = lockUntilLessThan(tlt)(su, tlu) match
case Gone => return lu
case v: LockMarker => v
val mt = tlt.lockSelf(st) match
case Gone => { tlu.releaseAll(mu); return lt }
case v: LockMarker => v
loop(mt, mu)
if lockT.selfNumber == lockU.selfNumber then throw ConflictingLocksException((lt, lu))
else if lockT.selfNumber > lockU.selfNumber then doLock(lt, lu)(lockT, lockU)
else doLock(lu, lt)(lockU, lockT)
end lockBoth
Loading