Skip to content

Commit

Permalink
Merge pull request #4179 from RaasAhsan/ember-server-graceful-shutdown
Browse files Browse the repository at this point in the history
Ember server graceful shutdown
  • Loading branch information
rossabaker committed Jan 16, 2021
2 parents dd9335d + 1ce4236 commit 5b2ed57
Show file tree
Hide file tree
Showing 5 changed files with 211 additions and 27 deletions.
9 changes: 8 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -256,9 +256,16 @@ lazy val emberServer = libraryProject("ember-server")
ProblemFilters.exclude[DirectMissingMethodProblem]("org.http4s.ember.server.internal.ServerHelpers.server"),
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.http4s.ember.server.internal.ServerHelpers.server$default$12"),
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.http4s.ember.server.internal.ServerHelpers.server$default$12"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.http4s.ember.server.EmberServerBuilder.this"),
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.http4s.ember.server.internal.ServerHelpers.server"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.http4s.ember.server.internal.ServerHelpers.server$default$5"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.http4s.ember.server.internal.ServerHelpers.server$default$7"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.http4s.ember.server.internal.ServerHelpers.server$default$7"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.http4s.ember.server.internal.ServerHelpers.server$default$5"),
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.http4s.ember.server.internal.ServerHelpers.server")
)
)
.dependsOn(emberCore % "compile;test->test", server % "compile;test->test")
.dependsOn(emberCore % "compile;test->test", server % "compile;test->test", emberClient % "test->compile")

