From 59fdc59a8bdf07df6bef20c2861ccd5a2542a6d8 Mon Sep 17 00:00:00 2001 From: Chris Birchall Date: Mon, 9 Mar 2015 20:50:34 +0000 Subject: [PATCH] Support custom classloader for Redis deserialization This is needed when using ScalaCache with Play. Fixes #32 --- .gitignore | 1 + project/Build.scala | 9 +- .../scala/scalacache/redis/RedisCache.scala | 11 +- .../scalacache/redis/RedisSerialization.scala | 15 +- .../scala/scalacache/redis/Issue32Spec.scala | 35 ++++ .../redis/PlayIntegrationSpec.scala | 48 +++++ .../scalacache/redis/RedisCacheSpec.scala | 168 +++++++++--------- .../scalacache/redis/RedisTestUtil.scala | 23 +++ 8 files changed, 217 insertions(+), 93 deletions(-) create mode 100644 redis/src/test/scala/scalacache/redis/Issue32Spec.scala create mode 100644 redis/src/test/scala/scalacache/redis/PlayIntegrationSpec.scala create mode 100644 redis/src/test/scala/scalacache/redis/RedisTestUtil.scala diff --git a/.gitignore b/.gitignore index eeb3224a..9dedaeb6 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ target coveralls-token.txt /*.html dump.rdb +logs diff --git a/project/Build.scala b/project/Build.scala index ba63b12e..a534bba3 100644 --- a/project/Build.scala +++ b/project/Build.scala @@ -80,7 +80,7 @@ object ScalaCacheBuild extends Build { .settings( libraryDependencies ++= Seq( "redis.clients" % "jedis" % "2.6.0" - ) + ) ++ playTesting ) .dependsOn(core) .disablePlugins(CoverallsPlugin) @@ -101,6 +101,12 @@ object ScalaCacheBuild extends Build { Seq("org.scala-lang.modules" %% "scala-xml" % "1.0.1" % "test") } else Nil) + val playVersion = "2.3.8" + lazy val playTesting = Seq( + "com.typesafe.play" %% "play-test" % playVersion % Test, + "org.scalatestplus" %% "play" % "1.2.0" % Test + ) + // Dependencies common to all projects lazy val commonDeps = scalaLogging ++ @@ -120,6 +126,7 @@ object ScalaCacheBuild extends Build { organization := "com.github.cb372", scalaVersion := Versions.scala, scalacOptions ++= Seq("-unchecked", "-deprecation", "-feature"), + resolvers += Resolver.typesafeRepo("releases"), libraryDependencies ++= commonDeps, parallelExecution in Test := false, publishArtifactsAction := PgpKeys.publishSigned.value, diff --git a/redis/src/main/scala/scalacache/redis/RedisCache.scala b/redis/src/main/scala/scalacache/redis/RedisCache.scala index cdef07cf..49d14e03 100644 --- a/redis/src/main/scala/scalacache/redis/RedisCache.scala +++ b/redis/src/main/scala/scalacache/redis/RedisCache.scala @@ -9,8 +9,10 @@ import scala.concurrent.{ Future, ExecutionContext, blocking } /** * Thin wrapper around Jedis + * @param customClassloader a classloader to use when deserializing objects from the cache. + * If you are using Play, you should pass in `app.classloader`. */ -class RedisCache(jedisPool: JedisPool)(implicit execContext: ExecutionContext = ExecutionContext.global) +class RedisCache(jedisPool: JedisPool, override val customClassloader: Option[ClassLoader] = None)(implicit execContext: ExecutionContext = ExecutionContext.global) extends Cache with RedisSerialization with LoggingSupport @@ -92,10 +94,13 @@ object RedisCache { def apply(host: String, port: Int): RedisCache = apply(new JedisPool(host, port)) /** - * Create a cache that uses the given Jedis client pool + * Create a cache that uses the given Jedis client * @param jedisPool a Jedis pool + * @param customClassloader a classloader to use when deserializing objects from the cache. + * If you are using Play, you should pass in `app.classloader`. */ - def apply(jedisPool: JedisPool): RedisCache = new RedisCache(jedisPool) + def apply(jedisPool: JedisPool, customClassloader: Option[ClassLoader] = None): RedisCache = + new RedisCache(jedisPool, customClassloader) private val utf8 = Charset.forName("UTF-8") diff --git a/redis/src/main/scala/scalacache/redis/RedisSerialization.scala b/redis/src/main/scala/scalacache/redis/RedisSerialization.scala index bc93144b..5d5bdb58 100644 --- a/redis/src/main/scala/scalacache/redis/RedisSerialization.scala +++ b/redis/src/main/scala/scalacache/redis/RedisSerialization.scala @@ -9,6 +9,8 @@ import java.io._ */ trait RedisSerialization { + protected def customClassloader: Option[ClassLoader] = None + object MagicNumbers { val STRING: Byte = 0 val BYTE_ARRAY: Byte = 1 @@ -42,7 +44,7 @@ trait RedisSerialization { def deserialize[A](bytes: Array[Byte]): A = { val bais = new ByteArrayInputStream(bytes) val typeId = bais.read().toByte // Read the next byte to discover the type - val ois = new ObjectInputStream(bais) // The rest of the array is in ObjectInputStream format + val ois = createObjectInputStream(bais) // The rest of the array is in ObjectInputStream format val result = typeId match { case MagicNumbers.STRING => ois.readUTF() case MagicNumbers.BYTE_ARRAY => { @@ -59,4 +61,15 @@ trait RedisSerialization { result.asInstanceOf[A] } + private def createObjectInputStream(inputStream: InputStream): ObjectInputStream = customClassloader match { + case Some(classloader) => new ClassLoaderOIS(inputStream, classloader) + case None => new ObjectInputStream(inputStream) + } + +} + +class ClassLoaderOIS(stream: InputStream, customClassloader: ClassLoader) extends ObjectInputStream(stream) { + override protected def resolveClass(desc: ObjectStreamClass) = { + Class.forName(desc.getName, false, customClassloader) + } } diff --git a/redis/src/test/scala/scalacache/redis/Issue32Spec.scala b/redis/src/test/scala/scalacache/redis/Issue32Spec.scala new file mode 100644 index 00000000..9db6a5a4 --- /dev/null +++ b/redis/src/test/scala/scalacache/redis/Issue32Spec.scala @@ -0,0 +1,35 @@ +package scalacache.redis + +import org.scalatest.{ BeforeAndAfter, Matchers, FlatSpec } + +import scalacache._ +import memoization._ +import redis._ + +case class User(id: Int, name: String) + +/** + * Test to check the sample code in issue #32 runs OK + * (just to isolate the use of the List[User] type from the Play classloader problem) + */ +class Issue32Spec + extends FlatSpec + with Matchers + with BeforeAndAfter + with RedisTestUtil { + + assumingRedisIsRunning { (pool, client) => + + implicit val scalaCache = ScalaCache(RedisCache(pool)) + + def getUser(id: Int): List[User] = memoize { + List(User(id, "Taro")) + } + + "memoize and Redis" should "work with List[User]" in { + getUser(1) should be(List(User(1, "Taro"))) + getUser(1) should be(List(User(1, "Taro"))) + } + } + +} diff --git a/redis/src/test/scala/scalacache/redis/PlayIntegrationSpec.scala b/redis/src/test/scala/scalacache/redis/PlayIntegrationSpec.scala new file mode 100644 index 00000000..9a4d1ee6 --- /dev/null +++ b/redis/src/test/scala/scalacache/redis/PlayIntegrationSpec.scala @@ -0,0 +1,48 @@ +package scalacache.redis + +import _root_.redis.clients.jedis.JedisPool +import org.scalatest.{ FlatSpec, Matchers, TestData } +import org.scalatestplus.play.OneAppPerTest +import play.api.test.FakeApplication +import play.api.{ Application, GlobalSettings } + +import scalacache._ +import scalacache.memoization._ + +class PlayIntegrationSpec extends FlatSpec with Matchers with OneAppPerTest { + + override def newAppForTest(testData: TestData) = new FakeApplication( + withGlobal = Some(Global) + ) + + "Redis and memoization" should "work with Play in one application" in { + Global.getItems(List(1, 2)) should be(List(Item(1, "Chris"), Item(2, "Chris"))) + Global.getItems(List(1, 2)) should be(List(Item(1, "Chris"), Item(2, "Chris"))) + } + + "Redis and memoization" should "work with Play in another application" in { + Global.getItems(List(1, 2)) should be(List(Item(1, "Chris"), Item(2, "Chris"))) + Global.getItems(List(1, 2)) should be(List(Item(1, "Chris"), Item(2, "Chris"))) + } + +} + +case class Item(id: Int, name: String) + +object Global extends GlobalSettings { + @volatile implicit var jedisPool: JedisPool = _ + @volatile implicit var scalaCache: ScalaCache = _ + + override def onStart(app: Application): Unit = { + jedisPool = new JedisPool("localhost", 6379) + scalaCache = ScalaCache(RedisCache(jedisPool, customClassloader = Some(app.classloader))) + } + + override def onStop(app: Application): Unit = { + jedisPool.destroy() + } + + def getItems(ids: List[Int]): List[Item] = memoize { + ids map { Item(_, "Chris") } + } +} diff --git a/redis/src/test/scala/scalacache/redis/RedisCacheSpec.scala b/redis/src/test/scala/scalacache/redis/RedisCacheSpec.scala index 8049b2fa..fbfc01bc 100644 --- a/redis/src/test/scala/scalacache/redis/RedisCacheSpec.scala +++ b/redis/src/test/scala/scalacache/redis/RedisCacheSpec.scala @@ -6,132 +6,124 @@ import org.scalatest.concurrent.{ ScalaFutures, Eventually, IntegrationPatience import org.scalatest.time.{ Span, Seconds } import scala.language.postfixOps -import scala.util.{ Success, Failure, Try } -import redis.clients.jedis.JedisPool import scala.concurrent.Future import scala.concurrent.ExecutionContext.Implicits.global class RedisCacheSpec - extends FlatSpec with Matchers with Eventually with BeforeAndAfter with RedisSerialization with ScalaFutures with IntegrationPatience { - - Try { - val jedisPool = new JedisPool("localhost", 6379) - val jedis = jedisPool.getResource() - try { - jedis.ping() - } finally { - jedis.close() - } - (jedisPool, jedis) - } match { - case Failure(_) => alert("Skipping tests because Redis does not appear to be running on localhost.") - case Success((pool, client)) => { + extends FlatSpec + with Matchers + with Eventually + with BeforeAndAfter + with RedisSerialization + with ScalaFutures + with IntegrationPatience + with RedisTestUtil { - val cache = RedisCache(pool) + assumingRedisIsRunning { (pool, client) => - before { - client.flushDB() - } + val cache = RedisCache(pool) - behavior of "get" + before { + client.flushDB() + } - it should "return the value stored in Redis" in { - client.set(bytes("key1"), serialize(123)) - whenReady(cache.get("key1")) { _ should be(Some(123)) } - } + behavior of "get" - it should "return None if the given key does not exist in the underlying cache" in { - whenReady(cache.get("non-existent-key")) { _ should be(None) } - } + it should "return the value stored in Redis" in { + client.set(bytes("key1"), serialize(123)) + whenReady(cache.get("key1")) { _ should be(Some(123)) } + } - behavior of "put" + it should "return None if the given key does not exist in the underlying cache" in { + whenReady(cache.get("non-existent-key")) { _ should be(None) } + } - it should "store the given key-value pair in the underlying cache" in { - whenReady(cache.put("key2", 123, None)) { _ => - deserialize[Int](client.get(bytes("key2"))) should be(123) - } + behavior of "put" + + it should "store the given key-value pair in the underlying cache" in { + whenReady(cache.put("key2", 123, None)) { _ => + deserialize[Int](client.get(bytes("key2"))) should be(123) } + } - behavior of "put with TTL" + behavior of "put with TTL" - it should "store the given key-value pair in the underlying cache" in { - whenReady(cache.put("key3", 123, Some(1 second))) { _ => - deserialize[Int](client.get(bytes("key3"))) should be(123) + it should "store the given key-value pair in the underlying cache" in { + whenReady(cache.put("key3", 123, Some(1 second))) { _ => + deserialize[Int](client.get(bytes("key3"))) should be(123) - // Should expire after 1 second - eventually(timeout(Span(2, Seconds))) { - client.get(bytes("key3")) should be(null) - } + // Should expire after 1 second + eventually(timeout(Span(2, Seconds))) { + client.get(bytes("key3")) should be(null) } } + } - behavior of "put with TTL of zero" + behavior of "put with TTL of zero" - it should "store the given key-value pair in the underlying cache with no expiry" in { - whenReady(cache.put("key4", 123, Some(Duration.Zero))) { _ => - deserialize[Int](client.get(bytes("key4"))) should be(123) - client.ttl("key4") should be(-1L) - } + it should "store the given key-value pair in the underlying cache with no expiry" in { + whenReady(cache.put("key4", 123, Some(Duration.Zero))) { _ => + deserialize[Int](client.get(bytes("key4"))) should be(123) + client.ttl("key4") should be(-1L) } + } - behavior of "put with TTL of less than 1 second" + behavior of "put with TTL of less than 1 second" - it should "store the given key-value pair in the underlying cache" in { - whenReady(cache.put("key5", 123, Some(100 milliseconds))) { _ => - deserialize[Int](client.get(bytes("key5"))) should be(123) - client.pttl("key5").toLong should be > 0L + it should "store the given key-value pair in the underlying cache" in { + whenReady(cache.put("key5", 123, Some(100 milliseconds))) { _ => + deserialize[Int](client.get(bytes("key5"))) should be(123) + client.pttl("key5").toLong should be > 0L - // Should expire after 1 second - eventually(timeout(Span(2, Seconds))) { - client.get("key5") should be(null) - } + // Should expire after 1 second + eventually(timeout(Span(2, Seconds))) { + client.get("key5") should be(null) } } + } - behavior of "caching with serialization" + behavior of "caching with serialization" - def roundTrip[V](key: String, value: V): Future[Option[V]] = { - cache.put(key, value, None).flatMap(_ => cache.get(key)) - } + def roundTrip[V](key: String, value: V): Future[Option[V]] = { + cache.put(key, value, None).flatMap(_ => cache.get(key)) + } - it should "round-trip a String" in { - whenReady(roundTrip("string", "hello")) { _ should be(Some("hello")) } - } + it should "round-trip a String" in { + whenReady(roundTrip("string", "hello")) { _ should be(Some("hello")) } + } - it should "round-trip a byte array" in { - whenReady(roundTrip("bytearray", bytes("world"))) { result => - new String(result.get, "UTF-8") should be("world") - } + it should "round-trip a byte array" in { + whenReady(roundTrip("bytearray", bytes("world"))) { result => + new String(result.get, "UTF-8") should be("world") } + } - it should "round-trip an Int" in { - whenReady(roundTrip("int", 345)) { _ should be(Some(345)) } - } + it should "round-trip an Int" in { + whenReady(roundTrip("int", 345)) { _ should be(Some(345)) } + } - it should "round-trip a Double" in { - whenReady(roundTrip("double", 1.23)) { _ should be(Some(1.23)) } - } + it should "round-trip a Double" in { + whenReady(roundTrip("double", 1.23)) { _ should be(Some(1.23)) } + } - it should "round-trip a Long" in { - whenReady(roundTrip("long", 3456L)) { _ should be(Some(3456L)) } - } + it should "round-trip a Long" in { + whenReady(roundTrip("long", 3456L)) { _ should be(Some(3456L)) } + } - it should "round-trip a Serializable case class" in { - val cc = CaseClass(123, "wow") - whenReady(roundTrip("caseclass", cc)) { _ should be(Some(cc)) } - } + it should "round-trip a Serializable case class" in { + val cc = CaseClass(123, "wow") + whenReady(roundTrip("caseclass", cc)) { _ should be(Some(cc)) } + } - behavior of "remove" + behavior of "remove" - it should "delete the given key and its value from the underlying cache" in { - client.set(bytes("key1"), serialize(123)) - deserialize[Int](client.get(bytes("key1"))) should be(123) + it should "delete the given key and its value from the underlying cache" in { + client.set(bytes("key1"), serialize(123)) + deserialize[Int](client.get(bytes("key1"))) should be(123) - whenReady(cache.remove("key1")) { _ => - client.get("key1") should be(null) - } + whenReady(cache.remove("key1")) { _ => + client.get("key1") should be(null) } - } } diff --git a/redis/src/test/scala/scalacache/redis/RedisTestUtil.scala b/redis/src/test/scala/scalacache/redis/RedisTestUtil.scala new file mode 100644 index 00000000..31c00ae1 --- /dev/null +++ b/redis/src/test/scala/scalacache/redis/RedisTestUtil.scala @@ -0,0 +1,23 @@ +package scalacache.redis + +import org.scalatest.Alerting +import redis.clients.jedis.{ JedisPool, Jedis } + +import scala.util.{ Success, Failure, Try } + +trait RedisTestUtil { self: Alerting => + + def assumingRedisIsRunning(f: (JedisPool, Jedis) => Unit): Unit = { + Try { + val jedisPool = new JedisPool("localhost", 6379) + val jedis = jedisPool.getResource() + jedis.ping() + (jedisPool, jedis) + } match { + case Failure(_) => alert("Skipping tests because Redis does not appear to be running on localhost.") + case Success((pool, client)) => + f(pool, client) + } + } + +}