Skip to content

Commit

Permalink
Issue zio#2321: MonoQuery and MultiQuery
Browse files Browse the repository at this point in the history
  • Loading branch information
eshu committed Oct 7, 2023
1 parent 588ab4f commit 3eddb79
Show file tree
Hide file tree
Showing 10 changed files with 162 additions and 53 deletions.
Original file line number Diff line number Diff line change
@@ -1,15 +1,7 @@
package zio.http.endpoint.cli

import scala.util.Try

import zio.cli._

import zio.schema._

import zio.http._
import zio.http.codec.HttpCodec.Metadata
import zio.http.codec._
import zio.http.codec.internal._
import zio.http.endpoint._

/**
Expand Down Expand Up @@ -133,10 +125,10 @@ private[cli] object CliEndpoint {
case HttpCodec.Path(pathCodec, _) =>
CliEndpoint(url = HttpOptions.Path(pathCodec) :: List())

case HttpCodec.Query(name, textCodec, _) =>
textCodec.asInstanceOf[TextCodec[_]] match {
case TextCodec.Constant(value) => CliEndpoint(url = HttpOptions.QueryConstant(name, value) :: List())
case _ => CliEndpoint(url = HttpOptions.Query(name, textCodec) :: List())
case query: HttpCodec.Query[Input, ?] =>
query.textCodec match {
case TextCodec.Constant(value) => CliEndpoint(url = HttpOptions.QueryConstant(query.name, value) :: List())
case _ => CliEndpoint(url = HttpOptions.Query(query.name, query.textCodec) :: List())
}

case HttpCodec.Status(_, _) => CliEndpoint.empty
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,9 @@ object EndpointGen {
}

lazy val anyQuery: Gen[Any, CliReprOf[Codec[_]]] =
Gen.alphaNumericStringBounded(1, 30).zip(anyTextCodec).map { case (name, codec) =>
Gen.alphaNumericStringBounded(1, 30).zip(anyTextCodec).zip(Gen.boolean).map { case (name, codec, isMono) =>
CliRepr(
HttpCodec.Query(name, codec),
if (isMono) HttpCodec.MonoQuery(name, codec) else HttpCodec.MultiQuery(name, codec),
codec match {
case TextCodec.Constant(value) => CliEndpoint(url = HttpOptions.QueryConstant(name, value) :: Nil)
case _ => CliEndpoint(url = HttpOptions.Query(name, codec) :: Nil)
Expand Down
10 changes: 5 additions & 5 deletions zio-http/src/main/scala/zio/http/FormField.scala
Original file line number Diff line number Diff line change
Expand Up @@ -155,15 +155,15 @@ object FormField {
Chunk.empty[FormAST.Content],
),
) {
case (accum, header: FormAST.Header) if header.name == "Content-Disposition" =>
case (accum, header: FormAST.Header) if header.name.equalsIgnoreCase("Content-Disposition") =>
(Some(header), accum._2, accum._3, accum._4)
case (accum, content: FormAST.Content) =>
case (accum, content: FormAST.Content) =>
(accum._1, accum._2, accum._3, accum._4 :+ content)
case (accum, header: FormAST.Header) if header.name == "Content-Type" =>
case (accum, header: FormAST.Header) if header.name.equalsIgnoreCase("Content-Type") =>
(accum._1, Some(header), accum._3, accum._4)
case (accum, header: FormAST.Header) if header.name == "Content-Transfer-Encoding" =>
case (accum, header: FormAST.Header) if header.name.equalsIgnoreCase("Content-Transfer-Encoding") =>
(accum._1, accum._2, Some(header), accum._4)
case (accum, _) => accum
case (accum, _) => accum
}

for {
Expand Down
78 changes: 78 additions & 0 deletions zio-http/src/main/scala/zio/http/Middleware.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,14 @@
*/
package zio.http

import java.io.File

import zio._
import zio.metrics._
import zio.stacktracer.TracingImplicits.disableAutoTrace

import zio.http.codec.{PathCodec, SegmentCodec}

