diff --git a/driver-scala/src/main/scala/org/mongodb/scala/internal/FlatMapObservable.scala b/driver-scala/src/main/scala/org/mongodb/scala/internal/FlatMapObservable.scala index c8fba7d95d3..75aa5add735 100644 --- a/driver-scala/src/main/scala/org/mongodb/scala/internal/FlatMapObservable.scala +++ b/driver-scala/src/main/scala/org/mongodb/scala/internal/FlatMapObservable.scala @@ -18,72 +18,99 @@ package org.mongodb.scala.internal import org.mongodb.scala._ +import java.util.concurrent.atomic.AtomicReference + +sealed trait State +case object Init extends State +case class WaitingOnChild(s: Subscription) extends State +case object LastChildNotified extends State +case object LastChildResponded extends State +case object Done extends State +case object Error extends State + private[scala] case class FlatMapObservable[T, S](observable: Observable[T], f: T => Observable[S]) extends Observable[S] { - // scalastyle:off cyclomatic.complexity method.length override def subscribe(observer: Observer[_ >: S]): Unit = { observable.subscribe( SubscriptionCheckingObserver( new Observer[T] { - - @volatile - private var outerSubscription: Option[Subscription] = None - @volatile - private var nestedSubscription: Option[Subscription] = None - @volatile - private var demand: Long = 0 - @volatile - private var onCompleteCalled: Boolean = false + @volatile private var outerSubscription: Option[Subscription] = None + @volatile private var demand: Long = 0 + private val state = new AtomicReference[State](Init) override def onSubscribe(subscription: Subscription): Unit = { val masterSub = new Subscription() { override def isUnsubscribed: Boolean = subscription.isUnsubscribed - - def request(n: Long): Unit = { + override def unsubscribe(): Unit = subscription.unsubscribe() + override def request(n: Long): Unit = { require(n > 0L, s"Number requested must be greater than zero: $n") val localDemand = addDemand(n) - val (sub, num) = nestedSubscription.map((_, localDemand)).getOrElse((subscription, 1L)) - sub.request(num) + state.get() match { + case Init => subscription.request(1L) + case WaitingOnChild(s) => s.request(localDemand) + case _ => // noop + } } - - override def unsubscribe(): Unit = subscription.unsubscribe() } - outerSubscription = Some(masterSub) + state.set(Init) observer.onSubscribe(masterSub) } override def onComplete(): Unit = { - if (!onCompleteCalled) { - onCompleteCalled = true - if (nestedSubscription.isEmpty) observer.onComplete() + state.get() match { + case Done => // ok + case Error => // ok + case Init if state.compareAndSet(Init, Done) => + observer.onComplete() + case w @ WaitingOnChild(_) if state.compareAndSet(w, LastChildNotified) => + // letting the child know that we delegate onComplete call to it + case LastChildNotified => + // wait for the child to do the delegated onCompleteCall + case LastChildResponded if state.compareAndSet(LastChildResponded, Done) => + observer.onComplete() + case other => + // state machine is broken, let's fail + // normally this won't happen + throw new IllegalStateException(s"Unexpected state in FlatMapObservable `onComplete` handler: ${other}") } } - override def onError(throwable: Throwable): Unit = observer.onError(throwable) + override def onError(throwable: Throwable): Unit = { + observer.onError(throwable) + } override def onNext(tResult: T): Unit = { f(tResult).subscribe( new Observer[S]() { override def onError(throwable: Throwable): Unit = { - nestedSubscription = None + state.set(Error) observer.onError(throwable) } override def onSubscribe(subscription: Subscription): Unit = { - nestedSubscription = Some(subscription) + state.set(WaitingOnChild(subscription)) if (demand > 0) subscription.request(demand) } override def onComplete(): Unit = { - nestedSubscription = None - onCompleteCalled match { - case true => observer.onComplete() - case false if demand > 0 => + state.get() match { + case Done => // no need to call parent's onComplete + case Error => // no need to call parent's onComplete + case LastChildNotified if state.compareAndSet(LastChildNotified, LastChildResponded) => + // parent told us to call onComplete + observer.onComplete() + case _ if demand > 0 => + // otherwise we are not the last child, let's tell the parent + // it's not dealing with us anymore. + // Init -> * will be handled by possible later items in the stream + state.set(Init) addDemand(-1) // reduce demand by 1 as it will be incremented by the outerSubscription outerSubscription.foreach(_.request(1)) - case false => // No more demand + case _ => + // no demand + state.set(Init) } } diff --git a/driver-scala/src/test/scala/org/mongodb/scala/internal/FlatMapObservableTest.scala b/driver-scala/src/test/scala/org/mongodb/scala/internal/FlatMapObservableTest.scala new file mode 100644 index 00000000000..32509075475 --- /dev/null +++ b/driver-scala/src/test/scala/org/mongodb/scala/internal/FlatMapObservableTest.scala @@ -0,0 +1,38 @@ +package org.mongodb.scala.internal + +import org.mongodb.scala.{ BaseSpec, Observable, Observer } +import org.scalatest.concurrent.{ Eventually, Futures } + +import java.util.concurrent.atomic.AtomicInteger +import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.{ Future, Promise } + +class FlatMapObservableTest extends BaseSpec with Futures with Eventually { + "FlatMapObservable" should "only complete once" in { + val p = Promise[Unit]() + val completedCounter = new AtomicInteger(0) + Observable(1 to 100) + .flatMap( + x => + (observer: Observer[_ >: Int]) => { + Future(()).onComplete(_ => { + observer.onNext(x) + observer.onComplete() + }) + } + ) + .subscribe( + _ => (), + p.failure, + () => { + completedCounter.incrementAndGet() + Thread.sleep(100) + p.trySuccess(()) + } + ) + eventually(assert(completedCounter.get() == 1, s"${completedCounter.get()}")) + Thread.sleep(200) + assert(completedCounter.get() == 1, s"${completedCounter.get()}") + Thread.sleep(1000) + } +}