diff --git a/docs/src/main/mdoc/middleware.md b/docs/src/main/mdoc/middleware.md index ca6772692f4..c7d666acb58 100644 --- a/docs/src/main/mdoc/middleware.md +++ b/docs/src/main/mdoc/middleware.md @@ -148,6 +148,7 @@ package. These include: * [Jsonp] * [Virtual Host] * [Metrics] +* [`X-Request-ID` header] And a few others. @@ -218,6 +219,40 @@ val meteredRouter: Resource[IO, HttpRoutes[IO]] = ``` +### X-Request-ID Middleware + +Use the `RequestId` middleware to automatically generate a `X-Request-ID` header to a request, +if one wasn't supplied. Adds a `X-Request-ID` header to the response with the id generated +or supplied as part of the request. + +This [heroku guide](https://devcenter.heroku.com/articles/http-request-id) gives a brief explanation +as to why this header is useful. + +```scala mdoc:silent +import org.http4s.server.middleware.RequestId +import org.typelevel.ci.CIString + +val requestIdService = RequestId.httpRoutes(HttpRoutes.of[IO] { + case req => + val reqId = req.headers.get(CIString("X-Request-ID")).fold("null")(_.value) + // use request id to correlate logs with the request + IO(println(s"request received, cid=$reqId")) *> Ok() +}) +val responseIO = requestIdService.orNotFound(goodRequest) +``` + +Note: `req.attributes.lookup(RequestId.requestIdAttrKey)` can also be used to lookup the request id +extracted from the header, or the generated request id. + +```scala mdoc +// generated request id can be correlated with logs +val resp = responseIO.unsafeRunSync() +// X-Request-ID header added to response +resp.headers +// the request id is also available using attributes +resp.attributes.lookup(RequestId.requestIdAttrKey) +``` + [service]: ../service [dsl]: ../dsl [Authentication]: ../auth @@ -228,4 +263,5 @@ val meteredRouter: Resource[IO, HttpRoutes[IO]] = [Jsonp]: ../api/org/http4s/server/middleware/Jsonp$ [Virtual Host]: ../api/org/http4s/server/middleware/VirtualHost$ [Metrics]: ../api/org/http4s/server/middleware/Metrics$ +[`X-Request-ID` header]: ../api/org/http4s/server/middleware/RequestId$ [`Kleisli`]: https://typelevel.org/cats/datatypes/kleisli.html diff --git a/server/src/main/scala/org/http4s/server/middleware/RequestId.scala b/server/src/main/scala/org/http4s/server/middleware/RequestId.scala new file mode 100644 index 00000000000..94bbf19d8a4 --- /dev/null +++ b/server/src/main/scala/org/http4s/server/middleware/RequestId.scala @@ -0,0 +1,96 @@ +/* + * Copyright 2013-2020 http4s.org + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.http4s + +package server +package middleware + +import org.http4s.{Header, Http, Request, Response} +import cats.{FlatMap, ~>} +import cats.arrow.FunctionK +import cats.data.{Kleisli, OptionT} +import cats.effect.{IO, Sync} +import cats.implicits._ +import org.typelevel.ci.CIString +import io.chrisdavenport.vault.Key +import java.util.UUID + +/** Propagate a `X-Request-Id` header to the response, generate a UUID + * when the `X-Request-Id` header is unset. + * https://devcenter.heroku.com/articles/http-request-id + */ +object RequestId { + + private[this] val requestIdHeader = CIString("X-Request-ID") + + val requestIdAttrKey: Key[String] = Key.newKey[IO, String].unsafeRunSync + + def apply[G[_], F[_]](http: Http[G, F])(implicit G: Sync[G]): Http[G, F] = + apply(requestIdHeader)(http) + + def apply[G[_], F[_]]( + headerName: CIString + )(http: Http[G, F])(implicit G: Sync[G]): Http[G, F] = + Kleisli[G, Request[F], Response[F]] { req => + for { + header <- req.headers.get(headerName) match { + case None => G.delay(Header.Raw(headerName, UUID.randomUUID().toString())) + case Some(header) => G.pure[Header](header) + } + reqId = header.value + response <- http(req.withAttribute(requestIdAttrKey, reqId).putHeaders(header)) + } yield response.withAttribute(requestIdAttrKey, reqId).putHeaders(header) + } + + def apply[G[_], F[_]]( + fk: F ~> G, + headerName: CIString = requestIdHeader, + genReqId: F[UUID] + )(http: Http[G, F])(implicit G: FlatMap[G], F: Sync[F]): Http[G, F] = + Kleisli[G, Request[F], Response[F]] { req => + for { + header <- fk(req.headers.get(headerName) match { + case None => genReqId.map(reqId => Header.Raw(headerName, reqId.show)) + case Some(header) => F.pure[Header](header) + }) + reqId = header.value + response <- http(req.withAttribute(requestIdAttrKey, reqId).putHeaders(header)) + } yield response.withAttribute(requestIdAttrKey, reqId).putHeaders(header) + } + + object httpApp { + def apply[F[_]: Sync](httpApp: HttpApp[F]): HttpApp[F] = + RequestId.apply(requestIdHeader)(httpApp) + + def apply[F[_]: Sync]( + headerName: CIString + )(httpApp: HttpApp[F]): HttpApp[F] = + RequestId.apply(headerName)(httpApp) + + def apply[F[_]: Sync]( + headerName: CIString = requestIdHeader, + genReqId: F[UUID] + )(httpApp: HttpApp[F]): HttpApp[F] = + RequestId.apply(FunctionK.id[F], headerName, genReqId)(httpApp) + } + + object httpRoutes { + def apply[F[_]: Sync](httpRoutes: HttpRoutes[F]): HttpRoutes[F] = + RequestId.apply(requestIdHeader)(httpRoutes) + + def apply[F[_]: Sync]( + headerName: CIString + )(httpRoutes: HttpRoutes[F]): HttpRoutes[F] = + RequestId.apply(headerName)(httpRoutes) + + def apply[F[_]: Sync]( + headerName: CIString = requestIdHeader, + genReqId: F[UUID] + )(httpRoutes: HttpRoutes[F]): HttpRoutes[F] = + RequestId.apply(OptionT.liftK[F], headerName, genReqId)(httpRoutes) + } +} diff --git a/server/src/test/scala/org/http4s/server/middleware/RequestIdSpec.scala b/server/src/test/scala/org/http4s/server/middleware/RequestIdSpec.scala new file mode 100644 index 00000000000..668e790ddcf --- /dev/null +++ b/server/src/test/scala/org/http4s/server/middleware/RequestIdSpec.scala @@ -0,0 +1,126 @@ +/* + * Copyright 2013-2020 http4s.org + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.http4s.server.middleware + +import cats.effect._ +import cats.implicits._ +import org.http4s._ +import org.http4s.dsl.io._ +import org.http4s.Uri.uri +import org.typelevel.ci.CIString +import java.util.UUID + +class RequestIdSpec extends Http4sSpec { + private def testService(headerKey: CIString = CIString("X-Request-ID")) = + HttpRoutes.of[IO] { + case req @ GET -> Root / "request" => + Ok(show"request-id: ${req.headers.get(headerKey).fold("None")(_.value)}") + case req @ GET -> Root / "attribute" => + Ok( + show"request-id: ${req.attributes.lookup(RequestId.requestIdAttrKey).getOrElse[String]("None")}") + } + + private def requestIdFromBody(resp: Response[IO]) = + resp.as[String].map(_.stripPrefix("request-id: ")) + + private def requestIdFromHeaders( + resp: Response[IO], + headerKey: CIString = CIString("X-Request-ID")) = + resp.headers.get(headerKey).fold("None")(_.value) + + "RequestId middleware" should { + "propagate X-Request-ID header from request to response" in { + val req = + Request[IO](uri = uri("/request"), headers = Headers.of(Header("X-Request-ID", "123"))) + val (reqReqId, respReqId) = RequestId + .httpRoutes(testService()) + .orNotFound(req) + .flatMap { resp => + requestIdFromBody(resp).map(_ -> requestIdFromHeaders(resp)) + } + .unsafeRunSync() + + (reqReqId must_=== "123").and(respReqId must_=== "123") + } + "generate X-Request-ID header when unset" in { + val req = Request[IO](uri = uri("/request")) + val (reqReqId, respReqId) = RequestId + .httpRoutes(testService()) + .orNotFound(req) + .flatMap { resp => + requestIdFromBody(resp).map(_ -> requestIdFromHeaders(resp)) + } + .unsafeRunSync() + + (reqReqId must_=== respReqId).and( + Either.catchNonFatal(UUID.fromString(respReqId)) must (beRight)) + } + "generate different request ids on subsequent requests" in { + val req = Request[IO](uri = uri("/request")) + val resp = RequestId.httpRoutes(testService()).orNotFound(req) + val requestId1 = resp.map(requestIdFromHeaders(_)).unsafeRunSync() + val requestId2 = resp.map(requestIdFromHeaders(_)).unsafeRunSync() + + (requestId1 must_!== requestId2) + } + "propagate custom request id header from request to response" in { + val req = Request[IO]( + uri = uri("/request"), + headers = Headers.of(Header("X-Request-ID", "123"), Header("X-Correlation-ID", "abc"))) + val (reqReqId, respReqId) = RequestId + .httpRoutes(CIString("X-Correlation-ID"))(testService(CIString("X-Correlation-ID"))) + .orNotFound(req) + .flatMap { resp => + requestIdFromBody(resp).map(_ -> requestIdFromHeaders(resp, CIString("X-Correlation-ID"))) + } + .unsafeRunSync() + + (reqReqId must_=== "abc").and(respReqId must_=== "abc") + } + "generate custom request id header when unset" in { + val req = + Request[IO](uri = uri("/request"), headers = Headers.of(Header("X-Request-ID", "123"))) + val (reqReqId, respReqId) = RequestId + .httpRoutes(CIString("X-Correlation-ID"))(testService(CIString("X-Correlation-ID"))) + .orNotFound(req) + .flatMap { resp => + requestIdFromBody(resp).map(_ -> requestIdFromHeaders(resp, CIString("X-Correlation-ID"))) + } + .unsafeRunSync() + + (reqReqId must_=== respReqId).and( + Either.catchNonFatal(UUID.fromString(respReqId)) must (beRight)) + } + "generate X-Request-ID header when unset using supplied generator" in { + val uuid = UUID.fromString("00000000-0000-0000-0000-000000000000") + val req = Request[IO](uri = uri("/request")) + val (reqReqId, respReqId) = RequestId + .httpRoutes(genReqId = IO.pure(uuid))(testService()) + .orNotFound(req) + .flatMap { resp => + requestIdFromBody(resp).map(_ -> requestIdFromHeaders(resp)) + } + .unsafeRunSync() + + (reqReqId must_=== uuid.show).and(respReqId must_=== uuid.show) + } + "include requestId attribute with request and response" in { + val req = + Request[IO](uri = uri("/attribute"), headers = Headers.of(Header("X-Request-ID", "123"))) + val (reqReqId, respReqId) = RequestId + .httpRoutes(testService()) + .orNotFound(req) + .flatMap { resp => + requestIdFromBody(resp).map( + _ -> resp.attributes.lookup(RequestId.requestIdAttrKey).getOrElse("None")) + } + .unsafeRunSync() + + (reqReqId must_=== "123").and(respReqId must_=== "123") + } + } +}