diff --git a/src/main/scala/dev/kovstas/fs2throttler/Throttler.scala b/src/main/scala/dev/kovstas/fs2throttler/Throttler.scala index 0380556..3154317 100644 --- a/src/main/scala/dev/kovstas/fs2throttler/Throttler.scala +++ b/src/main/scala/dev/kovstas/fs2throttler/Throttler.scala @@ -22,11 +22,10 @@ package dev.kovstas.fs2throttler import cats.Applicative +import cats.effect.Temporal import cats.effect.kernel.Clock -import cats.effect.{Ref, Temporal} import cats.implicits._ import fs2.{Pipe, Pull, Stream} - import scala.concurrent.duration._ object Throttler { @@ -88,75 +87,64 @@ object Throttler { mode: ThrottleMode, burst: Long, fnCost: O => F[Long] - ): Pipe[F, O, O] = { + ): Pipe[F, O, O] = { in => + val capacity = if (elements + burst <= 0) Long.MaxValue else elements + burst + val interval = duration.toNanos / capacity def go( s: Stream[F, O], - bucket: Ref[F, (Long, FiniteDuration)], - capacity: Long, - interval: Long + tokens: => Long, + time: => Long ): Pull[F, O, Unit] = { s.pull.uncons1.flatMap { case Some((head, tail)) => - Pull - .eval(for { - cost <- fnCost(head) - now <- Clock[F].monotonic - delay <- bucket.modify { case (tokens, lastUpdate) => - if (interval == 0) { - ((0, now), Duration.Zero) - } else { - val elapsed = (now - lastUpdate).toNanos - val tokensArrived = - if (elapsed >= interval) { - elapsed / interval - } else 0 - - val nextTime = lastUpdate + (tokensArrived * interval).nanos - val available = math.min(tokens + tokensArrived, capacity) - - if (cost <= available) { - ((available - cost, nextTime), Duration.Zero) - } else { - val timePassed = now.toNanos - nextTime.toNanos - val waitingTime = (cost - available) * interval - val delay = (waitingTime - timePassed).nanos - - ((0, now + delay), delay) - } - } + Pull.eval(fnCost(head) product Clock[F].monotonic.map(_.toNanos)).flatMap { case (cost, now) => + val (remainingTokens, nextTime, delay) = { + val elapsed = now - time + + val tokensArrived = + if (elapsed >= interval) { + elapsed / interval + } else 0 + val nextTime = time + tokensArrived * interval + val available = math.min(tokens + tokensArrived, capacity) + + if (cost <= available) { + (available - cost, nextTime, 0L) + } else { + val timePassed = now - nextTime + val waitingTime = (cost - available) * interval + val delay = waitingTime - timePassed + + (0L, now + delay, delay) + } + } + + if (delay == 0) { + Pull.output1(head) >> go(tail, remainingTokens, nextTime) + } else + mode match { + case Enforcing => + go(tail, remainingTokens, nextTime) + case Shaping => + Pull.sleep(delay.nanos) >> Pull.output1(head) >> go(tail, remainingTokens, nextTime) } - continueF = Pull.output1(head) >> go(tail, bucket, capacity, interval) - result <- - if (delay == Duration.Zero) { - Applicative[F].pure(continueF) - } else { - mode match { - case Enforcing => - Applicative[F].pure(go(tail, bucket, capacity, interval)) - case Shaping => - Clock[F].delayBy(Applicative[F].pure(continueF), delay) - } - } - } yield result) - .flatMap(identity) + } case None => Pull.done } } - in => - val capacity = if (elements + burst <= 0) Long.MaxValue else elements + burst - - for { - bucket <- Stream.eval( - Ref.ofEffect( - Clock[F].monotonic.map((capacity, _)) - ) - ) - stream <- go(in, bucket, capacity, duration.toNanos / capacity).stream - } yield stream + if (interval == 0) { + in + } else { + Stream + .eval(Clock[F].monotonic) + .flatMap { time => + go(in, elements, time.toNanos).stream + } + } } diff --git a/src/test/scala/dev/kovstas/fs2throttler/ThrottlerSpec.scala b/src/test/scala/dev/kovstas/fs2throttler/ThrottlerSpec.scala index 4874e24..b6b83bd 100644 --- a/src/test/scala/dev/kovstas/fs2throttler/ThrottlerSpec.scala +++ b/src/test/scala/dev/kovstas/fs2throttler/ThrottlerSpec.scala @@ -121,6 +121,7 @@ class ThrottlerSpec extends munit.FunSuite { .unsafeToFuture()(runtime) ctx.tick() + ctx.advanceAndTick(500.millis) assertEquals(elements.toList, List(0, 1)) ctx.advanceAndTick(2.seconds) assertEquals(elements.toList, List(0, 1, 2, 3, 4))