trait Middleware[-UpperEnv] { self =>
def apply[Env1 <: UpperEnv, Err](
routes: Routes[Env1, Err],
Expand Down Expand Up @@ -244,10 +248,84 @@ object Middleware extends HandlerAspects {
}
}

private sealed trait StaticServe[-R, +E] { self =>
def run(path: Path, req: Request): Handler[R, E, Request, Response]

}

private object StaticServe {
def make[R, E](f: (Path, Request) => Handler[R, E, Request, Response]): StaticServe[R, E] =
new StaticServe[R, E] {
override def run(path: Path, request: Request) = f(path, request)
}

def fromDirectory(docRoot: File)(implicit trace: Trace): StaticServe[Any, Throwable] = make { (path, _) =>
val target = new File(docRoot.getAbsolutePath() + path.encode)
if (target.getCanonicalPath.startsWith(docRoot.getCanonicalPath)) Handler.fromFile(target)
else {
Handler.fromZIO(
ZIO.logWarning(s"attempt to access file outside of docRoot: ${target.getAbsolutePath}"),
) *> Handler.badRequest
}
}

def fromResource(implicit trace: Trace): StaticServe[Any, Throwable] = make { (path, _) =>
Handler.fromResource(path.dropLeadingSlash.encode)
}

}

private def toMiddleware[E](path: Path, staticServe: StaticServe[Any, E])(implicit trace: Trace): Middleware[Any] =
new Middleware[Any] {

private def checkFishy(acc: Boolean, segment: String): Boolean = {
val stop = segment.indexOf('/') >= 0 || segment.indexOf('\\') >= 0 || segment == ".."
acc || stop
}

override def apply[Env1 <: Any, Err](routes: Routes[Env1, Err]): Routes[Env1, Err] = {
val mountpoint = Method.GET / path.segments.map(PathCodec.literal).reduceLeft(_ / _)
val pattern = mountpoint / trailing
val other = Routes(
pattern -> Handler
.identity[Request]
.flatMap { request =>
val isFishy = request.path.segments.foldLeft(false)(checkFishy)
if (isFishy) {
Handler.fromZIO(ZIO.logWarning(s"fishy request detected: ${request.path.encode}")) *> Handler.badRequest
} else {
val segs = pattern.pathCodec.segments.collect { case SegmentCodec.Literal(v, _) =>
v
}
val unnest = segs.foldLeft(Path.empty)(_ / _).addLeadingSlash
val path = request.path.unnest(unnest).addLeadingSlash
staticServe.run(path, request).sandbox
}
},
)
routes ++ other
}
}

/**
* Creates a middleware for serving static files from the directory `docRoot`
* at the path `path`.
*/
def serveDirectory(path: Path, docRoot: File)(implicit trace: Trace): Middleware[Any] =
toMiddleware(path, StaticServe.fromDirectory(docRoot))

/**
* Creates a middleware for serving static files from resources at the path
* `path`.
*/
def serveResources(path: Path)(implicit trace: Trace): Middleware[Any] =
toMiddleware(path, StaticServe.fromResource)

/**
* Creates a middleware for managing the flash scope.
*/
def flashScopeHandling: HandlerAspect[Any, Unit] = Middleware.intercept { (req, resp) =>
req.cookie("zio-http-flash").fold(resp)(flash => resp.addCookie(Cookie.clear(flash.name)))
}

}
45 changes: 40 additions & 5 deletions zio-http/src/main/scala/zio/http/codec/HttpCodec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -574,14 +574,49 @@ object HttpCodec extends ContentCodecs with HeaderCodecs with MethodCodecs with

