Skip to content

Commit

Permalink
Support custom classloader for Redis deserialization
Browse files Browse the repository at this point in the history
This is needed when using ScalaCache with Play.

Fixes #32
  • Loading branch information
cb372 committed Mar 14, 2015
1 parent dd086c7 commit 59fdc59
Show file tree
Hide file tree
Showing 8 changed files with 217 additions and 93 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ target
coveralls-token.txt
/*.html
dump.rdb
logs
9 changes: 8 additions & 1 deletion project/Build.scala
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ object ScalaCacheBuild extends Build {
.settings(
libraryDependencies ++= Seq(
"redis.clients" % "jedis" % "2.6.0"
)
) ++ playTesting
)
.dependsOn(core)
.disablePlugins(CoverallsPlugin)
Expand All @@ -101,6 +101,12 @@ object ScalaCacheBuild extends Build {
Seq("org.scala-lang.modules" %% "scala-xml" % "1.0.1" % "test")
} else Nil)

val playVersion = "2.3.8"
lazy val playTesting = Seq(
"com.typesafe.play" %% "play-test" % playVersion % Test,
"org.scalatestplus" %% "play" % "1.2.0" % Test
)

// Dependencies common to all projects
lazy val commonDeps =
scalaLogging ++
Expand All @@ -120,6 +126,7 @@ object ScalaCacheBuild extends Build {
organization := "com.github.cb372",
scalaVersion := Versions.scala,
scalacOptions ++= Seq("-unchecked", "-deprecation", "-feature"),
resolvers += Resolver.typesafeRepo("releases"),
libraryDependencies ++= commonDeps,
parallelExecution in Test := false,
publishArtifactsAction := PgpKeys.publishSigned.value,
Expand Down
11 changes: 8 additions & 3 deletions redis/src/main/scala/scalacache/redis/RedisCache.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ import scala.concurrent.{ Future, ExecutionContext, blocking }

/**
* Thin wrapper around Jedis
* @param customClassloader a classloader to use when deserializing objects from the cache.
* If you are using Play, you should pass in `app.classloader`.
*/
class RedisCache(jedisPool: JedisPool)(implicit execContext: ExecutionContext = ExecutionContext.global)
class RedisCache(jedisPool: JedisPool, override val customClassloader: Option[ClassLoader] = None)(implicit execContext: ExecutionContext = ExecutionContext.global)
extends Cache
with RedisSerialization
with LoggingSupport
Expand Down Expand Up @@ -92,10 +94,13 @@ object RedisCache {
def apply(host: String, port: Int): RedisCache = apply(new JedisPool(host, port))

/**
* Create a cache that uses the given Jedis client pool
* Create a cache that uses the given Jedis client
* @param jedisPool a Jedis pool
* @param customClassloader a classloader to use when deserializing objects from the cache.
* If you are using Play, you should pass in `app.classloader`.
*/
def apply(jedisPool: JedisPool): RedisCache = new RedisCache(jedisPool)
def apply(jedisPool: JedisPool, customClassloader: Option[ClassLoader] = None): RedisCache =
new RedisCache(jedisPool, customClassloader)

private val utf8 = Charset.forName("UTF-8")

Expand Down
15 changes: 14 additions & 1 deletion redis/src/main/scala/scalacache/redis/RedisSerialization.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import java.io._
*/
trait RedisSerialization {

protected def customClassloader: Option[ClassLoader] = None

object MagicNumbers {
val STRING: Byte = 0
val BYTE_ARRAY: Byte = 1
Expand Down Expand Up @@ -42,7 +44,7 @@ trait RedisSerialization {
def deserialize[A](bytes: Array[Byte]): A = {
val bais = new ByteArrayInputStream(bytes)
val typeId = bais.read().toByte // Read the next byte to discover the type
val ois = new ObjectInputStream(bais) // The rest of the array is in ObjectInputStream format
val ois = createObjectInputStream(bais) // The rest of the array is in ObjectInputStream format
val result = typeId match {
case MagicNumbers.STRING => ois.readUTF()
case MagicNumbers.BYTE_ARRAY => {
Expand All @@ -59,4 +61,15 @@ trait RedisSerialization {
result.asInstanceOf[A]
}

private def createObjectInputStream(inputStream: InputStream): ObjectInputStream = customClassloader match {
case Some(classloader) => new ClassLoaderOIS(inputStream, classloader)
case None => new ObjectInputStream(inputStream)
}

}

class ClassLoaderOIS(stream: InputStream, customClassloader: ClassLoader) extends ObjectInputStream(stream) {
override protected def resolveClass(desc: ObjectStreamClass) = {
Class.forName(desc.getName, false, customClassloader)
}
}
35 changes: 35 additions & 0 deletions redis/src/test/scala/scalacache/redis/Issue32Spec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package scalacache.redis

import org.scalatest.{ BeforeAndAfter, Matchers, FlatSpec }

import scalacache._
import memoization._
import redis._

case class User(id: Int, name: String)

/**
* Test to check the sample code in issue #32 runs OK
* (just to isolate the use of the List[User] type from the Play classloader problem)
*/
class Issue32Spec
extends FlatSpec
with Matchers
with BeforeAndAfter
with RedisTestUtil {

assumingRedisIsRunning { (pool, client) =>

implicit val scalaCache = ScalaCache(RedisCache(pool))

def getUser(id: Int): List[User] = memoize {
List(User(id, "Taro"))
}

"memoize and Redis" should "work with List[User]" in {
getUser(1) should be(List(User(1, "Taro")))
getUser(1) should be(List(User(1, "Taro")))
}
}

}
48 changes: 48 additions & 0 deletions redis/src/test/scala/scalacache/redis/PlayIntegrationSpec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package scalacache.redis

import _root_.redis.clients.jedis.JedisPool
import org.scalatest.{ FlatSpec, Matchers, TestData }
import org.scalatestplus.play.OneAppPerTest
import play.api.test.FakeApplication
import play.api.{ Application, GlobalSettings }

import scalacache._
import scalacache.memoization._

class PlayIntegrationSpec extends FlatSpec with Matchers with OneAppPerTest {

override def newAppForTest(testData: TestData) = new FakeApplication(
withGlobal = Some(Global)
)

"Redis and memoization" should "work with Play in one application" in {
Global.getItems(List(1, 2)) should be(List(Item(1, "Chris"), Item(2, "Chris")))
Global.getItems(List(1, 2)) should be(List(Item(1, "Chris"), Item(2, "Chris")))
}

"Redis and memoization" should "work with Play in another application" in {
Global.getItems(List(1, 2)) should be(List(Item(1, "Chris"), Item(2, "Chris")))
Global.getItems(List(1, 2)) should be(List(Item(1, "Chris"), Item(2, "Chris")))
}

}

case class Item(id: Int, name: String)

object Global extends GlobalSettings {
@volatile implicit var jedisPool: JedisPool = _
@volatile implicit var scalaCache: ScalaCache = _

override def onStart(app: Application): Unit = {
jedisPool = new JedisPool("localhost", 6379)
scalaCache = ScalaCache(RedisCache(jedisPool, customClassloader = Some(app.classloader)))
}

override def onStop(app: Application): Unit = {
jedisPool.destroy()
}

def getItems(ids: List[Int]): List[Item] = memoize {
ids map { Item(_, "Chris") }
}
}
168 changes: 80 additions & 88 deletions redis/src/test/scala/scalacache/redis/RedisCacheSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,132 +6,124 @@ import org.scalatest.concurrent.{ ScalaFutures, Eventually, IntegrationPatience
import org.scalatest.time.{ Span, Seconds }

import scala.language.postfixOps
import scala.util.{ Success, Failure, Try }
import redis.clients.jedis.JedisPool
import scala.concurrent.Future
import scala.concurrent.ExecutionContext.Implicits.global

class RedisCacheSpec
extends FlatSpec with Matchers with Eventually with BeforeAndAfter with RedisSerialization with ScalaFutures with IntegrationPatience {

Try {
val jedisPool = new JedisPool("localhost", 6379)
val jedis = jedisPool.getResource()
try {
jedis.ping()
} finally {
jedis.close()
}
(jedisPool, jedis)
} match {
case Failure(_) => alert("Skipping tests because Redis does not appear to be running on localhost.")
case Success((pool, client)) => {
extends FlatSpec
with Matchers
with Eventually
with BeforeAndAfter
with RedisSerialization
with ScalaFutures
with IntegrationPatience
with RedisTestUtil {

val cache = RedisCache(pool)
assumingRedisIsRunning { (pool, client) =>

before {
client.flushDB()
}
val cache = RedisCache(pool)

behavior of "get"
before {
client.flushDB()
}

it should "return the value stored in Redis" in {
client.set(bytes("key1"), serialize(123))
whenReady(cache.get("key1")) { _ should be(Some(123)) }
}
behavior of "get"

it should "return None if the given key does not exist in the underlying cache" in {
whenReady(cache.get("non-existent-key")) { _ should be(None) }
}
it should "return the value stored in Redis" in {
client.set(bytes("key1"), serialize(123))
whenReady(cache.get("key1")) { _ should be(Some(123)) }
}

behavior of "put"
it should "return None if the given key does not exist in the underlying cache" in {
whenReady(cache.get("non-existent-key")) { _ should be(None) }
}

it should "store the given key-value pair in the underlying cache" in {
whenReady(cache.put("key2", 123, None)) { _ =>
deserialize[Int](client.get(bytes("key2"))) should be(123)
}
behavior of "put"

it should "store the given key-value pair in the underlying cache" in {
whenReady(cache.put("key2", 123, None)) { _ =>
deserialize[Int](client.get(bytes("key2"))) should be(123)
}
}

behavior of "put with TTL"
behavior of "put with TTL"

it should "store the given key-value pair in the underlying cache" in {
whenReady(cache.put("key3", 123, Some(1 second))) { _ =>
deserialize[Int](client.get(bytes("key3"))) should be(123)
it should "store the given key-value pair in the underlying cache" in {
whenReady(cache.put("key3", 123, Some(1 second))) { _ =>
deserialize[Int](client.get(bytes("key3"))) should be(123)

// Should expire after 1 second
eventually(timeout(Span(2, Seconds))) {
client.get(bytes("key3")) should be(null)
}
// Should expire after 1 second
eventually(timeout(Span(2, Seconds))) {
client.get(bytes("key3")) should be(null)
}
}
}

behavior of "put with TTL of zero"
behavior of "put with TTL of zero"

it should "store the given key-value pair in the underlying cache with no expiry" in {
whenReady(cache.put("key4", 123, Some(Duration.Zero))) { _ =>
deserialize[Int](client.get(bytes("key4"))) should be(123)
client.ttl("key4") should be(-1L)
}
it should "store the given key-value pair in the underlying cache with no expiry" in {
whenReady(cache.put("key4", 123, Some(Duration.Zero))) { _ =>
deserialize[Int](client.get(bytes("key4"))) should be(123)
client.ttl("key4") should be(-1L)
}
}

behavior of "put with TTL of less than 1 second"
behavior of "put with TTL of less than 1 second"

it should "store the given key-value pair in the underlying cache" in {
whenReady(cache.put("key5", 123, Some(100 milliseconds))) { _ =>
deserialize[Int](client.get(bytes("key5"))) should be(123)
client.pttl("key5").toLong should be > 0L
it should "store the given key-value pair in the underlying cache" in {
whenReady(cache.put("key5", 123, Some(100 milliseconds))) { _ =>
deserialize[Int](client.get(bytes("key5"))) should be(123)
client.pttl("key5").toLong should be > 0L

// Should expire after 1 second
eventually(timeout(Span(2, Seconds))) {
client.get("key5") should be(null)
}
// Should expire after 1 second
eventually(timeout(Span(2, Seconds))) {
client.get("key5") should be(null)
}
}
}

behavior of "caching with serialization"
behavior of "caching with serialization"

def roundTrip[V](key: String, value: V): Future[Option[V]] = {
cache.put(key, value, None).flatMap(_ => cache.get(key))
}
def roundTrip[V](key: String, value: V): Future[Option[V]] = {
cache.put(key, value, None).flatMap(_ => cache.get(key))
}

it should "round-trip a String" in {
whenReady(roundTrip("string", "hello")) { _ should be(Some("hello")) }
}
it should "round-trip a String" in {
whenReady(roundTrip("string", "hello")) { _ should be(Some("hello")) }
}

it should "round-trip a byte array" in {
whenReady(roundTrip("bytearray", bytes("world"))) { result =>
new String(result.get, "UTF-8") should be("world")
}
it should "round-trip a byte array" in {
whenReady(roundTrip("bytearray", bytes("world"))) { result =>
new String(result.get, "UTF-8") should be("world")
}
}

it should "round-trip an Int" in {
whenReady(roundTrip("int", 345)) { _ should be(Some(345)) }
}
it should "round-trip an Int" in {
whenReady(roundTrip("int", 345)) { _ should be(Some(345)) }
}

it should "round-trip a Double" in {
whenReady(roundTrip("double", 1.23)) { _ should be(Some(1.23)) }
}
it should "round-trip a Double" in {
whenReady(roundTrip("double", 1.23)) { _ should be(Some(1.23)) }
}

it should "round-trip a Long" in {
whenReady(roundTrip("long", 3456L)) { _ should be(Some(3456L)) }
}
it should "round-trip a Long" in {
whenReady(roundTrip("long", 3456L)) { _ should be(Some(3456L)) }
}

it should "round-trip a Serializable case class" in {
val cc = CaseClass(123, "wow")
whenReady(roundTrip("caseclass", cc)) { _ should be(Some(cc)) }
}
it should "round-trip a Serializable case class" in {
val cc = CaseClass(123, "wow")
whenReady(roundTrip("caseclass", cc)) { _ should be(Some(cc)) }
}

behavior of "remove"
behavior of "remove"

it should "delete the given key and its value from the underlying cache" in {
client.set(bytes("key1"), serialize(123))
deserialize[Int](client.get(bytes("key1"))) should be(123)
it should "delete the given key and its value from the underlying cache" in {
client.set(bytes("key1"), serialize(123))
deserialize[Int](client.get(bytes("key1"))) should be(123)

whenReady(cache.remove("key1")) { _ =>
client.get("key1") should be(null)
}
whenReady(cache.remove("key1")) { _ =>
client.get("key1") should be(null)
}

}

}
Expand Down
Loading

0 comments on commit 59fdc59

Please sign in to comment.