Skip to content

Commit

Permalink
fix: deadlock in Monitor, pass 2
Browse files Browse the repository at this point in the history
  • Loading branch information
oleg-py committed May 8, 2019
1 parent 77a85ae commit aeb146c
Show file tree
Hide file tree
Showing 8 changed files with 89 additions and 62 deletions.
8 changes: 8 additions & 0 deletions js/src/test/scala/com/olegpy/stm/SingleThreadECImpl.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package com.olegpy.stm

import scala.concurrent.ExecutionContext


trait SingleThreadECImpl extends SingleThreadEC {
override val singleThread: ExecutionContext = ExecutionContext.global
}
10 changes: 10 additions & 0 deletions jvm/src/test/scala/com/olegpy/stm/SingleThreadECImpl.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package com.olegpy.stm

import scala.concurrent.ExecutionContext

import java.util.concurrent.Executors


trait SingleThreadECImpl extends SingleThreadEC {
override val singleThread: ExecutionContext = ExecutionContext.fromExecutor(Executors.newSingleThreadExecutor())
}
81 changes: 36 additions & 45 deletions shared/src/main/scala/com/olegpy/stm/internal/Monitor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,18 @@ class Monitor private[stm] () {
private[this] val store: Store = /*_*/Store.forPlatform()/*_*/
private[this] val rightUnit = Right(())

private[this] class RetryCallback (catsCb: Callback, keys: Iterable[AnyRef]) {
keys.foreach(addToSet(_, this))
private[this] class RetryCallback(keys: Iterable[AnyRef]) {
@volatile var catsCb: Callback = _
keys.foreach(listenTo)
def invoke(): Unit = {
catsCb(rightUnit)
}

def listenTo(key: AnyRef): Unit = {
addToSet(key, this)
addToSet(this, key)
}

private[this] def addToSet(key: AnyRef, value: Any): Unit = {
val j = store.current()
j.update(key, j.read(key) match {
Expand All @@ -34,66 +40,51 @@ class Monitor private[stm] () {
}

def removeAllKeys(): Unit = {
keys.foreach(removeFromSet(_, this))
}
}

def waitOn[F[_]](keys: Iterable[AnyRef])(implicit F: Concurrent[F]): F[Unit] = Concurrent.cancelableF[F, Unit] { cb =>
store.transact {
val ks = keys.toSet
store.current().read(this) match {
case l: List[Set[AnyRef] @unchecked] if l.exists(_.exists(ks)) =>
null
val j = store.current()
j.read(this) match {
case set: Set[AnyRef @unchecked] => set.foreach(removeFromSet(_, this))
// This might not be hit in a single test run, avoid fluctuating coverage
// $COVERAGE-OFF$
case _ =>
val retryCallback = new RetryCallback(cb, keys)
F.delay { store.transact(retryCallback.removeAllKeys()) }
// $COVERAGE-ON$
}
} match {
case null =>
F.delay { cb(rightUnit) }.start.map(_.cancel)
case token: F[Unit] @unchecked => F.pure(token)
j.update(this, null)
}
}

def notifyOn[F[_]](keys: Iterable[AnyRef])(implicit F: Concurrent[F]): F[Unit] = F.suspend {
val ks = keys.toSet
store.transact(store.current().update(this, 0))

def unregister(): Unit = {
store.transact {
val j = store.current()
j.read(this) match {
case l: List[Set[AnyRef] @unchecked] => j.update(this, l.filterNot(_ == ks))
case _ =>
def lastNotify: Int = store.transact { lastNotifyT }
private def lastNotifyT = store.current().read(this).asInstanceOf[Int]

def waitOn[F[_]](ln: Int, keys: Iterable[AnyRef])(implicit F: Concurrent[F]): F[Unit] =
F.suspend { store.transact {
if (ln == lastNotifyT) {
val rc = new RetryCallback(keys)
F.cancelable[Unit] { cb =>
rc.catsCb = cb
F.delay { store.transact(rc.removeAllKeys()) }
}
}
}
} else F.unit
}}

val jobs = store.transact {
def notifyOn[F[_]](keys: Iterable[AnyRef])(implicit F: Concurrent[F]): F[Unit] = F.suspend {
store.transact {
val j = store.current()
j.update(this, j.read(this) match {
case l: List[Set[AnyRef] @unchecked] => ks :: l
case _ => ks :: Nil
})
j.update(this, lastNotifyT + 1)
val cbs = Set.newBuilder[RetryCallback]
keys.foreach { key =>
j.read(key) match {
case s: Set[RetryCallback @unchecked] => cbs ++= s
case _ =>
}
}

val allJobs = cbs.result()
allJobs.foreach(_.removeAllKeys())
allJobs
}

if (jobs.isEmpty) {
unregister()
F.unit
if (allJobs.isEmpty) F.unit
else {
allJobs.foreach(_.removeAllKeys())
F.delay(allJobs.foreach(_.invoke())).start.void
}
}
else F.delay {
unregister()
jobs.foreach(_.invoke())
}.start.void
}
}
3 changes: 2 additions & 1 deletion shared/src/main/scala/com/olegpy/stm/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ package object stm {

private[this] def atomicallyImpl[F[_]: Concurrent, A](stm: STM[A]): F[A] =
Concurrent[F].suspend {
val waitId = globalLock.lastNotify
var journal: Store.Journal = null
try {
val result = store.transact {
Expand All @@ -108,7 +109,7 @@ package object stm {
} catch { case Retry =>
val rk = journal.readKeys
if (rk.isEmpty) throw new PotentialDeadlockException
globalLock.waitOn[F](rk) >> atomicallyImpl[F, A](stm)
globalLock.waitOn[F](waitId, rk) >> atomicallyImpl[F, A](stm)
}
}
}
Expand Down
12 changes: 6 additions & 6 deletions shared/src/test/scala/com/olegpy/stm/BaseIOSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ import cats.implicits._
import cats.effect.{ContextShift, IO, Timer}
import utest._

trait BaseIOSuite extends TestSuite {
def ec: ExecutionContext = ExecutionContext.global//.fromExecutor(Executors.newSingleThreadExecutor())
implicit val cs: ContextShift[IO] = IO.contextShift(ec)
implicit val timer: Timer[IO] = IO.timer(ec)
trait BaseIOSuite extends TestSuite with SingleThreadEC with SingleThreadECImpl {
def ec: ExecutionContext = ExecutionContext.global
implicit def cs: ContextShift[IO] = IO.contextShift(ec)
implicit def timer: Timer[IO] = IO.timer(ec)

val number = 42

Expand All @@ -26,9 +26,9 @@ trait BaseIOSuite extends TestSuite {
def ioTestTimed[A](timeout: FiniteDuration)(io: IO[A]): Future[A] =
io.timeout(timeout).unsafeToFuture()

def nap: IO[Unit] = IO.sleep(10.millis)
def nap(implicit t: Timer[IO]): IO[Unit] = t.sleep(10.millis)

def longNap: IO[Unit] = nap.replicateA(10).void
def longNap(implicit t: Timer[IO]): IO[Unit] = nap(t).replicateA(10).void

def fail[A]: IO[A] = IO.suspend {
assert(false)
Expand Down
23 changes: 16 additions & 7 deletions shared/src/test/scala/com/olegpy/stm/RetryTests.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package com.olegpy.stm

import cats.effect.{ExitCase, IO, SyncIO}
import cats.effect.{ContextShift, ExitCase, IO, SyncIO, Timer}
import utest._
import cats.implicits._

Expand Down Expand Up @@ -70,6 +70,9 @@ object RetryTests extends TestSuite with BaseIOSuite {
}

"retries are not triggered by writes to independent variables" - {
implicit val cs: ContextShift[IO] = IO.contextShift(singleThread)
implicit val timer: Timer[IO] = IO.timer(singleThread)

@volatile var count = 0
val r1, r2, r3 = TRef.in[SyncIO](0).unsafeRunSync()
val txn: STM[Unit] = for {
Expand All @@ -80,17 +83,23 @@ object RetryTests extends TestSuite with BaseIOSuite {
_ <- r3.get // after-check gets should not affect anything
} yield ()

def later(block: => Unit): IO[Unit] = nap >> nap >> IO(block)
val isJS = ().toString != "()"

// Use lax checking for JVM, where CPU black magic is more prominent
def later(expect: Int): IO[Unit] = nap(timer) >> {
if (isJS) IO(assert(count == expect))
else IO(assert((expect - 2).to(expect + 2) contains count))
}

for {
f <- txn.commit[IO].start
_ <- later { count ==> 1 } // Tried once, but failed
f <- txn.commit[IO].start(cs)
_ <- later(1) // Tried once, but failed
_ <- r1.set(number).commit[IO]
_ <- later { count ==> 2 } // Tried twice, as we modified r1
_ <- later(2) // Tried twice, as we modified r1
_ <- r3.set(number).commit[IO]
_ <- later { count ==> 2 } // Didn't try again, as we didn't touch r1 or r2
_ <- later(2) // Didn't try again, as we didn't touch r1 or r2
_ <- r2.set(number + 1).commit[IO]
_ <- later { count ==> 3 } // Tried again, and should complete at this point
_ <- later(3) // Tried again, and should complete at this point
_ <- f.join
} yield ()
}
Expand Down
8 changes: 8 additions & 0 deletions shared/src/test/scala/com/olegpy/stm/SingleThreadEC.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package com.olegpy.stm

import scala.concurrent.ExecutionContext


trait SingleThreadEC {
def singleThread: ExecutionContext = sys.error("Not overriden")
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import cats.implicits._


object CigaretteSmokersProblem extends TestSuite with BaseIOSuite {
override def ioTimeout: FiniteDuration = 10.seconds
override def ioTimeout: FiniteDuration = 3.seconds

val tests = Tests {
"Cigarette smokers problem" - {
Expand Down Expand Up @@ -41,14 +41,14 @@ object CigaretteSmokersProblem extends TestSuite with BaseIOSuite {

class Table(queue: TQueue[Ingredient]) {
def put(ingredient: Ingredient): STM[Unit] = queue.enqueue(ingredient)
def things: STM[Set[Ingredient]] = queue.dequeue.replicateA(2).map(_.toSet)
def takeThings: STM[Set[Ingredient]] = queue.dequeue.replicateA(2).map(_.toSet)
}

def mkTable: IO[Table] = TQueue.boundedIn[IO, Ingredient](2).map(new Table(_))

class Smoker (ingredient: Ingredient, table: Table) {
def buildACig(puff: IO[Unit]): IO[Unit] =
table.things.filterNot(_(ingredient)).commit[IO] >> puff
table.takeThings.filterNot(_ contains ingredient).commit[IO] >> puff
}

class Dealer(table: Table) {
Expand Down

0 comments on commit aeb146c

Please sign in to comment.