Skip to content

Commit

Permalink
Introduce Retry as an alternative to Tasks (#48)
Browse files Browse the repository at this point in the history
* Add a `Retry` module that performs operations with delays

* Add documentation

* Use a proper power for calculating exponential backoff, pass last delay always

Tests are included

* Move withTimeout and friends to AsyncOperations

* Move implementation of Retry to one without boxing Try

* Use FiniteDurations in tests too
  • Loading branch information
natsukagami committed Mar 11, 2024
1 parent f66a696 commit dc313af
Show file tree
Hide file tree
Showing 3 changed files with 276 additions and 0 deletions.
23 changes: 23 additions & 0 deletions shared/src/main/scala/async/AsyncOperations.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package gears.async

import scala.concurrent.duration.FiniteDuration
import java.util.concurrent.TimeoutException
import gears.async.AsyncOperations.sleep

trait AsyncOperations:
/** Suspends the current [[Async]] context for at least [[millis]] milliseconds. */
Expand All @@ -20,3 +22,24 @@ object AsyncOperations:
*/
inline def sleep(duration: FiniteDuration)(using AsyncOperations, Async): Unit =
sleep(duration.toMillis)

/** Runs [[op]] with a timeout. When the timeout occurs, [[op]] is cancelled through the given [[Async]] context, and
* [[TimeoutException]] is thrown.
*/
def withTimeout[T](timeout: FiniteDuration)(op: Async ?=> T)(using AsyncOperations, Async): T =
Async.group:
Async.select(
Future(op).handle(_.get),
Future(sleep(timeout)).handle: _ =>
throw TimeoutException()
)

/** Runs [[op]] with a timeout. When the timeout occurs, [[op]] is cancelled through the given [[Async]] context, and
* [[None]] is returned.
*/
def withTimeoutOption[T](timeout: FiniteDuration)(op: Async ?=> T)(using AsyncOperations, Async): Option[T] =
Async.group:
Async.select(
Future(op).handle(v => Some(v.get)),
Future(sleep(timeout)).handle(_ => None)
)
154 changes: 154 additions & 0 deletions shared/src/main/scala/async/retry.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
package gears.async

import scala.util.Random
import scala.util.{Try, Success, Failure}
import scala.util.boundary
import scala.util.control.NonFatal
import scala.concurrent.duration._

import gears.async.Async
import gears.async.AsyncOperations.sleep
import gears.async.Retry.Delay

/** Utility class to perform asynchronous actions with retrying policies on exceptions.
*
* See [[Retry]] companion object for common policies as a starting point.
*/
case class Retry(
retryOnSuccess: Boolean = false,
maximumFailures: Option[Int] = None,
delay: Delay = Delay.none
):
/** Runs [[body]] with the current policy in its own scope, returning the result or the last failure as an exception.
*/
def apply[T](op: => T)(using Async, AsyncOperations): T =
var failures = 0
var lastDelay: FiniteDuration = 0.second
boundary:
while true do
try
val value = op
if retryOnSuccess then
failures = 0
lastDelay = delay.delayFor(failures, lastDelay)
sleep(lastDelay)
else boundary.break(value)
catch
case b: boundary.Break[?] => throw b // handle this manually as it will be otherwise caught by NonFatal
case NonFatal(exception) =>
if maximumFailures.exists(_ == failures) then // maximum failure count reached
throw exception
else
failures = failures + 1
lastDelay = delay.delayFor(failures, lastDelay)
sleep(lastDelay)
end while
???

/** Set the maximum failure count. */
def withMaximumFailures(max: Int) =
assert(max >= 0)
this.copy(maximumFailures = Some(max))

/** Set the delay policy between runs. See [[Delay]]. */
def withDelay(delay: Delay) = this.copy(delay = delay)

object Retry:
/** Ignores the result and attempt the action in an infinite loop. [[Retry.withMaximumFailures]] can be useful for
* bailing on multiple failures. [[scala.util.boundary]] can be used for manually breaking.
*/
val forever = Retry(retryOnSuccess = true)

/** Returns the result, or attempt to retry if an exception is raised. */
val untilSuccess = Retry(retryOnSuccess = false)

/** Attempt to retry the operation *until* an exception is raised. In this mode, [[Retry]] always throws an exception
* on return.
*/
val untilFailure = Retry(retryOnSuccess = true).withMaximumFailures(0)

/** Defines a delay policy based on the number of successive failures and the duration of the last delay. See
* [[Delay]] companion object for some provided delay policies.
*/
trait Delay:
/** Return the expected duration to delay the next attempt from the current attempt.
*
* @param failuresCount
* The number of successive failures until the current attempt. Note that if the last attempt was a success,
* [[failuresCount]] is `0`.
* @param lastDelay
* The duration of the last delay.
*/
def delayFor(failuresCount: Int, lastDelay: FiniteDuration): FiniteDuration

object Delay:
/** No delays. */
val none = constant(0.second)

/** A fixed amount of delays, whether the last attempt was a success or failure. */
def constant(duration: FiniteDuration) = new Delay:
def delayFor(failuresCount: Int, lastDelay: FiniteDuration): FiniteDuration = duration

/** Returns a delay policy for exponential backoff.
* @param maximum
* The maximum duration possible for a delay.
* @param starting
* The delay duration between successful attempts, and after the first failures.
* @param multiplier
* Scale the delay duration by this multiplier for each successive failure. Defaults to `2`.
* @param jitter
* An additional jitter to randomize the delay duration. Defaults to none. See [[Jitter]].
*/
def backoff(maximum: Duration, starting: FiniteDuration, multiplier: Double = 2, jitter: Jitter = Jitter.none) =
new Delay:
def delayFor(failuresCount: Int, lastDelay: FiniteDuration): FiniteDuration =
val sleep = jitter
.jitterDelay(
lastDelay,
if failuresCount <= 1 then starting
else (starting.toMillis * scala.math.pow(multiplier, failuresCount - 1)).millis
)
maximum match
case max: FiniteDuration => sleep.min(max)
case _ => sleep /* infinite maximum */

/** Decorrelated exponential backoff: randomize between the last delay duration and a multiple of that duration. */
def deccorelated(maximum: Duration, starting: Duration, multiplier: Double = 3) =
new Delay:
def delayFor(failuresCount: Int, lastDelay: FiniteDuration): FiniteDuration =
val lowerBound =
if failuresCount <= 1 then 0.second else lastDelay
val upperBound =
(if failuresCount <= 1 then starting
else multiplier * lastDelay).min(maximum)
Random.between(lowerBound.toMillis, upperBound.toMillis + 1).millis

/** A randomizer for the delay duration, to avoid accidental coordinated DoS on failures. See [[Jitter]] companion
* objects for some provided jitter implementations.
*/
trait Jitter:
/** Returns the *actual* duration to delay between attempts, given the theoretical given delay and actual last delay
* duration, possibly with some randomization.
* @param last
* The last delay duration performed, with jitter applied.
* @param maximum
* The theoretical amount of delay governed by the [[Delay]] policy, serving as an upper bound.
*/
def jitterDelay(last: FiniteDuration, maximum: FiniteDuration): FiniteDuration

object Jitter:
import FiniteDuration as Duration

/** No jitter, always return the exact duration suggested by the policy. */
val none = new Jitter:
def jitterDelay(last: Duration, maximum: Duration): Duration = maximum

/** Full jitter: randomize between 0 and the suggested delay duration. */
val full = new Jitter:
def jitterDelay(last: Duration, maximum: Duration): Duration = Random.between(0, maximum.toMillis + 1).millis

/** Equal jitter: randomize between the last delay duration and the suggested delay duration. */
val equal = new Jitter:
def jitterDelay(last: Duration, maximum: Duration): Duration =
val base = maximum.toMillis / 2
(base + Random.between(0, maximum.toMillis - base + 1)).millis
99 changes: 99 additions & 0 deletions shared/src/test/scala/RetryBehavior.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import gears.async.{Async, Future, Task, TaskSchedule, Retry}
import Retry.Delay
import scala.concurrent.duration.*
import FiniteDuration as Duration
import gears.async.default.given
import Future.{*:, zip}

import scala.concurrent.ExecutionContext
import scala.util.{Failure, Success, Try}
import scala.util.Random

class RetryBehavior extends munit.FunSuite {
test("Exponential backoff(2) 50ms, 5 times total schedule"):
val start = System.currentTimeMillis()
Async.blocking:
var i = 0
Retry.untilSuccess.withDelay(Delay.backoff(1.second, 50.millis)):
i += 1
if i < 5 then throw Exception("try again!")
val end = System.currentTimeMillis()
assert(end - start >= 50 + 100 + 200 + 400)
assert(end - start < 50 + 100 + 200 + 400 + 800)

test("UntilSuccess 150ms"):
val start = System.currentTimeMillis()
Async.blocking:
var i = 0
val ret = Retry.untilSuccess.withDelay(Delay.constant(150.millis)):
if (i < 4) then
i += 1
throw AssertionError()
else i
assertEquals(ret, 4)
val end = System.currentTimeMillis()
assert(end - start >= 4 * 150)
assert(end - start < 5 * 150)

test("UntilFailure 150ms") {
val start = System.currentTimeMillis()
val ex = AssertionError()
Async.blocking:
var i = 0
val ret = Try(Retry.untilFailure.withDelay(Delay.constant(150.millis)):
if (i < 4) then
i += 1
i
else throw ex
)
assertEquals(ret, Failure(ex))
val end = System.currentTimeMillis()
assert(end - start >= 4 * 150)
assert(end - start < 5 * 150)
}

test("delay policies") {
// start with wave1.length of failures, one success, and then wave2.length of failures
def expectDurations(policy: Delay, wave1: Seq[Duration], success: Duration, wave2: Seq[Duration]) =
var lastDelay: Duration = 0.second
for (len, i) <- wave1.iterator.zipWithIndex do
assertEquals(policy.delayFor(i + 1, lastDelay), len, clue = s"$policy $len $i")
lastDelay = len
assertEquals(policy.delayFor(0, lastDelay), success)
lastDelay = success
for (len, i) <- wave2.iterator.zipWithIndex do
assertEquals(policy.delayFor(i + 1, lastDelay), len)
lastDelay = len

expectDurations(
Delay.none,
Seq(0.second, 0.second),
0.second,
Seq(0.second, 0.second)
)
expectDurations(
Delay.constant(1.second),
Seq(1.second, 1.second),
1.second,
Seq(1.second, 1.second)
)

expectDurations(
Delay.backoff(1.minute, 1.second, multiplier = 5),
Seq(1.second, 5.seconds, 25.seconds, 1.minute),
1.second,
Seq(1.second, 5.seconds, 25.seconds, 1.minute)
)

val decor = Delay.deccorelated(1.minute, 1.second, multiplier = 5)
def decorLoop(i: Int, last: Duration, max: Duration): Unit =
if last == max then assertEquals(decor.delayFor(i, max), max)
else
val delay = decor.delayFor(i, last)
if i > 1 then assert(last <= delay)
assert(delay <= max)
decorLoop(i + 1, delay, max)
decorLoop(1, 0.second, 1.minute)
decorLoop(0, 5.second, 1.minute)
}
}

0 comments on commit dc313af

Please sign in to comment.