Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Iterant fold left operators need to handle errors thrown in F[_] context #569

Merged
merged 3 commits into from
Jan 26, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@ package monix.tail.internal

import cats.effect.Sync
import cats.syntax.all._
import monix.execution.misc.NonFatal
import monix.tail.Iterant
import monix.tail.Iterant.{Halt, Last, Next, NextBatch, NextCursor, Suspend}

import scala.collection.mutable
import monix.execution.misc.NonFatal
import scala.runtime.ObjectRef

private[tail] object IterantFoldLeft {
/**
Expand All @@ -31,37 +33,51 @@ private[tail] object IterantFoldLeft {
final def apply[F[_], S, A](source: Iterant[F, A], seed: => S)(op: (S,A) => S)
(implicit F: Sync[F]): F[S] = {

def loop(self: Iterant[F, A], state: S): F[S] = {
def loop(stopRef: ObjectRef[F[Unit]], state: S)(self: Iterant[F, A]): F[S] = {
try self match {
case Next(a, rest, _) =>
case Next(a, rest, stop) =>
stopRef.elem = stop
val newState = op(state, a)
rest.flatMap(loop(_, newState))
case NextCursor(cursor, rest, _) =>
rest.flatMap(loop(stopRef, newState))
case NextCursor(cursor, rest, stop) =>
stopRef.elem = stop
val newState = cursor.foldLeft(state)(op)
rest.flatMap(loop(_, newState))
case NextBatch(gen, rest, _) =>
rest.flatMap(loop(stopRef, newState))
case NextBatch(gen, rest, stop) =>
stopRef.elem = stop
val newState = gen.foldLeft(state)(op)
rest.flatMap(loop(_, newState))
case Suspend(rest, _) =>
rest.flatMap(loop(_, state))
rest.flatMap(loop(stopRef, newState))
case Suspend(rest, stop) =>
stopRef.elem = stop
rest.flatMap(loop(stopRef, state))
case Last(item) =>
F.pure(op(state,item))
case Halt(None) =>
F.pure(state)
case Halt(Some(ex)) =>
stopRef.elem = null.asInstanceOf[F[Unit]]
F.raiseError(ex)
} catch {
case ex if NonFatal(ex) =>
source.earlyStop *> F.raiseError(ex)
F.raiseError(ex)
}
}

F.suspend {
var catchErrors = true
try {
// handle exception in the seed
val init = seed
catchErrors = false
loop(source, init)
// Reference to keep track of latest `earlyStop` value
val stopRef = ObjectRef.create(null.asInstanceOf[F[Unit]])
// Catch-all exceptions, ensuring latest `earlyStop` gets called
F.handleErrorWith(loop(stopRef, init)(source)) { ex =>
stopRef.elem match {
case null => F.raiseError(ex)
case stop => stop *> F.raiseError(ex)
}
}
} catch {
case NonFatal(e) if catchErrors =>
source.earlyStop *> F.raiseError(e)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@ import monix.execution.misc.NonFatal
import monix.tail.Iterant.{Halt, Last, Next, NextBatch, NextCursor, Suspend}
import monix.tail.batches.BatchCursor

import scala.runtime.ObjectRef

private[tail] object IterantFoldWhileLeft {
/** Implementation for `Iterant.foldWhileLeftL`. */
def strict[F[_], A, S](self: Iterant[F, A], seed: => S, f: (S, A) => Either[S, S])
(implicit F: Sync[F]): F[S] = {

def process(state: S, cursor: BatchCursor[A], rest: F[Iterant[F, A]], stop: F[Unit]) = {
def process(stopRef: ObjectRef[F[Unit]])(state: S, cursor: BatchCursor[A], rest: F[Iterant[F, A]], stop: F[Unit]) = {
var hasResult = false
var s = state

Expand All @@ -45,103 +47,136 @@ private[tail] object IterantFoldWhileLeft {
if (hasResult)
stop.map(_ => s)
else
rest.flatMap(loop(s))
rest.flatMap(loop(stopRef, s))
}

def loop(state: S)(self: Iterant[F, A]): F[S] = {
def loop(stopRef: ObjectRef[F[Unit]], state: S)(self: Iterant[F, A]): F[S] = {
try self match {
case Next(a, rest, stop) =>
stopRef.elem = stop
f(state, a) match {
case Left(s) => rest.flatMap(loop(s))
case Left(s) => rest.flatMap(loop(stopRef, s))
case Right(s) => stop.map(_ => s)
}

case NextCursor(cursor, rest, stop) =>
process(state, cursor, rest, stop)
stopRef.elem = stop
process(stopRef)(state, cursor, rest, stop)

case NextBatch(batch, rest, stop) =>
stopRef.elem = stop
val cursor = batch.cursor()
process(state, cursor, rest, stop)
process(stopRef)(state, cursor, rest, stop)

case Suspend(rest, _) =>
rest.flatMap(loop(state))
case Suspend(rest, stop) =>
stopRef.elem = stop
rest.flatMap(loop(stopRef, state))

case Last(a) =>
stopRef.elem = null.asInstanceOf[F[Unit]]
F.pure(f(state, a) match {
case Left(s) => s
case Right(s) => s
})

case Halt(optE) =>
optE match {
case None => F.pure(state)
case Some(e) => F.raiseError(e)
case None =>
F.pure(state)
case Some(e) =>
stopRef.elem = null.asInstanceOf[F[Unit]]
F.raiseError(e)
}
}
catch {
case e if NonFatal(e) =>
self.earlyStop *> F.raiseError(e)
F.raiseError(e)
}
}


F.suspend(loop(seed)(self))
F.suspend {
// Reference to keep track of latest `earlyStop` value
val stopRef = ObjectRef.create(null.asInstanceOf[F[Unit]])
// Catch-all exceptions, ensuring latest `earlyStop` gets called
F.handleErrorWith(loop(stopRef, seed)(self)) { ex =>
stopRef.elem match {
case null => F.raiseError(ex)
case stop => stop *> F.raiseError(ex)
}
}
}
}

/** Implementation for `Iterant.foldWhileLeftEvalL`. */
def eval[F[_], A, S](self: Iterant[F, A], seed: F[S], f: (S, A) => F[Either[S, S]])
(implicit F: Sync[F]): F[S] = {

def process(state: S, stop: F[Unit], rest: F[Iterant[F, A]], a: A): F[S] = {
val fs = f(state, a).handleErrorWith { e =>
stop.flatMap(_ => F.raiseError(e))
}
def process(stopRef: ObjectRef[F[Unit]])(state: S, stop: F[Unit], rest: F[Iterant[F, A]], a: A): F[S] = {
val fs = f(state, a)

fs.flatMap {
case Left(s) => rest.flatMap(loop(s))
case Left(s) => rest.flatMap(loop(stopRef, s))
case Right(s) => stop.map(_ => s)
}
}

def loop(state: S)(self: Iterant[F, A]): F[S] = {
def loop(stopRef: ObjectRef[F[Unit]], state: S)(self: Iterant[F, A]): F[S] = {
try self match {
case Next(a, rest, stop) =>
process(state, stop, rest, a)
stopRef.elem = stop
process(stopRef)(state, stop, rest, a)

case NextCursor(cursor, rest, stop) =>
if (!cursor.hasNext()) rest.flatMap(loop(state)) else {
stopRef.elem = stop
if (!cursor.hasNext()) rest.flatMap(loop(stopRef, state)) else {
val a = cursor.next()
process(state, stop, F.pure(self), a)
process(stopRef)(state, stop, F.pure(self), a)
}

case NextBatch(batch, rest, stop) =>
stopRef.elem = stop
val cursor = batch.cursor()
if (!cursor.hasNext()) rest.flatMap(loop(state)) else {
if (!cursor.hasNext()) rest.flatMap(loop(stopRef, state)) else {
val a = cursor.next()
process(state, stop, F.pure(NextCursor(cursor, rest, stop)), a)
process(stopRef)(state, stop, F.pure(NextCursor(cursor, rest, stop)), a)
}

case Suspend(rest, _) =>
rest.flatMap(loop(state))
case Suspend(rest, stop) =>
stopRef.elem = stop
rest.flatMap(loop(stopRef, state))

case Last(a) =>
stopRef.elem = null.asInstanceOf[F[Unit]]
f(state, a).map {
case Left(s) => s
case Right(s) => s
}

case Halt(optE) =>
optE match {
case None => F.pure(state)
case Some(e) => F.raiseError(e)
case None =>
F.pure(state)
case Some(e) =>
stopRef.elem = null.asInstanceOf[F[Unit]]
F.raiseError(e)
}
}
catch {
case e if NonFatal(e) =>
self.earlyStop *> F.raiseError(e)
F.raiseError(e)
}
}

F.suspend(seed.flatMap(s => loop(s)(self)))
F.suspend {
// Reference to keep track of latest `earlyStop` value
val stopRef = ObjectRef.create(null.asInstanceOf[F[Unit]])
// Catch-all exceptions, ensuring latest `earlyStop` gets called
F.handleErrorWith(seed.flatMap(s => loop(stopRef, s)(self))) { ex =>
stopRef.elem match {
case null => F.raiseError(ex)
case stop => stop *> F.raiseError(ex)
}
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -227,4 +227,18 @@ object IterantFoldLeftSuite extends BaseTestSuite {
i.countL <-> Coeval(list.length)
}
}
}

test("earlyStop gets called for failing `rest` on Next node") { implicit s =>
var effect = 0

def stop(i: Int): Coeval[Unit] = Coeval { effect = i}
val dummy = DummyException("dummy")
val node3 = Iterant[Coeval].nextS(3, Coeval.raiseError(dummy), stop(3))
val node2 = Iterant[Coeval].nextS(2, Coeval(node3), stop(2))
val node1 = Iterant[Coeval].nextS(1, Coeval(node2), stop(1))

assertEquals(node1.toListL.runTry, Failure(dummy))
assertEquals(effect, 3)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ object IterantFoldWhileLeftSuite extends BaseTestSuite {
assertEquals(ref.runTry, Failure(dummy))
assertEquals(effect, 1)
}

test("foldWhileLeftEvalL protects against broken seed") { implicit s =>
var effect = 0
val dummy = DummyException("dummy")
Expand Down Expand Up @@ -304,7 +304,7 @@ object IterantFoldWhileLeftSuite extends BaseTestSuite {
assertEquals(r, Failure(dummy))
assertEquals(effect, 1)
}

test("findL is consistent with List.find") { implicit s =>
check3 { (list: List[Int], idx: Int, p: Int => Boolean) =>
val fa = arbitraryListToIterant[Coeval, Int](list, idx, allowErrors = false)
Expand Down Expand Up @@ -343,4 +343,57 @@ object IterantFoldWhileLeftSuite extends BaseTestSuite {
assertEquals(r, Failure(dummy))
assertEquals(effect, 1)
}
}

test("foldWhileLeft earlyStop gets called for failing `rest` on Next node") { implicit s =>
var effect = 0

def stop(i: Int): Coeval[Unit] = Coeval { effect = i}
val dummy = DummyException("dummy")
val node3 = Iterant[Coeval].nextS(3, Coeval.raiseError(dummy), stop(3))
val node2 = Iterant[Coeval].nextS(2, Coeval(node3), stop(2))
val node1 = Iterant[Coeval].nextS(1, Coeval(node2), stop(1))

assertEquals(node1.foldWhileLeftL(0)((_, _) => Left(0)).runTry, Failure(dummy))
assertEquals(effect, 3)
}

test("foldWhileLeft earlyStop doesn't get called for Last node") { implicit s =>
var effect = 0

def stop(i: Int): Coeval[Unit] = Coeval { effect = i}
val dummy = DummyException("dummy")
val node3 = Iterant[Coeval].lastS(3)
val node2 = Iterant[Coeval].nextS(2, Coeval(node3), stop(2))
val node1 = Iterant[Coeval].nextS(1, Coeval(node2), stop(1))

assertEquals(node1.foldWhileLeftL(0)((_, el) => if (el == 3) throw dummy else Left(0)).runTry, Failure(dummy))
assertEquals(effect, 0)
}

test("foldWhileLeftEvalL earlyStop gets called for failing `rest` on Next node") { implicit s =>
var effect = 0

def stop(i: Int): Coeval[Unit] = Coeval { effect = i}
val dummy = DummyException("dummy")
val node3 = Iterant[Coeval].nextS(3, Coeval.raiseError(dummy), stop(3))
val node2 = Iterant[Coeval].nextS(2, Coeval(node3), stop(2))
val node1 = Iterant[Coeval].nextS(1, Coeval(node2), stop(1))

assertEquals(node1.foldWhileLeftEvalL(Coeval(0))((_, _) => Coeval(Left(0))).runTry, Failure(dummy))
assertEquals(effect, 3)
}

test("foldWhileLeftEvalL earlyStop doesn't get called for Last node") { implicit s =>
var effect = 0

def stop(i: Int): Coeval[Unit] = Coeval { effect = i}
val dummy = DummyException("dummy")
val node3 = Iterant[Coeval].lastS(3)
val node2 = Iterant[Coeval].nextS(2, Coeval(node3), stop(2))
val node1 = Iterant[Coeval].nextS(1, Coeval(node2), stop(1))

assertEquals(node1.foldWhileLeftEvalL(Coeval(0))((_, el) => if (el == 3) throw dummy else Coeval(Left(0))).runTry, Failure(dummy))
assertEquals(effect, 0)
}

}