Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package com.github.mogproject.redismock

import com.github.mogproject.redismock.entity.{Bytes, HASH, Key, HashValue}
import com.github.mogproject.redismock.storage.Storage
import com.github.mogproject.redismock.util.ops._
import com.redis.{HashOperations, Redis}
import com.redis.serialization._
import com.redis.serialization.Parse.parseDefault
Expand All @@ -12,80 +13,67 @@ import scala.util.Try
trait MockHashOperations extends HashOperations with MockOperations with Storage {
self: Redis =>

private[this] def setRaw(key: Any, rawValue: Map[Bytes, Bytes])(implicit format: Format): Unit = {
private def setRaw(key: Any, rawValue: Map[Bytes, Bytes])(implicit format: Format): Unit = {
currentDB.update(Key(key), HashValue(rawValue))
}

private[this] def getRaw(key: Any)(implicit format: Format): Option[HASH.DataType] =
currentDB.get(Key(format.apply(key))).map(_.as(HASH))
private def getRaw(key: Any)(implicit format: Format): Option[HASH.DataType] =
currentDB.get(Key(key)).map(_.as(HASH))

private[this] def getRawOrEmpty(key: Any)(implicit format: Format): HASH.DataType =
private def getRawOrEmpty(key: Any)(implicit format: Format): HASH.DataType =
getRaw(key).getOrElse(Map.empty[Bytes, Bytes])

private def getLong(key: Any, field: Any)(implicit format: Format): Option[Long] = hget(key, field).map { v =>
Try(v.toLong).getOrElse(throw new RuntimeException("ERR hash value is not an integer or out of range"))
}

private def getLongOrZero(key: Any, field: Any)(implicit format: Format): Long = getLong(key, field).getOrElse(0L)

private def getFloat(key: Any, field: Any)(implicit format: Format): Option[Float] = hget(key, field).map { v =>
Try(v.toFloat).getOrElse(throw new RuntimeException("ERR hash value is not a valid float"))
}

private def getFloatOrZero(key: Any, field: Any)(implicit format: Format): Float =
getFloat(key, field).getOrElse(0.0f)


override def hset(key: Any, field: Any, value: Any)(implicit format: Format): Boolean = currentDB.synchronized {
val m = getRawOrEmpty(key)
setRaw(key, m.updated(Bytes(field), Bytes(value)))
m.isEmpty
override def hset(key: Any, field: Any, value: Any)(implicit format: Format): Boolean = withDB {
getRawOrEmpty(key) <| { h => setRaw(key, h.updated(Bytes(field), Bytes(value)))} |> {_.isEmpty}
}

override def hsetnx(key: Any, field: Any, value: Any)(implicit format: Format): Boolean = currentDB.synchronized {
getRaw(key) match {
case Some(_) => false
case None =>
setRaw(key, Map(Bytes(field) -> Bytes(value)))
true
}
override def hsetnx(key: Any, field: Any, value: Any)(implicit format: Format): Boolean = withDB {
!exists(key) <| { _ => setRaw(key, Map(Bytes(field) -> Bytes(value)))}
}

override def hget[A](key: Any, field: Any)(implicit format: Format, parse: Parse[A]): Option[A] =
getRaw(key).flatMap(_.get(Bytes(field))).map(_.parse(parse))

override def hmset(key: Any, map: Iterable[Product2[Any, Any]])(implicit format: Format): Boolean =
currentDB.synchronized {
val m = getRawOrEmpty(key)
setRaw(key, m ++ map.map { case (k, v) => Bytes(k) -> Bytes(v) }.toMap)
m.isEmpty
}
override def hmset(key: Any, map: Iterable[Product2[Any, Any]])(implicit format: Format): Boolean = withDB {
val m = map.map { case (k: Any, v: Any) => Bytes(k) -> Bytes(v)}.toMap
getRawOrEmpty(key) ++ m <| {setRaw(key, _)} |> {_.isEmpty}
}

override def hmget[K, V](key: Any, fields: K*)(implicit format: Format, parseV: Parse[V]): Option[Map[K, V]] =
getRaw(key).map(m => fields.flatMap(f => m.get(Bytes(f)).map(_.parse(parseV)).map(f -> _)).toMap)

override def hincrby(key: Any, field: Any, value: Int)(implicit format: Format): Option[Long] =
currentDB.synchronized {
val n = (for {
m <- getRaw(key)
v <- m.get(Bytes(field))
} yield {
Try(v.parse(parseDefault).toLong).getOrElse(throw new RuntimeException("ERR hash value is not an integer or out of range"))
}).getOrElse(0L) + value
hset(key, field, n)
Some(n)
}

override def hincrbyfloat(key: Any, field: Any, value: Float)(implicit format: Format): Option[Float] =
currentDB.synchronized {
val n = (for {
m <- getRaw(key)
v <- m.get(Bytes(field))
} yield {
Try(v.parse(parseDefault).toFloat).getOrElse(throw new RuntimeException("ERR hash value is not a valid float"))
}).getOrElse(0.0f) + value
hset(key, field, n)
Some(n)
}
override def hincrby(key: Any, field: Any, value: Int)(implicit format: Format): Option[Long] = withDB {
getLongOrZero(key, field) + value <| { x => hset(key, field, x)} |> Some.apply
}

override def hincrbyfloat(key: Any, field: Any, value: Float)(implicit format: Format): Option[Float] = withDB {
getFloatOrZero(key, field) + value <| { x => hset(key, field, x)} |> Some.apply
}

override def hexists(key: Any, field: Any)(implicit format: Format): Boolean =
getRawOrEmpty(key).contains(Bytes(field))

override def hdel(key: Any, field: Any, fields: Any*)(implicit format: Format): Option[Long] =
currentDB.synchronized {
val fs = (field :: fields.toList).map(Bytes.apply)
val x = getRawOrEmpty(key)
val y = x.filterKeys(!fs.contains(_))
setRaw(key, y)
Some(x.size - y.size)
}
override def hdel(key: Any, field: Any, fields: Any*)(implicit format: Format): Option[Long] = withDB {
val fs = (field :: fields.toList).map(Bytes.apply)
val x = getRawOrEmpty(key)
val y = x.filterKeys(!fs.contains(_))
setRaw(key, y)
Some(x.size - y.size)
}

override def hlen(key: Any)(implicit format: Format): Option[Long] = getRaw(key).map(_.size)

Expand All @@ -96,7 +84,7 @@ trait MockHashOperations extends HashOperations with MockOperations with Storage
getRaw(key).map(_.values.toList.map(_.parse(parse)))

override def hgetall[K, V](key: Any)(implicit format: Format, parseK: Parse[K], parseV: Parse[V]): Option[Map[K, V]] =
getRaw(key).map(_.map { case (k, v) => k.parse(parseK) -> v.parse(parseV) })
getRaw(key).map(_.map { case (k, v) => k.parse(parseK) -> v.parse(parseV)})

// HSCAN
// Incrementally iterate hash fields and associated values (since 2.8)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import com.github.mogproject.redismock.entity.{Bytes, LIST, Key, ListValue}
import com.github.mogproject.redismock.storage.Storage
import com.redis.{Redis, ListOperations}
import com.redis.serialization._
import com.github.mogproject.redismock.util.ops._
import com.github.mogproject.redismock.util.Implicits._
import scala.annotation.tailrec

Expand All @@ -14,9 +15,8 @@ trait MockListOperations extends ListOperations with MockOperations with Storage
// helper functions
//

private def setRaw(key: Any, rawValue: Traversable[Bytes])(implicit format: Format): Unit = {
private def setRaw(key: Any, rawValue: Traversable[Bytes])(implicit format: Format): Unit =
currentDB.update(Key(key), ListValue(rawValue.toVector))
}

private def getRaw(key: Any)(implicit format: Format): Option[LIST.DataType] = currentDB.get(Key(key)).map(_.as(LIST))

Expand All @@ -26,9 +26,7 @@ trait MockListOperations extends ListOperations with MockOperations with Storage
// LPUSH (Variadic: >= 2.4)
// add values to the head of the list stored at key
override def lpush(key: Any, value: Any, values: Any*)(implicit format: Format): Option[Long] = withDB {
val v = (value :: values.toList).map(Bytes.apply).toVector ++ getRawOrEmpty(key)
setRaw(key, v)
Some(v.size)
(value :: values.toList).map(Bytes.apply) ++ getRawOrEmpty(key) <| (setRaw(key, _)) |> { v => Some(v.size)}
}

// LPUSHX (Variadic: >= 2.4)
Expand All @@ -38,9 +36,7 @@ trait MockListOperations extends ListOperations with MockOperations with Storage
// RPUSH (Variadic: >= 2.4)
// add values to the tail of the list stored at key
override def rpush(key: Any, value: Any, values: Any*)(implicit format: Format): Option[Long] = withDB {
val v = getRawOrEmpty(key) ++ (value :: values.toList).map(Bytes.apply).toVector
setRaw(key, v)
Some(v.size)
getRawOrEmpty(key) ++ (value :: values.toList).map(Bytes.apply) <| (setRaw(key, _)) |> { v => Some(v.size)}
}

// RPUSHX (Variadic: >= 2.4)
Expand All @@ -56,7 +52,8 @@ trait MockListOperations extends ListOperations with MockOperations with Storage
// LRANGE
// return the specified elements of the list stored at the specified key.
// Start and end are zero-based indexes.
override def lrange[A](key: Any, start: Int, end: Int)(implicit format: Format, parse: Parse[A]): Option[List[Option[A]]] = {
override def lrange[A](key: Any, start: Int, end: Int)
(implicit format: Format, parse: Parse[A]): Option[List[Option[A]]] = {
getRaw(key) map {
_.sliceFromTo(start, end).map(_.parseOption(parse)).toList
}
Expand Down Expand Up @@ -148,65 +145,38 @@ trait MockListOperations extends ListOperations with MockOperations with Storage
}
}

override def brpoplpush[A](srcKey: Any, dstKey: Any, timeoutInSeconds: Int)(implicit format: Format, parse: Parse[A]): Option[A] = {
@tailrec
def loop(limit: Long): Option[A] = {
if (limit != 0 && limit <= System.currentTimeMillis()) {
None
} else {
getRaw(srcKey) match {
case Some(bs) if bs.nonEmpty => rpoplpush(srcKey, dstKey) // TODO: reduce # of getRaw (to be once) and be atomic
case _ => Thread.sleep(500L); loop(limit)
}
}
}
override def brpoplpush[A](srcKey: Any, dstKey: Any, timeoutInSeconds: Int)
(implicit format: Format, parse: Parse[A]): Option[A] =
loopUntilFound(List(srcKey))(rpoplpush(_, dstKey))(getTimeLimit(timeoutInSeconds))

loop(if (timeoutInSeconds == 0) 0 else System.currentTimeMillis() + timeoutInSeconds * 1000L)
override def blpop[K, V](timeoutInSeconds: Int, key: K, keys: K*)
(implicit format: Format, parseK: Parse[K], parseV: Parse[V]): Option[(K, V)] = {
loopUntilFound(key :: keys.toList){ k => lpop(k)(format, parseV).map((k, _))}(getTimeLimit(timeoutInSeconds))
}

override def blpop[K, V](timeoutInSeconds: Int, key: K, keys: K*)(implicit format: Format, parseK: Parse[K], parseV: Parse[V]): Option[(K, V)] = {
val ks = key #:: keys.toStream

@tailrec
def loop(limit: Long): Option[(K, V)] = {
if (limit != 0 && limit <= System.currentTimeMillis()) {
None
} else {
val results = ks.map(k => (k, getRawOrEmpty(k)))
results.find(_._2.nonEmpty) match {
case Some((k, _)) =>
// TODO: refactor and be atomic
Some((k, lpop(k)(format, parseV).get))
case _ =>
Thread.sleep(500L)
loop(limit)
}
}
}

loop(if (timeoutInSeconds == 0) 0 else System.currentTimeMillis() + timeoutInSeconds * 1000L)
override def brpop[K, V](timeoutInSeconds: Int, key: K, keys: K*)
(implicit format: Format, parseK: Parse[K], parseV: Parse[V]): Option[(K, V)] = {
loopUntilFound(key :: keys.toList){ k => rpop(k)(format, parseV).map((k, _))}(getTimeLimit(timeoutInSeconds))
}

override def brpop[K, V](timeoutInSeconds: Int, key: K, keys: K*)(implicit format: Format, parseK: Parse[K], parseV: Parse[V]): Option[(K, V)] = {
val ks = key #:: keys.toStream

@tailrec
def loop(limit: Long): Option[(K, V)] = {
if (limit != 0 && limit <= System.currentTimeMillis()) {
None
} else {
val results = ks.map(k => (k, getRawOrEmpty(k)))
results.find(_._2.nonEmpty) match {
case Some((k, _)) =>
// TODO: refactor and be atomic
Some((k, rpop(k)(format, parseV).get))
case _ =>
Thread.sleep(500L)
loop(limit)
}
private def findFirstNonEmpty[K](keys: Seq[K])(implicit format: Format): Option[K] =
keys.find(getRawOrEmpty(_).nonEmpty)

@tailrec
private def loopUntilFound[K, A](keys: Seq[K])(task: K => Option[A])(limit: Long): Option[A] = {
if (limit != 0 && limit <= System.currentTimeMillis()) {
None
} else {
val result = withDB {findFirstNonEmpty(keys).flatMap(task)}
result match {
case Some(_) => result
case None =>
Thread.sleep(500L)
loopUntilFound(keys)(task)(limit)
}
}

loop(if (timeoutInSeconds == 0) 0 else System.currentTimeMillis() + timeoutInSeconds * 1000L)
}

private def getTimeLimit(timeoutInSeconds: Int): Long =
if (timeoutInSeconds == 0) 0 else System.currentTimeMillis() + timeoutInSeconds * 1000L
}
Loading