diff --git a/core/src/main/scala/org/scalatra/RenderPipeline.scala b/core/src/main/scala/org/scalatra/RenderPipeline.scala new file mode 100644 index 000000000..bd94194bf --- /dev/null +++ b/core/src/main/scala/org/scalatra/RenderPipeline.scala @@ -0,0 +1,76 @@ +package org.scalatra + +import collection.mutable +import java.io.{FileInputStream, File} +import util.using +import util.io.copy + +// Perhaps making renderResponseBody a stackable method this would also give a render pipeline maybe even a better one at that +//trait RenderResponseBody { +// def renderResponseBody(actionResult: Any) +//} + +/** + * Allows overriding and chaining of response body rendering. Overrides [[ScalatraKernel#renderResponseBody]]. + */ +trait RenderPipeline {this: ScalatraKernel => + + object ActionRenderer{ + def apply[A: ClassManifest](fun: A => Any) = new ActionRenderer(fun) + } + private[scalatra] class ActionRenderer[A: ClassManifest](fun: A => Any) extends PartialFunction[Any, Any] { + def apply(v1: Any) = fun(v1.asInstanceOf[A]) + def isDefinedAt(x: Any) = implicitly[ClassManifest[A]].erasure.isInstance(x) + } + + private type RenderAction = PartialFunction[Any, Any] + protected val renderPipeline = new mutable.ArrayBuffer[RenderAction] with mutable.SynchronizedBuffer[RenderAction] + + override def renderResponseBody(actionResult: Any) { + (useRenderPipeline orElse defaultRenderResponse) apply actionResult + } + + private def useRenderPipeline: PartialFunction[Any, Any] = { + case pipelined if renderPipeline.exists(_.isDefinedAt(pipelined)) => { + (pipelined /: renderPipeline) { + case (body, renderer) if (renderer.isDefinedAt(body)) => renderer(body) + case (body, _) => body + } + } + } + + private def defaultRenderResponse: PartialFunction[Any, Any] = { + case bytes: Array[Byte] => + response.getOutputStream.write(bytes) + case file: File => + using(new FileInputStream(file)) { in => copy(in, response.getOutputStream) } + case _: Unit => + // If an action returns Unit, it assumes responsibility for the response + case x: Any => + response.getWriter.print(x.toString) + } + + /** + * Prepend a new renderer to the front of the render pipeline. + */ + def render[A: Manifest](fun: A => Any) { + ActionRenderer(fun) +=: renderPipeline + } + + +} + +trait DefaultRendererPipeline { self: ScalatraKernel with RenderPipeline => + render[Any] { + case _: Unit => // If an action or renderer returns Unit, it assumes responsibility for the response + case x => response.getWriter.print(x.toString) + } + + render[File] {file => + using(new FileInputStream(file)) {in => copy(in, response.getOutputStream)} + } + + render[Array[Byte]] {bytes => + response.getOutputStream.write(bytes) + } +} \ No newline at end of file diff --git a/core/src/main/scala/org/scalatra/ScalatraKernel.scala b/core/src/main/scala/org/scalatra/ScalatraKernel.scala index a118040e8..56969c73b 100644 --- a/core/src/main/scala/org/scalatra/ScalatraKernel.scala +++ b/core/src/main/scala/org/scalatra/ScalatraKernel.scala @@ -5,7 +5,7 @@ import javax.servlet.http._ import scala.util.DynamicVariable import scala.util.matching.Regex import scala.collection.JavaConversions._ -import scala.collection.mutable.{ConcurrentMap, HashMap, ListBuffer} +import scala.collection.mutable.{ConcurrentMap, ListBuffer} import scala.xml.NodeSeq import util.io.copy import java.io.{File, FileInputStream} @@ -38,23 +38,27 @@ import ScalatraKernel._ * methods register a new action to a route for a given HTTP method, possibly * overwriting a previous one. This trait is thread safe. */ -trait ScalatraKernel extends Handler with Initializable +trait ScalatraKernel extends Handler with Initializable //with RenderResponseBody { + protected implicit def map2multimap(map: Map[String, Seq[String]]) = new MultiMap(map) + protected val Routes: ConcurrentMap[String, List[Route]] = { val map = new ConcurrentHashMap[String, List[Route]] httpMethods foreach { x: String => map += ((x, List[Route]())) } map } - def contentType = response.getContentType - def contentType_=(value: String): Unit = response.setContentType(value) + def contentType_=(value: String) { + response.setContentType(value) + } protected val defaultCharacterEncoding = "UTF-8" protected val _response = new DynamicVariable[HttpServletResponse](null) - protected val _request = new DynamicVariable[HttpServletRequest](null) + protected val _request = new DynamicVariable[HttpServletRequest](null) protected implicit def requestWrapper(r: HttpServletRequest) = RichRequest(r) protected implicit def sessionWrapper(s: HttpSession) = new RichSession(s) + protected implicit def servletContextWrapper(sc: ServletContext) = new RichServletContext(sc) protected[scalatra] class Route(val routeMatchers: Iterable[RouteMatcher], val action: Action) { @@ -73,7 +77,6 @@ trait ScalatraKernel extends Handler with Initializable override def toString = routeMatchers.toString() } - protected implicit def map2multimap(map: Map[String, Seq[String]]) = new MultiMap(map) /** * Pluggable way to convert Strings into RouteMatchers. By default, we * interpret them the same way Sinatra does. @@ -91,7 +94,7 @@ trait ScalatraKernel extends Handler with Initializable // By overriding toString, we can list the available routes in the // default notFound handler. - override def toString = pattern.regex.toString + override def toString() = pattern.regex.toString() } protected implicit def regex2RouteMatcher(regex: Regex): RouteMatcher = new RouteMatcher { @@ -100,7 +103,7 @@ trait ScalatraKernel extends Handler with Initializable case xs => Map("captures" -> xs) }} - override def toString = regex.toString + override def toString() = regex.toString() } protected implicit def booleanBlock2RouteMatcher(matcher: => Boolean): RouteMatcher = @@ -122,23 +125,29 @@ trait ScalatraKernel extends Handler with Initializable _multiParams.withValue(Map() ++ realMultiParams) { val result = try { beforeFilters foreach { _() } - Routes(effectiveMethod).toStream.flatMap { _(requestPath) }.headOption.getOrElse(doNotFound()) + val res = Routes(effectiveMethod).toStream.flatMap { _(requestPath) }.headOption.getOrElse(doNotFound()) + renderResponse(res) } catch { - case HaltException(Some(code), Some(msg)) => response.sendError(code, msg) - case HaltException(Some(code), None) => response.sendError(code) - case HaltException(None, _) => - case e => handleError(e) + case e => renderResponse((renderError orElse internalRenderError).apply(e)) } finally { afterFilters foreach { _() } } - renderResponse(result) } } } } + type ErrorRenderer = PartialFunction[Throwable, Any] + def renderError: ErrorRenderer = internalRenderError + private def internalRenderError: ErrorRenderer = { + case HaltException(Some(code), Some(msg)) => response.sendError(code, msg) + case HaltException(Some(code), None) => response.sendError(code) + case HaltException(None, _) => + case e => handleError(e) + } + protected def effectiveMethod = request.getMethod.toUpperCase match { case "HEAD" => "GET" case x => x @@ -153,7 +162,11 @@ trait ScalatraKernel extends Handler with Initializable def after(fun: => Any) = afterFilters += { () => fun } protected var doNotFound: Action - def notFound(fun: => Any) = doNotFound = { () => fun } + def notFound(fun: => Any) { + doNotFound = { + () => fun + } + } protected def handleError(e: Throwable): Any = { status(HttpServletResponse.SC_INTERNAL_SERVER_ERROR) @@ -161,7 +174,11 @@ trait ScalatraKernel extends Handler with Initializable } protected var errorHandler: Action = { () => throw caughtThrowable } - def error(fun: => Any) = errorHandler = { () => fun } + def error(fun: => Any) { + errorHandler = { + () => fun + } + } private val _caughtThrowable = new DynamicVariable[Throwable](null) protected def caughtThrowable = _caughtThrowable.value @@ -202,7 +219,9 @@ trait ScalatraKernel extends Handler with Initializable } def params = _params - def redirect(uri: String) = (_response value) sendRedirect uri + def redirect(uri: String) { + (_response value) sendRedirect uri + } implicit def request = _request value implicit def response = _response value def session = request.getSession @@ -210,12 +229,14 @@ trait ScalatraKernel extends Handler with Initializable case s: HttpSession => Some(s) case null => None } - def status(code: Int) = (_response value) setStatus code + def status(code: Int) { + (_response value) setStatus code + } def halt(code: Int, msg: String) = throw new HaltException(Some(code), Some(msg)) def halt(code: Int) = throw new HaltException(Some(code), None) def halt() = throw new HaltException(None, None) - private case class HaltException(val code: Option[Int], val msg: Option[String]) extends RuntimeException + protected[scalatra] case class HaltException(code: Option[Int], msg: Option[String]) extends RuntimeException def pass() = throw new PassException protected[scalatra] class PassException extends RuntimeException @@ -294,7 +315,7 @@ trait ScalatraKernel extends Handler with Initializable * * @see addRoute */ - protected def removeRoute(verb: String, route: Route): Unit = { + protected def removeRoute(verb: String, route: Route) { modifyRoutes(verb, _ filterNot (_ == route) ) route } @@ -303,7 +324,7 @@ trait ScalatraKernel extends Handler with Initializable * since routes is a ConcurrentMap and we avoid locking, we need to retry if there are * concurrent modifications, this is abstracted here for removeRoute and addRoute */ - @tailrec private def modifyRoutes(protocol: String, f: (List[Route] => List[Route])): Unit = { + @tailrec private def modifyRoutes(protocol: String, f: (List[Route] => List[Route])) { val oldRoutes = Routes(protocol) if (!Routes.replace(protocol, oldRoutes, f(oldRoutes))) { modifyRoutes(protocol,f) @@ -311,7 +332,9 @@ trait ScalatraKernel extends Handler with Initializable } private var config: Config = _ - def initialize(config: Config) = this.config = config + def initialize(config: Config) { + this.config = config + } def initParameter(name: String): Option[String] = config match { case config: ServletConfig => Option(config.getInitParameter(name)) diff --git a/core/src/main/scala/org/scalatra/ScalatraServlet.scala b/core/src/main/scala/org/scalatra/ScalatraServlet.scala index 2ee7a3a19..1f003cfab 100644 --- a/core/src/main/scala/org/scalatra/ScalatraServlet.scala +++ b/core/src/main/scala/org/scalatra/ScalatraServlet.scala @@ -45,3 +45,5 @@ abstract class ScalatraServlet override def initialize(config: ServletConfig): Unit = super.initialize(config) } + +abstract class ScalatraPipelinedServlet extends ScalatraServlet with RenderPipeline with DefaultRendererPipeline diff --git a/core/src/main/scala/org/scalatra/util/package.scala b/core/src/main/scala/org/scalatra/util/package.scala index ed992d214..708252e5f 100644 --- a/core/src/main/scala/org/scalatra/util/package.scala +++ b/core/src/main/scala/org/scalatra/util/package.scala @@ -9,7 +9,7 @@ package object util { * @param closeable the closeable resource * @param f the block */ - def using[A, B <: { def close(): Unit }](closeable: B)(f: B => A) { + def using[A, B <: { def close() }](closeable: B)(f: B => A) { try { f(closeable) } diff --git a/core/src/test/scala/org/scalatra/RenderPipelineTest.scala b/core/src/test/scala/org/scalatra/RenderPipelineTest.scala new file mode 100644 index 000000000..289b2bab2 --- /dev/null +++ b/core/src/test/scala/org/scalatra/RenderPipelineTest.scala @@ -0,0 +1,48 @@ +package org.scalatra + +import org.scalatest.matchers.ShouldMatchers +import test.scalatest.ScalatraFunSuite + + +class RenderPipelineTestServlet extends ScalatraPipelinedServlet { + + render[String] { + case s @ "the string to render" => response.getWriter print ("Rendering string: %s" format s) + case s => "Augmenting string: " + s + } + + get("/any") { + 11111 + } + + get("/string") { + "the string to render" + } + + get("/augment") { + "yet another string" + } +} + +class RenderPipelineTest extends ScalatraFunSuite with ShouldMatchers { + + addServlet(new RenderPipelineTestServlet, "/*") + + test("should still render defaults") { + get("/any") { + body should equal("11111") + } + } + + test("should render the string") { + get("/string") { + body should equal("Rendering string: the string to render") + } + } + + test("should augment a string") { + get("/augment") { + body should equal("Augmenting string: yet another string") + } + } +} \ No newline at end of file diff --git a/example/src/main/scala/org/scalatra/RenderPipelineExample.scala b/example/src/main/scala/org/scalatra/RenderPipelineExample.scala new file mode 100644 index 000000000..3612fb0e1 --- /dev/null +++ b/example/src/main/scala/org/scalatra/RenderPipelineExample.scala @@ -0,0 +1,22 @@ +package org.scalatra + + + +class RenderPipelineExample extends ScalatraPipelinedServlet { + + render[String] { + case s => "rendering: " + s + } + + render[List[String]] { + case l => "Rendering list:" + l.mkString("
\n", "
\n", "
\n") + } + + get("/?") { + "hello I'm rendering" + } + + get("/list") { + "first" :: "second" :: "third" :: "fourth" :: Nil + } +} \ No newline at end of file diff --git a/example/src/main/webapp/WEB-INF/web.xml b/example/src/main/webapp/WEB-INF/web.xml index f21b8ba66..75e4ebf18 100644 --- a/example/src/main/webapp/WEB-INF/web.xml +++ b/example/src/main/webapp/WEB-INF/web.xml @@ -8,6 +8,10 @@ PUBLIC "-//Sun Microsystems, Inc.//DTD Web Application 2.2//EN" TemplateExample org.scalatra.TemplateExample + + RenderPipelineExample + org.scalatra.RenderPipelineExample + BasicAuthExample org.scalatra.BasicAuthExample @@ -57,6 +61,10 @@ PUBLIC "-//Sun Microsystems, Inc.//DTD Web Application 2.2//EN" + + RenderPipelineExample + /pipelined/* + ChatApplication /socket.io/*