diff --git a/core/src/main/scala/org/http4s/rho/ExecutableCompiler.scala b/core/src/main/scala/org/http4s/rho/ExecutableCompiler.scala index eed599ca..11241445 100644 --- a/core/src/main/scala/org/http4s/rho/ExecutableCompiler.scala +++ b/core/src/main/scala/org/http4s/rho/ExecutableCompiler.scala @@ -3,6 +3,7 @@ package rho import bits.HeaderAST._ import bits.QueryAST._ +import org.http4s.rho.bits.ResponseGeneratorInstances.BadRequest import org.http4s.rho.bits._ @@ -29,17 +30,17 @@ trait ExecutableCompiler { case HeaderCapture(key) => req.headers.get(key) match { case Some(h) => ParserSuccess(h::stack) - case None => ValidationFailure(s"Missing header: ${key.name}") + case None => ValidationFailure(BadRequest(s"Missing header: ${key.name}")) } case HeaderRequire(key, f) => req.headers.get(key) match { - case Some(h) => if (f(h)) ParserSuccess(stack) else ValidationFailure(s"Invalid header: $h") - case None => ValidationFailure(s"Missing header: ${key.name}") + case Some(h) => f(h).fold[ParserResult[HList]](ParserSuccess(stack))(r =>ValidationFailure(r)) + case None => ValidationFailure(BadRequest(s"Missing header: ${key.name}")) } case HeaderMapper(key, f) => req.headers.get(key) match { case Some(h) => ParserSuccess(f(h)::stack) - case None => ValidationFailure(s"Missing header: ${key.name}") + case None => ValidationFailure(BadRequest(s"Missing header: ${key.name}")) } case MetaCons(r, _) => runValidation(req, r, stack) diff --git a/core/src/main/scala/org/http4s/rho/Result.scala b/core/src/main/scala/org/http4s/rho/Result.scala index 9e4f6c92..c88ef19e 100644 --- a/core/src/main/scala/org/http4s/rho/Result.scala +++ b/core/src/main/scala/org/http4s/rho/Result.scala @@ -78,9 +78,11 @@ import Result._ trait ResultSyntaxInstances { - implicit class ResultSyntax[T >: Result.TopResult <: BaseResult](r: T) extends MessageOps { + implicit class ResultSyntax[T >: Result.TopResult <: BaseResult](r: T) extends ResponseOps { override type Self = T + def withStatus[S <% Status](status: S): Self = r.copy(resp = r.resp.copy(status = status)) + override def attemptAs[T](implicit decoder: EntityDecoder[T]): DecodeResult[T] = { val t: Task[ParseFailure\/T] = r.resp.attemptAs(decoder).run EitherT[Task, ParseFailure, T](t) @@ -103,9 +105,13 @@ trait ResultSyntaxInstances { } } - implicit class TaskResultSyntax[T >: Result.TopResult <: BaseResult](r: Task[T]) extends MessageOps { + implicit class TaskResultSyntax[T >: Result.TopResult <: BaseResult](r: Task[T]) extends ResponseOps { override type Self = Task[T] + def withStatus[S <% Status](status: S): Self = r.map{ result => + result.copy(resp = result.resp.copy(status = status)) + } + override def attemptAs[T](implicit decoder: EntityDecoder[T]): DecodeResult[T] = { val t: Task[ParseFailure\/T] = r.flatMap { t => t.resp.attemptAs(decoder).run diff --git a/core/src/main/scala/org/http4s/rho/RhoService.scala b/core/src/main/scala/org/http4s/rho/RhoService.scala index 44f2470b..3a86510f 100644 --- a/core/src/main/scala/org/http4s/rho/RhoService.scala +++ b/core/src/main/scala/org/http4s/rho/RhoService.scala @@ -32,7 +32,7 @@ trait RhoService extends bits.MethodAliases case NoMatch => Task.now(None) case ParserSuccess(t) => attempt(t).map(Some(_)) case ParserFailure(s) => onBadRequest(s).map(Some(_)) - case ValidationFailure(s) => onBadRequest(s).map(Some(_)) + case ValidationFailure(r) => r.map( r => Some(r.resp)) } } @@ -55,7 +55,7 @@ trait RhoService extends bits.MethodAliases val w = EntityEncoder.stringEncoder w.toEntity(reason).map{ entity => val hs = entity.length match { - case Some(l) => w.headers.put(Header.`Content-Length`(l)) + case Some(l) => w.headers.put(headers.`Content-Length`(l)) case None => w.headers } Response(status, body = entity.body, headers = hs) diff --git a/core/src/main/scala/org/http4s/rho/Router.scala b/core/src/main/scala/org/http4s/rho/Router.scala index e4e30731..2b8cb03d 100644 --- a/core/src/main/scala/org/http4s/rho/Router.scala +++ b/core/src/main/scala/org/http4s/rho/Router.scala @@ -5,6 +5,7 @@ import bits.PathAST._ import bits.HeaderAST._ import bits.QueryAST.QueryRule import org.http4s.rho.bits.{HeaderAppendable, HListToFunc} +import headers.`Content-Type` import shapeless.{::, HList} import shapeless.ops.hlist.Prepend @@ -70,7 +71,7 @@ case class CodecRouter[T <: HList, R](router: Router[T], decoder: EntityDecoder[ override val headers: HeaderRule = { if (!decoder.consumes.isEmpty) { - val mt = requireThat(Header.`Content-Type`) { h: Header.`Content-Type`.HeaderT => + val mt = requireThat(`Content-Type`) { h: `Content-Type`.HeaderT => decoder.matchesMediaType(h.mediaType) } diff --git a/core/src/main/scala/org/http4s/rho/bits/HeaderAST.scala b/core/src/main/scala/org/http4s/rho/bits/HeaderAST.scala index e7c9e503..d3338651 100644 --- a/core/src/main/scala/org/http4s/rho/bits/HeaderAST.scala +++ b/core/src/main/scala/org/http4s/rho/bits/HeaderAST.scala @@ -1,9 +1,12 @@ package org.http4s package rho.bits +import org.http4s.rho.Result.BaseResult import shapeless.ops.hlist.Prepend import shapeless.{::, HList} +import scalaz.concurrent.Task + /** AST representing the Header operations of the DSL */ object HeaderAST { @@ -23,7 +26,7 @@ object HeaderAST { sealed trait HeaderRule - case class HeaderRequire[T <: HeaderKey.Extractable](key: T, f: T#HeaderT => Boolean) extends HeaderRule + case class HeaderRequire[T <: HeaderKey.Extractable](key: T, f: T#HeaderT => Option[Task[BaseResult]]) extends HeaderRule case class HeaderMapper[T <: HeaderKey.Extractable, R](key: T, f: T#HeaderT => R) extends HeaderRule diff --git a/core/src/main/scala/org/http4s/rho/bits/ParserResult.scala b/core/src/main/scala/org/http4s/rho/bits/ParserResult.scala index d9a5e48c..699dca8e 100644 --- a/core/src/main/scala/org/http4s/rho/bits/ParserResult.scala +++ b/core/src/main/scala/org/http4s/rho/bits/ParserResult.scala @@ -1,5 +1,10 @@ package org.http4s.rho.bits +import org.http4s.Response +import org.http4s.rho.Result.BaseResult + +import scalaz.concurrent.Task + sealed trait RouteResult[+T] case object NoMatch extends RouteResult[Nothing] @@ -26,5 +31,5 @@ sealed trait ParserResult[+T] extends RouteResult[T] { case class ParserSuccess[+T](result: T) extends ParserResult[T] case class ParserFailure(reason: String) extends ParserResult[Nothing] // TODO: I think the reason for failure could be made easier to use with specific failure types -case class ValidationFailure(reason: String) extends ParserResult[Nothing] +case class ValidationFailure(response: Task[BaseResult]) extends ParserResult[Nothing] diff --git a/core/src/main/scala/org/http4s/rho/bits/QueryParser.scala b/core/src/main/scala/org/http4s/rho/bits/QueryParser.scala index 65416cee..b38b19ca 100644 --- a/core/src/main/scala/org/http4s/rho/bits/QueryParser.scala +++ b/core/src/main/scala/org/http4s/rho/bits/QueryParser.scala @@ -1,25 +1,28 @@ package org.http4s package rho.bits +import org.http4s.rho.Result.BaseResult import org.http4s.rho.bits.QueryParser.Params +import org.http4s.rho.bits.ResponseGeneratorInstances.BadRequest import scala.language.higherKinds -import scalaz.{-\/, \/-} import scala.annotation.tailrec import scala.collection.generic.CanBuildFrom +import scalaz.concurrent.Task trait QueryParser[A] { import QueryParser.Params def collect(name: String, params: Params, default: Option[A]): ParserResult[A] } -final class ValidatingParser[A](parent: QueryParser[A], validate: A => Boolean) extends QueryParser[A] { +final class ValidatingParser[A](parent: QueryParser[A], validate: A => Option[Task[BaseResult]]) extends QueryParser[A] { override def collect(name: String, params: Params, default: Option[A]): ParserResult[A] = { val result = parent.collect(name, params, default) - result.flatMap{ r => - if (validate(r)) result - else ValidationFailure("Invalid parameter: \"" + r + '"') + result.flatMap{ r => validate(r) match { + case None => result + case Some(resp) => ValidationFailure(resp) + } } } } @@ -73,14 +76,14 @@ object QueryParser { case Some(Seq()) => default match { case Some(defaultValue) => ParserSuccess(defaultValue) - case None => ValidationFailure(s"Value of query parameter '$name' missing") + case None => ValidationFailure(BadRequest(s"Value of query parameter '$name' missing")) } case None => default match { case Some(defaultValue) => ParserSuccess(defaultValue) - case None => ValidationFailure(s"Missing query param: $name") + case None => ValidationFailure(BadRequest(s"Missing query param: $name")) } } } } - } + diff --git a/core/src/main/scala/org/http4s/rho/package.scala b/core/src/main/scala/org/http4s/rho/package.scala index 72ed2bd2..d7a5f349 100644 --- a/core/src/main/scala/org/http4s/rho/package.scala +++ b/core/src/main/scala/org/http4s/rho/package.scala @@ -1,5 +1,8 @@ package org.http4s +import org.http4s.rho.Result.BaseResult +import org.http4s.rho.bits.ResponseGeneratorInstances.BadRequest + import scala.language.implicitConversions import rho.bits.PathAST._ @@ -10,6 +13,7 @@ import shapeless.{HNil, ::} import org.http4s.rho.bits._ import scala.reflect.runtime.universe.TypeTag +import scalaz.concurrent.Task package object rho extends Http4s with ResultSyntaxInstances { @@ -31,13 +35,36 @@ package object rho extends Http4s with ResultSyntaxInstances { def param[T](name: String)(implicit parser: QueryParser[T], m: TypeTag[T]): TypedQuery[T :: HNil] = TypedQuery(QueryCapture(name, parser, default = None, m)) - /** - * Defines a parameter in query string that should be bound to a route definition. - * @param name name of the parameter in query - * @param default value that should be used if no or an invalid parameter is available - * @param validate predicate to determine if a parameter is valid - */ - def param[T](name: String, default: T, validate: T => Boolean = (_: T) => true)(implicit parser: QueryParser[T], m: TypeTag[T]): TypedQuery[T :: HNil] = + /** Define a query parameter with a default value */ + def param[T](name: String, default: T)(implicit parser: QueryParser[T], m: TypeTag[T]): TypedQuery[T :: HNil] = + TypedQuery(QueryCapture(name, parser, default = Some(default), m)) + + /** Define a query parameter that will be validated with the predicate + * + * Failure of the predicate results in a '403: BadRequest' response. */ + def param[T](name: String, validate: T => Boolean) + (implicit parser: QueryParser[T], m: TypeTag[T]): TypedQuery[T :: HNil] = + paramR(name, {t => + if (validate(t)) None + else Some(BadRequest("Invalid query parameter: \"" + t + "\"")) + }) + + /** Define a query parameter that will be validated with the predicate + * + * Failure of the predicate results in a '403: BadRequest' response. */ + def param[T](name: String, default: T, validate: T => Boolean) + (implicit parser: QueryParser[T], m: TypeTag[T]): TypedQuery[T :: HNil] = + paramR(name, default, {t => + if (validate(t)) None + else Some(BadRequest("Invalid query parameter: \"" + t + "\"")) + }) + + /** Defines a parameter in query string that should be bound to a route definition. */ + def paramR[T](name: String, validate: T => Option[Task[BaseResult]])(implicit parser: QueryParser[T], m: TypeTag[T]): TypedQuery[T :: HNil] = + TypedQuery(QueryCapture(name, new ValidatingParser(parser, validate), default = None, m)) + + /** Defines a parameter in query string that should be bound to a route definition. */ + def paramR[T](name: String, default: T, validate: T => Option[Task[BaseResult]])(implicit parser: QueryParser[T], m: TypeTag[T]): TypedQuery[T :: HNil] = TypedQuery(QueryCapture(name, new ValidatingParser(parser, validate), default = Some(default), m)) /** @@ -65,10 +92,17 @@ package object rho extends Http4s with ResultSyntaxInstances { /////////////////////////////// Header helpers ////////////////////////////////////// /* Checks that the header exists */ - def require(header: HeaderKey.Extractable): TypedHeader[HNil] = requireThat(header)(_ => true) + def require(header: HeaderKey.Extractable): TypedHeader[HNil] = requireThatR(header)(_ => None) /* Check that the header exists and satisfies the condition */ def requireThat[H <: HeaderKey.Extractable](header: H)(f: H#HeaderT => Boolean): TypedHeader[HNil] = + requireThatR(header){ h => + if (f(h)) None + else Some(BadRequest("Invalid header: " + h.value)) + } + + /* Check that the header exists and satisfies the condition */ + def requireThatR[H <: HeaderKey.Extractable](header: H)(f: H#HeaderT => Option[Task[BaseResult]]): TypedHeader[HNil] = TypedHeader(HeaderRequire(header, f)) /** requires the header and will pull this header from the pile and put it into the function args stack */ diff --git a/core/src/test/scala/org/http4s/rho/ApiTest.scala b/core/src/test/scala/org/http4s/rho/ApiTest.scala index a050ca6b..a0877fcd 100644 --- a/core/src/test/scala/org/http4s/rho/ApiTest.scala +++ b/core/src/test/scala/org/http4s/rho/ApiTest.scala @@ -1,11 +1,11 @@ package org.http4s package rho -import org.http4s.rho.bits.MethodAliases._ -import org.http4s.rho.bits.ResponseGeneratorInstances._ +import bits.MethodAliases._ +import bits.ResponseGeneratorInstances._ -import org.http4s.rho.bits.HeaderAST.{TypedHeader, HeaderAnd} -import org.http4s.rho.bits.{RhoPathTree, ParserSuccess, ValidationFailure} +import bits.HeaderAST.{TypedHeader, HeaderAnd} +import bits.{RhoPathTree, ParserSuccess, ValidationFailure} import org.specs2.mutable._ import shapeless.HNil @@ -15,18 +15,18 @@ import scodec.bits.ByteVector // TODO: these tests are a bit of a mess class ApiTest extends Specification { - val lenheader = Header.`Content-Length`(4) - val etag = Header.ETag("foo") + val lenheader = headers.`Content-Length`(4) + val etag = headers.ETag("foo") - val RequireETag = require(Header.ETag) - val RequireNonZeroLen = requireThat(Header.`Content-Length`){ h => h.length != 0 } + val RequireETag = require(headers.ETag) + val RequireNonZeroLen = requireThat(headers.`Content-Length`){ h => h.length != 0 } def fetchETag(p: Task[Option[Response]]): String = { val resp = p.run val mvalue = for { r <- resp - h <- r.headers.get(Header.ETag) + h <- r.headers.get(headers.ETag) } yield h.value mvalue.getOrElse(sys.error("No ETag: " + resp)) @@ -39,8 +39,10 @@ class ApiTest extends Specification { "Fail on a bad request" in { val badreq = Request().withHeaders(Headers(lenheader)) - RhoPathTree.ValidationTools.ensureValidHeaders((RequireETag && RequireNonZeroLen).rule,badreq) should_== - ValidationFailure(s"Missing header: ${etag.name}") + val res = RhoPathTree.ValidationTools.ensureValidHeaders((RequireETag && RequireNonZeroLen).rule,badreq) + + res must beAnInstanceOf[ValidationFailure] + res.asInstanceOf[ValidationFailure].response.run.resp.status must_== Status.BadRequest } "Match captureless route" in { @@ -53,41 +55,41 @@ class ApiTest extends Specification { "Capture params" in { val req = Request().withHeaders(Headers(etag, lenheader)) Seq({ - val c2 = capture(Header.`Content-Length`) && RequireETag + val c2 = capture(headers.`Content-Length`) && RequireETag RhoPathTree.ValidationTools.ensureValidHeaders(c2.rule, req) should_== ParserSuccess(lenheader::HNil) }, { - val c3 = capture(Header.`Content-Length`) && capture(Header.ETag) + val c3 = capture(headers.`Content-Length`) && capture(headers.ETag) RhoPathTree.ValidationTools.ensureValidHeaders(c3.rule, req) should_== ParserSuccess(etag::lenheader::HNil) }).reduce( _ and _) } "Map header params" in { val req = Request().withHeaders(Headers(etag, lenheader)) - val c = requireMap(Header.`Content-Length`)(_.length) + val c = requireMap(headers.`Content-Length`)(_.length) RhoPathTree.ValidationTools.ensureValidHeaders(c.rule, req) should_== ParserSuccess(4::HNil) } "Append headers to a Route" in { val path = POST / "hello" / 'world +? param[Int]("fav") - val validations = requireThat(Header.`Content-Length`){ h => h.length != 0 } + val validations = requireThat(headers.`Content-Length`){ h => h.length != 0 } - val route = (path >>> validations >>> capture(Header.ETag)).decoding(EntityDecoder.text) runWith - {(world: String, fav: Int, tag: Header.ETag, body: String) => + val route = (path >>> validations >>> capture(headers.ETag)).decoding(EntityDecoder.text) runWith + {(world: String, fav: Int, tag: headers.ETag, body: String) => Ok(s"Hello to you too, $world. Your Fav number is $fav. You sent me $body") - .putHeaders(Header.ETag("foo")) + .putHeaders(headers.ETag("foo")) } val body = Process.emit(ByteVector("cool".getBytes)) val req = Request(POST, uri = Uri.fromString("/hello/neptune?fav=23").getOrElse(sys.error("Fail"))) - .putHeaders(Header.ETag("foo")) + .putHeaders(headers.ETag("foo")) .withBody("cool") .run val resp = route(req).run.get - resp.headers.get(Header.ETag).get.value should_== "foo" + resp.headers.get(headers.ETag).get.value should_== "foo" } @@ -95,7 +97,7 @@ class ApiTest extends Specification { val p1 = "one" / 'two val p2 = "three" / 'four - val f = GET / (p1 || p2) runWith { (s: String) => Ok("").withHeaders(Header.ETag(s)) } + val f = GET / (p1 || p2) runWith { (s: String) => Ok("").withHeaders(headers.ETag(s)) } val req1 = Request(uri = Uri.fromString("/one/two").getOrElse(sys.error("Failed."))) fetchETag(f(req1)) should_== "two" @@ -107,23 +109,23 @@ class ApiTest extends Specification { "Execute a complicated route" in { val path = POST / "hello" / 'world +? param[Int]("fav") - val validations = requireThat(Header.`Content-Length`){ h => h.length != 0 } && - capture(Header.ETag) + val validations = requireThat(headers.`Content-Length`){ h => h.length != 0 } && + capture(headers.ETag) val route = - (path >>> validations).decoding(EntityDecoder.text) runWith {(world: String, fav: Int, tag: Header.ETag, body: String) => + (path >>> validations).decoding(EntityDecoder.text) runWith {(world: String, fav: Int, tag: headers.ETag, body: String) => Ok(s"Hello to you too, $world. Your Fav number is $fav. You sent me $body") - .putHeaders(Header.ETag("foo")) + .putHeaders(headers.ETag("foo")) } val req = Request(POST, uri = Uri.fromString("/hello/neptune?fav=23").getOrElse(sys.error("Fail"))) - .putHeaders( Header.ETag("foo")) + .putHeaders( headers.ETag("foo")) .withBody("cool") .run val resp = route(req).run.get - resp.headers.get(Header.ETag).get.value should_== "foo" + resp.headers.get(headers.ETag).get.value should_== "foo" } "Deal with 'no entity' responses" in { @@ -152,14 +154,14 @@ class ApiTest extends Specification { "PathValidator" should { def check(p: Task[Option[Response]], s: String) = { - p.run.get.headers.get(Header.ETag).get.value should_== s + p.run.get.headers.get(headers.ETag).get.value should_== s } "traverse a captureless path" in { val stuff = GET / "hello" val req = Request(uri = Uri.fromString("/hello").getOrElse(sys.error("Failed."))) - val f = stuff runWith { () => Ok("Cool.").withHeaders(Header.ETag("foo")) } + val f = stuff runWith { () => Ok("Cool.").withHeaders(headers.ETag("foo")) } check(f(req), "foo") } @@ -167,7 +169,7 @@ class ApiTest extends Specification { val stuff = GET / "hello" val req = Request(uri = Uri.fromString("/hello/world").getOrElse(sys.error("Failed."))) - val f = stuff runWith { () => Ok("Cool.").withHeaders(Header.ETag("foo")) } + val f = stuff runWith { () => Ok("Cool.").withHeaders(headers.ETag("foo")) } val r = f(req).run r should_== None } @@ -176,7 +178,7 @@ class ApiTest extends Specification { val stuff = GET / 'hello val req = Request(uri = Uri.fromString("/hello").getOrElse(sys.error("Failed."))) - val f = stuff runWith { str: String => Ok("Cool.").withHeaders(Header.ETag(str)) } + val f = stuff runWith { str: String => Ok("Cool.").withHeaders(headers.ETag(str)) } check(f(req), "hello") } @@ -184,7 +186,7 @@ class ApiTest extends Specification { val stuff = GET / "hello" val req = Request(uri = Uri.fromString("/hello").getOrElse(sys.error("Failed."))) - val f = stuff runWith { () => Ok("Cool.").withHeaders(Header.ETag("foo")) } + val f = stuff runWith { () => Ok("Cool.").withHeaders(headers.ETag("foo")) } check(f(req), "foo") } @@ -192,7 +194,7 @@ class ApiTest extends Specification { "capture end with nothing" in { val stuff = GET / "hello" / * val req = Request(uri = Uri.fromString("/hello").getOrElse(sys.error("Failed."))) - val f = stuff runWith { path: List[String] => Ok("Cool.").withHeaders(Header.ETag(if (path.isEmpty) "go" else "nogo")) } + val f = stuff runWith { path: List[String] => Ok("Cool.").withHeaders(headers.ETag(if (path.isEmpty) "go" else "nogo")) } check(f(req), "go") } @@ -200,7 +202,7 @@ class ApiTest extends Specification { "capture remaining" in { val stuff = GET / "hello" / * val req = Request(uri = Uri.fromString("/hello/world/foo").getOrElse(sys.error("Failed."))) - val f = stuff runWith { path: List[String] => Ok("Cool.").withHeaders(Header.ETag(path.mkString)) } + val f = stuff runWith { path: List[String] => Ok("Cool.").withHeaders(headers.ETag(path.mkString)) } check(f(req), "worldfoo") } @@ -211,7 +213,7 @@ class ApiTest extends Specification { val path = GET / "hello" +? param[Int]("jimbo") val req = Request(uri = Uri.fromString("/hello?jimbo=32").getOrElse(sys.error("Failed."))) - val route = path runWith { i: Int => Ok("stuff").withHeaders(Header.ETag((i + 1).toString)) } + val route = path runWith { i: Int => Ok("stuff").withHeaders(headers.ETag((i + 1).toString)) } fetchETag(route(req)) should_== "33" @@ -220,7 +222,7 @@ class ApiTest extends Specification { "Decoders" should { "Decode a body" in { - val reqHeader = requireThat(Header.`Content-Length`){ h => h.length < 10 } + val reqHeader = requireThat(headers.`Content-Length`){ h => h.length < 10 } val path = POST / "hello" >>> reqHeader @@ -234,7 +236,7 @@ class ApiTest extends Specification { .run val route = path.decoding(EntityDecoder.text) runWith { str: String => - Ok("stuff").withHeaders(Header.ETag(str)) + Ok("stuff").withHeaders(headers.ETag(str)) } fetchETag(route(req1)) should_== "foo" @@ -249,7 +251,7 @@ class ApiTest extends Specification { .run val route = path ^ EntityDecoder.text runWith { str: String => - Ok("stuff").withHeaders(Header.ETag(str)) + Ok("stuff").withHeaders(headers.ETag(str)) } fetchETag(route(req)) should_== "foo" @@ -257,17 +259,42 @@ class ApiTest extends Specification { "Fail on a header" in { val path = GET / "hello" - val reqHeader = requireThat(Header.`Content-Length`){ h => h.length < 2} - val body = Process.emit(ByteVector.apply("foo".getBytes())) - val req = Request(uri = Uri.fromString("/hello").getOrElse(sys.error("Failed.")), body = body) - .withHeaders(Headers(Header.`Content-Length`("foo".length))) - val route = path.validate(reqHeader).decoding(EntityDecoder.text) runWith { str: String => - Ok("stuff").withHeaders(Header.ETag(str)) + val req = Request(uri = uri("/hello")) + .withHeaders(Headers(headers.`Content-Length`("foo".length))) + + val reqHeader = requireThat(headers.`Content-Length`){ h => h.length < 2} + val route1 = path.validate(reqHeader) runWith { () => + Ok("shouldn't get here.") + } + + route1(req).run.get.status should_== Status.BadRequest + + val reqHeaderR = requireThatR(headers.`Content-Length`){ h => Some(Unauthorized("Foo."))} + val route2 = path.validate(reqHeaderR) runWith { () => + Ok("shouldn't get here.") + } + + route2(req).run.get.status should_== Status.Unauthorized + } + + "Fail on a query" in { + val path = GET / "hello" + + val req = Request(uri = uri("/hello?foo=bar")) + .withHeaders(Headers(headers.`Content-Length`("foo".length))) + + val route1 = (path +? param[Int]("foo")).runWith { i: Int => + Ok("shouldn't get here.") + } + + route1(req).run.get.status should_== Status.BadRequest + + val route2 = (path +? paramR[String]("foo", (_: String) => Some(Unauthorized("foo")))).runWith { str: String => + Ok("shouldn't get here.") } - val result = route(req) - result.run.get.status should_== Status.BadRequest + route2(req).run.get.status should_== Status.Unauthorized } } } diff --git a/core/src/test/scala/org/http4s/rho/ParamDefaultValueSpec.scala b/core/src/test/scala/org/http4s/rho/ParamDefaultValueSpec.scala index 96228aef..033b664c 100644 --- a/core/src/test/scala/org/http4s/rho/ParamDefaultValueSpec.scala +++ b/core/src/test/scala/org/http4s/rho/ParamDefaultValueSpec.scala @@ -205,10 +205,10 @@ class ParamDefaultValueSpec extends Specification { body(requestGet("/test9")) must be equalTo default } "fail to map parameter with empty value" in { - body(requestGet("/test9?param1=")) must be equalTo "Invalid parameter: \"\"" + body(requestGet("/test9?param1=")) must be equalTo "Invalid query parameter: \"\"" } "fail to map parameter with invalid value" in { - body(requestGet("/test9?param1=fail")) must be equalTo "Invalid parameter: \"fail\"" + body(requestGet("/test9?param1=fail")) must be equalTo "Invalid query parameter: \"fail\"" } "map parameter with valid value" in { body(requestGet("/test9?param1=pass")) must be equalTo "test9:pass" @@ -227,7 +227,7 @@ class ParamDefaultValueSpec extends Specification { body(requestGet("/test10?param1=")) must be equalTo "Invalid Number Format: \"\"" } "fail to map parameter with invalid numeric value" in { - body(requestGet("/test10?param1=-4")) must be equalTo "Invalid parameter: \"-4\"" + body(requestGet("/test10?param1=-4")) must be equalTo "Invalid query parameter: \"-4\"" } "fail to map parameter with non-numeric value" in { body(requestGet("/test10?param1=value1")) must be equalTo "Invalid Number Format: \"value1\"" @@ -249,7 +249,7 @@ class ParamDefaultValueSpec extends Specification { body(requestGet("/test11?param1=")) must be equalTo "Invalid Number Format: \"\"" } "fail to map parameter with invalid numeric value" in { - body(requestGet("/test11?param1=0")) must be equalTo "Invalid parameter: \"Some(0)\"" + body(requestGet("/test11?param1=0")) must be equalTo "Invalid query parameter: \"Some(0)\"" } "fail to map parameter with non-numeric value" in { body(requestGet("/test11?param1=value1")) must be equalTo "Invalid Number Format: \"value1\"" @@ -268,10 +268,10 @@ class ParamDefaultValueSpec extends Specification { body(requestGet("/test12")) must be equalTo default } "fail to map parameter with empty value" in { - body(requestGet("/test12?param1=")) must be equalTo "Invalid parameter: \"Some()\"" + body(requestGet("/test12?param1=")) must be equalTo "Invalid query parameter: \"Some()\"" } "fail to map parameter with invalid value" in { - body(requestGet("/test12?param1=fail")) must be equalTo "Invalid parameter: \"Some(fail)\"" + body(requestGet("/test12?param1=fail")) must be equalTo "Invalid query parameter: \"Some(fail)\"" } "map parameter with valid value" in { body(requestGet("/test12?param1=pass")) must be equalTo "test12:pass" @@ -287,13 +287,13 @@ class ParamDefaultValueSpec extends Specification { body(requestGet("/test13")) must be equalTo default } "fail to map parameter with empty value" in { - body(requestGet("/test13?param1=")) must be equalTo "Invalid parameter: \"List()\"" + body(requestGet("/test13?param1=")) must be equalTo "Invalid query parameter: \"List()\"" } "fail to map parameter with one invalid value" in { - body(requestGet("/test13?param1=z")) must be equalTo "Invalid parameter: \"List(z)\"" + body(requestGet("/test13?param1=z")) must be equalTo "Invalid query parameter: \"List(z)\"" } "map parameter with many values and one invalid" in { - body(requestGet("/test13?param1=z¶m1=aa¶m1=bb")) must be equalTo "Invalid parameter: \"List(z, aa, bb)\"" + body(requestGet("/test13?param1=z¶m1=aa¶m1=bb")) must be equalTo "Invalid query parameter: \"List(z, aa, bb)\"" } "map parameter with many valid values" in { body(requestGet("/test13?param1=c¶m1=d")) must be equalTo "test13:c,d" @@ -312,7 +312,7 @@ class ParamDefaultValueSpec extends Specification { body(requestGet("/test14?param1=")) must be equalTo "Invalid Number Format: \"\"" } "fail to map parameter with one invalid numeric value" in { - body(requestGet("/test14?param1=8¶m1=5¶m1=3")) must be equalTo "Invalid parameter: \"List(8, 5, 3)\"" + body(requestGet("/test14?param1=8¶m1=5¶m1=3")) must be equalTo "Invalid query parameter: \"List(8, 5, 3)\"" } "fail to map parameter with one non-numeric value" in { body(requestGet("/test14?param1=test")) must be equalTo "Invalid Number Format: \"test\"" diff --git a/core/src/test/scala/org/http4s/rho/bits/ParserResultSpec.scala b/core/src/test/scala/org/http4s/rho/bits/ParserResultSpec.scala index 973e4953..d1895abc 100644 --- a/core/src/test/scala/org/http4s/rho/bits/ParserResultSpec.scala +++ b/core/src/test/scala/org/http4s/rho/bits/ParserResultSpec.scala @@ -1,5 +1,8 @@ -package org.http4s.rho.bits +package org.http4s.rho +package bits + +import org.http4s.rho.bits.ResponseGeneratorInstances.BadRequest import org.specs2.mutable.Specification @@ -12,7 +15,7 @@ class ParserResultSpec extends Specification { } "map a ValidationFailure" in { - val result: ParserResult[Int] = ValidationFailure("foo") + val result: ParserResult[Int] = ValidationFailure(BadRequest("foo")) result.map(_.toString) should_== result } @@ -26,7 +29,7 @@ class ParserResultSpec extends Specification { } "flatMap a ValidationFailure" in { - val result: ParserResult[Int] = ValidationFailure("foo") + val result: ParserResult[Int] = ValidationFailure(BadRequest("foo")) result.flatMap(i => ParserSuccess(i.toString)) should_== result } diff --git a/examples/src/main/scala/com/http4s/rho/swagger/demo/MyService.scala b/examples/src/main/scala/com/http4s/rho/swagger/demo/MyService.scala index 0eabd5bc..98afd589 100644 --- a/examples/src/main/scala/com/http4s/rho/swagger/demo/MyService.scala +++ b/examples/src/main/scala/com/http4s/rho/swagger/demo/MyService.scala @@ -16,8 +16,19 @@ object MyService extends RhoService with SwaggerSupport { import org.http4s.rho._ import org.http4s.rho.swagger._ + import org.http4s.headers + import org.http4s.{Request, Headers, DateTime} + case class JsonResult(name: String, number: Int) extends AutoSerializable + val requireCookie = requireThatR(headers.Cookie){ cookie => + cookie.values.toList.find(c => c.name == "Foo" && c.content == "bar") match { + case Some(_) => None // Cookie found, good to go. + case None => // Didn't find cookie + Some(TemporaryRedirect(uri("/addcookie"))) + } + } + "We don't want to have a real 'root' route anyway... " ** GET |>> TemporaryRedirect(Uri(path="/swagger-ui")) @@ -45,4 +56,28 @@ object MyService extends RhoService with SwaggerSupport { val i = new AtomicInteger(0) Task(

{ s"The number is ${i.getAndIncrement()}" }

) } + + "Adds the cookie Foo=bar to the client" ** + GET / "addcookie" |>> { + Ok("You now have a good cookie!").addCookie("Foo", "bar") + } + + "Sets the cookie Foo=barr to the client" ** + GET / "addbadcookie" |>> { + Ok("You now have an evil cookie!").addCookie("Foo", "barr") + } + + "Checks the Foo cookie to make sure its 'bar'" ** + GET / "checkcookie" >>> requireCookie |>> Ok("Good job, you have the cookie!") + + "Clears the cookies" ** + GET / "clearcookies" |>> { req: Request => + val hs = req.headers.get(headers.Cookie) match { + case None => Headers.empty + case Some(cookie) => + Headers(cookie.values.toList.map{ c => headers.`Set-Cookie`(c.copy(expires = Some(DateTime.UnixEpoch), maxAge = Some(0))) }) + } + + Ok("Deleted cookies!").withHeaders(hs) + } }