Skip to content

Commit

Permalink
Don't start response until after we add async timeout listener
Browse files Browse the repository at this point in the history
  • Loading branch information
rossabaker committed Sep 24, 2018
1 parent 6220e2b commit 38b7389
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 43 deletions.
Expand Up @@ -2,11 +2,14 @@ package org.http4s
package server
package jetty

import cats.effect.IO
import cats.effect.{IO, Timer}
import cats.implicits._
import java.net.{HttpURLConnection, URL}
import java.io.IOException
import java.nio.charset.StandardCharsets
import org.http4s.dsl.io._
import org.specs2.specification.AfterAll
import scala.concurrent.duration._
import scala.io.Source

class JettyServerSpec extends Http4sSpec with AfterAll {
Expand All @@ -15,6 +18,7 @@ class JettyServerSpec extends Http4sSpec with AfterAll {
val server =
builder
.bindAny()
.withAsyncTimeout(500.millis)
.mountService(
HttpRoutes.of {
case GET -> Root / "thread" / "routing" =>
Expand All @@ -26,6 +30,12 @@ class JettyServerSpec extends Http4sSpec with AfterAll {

case req @ POST -> Root / "echo" =>
Ok(req.body)

case GET -> Root / "never" =>
IO.never

case GET -> Root / "slow" =>
implicitly[Timer[IO]].sleep(50.millis) *> Ok("slow")
},
"/"
)
Expand Down Expand Up @@ -67,4 +77,14 @@ class JettyServerSpec extends Http4sSpec with AfterAll {
post("/echo", input) must startWith(input)
}
}

"Timeout" should {
"not fire prematurely" in {
get("/slow") must_== "slow"
}

"fire on timeout" in {
get("/never") must throwAn[IOException]
}
}
}
33 changes: 16 additions & 17 deletions servlet/src/main/scala/org/http4s/servlet/AsyncHttp4sServlet.scala
Expand Up @@ -3,6 +3,7 @@ package servlet

import cats.data.OptionT
import cats.effect._
import cats.effect.concurrent.Deferred
import cats.implicits.{catsSyntaxEither => _, _}
import javax.servlet._
import javax.servlet.http.{HttpServletRequest, HttpServletResponse}
Expand Down Expand Up @@ -57,22 +58,20 @@ class AsyncHttp4sServlet[F[_]](
private def handleRequest(
ctx: AsyncContext,
request: Request[F],
bodyWriter: BodyWriter[F]): F[Unit] = {
val timeout = F.async[Unit] { cb =>
ctx.addListener(new AsyncTimeoutHandler(cb))
}
bodyWriter: BodyWriter[F]): F[Unit] = Deferred[F, Unit].flatMap { gate =>
// It is an error to add a listener to an async context that is
// already completed, so we must take care to add the listener
// before the response can complete.
val timeout =
F.asyncF[Response[F]](cb => gate.complete(ctx.addListener(new AsyncTimeoutHandler(cb))))
val response =
optionTSync
.suspend(serviceFn(request))
.getOrElse(Response.notFound)
.recoverWith(serviceErrorHandler(request))
gate.get *>
optionTSync
.suspend(serviceFn(request))
.getOrElse(Response.notFound)
.recoverWith(serviceErrorHandler(request))
val servletResponse = ctx.getResponse.asInstanceOf[HttpServletResponse]
F.race(timeout, response).flatMap {
case Left(()) =>
renderResponse(Response.timeout[F], servletResponse, bodyWriter, F.never)
case Right(resp) =>
renderResponse(resp, servletResponse, bodyWriter, timeout)
}
F.race(timeout, response).flatMap(r => renderResponse(r.merge, servletResponse, bodyWriter))
}

private def errorHandler(
Expand All @@ -86,19 +85,19 @@ class AsyncHttp4sServlet[F[_]](
val response = Response[F](Status.InternalServerError)
// We don't know what I/O mode we're in here, and we're not rendering a body
// anyway, so we use a NullBodyWriter.
val f = renderResponse(response, servletResponse, NullBodyWriter, F.unit) *>
val f = renderResponse(response, servletResponse, NullBodyWriter) *>
F.delay(
if (servletRequest.isAsyncStarted)
servletRequest.getAsyncContext.complete()
)
F.runAsync(f)(loggingAsyncCallback(logger)).unsafeRunSync()
}

private class AsyncTimeoutHandler(cb: Callback[Unit]) extends AbstractAsyncListener {
private class AsyncTimeoutHandler(cb: Callback[Response[F]]) extends AbstractAsyncListener {
override def onTimeout(event: AsyncEvent): Unit = {
val req = event.getAsyncContext.getRequest.asInstanceOf[HttpServletRequest]
logger.info(s"Request timed out: ${req.getMethod} ${req.getServletPath}${req.getPathInfo}")
cb(Right(()))
cb(Right(Response.timeout[F]))
}
}
}
Expand Down
Expand Up @@ -43,7 +43,7 @@ class BlockingHttp4sServlet[F[_]](
.suspend(serviceFn(request))
.getOrElse(Response.notFound)
.recoverWith(serviceErrorHandler(request))
.flatMap(renderResponse(_, servletResponse, bodyWriter, F.never))
.flatMap(renderResponse(_, servletResponse, bodyWriter))

private def errorHandler(servletResponse: HttpServletResponse): PartialFunction[Throwable, Unit] = {
case t: Throwable if servletResponse.isCommitted =>
Expand All @@ -54,7 +54,7 @@ class BlockingHttp4sServlet[F[_]](
val response = Response[F](Status.InternalServerError)
// We don't know what I/O mode we're in here, and we're not rendering a body
// anyway, so we use a NullBodyWriter.
val render = renderResponse(response, servletResponse, NullBodyWriter, F.never)
val render = renderResponse(response, servletResponse, NullBodyWriter)
F.runAsync(render)(_ => IO.unit).unsafeRunSync()
}
}
Expand Down
7 changes: 3 additions & 4 deletions servlet/src/main/scala/org/http4s/servlet/Http4sServlet.scala
Expand Up @@ -40,14 +40,13 @@ abstract class Http4sServlet[F[_]](service: HttpRoutes[F], servletIo: ServletIo[
bodyWriter: BodyWriter[F]
): F[Unit] = {
val response = Response[F](Status.BadRequest).withEntity(parseFailure.sanitized)
renderResponse(response, servletResponse, bodyWriter, F.async(_ => ()))
renderResponse(response, servletResponse, bodyWriter)
}

protected def renderResponse(
response: Response[F],
servletResponse: HttpServletResponse,
bodyWriter: BodyWriter[F],
timeout: F[Unit]
bodyWriter: BodyWriter[F]
): F[Unit] =
// Note: the servlet API gives us no undeprecated method to both set
// a body and a status reason. We sacrifice the status reason.
Expand All @@ -58,7 +57,7 @@ abstract class Http4sServlet[F[_]](service: HttpRoutes[F], servletIo: ServletIo[
}
.attempt
.flatMap {
case Right(()) => bodyWriter(response, timeout)
case Right(()) => bodyWriter(response)
case Left(t) =>
response.body.drain.compile.drain.handleError {
case t2 => logger.error(t2)("Error draining body")
Expand Down
29 changes: 14 additions & 15 deletions servlet/src/main/scala/org/http4s/servlet/ServletIo.scala
Expand Up @@ -42,20 +42,19 @@ final case class BlockingServletIo[F[_]: Effect: ContextShift](
blockingExecutionContext)

override protected[servlet] def initWriter(
servletResponse: HttpServletResponse): BodyWriter[F] = {
(response: Response[F], timeout: F[Unit]) =>
val out = servletResponse.getOutputStream
val flush = response.isChunked
response.body.chunks
.map { chunk =>
// Avoids copying for specialized chunks
val byteChunk = chunk.toBytes
out.write(byteChunk.values, byteChunk.offset, byteChunk.length)
if (flush)
servletResponse.flushBuffer()
}
.compile
.drain
servletResponse: HttpServletResponse): BodyWriter[F] = { response: Response[F] =>
val out = servletResponse.getOutputStream
val flush = response.isChunked
response.body.chunks
.map { chunk =>
// Avoids copying for specialized chunks
val byteChunk = chunk.toBytes
out.write(byteChunk.values, byteChunk.offset, byteChunk.length)
if (flush)
servletResponse.flushBuffer()
}
.compile
.drain
}
}

Expand Down Expand Up @@ -233,7 +232,7 @@ final case class NonBlockingServletIo[F[_]: Effect](chunkSize: Int) extends Serv
}
}

{ (response: Response[F], timeout: F[Unit]) =>
{ response: Response[F] =>
if (response.isChunked)
autoFlush = true
response.body.chunks
Expand Down
8 changes: 5 additions & 3 deletions servlet/src/main/scala/org/http4s/servlet/package.scala
@@ -1,10 +1,12 @@
package org.http4s

import cats.effect.Async

package object servlet {
protected[servlet] type BodyWriter[F[_]] = (Response[F], F[Unit]) => F[Unit]
protected[servlet] type BodyWriter[F[_]] = Response[F] => F[Unit]

protected[servlet] def NullBodyWriter[F[_]]: BodyWriter[F] =
(_, timeout) => timeout
protected[servlet] def NullBodyWriter[F[_]](implicit F: Async[F]): BodyWriter[F] =
_ => F.unit

protected[servlet] val DefaultChunkSize = 4096
}
21 changes: 20 additions & 1 deletion tomcat/src/test/scala/org/http4s/tomcat/TomcatServerSpec.scala
Expand Up @@ -2,11 +2,14 @@ package org.http4s
package server
package tomcat

import cats.effect.IO
import cats.effect.{IO, Timer}
import cats.implicits._
import java.io.IOException
import java.net.{HttpURLConnection, URL}
import java.nio.charset.StandardCharsets
import org.http4s.dsl.io._
import org.specs2.specification.AfterAll
import scala.concurrent.duration._
import scala.io.Source
import org.apache.catalina.webresources.TomcatURLStreamHandlerFactory

Expand Down Expand Up @@ -34,6 +37,12 @@ class TomcatServerSpec extends {

case req @ POST -> Root / "echo" =>
Ok(req.body)

case GET -> Root / "never" =>
IO.never

case GET -> Root / "slow" =>
implicitly[Timer[IO]].sleep(50.millis) *> Ok("slow")
},
"/"
)
Expand Down Expand Up @@ -75,4 +84,14 @@ class TomcatServerSpec extends {
post("/echo", input) must startWith(input)
}
}

"Timeout" should {
"not fire prematurely" in {
get("/slow") must_== "slow"
}

"fire on timeout" in {
get("/never") must throwAn[IOException]
}
}
}

0 comments on commit 38b7389

Please sign in to comment.