lazy val emberClient = libraryProject("ember-client")
.settings(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,17 @@ package org.http4s.ember.server
import cats._
import cats.syntax.all._
import cats.effect._
import fs2.concurrent._
import fs2.io.tcp.SocketGroup
import fs2.io.tcp.SocketOptionMapping
import fs2.io.tls._
import org.http4s._
import org.http4s.server.Server

import scala.concurrent.duration._
import java.net.InetSocketAddress
import _root_.io.chrisdavenport.log4cats.Logger
import _root_.io.chrisdavenport.log4cats.slf4j.Slf4jLogger
import org.http4s.ember.server.internal.ServerHelpers
import org.http4s.ember.server.internal.{ServerHelpers, Shutdown}

final class EmberServerBuilder[F[_]: Concurrent: Timer: ContextShift] private (
val host: String,
Expand All @@ -45,6 +45,7 @@ final class EmberServerBuilder[F[_]: Concurrent: Timer: ContextShift] private (
val maxHeaderSize: Int,
val requestHeaderReceiveTimeout: Duration,
val idleTimeout: Duration,
val shutdownTimeout: Duration,
val additionalSocketOptions: List[SocketOptionMapping[_]],
private val logger: Logger[F]
) { self =>
Expand Down Expand Up @@ -79,6 +80,7 @@ final class EmberServerBuilder[F[_]: Concurrent: Timer: ContextShift] private (
maxHeaderSize = maxHeaderSize,
requestHeaderReceiveTimeout = requestHeaderReceiveTimeout,
idleTimeout = EmberServerBuilder.Defaults.idleTimeout,
shutdownTimeout = EmberServerBuilder.Defaults.shutdownTimeout,
additionalSocketOptions = additionalSocketOptions,
logger = logger
)
Expand All @@ -97,6 +99,7 @@ final class EmberServerBuilder[F[_]: Concurrent: Timer: ContextShift] private (
maxHeaderSize: Int = self.maxHeaderSize,
requestHeaderReceiveTimeout: Duration = self.requestHeaderReceiveTimeout,
idleTimeout: Duration = self.idleTimeout,
shutdownTimeout: Duration = self.shutdownTimeout,
additionalSocketOptions: List[SocketOptionMapping[_]] = self.additionalSocketOptions,
logger: Logger[F] = self.logger
): EmberServerBuilder[F] =
Expand All @@ -114,6 +117,7 @@ final class EmberServerBuilder[F[_]: Concurrent: Timer: ContextShift] private (
maxHeaderSize = maxHeaderSize,
requestHeaderReceiveTimeout = requestHeaderReceiveTimeout,
idleTimeout = idleTimeout,
shutdownTimeout = shutdownTimeout,
additionalSocketOptions = additionalSocketOptions,
logger = logger
)
Expand All @@ -136,6 +140,9 @@ final class EmberServerBuilder[F[_]: Concurrent: Timer: ContextShift] private (
def withIdleTimeout(idleTimeout: Duration) =
copy(idleTimeout = idleTimeout)

def withShutdownTimeout(shutdownTimeout: Duration) =
copy(shutdownTimeout = shutdownTimeout)

def withOnError(onError: Throwable => Response[F]) = copy(onError = onError)
def withOnWriteFailure(onWriteFailure: (Option[Request[F]], Response[F], Throwable) => F[Unit]) =
copy(onWriteFailure = onWriteFailure)
Expand All @@ -148,20 +155,20 @@ final class EmberServerBuilder[F[_]: Concurrent: Timer: ContextShift] private (

def build: Resource[F, Server[F]] =
for {
bindAddress <- Resource.liftF(Sync[F].delay(new InetSocketAddress(host, port)))
blocker <- blockerOpt.fold(Blocker[F])(_.pure[Resource[F, *]])
sg <- sgOpt.fold(SocketGroup[F](blocker))(_.pure[Resource[F, *]])
bindAddress <- Resource.liftF(Sync[F].delay(new InetSocketAddress(host, port)))
shutdownSignal <- Resource.liftF(SignallingRef[F, Boolean](false))
shutdown <- Resource.liftF(Shutdown[F](shutdownTimeout))
_ <- Concurrent[F].background(
ServerHelpers
.server(
bindAddress,
httpApp,
sg,
tlsInfoOpt,
shutdown,
onError,
onWriteFailure,
shutdownSignal.some,
maxConcurrency,
receiveBufferSize,
maxHeaderSize,
Expand All @@ -173,7 +180,7 @@ final class EmberServerBuilder[F[_]: Concurrent: Timer: ContextShift] private (
.compile
.drain
)
_ <- Resource.make(Applicative[F].unit)(_ => shutdownSignal.set(true))
_ <- Resource.make(Applicative[F].unit)(_ => shutdown.await)
} yield new Server[F] {
def address: InetSocketAddress = bindAddress
def isSecure: Boolean = tlsInfoOpt.isDefined
Expand All @@ -196,6 +203,7 @@ object EmberServerBuilder {
maxHeaderSize = Defaults.maxHeaderSize,
requestHeaderReceiveTimeout = Defaults.requestHeaderReceiveTimeout,
idleTimeout = Defaults.idleTimeout,
shutdownTimeout = Defaults.shutdownTimeout,
additionalSocketOptions = Defaults.additionalSocketOptions,
logger = Slf4jLogger.getLogger[F]
)
Expand All @@ -217,6 +225,7 @@ object EmberServerBuilder {
val maxHeaderSize: Int = server.defaults.MaxHeadersSize
val requestHeaderReceiveTimeout: Duration = 5.seconds
val idleTimeout: Duration = server.defaults.IdleTimeout
val shutdownTimeout: Duration = server.defaults.ShutdownTimeout
val additionalSocketOptions = List.empty[SocketOptionMapping[_]]
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,29 +40,25 @@ private[server] object ServerHelpers {
private val close = Connection(NonEmptyList.of(closeCi))
private val keepAlive = Connection(NonEmptyList.one("keep-alive".ci))

def server[F[_]: Concurrent: ContextShift](
def server[F[_]: ContextShift](
bindAddress: InetSocketAddress,
httpApp: HttpApp[F],
sg: SocketGroup,
tlsInfoOpt: Option[(TLSContext, TLSParameters)],
shutdown: Shutdown[F],
// Defaults
onError: Throwable => Response[F] = { (_: Throwable) =>
Response[F](Status.InternalServerError)
},
onWriteFailure: (Option[Request[F]], Response[F], Throwable) => F[Unit],
terminationSignal: Option[SignallingRef[F, Boolean]] = None,
maxConcurrency: Int = Int.MaxValue,
receiveBufferSize: Int = 256 * 1024,
maxHeaderSize: Int = 10 * 1024,
requestHeaderReceiveTimeout: Duration = 5.seconds,
idleTimeout: Duration = 60.seconds,
additionalSocketOptions: List[SocketOptionMapping[_]] = List.empty,
logger: Logger[F]
)(implicit C: Clock[F]): Stream[F, Nothing] = {
// Termination Signal, if not present then does not terminate.
val termSignal: F[SignallingRef[F, Boolean]] =
terminationSignal.fold(SignallingRef[F, Boolean](false))(_.pure[F])

)(implicit F: Concurrent[F], C: Clock[F]): Stream[F, Nothing] = {
def socketReadRequest(
socket: Socket[F],
requestHeaderReceiveTimeout: Duration,
Expand Down Expand Up @@ -155,18 +151,18 @@ private[server] object ServerHelpers {
}
.drain

Stream
.eval(termSignal)
.flatMap(terminationSignal =>
sg.server[F](bindAddress, additionalSocketOptions = additionalSocketOptions)
.map { connect =>
Stream
.resource(connect.flatMap(upgradeSocket(_, tlsInfoOpt)))
.flatMap(withUpgradedSocket(_))
}
.parJoin(maxConcurrency)
.interruptWhen(terminationSignal)
.drain)

sg.server[F](bindAddress, additionalSocketOptions = additionalSocketOptions)
.interruptWhen(shutdown.signal.attempt)
// Divorce the scopes of the server stream and handler streams so the
// former can be terminated while handlers complete.
.prefetch
.map { connect =>
shutdown.trackConnection >>
Stream
.resource(connect.flatMap(upgradeSocket(_, tlsInfoOpt)))
.flatMap(withUpgradedSocket(_))
}
.parJoin(maxConcurrency)
.drain
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
* Copyright 2019 http4s.org
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.http4s.ember.server.internal

import cats.syntax.all._
import cats.effect._
import cats.effect.implicits._
import cats.effect.concurrent._
import fs2.Stream

import scala.concurrent.duration.{Duration, FiniteDuration}

private[server] abstract class Shutdown[F[_]] {
def await: F[Unit]
def signal: F[Unit]
def newConnection: F[Unit]
def removeConnection: F[Unit]

def trackConnection: Stream[F, Unit] =
Stream.bracket(newConnection)(_ => removeConnection)
}

private[server] object Shutdown {

def apply[F[_]](timeout: Duration)(implicit F: Concurrent[F], timer: Timer[F]): F[Shutdown[F]] =
timeout match {
case fi: FiniteDuration =>
if (fi.length == 0) immediateShutdown else timedShutdown(timeout)
case _ => timedShutdown(timeout)
}

private def timedShutdown[F[_]](
timeout: Duration)(implicit F: Concurrent[F], timer: Timer[F]): F[Shutdown[F]] = {
case class State(isShutdown: Boolean, active: Int)

for {
unblockStart <- Deferred[F, Unit]
unblockFinish <- Deferred[F, Unit]
state <- Ref.of[F, State](State(false, 0))
} yield new Shutdown[F] {
override val await: F[Unit] =
unblockStart
.complete(())
.flatMap { _ =>
state.modify { case s @ State(_, active) =>
val fa = if (active == 0) {
F.unit
} else {
timeout match {
case fi: FiniteDuration => unblockFinish.get.timeoutTo(fi, F.unit)
case _ => unblockFinish.get
}
}
s.copy(isShutdown = true) -> fa
}
}
.uncancelable
.flatten

override val signal: F[Unit] =
unblockStart.get

override val newConnection: F[Unit] =
state.update { s =>
s.copy(active = s.active + 1)
}

override val removeConnection: F[Unit] =
state
.modify { case s @ State(isShutdown, active) =>
val conns = active - 1
if (isShutdown && conns <= 0) {
s.copy(active = conns) -> unblockFinish.complete(())
} else {
s.copy(active = conns) -> F.unit
}
}
.flatten
.uncancelable
}
}

private def immediateShutdown[F[_]](implicit F: Concurrent[F]): F[Shutdown[F]] =
Deferred[F, Unit].map { unblock =>
new Shutdown[F] {
override val await: F[Unit] = unblock.complete(())
override val signal: F[Unit] = unblock.get
override val newConnection: F[Unit] = F.unit
override val removeConnection: F[Unit] = F.unit
override val trackConnection: Stream[F, Unit] = Stream.empty
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright 2019 http4s.org
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.http4s.ember.server

import cats.syntax.all._
import cats.effect._
import org.http4s._
import org.http4s.client.Client
import org.http4s.server.Server
import org.http4s.implicits._
import org.http4s.dsl.Http4sDsl
import org.http4s.ember.client.EmberClientBuilder

import scala.concurrent.duration._

class EmberServerSuite extends Http4sSuite {

def service[F[_]](implicit F: Async[F]): HttpApp[F] = {
val dsl = new Http4sDsl[F] {}
import dsl._

HttpRoutes
.of[F] { case GET -> Root =>
Ok("Hello!")
}
.orNotFound
}

def serverResource: Resource[IO, Server[IO]] =
EmberServerBuilder
.default[IO]
.withHttpApp(service[IO])
.build

def client: FunFixture[Client[IO]] =
ResourceFixture(EmberClientBuilder.default[IO].build)

def server: FunFixture[Server[IO]] =
ResourceFixture(serverResource)

def fixture: FunFixture[(Server[IO], Client[IO])] =
FunFixture.map2(server, client)

fixture.test("server responds to requests") { case (server, client) =>
IO.sleep(3.seconds) >> client
.get(s"http://${server.address.getHostName}:${server.address.getPort}")(_.status.pure[IO])
.timeout(5.seconds)
.assertEquals(Status.Ok)
}
}

0 comments on commit 5b2ed57

Please sign in to comment.