From 263a9e0eaf4327601211a30a7bddac96e4981b4e Mon Sep 17 00:00:00 2001 From: Eshu Date: Sat, 7 Oct 2023 18:30:16 +0900 Subject: [PATCH] Issue #2321: MonoQuery and MultiQuery --- .../zio/http/endpoint/cli/CliEndpoint.scala | 16 +--- .../zio/http/endpoint/cli/EndpointGen.scala | 4 +- .../src/main/scala/zio/http/FormField.scala | 10 +-- .../src/main/scala/zio/http/Middleware.scala | 78 +++++++++++++++++++ .../main/scala/zio/http/codec/HttpCodec.scala | 45 +++++++++-- .../scala/zio/http/codec/HttpCodecError.scala | 3 + .../scala/zio/http/codec/QueryCodecs.scala | 22 ++++-- .../http/codec/internal/AtomizedCodecs.scala | 4 +- .../http/codec/internal/EncoderDecoder.scala | 20 +---- .../scala/zio/http/codec/HttpCodecSpec.scala | 15 +++- 10 files changed, 163 insertions(+), 54 deletions(-) diff --git a/zio-http-cli/src/main/scala/zio/http/endpoint/cli/CliEndpoint.scala b/zio-http-cli/src/main/scala/zio/http/endpoint/cli/CliEndpoint.scala index a5bc1938ba2..255f4ce6f43 100644 --- a/zio-http-cli/src/main/scala/zio/http/endpoint/cli/CliEndpoint.scala +++ b/zio-http-cli/src/main/scala/zio/http/endpoint/cli/CliEndpoint.scala @@ -1,15 +1,7 @@ package zio.http.endpoint.cli -import scala.util.Try - -import zio.cli._ - -import zio.schema._ - import zio.http._ -import zio.http.codec.HttpCodec.Metadata import zio.http.codec._ -import zio.http.codec.internal._ import zio.http.endpoint._ /** @@ -133,10 +125,10 @@ private[cli] object CliEndpoint { case HttpCodec.Path(pathCodec, _) => CliEndpoint(url = HttpOptions.Path(pathCodec) :: List()) - case HttpCodec.Query(name, textCodec, _) => - textCodec.asInstanceOf[TextCodec[_]] match { - case TextCodec.Constant(value) => CliEndpoint(url = HttpOptions.QueryConstant(name, value) :: List()) - case _ => CliEndpoint(url = HttpOptions.Query(name, textCodec) :: List()) + case query: HttpCodec.Query[Input, ?] => + query.textCodec match { + case TextCodec.Constant(value) => CliEndpoint(url = HttpOptions.QueryConstant(query.name, value) :: List()) + case _ => CliEndpoint(url = HttpOptions.Query(query.name, query.textCodec) :: List()) } case HttpCodec.Status(_, _) => CliEndpoint.empty diff --git a/zio-http-cli/src/test/scala/zio/http/endpoint/cli/EndpointGen.scala b/zio-http-cli/src/test/scala/zio/http/endpoint/cli/EndpointGen.scala index 5c7ba70f94c..ff50eca5d97 100644 --- a/zio-http-cli/src/test/scala/zio/http/endpoint/cli/EndpointGen.scala +++ b/zio-http-cli/src/test/scala/zio/http/endpoint/cli/EndpointGen.scala @@ -91,9 +91,9 @@ object EndpointGen { } lazy val anyQuery: Gen[Any, CliReprOf[Codec[_]]] = - Gen.alphaNumericStringBounded(1, 30).zip(anyTextCodec).map { case (name, codec) => + Gen.alphaNumericStringBounded(1, 30).zip(anyTextCodec).zip(Gen.boolean).map { case (name, codec, isMono) => CliRepr( - HttpCodec.Query(name, codec), + if (isMono) HttpCodec.MonoQuery(name, codec) else HttpCodec.MultiQuery(name, codec), codec match { case TextCodec.Constant(value) => CliEndpoint(url = HttpOptions.QueryConstant(name, value) :: Nil) case _ => CliEndpoint(url = HttpOptions.Query(name, codec) :: Nil) diff --git a/zio-http/src/main/scala/zio/http/FormField.scala b/zio-http/src/main/scala/zio/http/FormField.scala index 41bd7863379..55e5fc2b7e7 100644 --- a/zio-http/src/main/scala/zio/http/FormField.scala +++ b/zio-http/src/main/scala/zio/http/FormField.scala @@ -155,15 +155,15 @@ object FormField { Chunk.empty[FormAST.Content], ), ) { - case (accum, header: FormAST.Header) if header.name == "Content-Disposition" => + case (accum, header: FormAST.Header) if header.name.equalsIgnoreCase("Content-Disposition") => (Some(header), accum._2, accum._3, accum._4) - case (accum, content: FormAST.Content) => + case (accum, content: FormAST.Content) => (accum._1, accum._2, accum._3, accum._4 :+ content) - case (accum, header: FormAST.Header) if header.name == "Content-Type" => + case (accum, header: FormAST.Header) if header.name.equalsIgnoreCase("Content-Type") => (accum._1, Some(header), accum._3, accum._4) - case (accum, header: FormAST.Header) if header.name == "Content-Transfer-Encoding" => + case (accum, header: FormAST.Header) if header.name.equalsIgnoreCase("Content-Transfer-Encoding") => (accum._1, accum._2, Some(header), accum._4) - case (accum, _) => accum + case (accum, _) => accum } for { diff --git a/zio-http/src/main/scala/zio/http/Middleware.scala b/zio-http/src/main/scala/zio/http/Middleware.scala index a6daeec3902..71570f9e14d 100644 --- a/zio-http/src/main/scala/zio/http/Middleware.scala +++ b/zio-http/src/main/scala/zio/http/Middleware.scala @@ -15,10 +15,14 @@ */ package zio.http +import java.io.File + import zio._ import zio.metrics._ import zio.stacktracer.TracingImplicits.disableAutoTrace +import zio.http.codec.{PathCodec, SegmentCodec} + trait Middleware[-UpperEnv] { self => def apply[Env1 <: UpperEnv, Err]( routes: Routes[Env1, Err], @@ -244,10 +248,84 @@ object Middleware extends HandlerAspects { } } + private sealed trait StaticServe[-R, +E] { self => + def run(path: Path, req: Request): Handler[R, E, Request, Response] + + } + + private object StaticServe { + def make[R, E](f: (Path, Request) => Handler[R, E, Request, Response]): StaticServe[R, E] = + new StaticServe[R, E] { + override def run(path: Path, request: Request) = f(path, request) + } + + def fromDirectory(docRoot: File)(implicit trace: Trace): StaticServe[Any, Throwable] = make { (path, _) => + val target = new File(docRoot.getAbsolutePath() + path.encode) + if (target.getCanonicalPath.startsWith(docRoot.getCanonicalPath)) Handler.fromFile(target) + else { + Handler.fromZIO( + ZIO.logWarning(s"attempt to access file outside of docRoot: ${target.getAbsolutePath}"), + ) *> Handler.badRequest + } + } + + def fromResource(implicit trace: Trace): StaticServe[Any, Throwable] = make { (path, _) => + Handler.fromResource(path.dropLeadingSlash.encode) + } + + } + + private def toMiddleware[E](path: Path, staticServe: StaticServe[Any, E])(implicit trace: Trace): Middleware[Any] = + new Middleware[Any] { + + private def checkFishy(acc: Boolean, segment: String): Boolean = { + val stop = segment.indexOf('/') >= 0 || segment.indexOf('\\') >= 0 || segment == ".." + acc || stop + } + + override def apply[Env1 <: Any, Err](routes: Routes[Env1, Err]): Routes[Env1, Err] = { + val mountpoint = Method.GET / path.segments.map(PathCodec.literal).reduceLeft(_ / _) + val pattern = mountpoint / trailing + val other = Routes( + pattern -> Handler + .identity[Request] + .flatMap { request => + val isFishy = request.path.segments.foldLeft(false)(checkFishy) + if (isFishy) { + Handler.fromZIO(ZIO.logWarning(s"fishy request detected: ${request.path.encode}")) *> Handler.badRequest + } else { + val segs = pattern.pathCodec.segments.collect { case SegmentCodec.Literal(v, _) => + v + } + val unnest = segs.foldLeft(Path.empty)(_ / _).addLeadingSlash + val path = request.path.unnest(unnest).addLeadingSlash + staticServe.run(path, request).sandbox + } + }, + ) + routes ++ other + } + } + + /** + * Creates a middleware for serving static files from the directory `docRoot` + * at the path `path`. + */ + def serveDirectory(path: Path, docRoot: File)(implicit trace: Trace): Middleware[Any] = + toMiddleware(path, StaticServe.fromDirectory(docRoot)) + + /** + * Creates a middleware for serving static files from resources at the path + * `path`. + */ + def serveResources(path: Path)(implicit trace: Trace): Middleware[Any] = + toMiddleware(path, StaticServe.fromResource) + /** * Creates a middleware for managing the flash scope. */ def flashScopeHandling: HandlerAspect[Any, Unit] = Middleware.intercept { (req, resp) => req.cookie("zio-http-flash").fold(resp)(flash => resp.addCookie(Cookie.clear(flash.name))) } + } diff --git a/zio-http/src/main/scala/zio/http/codec/HttpCodec.scala b/zio-http/src/main/scala/zio/http/codec/HttpCodec.scala index f80665ac302..4dd554ec083 100644 --- a/zio-http/src/main/scala/zio/http/codec/HttpCodec.scala +++ b/zio-http/src/main/scala/zio/http/codec/HttpCodec.scala @@ -574,14 +574,49 @@ object HttpCodec extends ContentCodecs with HeaderCodecs with MethodCodecs with def index(index: Int): ContentStream[A] = copy(index = index) } - private[http] final case class Query[A](name: String, textCodec: TextCodec[A], index: Int = 0) - extends Atom[HttpCodecType.Query, A] { - self => - def erase: Query[Any] = self.asInstanceOf[Query[Any]] + + private[http] sealed trait Query[A, I] extends Atom[HttpCodecType.Query, A] { + def erase: Query[Any, I] = asInstanceOf[Query[Any, I]] + + def name: String + + def textCodec: TextCodec[I] + + def index: Int def tag: AtomTag = AtomTag.Query - def index(index: Int): Query[A] = copy(index = index) + def index(index: Int): Query[A, I] + + def encode(value: A): Chunk[String] + + def decode(values: Chunk[String]): A + + @inline final private[HttpCodec] def decodeItem(value: String): I = + if (textCodec.isDefinedAt(value)) textCodec(value) + else throw HttpCodecError.MalformedQueryParam(name, textCodec) + } + + private[http] final case class MonoQuery[A](name: String, textCodec: TextCodec[A], index: Int = 0) + extends Query[A, A] { + def index(index: Int): Query[A, A] = copy(index = index) + + def encode(value: A): Chunk[String] = Chunk(textCodec.encode(value)) + + def decode(values: Chunk[String]): A = values match { + case Chunk(value) => decodeItem(value) + case empty if empty.isEmpty => throw HttpCodecError.MissingQueryParam(name) + case _ => throw HttpCodecError.SingleQueryParamValueExpected(name) + } + } + + private[http] final case class MultiQuery[I](name: String, textCodec: TextCodec[I], index: Int = 0) + extends Query[Chunk[I], I] { + def index(index: Int): Query[Chunk[I], I] = copy(index = index) + + def encode(value: Chunk[I]): Chunk[String] = value map textCodec.encode + + def decode(values: Chunk[String]): Chunk[I] = values map decodeItem } private[http] final case class Method[A](codec: SimpleCodec[zio.http.Method, A], index: Int = 0) diff --git a/zio-http/src/main/scala/zio/http/codec/HttpCodecError.scala b/zio-http/src/main/scala/zio/http/codec/HttpCodecError.scala index 8c486cc6683..822b3638212 100644 --- a/zio-http/src/main/scala/zio/http/codec/HttpCodecError.scala +++ b/zio-http/src/main/scala/zio/http/codec/HttpCodecError.scala @@ -51,6 +51,9 @@ object HttpCodecError { final case class MissingQueryParam(queryParamName: String) extends HttpCodecError { def message = s"Missing query parameter $queryParamName" } + final case class SingleQueryParamValueExpected(queryParamName: String) extends HttpCodecError { + def message = s"Single query parameter $queryParamName value expected, but multiple values are found" + } final case class MalformedQueryParam(queryParamName: String, textCodec: TextCodec[_]) extends HttpCodecError { def message = s"Malformed query parameter $queryParamName failed to decode using $textCodec" } diff --git a/zio-http/src/main/scala/zio/http/codec/QueryCodecs.scala b/zio-http/src/main/scala/zio/http/codec/QueryCodecs.scala index 5bf72e57e31..4d73bb0026f 100644 --- a/zio-http/src/main/scala/zio/http/codec/QueryCodecs.scala +++ b/zio-http/src/main/scala/zio/http/codec/QueryCodecs.scala @@ -15,30 +15,36 @@ */ package zio.http.codec +import zio.Chunk import zio.stacktracer.TracingImplicits.disableAutoTrace private[codec] trait QueryCodecs { def query(name: String): QueryCodec[String] = - HttpCodec.Query(name, TextCodec.string) + HttpCodec.MonoQuery(name, TextCodec.string) def queryBool(name: String): QueryCodec[Boolean] = - HttpCodec.Query(name, TextCodec.boolean) + HttpCodec.MonoQuery(name, TextCodec.boolean) def queryInt(name: String): QueryCodec[Int] = - HttpCodec.Query(name, TextCodec.int) + HttpCodec.MonoQuery(name, TextCodec.int) def queryAs[A](name: String)(implicit codec: TextCodec[A]): QueryCodec[A] = - HttpCodec.Query(name, codec) + HttpCodec.MonoQuery(name, codec) + + def queries[I](name: String)(implicit codec: TextCodec[I]): QueryCodec[Chunk[I]] = + HttpCodec.MultiQuery(name, codec) def paramStr(name: String): QueryCodec[String] = - HttpCodec.Query(name, TextCodec.string) + HttpCodec.MonoQuery(name, TextCodec.string) def paramBool(name: String): QueryCodec[Boolean] = - HttpCodec.Query(name, TextCodec.boolean) + HttpCodec.MonoQuery(name, TextCodec.boolean) def paramInt(name: String): QueryCodec[Int] = - HttpCodec.Query(name, TextCodec.int) + HttpCodec.MonoQuery(name, TextCodec.int) def paramAs[A](name: String)(implicit codec: TextCodec[A]): QueryCodec[A] = - HttpCodec.Query(name, codec) + HttpCodec.MonoQuery(name, codec) + def params[I](name: String)(implicit codec: TextCodec[I]): QueryCodec[Chunk[I]] = + HttpCodec.MultiQuery(name, codec) } diff --git a/zio-http/src/main/scala/zio/http/codec/internal/AtomizedCodecs.scala b/zio-http/src/main/scala/zio/http/codec/internal/AtomizedCodecs.scala index 4c18c8e4664..b84fbbd9b86 100644 --- a/zio-http/src/main/scala/zio/http/codec/internal/AtomizedCodecs.scala +++ b/zio-http/src/main/scala/zio/http/codec/internal/AtomizedCodecs.scala @@ -25,7 +25,7 @@ import zio.http.codec._ private[http] final case class AtomizedCodecs( method: Chunk[SimpleCodec[zio.http.Method, _]], path: Chunk[PathCodec[_]], - query: Chunk[Query[_]], + query: Chunk[Query[_, _]], header: Chunk[Header[_]], content: Chunk[BodyCodec[_]], status: Chunk[SimpleCodec[zio.http.Status, _]], @@ -33,7 +33,7 @@ private[http] final case class AtomizedCodecs( def append(atom: Atom[_, _]): AtomizedCodecs = atom match { case path0: Path[_] => self.copy(path = path :+ path0.pathCodec) case method0: Method[_] => self.copy(method = method :+ method0.codec) - case query0: Query[_] => self.copy(query = query :+ query0) + case query0: Query[_, _] => self.copy(query = query :+ query0) case header0: Header[_] => self.copy(header = header :+ header0) case content0: Content[_] => self.copy(content = content :+ BodyCodec.Single(content0.schema, content0.mediaType, content0.name)) diff --git a/zio-http/src/main/scala/zio/http/codec/internal/EncoderDecoder.scala b/zio-http/src/main/scala/zio/http/codec/internal/EncoderDecoder.scala index 720ec51e4ab..99c3970110b 100644 --- a/zio-http/src/main/scala/zio/http/codec/internal/EncoderDecoder.scala +++ b/zio-http/src/main/scala/zio/http/codec/internal/EncoderDecoder.scala @@ -279,20 +279,8 @@ private[codec] object EncoderDecoder { var i = 0 val queries = flattened.query while (i < queries.length) { - val query = queries(i).erase - - val queryParamValue = - queryParams - .getAllOrElse(query.name, Nil) - .collectFirst(query.textCodec) - - queryParamValue match { - case Some(value) => - inputs(i) = value - case None => - throw HttpCodecError.MissingQueryParam(query.name) - } - + val query = queries(i) + inputs(i) = query.decode(queryParams.getAllOrElse(query.name, Nil)) i = i + 1 } } @@ -478,9 +466,7 @@ private[codec] object EncoderDecoder { val query = flattened.query(i).erase val input = inputs(i) - val value = query.textCodec.encode(input) - - queryParams = queryParams.add(query.name, value) + queryParams = queryParams.addAll(query.name, query.encode(input)) i = i + 1 } diff --git a/zio-http/src/test/scala/zio/http/codec/HttpCodecSpec.scala b/zio-http/src/test/scala/zio/http/codec/HttpCodecSpec.scala index 2f4c5aabb3b..d3ad9de0ba2 100644 --- a/zio-http/src/test/scala/zio/http/codec/HttpCodecSpec.scala +++ b/zio-http/src/test/scala/zio/http/codec/HttpCodecSpec.scala @@ -16,12 +16,11 @@ package zio.http.codec -import java.util.UUID - import zio._ +import zio.http._ import zio.test._ -import zio.http._ +import java.util.UUID object HttpCodecSpec extends ZIOHttpSpec { val googleUrl = URL.decode("http://google.com").toOption.get @@ -39,6 +38,8 @@ object HttpCodecSpec extends ZIOHttpSpec { val isAge = "isAge" val codecBool = QueryCodec.paramBool(isAge) + val intSeq = "intSeq" + val codecSeq = QueryCodec.params[Int](intSeq) def makeRequest(paramValue: String) = Request.get(googleUrl.queryParams(QueryParams(isAge -> paramValue))) def spec = suite("HttpCodecSpec")( @@ -140,6 +141,14 @@ object HttpCodecSpec extends ZIOHttpSpec { assert(requestTrue.url.queryParams.get(isAge).get)(Assertion.equalTo("true")) && assert(requestFalse.url.queryParams.get(isAge).get)(Assertion.equalTo("false")) }, + test("paramSeq encoding with empty value") { + val requestNil = codecSeq.encodeRequest(Chunk.empty) + assert(requestNil.url.queryParams.get(intSeq))(Assertion.isNone) + }, + test("paramSeq encoding with non-empty value") { + val requestNil = codecSeq.encodeRequest(Chunk(1974, 5, 3)) + assert(requestNil.url.queryParams.getAll(intSeq).get)(Assertion.equalTo(Chunk("1974", "5", "3"))) + }, ) + suite("Codec with examples") { test("with examples") {