Skip to content

Commit

Permalink
Fix race condition in FlatMapObservable completion handler
Browse files Browse the repository at this point in the history
  • Loading branch information
tkroman authored and rozza committed Jul 29, 2021
1 parent 05fdd7e commit f5c00d7
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 28 deletions.
Expand Up @@ -18,72 +18,99 @@ package org.mongodb.scala.internal

import org.mongodb.scala._

import java.util.concurrent.atomic.AtomicReference

sealed trait State
case object Init extends State
case class WaitingOnChild(s: Subscription) extends State
case object LastChildNotified extends State
case object LastChildResponded extends State
case object Done extends State
case object Error extends State

private[scala] case class FlatMapObservable[T, S](observable: Observable[T], f: T => Observable[S])
extends Observable[S] {

// scalastyle:off cyclomatic.complexity method.length
override def subscribe(observer: Observer[_ >: S]): Unit = {
observable.subscribe(
SubscriptionCheckingObserver(
new Observer[T] {

@volatile
private var outerSubscription: Option[Subscription] = None
@volatile
private var nestedSubscription: Option[Subscription] = None
@volatile
private var demand: Long = 0
@volatile
private var onCompleteCalled: Boolean = false
@volatile private var outerSubscription: Option[Subscription] = None
@volatile private var demand: Long = 0
private val state = new AtomicReference[State](Init)

override def onSubscribe(subscription: Subscription): Unit = {
val masterSub = new Subscription() {
override def isUnsubscribed: Boolean = subscription.isUnsubscribed

def request(n: Long): Unit = {
override def unsubscribe(): Unit = subscription.unsubscribe()
override def request(n: Long): Unit = {
require(n > 0L, s"Number requested must be greater than zero: $n")
val localDemand = addDemand(n)
val (sub, num) = nestedSubscription.map((_, localDemand)).getOrElse((subscription, 1L))
sub.request(num)
state.get() match {
case Init => subscription.request(1L)
case WaitingOnChild(s) => s.request(localDemand)
case _ => // noop
}
}

override def unsubscribe(): Unit = subscription.unsubscribe()
}

outerSubscription = Some(masterSub)
state.set(Init)
observer.onSubscribe(masterSub)
}

override def onComplete(): Unit = {
if (!onCompleteCalled) {
onCompleteCalled = true
if (nestedSubscription.isEmpty) observer.onComplete()
state.get() match {
case Done => // ok
case Error => // ok
case Init if state.compareAndSet(Init, Done) =>
observer.onComplete()
case w @ WaitingOnChild(_) if state.compareAndSet(w, LastChildNotified) =>
// letting the child know that we delegate onComplete call to it
case LastChildNotified =>
// wait for the child to do the delegated onCompleteCall
case LastChildResponded if state.compareAndSet(LastChildResponded, Done) =>
observer.onComplete()
case other =>
// state machine is broken, let's fail
// normally this won't happen
throw new IllegalStateException(s"Unexpected state in FlatMapObservable `onComplete` handler: ${other}")
}
}

override def onError(throwable: Throwable): Unit = observer.onError(throwable)
override def onError(throwable: Throwable): Unit = {
observer.onError(throwable)
}

override def onNext(tResult: T): Unit = {
f(tResult).subscribe(
new Observer[S]() {
override def onError(throwable: Throwable): Unit = {
nestedSubscription = None
state.set(Error)
observer.onError(throwable)
}

override def onSubscribe(subscription: Subscription): Unit = {
nestedSubscription = Some(subscription)
state.set(WaitingOnChild(subscription))
if (demand > 0) subscription.request(demand)
}

override def onComplete(): Unit = {
nestedSubscription = None
onCompleteCalled match {
case true => observer.onComplete()
case false if demand > 0 =>
state.get() match {
case Done => // no need to call parent's onComplete
case Error => // no need to call parent's onComplete
case LastChildNotified if state.compareAndSet(LastChildNotified, LastChildResponded) =>
// parent told us to call onComplete
observer.onComplete()
case _ if demand > 0 =>
// otherwise we are not the last child, let's tell the parent
// it's not dealing with us anymore.
// Init -> * will be handled by possible later items in the stream
state.set(Init)
addDemand(-1) // reduce demand by 1 as it will be incremented by the outerSubscription
outerSubscription.foreach(_.request(1))
case false => // No more demand
case _ =>
// no demand
state.set(Init)
}
}

Expand Down
@@ -0,0 +1,38 @@
package org.mongodb.scala.internal

import org.mongodb.scala.{ BaseSpec, Observable, Observer }
import org.scalatest.concurrent.{ Eventually, Futures }

import java.util.concurrent.atomic.AtomicInteger
import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.{ Future, Promise }

class FlatMapObservableTest extends BaseSpec with Futures with Eventually {
"FlatMapObservable" should "only complete once" in {
val p = Promise[Unit]()
val completedCounter = new AtomicInteger(0)
Observable(1 to 100)
.flatMap(
x =>
(observer: Observer[_ >: Int]) => {
Future(()).onComplete(_ => {
observer.onNext(x)
observer.onComplete()
})
}
)
.subscribe(
_ => (),
p.failure,
() => {
completedCounter.incrementAndGet()
Thread.sleep(100)
p.trySuccess(())
}
)
eventually(assert(completedCounter.get() == 1, s"${completedCounter.get()}"))
Thread.sleep(200)
assert(completedCounter.get() == 1, s"${completedCounter.get()}")
Thread.sleep(1000)
}
}

0 comments on commit f5c00d7

Please sign in to comment.