Skip to content

Commit

Permalink
Add ability to use ShardedJedis for sharded caching with Redis
Browse files Browse the repository at this point in the history
  • Loading branch information
Jared Dellitt committed Oct 17, 2015
1 parent ba028c2 commit b9618de
Show file tree
Hide file tree
Showing 4 changed files with 305 additions and 2 deletions.
20 changes: 20 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,26 @@ val jedis = new Jedis(...)
implicit val scalaCache = ScalaCache(RedisCache(jedis))
```

For [sharded](https://github.com/xetorthio/jedis/wiki/AdvancedUsage#shardedjedis) caching, use a ShardedRedisCache:

```scala
import scalacache._
import redis._

implicit val scalaCache = ScalaCache(ShardedRedisCache(("host1", 6379), ("host2", 6380)))
```

or provide a ShardedJedisPool:

```scala
import scalacache._
import redis._
import redis.clients.jedis._

val jedis = new ShardedJedisPool(...)
implicit val scalaCache = ScalaCache(ShardedRedisCache(jedis))
```

### LruMap (twitter-util)

SBT:
Expand Down
131 changes: 131 additions & 0 deletions redis/src/main/scala/scalacache/redis/ShardedRedisCache.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
package scalacache.redis

import java.nio.charset.Charset

import com.typesafe.scalalogging.StrictLogging
import redis.clients.jedis._

import scala.concurrent._
import scala.concurrent.duration._
import scalacache.{ Cache, LoggingSupport }

class ShardedRedisCache(jedisPool: ShardedJedisPool, override val customClassloader: Option[ClassLoader] = None)(implicit execContext: ExecutionContext = ExecutionContext.global)
extends Cache
with RedisSerialization
with LoggingSupport
with StrictLogging {

import ShardedRedisCache.StringWithUtf8Bytes

/**
* Get the value corresponding to the given key from the cache
* @param key cache key
* @tparam V the type of the corresponding value
* @return the value, if there is one
*/
override def get[V](key: String) = Future {
blocking {
withJedisClient { client =>
val resultBytes = Option(client.get(key.utf8bytes))
val result = resultBytes.map(deserialize[V])
logCacheHitOrMiss(key, result)
result
}
}
}

/**
* Insert the given key-value pair into the cache, with an optional Time To Live.
* @param key cache key
* @param value corresponding value
* @param ttl Time To Live
* @tparam V the type of the corresponding value
*/
override def put[V](key: String, value: V, ttl: Option[Duration]) = Future {
blocking {
withJedisClient { client =>
val keyBytes = key.utf8bytes
val valueBytes = serialize(value)
ttl match {
case None => client.set(keyBytes, valueBytes)
case Some(Duration.Zero) => client.set(keyBytes, valueBytes)
case Some(d) if d < 1.second =>
logger.warn("Because Redis (pre 2.6.12) does not support sub-second expiry, TTL of $d will be rounded up to 1 second")
client.setex(keyBytes, 1, valueBytes)
case Some(d) => client.setex(keyBytes, d.toSeconds.toInt, valueBytes)
}
}
}
}

/**
* Remove the given key and its associated value from the cache, if it exists.
* If the key is not in the cache, do nothing.
* @param key cache key
*/
override def remove(key: String) = Future {
blocking {
withJedisClient { client =>
client.del(key.utf8bytes)
}
}
}

override def removeAll() = Future {
blocking {
withJedisClient { client =>
import scala.collection.JavaConversions.collectionAsScalaIterable
client.getAllShards.foreach(_.flushDB())
}
}
}

override def close(): Unit = {
jedisPool.close()
}

private def withJedisClient[T](f: ShardedJedis => T): T = {
val jedis = jedisPool.getResource
try {
f(jedis)
} finally {
jedis.close()
}
}
}

object ShardedRedisCache {

/**
* Create a sharded Redis client connecting to the given hosts and use them for caching
*/
def apply(hosts: (String, Int)*): ShardedRedisCache = {
import scala.collection.JavaConversions.seqAsJavaList

val pool = new ShardedJedisPool(new JedisPoolConfig(), hosts.map {
case (host, port) new JedisShardInfo(host, port)
})

apply(pool)
}

/**
* Create a cache that uses the given ShardedJedis client pool
* @param jedisPool a ShardedJedis pool
* @param customClassloader a classloader to use when deserializing objects from the cache.
* If you are using Play, you should pass in `app.classloader`.
*/
def apply(jedisPool: ShardedJedisPool, customClassloader: Option[ClassLoader] = None): ShardedRedisCache =
new ShardedRedisCache(jedisPool, customClassloader)

private val utf8 = Charset.forName("UTF-8")

/**
* Enrichment class to convert String to UTF-8 byte array
*/
private implicit class StringWithUtf8Bytes(val string: String) extends AnyVal {
def utf8bytes = string.getBytes(utf8)
}

}

24 changes: 22 additions & 2 deletions redis/src/test/scala/scalacache/redis/RedisTestUtil.scala
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package scalacache.redis

import org.scalatest.Alerting
import redis.clients.jedis.{ JedisPool, Jedis }

import redis.clients.jedis._
import scala.collection.JavaConversions.collectionAsScalaIterable
import scala.util.{ Success, Failure, Try }

trait RedisTestUtil { self: Alerting =>
Expand All @@ -20,4 +20,24 @@ trait RedisTestUtil { self: Alerting =>
}
}

def assumingMultipleRedisAreRunning(f: (ShardedJedisPool, ShardedJedis) => Unit): Unit = {
import scala.collection.JavaConversions.seqAsJavaList

Try {
val shard1 = new JedisShardInfo("localhost", 6379)
val shard2 = new JedisShardInfo("localhost", 6380)

val jedisPool = new ShardedJedisPool(new JedisPoolConfig(), Seq(shard1, shard2))
val jedis = jedisPool.getResource

jedis.getAllShards.foreach(_.ping())

(jedisPool, jedis)
} match {
case Failure(_) => alert("Skipping tests because Redis does not appear to be running on localhost.")
case Success((pool, client)) =>
f(pool, client)
}
}

}
132 changes: 132 additions & 0 deletions redis/src/test/scala/scalacache/redis/ShardedRedisCacheSpec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
package scalacache.redis

import org.scalatest.concurrent.{ Eventually, IntegrationPatience, ScalaFutures }
import org.scalatest.time.{ Seconds, Span }
import org.scalatest.{ BeforeAndAfter, FlatSpec, Matchers }

import scala.concurrent.Future
import scala.concurrent.duration._
import scala.concurrent.ExecutionContext.Implicits.global

class ShardedRedisCacheSpec
extends FlatSpec
with Matchers
with Eventually
with BeforeAndAfter
with RedisSerialization
with ScalaFutures
with IntegrationPatience
with RedisTestUtil {

assumingMultipleRedisAreRunning { (pool, client) =>

val cache = ShardedRedisCache(pool)

before {
import scala.collection.JavaConversions.collectionAsScalaIterable
client.getAllShards.foreach(_.flushDB())
}

behavior of "get"

it should "return the value stored in Redis" in {
client.set(bytes("key1"), serialize(123))
whenReady(cache.get("key1")) { _ should be(Some(123)) }
}

it should "return None if the given key does not exist in the underlying cache" in {
whenReady(cache.get("non-existent-key")) { _ should be(None) }
}

behavior of "put"

it should "store the given key-value pair in the underlying cache" in {
whenReady(cache.put("key2", 123, None)) { _ =>
deserialize[Int](client.get(bytes("key2"))) should be(123)
}
}

behavior of "put with TTL"

it should "store the given key-value pair in the underlying cache" in {
whenReady(cache.put("key3", 123, Some(1 second))) { _ =>
deserialize[Int](client.get(bytes("key3"))) should be(123)

// Should expire after 1 second
eventually(timeout(Span(2, Seconds))) {
client.get(bytes("key3")) should be(null)
}
}
}

behavior of "put with TTL of zero"

it should "store the given key-value pair in the underlying cache with no expiry" in {
whenReady(cache.put("key4", 123, Some(Duration.Zero))) { _ =>
deserialize[Int](client.get(bytes("key4"))) should be(123)
client.ttl("key4") should be(-1L)
}
}

behavior of "put with TTL of less than 1 second"

it should "store the given key-value pair in the underlying cache" in {
whenReady(cache.put("key5", 123, Some(100 milliseconds))) { _ =>
deserialize[Int](client.get(bytes("key5"))) should be(123)
client.ttl("key5").toLong should be > 0L

// Should expire after 1 second
eventually(timeout(Span(2, Seconds))) {
client.get("key5") should be(null)
}
}
}

behavior of "caching with serialization"

def roundTrip[V](key: String, value: V): Future[Option[V]] = {
cache.put(key, value, None).flatMap(_ => cache.get(key))
}

it should "round-trip a String" in {
whenReady(roundTrip("string", "hello")) { _ should be(Some("hello")) }
}

it should "round-trip a byte array" in {
whenReady(roundTrip("bytearray", bytes("world"))) { result =>
new String(result.get, "UTF-8") should be("world")
}
}

it should "round-trip an Int" in {
whenReady(roundTrip("int", 345)) { _ should be(Some(345)) }
}

it should "round-trip a Double" in {
whenReady(roundTrip("double", 1.23)) { _ should be(Some(1.23)) }
}

it should "round-trip a Long" in {
whenReady(roundTrip("long", 3456L)) { _ should be(Some(3456L)) }
}

it should "round-trip a Serializable case class" in {
val cc = CaseClass(123, "wow")
whenReady(roundTrip("caseclass", cc)) { _ should be(Some(cc)) }
}

behavior of "remove"

it should "delete the given key and its value from the underlying cache" in {
client.set(bytes("key1"), serialize(123))
deserialize[Int](client.get(bytes("key1"))) should be(123)

whenReady(cache.remove("key1")) { _ =>
client.get("key1") should be(null)
}
}

}

def bytes(s: String) = s.getBytes("utf-8")
}

0 comments on commit b9618de

Please sign in to comment.