diff --git a/server/src/main/scala/org/http4s/server/middleware/HttpMethodOverrider.scala b/server/src/main/scala/org/http4s/server/middleware/HttpMethodOverrider.scala new file mode 100644 index 00000000000..745eb11195c --- /dev/null +++ b/server/src/main/scala/org/http4s/server/middleware/HttpMethodOverrider.scala @@ -0,0 +1,139 @@ +package org.http4s +package server +package middleware + +import cats.data.Kleisli +import cats.effect._ +import cats.instances.option._ +import cats.syntax.functor._ +import cats.syntax.flatMap._ +import cats.syntax.alternative._ +import cats.{Monad, ~>} +import io.chrisdavenport.vault.Key +import org.http4s.Http +import org.http4s.util.CaseInsensitiveString + +object HttpMethodOverrider { + + /** + * HttpMethodOverrider middleware config options. + */ + class HttpMethodOverriderConfig[F[_], G[_]]( + val overrideStrategy: OverrideStrategy[F, G], + val overridableMethods: Set[Method]) { + + type Self = HttpMethodOverriderConfig[F, G] + + private def copy( + overrideStrategy: OverrideStrategy[F, G] = overrideStrategy, + overridableMethods: Set[Method] = overridableMethods + ): Self = + new HttpMethodOverriderConfig[F, G](overrideStrategy, overridableMethods) + + def withOverrideStrategy(overrideStrategy: OverrideStrategy[F, G]): Self = + copy(overrideStrategy = overrideStrategy) + + def withOverridableMethods(overridableMethods: Set[Method]): Self = + copy(overridableMethods = overridableMethods) + } + + object HttpMethodOverriderConfig { + def apply[F[_], G[_]]( + overrideStrategy: OverrideStrategy[F, G], + overridableMethods: Set[Method]): HttpMethodOverriderConfig[F, G] = + new HttpMethodOverriderConfig[F, G](overrideStrategy, overridableMethods) + } + + sealed trait OverrideStrategy[F[_], G[_]] + final case class HeaderOverrideStrategy[F[_], G[_]](headerName: CaseInsensitiveString) + extends OverrideStrategy[F, G] + final case class QueryOverrideStrategy[F[_], G[_]](paramName: String) + extends OverrideStrategy[F, G] + final case class FormOverrideStrategy[F[_], G[_]]( + fieldName: String, + naturalTransformation: G ~> F) + extends OverrideStrategy[F, G] + + def defaultConfig[F[_], G[_]]: HttpMethodOverriderConfig[F, G] = + HttpMethodOverriderConfig[F, G]( + HeaderOverrideStrategy(CaseInsensitiveString("X-HTTP-Method-Override")), + Set(Method.POST)) + + val overriddenMethodAttrKey: Key[Method] = Key.newKey[IO, Method].unsafeRunSync + + /** Simple middleware for HTTP Method Override. + * + * This middleware lets you use HTTP verbs such as PUT or DELETE in places where the client + * doesn't support it. Camouflage your request with another HTTP verb(usually POST) and sneak + * the desired one using a custom header or request parameter. The middleware will '''override''' + * the original verb with the new one for you, allowing the request the be dispatched properly. + * + * @param http [[Http]] to transform + * @param config http method overrider config + */ + def apply[F[_], G[_]](http: Http[F, G], config: HttpMethodOverriderConfig[F, G])( + implicit F: Monad[F], + S: Sync[G]): Http[F, G] = { + + val parseMethod = (m: String) => Method.fromString(m.toUpperCase) + + val processRequestWithOriginalMethod = (req: Request[G]) => http(req) + + def processRequestWithMethod( + req: Request[G], + parseResult: ParseResult[Method]): F[Response[G]] = parseResult match { + case Left(_) => F.pure(Response[G](Status.BadRequest)) + case Right(om) => http(updateRequestWithMethod(req, om)).map(updateVaryHeader) + } + + def updateVaryHeader(resp: Response[G]): Response[G] = { + val varyHeaderName = CaseInsensitiveString("Vary") + config.overrideStrategy match { + case HeaderOverrideStrategy(headerName) => + val updatedVaryHeader = + resp.headers + .get(varyHeaderName) + .map((h: Header) => Header(h.name.value, s"${h.value}, ${headerName.value}")) + .getOrElse(Header(varyHeaderName.value, headerName.value)) + + resp.withHeaders(resp.headers.put(updatedVaryHeader)) + case _ => resp + } + } + + def updateRequestWithMethod(req: Request[G], om: Method): Request[G] = { + val attrs = req.attributes.insert(overriddenMethodAttrKey, req.method) + req.withAttributes(attrs).withMethod(om) + } + + def getUnsafeOverrideMethod(req: Request[G]): F[Option[String]] = + config.overrideStrategy match { + case HeaderOverrideStrategy(headerName) => F.pure(req.headers.get(headerName).map(_.value)) + case QueryOverrideStrategy(parameter) => F.pure(req.params.get(parameter)) + case FormOverrideStrategy(field, f) => + for { + formFields <- f( + UrlForm + .entityDecoder[G] + .decode(req, strict = true) + .value + .map(_.toOption.map(_.values))) + } yield formFields.flatMap(_.get(field).flatMap(_.uncons.map(_._1))) + } + + def processRequest(req: Request[G]): F[Response[G]] = getUnsafeOverrideMethod(req).flatMap { + case Some(m: String) => parseMethod.andThen(processRequestWithMethod(req, _)).apply(m) + case None => processRequestWithOriginalMethod(req) + } + + Kleisli { req: Request[G] => + { + config.overridableMethods + .contains(req.method) + .guard[Option] + .as(processRequest(req)) + .getOrElse(processRequestWithOriginalMethod(req)) + } + } + } +} diff --git a/server/src/test/scala/org/http4s/server/middleware/HttpMethodOverriderSpec.scala b/server/src/test/scala/org/http4s/server/middleware/HttpMethodOverriderSpec.scala new file mode 100644 index 00000000000..abe7f542e2d --- /dev/null +++ b/server/src/test/scala/org/http4s/server/middleware/HttpMethodOverriderSpec.scala @@ -0,0 +1,323 @@ +package org.http4s.server.middleware + +import cats.effect.IO +import cats.~> +import org.http4s._ +import org.http4s.dsl.io._ +import org.http4s.server.Router +import org.http4s.server.middleware.HttpMethodOverrider._ +import org.http4s.util.CaseInsensitiveString + +class HttpMethodOverriderSpec extends Http4sSpec { + + private final val overrideHeader = "X-HTTP-Method-Override" + private final val overrideParam, overrideField: String = "_method" + private final val varyHeader = "Vary" + private final val customHeader = "X-Custom-Header" + + private def headerOverrideStrategy[F[_], G[_]] = + HeaderOverrideStrategy[F, G](CaseInsensitiveString(overrideHeader)) + private def queryOverrideStrategy[F[_], G[_]] = QueryOverrideStrategy[F, G](overrideParam) + private val formOverrideStrategy = FormOverrideStrategy(overrideParam, λ[IO ~> IO](i => i)) + + private def postHeaderOverriderConfig[F[_], G[_]] = defaultConfig[F, G] + private def postQueryOverriderConfig[F[_], G[_]] = + HttpMethodOverriderConfig[F, G](queryOverrideStrategy, Set(POST)) + private val postFormOverriderConfig = + HttpMethodOverriderConfig(formOverrideStrategy, Set(POST)) + private def deleteHeaderOverriderConfig[F[_], G[_]] = + HttpMethodOverriderConfig[F, G](headerOverrideStrategy, Set(DELETE)) + private def deleteQueryOverriderConfig[F[_], G[_]] = + HttpMethodOverriderConfig[F, G](queryOverrideStrategy, Set(DELETE)) + private val deleteFormOverriderConfig = + HttpMethodOverriderConfig(formOverrideStrategy, Set(DELETE)) + private def noMethodHeaderOverriderConfig[F[_], G[_]] = + HttpMethodOverriderConfig[F, G](headerOverrideStrategy, Set.empty) + + private val testApp = Router("/" -> HttpRoutes.of[IO] { + case r @ GET -> Root / "resources" / "id" => + Ok(responseText[IO](msg = "resource's details", r)) + case r @ PUT -> Root / "resources" / "id" => + Ok(responseText(msg = "resource updated", r), Header(varyHeader, customHeader)) + case r @ DELETE -> Root / "resources" / "id" => + Ok(responseText(msg = "resource deleted", r)) + }).orNotFound + + private def mkResponseText( + msg: String, + reqMethod: Method, + overriddenMethod: Option[Method]): String = + overriddenMethod + .map(om => s"[$om ~> $reqMethod] => $msg") + .getOrElse(s"[$reqMethod] => $msg") + + private def responseText[F[_]](msg: String, req: Request[F]): String = { + val overriddenMethod = req.attributes.lookup(HttpMethodOverrider.overriddenMethodAttrKey) + mkResponseText(msg, req.method, overriddenMethod) + } + + "MethodOverrider middleware" should { + "ignore method override if request method not in the overridable method list" in { + val req = Request[IO](uri = Uri.uri("/resources/id")) + .withMethod(GET) + .withHeaders(Header(overrideHeader, "PUT")) + val app = HttpMethodOverrider(testApp, noMethodHeaderOverriderConfig) + + val res = app(req) + res must returnStatus(Status.Ok) + res must returnBody( + mkResponseText(msg = "resource's details", reqMethod = GET, overriddenMethod = None)) + } + + "override request method when using header method overrider strategy if override method provided" in { + val req = Request[IO](uri = Uri.uri("/resources/id")) + .withMethod(POST) + .withHeaders(Header(overrideHeader, "PUT")) + val app = HttpMethodOverrider(testApp, postHeaderOverriderConfig) + + val res = app(req) + res must returnStatus(Status.Ok) + res must returnBody( + mkResponseText(msg = "resource updated", reqMethod = PUT, overriddenMethod = Some(POST))) + } + + "not override request method when using header method overrider strategy if override method not provided" in { + val req = Request[IO](uri = Uri.uri("/resources/id")) + .withMethod(DELETE) + val app = HttpMethodOverrider(testApp, deleteHeaderOverriderConfig) + + val res = app(req) + res must returnStatus(Status.Ok) + res must returnBody( + mkResponseText(msg = "resource deleted", reqMethod = DELETE, overriddenMethod = None)) + } + + "override request method and store the original method when using query method overrider strategy" in { + val req = Request[IO](uri = Uri.uri("/resources/id?_method=PUT")) + .withMethod(POST) + val app = HttpMethodOverrider(testApp, postQueryOverriderConfig) + + val res = app(req) + res must returnStatus(Status.Ok) + res must returnBody( + mkResponseText(msg = "resource updated", reqMethod = PUT, overriddenMethod = Some(POST))) + } + + "not override request method when using query method overrider strategy if override method not provided" in { + val req = Request[IO](uri = Uri.uri("/resources/id")) + .withMethod(DELETE) + val app = HttpMethodOverrider(testApp, deleteQueryOverriderConfig) + + val res = app(req) + res must returnStatus(Status.Ok) + res must returnBody( + mkResponseText(msg = "resource deleted", reqMethod = DELETE, overriddenMethod = None)) + } + + "override request method and store the original method when using form method overrider strategy" in { + val urlForm = UrlForm("foo" -> "bar", overrideField -> "PUT") + val req = Request[IO](uri = Uri.uri("/resources/id")) + .withEntity(urlForm) + .withMethod(POST) + val app = HttpMethodOverrider(testApp, postFormOverriderConfig) + + val res = app(req) + res must returnStatus(Status.Ok) + res must returnBody( + mkResponseText(msg = "resource updated", reqMethod = PUT, overriddenMethod = Some(POST))) + } + + "not override request method when using form method overrider strategy if override method not provided" in { + val urlForm = UrlForm("foo" -> "bar") + val req = Request[IO](uri = Uri.uri("/resources/id")) + .withEntity(urlForm) + .withMethod(DELETE) + val app = HttpMethodOverrider(testApp, deleteFormOverriderConfig) + + val res = app(req) + res must returnStatus(Status.Ok) + res must returnBody( + mkResponseText(msg = "resource deleted", reqMethod = DELETE, overriddenMethod = None)) + } + + "return 404 when using header method overrider strategy if override method provided is not recognized" in { + val req = Request[IO](uri = Uri.uri("/resources/id")) + .withMethod(POST) + .withHeaders(Header(overrideHeader, "INVALID")) + val app = HttpMethodOverrider(testApp, postHeaderOverriderConfig) + + val res = app(req) + res must returnStatus(Status.NotFound) + } + + "return 404 when using query method overrider strategy if override method provided is not recognized" in { + val req = Request[IO](uri = Uri.uri("/resources/id?_method=INVALID")) + .withMethod(POST) + val app = HttpMethodOverrider(testApp, postQueryOverriderConfig) + + val res = app(req) + res must returnStatus(Status.NotFound) + } + + "return 404 when using form method overrider strategy if override method provided is not recognized" in { + val urlForm = UrlForm("foo" -> "bar", overrideField -> "INVALID") + val req = Request[IO](uri = Uri.uri("/resources/id")) + .withEntity(urlForm) + .withMethod(POST) + val app = HttpMethodOverrider(testApp, postFormOverriderConfig) + + val res = app(req) + res must returnStatus(Status.NotFound) + } + + "return 400 when using header method overrider strategy if override method provided is duped" in { + val req = Request[IO](uri = Uri.uri("/resources/id")) + .withMethod(POST) + .withHeaders(Header(overrideHeader, "")) + val app = HttpMethodOverrider(testApp, postHeaderOverriderConfig) + + val res = app(req) + res must returnStatus(Status.BadRequest) + } + + "return 400 when using query method overrider strategy if override method provided is duped" in { + val req = Request[IO](uri = Uri.uri("/resources/id?_method=")) + .withMethod(POST) + val app = HttpMethodOverrider(testApp, postQueryOverriderConfig) + + val res = app(req) + res must returnStatus(Status.BadRequest) + } + + "return 400 when using form method overrider strategy if override method provided is duped" in { + val urlForm = UrlForm("foo" -> "bar", overrideField -> "") + val req = Request[IO](uri = Uri.uri("/resources/id")) + .withEntity(urlForm) + .withMethod(POST) + val app = HttpMethodOverrider(testApp, postFormOverriderConfig) + + val res = app(req) + res must returnStatus(Status.BadRequest) + } + + "override request method when using header method overrider strategy and be case insensitive" in { + val req = Request[IO](uri = Uri.uri("/resources/id")) + .withMethod(POST) + .withHeaders(Header(overrideHeader, "pUt")) + val app = HttpMethodOverrider(testApp, postHeaderOverriderConfig) + + val res = app(req) + res must returnStatus(Status.Ok) + res must returnBody( + mkResponseText(msg = "resource updated", reqMethod = PUT, overriddenMethod = Some(POST))) + } + + "override request method when using query method overrider strategy and be case insensitive" in { + val req = Request[IO](uri = Uri.uri("/resources/id?_method=pUt")) + .withMethod(POST) + val app = HttpMethodOverrider(testApp, postQueryOverriderConfig) + + val res = app(req) + res must returnStatus(Status.Ok) + res must returnBody( + mkResponseText(msg = "resource updated", reqMethod = PUT, overriddenMethod = Some(POST))) + } + + "override request method when form query method overrider strategy and be case insensitive" in { + val urlForm = UrlForm("foo" -> "bar", overrideField -> "pUt") + val req = Request[IO](uri = Uri.uri("/resources/id")) + .withEntity(urlForm) + .withMethod(POST) + val app = HttpMethodOverrider(testApp, postFormOverriderConfig) + + val res = app(req) + res must returnStatus(Status.Ok) + res must returnBody( + mkResponseText(msg = "resource updated", reqMethod = PUT, overriddenMethod = Some(POST))) + } + + "updates vary header when using query method overrider strategy and vary header comes pre-populated" in { + val req = Request[IO](uri = Uri.uri("/resources/id")) + .withMethod(POST) + .withHeaders(Header(overrideHeader, "PUT")) + val app = HttpMethodOverrider(testApp, postHeaderOverriderConfig) + + val res = app(req) + res must returnStatus(Status.Ok) + res must returnBody( + mkResponseText(msg = "resource updated", reqMethod = PUT, overriddenMethod = Some(POST))) + + res must returnValue(containsHeader(Header(varyHeader, s"$customHeader, $overrideHeader"))) + } + + "set vary header when using header method overrider strategy and vary header has not been set" in { + val req = Request[IO](uri = Uri.uri("/resources/id")) + .withMethod(POST) + .withHeaders(Header(overrideHeader, "DELETE")) + val app = HttpMethodOverrider(testApp, postHeaderOverriderConfig) + + val res = app(req) + res must returnStatus(Status.Ok) + res must returnBody( + mkResponseText(msg = "resource deleted", reqMethod = DELETE, overriddenMethod = Some(POST))) + + res must returnValue(containsHeader(Header(varyHeader, s"$overrideHeader"))) + } + + "not set vary header when using query method overrider strategy and vary header has not been set" in { + val req = Request[IO](uri = Uri.uri("/resources/id?_method=DELETE")) + .withMethod(POST) + val app = HttpMethodOverrider(testApp, postQueryOverriderConfig) + + val res = app(req) + res must returnStatus(Status.Ok) + res must returnBody( + mkResponseText(msg = "resource deleted", reqMethod = DELETE, overriddenMethod = Some(POST))) + + res must returnValue(doesntContainHeader(CaseInsensitiveString(varyHeader))) + } + + "not update vary header when using query method overrider strategy and vary header comes pre-populated" in { + val req = Request[IO](uri = Uri.uri("/resources/id?_method=PUT")) + .withMethod(POST) + val app = HttpMethodOverrider(testApp, postQueryOverriderConfig) + + val res = app(req) + res must returnStatus(Status.Ok) + res must returnBody( + mkResponseText(msg = "resource updated", reqMethod = PUT, overriddenMethod = Some(POST))) + + res must returnValue(containsHeader(Header(varyHeader, s"$customHeader"))) + } + + "not set vary header when using form method overrider strategy and vary header has not been set" in { + val urlForm = UrlForm("foo" -> "bar", overrideField -> "DELETE") + val req = Request[IO](uri = Uri.uri("/resources/id")) + .withEntity(urlForm) + .withMethod(POST) + val app = HttpMethodOverrider(testApp, postFormOverriderConfig) + + val res = app(req) + res must returnStatus(Status.Ok) + res must returnBody( + mkResponseText(msg = "resource deleted", reqMethod = DELETE, overriddenMethod = Some(POST))) + + res must returnValue(doesntContainHeader(CaseInsensitiveString(varyHeader))) + } + + "not update vary header when using form method overrider strategy and vary header comes pre-populated" in { + val urlForm = UrlForm("foo" -> "bar", overrideField -> "PUT") + val req = Request[IO](uri = Uri.uri("/resources/id")) + .withEntity(urlForm) + .withMethod(POST) + val app = HttpMethodOverrider(testApp, postFormOverriderConfig) + + val res = app(req) + res must returnStatus(Status.Ok) + res must returnBody( + mkResponseText(msg = "resource updated", reqMethod = PUT, overriddenMethod = Some(POST))) + + res must returnValue(containsHeader(Header(varyHeader, s"$customHeader"))) + } + } +} diff --git a/testing/src/main/scala/org/http4s/testing/Http4sMatchers.scala b/testing/src/main/scala/org/http4s/testing/Http4sMatchers.scala index 751960a75a7..35638f0e688 100644 --- a/testing/src/main/scala/org/http4s/testing/Http4sMatchers.scala +++ b/testing/src/main/scala/org/http4s/testing/Http4sMatchers.scala @@ -4,6 +4,7 @@ package testing import cats.syntax.flatMap._ import cats.data.EitherT import org.http4s.headers._ +import org.http4s.util.CaseInsensitiveString import org.specs2.matcher._ /** This might be useful in a testkit spinoff. Let's see what they do for us. */ @@ -34,6 +35,16 @@ trait Http4sMatchers[F[_]] extends Matchers with RunTimedMatchers[F] { m.headers.aka("the headers") } + def containsHeader(h: Header): Matcher[Message[F]] = + beSome(h.value) ^^ { m: Message[F] => + m.headers.get(h.name).map(_.value).aka("the particular header") + } + + def doesntContainHeader(h: CaseInsensitiveString): Matcher[Message[F]] = + beNone ^^ { m: Message[F] => + m.headers.get(h).aka("the particular header") + } + def haveMediaType(mt: MediaType): Matcher[Message[F]] = beSome(mt) ^^ { m: Message[F] => m.headers.get(`Content-Type`).map(_.mediaType).aka("the media type header")