def index(index: Int): ContentStream[A] = copy(index = index)
}
private[http] final case class Query[A](name: String, textCodec: TextCodec[A], index: Int = 0)
extends Atom[HttpCodecType.Query, A] {
self =>
def erase: Query[Any] = self.asInstanceOf[Query[Any]]

private[http] sealed trait Query[A, I] extends Atom[HttpCodecType.Query, A] {
def erase: Query[Any, I] = asInstanceOf[Query[Any, I]]

def name: String

def textCodec: TextCodec[I]

def index: Int

def tag: AtomTag = AtomTag.Query

def index(index: Int): Query[A] = copy(index = index)
def index(index: Int): Query[A, I]

def encode(value: A): Chunk[String]

def decode(values: Chunk[String]): A

@inline final private[HttpCodec] def decodeItem(value: String): I =
if (textCodec.isDefinedAt(value)) textCodec(value)
else throw HttpCodecError.MalformedQueryParam(name, textCodec)
}

private[http] final case class MonoQuery[A](name: String, textCodec: TextCodec[A], index: Int = 0)
extends Query[A, A] {
def index(index: Int): Query[A, A] = copy(index = index)

def encode(value: A): Chunk[String] = Chunk(textCodec.encode(value))

def decode(values: Chunk[String]): A = values match {
case Chunk(value) => decodeItem(value)
case empty if empty.isEmpty => throw HttpCodecError.MissingQueryParam(name)
case _ => throw HttpCodecError.SingleQueryParamValueExpected(name)
}
}

private[http] final case class MultiQuery[I](name: String, textCodec: TextCodec[I], index: Int = 0)
extends Query[Chunk[I], I] {
def index(index: Int): Query[Chunk[I], I] = copy(index = index)

def encode(value: Chunk[I]): Chunk[String] = value map textCodec.encode

def decode(values: Chunk[String]): Chunk[I] = values map decodeItem
}

private[http] final case class Method[A](codec: SimpleCodec[zio.http.Method, A], index: Int = 0)
Expand Down
3 changes: 3 additions & 0 deletions zio-http/src/main/scala/zio/http/codec/HttpCodecError.scala
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ object HttpCodecError {
final case class MissingQueryParam(queryParamName: String) extends HttpCodecError {
def message = s"Missing query parameter $queryParamName"
}
final case class SingleQueryParamValueExpected(queryParamName: String) extends HttpCodecError {
def message = s"Single query parameter $queryParamName value expected, but multiple values are found"
}
final case class MalformedQueryParam(queryParamName: String, textCodec: TextCodec[_]) extends HttpCodecError {
def message = s"Malformed query parameter $queryParamName failed to decode using $textCodec"
}
Expand Down
22 changes: 14 additions & 8 deletions zio-http/src/main/scala/zio/http/codec/QueryCodecs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,30 +15,36 @@
*/

package zio.http.codec
import zio.Chunk
import zio.stacktracer.TracingImplicits.disableAutoTrace
private[codec] trait QueryCodecs {
def query(name: String): QueryCodec[String] =
HttpCodec.Query(name, TextCodec.string)
HttpCodec.MonoQuery(name, TextCodec.string)

def queryBool(name: String): QueryCodec[Boolean] =
HttpCodec.Query(name, TextCodec.boolean)
HttpCodec.MonoQuery(name, TextCodec.boolean)

def queryInt(name: String): QueryCodec[Int] =
HttpCodec.Query(name, TextCodec.int)
HttpCodec.MonoQuery(name, TextCodec.int)

def queryAs[A](name: String)(implicit codec: TextCodec[A]): QueryCodec[A] =
HttpCodec.Query(name, codec)
HttpCodec.MonoQuery(name, codec)

def queries[I](name: String)(implicit codec: TextCodec[I]): QueryCodec[Chunk[I]] =
HttpCodec.MultiQuery(name, codec)

def paramStr(name: String): QueryCodec[String] =
HttpCodec.Query(name, TextCodec.string)
HttpCodec.MonoQuery(name, TextCodec.string)

def paramBool(name: String): QueryCodec[Boolean] =
HttpCodec.Query(name, TextCodec.boolean)
HttpCodec.MonoQuery(name, TextCodec.boolean)

def paramInt(name: String): QueryCodec[Int] =
HttpCodec.Query(name, TextCodec.int)
HttpCodec.MonoQuery(name, TextCodec.int)

def paramAs[A](name: String)(implicit codec: TextCodec[A]): QueryCodec[A] =
HttpCodec.Query(name, codec)
HttpCodec.MonoQuery(name, codec)

def params[I](name: String)(implicit codec: TextCodec[I]): QueryCodec[Chunk[I]] =
HttpCodec.MultiQuery(name, codec)
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@ import zio.http.codec._
private[http] final case class AtomizedCodecs(
method: Chunk[SimpleCodec[zio.http.Method, _]],
path: Chunk[PathCodec[_]],
query: Chunk[Query[_]],
query: Chunk[Query[_, _]],
header: Chunk[Header[_]],
content: Chunk[BodyCodec[_]],
status: Chunk[SimpleCodec[zio.http.Status, _]],
) { self =>
def append(atom: Atom[_, _]): AtomizedCodecs = atom match {
case path0: Path[_] => self.copy(path = path :+ path0.pathCodec)
case method0: Method[_] => self.copy(method = method :+ method0.codec)
case query0: Query[_] => self.copy(query = query :+ query0)
case query0: Query[_, _] => self.copy(query = query :+ query0)
case header0: Header[_] => self.copy(header = header :+ header0)
case content0: Content[_] =>
self.copy(content = content :+ BodyCodec.Single(content0.schema, content0.mediaType, content0.name))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,20 +279,8 @@ private[codec] object EncoderDecoder {
var i = 0
val queries = flattened.query
while (i < queries.length) {
val query = queries(i).erase

val queryParamValue =
queryParams
.getAllOrElse(query.name, Nil)
.collectFirst(query.textCodec)

queryParamValue match {
case Some(value) =>
inputs(i) = value
case None =>
throw HttpCodecError.MissingQueryParam(query.name)
}

val query = queries(i)
inputs(i) = query.decode(queryParams.getAllOrElse(query.name, Nil))
i = i + 1
}
}
Expand Down Expand Up @@ -478,9 +466,7 @@ private[codec] object EncoderDecoder {
val query = flattened.query(i).erase
val input = inputs(i)

val value = query.textCodec.encode(input)

queryParams = queryParams.add(query.name, value)
queryParams = queryParams.addAll(query.name, query.encode(input))

i = i + 1
}
Expand Down
13 changes: 11 additions & 2 deletions zio-http/src/test/scala/zio/http/codec/HttpCodecSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@ package zio.http.codec
import java.util.UUID

import zio._
import zio.test._

import zio.http._
import zio.test._

object HttpCodecSpec extends ZIOHttpSpec {
val googleUrl = URL.decode("http://google.com").toOption.get
Expand All @@ -39,6 +38,8 @@ object HttpCodecSpec extends ZIOHttpSpec {

val isAge = "isAge"
val codecBool = QueryCodec.paramBool(isAge)
val intSeq = "intSeq"
val codecSeq = QueryCodec.params[Int](intSeq)
def makeRequest(paramValue: String) = Request.get(googleUrl.queryParams(QueryParams(isAge -> paramValue)))

def spec = suite("HttpCodecSpec")(
Expand Down Expand Up @@ -140,6 +141,14 @@ object HttpCodecSpec extends ZIOHttpSpec {
assert(requestTrue.url.queryParams.get(isAge).get)(Assertion.equalTo("true")) &&
assert(requestFalse.url.queryParams.get(isAge).get)(Assertion.equalTo("false"))
},
test("paramSeq encoding with empty value") {
val requestNil = codecSeq.encodeRequest(Chunk.empty)
assert(requestNil.url.queryParams.get(intSeq))(Assertion.isNone)
},
test("paramSeq encoding with non-empty value") {
val requestNil = codecSeq.encodeRequest(Chunk(1974, 5, 3))
assert(requestNil.url.queryParams.getAll(intSeq).get)(Assertion.equalTo(Chunk("1974", "5", "3")))
},
) +
suite("Codec with examples") {
test("with examples") {
Expand Down

0 comments on commit 3eddb79

Please sign in to comment.