Skip to content

Commit

Permalink
=h2 akka#886 fix leaks in ALPN switcher
Browse files Browse the repository at this point in the history
This technically is very similar to CoupledTerminationFlow,
however since it's "sub" sources/sinks we could not directly use that.
  • Loading branch information
ktoso committed May 1, 2017
1 parent 6be1971 commit bf7b57b
Showing 1 changed file with 62 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ import javax.net.ssl.SSLException
import akka.NotUsed
import akka.http.scaladsl.model.{ HttpRequest, HttpResponse }
import akka.stream.TLSProtocol.{ SessionBytes, SessionTruncated, SslTlsInbound, SslTlsOutbound }
import akka.stream.scaladsl.{ BidiFlow, Flow }
import akka.stream.scaladsl.{ BidiFlow, Flow, GraphDSL, Keep, Sink, Source }
import akka.stream.stage.{ GraphStage, GraphStageLogic, InHandler, OutHandler }
import akka.stream.{ Attributes, BidiShape, Inlet, Outlet }
import akka.stream._

object AlpnSwitch {
type HttpServerBidiFlow = BidiFlow[HttpResponse, SslTlsOutbound, SslTlsInbound, HttpRequest, NotUsed]
Expand All @@ -21,16 +21,29 @@ object AlpnSwitch {
http2Stack: HttpServerBidiFlow): HttpServerBidiFlow =
BidiFlow.fromGraph(
new GraphStage[BidiShape[HttpResponse, SslTlsOutbound, SslTlsInbound, HttpRequest]] {

// --- outer ports ---
val netIn = Inlet[SslTlsInbound]("AlpnSwitch.netIn")
val netOut = Outlet[SslTlsOutbound]("AlpnSwitch.netOut")

val requestOut = Outlet[HttpRequest]("AlpnSwitch.requestOut")
val responseIn = Inlet[HttpResponse]("AlpnSwitch.responseIn")
// --- end of outer ports ---

val shape: BidiShape[HttpResponse, SslTlsOutbound, SslTlsInbound, HttpRequest] =
BidiShape(responseIn, netOut, netIn, requestOut)

def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) {
logic

// --- inner ports, bound to actual server in install call ---
val serverDataIn = new SubSinkInlet[SslTlsOutbound]("ServerImpl.netIn")
val serverDataOut = new SubSourceOutlet[SslTlsInbound]("ServerImpl.netOut")

val serverRequestIn = new SubSinkInlet[HttpRequest]("ServerImpl.serverRequestIn")
val serverResponseOut = new SubSourceOutlet[HttpResponse]("ServerImpl.serverResponseOut")
// --- end of inner ports ---

override def preStart(): Unit = pull(netIn)

setHandler(netIn, new InHandler {
Expand All @@ -46,20 +59,14 @@ object AlpnSwitch {
}
})

val ignorePull = new OutHandler { def onPull(): Unit = () }
val failPush = new InHandler { def onPush(): Unit = throw new IllegalStateException("Wasn't pulled yet") }
private val ignorePull = new OutHandler { def onPull(): Unit = () }
private val failPush = new InHandler { def onPush(): Unit = throw new IllegalStateException("Wasn't pulled yet") }

setHandler(netOut, ignorePull)
setHandler(requestOut, ignorePull)
setHandler(responseIn, failPush)

def install(serverImplementation: HttpServerBidiFlow, firstElement: SslTlsInbound): Unit = {
val serverDataIn = new SubSinkInlet[SslTlsOutbound]("ServerImpl.netIn")
val serverDataOut = new SubSourceOutlet[SslTlsInbound]("ServerImpl.netOut")

val serverRequestIn = new SubSinkInlet[HttpRequest]("ServerImpl.serverRequestIn")
val serverResponseOut = new SubSourceOutlet[HttpResponse]("ServerImpl.serverResponseOut")

val networkSide = Flow.fromSinkAndSource(serverDataIn.sink, serverDataOut.source)
val userSide = Flow.fromSinkAndSource(serverRequestIn.sink, serverResponseOut.source)

Expand All @@ -75,11 +82,15 @@ object AlpnSwitch {
.run()(interpreter.subFusingMaterializer)
}

// helpers to connect inlets and outlets
// helpers to connect inlets and outlets also binding completion signals of given ports
def connect[T](in: Inlet[T], out: SubSourceOutlet[T], initialElement: Option[T]): Unit = {
val propagatePull =
new OutHandler {
def onPull(): Unit = pull(in)
override def onPull(): Unit = pull(in)
override def onDownstreamFinish(): Unit = {
out.complete()
super.onDownstreamFinish()
}
}

val firstHandler =
Expand All @@ -89,28 +100,59 @@ object AlpnSwitch {
propagatePull
case Some(ele)
new OutHandler {
def onPull(): Unit = {
override def onPull(): Unit = {
out.push(initialElement.get)
out.setHandler(propagatePull)
}

override def onDownstreamFinish(): Unit = {
out.complete()
super.onDownstreamFinish()
}
}
case None propagatePull
}

out.setHandler(firstHandler)
setHandler(in, new InHandler {
def onPush(): Unit = out.push(grab(in))
override def onPush(): Unit = out.push(grab(in))

override def onUpstreamFinish(): Unit = {
out.complete()
super.onUpstreamFinish()
}

override def onUpstreamFailure(ex: Throwable): Unit = {
out.fail(ex)
super.onUpstreamFailure(ex)
}
})

if (out.isAvailable) pull(in) // to account for lost pulls during initialization
}
def connect[T](in: SubSinkInlet[T], out: Outlet[T]): Unit = {
in.setHandler(new InHandler {
def onPush(): Unit = push(out, in.grab())
})
setHandler(out, new OutHandler {
def onPull(): Unit = in.pull()
})
val handler = new InHandler {
override def onPush(): Unit = push(out, in.grab())

override def onUpstreamFinish(): Unit = {
in.cancel()
super.onUpstreamFinish()
}
override def onUpstreamFailure(ex: Throwable): Unit = {
in.cancel()
super.onUpstreamFailure(ex)
}
}

val outHandler = new OutHandler {
override def onPull(): Unit = in.pull()
override def onDownstreamFinish(): Unit = {
in.cancel()
super.onDownstreamFinish()
}
}
in.setHandler(handler)
setHandler(out, outHandler)

if (isAvailable(out)) in.pull() // to account for lost pulls during initialization
}
Expand Down

0 comments on commit bf7b57b

Please sign in to comment.