diff --git a/core/src/main/scala/org/scalatra/RenderPipeline.scala b/core/src/main/scala/org/scalatra/RenderPipeline.scala
new file mode 100644
index 000000000..bd94194bf
--- /dev/null
+++ b/core/src/main/scala/org/scalatra/RenderPipeline.scala
@@ -0,0 +1,76 @@
+package org.scalatra
+
+import collection.mutable
+import java.io.{FileInputStream, File}
+import util.using
+import util.io.copy
+
+// Perhaps making renderResponseBody a stackable method this would also give a render pipeline maybe even a better one at that
+//trait RenderResponseBody {
+// def renderResponseBody(actionResult: Any)
+//}
+
+/**
+ * Allows overriding and chaining of response body rendering. Overrides [[ScalatraKernel#renderResponseBody]].
+ */
+trait RenderPipeline {this: ScalatraKernel =>
+
+ object ActionRenderer{
+ def apply[A: ClassManifest](fun: A => Any) = new ActionRenderer(fun)
+ }
+ private[scalatra] class ActionRenderer[A: ClassManifest](fun: A => Any) extends PartialFunction[Any, Any] {
+ def apply(v1: Any) = fun(v1.asInstanceOf[A])
+ def isDefinedAt(x: Any) = implicitly[ClassManifest[A]].erasure.isInstance(x)
+ }
+
+ private type RenderAction = PartialFunction[Any, Any]
+ protected val renderPipeline = new mutable.ArrayBuffer[RenderAction] with mutable.SynchronizedBuffer[RenderAction]
+
+ override def renderResponseBody(actionResult: Any) {
+ (useRenderPipeline orElse defaultRenderResponse) apply actionResult
+ }
+
+ private def useRenderPipeline: PartialFunction[Any, Any] = {
+ case pipelined if renderPipeline.exists(_.isDefinedAt(pipelined)) => {
+ (pipelined /: renderPipeline) {
+ case (body, renderer) if (renderer.isDefinedAt(body)) => renderer(body)
+ case (body, _) => body
+ }
+ }
+ }
+
+ private def defaultRenderResponse: PartialFunction[Any, Any] = {
+ case bytes: Array[Byte] =>
+ response.getOutputStream.write(bytes)
+ case file: File =>
+ using(new FileInputStream(file)) { in => copy(in, response.getOutputStream) }
+ case _: Unit =>
+ // If an action returns Unit, it assumes responsibility for the response
+ case x: Any =>
+ response.getWriter.print(x.toString)
+ }
+
+ /**
+ * Prepend a new renderer to the front of the render pipeline.
+ */
+ def render[A: Manifest](fun: A => Any) {
+ ActionRenderer(fun) +=: renderPipeline
+ }
+
+
+}
+
+trait DefaultRendererPipeline { self: ScalatraKernel with RenderPipeline =>
+ render[Any] {
+ case _: Unit => // If an action or renderer returns Unit, it assumes responsibility for the response
+ case x => response.getWriter.print(x.toString)
+ }
+
+ render[File] {file =>
+ using(new FileInputStream(file)) {in => copy(in, response.getOutputStream)}
+ }
+
+ render[Array[Byte]] {bytes =>
+ response.getOutputStream.write(bytes)
+ }
+}
\ No newline at end of file
diff --git a/core/src/main/scala/org/scalatra/ScalatraKernel.scala b/core/src/main/scala/org/scalatra/ScalatraKernel.scala
index a118040e8..56969c73b 100644
--- a/core/src/main/scala/org/scalatra/ScalatraKernel.scala
+++ b/core/src/main/scala/org/scalatra/ScalatraKernel.scala
@@ -5,7 +5,7 @@ import javax.servlet.http._
import scala.util.DynamicVariable
import scala.util.matching.Regex
import scala.collection.JavaConversions._
-import scala.collection.mutable.{ConcurrentMap, HashMap, ListBuffer}
+import scala.collection.mutable.{ConcurrentMap, ListBuffer}
import scala.xml.NodeSeq
import util.io.copy
import java.io.{File, FileInputStream}
@@ -38,23 +38,27 @@ import ScalatraKernel._
* methods register a new action to a route for a given HTTP method, possibly
* overwriting a previous one. This trait is thread safe.
*/
-trait ScalatraKernel extends Handler with Initializable
+trait ScalatraKernel extends Handler with Initializable //with RenderResponseBody
{
+ protected implicit def map2multimap(map: Map[String, Seq[String]]) = new MultiMap(map)
+
protected val Routes: ConcurrentMap[String, List[Route]] = {
val map = new ConcurrentHashMap[String, List[Route]]
httpMethods foreach { x: String => map += ((x, List[Route]())) }
map
}
-
def contentType = response.getContentType
- def contentType_=(value: String): Unit = response.setContentType(value)
+ def contentType_=(value: String) {
+ response.setContentType(value)
+ }
protected val defaultCharacterEncoding = "UTF-8"
protected val _response = new DynamicVariable[HttpServletResponse](null)
- protected val _request = new DynamicVariable[HttpServletRequest](null)
+ protected val _request = new DynamicVariable[HttpServletRequest](null)
protected implicit def requestWrapper(r: HttpServletRequest) = RichRequest(r)
protected implicit def sessionWrapper(s: HttpSession) = new RichSession(s)
+
protected implicit def servletContextWrapper(sc: ServletContext) = new RichServletContext(sc)
protected[scalatra] class Route(val routeMatchers: Iterable[RouteMatcher], val action: Action) {
@@ -73,7 +77,6 @@ trait ScalatraKernel extends Handler with Initializable
override def toString = routeMatchers.toString()
}
- protected implicit def map2multimap(map: Map[String, Seq[String]]) = new MultiMap(map)
/**
* Pluggable way to convert Strings into RouteMatchers. By default, we
* interpret them the same way Sinatra does.
@@ -91,7 +94,7 @@ trait ScalatraKernel extends Handler with Initializable
// By overriding toString, we can list the available routes in the
// default notFound handler.
- override def toString = pattern.regex.toString
+ override def toString() = pattern.regex.toString()
}
protected implicit def regex2RouteMatcher(regex: Regex): RouteMatcher = new RouteMatcher {
@@ -100,7 +103,7 @@ trait ScalatraKernel extends Handler with Initializable
case xs => Map("captures" -> xs)
}}
- override def toString = regex.toString
+ override def toString() = regex.toString()
}
protected implicit def booleanBlock2RouteMatcher(matcher: => Boolean): RouteMatcher =
@@ -122,23 +125,29 @@ trait ScalatraKernel extends Handler with Initializable
_multiParams.withValue(Map() ++ realMultiParams) {
val result = try {
beforeFilters foreach { _() }
- Routes(effectiveMethod).toStream.flatMap { _(requestPath) }.headOption.getOrElse(doNotFound())
+ val res = Routes(effectiveMethod).toStream.flatMap { _(requestPath) }.headOption.getOrElse(doNotFound())
+ renderResponse(res)
}
catch {
- case HaltException(Some(code), Some(msg)) => response.sendError(code, msg)
- case HaltException(Some(code), None) => response.sendError(code)
- case HaltException(None, _) =>
- case e => handleError(e)
+ case e => renderResponse((renderError orElse internalRenderError).apply(e))
}
finally {
afterFilters foreach { _() }
}
- renderResponse(result)
}
}
}
}
+ type ErrorRenderer = PartialFunction[Throwable, Any]
+ def renderError: ErrorRenderer = internalRenderError
+ private def internalRenderError: ErrorRenderer = {
+ case HaltException(Some(code), Some(msg)) => response.sendError(code, msg)
+ case HaltException(Some(code), None) => response.sendError(code)
+ case HaltException(None, _) =>
+ case e => handleError(e)
+ }
+
protected def effectiveMethod = request.getMethod.toUpperCase match {
case "HEAD" => "GET"
case x => x
@@ -153,7 +162,11 @@ trait ScalatraKernel extends Handler with Initializable
def after(fun: => Any) = afterFilters += { () => fun }
protected var doNotFound: Action
- def notFound(fun: => Any) = doNotFound = { () => fun }
+ def notFound(fun: => Any) {
+ doNotFound = {
+ () => fun
+ }
+ }
protected def handleError(e: Throwable): Any = {
status(HttpServletResponse.SC_INTERNAL_SERVER_ERROR)
@@ -161,7 +174,11 @@ trait ScalatraKernel extends Handler with Initializable
}
protected var errorHandler: Action = { () => throw caughtThrowable }
- def error(fun: => Any) = errorHandler = { () => fun }
+ def error(fun: => Any) {
+ errorHandler = {
+ () => fun
+ }
+ }
private val _caughtThrowable = new DynamicVariable[Throwable](null)
protected def caughtThrowable = _caughtThrowable.value
@@ -202,7 +219,9 @@ trait ScalatraKernel extends Handler with Initializable
}
def params = _params
- def redirect(uri: String) = (_response value) sendRedirect uri
+ def redirect(uri: String) {
+ (_response value) sendRedirect uri
+ }
implicit def request = _request value
implicit def response = _response value
def session = request.getSession
@@ -210,12 +229,14 @@ trait ScalatraKernel extends Handler with Initializable
case s: HttpSession => Some(s)
case null => None
}
- def status(code: Int) = (_response value) setStatus code
+ def status(code: Int) {
+ (_response value) setStatus code
+ }
def halt(code: Int, msg: String) = throw new HaltException(Some(code), Some(msg))
def halt(code: Int) = throw new HaltException(Some(code), None)
def halt() = throw new HaltException(None, None)
- private case class HaltException(val code: Option[Int], val msg: Option[String]) extends RuntimeException
+ protected[scalatra] case class HaltException(code: Option[Int], msg: Option[String]) extends RuntimeException
def pass() = throw new PassException
protected[scalatra] class PassException extends RuntimeException
@@ -294,7 +315,7 @@ trait ScalatraKernel extends Handler with Initializable
*
* @see addRoute
*/
- protected def removeRoute(verb: String, route: Route): Unit = {
+ protected def removeRoute(verb: String, route: Route) {
modifyRoutes(verb, _ filterNot (_ == route) )
route
}
@@ -303,7 +324,7 @@ trait ScalatraKernel extends Handler with Initializable
* since routes is a ConcurrentMap and we avoid locking, we need to retry if there are
* concurrent modifications, this is abstracted here for removeRoute and addRoute
*/
- @tailrec private def modifyRoutes(protocol: String, f: (List[Route] => List[Route])): Unit = {
+ @tailrec private def modifyRoutes(protocol: String, f: (List[Route] => List[Route])) {
val oldRoutes = Routes(protocol)
if (!Routes.replace(protocol, oldRoutes, f(oldRoutes))) {
modifyRoutes(protocol,f)
@@ -311,7 +332,9 @@ trait ScalatraKernel extends Handler with Initializable
}
private var config: Config = _
- def initialize(config: Config) = this.config = config
+ def initialize(config: Config) {
+ this.config = config
+ }
def initParameter(name: String): Option[String] = config match {
case config: ServletConfig => Option(config.getInitParameter(name))
diff --git a/core/src/main/scala/org/scalatra/ScalatraServlet.scala b/core/src/main/scala/org/scalatra/ScalatraServlet.scala
index 2ee7a3a19..1f003cfab 100644
--- a/core/src/main/scala/org/scalatra/ScalatraServlet.scala
+++ b/core/src/main/scala/org/scalatra/ScalatraServlet.scala
@@ -45,3 +45,5 @@ abstract class ScalatraServlet
override def initialize(config: ServletConfig): Unit = super.initialize(config)
}
+
+abstract class ScalatraPipelinedServlet extends ScalatraServlet with RenderPipeline with DefaultRendererPipeline
diff --git a/core/src/main/scala/org/scalatra/util/package.scala b/core/src/main/scala/org/scalatra/util/package.scala
index ed992d214..708252e5f 100644
--- a/core/src/main/scala/org/scalatra/util/package.scala
+++ b/core/src/main/scala/org/scalatra/util/package.scala
@@ -9,7 +9,7 @@ package object util {
* @param closeable the closeable resource
* @param f the block
*/
- def using[A, B <: { def close(): Unit }](closeable: B)(f: B => A) {
+ def using[A, B <: { def close() }](closeable: B)(f: B => A) {
try {
f(closeable)
}
diff --git a/core/src/test/scala/org/scalatra/RenderPipelineTest.scala b/core/src/test/scala/org/scalatra/RenderPipelineTest.scala
new file mode 100644
index 000000000..289b2bab2
--- /dev/null
+++ b/core/src/test/scala/org/scalatra/RenderPipelineTest.scala
@@ -0,0 +1,48 @@
+package org.scalatra
+
+import org.scalatest.matchers.ShouldMatchers
+import test.scalatest.ScalatraFunSuite
+
+
+class RenderPipelineTestServlet extends ScalatraPipelinedServlet {
+
+ render[String] {
+ case s @ "the string to render" => response.getWriter print ("Rendering string: %s" format s)
+ case s => "Augmenting string: " + s
+ }
+
+ get("/any") {
+ 11111
+ }
+
+ get("/string") {
+ "the string to render"
+ }
+
+ get("/augment") {
+ "yet another string"
+ }
+}
+
+class RenderPipelineTest extends ScalatraFunSuite with ShouldMatchers {
+
+ addServlet(new RenderPipelineTestServlet, "/*")
+
+ test("should still render defaults") {
+ get("/any") {
+ body should equal("11111")
+ }
+ }
+
+ test("should render the string") {
+ get("/string") {
+ body should equal("Rendering string: the string to render")
+ }
+ }
+
+ test("should augment a string") {
+ get("/augment") {
+ body should equal("Augmenting string: yet another string")
+ }
+ }
+}
\ No newline at end of file
diff --git a/example/src/main/scala/org/scalatra/RenderPipelineExample.scala b/example/src/main/scala/org/scalatra/RenderPipelineExample.scala
new file mode 100644
index 000000000..3612fb0e1
--- /dev/null
+++ b/example/src/main/scala/org/scalatra/RenderPipelineExample.scala
@@ -0,0 +1,22 @@
+package org.scalatra
+
+
+
+class RenderPipelineExample extends ScalatraPipelinedServlet {
+
+ render[String] {
+ case s => "rendering: " + s
+ }
+
+ render[List[String]] {
+ case l => "Rendering list:" + l.mkString("
\n", "
\n", "
\n")
+ }
+
+ get("/?") {
+ "hello I'm rendering"
+ }
+
+ get("/list") {
+ "first" :: "second" :: "third" :: "fourth" :: Nil
+ }
+}
\ No newline at end of file
diff --git a/example/src/main/webapp/WEB-INF/web.xml b/example/src/main/webapp/WEB-INF/web.xml
index f21b8ba66..75e4ebf18 100644
--- a/example/src/main/webapp/WEB-INF/web.xml
+++ b/example/src/main/webapp/WEB-INF/web.xml
@@ -8,6 +8,10 @@ PUBLIC "-//Sun Microsystems, Inc.//DTD Web Application 2.2//EN"
TemplateExample
org.scalatra.TemplateExample
+
+ RenderPipelineExample
+ org.scalatra.RenderPipelineExample
+
BasicAuthExample
org.scalatra.BasicAuthExample
@@ -57,6 +61,10 @@ PUBLIC "-//Sun Microsystems, Inc.//DTD Web Application 2.2//EN"
+
+ RenderPipelineExample
+ /pipelined/*
+
ChatApplication
/socket.io/*