Skip to content

Commit

Permalink
#53 - workaround that works with trampoline - safety
Browse files Browse the repository at this point in the history
  • Loading branch information
chris-twiner committed Jan 9, 2024
1 parent 4042c5f commit 15f92ab
Show file tree
Hide file tree
Showing 4 changed files with 179 additions and 96 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,12 @@ trait IterateeImplicits {
def apply[A] = {
def go(xs: Iterator[E])(s: StepT[E, F, A]): IterateeT[E, F, A] =
if (xs.isEmpty) s.pointI
else
else {
s mapCont { k =>
val next = xs.next()
k(elInput(next)) >>== go(xs)
}
}

go(iter)
}
Expand Down
116 changes: 79 additions & 37 deletions scales-xml/src/test/scala/scales/utils/IterateeTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@ import scalaz.EphemeralStream.emptyEphemeralStream
import scalaz.Free.Trampoline
import scalaz.iteratee.{EnumeratorT, IterateeT, StepT}
import scalaz.iteratee.Input.{Element, Empty, Eof, emptyInput, eofInput}
import scalaz.iteratee.Iteratee.{done, enumIterator, foldM, iterateeT => siteratee}
import scalaz.iteratee.Iteratee.{done, empty, enumEofT, enumIterator, eofInput, foldM, iterateeT}
import scalaz.iteratee.StepT.{Cont, Done, scont}
import scalaz.Free._
import scalaz.{Bind, EphemeralStream, Monad}
import scalaz.{Applicative, Bind, EphemeralStream, Monad}
import scalaz.Scalaz._
import scalaz.effect.IO
import scales.utils.IterateeTests.{isDoneT, maxIterations}
import scales.utils.IterateeTests.{enumWeakStreamF, isDoneT, maxIterations}
import scales.utils.iteratee.Eval

import scala.annotation.tailrec

Expand Down Expand Up @@ -41,6 +42,19 @@ object IterateeTests {
}
}

def enumEphemeralStreamF[E, F[_] : Monad](state: EphemeralStream[E] => Unit)(xs: EphemeralStream[E]): EnumeratorT[E, F] = {
import EphemeralStream.##::

new EnumeratorT[E, F] {
def apply[A] = (s: StepT[E, F, A]) =>
xs match {
case h ##:: t => s.mapCont(k => k(scalaz.iteratee.Iteratee.elInput(h)) >>== enumEphemeralStreamF[E, F]({state(t);state})(t).apply[A])
case _ => s.pointI
}

}
}

def enumWeakStream[E, F[_] : Monad](xs: WeakStream[E]): EnumeratorT[E, F] = {
import WeakStream.##::

Expand All @@ -64,11 +78,33 @@ object IterateeTests {
} else
s.pointI
)
}
}


def enumWeakStreamF[E, F[_] : Monad](state: WeakStream[E] => Unit)(xs: WeakStream[E]): EnumeratorT[E, F] = {
import WeakStream.##::

/*xs match {
case h ##:: t => s.mapCont(k => k(scalaz.iteratee.Iteratee.elInput(h)) >>== enumWeakStream[E, F](t).apply[A])
case _ => s.pointI
}*/
new EnumeratorT[E, F] {
def apply[A] = (s: StepT[E, F, A]) =>
s(done = (x,y) => {
val isEof = y.isEof
val isEL = y.isEl
val isEmpty = y.isEmpty
if (xs.isEmpty)
done(x, Eof[E])
else
s.pointI
}
,
cont =
k =>
if (xs.nonEmpty) xs match {
case h ##:: t => s.mapCont(k => k(scalaz.iteratee.Iteratee.elInput(h)) >>== enumWeakStreamF[E, F]({state(t); state})(t).apply[A])
case _ => s.pointI
} else
s.pointI
)
}
}

