Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add collectWhile observable #945

Merged
merged 1 commit into from Jul 11, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -865,6 +865,16 @@ abstract class Observable[+A] extends Serializable { self =>
final def collect[B](pf: PartialFunction[A, B]): Observable[B] =
self.liftByOperator(new CollectOperator(pf))

/** Takes longest prefix of elements that satisfy the given partial function
* and returns a new Observable that emits those elements.
*
* @param pf the function that filters and maps the source
* @return an observable that emits the transformed items by the
* given partial function until it is contained in the function's domain
*/
final def collectWhile[B](pf: PartialFunction[A, B]): Observable[B] =
self.liftByOperator(new CollectWhileOperator(pf))

/** Creates a new observable from the source and another given
* observable, by emitting elements combined in pairs.
*
Expand Down
@@ -0,0 +1,83 @@
/*
* Copyright (c) 2014-2019 by The Monix Project Developers.
* See the project homepage at: https://monix.io
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package monix.reactive.internal.operators

import monix.execution.Ack
import monix.execution.Ack.Stop

import scala.util.control.NonFatal
import monix.reactive.Observable.Operator
import monix.reactive.internal.operators.CollectOperator.{checkFallback, isDefined}
import monix.reactive.observers.Subscriber

import scala.concurrent.Future

private[reactive] final class CollectWhileOperator[-A, +B](pf: PartialFunction[A, B]) extends Operator[A, B] {

def apply(out: Subscriber[B]): Subscriber[A] =
new Subscriber[A] {
implicit val scheduler = out.scheduler
private[this] var isActive = true

def onNext(elem: A): Future[Ack] = {
if (!isActive) Stop
else {
// Protects calls to user code from within an operator
var streamError = true
try {
val next = pf.applyOrElse(elem, checkFallback[B])
if (isDefined(next)) {
streamError = false
out.onNext(next)
} else {
isActive = false
out.onComplete()
Stop
}
} catch {
case NonFatal(ex) if streamError =>
onError(ex)
Stop
}
}
}

def onComplete() =
if (isActive) {
isActive = false
out.onComplete()
}

def onError(ex: Throwable) =
if (isActive) {
isActive = false
out.onError(ex)
}
}
}

private object CollectWhileOperator extends (Any => Any) {
/** In the case a partial function is not defined, return a magic fallback value. */
def checkFallback[B]: Any => B = this.asInstanceOf[Any => B]

/** Indicates whether the result is the magic fallback value. */
def isDefined(result: Any): Boolean = result.asInstanceOf[AnyRef] ne this

/** Always returns `this`, used as the magic fallback value. */
override def apply(elem: Any): Any = this
}
@@ -0,0 +1,135 @@
/*
* Copyright (c) 2014-2019 by The Monix Project Developers.
* See the project homepage at: https://monix.io
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package monix.reactive.internal.operators

import cats.laws._
import cats.laws.discipline._
import monix.execution.Ack
import monix.execution.Ack.Continue
import monix.execution.exceptions.DummyException
import monix.reactive.{Observable, Observer}

import scala.concurrent.duration._
import scala.concurrent.duration.Duration.Zero
import scala.concurrent.{Future, Promise}

object CollectWhileSuite extends BaseOperatorSuite {
def sum(sourceCount: Int): Long =
sourceCount.toLong * (sourceCount + 1) / 2

def count(sourceCount: Int) =
sourceCount

def createObservable(sourceCount: Int) = {
require(sourceCount > 0, "sourceCount should be strictly positive")
Some {
val o =
if (sourceCount == 1)
Observable.range(1, 10).collectWhile { case x if x <= 1 => x } else
Observable.range(1, sourceCount * 2 + 1).collectWhile { case x if x <= sourceCount => x }

Sample(o, count(sourceCount), sum(sourceCount), Zero, Zero)
}
}

def observableInError(sourceCount: Int, ex: Throwable) = {
require(sourceCount > 0, "sourceCount should be strictly positive")
Some {
val ex = DummyException("dummy")
val o = createObservableEndingInError(Observable.range(1, sourceCount + 1), ex).collectWhile {
case x if x <= sourceCount * 2 => x
}

Sample(o, count(sourceCount), sum(sourceCount), Zero, Zero)
}
}

def brokenUserCodeObservable(sourceCount: Int, ex: Throwable) = {
require(sourceCount > 0, "sourceCount should be strictly positive")
Some {
val ex = DummyException("dummy")
val o = Observable.range(1, sourceCount * 2).collectWhile {
case x if x < sourceCount => x
case _ => throw ex
}

Sample(o, count(sourceCount - 1), sum(sourceCount - 1), Zero, Zero)
}
}

override def cancelableObservables(): Seq[Sample] = {
val s = Observable.range(1, 10).delayOnNext(1.second).collectWhile {
case x if x <= 1 => x
}
Seq(Sample(s, 0, 0, 0.seconds, 0.seconds))
}

test("should not call onComplete multiple times for 1 element") { implicit s =>
val p = Promise[Continue.type]()
var wasCompleted = 0

createObservable(1) match {
case ref @ Some(Sample(obs, count, sum, waitForFirst, waitForNext)) =>
var onNextReceived = false

obs.unsafeSubscribeFn(new Observer[Long] {
def onNext(elem: Long): Future[Ack] = { onNextReceived = true; p.future }
def onError(ex: Throwable): Unit = throw new IllegalStateException()
def onComplete(): Unit = wasCompleted += 1
})

s.tick(waitForFirst)
assert(onNextReceived)
p.success(Continue)
s.tick(waitForNext)
assertEquals(wasCompleted, 1)
}
}

test("should only invoke the partial function once per element") { implicit s =>
var invocationCount = 0
var result: Int = 0
var wasCompleted = false
val f: Int => Option[Int] = x => {
invocationCount += 1
if (x % 2 == 0) Some(x) else None
}
val pf = Function.unlift(f)
Observable
.now(2)
.collectWhile(pf)
.unsafeSubscribeFn(new Observer[Int] {
def onNext(elem: Int): Future[Ack] = { result = elem; Continue }
def onError(ex: Throwable): Unit = throw new IllegalStateException()
def onComplete(): Unit = wasCompleted = true
})
s.tick()
assert(wasCompleted)
assertEquals(result, 2)
assertEquals(invocationCount, 1)
}

test("Observable.collectWhile <=> Observable.takeWhile.collect ") { implicit s =>
check2 { (stream: Observable[Option[Int]], f: Int => Int) =>
val result = stream.collectWhile { case Some(x) => f(x) }.toListL
val expected = stream.takeWhile(_.isDefined).collect { case Some(x) => f(x) }.toListL

result <-> expected
}
}
}