diff --git a/client/src/main/scala/org/http4s/client/PoolManager.scala b/client/src/main/scala/org/http4s/client/PoolManager.scala index e1b1896bc27..30e4f520203 100644 --- a/client/src/main/scala/org/http4s/client/PoolManager.scala +++ b/client/src/main/scala/org/http4s/client/PoolManager.scala @@ -111,7 +111,7 @@ private final class PoolManager[F[_], A <: Connection[F]]( } private def addToWaitQueue(key: RequestKey, callback: Callback[NextConnection]): Unit = - if (waitQueue.length <= maxWaitQueueLimit) { + if (waitQueue.length < maxWaitQueueLimit) { waitQueue.enqueue(Waiting(key, callback, Instant.now())) } else { logger.error(s"Max wait length reached, not scheduling.") diff --git a/client/src/test/scala/org/http4s/client/PoolManagerSpec.scala b/client/src/test/scala/org/http4s/client/PoolManagerSpec.scala new file mode 100644 index 00000000000..dc07219cf0c --- /dev/null +++ b/client/src/test/scala/org/http4s/client/PoolManagerSpec.scala @@ -0,0 +1,82 @@ +package org.http4s +package client + +import cats.effect._ +import fs2.Stream +import scala.concurrent.duration._ + +class PoolManagerSpec(name: String) extends Http4sSpec { + val key = RequestKey(Uri.Scheme.http, Uri.Authority(host = Uri.IPv4("127.0.0.1"))) + class TestConnection extends Connection[IO] { + def runRequest(req: Request[IO]) = IO.never + def isClosed = false + def isRecyclable = true + def requestKey = key + def shutdown() = () + } + + def mkPool( + maxTotal: Int = 1, + maxWaitQueueLimit: Int = 2 + ) = + IO( + ConnectionManager.pool( + builder = _ => IO(new TestConnection()), + maxTotal = maxTotal, + maxWaitQueueLimit = maxWaitQueueLimit, + maxConnectionsPerRequestKey = _ => 5, + responseHeaderTimeout = Duration.Inf, + requestTimeout = Duration.Inf, + executionContext = testExecutionContext + )) + + "A pool manager" should { + "wait up to maxWaitQueueLimit" in { + (for { + pool <- mkPool(maxTotal = 1, maxWaitQueueLimit = 2) + _ <- pool.borrow(key) + att <- Stream(Stream.eval(pool.borrow(key))).repeat + .take(2) + .covary[IO] + .joinUnbounded + .compile + .toList + .attempt + } yield att).unsafeRunTimed(2.seconds) must_== None + } + + "throw at maxWaitQueueLimit" in { + (for { + pool <- mkPool(maxTotal = 1, maxWaitQueueLimit = 2) + _ <- pool.borrow(key) + att <- Stream(Stream.eval(pool.borrow(key))).repeat + .take(3) + .covary[IO] + .joinUnbounded + .compile + .toList + .attempt + } yield att).unsafeRunTimed(2.seconds) must_== Some(Left(WaitQueueFullFailure())) + } + + "wake up a waiting connection on release" in { + (for { + pool <- mkPool(maxTotal = 1, maxWaitQueueLimit = 1) + conn <- pool.borrow(key) + fiber <- pool.borrow(key).start // Should be one waiting + _ <- pool.release(conn.connection) + _ <- fiber.join + } yield ()).unsafeRunTimed(2.seconds) must_== Some(()) + } + + "wake up a waiting connection on invalidate" in { + (for { + pool <- mkPool(maxTotal = 1, maxWaitQueueLimit = 1) + conn <- pool.borrow(key) + fiber <- pool.borrow(key).start // Should be one waiting + _ <- pool.invalidate(conn.connection) + _ <- fiber.join + } yield ()).unsafeRunTimed(2.seconds) must_== Some(()) + } + } +}