Expand All @@ -78,7 +114,10 @@ object IterateeTests {
done = (a, y) => {
val (x, cont) = a
test(x)
assertTrue("should have been Empty " + i, y.isEmpty)
if (i == maxIterations)
assertTrue("should have been Eof " + i, y.isEof)
else
assertTrue("should have been Empty " + i, y.isEmpty)
},
cont = _ => fail("was not done " + i)
)
Expand Down Expand Up @@ -123,17 +162,17 @@ class IterateeTest extends junit.framework.TestCase {

def f(i: Int): ResumableIter[Int, Trampoline, EphemeralStream[Int]] = {
def step(i: Int)(s: Input[Int]): ResumableIter[Int, Trampoline, EphemeralStream[Int]] =
siteratee( F.point(
iterateeT( F.point(
s( el = e => {
//println("i "+i+" e "+e)
Done((iTo(i, e), siteratee( F.point( Cont(step(i + 1))) ) ), Input.Empty[Int])
Done((iTo(i, e), iterateeT( F.point( Cont(step(i + 1))) ) ), Input.Empty[Int])
},
empty = Cont(step(i)),
eof = Done((emptyEphemeralStream, siteratee( F.point( Cont(error("Shouldn't call cont on eof")).asInstanceOf[ResumableStep[Int, Trampoline, EphemeralStream[Int]]]) )), Eof[Int])
eof = Done((emptyEphemeralStream, iterateeT( F.point( Cont(error("Shouldn't call cont on eof")).asInstanceOf[ResumableStep[Int, Trampoline, EphemeralStream[Int]]]) )), Eof[Int])
))
)

siteratee( F.point( Cont(step(i)) ))
iterateeT( F.point( Cont(step(i)) ))
}

val enum = (i: Int) => enumToMany(usum[Int, Trampoline])(f(i))
Expand Down Expand Up @@ -231,20 +270,20 @@ class IterateeTest extends junit.framework.TestCase {
// force a restart at magic number 1
val sum: ResumableIter[Long,Trampoline, Long] = {
def step(acc: Long)( s : Input[Long] ) : ResumableIter[Long, Trampoline, Long] =
siteratee(F.point(
iterateeT(F.point(
s( el = e => {
val nacc = acc + e
if (nacc == 25000050001L)
Done((nacc, siteratee( F.point( Cont(step(nacc)) )) ), Empty[Long])
Done((nacc, iterateeT( F.point( Cont(step(nacc)) )) ), Empty[Long])
else
Cont(step(nacc))
},
empty = Cont(step(acc)),
eof = Done((acc, siteratee( F.point( Done(acc, Empty[Long]).asInstanceOf[ResumableStep[Long, Trampoline, Long]] ))), Eof[Long])
eof = Done((acc, iterateeT( F.point( Done(acc, Empty[Long]).asInstanceOf[ResumableStep[Long, Trampoline, Long]] ))), Eof[Long])
)
))

siteratee( F.point( Cont(step(0))) )
iterateeT( F.point( Cont(step(0))) )
}

val p =
Expand Down Expand Up @@ -275,22 +314,22 @@ class IterateeTest extends junit.framework.TestCase {
// force a restart every entry
val echo: ResumableIter[Long,TheF, Long] = {
def step( s : Input[Long] ) : ResumableIter[Long, TheF, Long] =
siteratee( F.point(
iterateeT( F.point(
s( el = e => {
// println("got "+e)
//if (e > 2) {
Done((e, siteratee( F.point( Cont(step) ))), Empty[Long])
Done((e, iterateeT( F.point( Cont(step) ))), Empty[Long])
//} else Cont(step)
//
},
empty = Cont(step),
eof = //Done((0, siteratee( F.point( Done(0, Eof[Long])))), Eof[Long])
eof = //Done((0, iterateeT( F.point( Done(0, Eof[Long])))), Eof[Long])
resumableEOFDone(0)
//Done((0, siteratee( F.point( Cont(step) ))), Eof[Long])
//Done((0, iterateeT( F.point( Cont(step) ))), Eof[Long])
)
))

siteratee( F.point( Cont(step) ) )
iterateeT( F.point( Cont(step) ) )
}

val oitr = enumToMany(echo)(mapTo[Long, TheF, Long] {
Expand Down Expand Up @@ -353,19 +392,20 @@ class IterateeTest extends junit.framework.TestCase {
p run
}


/**
* Normal iters can't maintain state if they return Done, since
* we pass back a new iter as well we can keep state
*/
def testResumableIterFolds(): Unit = {
//val liter = (1 to maxIterations)
val liter = WeakStream.iterTo( 1 to maxIterations iterator )//WeakStream.iTo(1, maxIterations)
//val liter = (1 to maxIterations).toIterator
var liter = WeakStream.iterTo( 1 to maxIterations iterator )//WeakStream.iTo(1, maxIterations)

type F[X] = IO[X]

def enum(i: WeakStream[Int]) =
IterateeTests.enumWeakStream[Int, F](i)
val func = (s : WeakStream[Int]) => {liter = s}

//def enum(i: Iterator[Int]) = iteratorEnumerator[Int, F](i)
def enum(i: WeakStream[Int]) = enumWeakStreamF[Int, F](func)(i)
/*
def enum(i: WeakStream[Int]) =
//enumIndexedSeq2[Int, F](i.toIterator.toIndexedSeq)
Expand Down Expand Up @@ -419,7 +459,7 @@ class IterateeTest extends junit.framework.TestCase {
val (x, cont) = a

assertTrue("should have been EOF", y.isEof)
assertEquals(maxIterations + 1, x)
assertEquals(maxIterations, x)
},
cont = _ => fail("was not done")
)
Expand All @@ -433,27 +473,29 @@ class IterateeTest extends junit.framework.TestCase {
* and another that is done only each three.
*/
def testResumableOnDone():Unit = {
val liter = (1 to maxIterations).iterator

val F = implicitly[ Monad[Trampoline] ]
val counter = runningCount[Int, Trampoline]

val F = implicitly[ Monad[Trampoline] ]
var liter = WeakStream.iterTo((1 to maxIterations).iterator)
val func = (s : WeakStream[Int]) => {liter = s}

def enum(i: WeakStream[Int]) = enumWeakStreamF[Int, Trampoline](func)(i)

def step( list : List[Long])( s : Input[Int] ) : ResumableIter[Int, Trampoline, Long] =
siteratee(F.point(
iterateeT(F.point(
s(el = {e =>
val next = e.longValue :: list
if (next.size == 3)
Done((e, siteratee(F.point( Cont(step(List()))))), Input.Empty[Int])
Done((e, iterateeT(F.point( Cont(step(List()))))), Input.Empty[Int])
else
Cont(step(next))
},
empty = Cont(step(list)),
eof = Done((list.last, siteratee(F.point( Cont(step(List()) )))), Eof[Int])
eof = Done((list.last, iterateeT(F.point( Cont(step(List()) )))), Eof[Int])
)
))

val inThrees = siteratee( F.point( Cont(step(List())) ) )
val inThrees = iterateeT( F.point( Cont(step(List())) ) )

val ionDone = onDone[Int, Trampoline, Long](List(counter, inThrees))

Expand All @@ -474,23 +516,23 @@ class IterateeTest extends junit.framework.TestCase {
cont = _ => fail("was not done "+i))
}

val starter = (ionDone &= iteratorEnumerator(liter)).eval
val starter = (ionDone &= enum(liter)).eval

val p =
for {
_ <- isDone(1, starter)

// check it does not blow up the stack and/or mem
r <- (foldM[Int, Trampoline, ResumableIterList[Int,Trampoline,Long]](starter){ (itr, i) =>
val nitr = (extractCont(itr) &= iteratorEnumerator(liter)).eval
val nitr = (extractCont(itr) &= enum(liter)).eval
Monad[Trampoline].map(nitr.value){
_ =>
isDone(i, nitr)
nitr
}
} &= iteratorEnumerator((2 to maxIterations iterator))) run

res = (extractCont(r) &= iteratorEnumerator(liter)).eval
res = (extractCont(r) &= enum(liter)).eval

step <- res.value
} yield {
Expand Down
6 changes: 5 additions & 1 deletion scales-xml/src/test/scala/scales/utils/WeakStream.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@ object WeakStream {
if (lower > upper) empty else WeakStream.cons(lower, iTo(lower + 1, upper))

def iterTo[A](iterator: Iterator[A]): WeakStream[A] =
if (iterator.isEmpty) empty else WeakStream.cons(iterator.next(), iterTo(iterator))
if (iterator.isEmpty) empty else {
val next = iterator.next()
val after = iterTo(iterator)
WeakStream.cons(next, after)
}

def empty[A] = new WeakStream[A]{
val empty = true
Expand Down
Loading

0 comments on commit 15f92ab

Please sign in to comment.