From ecb76a8123c79f2b4e1d247b0253393f39990348 Mon Sep 17 00:00:00 2001 From: Johannes Rudolph Date: Thu, 30 Nov 2017 17:04:01 +0100 Subject: [PATCH] +htc #1312 new host connection pool implementation A new implementation of the HostConnectionPool. The basic idea is to replace the former complicated streaming pipeline of `PoolFlow`, `PoolConductor`, and `PoolSlot` into a single stage that handles all aspects to get rid of all the small race condition issues that exist in the current ("legacy") pool implementation. The new pool implementation is split into two basic classes * akka.http.impl.engine.client.pool.NewHostConnectionPool that provides all the infrastructure and event handling to drive the pool * akka.http.impl.engine.client.pool.SlotState that contains only the logic to handle state changes of a single pool slot --- .../mima-filters/10.0.10.backwards.excludes | 5 + .../src/main/resources/reference.conf | 10 + .../http/impl/engine/client/PoolFlow.scala | 5 +- .../engine/client/PoolInterfaceActor.scala | 9 +- .../client/pool/NewHostConnectionPool.scala | 409 ++++++++++++ .../impl/engine/client/pool/SlotState.scala | 221 +++++++ .../settings/ConnectionPoolSettingsImpl.scala | 45 +- .../akka/http/impl/util/JavaMapping.scala | 1 + .../settings/ConnectionPoolSettings.scala | 23 +- .../settings/ConnectionPoolSettings.scala | 23 +- .../engine/client/ConnectionPoolSpec.scala | 43 +- .../client/HostConnectionPoolSpec.scala | 615 ++++++++++++++++++ 12 files changed, 1376 insertions(+), 33 deletions(-) create mode 100644 akka-http-core/src/main/scala/akka/http/impl/engine/client/pool/NewHostConnectionPool.scala create mode 100644 akka-http-core/src/main/scala/akka/http/impl/engine/client/pool/SlotState.scala create mode 100644 akka-http-core/src/test/scala/akka/http/impl/engine/client/HostConnectionPoolSpec.scala diff --git a/akka-http-core/src/main/mima-filters/10.0.10.backwards.excludes b/akka-http-core/src/main/mima-filters/10.0.10.backwards.excludes index 355134969ea..48c98b06368 100644 --- a/akka-http-core/src/main/mima-filters/10.0.10.backwards.excludes +++ b/akka-http-core/src/main/mima-filters/10.0.10.backwards.excludes @@ -10,3 +10,8 @@ ProblemFilters.exclude[ReversedMissingMethodProblem]("akka.http.javadsl.settings # New settings in `@DoNotInherit` classes ProblemFilters.exclude[ReversedMissingMethodProblem]("akka.http.javadsl.settings.ParserSettings.getModeledHeaderParsing") ProblemFilters.exclude[ReversedMissingMethodProblem]("akka.http.scaladsl.settings.ParserSettings.modeledHeaderParsing") + +# New poolImplementation setting on @DoNotInherit class +ProblemFilters.exclude[ReversedMissingMethodProblem]("akka.http.scaladsl.settings.ConnectionPoolSettings.poolImplementation") +# New responseEntitySubscriptionTimeout setting +ProblemFilters.exclude[ReversedMissingMethodProblem]("akka.http.scaladsl.settings.ConnectionPoolSettings.responseEntitySubscriptionTimeout") diff --git a/akka-http-core/src/main/resources/reference.conf b/akka-http-core/src/main/resources/reference.conf index a4a74091b8e..bc340224c1d 100644 --- a/akka-http-core/src/main/resources/reference.conf +++ b/akka-http-core/src/main/resources/reference.conf @@ -308,6 +308,16 @@ akka.http { # will automatically terminate itself. Set to `infinite` to completely disable idle timeouts. idle-timeout = 30 s + # The pool implementation to use. Currently supported are: + # - legacy: the original, still default, pool implementation + # - new: the new still-evolving pool implementation, that will receive fixes and new features + pool-implementation = legacy + + # The "new" pool implementation will fail a connection early and clear the slot if a response entity was not + # subscribed during the given time period after the response was dispatched. In busy systems the timeout might be + # too tight if a response is not picked up quick enough after it was dispatched by the pool. + response-entity-subscription-timeout = 1.second + # Modify this section to tweak client settings only for host connection pools APIs like `Http().superPool` or # `Http().singleRequest`. client = { diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/client/PoolFlow.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/client/PoolFlow.scala index a3ef2208351..6c382da2f5b 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/client/PoolFlow.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/client/PoolFlow.scala @@ -5,6 +5,7 @@ package akka.http.impl.engine.client import akka.NotUsed +import akka.annotation.InternalApi import akka.http.impl.engine.client.PoolConductor.PoolSlotsSetting import akka.http.scaladsl.settings.ConnectionPoolSettings @@ -16,7 +17,9 @@ import akka.stream.scaladsl._ import akka.http.scaladsl.model._ import akka.http.scaladsl.Http -private object PoolFlow { +/** Internal API */ +@InternalApi +private[client] object PoolFlow { case class RequestContext(request: HttpRequest, responsePromise: Promise[HttpResponse], retriesLeft: Int) { require(retriesLeft >= 0) diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/client/PoolInterfaceActor.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/client/PoolInterfaceActor.scala index 20069089b42..6285a186232 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/client/PoolInterfaceActor.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/client/PoolInterfaceActor.scala @@ -7,9 +7,11 @@ package akka.http.impl.engine.client import akka.actor._ import akka.event.{ LogSource, Logging, LoggingAdapter } import akka.http.impl.engine.client.PoolFlow._ +import akka.http.impl.engine.client.pool.NewHostConnectionPool import akka.http.impl.util.RichHttpRequest import akka.http.scaladsl.model._ import akka.http.scaladsl.Http +import akka.http.scaladsl.settings.PoolImplementation import akka.macros.LogHelper import akka.stream.actor.ActorPublisherMessage._ import akka.stream.actor.ActorSubscriberMessage._ @@ -98,7 +100,12 @@ private class PoolInterfaceActor(gateway: PoolGateway)(implicit fm: Materializer val connectionFlow = Http().outgoingConnectionUsingTransport(host, port, settings.transport, connectionContext, settings.connectionSettings, setup.log) - val poolFlow = PoolFlow(connectionFlow, settings, log).named("PoolFlow") + val poolFlow = + settings.poolImplementation match { + case PoolImplementation.Legacy ⇒ PoolFlow(connectionFlow, settings, log).named("PoolFlow") + case PoolImplementation.New ⇒ NewHostConnectionPool(connectionFlow, settings, log).named("PoolFlow") + } + Source.fromPublisher(ActorPublisher(self)).via(poolFlow).runWith(Sink.fromSubscriber(ActorSubscriber[ResponseContext](self))) } diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/client/pool/NewHostConnectionPool.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/client/pool/NewHostConnectionPool.scala new file mode 100644 index 00000000000..07fe3bd9b2b --- /dev/null +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/client/pool/NewHostConnectionPool.scala @@ -0,0 +1,409 @@ +/* + * Copyright (C) 2009-2017 Lightbend Inc. + */ + +package akka.http.impl.engine.client.pool + +import java.util + +import akka.NotUsed +import akka.actor.Cancellable +import akka.annotation.InternalApi +import akka.dispatch.ExecutionContexts +import akka.event.LoggingAdapter +import akka.http.impl.engine.client.PoolFlow.{ RequestContext, ResponseContext } +import akka.http.impl.engine.client.pool.SlotState.Unconnected +import akka.http.impl.util.{ StageLoggingWithOverride, StreamUtils } +import akka.http.scaladsl.Http +import akka.http.scaladsl.model.{ HttpEntity, HttpRequest, HttpResponse, headers } +import akka.http.scaladsl.settings.ConnectionPoolSettings +import akka.stream.scaladsl.{ Flow, Keep, Sink, Source } +import akka.stream.stage.{ GraphStage, GraphStageLogic, InHandler, OutHandler } +import akka.stream._ + +import scala.concurrent.Future +import scala.concurrent.duration.{ Duration, FiniteDuration } +import scala.util.{ Failure, Success } + +/** + * Internal API + * + * New host connection pool implementation. + * + * Backpressure logic of the external interface: + * + * * pool pulls if there's a free slot + * * pool buffers responses until they are pulled. The buffer is unlimited in theory with the reasoning that + * reasonable behavior can be expected from downstream consumers, i.e. at least one pull for every request sent in. + * + * It's hard to say if that's reasonable enough. As we can only ever receive a single pull we will always need a + * buffer of at least `max-connections` elements to allow for any parallelism. So, an alternative strategy could be + * to leave a response in its slot until it is fetched. + * + * (The old implementation may or may not have implemented similar behavior. It probably had a buffer because of + * the merge in the involved graph structure.) + * + * The implementation is split up into this class which does all the stream-based wiring. It contains a vector of + * slots that contain the mutable slot state for every slot. + * + * The actual state machine logic is handled in separate [[SlotState]] subclasses that interface with the logic through + * the clean [[SlotContext]] interface. + */ +@InternalApi +private[client] object NewHostConnectionPool { + def apply( + connectionFlow: Flow[HttpRequest, HttpResponse, Future[Http.OutgoingConnection]], + settings: ConnectionPoolSettings, log: LoggingAdapter): Flow[RequestContext, ResponseContext, NotUsed] = + Flow.fromGraph(new HostConnectionPoolStage(connectionFlow, settings, log)) + + private final class HostConnectionPoolStage( + connectionFlow: Flow[HttpRequest, HttpResponse, Future[Http.OutgoingConnection]], + _settings: ConnectionPoolSettings, _log: LoggingAdapter + ) extends GraphStage[FlowShape[RequestContext, ResponseContext]] { + val requestsIn = Inlet[RequestContext]("HostConnectionPoolStage.requestsIn") + val responsesOut = Outlet[ResponseContext]("HostConnectionPoolStage.responsesOut") + + override val shape = FlowShape(requestsIn, responsesOut) + def createLogic(inheritedAttributes: Attributes): GraphStageLogic = + new GraphStageLogic(shape) with StageLoggingWithOverride with InHandler with OutHandler { logic ⇒ + override def logOverride: LoggingAdapter = _log + + setHandlers(requestsIn, responsesOut, this) + + private[this] var lastTimeoutId = 0L + + val slots = Vector.tabulate(_settings.maxConnections)(new Slot(_)) + val outBuffer: util.Deque[ResponseContext] = new util.ArrayDeque[ResponseContext] + val retryBuffer: util.Deque[RequestContext] = new util.ArrayDeque[RequestContext] + + override def preStart(): Unit = { + pull(requestsIn) + slots.foreach(_.initialize()) + } + + def onPush(): Unit = { + dispatchRequest(grab(requestsIn)) + pullIfNeeded() + } + def onPull(): Unit = + if (!outBuffer.isEmpty) + push(responsesOut, outBuffer.pollFirst()) + + def manageState(): Unit = { + pullIfNeeded() + } + + def pullIfNeeded(): Unit = + if (hasIdleSlots) + if (!retryBuffer.isEmpty) { + log.debug("Dispatching request from retryBuffer") + dispatchRequest(retryBuffer.pollFirst()) + } else if (!hasBeenPulled(requestsIn)) + pull(requestsIn) + + def hasIdleSlots: Boolean = + // TODO: optimize by keeping track of idle connections? + slots.exists(_.isIdle) + + def dispatchResponse(req: RequestContext, res: HttpResponse): Unit = + dispatchResponseContext(ResponseContext(req, Success(res))) + + def dispatchFailure(req: RequestContext, cause: Throwable): Unit = { + if (req.retriesLeft > 0) { + log.debug("Request has {} retries left, retrying...", req.retriesLeft) + retryBuffer.addLast(req.copy(retriesLeft = req.retriesLeft - 1)) + } else + dispatchResponseContext(ResponseContext(req, Failure(cause))) + } + + def dispatchResponseContext(resCtx: ResponseContext): Unit = + if (outBuffer.isEmpty && isAvailable(responsesOut)) + push(responsesOut, resCtx) + else + outBuffer.addLast(resCtx) + + def dispatchRequest(req: RequestContext): Unit = { + val slot = + slots.find(_.isIdle) + .getOrElse(throw new IllegalStateException("Tried to dispatch request when no slot is idle")) + + slot.debug("Dispatching request") // FIXME: add abbreviation + slot.dispatchRequest(req) + } + + def numConnectedSlots: Int = slots.count(_.isConnected) + + private class Slot(val slotId: Int) extends SlotContext { + private[this] var state: SlotState = SlotState.Unconnected + private[this] var currentTimeoutId: Long = -1 + private[this] var currentTimeout: Cancellable = _ + + private[this] var connection: SlotConnection = _ + def isIdle: Boolean = state.isIdle + def isConnected: Boolean = state.isConnected + def shutdown(): Unit = { + // TODO: should we offer errors to the connection? + closeConnection() + + state.onShutdown(this) + } + + def initialize(): Unit = + if (slotId < settings.minConnections) { + debug("Preconnecting") + updateState(_.onPreConnect(this)) + } + + def onConnected(outgoing: Http.OutgoingConnection): Unit = + updateState(_.onConnectedAttemptSucceeded(this, outgoing)) + + def onConnectFailed(cause: Throwable): Unit = + updateState(_.onConnectionAttemptFailed(this, cause)) + + def dispatchRequest(req: RequestContext): Unit = + updateState(_.onNewRequest(this, req)) + + def onRequestEntityCompleted(): Unit = + updateState(_.onRequestEntityCompleted(this)) + def onRequestEntityFailed(cause: Throwable): Unit = + updateState(_.onRequestEntityFailed(this, cause)) + + def onResponseReceived(response: HttpResponse): Unit = + updateState(_.onResponseReceived(this, response)) + def onResponseEntitySubscribed(): Unit = + updateState(_.onResponseEntitySubscribed(this)) + def onResponseEntityCompleted(): Unit = + updateState(_.onResponseEntityCompleted(this)) + def onResponseEntityFailed(cause: Throwable): Unit = + updateState(_.onResponseEntityFailed(this, cause)) + + def onConnectionCompleted(): Unit = updateState(_.onConnectionCompleted(this)) + def onConnectionFailed(cause: Throwable): Unit = updateState(_.onConnectionFailed(this, cause)) + + protected def updateState(f: SlotState ⇒ SlotState): Unit = { + if (currentTimeout ne null) { + currentTimeout.cancel() + currentTimeout = null + currentTimeoutId = -1 + } + + val previousState = state + state = f(state) + debug(s"State change [${previousState.name}] -> [${state.name}]") + + state.stateTimeout match { + case Duration.Inf ⇒ + case d: FiniteDuration ⇒ + val myTimeoutId = createNewTimeoutId() + currentTimeoutId = myTimeoutId + currentTimeout = + materializer.scheduleOnce(d, safeRunnable { + if (myTimeoutId == currentTimeoutId) { // timeout may race with state changes, ignore if timeout isn't current any more + debug(s"Slot timeout after $d") + updateState(_.onTimeout(this)) + } + }) + } + + if (!previousState.isIdle && state.isIdle) { + debug("Slot became idle... Trying to pull") + pullIfNeeded() + } + + if (state == Unconnected && numConnectedSlots < settings.minConnections) { + debug(s"Preconnecting because number of connected slots fell down to $numConnectedSlots") + updateState(_.onPreConnect(this)) + } + + // put additional bookkeeping here (like keeping track of idle connections) + } + protected def setState(newState: SlotState): Unit = + updateState(_ ⇒ newState) + + def debug(msg: String): Unit = + log.debug("[{} ({})] {}", slotId, state.productPrefix, msg) + + def debug(msg: String, arg1: AnyRef): Unit = + log.debug(s"[{} ({})] $msg", slotId, state.productPrefix, arg1) + + def warning(msg: String): Unit = + log.warning("[{} ({})] {}", slotId, state.productPrefix, msg) + + def warning(msg: String, arg1: AnyRef): Unit = + log.warning(s"[{} ({})] $msg", slotId, state.productPrefix, arg1) + + def settings: ConnectionPoolSettings = _settings + + def openConnection(): Future[Http.OutgoingConnection] = { + if (connection ne null) throw new IllegalStateException("Cannot open connection when slot still has an open connection") + + connection = logic.openConnection(this) + connection.outgoingConnection + } + def pushRequestToConnectionAndThen(request: HttpRequest, nextState: SlotState): SlotState = { + if (connection eq null) throw new IllegalStateException("Cannot open push request to connection when there's no connection") + + // bit of a HACK to make sure onRequestEntityCompleted will end up in the right place + state = nextState + + connection.pushRequest(request) + state + } + def closeConnection(): Unit = + if (connection ne null) { + connection.close() + connection = null + } + def isCurrentConnection(conn: SlotConnection): Boolean = connection eq conn + def isConnectionClosed: Boolean = (connection eq null) || connection.isClosed + + def dispatchResponse(req: RequestContext, res: HttpResponse): Unit = logic.dispatchResponse(req, res) + def dispatchFailure(req: RequestContext, cause: Throwable): Unit = logic.dispatchFailure(req, cause) + def willCloseAfter(res: HttpResponse): Boolean = logic.willClose(res) + } + final class SlotConnection( + _slot: Slot, + requestOut: SubSourceOutlet[HttpRequest], + responseIn: SubSinkInlet[HttpResponse], + val outgoingConnection: Future[Http.OutgoingConnection] + ) extends InHandler with OutHandler { + var ongoingResponseEntity: Option[HttpEntity] = None + + /** Will only be executed if this connection is still the current connection for its slot */ + def withSlot(f: Slot ⇒ Unit): Unit = + if (_slot.isCurrentConnection(this)) f(_slot) + + // FIXME: is this safe? I.e. always pulled? + def pushRequest(request: HttpRequest): Unit = { + val newRequest = + request.entity match { + case _: HttpEntity.Strict ⇒ + withSlot(_.onRequestEntityCompleted()) + request + case e ⇒ + val (newEntity, entityComplete) = HttpEntity.captureTermination(request.entity) + entityComplete.onComplete(safely { + case Success(_) ⇒ withSlot(_.onRequestEntityCompleted()) + case Failure(cause) ⇒ withSlot(_.onRequestEntityFailed(cause)) + })(ExecutionContexts.sameThreadExecutionContext) + + request.withEntity(newEntity) + } + + requestOut.push(newRequest) + } + def close(): Unit = { + requestOut.complete() + responseIn.cancel() + + // FIXME: or should we use discardEntity which does Sink.ignore? + ongoingResponseEntity.foreach(_.dataBytes.runWith(Sink.cancelled)(subFusingMaterializer)) + } + def isClosed: Boolean = requestOut.isClosed || responseIn.isClosed + + def onPush(): Unit = { + val response = responseIn.grab() + + withSlot(_.debug("Received response")) // FIXME: add abbreviated info + + response.entity match { + case _: HttpEntity.Strict ⇒ + withSlot(_.onResponseReceived(response)) + withSlot(_.onResponseEntitySubscribed()) + withSlot(_.onResponseEntityCompleted()) + case e ⇒ + ongoingResponseEntity = Some(e) + + val (newEntity, (entitySubscribed, entityComplete)) = + StreamUtils.transformEntityStream(response.entity, StreamUtils.CaptureMaterializationAndTerminationOp) + + entitySubscribed.onComplete(safely { + case Success(()) ⇒ + withSlot(_.onResponseEntitySubscribed()) + + entityComplete.onComplete(safely { + case Success(_) ⇒ withSlot(_.onResponseEntityCompleted()) + case Failure(cause) ⇒ withSlot(_.onResponseEntityFailed(cause)) + })(ExecutionContexts.sameThreadExecutionContext) + })(ExecutionContexts.sameThreadExecutionContext) + + withSlot(_.onResponseReceived(response.withEntity(newEntity))) + } + + if (!responseIn.isClosed) responseIn.pull() + } + + override def onUpstreamFinish(): Unit = + withSlot { slot ⇒ + slot.debug("Connection completed") + slot.onConnectionCompleted() + } + override def onUpstreamFailure(ex: Throwable): Unit = + withSlot { slot ⇒ + slot.debug("Connection failed") + slot.onConnectionFailed(ex) + } + + def onPull(): Unit = () // FIXME: do we need push / pull handling? + + override def onDownstreamFinish(): Unit = + withSlot(_.debug("Connection cancelled")) + + } + def openConnection(slot: Slot): SlotConnection = { + val requestOut = new SubSourceOutlet[HttpRequest](s"PoolSlot[${slot.slotId}].requestOut") + + val responseIn = new SubSinkInlet[HttpResponse](s"PoolSlot[${slot.slotId}].responseIn") + responseIn.pull() + + slot.debug("Establishing connection") + val connection = + Source.fromGraph(requestOut.source) + .viaMat(connectionFlow)(Keep.right) + .toMat(responseIn.sink)(Keep.left) + .run()(subFusingMaterializer) + + connection.onComplete(safely { + case Success(outgoingConnection) ⇒ slot.onConnected(outgoingConnection) + case Failure(cause) ⇒ slot.onConnectFailed(cause) + })(ExecutionContexts.sameThreadExecutionContext) + + val slotCon = new SlotConnection(slot, requestOut, responseIn, connection) + requestOut.setHandler(slotCon) + responseIn.setHandler(slotCon) + slotCon + } + + override def onUpstreamFinish(): Unit = { + log.debug("Pool upstream was completed") + super.onDownstreamFinish() + } + override def onUpstreamFailure(ex: Throwable): Unit = { + log.debug("Pool upstream failed with {}", ex) + super.onUpstreamFailure(ex) + } + override def onDownstreamFinish(): Unit = { + log.debug("Pool downstream cancelled") + super.onDownstreamFinish() + } + override def postStop(): Unit = { + log.debug("Pool stopped") + slots.foreach(_.shutdown()) + } + + private def willClose(response: HttpResponse): Boolean = + response.header[headers.Connection].exists(_.hasClose) + + private val safeCallback = getAsyncCallback[() ⇒ Unit](f ⇒ f()) + private def safely[T, U](f: T ⇒ Unit): T ⇒ Unit = t ⇒ safeCallback.invoke(() ⇒ f(t)) + private def safeRunnable(body: ⇒ Unit): Runnable = + new Runnable { + def run(): Unit = safeCallback.invoke(() ⇒ body) + } + private def createNewTimeoutId(): Long = { + lastTimeoutId += 1 + lastTimeoutId + } + } + } +} diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/client/pool/SlotState.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/client/pool/SlotState.scala new file mode 100644 index 00000000000..3c8aea9285b --- /dev/null +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/client/pool/SlotState.scala @@ -0,0 +1,221 @@ +/* + * Copyright (C) 2009-2017 Lightbend Inc. + */ + +package akka.http.impl.engine.client.pool + +import akka.annotation.InternalApi +import akka.http.impl.engine.client.PoolFlow.RequestContext +import akka.http.impl.util._ +import akka.http.scaladsl.Http +import akka.http.scaladsl.model.{ HttpRequest, HttpResponse } +import akka.http.scaladsl.settings.ConnectionPoolSettings + +import scala.concurrent.Future +import scala.concurrent.duration._ + +/** + * Internal API + * + * Interface between slot states and the actual slot. + */ +@InternalApi +private[pool] abstract class SlotContext { + def openConnection(): Future[Http.OutgoingConnection] + def pushRequestToConnectionAndThen(request: HttpRequest, nextState: SlotState): SlotState + def closeConnection(): Unit + def isConnectionClosed: Boolean + + def dispatchFailure(req: RequestContext, cause: Throwable): Unit + def dispatchResponse(req: RequestContext, res: HttpResponse): Unit + + def willCloseAfter(res: HttpResponse): Boolean + + def debug(msg: String): Unit + def debug(msg: String, arg1: AnyRef): Unit + + def warning(msg: String): Unit + def warning(msg: String, arg1: AnyRef): Unit + + def settings: ConnectionPoolSettings +} + +/* Internal API */ +@InternalApi +private[pool] sealed abstract class SlotState extends Product { + def isIdle: Boolean + def isConnected: Boolean + + def onPreConnect(ctx: SlotContext): SlotState = illegalState(ctx, "preConnect") + def onConnectedAttemptSucceeded(ctx: SlotContext, outgoingConnection: Http.OutgoingConnection): SlotState = illegalState(ctx, "connected attempt succeeded") + def onConnectionAttemptFailed(ctx: SlotContext, cause: Throwable): SlotState = illegalState(ctx, "connection attempt failed") + + def onNewRequest(ctx: SlotContext, requestContext: RequestContext): SlotState = illegalState(ctx, "new request") + + /** Will be called either immediately if the request entity is strict or otherwise later */ + def onRequestEntityCompleted(ctx: SlotContext): SlotState = illegalState(ctx, "request entity completed") + def onRequestEntityFailed(ctx: SlotContext, cause: Throwable): SlotState = illegalState(ctx, "request entity failed") + + def onResponseReceived(ctx: SlotContext, response: HttpResponse): SlotState = illegalState(ctx, "receive response") + def onResponseEntitySubscribed(ctx: SlotContext): SlotState = illegalState(ctx, "responseEntitySubscribed") + + /** Will be called either immediately if the response entity is strict or otherwise later */ + def onResponseEntityCompleted(ctx: SlotContext): SlotState = illegalState(ctx, "response entity completed") + def onResponseEntityFailed(ctx: SlotContext, cause: Throwable): SlotState = illegalState(ctx, "response entity failed") + + def onConnectionCompleted(ctx: SlotContext): SlotState = illegalState(ctx, "connection completed") + def onConnectionFailed(ctx: SlotContext, cause: Throwable): SlotState = illegalState(ctx, "connection failed") + + def onTimeout(ctx: SlotContext): SlotState = illegalState(ctx, "timeout") + + def onShutdown(ctx: SlotContext): Unit = () + + /** A slot can define a timeout for that state after which onTimeout will be called. */ + def stateTimeout: Duration = Duration.Inf + + protected def illegalState(ctx: SlotContext, what: String): SlotState = { + ctx.debug(s"Got unexpected event [$what] in state [$name]]") + throw new IllegalStateException(s"Cannot [$what] when in state [$name]") + } + + def name: String = productPrefix +} + +/** + * Internal API + * + * Implementation of slot logic that is completed decoupled from the machinery bits which are implemented in the GraphStageLogic + * and exposed only through [[SlotContext]]. + */ +@InternalApi +private[pool] object SlotState { + sealed abstract class ConnectedState extends SlotState { + def isConnected: Boolean = true + } + sealed trait IdleState extends SlotState { + final override def isIdle = true + } + sealed trait BusyState extends SlotState { + final override def isIdle = false // no HTTP pipelining right now + def ongoingRequest: RequestContext + } + + case object Unconnected extends SlotState with IdleState { + def isConnected: Boolean = false + + override def onPreConnect(ctx: SlotContext): SlotState = { + ctx.openConnection() + PreConnecting + } + + override def onNewRequest(ctx: SlotContext, requestContext: RequestContext): SlotState = { + ctx.openConnection() + Connecting(requestContext) + } + } + case object Idle extends ConnectedState with IdleState with WithRequestDispatching { + override def onNewRequest(ctx: SlotContext, requestContext: RequestContext): SlotState = + dispatchRequestToConnection(ctx, requestContext) + + override def onConnectionCompleted(ctx: SlotContext): SlotState = Unconnected + override def onConnectionFailed(ctx: SlotContext, cause: Throwable): SlotState = Unconnected + } + sealed trait WithRequestDispatching { _: ConnectedState ⇒ + def dispatchRequestToConnection(ctx: SlotContext, ongoingRequest: RequestContext): SlotState = { + val r = ongoingRequest.request + ctx.pushRequestToConnectionAndThen(r, WaitingForEndOfRequestEntity(ongoingRequest)) + } + } + + final case class Connecting(ongoingRequest: RequestContext) extends ConnectedState with BusyState with WithRequestDispatching { + override def onConnectedAttemptSucceeded(ctx: SlotContext, outgoingConnection: Http.OutgoingConnection): SlotState = { + ctx.debug("Slot connection was established") + dispatchRequestToConnection(ctx, ongoingRequest) + } + override def onConnectionAttemptFailed(ctx: SlotContext, cause: Throwable): SlotState = { + ctx.debug("Connection attempt failed.") + // FIXME: register failed connection attempt, schedule request for rerun, backoff new connection attempts + ctx.dispatchFailure(ongoingRequest, cause) + Unconnected + } + } + + case object PreConnecting extends ConnectedState with IdleState with WithRequestDispatching { + override def onConnectedAttemptSucceeded(ctx: SlotContext, outgoingConnection: Http.OutgoingConnection): SlotState = { + ctx.debug("Slot connection was (pre-)established") + Idle + } + override def onConnectionAttemptFailed(ctx: SlotContext, cause: Throwable): SlotState = { + ctx.debug("Connection attempt failed.") + // FIXME: register failed connection attempt, schedule request for rerun, backoff new connection attempts + Unconnected + } + + override def onNewRequest(ctx: SlotContext, requestContext: RequestContext): SlotState = + Connecting(requestContext) + } + final case class WaitingForEndOfRequestEntity(ongoingRequest: RequestContext) extends ConnectedState with BusyState { + override def onRequestEntityCompleted(ctx: SlotContext): SlotState = WaitingForResponse(ongoingRequest) + + override def onConnectionFailed(ctx: SlotContext, cause: Throwable): SlotState = { + ctx.dispatchFailure(ongoingRequest, cause) + ctx.closeConnection() + + Unconnected + } + + } + final case class WaitingForResponse(ongoingRequest: RequestContext) extends ConnectedState with BusyState { + override def onResponseReceived(ctx: SlotContext, response: HttpResponse): SlotState = { + ctx.dispatchResponse(ongoingRequest, response) + + WaitingForResponseEntitySubscription(ongoingRequest, response, ctx.settings.responseEntitySubscriptionTimeout) + } + + override def onConnectionFailed(ctx: SlotContext, cause: Throwable): SlotState = { + ctx.dispatchFailure(ongoingRequest, cause) + ctx.closeConnection() + + Unconnected + } + } + final case class WaitingForResponseEntitySubscription( + ongoingRequest: RequestContext, + ongoingResponse: HttpResponse, override val stateTimeout: Duration) extends ConnectedState with BusyState { + + override def onResponseEntitySubscribed(ctx: SlotContext): SlotState = + WaitingForEndOfResponseEntity(ongoingRequest, ongoingResponse) + + override def onTimeout(ctx: SlotContext): SlotState = { + ctx.warning( + s"Response entity was not subscribed after $stateTimeout. Make sure to read the response entity body or call `discardBytes()` on it. " + + s"${ongoingRequest.request.debugString} -> ${ongoingResponse.debugString}") + ctx.closeConnection() + Unconnected + } + } + final case class WaitingForEndOfResponseEntity( + ongoingRequest: RequestContext, + ongoingResponse: HttpResponse) extends ConnectedState with BusyState { + + override def onResponseEntityCompleted(ctx: SlotContext): SlotState = + if (ctx.willCloseAfter(ongoingResponse) || ctx.isConnectionClosed) { + ctx.closeConnection() + Unconnected + } else + Idle + + override def onResponseEntityFailed(ctx: SlotContext, cause: Throwable): SlotState = { + ctx.debug("Response entity failed with {}", cause) + // we cannot fail the response at this point, the response has already been dispatched + ctx.closeConnection() + Unconnected + } + + // we ignore these signals here and expect that it will also be flagged on the entity stream + // FIXME: should we still add timeouts for these cases? + override def onConnectionFailed(ctx: SlotContext, cause: Throwable): SlotState = this + override def onConnectionCompleted(ctx: SlotContext): SlotState = this + } + +} diff --git a/akka-http-core/src/main/scala/akka/http/impl/settings/ConnectionPoolSettingsImpl.scala b/akka-http-core/src/main/scala/akka/http/impl/settings/ConnectionPoolSettingsImpl.scala index fc998be6b48..700e436e9df 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/settings/ConnectionPoolSettingsImpl.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/settings/ConnectionPoolSettingsImpl.scala @@ -7,7 +7,7 @@ package akka.http.impl.settings import akka.annotation.InternalApi import akka.http.impl.util.{ SettingsCompanion, _ } import akka.http.scaladsl.ClientTransport -import akka.http.scaladsl.settings.{ ClientConnectionSettings, ConnectionPoolSettings } +import akka.http.scaladsl.settings.{ ClientConnectionSettings, ConnectionPoolSettings, PoolImplementation } import com.typesafe.config.Config import scala.concurrent.duration.Duration @@ -15,26 +15,30 @@ import scala.concurrent.duration.Duration /** INTERNAL API */ @InternalApi private[akka] final case class ConnectionPoolSettingsImpl( - maxConnections: Int, - minConnections: Int, - maxRetries: Int, - maxOpenRequests: Int, - pipeliningLimit: Int, - idleTimeout: Duration, - connectionSettings: ClientConnectionSettings, - transport: ClientTransport) + maxConnections: Int, + minConnections: Int, + maxRetries: Int, + maxOpenRequests: Int, + pipeliningLimit: Int, + idleTimeout: Duration, + connectionSettings: ClientConnectionSettings, + poolImplementation: PoolImplementation, + responseEntitySubscriptionTimeout: Duration, + transport: ClientTransport) extends ConnectionPoolSettings { def this( - maxConnections: Int, - minConnections: Int, - maxRetries: Int, - maxOpenRequests: Int, - pipeliningLimit: Int, - idleTimeout: Duration, - connectionSettings: ClientConnectionSettings) = + maxConnections: Int, + minConnections: Int, + maxRetries: Int, + maxOpenRequests: Int, + pipeliningLimit: Int, + idleTimeout: Duration, + connectionSettings: ClientConnectionSettings, + poolImplementation: PoolImplementation, + responseEntitySubscriptionTimeout: Duration) = this(maxConnections, minConnections, maxRetries, maxOpenRequests, pipeliningLimit, idleTimeout, connectionSettings, - ClientTransport.TCP) + poolImplementation, responseEntitySubscriptionTimeout, ClientTransport.TCP) require(maxConnections > 0, "max-connections must be > 0") require(minConnections >= 0, "min-connections must be >= 0") @@ -56,7 +60,12 @@ object ConnectionPoolSettingsImpl extends SettingsCompanion[ConnectionPoolSettin c getInt "max-open-requests", c getInt "pipelining-limit", c getPotentiallyInfiniteDuration "idle-timeout", - ClientConnectionSettingsImpl.fromSubConfig(root, c.getConfig("client")) + ClientConnectionSettingsImpl.fromSubConfig(root, c.getConfig("client")), + c.getString("pool-implementation").toLowerCase match { + case "legacy" ⇒ PoolImplementation.Legacy + case "new" ⇒ PoolImplementation.New + }, + c getPotentiallyInfiniteDuration "response-entity-subscription-timeout" ) } } diff --git a/akka-http-core/src/main/scala/akka/http/impl/util/JavaMapping.scala b/akka-http-core/src/main/scala/akka/http/impl/util/JavaMapping.scala index 725f56bcecb..6ac9b4d7db1 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/util/JavaMapping.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/util/JavaMapping.scala @@ -195,6 +195,7 @@ private[http] object JavaMapping { implicit object PreviewServerSettings extends Inherited[js.PreviewServerSettings, akka.http.scaladsl.settings.PreviewServerSettings] implicit object ServerSettingsT extends Inherited[js.ServerSettings.Timeouts, akka.http.scaladsl.settings.ServerSettings.Timeouts] implicit object Http2ServerSettingT extends Inherited[js.Http2ServerSettings, akka.http.scaladsl.settings.Http2ServerSettings] + implicit object PoolImplementationT extends Inherited[js.PoolImplementation, akka.http.scaladsl.settings.PoolImplementation] implicit object OutgoingConnection extends JavaMapping[jdsl.OutgoingConnection, sdsl.Http.OutgoingConnection] { def toScala(javaObject: jdsl.OutgoingConnection): sdsl.Http.OutgoingConnection = javaObject.delegate diff --git a/akka-http-core/src/main/scala/akka/http/javadsl/settings/ConnectionPoolSettings.scala b/akka-http-core/src/main/scala/akka/http/javadsl/settings/ConnectionPoolSettings.scala index 9a1118f8c87..19a2f9e350c 100644 --- a/akka-http-core/src/main/scala/akka/http/javadsl/settings/ConnectionPoolSettings.scala +++ b/akka-http-core/src/main/scala/akka/http/javadsl/settings/ConnectionPoolSettings.scala @@ -4,7 +4,7 @@ package akka.http.javadsl.settings import akka.actor.ActorSystem -import akka.annotation.DoNotInherit +import akka.annotation.{ ApiMayChange, DoNotInherit } import akka.http.impl.settings.ConnectionPoolSettingsImpl import com.typesafe.config.Config @@ -12,6 +12,14 @@ import scala.concurrent.duration.Duration import akka.http.impl.util.JavaMapping.Implicits._ import akka.http.javadsl.ClientTransport +@ApiMayChange +trait PoolImplementation +@ApiMayChange +object PoolImplementation { + def Legacy: PoolImplementation = akka.http.scaladsl.settings.PoolImplementation.Legacy + def New: PoolImplementation = akka.http.scaladsl.settings.PoolImplementation.New +} + /** * Public API but not intended for subclassing */ @@ -25,6 +33,12 @@ abstract class ConnectionPoolSettings private[akka] () { self: ConnectionPoolSet def getIdleTimeout: Duration = idleTimeout def getConnectionSettings: ClientConnectionSettings = connectionSettings + @ApiMayChange + def getPoolImplementation: PoolImplementation = poolImplementation + + @ApiMayChange + def getResponseEntitySubscriptionTimeout: Duration = responseEntitySubscriptionTimeout + /** The underlying transport used to connect to hosts. By default [[ClientTransport.TCP]] is used. */ def getTransport: ClientTransport = transport.asJava @@ -37,6 +51,13 @@ abstract class ConnectionPoolSettings private[akka] () { self: ConnectionPoolSet def withPipeliningLimit(newValue: Int): ConnectionPoolSettings = self.copy(pipeliningLimit = newValue) def withIdleTimeout(newValue: Duration): ConnectionPoolSettings = self.copy(idleTimeout = newValue) def withConnectionSettings(newValue: ClientConnectionSettings): ConnectionPoolSettings = self.copy(connectionSettings = newValue.asScala) + + @ApiMayChange + def withPoolImplementation(newValue: PoolImplementation): ConnectionPoolSettings = self.copy(poolImplementation = newValue.asScala) + + @ApiMayChange + def withResponseEntitySubscriptionTimeout(newValue: Duration): ConnectionPoolSettings = self.copy(responseEntitySubscriptionTimeout = newValue) + def withTransport(newValue: ClientTransport): ConnectionPoolSettings = self.copy(transport = newValue.asScala) } diff --git a/akka-http-core/src/main/scala/akka/http/scaladsl/settings/ConnectionPoolSettings.scala b/akka-http-core/src/main/scala/akka/http/scaladsl/settings/ConnectionPoolSettings.scala index b881c93f109..0cf36c44359 100644 --- a/akka-http-core/src/main/scala/akka/http/scaladsl/settings/ConnectionPoolSettings.scala +++ b/akka-http-core/src/main/scala/akka/http/scaladsl/settings/ConnectionPoolSettings.scala @@ -3,7 +3,7 @@ */ package akka.http.scaladsl.settings -import akka.annotation.DoNotInherit +import akka.annotation.{ ApiMayChange, DoNotInherit } import akka.http.impl.settings.ConnectionPoolSettingsImpl import akka.http.javadsl.{ settings ⇒ js } import akka.http.scaladsl.ClientTransport @@ -11,6 +11,14 @@ import com.typesafe.config.Config import scala.concurrent.duration.Duration +@ApiMayChange +sealed trait PoolImplementation extends js.PoolImplementation +@ApiMayChange +object PoolImplementation { + case object Legacy extends PoolImplementation + case object New extends PoolImplementation +} + /** * Public API but not intended for subclassing */ @@ -24,6 +32,13 @@ abstract class ConnectionPoolSettings extends js.ConnectionPoolSettings { self: def idleTimeout: Duration def connectionSettings: ClientConnectionSettings + @ApiMayChange + def poolImplementation: PoolImplementation + + /** The time after which the pool will drop an entity automatically if it wasn't read or discarded */ + @ApiMayChange + def responseEntitySubscriptionTimeout: Duration + /** The underlying transport used to connect to hosts. By default [[ClientTransport.TCP]] is used. */ def transport: ClientTransport @@ -39,6 +54,12 @@ abstract class ConnectionPoolSettings extends js.ConnectionPoolSettings { self: // overloads for idiomatic Scala use def withConnectionSettings(newValue: ClientConnectionSettings): ConnectionPoolSettings = self.copy(connectionSettings = newValue) + + @ApiMayChange + def withPoolImplementation(newValue: PoolImplementation): ConnectionPoolSettings = self.copy(poolImplementation = newValue) + + @ApiMayChange + override def withResponseEntitySubscriptionTimeout(newValue: Duration): ConnectionPoolSettings = self.copy(responseEntitySubscriptionTimeout = newValue) def withTransport(newTransport: ClientTransport): ConnectionPoolSettings = self.copy(transport = newTransport) } diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/client/ConnectionPoolSpec.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/client/ConnectionPoolSpec.scala index 7b60c8ca644..30e3ddd95c6 100644 --- a/akka-http-core/src/test/scala/akka/http/impl/engine/client/ConnectionPoolSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/client/ConnectionPoolSpec.scala @@ -17,7 +17,7 @@ import akka.http.impl.util._ import akka.http.scaladsl.Http.OutgoingConnection import akka.http.scaladsl.model._ import akka.http.scaladsl.model.headers._ -import akka.http.scaladsl.settings.{ ClientConnectionSettings, ConnectionPoolSettings, ServerSettings } +import akka.http.scaladsl.settings.{ ClientConnectionSettings, ConnectionPoolSettings, PoolImplementation, ServerSettings } import akka.http.scaladsl.{ ClientTransport, ConnectionContext, Http } import akka.stream.ActorMaterializer import akka.stream.TLSProtocol._ @@ -32,7 +32,7 @@ import scala.concurrent.duration._ import scala.util.control.NonFatal import scala.util.{ Failure, Success, Try } -class ConnectionPoolSpec extends AkkaSpec(""" +abstract class ConnectionPoolSpec(poolImplementation: PoolImplementation) extends AkkaSpec(""" akka.loggers = [] akka.loglevel = OFF akka.io.tcp.windows-connection-abort-workaround-enabled = auto @@ -64,7 +64,7 @@ class ConnectionPoolSpec extends AkkaSpec(""" "The host-level client infrastructure" should { - "properly complete a simple request/response cycle" in new TestSetup { + "complete a simple request/response cycle" in new TestSetup { val (requestIn, responseOut, responseOutSub, hcp) = cachedHostConnectionPool[Int]() requestIn.sendNext(HttpRequest(uri = "/") → 42) @@ -160,7 +160,7 @@ class ConnectionPoolSpec extends AkkaSpec(""" Await.result(idSum, 10.seconds.dilated) shouldEqual N * (N + 1) / 2 } - "properly surface connection-level errors" in new TestSetup(autoAccept = true) { + "surface connection-level errors" in new TestSetup(autoAccept = true) { val (requestIn, responseOut, responseOutSub, hcp) = cachedHostConnectionPool[Int](maxRetries = 0) requestIn.sendNext(HttpRequest(uri = "/a") → 42) @@ -177,7 +177,7 @@ class ConnectionPoolSpec extends AkkaSpec(""" } // akka-http/#416 - "properly surface connection-level and stream-level errors while receiving response entity" in new TestSetup(autoAccept = true) { + "surface connection-level and stream-level errors while receiving response entity" in new TestSetup(autoAccept = true) { val errorOnConnection1 = Promise[ByteString]() val crashingEntity = @@ -550,9 +550,16 @@ class ConnectionPoolSpec extends AkkaSpec(""" ccSettings: ClientConnectionSettings = ClientConnectionSettings(system)) = { val settings = - new ConnectionPoolSettingsImpl(maxConnections, minConnections, - maxRetries, maxOpenRequests, pipeliningLimit, - idleTimeout.dilated, ccSettings) + ConnectionPoolSettings(system) + .withMaxConnections(maxConnections) + .withMinConnections(minConnections) + .withMaxRetries(maxRetries) + .withMaxOpenRequests(maxOpenRequests) + .withPipeliningLimit(pipeliningLimit) + .withIdleTimeout(idleTimeout.dilated) + .withConnectionSettings(ccSettings) + .withPoolImplementation(poolImplementation) + flowTestBench( Http().cachedHostConnectionPool[T](serverHostName, serverPort, settings)) } @@ -565,8 +572,17 @@ class ConnectionPoolSpec extends AkkaSpec(""" pipeliningLimit: Int = 1, idleTimeout: FiniteDuration = 5.seconds, ccSettings: ClientConnectionSettings = ClientConnectionSettings(system)) = { - val settings = new ConnectionPoolSettingsImpl(maxConnections, minConnections, maxRetries, maxOpenRequests, pipeliningLimit, - idleTimeout.dilated, ClientConnectionSettings(system)) + + val settings = + ConnectionPoolSettings(system) + .withMaxConnections(maxConnections) + .withMinConnections(minConnections) + .withMaxRetries(maxRetries) + .withMaxOpenRequests(maxOpenRequests) + .withPipeliningLimit(pipeliningLimit) + .withIdleTimeout(idleTimeout.dilated) + .withConnectionSettings(ccSettings) + .withPoolImplementation(poolImplementation) flowTestBench(Http().superPool[T](settings = settings)) } @@ -622,7 +638,8 @@ class ConnectionPoolSpec extends AkkaSpec(""" def connectTo(host: String, port: Int, settings: ClientConnectionSettings)(implicit system: ActorSystem): Flow[ByteString, ByteString, Future[OutgoingConnection]] = { promise.success((host, port, settings)) - Flow.fromSinkAndSource(in.sink, Source.fromPublisher(out)).mapMaterializedValue(_ ⇒ Promise().future) + Flow.fromSinkAndSource(in.sink, Source.fromPublisher(out)) + .mapMaterializedValue(_ ⇒ Future.successful(Http.OutgoingConnection(InetSocketAddress.createUnresolved("local", 12345), InetSocketAddress.createUnresolved(host, port)))) } } @@ -635,6 +652,7 @@ class ConnectionPoolSpec extends AkkaSpec(""" ConnectionPoolSettings(system) .withTransport(transport) .withConnectionSettings(ClientConnectionSettings(system).withIdleTimeout(CustomIdleTimeout)) + .withPoolImplementation(poolImplementation) val responseFuture = issueRequest(HttpRequest(uri = "http://example.org/test"), settings = poolSettings) @@ -655,3 +673,6 @@ class ConnectionPoolSpec extends AkkaSpec(""" response.entity.dataBytes.utf8String.awaitResult(10.seconds) should ===("Hello World!") } } + +class LegacyConnectionPoolSpec extends ConnectionPoolSpec(PoolImplementation.Legacy) +class NewConnectionPoolSpec extends ConnectionPoolSpec(PoolImplementation.New) diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/client/HostConnectionPoolSpec.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/client/HostConnectionPoolSpec.scala new file mode 100644 index 00000000000..f43f015853c --- /dev/null +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/client/HostConnectionPoolSpec.scala @@ -0,0 +1,615 @@ +/* + * Copyright (C) 2009-2017 Lightbend Inc. + */ + +package akka.http.impl.engine.client + +import java.net.InetSocketAddress +import java.util.concurrent.atomic.AtomicInteger + +import akka.actor.ActorSystem +import akka.http.impl.util._ +import akka.event.LoggingAdapter +import akka.http.impl.engine.client.PoolFlow.{ RequestContext, ResponseContext } +import akka.http.impl.engine.client.pool.NewHostConnectionPool +import akka.http.impl.engine.ws.ByteStringSinkProbe +import akka.http.scaladsl.{ ClientTransport, ConnectionContext, Http } +import akka.http.scaladsl.Http.ServerBinding +import akka.http.scaladsl.model._ +import akka.http.scaladsl.model.headers.Host +import akka.http.scaladsl.settings.{ ClientConnectionSettings, ConnectionPoolSettings } +import akka.stream._ +import akka.stream.scaladsl.{ BidiFlow, Flow, GraphDSL, Keep, Sink, Source, TLSPlacebo } +import akka.stream.stage.{ GraphStage, GraphStageLogic, InHandler, OutHandler } +import akka.stream.testkit.{ TestPublisher, TestSubscriber } +import akka.testkit._ +import akka.util.ByteString +import org.reactivestreams.{ Publisher, Subscriber } + +import scala.concurrent.{ Await, Future, Promise } +import scala.concurrent.duration._ +import scala.util.Failure + +/** + * Tests the host connection pool infrastructure. + * + * Right now it tests against various stacks with various depths. It's debatable whether it should actually be tested + * against plain network bytes instead to show interaction on the HTTP protocol level instead of against the server + * API level. + */ +class HostConnectionPoolSpec extends AkkaSpec( + """ + akka.loglevel = INFO + akka.actor { + serialize-creators = off + serialize-messages = off + default-dispatcher.throughput = 100 + } + akka.http.client.log-unencrypted-network-bytes = 200 + """ +) { + implicit val materializer = ActorMaterializer() + val singleElementBufferMaterializer = materializer // ActorMaterializer(ActorMaterializerSettings(system).withInputBuffer(1, 1)) + val defaultSettings = + ConnectionPoolSettings(system) + .withMaxConnections(1) + + trait PoolImplementation { + def get: (Flow[HttpRequest, HttpResponse, Future[Http.OutgoingConnection]], ConnectionPoolSettings, LoggingAdapter) ⇒ Flow[RequestContext, ResponseContext, Any] + } + trait ClientServerImplementation { + /** Returns a client / server implementation that include the kill switch flow in the middle */ + def get(connectionKillSwitch: SharedKillSwitch): BidiFlow[HttpResponse, HttpResponse, HttpRequest, HttpRequest, Future[Http.OutgoingConnection]] + + /** + * Specifies if the transport implementation will fail the handler request input side if an error is encountered + * at the response output side. + * + * I haven't decided yet what the right behavior should be. + */ + def failsHandlerInputWhenHandlerOutputFails: Boolean + } + + testSet(poolImplementation = NewPoolImplementation, clientServerImplementation = PassThrough) + testSet(poolImplementation = NewPoolImplementation, clientServerImplementation = AkkaHttpEngineNoNetwork) + testSet(poolImplementation = NewPoolImplementation, clientServerImplementation = AkkaHttpEngineTCP) + //testSet(poolImplementation = NewPoolImplementation, clientServerImplementation = AkkaHttpEngineTLS) + + testSet(poolImplementation = LegacyPoolImplementation, clientServerImplementation = PassThrough) + testSet(poolImplementation = LegacyPoolImplementation, clientServerImplementation = AkkaHttpEngineNoNetwork) + testSet(poolImplementation = LegacyPoolImplementation, clientServerImplementation = AkkaHttpEngineTCP) + //testSet(poolImplementation = OldPoolImplementation, clientServerImplementation = AkkaHttpEngineTLS) + + def testSet(poolImplementation: PoolImplementation, clientServerImplementation: ClientServerImplementation) = + s"$poolImplementation on $clientServerImplementation" should { + "complete a simple request/response cycle with a strict request and response" in new SetupWithServerProbes { + pushRequest(HttpRequest(uri = "/simple")) + + val conn1 = expectNextConnection() + val req = conn1.expectRequest() + conn1.pushResponse(HttpResponse(entity = req.uri.path.toString)) + expectResponseEntityAsString() shouldEqual "/simple" + } + "complete a simple request/response cycle with a chunked request and response" in new SetupWithServerProbes { + val reqBody = Source("Hello" :: " World" :: Nil map ByteString.apply) + pushRequest(HttpRequest(uri = "/simple", entity = HttpEntity.Chunked.fromData(ContentTypes.`application/octet-stream`, reqBody))) + + val conn1 = expectNextConnection() + val HttpRequest(_, _, _, reqEntityIn: HttpEntity.Chunked, _) = conn1.expectRequest() + reqEntityIn.dataBytes.runFold(ByteString.empty)(_ ++ _).awaitResult(3.seconds).utf8String shouldEqual "Hello World" + + val resBodyOut = conn1.pushChunkedResponse() + val resBodyIn = expectChunkedResponseBytesAsProbe() + + resBodyOut.sendNext(ByteString("hi")) + resBodyIn.expectUtf8EncodedString("hi") + + resBodyOut.sendComplete() + resBodyIn.request(1) // FIXME: should we support eager completion here? (reason is substreamHandler in PrepareResponse) + resBodyIn.expectComplete() + } + "open up to max-connections when enough requests are pending" in new SetupWithServerProbes(_.withMaxConnections(2)) { + pushRequest(HttpRequest(uri = "/1")) + val conn1 = expectNextConnection() + conn1.expectRequestToPath("/1") + + pushRequest(HttpRequest(uri = "/2")) + val conn2 = expectNextConnection() + conn2.expectRequestToPath("/2") + + pushRequest(HttpRequest(uri = "/3")) + conn1.pushResponse(HttpResponse()) + conn1.expectRequestToPath("/3") + } + "only buffer a reasonable number of extra requests" in pending + "only send next request when last response entity was read completely" in new SetupWithServerProbes() { + pushRequest(HttpRequest(uri = "/chunked-1")) + pushRequest(HttpRequest(uri = "/2")) + val conn1 = expectNextConnection() + conn1.expectRequestToPath("/chunked-1") + //FIXME: expectNoNewConnection() + //conn1.expectNoRequest() + + val resp1BytesOut = conn1.pushChunkedResponse() + val resp1BytesIn = expectChunkedResponseBytesAsProbe() + resp1BytesOut.sendNext(ByteString("test")) + resp1BytesIn.expectUtf8EncodedString("test") + + // FIXME: expectNoNewConnection() + // conn1.expectNoRequest() + + resp1BytesOut.sendComplete() + resp1BytesIn.request(1) // FIXME: should we support eager completion here? + resp1BytesIn.expectComplete() + + conn1.expectRequestToPath("/2") + } + "time out quickly when response entity stream is not subscribed fast enough" in new SetupWithServerProbes { + pendingIn(targetImpl = LegacyPoolImplementation) // not implemented in legacy + pendingIn(targetTrans = PassThrough) // infra seems to be missing something + + // FIXME: set subscription timeout to value relating to below `expectNoMsg` + + pushRequest(HttpRequest(uri = "/1")) + val conn1 = expectNextConnection() + conn1.expectRequestToPath("/1") + + val (resBodyOut, chunks) = + EventFilter.warning(pattern = ".*Response entity was not subscribed.*", occurrences = 1) intercept { + val resBodyOut = conn1.pushChunkedResponse() + val HttpResponse(_, _, HttpEntity.Chunked(_, chunks), _) = expectResponse() + (resBodyOut, chunks) + } + + val streamResult = chunks.runWith(Sink.ignore) + Await.ready(streamResult, 3.seconds) + streamResult.value.get.failed.get.getMessage shouldEqual "Substream Source cannot be materialized more than once" + } + "time out when a connection was unused for a long time" in pending + "time out and reconnect when a request is not handled in time" in pending + "time out when connection cannot be established" in pending + "fail a request if the request entity fails" in new SetupWithServerProbes { + val reqBytesOut = pushChunkedRequest() + + val conn1 = expectNextConnection() + val reqBytesIn = conn1.expectChunkedRequestBytesAsProbe() + reqBytesOut.sendNext(ByteString("hello")) + reqBytesIn.expectUtf8EncodedString("hello") + + reqBytesOut.sendError(new RuntimeException("oops")) + + // expectRequestStreamError(reqBytesIn) + + // FIXME: this is currently part of the implementation that is not tested (request entity error will fail the connection) + // conn1.serverRequests.expectError() + responseOut.expectSubscription() + // FIXME: currently the API is weird in that it contains a promise which is completed with a failure instead + // of properly threading through the context field from request to response + // responseOut.expectError() // actually, only the response should be failed + } + "fail a request if the connection stream fails while waiting for request entity bytes" in new SetupWithServerProbes { + val reqBytesOut = pushChunkedRequest(HttpRequest(method = HttpMethods.POST), numRetries = 0) + + val conn1 = expectNextConnection() + val reqBytesIn = conn1.expectChunkedRequestBytesAsProbe() + + reqBytesOut.sendNext(ByteString("chunk1")) + reqBytesIn.expectUtf8EncodedString("chunk1") + + conn1.failConnection(new RuntimeException("server temporarily out for lunch")) + + // server behavior not tested for now + // expectRequestStreamError(reqBytesIn) // some kind of truncation error + // reqBytesOut.expectCancellation() + expectResponseError() + } + "fail a request if the connection stream fails while waiting for a response" in new SetupWithServerProbes { + pushRequest(HttpRequest(method = HttpMethods.POST), numRetries = 0) + val conn1 = expectNextConnection() + conn1.expectRequest() + + conn1.failConnection(new RuntimeException("solar wind prevented transmission")) + expectResponseError() + } + "fail a request if the connection stream fails while waiting for response entity bytes" in new SetupWithServerProbes { + pushRequest(HttpRequest(method = HttpMethods.POST), numRetries = 0) + val conn1 = expectNextConnection() + conn1.expectRequest() + val resBytesOut = conn1.pushChunkedResponse() + val resBytesIn = expectChunkedResponseBytesAsProbe() + resBytesOut.sendNext(ByteString("hello")) + resBytesIn.expectUtf8EncodedString("hello") + + conn1.failConnection(new RuntimeException("solar wind prevented transmission")) + // server behavior not tested for now + // resBytesIn.expectError() + + // client already received response, no need to report error another time + } + "fail a request if the response entity stream fails during processing" in new SetupWithServerProbes { + pushRequest(HttpRequest(method = HttpMethods.POST), numRetries = 0) + val conn1 = expectNextConnection() + conn1.expectRequest() + val resBytesOut = conn1.pushChunkedResponse() + val resBytesIn = expectChunkedResponseBytesAsProbe() + resBytesOut.sendNext(ByteString("hello")) + resBytesIn.expectUtf8EncodedString("hello") + + resBytesOut.sendError(new RuntimeException("hard disk too soft for reading further")) + resBytesIn.expectError() + conn1.expectError() + + // client already received response, no need to report error another time + } + "create a new connection when previous one was closed regularly between requests" in new SetupWithServerProbes { + pushRequest(HttpRequest(uri = "/simple")) + + val conn1 = expectNextConnection() + val req = conn1.expectRequest() + conn1.pushResponse(HttpResponse(headers = headers.Connection("close") :: Nil, entity = req.uri.path.toString)) + expectResponseEntityAsString() shouldEqual "/simple" + conn1.completeHandler() + + pushRequest(HttpRequest(uri = "/next")) + val conn2 = expectNextConnection() + conn2.expectRequestToPath("/next") + conn2.pushResponse(HttpResponse(entity = "response")) + expectResponseEntityAsString() shouldEqual "response" + } + "create a new connection when previous one failed between requests" in new SetupWithServerProbes { + pushRequest(HttpRequest(uri = "/simple")) + + val conn1 = expectNextConnection() + val req = conn1.expectRequest() + conn1.pushResponse(HttpResponse(headers = headers.Connection("close") :: Nil, entity = req.uri.path.toString)) + expectResponseEntityAsString() shouldEqual "/simple" + conn1.failConnection(new RuntimeException("broken connection")) + + pushRequest(HttpRequest(uri = "/next")) + val conn2 = expectNextConnection() + conn2.expectRequestToPath("/next") + conn2.pushResponse(HttpResponse(entity = "response")) + expectResponseEntityAsString() shouldEqual "response" + } + "support 100-continue" in pending + "without any connections establish the number of configured min-connections" in new SetupWithServerProbes(_.withMaxConnections(2).withMinConnections(1)) { + // expect a new connection immediately + val conn1 = expectNextConnection() + + // should be used for the first request + pushRequest(HttpRequest(uri = "/simple")) + conn1.expectRequest() + } + "re-establish min-connections when number of open connections falls below threshold" in new SetupWithServerProbes(_.withMaxConnections(2).withMinConnections(1)) { + pendingIn(targetImpl = LegacyPoolImplementation) // has failed a few times but I didn't check why exactly + + // expect a new connection immediately + val conn1 = expectNextConnection() + + // should be used for the first request + pushRequest(HttpRequest(uri = "/simple")) + conn1.expectRequestToPath("/simple") + conn1.pushResponse(HttpResponse(headers = headers.Connection("close") :: Nil)) + expectResponse() + conn1.completeConnection() + + val conn2 = expectNextConnection() + } + "not send requests to known-to-be-closed-soon connections" in pending + "support retries" in pending + "strictly enforce number of established connections in longer running case" in pending + "provide access to basic metrics as the materialized value" in pending + "ignore the pipelining setting (for now)" in pending + "work correctly in the presence of `Connection: close` headers" in pending + "if connecting attempt fails, backup the next connection attempts" in pending + + def pendingIn(targetImpl: PoolImplementation = null, targetTrans: ClientServerImplementation = null): Unit = + if ((targetImpl == null || poolImplementation == targetImpl) && + (targetTrans == null || clientServerImplementation == targetTrans)) + pending + + abstract class TestSetup { + lazy val requestIn = TestPublisher.probe[RequestContext]() + lazy val responseOut = TestSubscriber.probe[ResponseContext]() + + protected val server: Flow[HttpRequest, HttpResponse, Future[Http.OutgoingConnection]] + + protected def settings: ConnectionPoolSettings + + lazy val impl = poolImplementation.get( + server, + settings, + system.log + ) + val stream = + Source.fromPublisher(requestIn) + .via(impl) + .runWith(Sink.fromSubscriber(responseOut)) + + def pushRequest(req: HttpRequest, numRetries: Int = 5): Unit = + requestIn.sendNext(RequestContext(req, Promise(), numRetries)) + + def pushChunkedRequest(req: HttpRequest = HttpRequest(), numRetries: Int = 5): TestPublisher.Probe[ByteString] = { + val probe = TestPublisher.probe[ByteString]() + pushRequest(req.withEntity(HttpEntity.Chunked.fromData(ContentTypes.`application/octet-stream`, Source.fromPublisher(probe))), numRetries) + probe + } + + def expectResponse(): HttpResponse = + responseOut.requestNext().response.recover { + case ex ⇒ throw new AssertionError("Expected successful response but got exception", ex) + }.get + + def expectResponseEntityAsString(): String = + expectResponse().entity.dataBytes.runFold(ByteString.empty)(_ ++ _).awaitResult(5.seconds.dilated).utf8String + + /** Expect a chunked response, connect a [[ByteStringSinkProbe]] to it and return it */ + def expectChunkedResponseBytesAsProbe(): ByteStringSinkProbe = { + val HttpResponse(_, _, entity: HttpEntity.Chunked, _) = expectResponse() + val probe = ByteStringSinkProbe() + entity.dataBytes.runWith(probe.sink) + probe + } + + def expectNoRequestDemand(): Unit = + requestIn.pending shouldEqual 0 + + def expectResponseError(): Throwable = + responseOut.requestNext().response.failed.get + } + + class SetupWithServerProbes(changeSettings: ConnectionPoolSettings ⇒ ConnectionPoolSettings = identity) extends TestSetup { + override protected def settings = changeSettings(defaultSettings) + + class ServerConnection(requestPublisher: Publisher[HttpRequest], responseSubscriber: Subscriber[HttpResponse]) { + val serverRequests = TestSubscriber.probe[HttpRequest]() + val serverResponses = TestPublisher.probe[HttpResponse]() + val killSwitch = KillSwitches.shared("connection-kill-switch") + + def expectRequest(): HttpRequest = + serverRequests.requestNext() + + def expectRequestToPath(path: String): Unit = + expectRequest().uri.path.toString shouldEqual path + + /** Expect a chunked response, connect a [[ByteStringSinkProbe]] to it and return it */ + def expectChunkedRequestBytesAsProbe(): ByteStringSinkProbe = { + val HttpRequest(_, _, _, entity: HttpEntity.Chunked, _) = expectRequest() + val probe = ByteStringSinkProbe() + entity.dataBytes.runWith(probe.sink) + probe + } + + def expectNoRequest(): Unit = + serverRequests.expectNoMsg() + + def pushResponse(response: HttpResponse = HttpResponse()) = + serverResponses.sendNext(response) + + def pushChunkedResponse(response: HttpResponse = HttpResponse()): TestPublisher.Probe[ByteString] = { + val res = TestPublisher.probe[ByteString]() + pushResponse(response.withEntity(HttpEntity.Chunked.fromData(ContentTypes.`application/octet-stream`, Source.fromPublisher(res)))) + res + } + + def completeHandler(): Unit = { + serverResponses.sendComplete() + serverRequests.expectComplete() + } + + def failConnection(cause: Exception): Unit = + killSwitch.abort(cause) + + def completeConnection(): Unit = + killSwitch.shutdown() + + def failHandler(cause: Exception): Unit = { + serverResponses.sendError(cause) + // since this is server behavior, it's not really important to check it here + // FIXME: verify server behavior + expectErrorOrCompleteOnRequestSide() + } + def expectError(): Unit = { + serverResponses.expectCancellation() + expectErrorOrCompleteOnRequestSide() + } + def expectErrorOrCompleteOnRequestSide(): Unit = + serverRequests.expectEventPF { + case _: TestSubscriber.OnError ⇒ + case TestSubscriber.OnComplete ⇒ + } + + lazy val outgoingConnection: Future[Http.OutgoingConnection] = + Flow.fromSinkAndSource( + Sink.fromSubscriber(serverRequests), + Source.fromPublisher(serverResponses)) + .joinMat(clientServerImplementation.get(killSwitch))(Keep.right) + .recover { + case ex ⇒ + println(s"Pool failed with error ${ex.getMessage}") + ex.printStackTrace() + throw ex + } + .join( + Flow.fromSinkAndSource( + Sink.fromSubscriber(responseSubscriber), + Source.fromPublisher(requestPublisher) + )) + .run()(singleElementBufferMaterializer) + } + + private val serverConnections = TestProbe() + + def expectNextConnection(): ServerConnection = + serverConnections.expectMsgType[ServerConnection] + + def expectNoNewConnection(): Unit = + serverConnections.expectNoMsg() + + protected lazy val server = + Flow.fromSinkAndSourceMat( + // buffer is needed because the async subscriber/publisher boundary will otherwise request > 1 + Flow[HttpRequest].buffer(1, OverflowStrategy.backpressure) + .toMat(Sink.asPublisher[HttpRequest](false))(Keep.right), + Source.asSubscriber[HttpResponse])(Keep.both) + .mapMaterializedValue { + case (requestPublisher, responseSubscriber) ⇒ + val connection = new ServerConnection(requestPublisher, responseSubscriber) + serverConnections.ref ! connection + connection.outgoingConnection + } + } + } + + case object LegacyPoolImplementation extends PoolImplementation { + override def get = PoolFlow(_, _, _) + } + case object NewPoolImplementation extends PoolImplementation { + override def get = NewHostConnectionPool(_, _, _) + } + + /** Transport that just passes through requests / responses */ + case object PassThrough extends ClientServerImplementation { + def failsHandlerInputWhenHandlerOutputFails: Boolean = true + override def get(connectionKillSwitch: SharedKillSwitch): BidiFlow[HttpResponse, HttpResponse, HttpRequest, HttpRequest, Future[Http.OutgoingConnection]] = + BidiFlow.fromGraph(PassThroughTransport) + .atop(BidiFlow.fromFlows(connectionKillSwitch.flow[HttpResponse], connectionKillSwitch.flow[HttpRequest])) + .mapMaterializedValue(_ ⇒ Future.successful(newOutgoingConnection())) + + object PassThroughTransport extends GraphStage[BidiShape[HttpResponse, HttpResponse, HttpRequest, HttpRequest]] { + val reqIn = Inlet[HttpRequest]("reqIn") + val reqOut = Outlet[HttpRequest]("reqOut") + val resIn = Inlet[HttpResponse]("resIn") + val resOut = Outlet[HttpResponse]("resOut") + + val shape = BidiShape(resIn, resOut, reqIn, reqOut) + + def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) { + val failureCallback = getAsyncCallback[Throwable](cause ⇒ failStage(cause)) + val killSwitch = KillSwitches.shared("entity") + + object AddKillSwitch extends StreamUtils.EntityStreamOp[Unit] { + def strictM: Unit = () + def apply[T, Mat](source: Source[T, Mat]): (Source[T, Mat], Unit) = + (source.via(killSwitch.flow[T]), ()) + } + class MonitorMessage[T <: HttpMessage](in: Inlet[T], out: Outlet[T]) extends InHandler with OutHandler { + + def onPush(): Unit = { + val msg: T = grab(in) + + val (newEntity, res) = + HttpEntity.captureTermination(msg.entity) + + val finalMsg: T = msg.withEntity( + StreamUtils.transformEntityStream(newEntity, AddKillSwitch)._1 + .asInstanceOf[MessageEntity]).asInstanceOf[T] // FIXME: that cast is probably unsafe for CloseLimited + + res.onComplete { // if entity fails we report back to fail the stage + case Failure(cause) ⇒ failureCallback.invoke(cause) + case _ ⇒ + }(materializer.executionContext) + + push(out, finalMsg) + } + def onPull(): Unit = pull(in) + + override def onUpstreamFailure(ex: Throwable): Unit = { + killSwitch.abort(ex) + super.onUpstreamFailure(ex) + } + + override def onDownstreamFinish(): Unit = failStage(new RuntimeException("was cancelled")) + } + setHandlers(reqIn, reqOut, new MonitorMessage(reqIn, reqOut)) + setHandlers(resIn, resOut, new MonitorMessage(resIn, resOut)) + } + } + } + /** Transport that runs everything through client and server engines but without actual network */ + case object AkkaHttpEngineNoNetwork extends ClientServerImplementation { + def failsHandlerInputWhenHandlerOutputFails: Boolean = false + + override def get(connectionKillSwitch: SharedKillSwitch): BidiFlow[HttpResponse, HttpResponse, HttpRequest, HttpRequest, Future[Http.OutgoingConnection]] = + Http().serverLayer() atop + TLSPlacebo() atop + BidiFlow.fromFlows(connectionKillSwitch.flow[ByteString], connectionKillSwitch.flow[ByteString]) atop + TLSPlacebo().reversed atop + Http().clientLayer(Host("example.org")).reversed mapMaterializedValue (_ ⇒ Future.successful(newOutgoingConnection())) + } + + class KillSwitchedClientTransport(connectionKillSwitch: SharedKillSwitch) extends ClientTransport { + def connectTo(host: String, port: Int, settings: ClientConnectionSettings)(implicit system: ActorSystem): Flow[ByteString, ByteString, Future[Http.OutgoingConnection]] = + Flow[ByteString] + .via(connectionKillSwitch.flow[ByteString]) + .viaMat(ClientTransport.TCP.connectTo(host, port, settings))(Keep.right) + .viaMat(connectionKillSwitch.flow[ByteString])(Keep.left) + } + + /** Transport that uses actual top-level Http APIs to establish a plaintext HTTP connection */ + case object AkkaHttpEngineTCP extends TopLevelApiClientServerImplementation { + protected override def bindServerSource = Http().bind("localhost", 0) + protected def clientConnectionFlow(connectionKillSwitch: SharedKillSwitch): Flow[HttpRequest, HttpResponse, Future[Http.OutgoingConnection]] = + Http().outgoingConnectionUsingTransport(host = "localhost", port = serverBinding.localAddress.getPort, connectionContext = ConnectionContext.noEncryption(), transport = new KillSwitchedClientTransport(connectionKillSwitch)) + } + + /** + * Transport that uses actual top-level Http APIs to establish a HTTPS connection + * + * Currently requires an /etc/hosts entry that points akka.example.org to a locally bindable address. + */ + case object AkkaHttpEngineTLS extends TopLevelApiClientServerImplementation { + protected override def bindServerSource = Http().bind("akka.example.org", 0, connectionContext = ExampleHttpContexts.exampleServerContext) + protected def clientConnectionFlow(connectionKillSwitch: SharedKillSwitch): Flow[HttpRequest, HttpResponse, Future[Http.OutgoingConnection]] = + Http().outgoingConnectionUsingTransport(host = "akka.example.org", port = serverBinding.localAddress.getPort, connectionContext = ExampleHttpContexts.exampleClientContext, transport = new KillSwitchedClientTransport(connectionKillSwitch)) + } + abstract class TopLevelApiClientServerImplementation extends ClientServerImplementation { + def failsHandlerInputWhenHandlerOutputFails: Boolean = false + + protected def bindServerSource: Source[Http.IncomingConnection, Future[ServerBinding]] + protected def clientConnectionFlow(connectionKillSwitch: SharedKillSwitch): Flow[HttpRequest, HttpResponse, Future[Http.OutgoingConnection]] + + val connectionProbe = TestProbe() + val serverBinding: ServerBinding = + bindServerSource + .to(Sink.foreach { serverConnection ⇒ + connectionProbe.ref ! serverConnection + }) + .run().awaitResult(3.seconds) + + override def get(connectionKillSwitch: SharedKillSwitch): BidiFlow[HttpResponse, HttpResponse, HttpRequest, HttpRequest, Future[Http.OutgoingConnection]] = + // needs to be an involved two step process: + // 1. setup client flow and proxies on the server side to be able to return that flow immediately + // 2. when client connection was established, grab server connection as well and attach to proxies + // (cannot be implemented with just mapMaterializedValue because there's no transposing constructor for BidiFlow) + BidiFlow.fromGraph( + GraphDSL.create(Sink.asPublisher[HttpResponse](fanout = false), Source.asSubscriber[HttpRequest], clientConnectionFlow(connectionKillSwitch))((_, _, _)) { implicit builder ⇒ (resIn, reqOut, client) ⇒ + import GraphDSL.Implicits._ + + builder.materializedValue ~> Sink.foreach[(Publisher[HttpResponse], Subscriber[HttpRequest], Future[Http.OutgoingConnection])] { + case (resOut, reqIn, clientConn) ⇒ + clientConn.foreach { _ ⇒ + val serverConn = connectionProbe.expectMsgType[Http.IncomingConnection] + Flow.fromSinkAndSource( + Sink.fromSubscriber(reqIn), + Source.fromPublisher(resOut)).join(serverConn.flow).run() + }(system.dispatcher) + } + + BidiShape(resIn.in, client.out, client.in, reqOut.out) + } + ).mapMaterializedValue(_._3) + } + + /** Generates a new unique outgoingConnection */ + protected val newOutgoingConnection: () ⇒ Http.OutgoingConnection = { + val portCounter = new AtomicInteger(1) + + () ⇒ { + val connId = portCounter.getAndIncrement() + Http.OutgoingConnection( + InetSocketAddress.createUnresolved(s"local-$connId", connId % 65536), + InetSocketAddress.createUnresolved("remote", 5555)) + } + } +}