diff --git a/README.md b/README.md index 094ff16..f16de1a 100644 --- a/README.md +++ b/README.md @@ -51,7 +51,8 @@ redis-cli shutdown |[x]|[ ]|[ ]|Keys| |[x]|[x]|[ ]|Strings| |[x]|[x]|[ ]|Lists| -|[ ]|[ ]|[ ]|Sets| +|[x]|[x]|[ ]|Sets (except SSCAN)| +|[ ]|[ ]|[ ]|Sets (SSCAN)| |[ ]|[ ]|[ ]|Sorted Sets| |[ ]|[ ]|[ ]|Hashes| |[ ]|[ ]|[ ]|HyperLogLog| diff --git a/src/main/scala/com/github/mogproject/redismock/MockRedisClient.scala b/src/main/scala/com/github/mogproject/redismock/MockRedisClient.scala index 26653c2..a9b8a86 100644 --- a/src/main/scala/com/github/mogproject/redismock/MockRedisClient.scala +++ b/src/main/scala/com/github/mogproject/redismock/MockRedisClient.scala @@ -16,7 +16,7 @@ trait MockRedisCommand extends Redis with MockOperations //with MockNodeOperations with MockStringOperations with MockListOperations -//with MockSetOperations +with MockSetOperations //with MockSortedSetOperations //with MockHashOperations //with MockEvalOperations diff --git a/src/main/scala/com/github/mogproject/redismock/MockSetOperations.scala b/src/main/scala/com/github/mogproject/redismock/MockSetOperations.scala new file mode 100644 index 0000000..5b2d1b9 --- /dev/null +++ b/src/main/scala/com/github/mogproject/redismock/MockSetOperations.scala @@ -0,0 +1,157 @@ +package com.github.mogproject.redismock + +import com.github.mogproject.redismock.entity.{Bytes, SET, Key, SetValue} +import com.github.mogproject.redismock.storage.Storage +import com.redis.{SetOperations, Redis} +import com.redis.serialization._ + +trait MockSetOperations extends SetOperations with MockOperations with Storage { + self: Redis => + + private[this] def setRaw(key: Any, rawValue: SET.DataType)(implicit format: Format): Unit = + currentDB.update(Key(key), SetValue(rawValue)) + + private[this] def getRaw(key: Any)(implicit format: Format): Option[SET.DataType] = + currentDB.get(Key(format.apply(key))).map(_.as(SET)) + + private[this] def getRawOrEmpty(key: Any)(implicit format: Format): SET.DataType = + getRaw(key).getOrElse(Set.empty[Bytes]) + + + // SADD (VARIADIC: >= 2.4) + // Add the specified members to the set value stored at key. + override def sadd(key: Any, value: Any, values: Any*)(implicit format: Format): Option[Long] = { + currentDB.synchronized { + val a = getRawOrEmpty(key) + val b = a ++ (value :: values.toList).map(Bytes.apply) + setRaw(key, b) + Some(b.size - a.size) + } + } + + // SREM (VARIADIC: >= 2.4) + // Remove the specified members from the set value stored at key. + override def srem(key: Any, value: Any, values: Any*)(implicit format: Format): Option[Long] = { + currentDB.synchronized { + val a = getRawOrEmpty(key) + val b = a -- (value :: values.toList).map(Bytes.apply) + setRaw(key, b) + Some(a.size - b.size) + } + } + + // SPOP + // Remove and return (pop) a random element from the Set value at key. + override def spop[A](key: Any)(implicit format: Format, parse: Parse[A]): Option[A] = { + currentDB.synchronized { + val (h, t) = getRawOrEmpty(key).splitAt(1) + h.headOption.map { x => + setRaw(key, t) + x.parse(parse) + } + } + } + + // SMOVE + // Move the specified member from one Set to another atomically. + // Integer reply, specifically: + // 1 if the element is moved. + // 0 if the element is not a member of source and no operation was performed. + override def smove(sourceKey: Any, destKey: Any, value: Any)(implicit format: Format): Option[Long] = { + currentDB.synchronized { + val ret = srem(sourceKey, value) + if (ret.exists(_ > 0)) sadd(destKey, value) + ret + } + } + + // SCARD + // Return the number of elements (the cardinality) of the Set at key. + override def scard(key: Any)(implicit format: Format): Option[Long] = Some(getRawOrEmpty(key).size) + + // SISMEMBER + // Test if the specified value is a member of the Set at key. + override def sismember(key: Any, value: Any)(implicit format: Format): Boolean = + getRawOrEmpty(key).contains(Bytes(value)) + + // SINTER + // Return the intersection between the Sets stored at key1, key2, ..., keyN. + override def sinter[A](key: Any, keys: Any*)(implicit format: Format, parse: Parse[A]): Option[Set[Option[A]]] = { + currentDB.synchronized { + Some((key :: keys.toList).map(getRawOrEmpty).reduceLeft(_ & _).map(_.parseOption(parse))) + } + } + + // SINTERSTORE + // Compute the intersection between the Sets stored at key1, key2, ..., keyN, + // and store the resulting Set at dstkey. + // SINTERSTORE returns the size of the intersection, unlike what the documentation says + // refer http://code.google.com/p/redis/issues/detail?id=121 + override def sinterstore(key: Any, keys: Any*)(implicit format: Format): Option[Long] = { + currentDB.synchronized { + val xs = keys.map(getRawOrEmpty).reduceLeft(_ & _) + setRaw(key, xs) + Some(xs.size) + } + } + + // SUNION + // Return the union between the Sets stored at key1, key2, ..., keyN. + override def sunion[A](key: Any, keys: Any*)(implicit format: Format, parse: Parse[A]): Option[Set[Option[A]]] = { + currentDB.synchronized { + Some((key :: keys.toList).map(getRawOrEmpty).reduceLeft(_ | _).map(_.parseOption(parse))) + } + } + + // SUNIONSTORE + // Compute the union between the Sets stored at key1, key2, ..., keyN, + // and store the resulting Set at dstkey. + // SUNIONSTORE returns the size of the union, unlike what the documentation says + // refer http://code.google.com/p/redis/issues/detail?id=121 + override def sunionstore(key: Any, keys: Any*)(implicit format: Format): Option[Long] = { + currentDB.synchronized { + val xs = keys.map(getRawOrEmpty).reduceLeft(_ | _) + setRaw(key, xs) + Some(xs.size) + } + } + + // SDIFF + // Return the difference between the Set stored at key1 and all the Sets key2, ..., keyN. + override def sdiff[A](key: Any, keys: Any*)(implicit format: Format, parse: Parse[A]): Option[Set[Option[A]]] = { + currentDB.synchronized { + Some((key :: keys.toList).map(getRawOrEmpty).reduceLeft(_ -- _).map(_.parseOption(parse))) + } + } + + // SDIFFSTORE + // Compute the difference between the Set key1 and all the Sets key2, ..., keyN, + // and store the resulting Set at dstkey. + override def sdiffstore(key: Any, keys: Any*)(implicit format: Format): Option[Long] = { + currentDB.synchronized { + val xs = keys.map(getRawOrEmpty).reduceLeft(_ -- _) + setRaw(key, xs) + Some(xs.size) + } + } + + // SMEMBERS + // Return all the members of the Set value at key. + override def smembers[A](key: Any)(implicit format: Format, parse: Parse[A]): Option[Set[Option[A]]] = + Some(getRawOrEmpty(key).map(_.parseOption(parse))) + + // SRANDMEMBER + // Return a random element from a Set + override def srandmember[A](key: Any)(implicit format: Format, parse: Parse[A]): Option[A] = + getRawOrEmpty(key).headOption.map(_.parse(parse)) + + // SRANDMEMBER + // Return multiple random elements from a Set (since 2.6) + override def srandmember[A](key: Any, count: Int)(implicit format: Format, parse: Parse[A]): Option[List[Option[A]]] = + Some(getRawOrEmpty(key).take(count).toList.map(_.parseOption(parse))) + + // SSCAN + // Incrementally iterate Set elements (since 2.8) + override def sscan[A](key: Any, cursor: Int, pattern: Any = "*", count: Int = 10)(implicit format: Format, parse: Parse[A]): Option[(Option[Int], Option[List[Option[A]]])] = + send("SSCAN", key :: cursor :: ((x: List[Any]) => if (pattern == "*") x else "match" :: pattern :: x)(if (count == 10) Nil else List("count", count)))(asPair) +} diff --git a/src/main/scala/com/github/mogproject/redismock/entity/Bytes.scala b/src/main/scala/com/github/mogproject/redismock/entity/Bytes.scala index 083a4c4..a086462 100644 --- a/src/main/scala/com/github/mogproject/redismock/entity/Bytes.scala +++ b/src/main/scala/com/github/mogproject/redismock/entity/Bytes.scala @@ -19,6 +19,8 @@ case class Bytes(value: Vector[Byte]) extends IndexedSeqLike[Byte, Bytes] { def parse[A](parse: Parse[A]): A = parse(value.toArray) + def parseOption[A](parse: Parse[A]): Option[A] = Try(parse(value.toArray)).toOption + def ++(bs: Bytes): Bytes = Bytes(value ++ bs.value) def ++(bs: Traversable[Byte]): Bytes = Bytes(value ++ bs) diff --git a/src/main/scala/com/github/mogproject/redismock/entity/Value.scala b/src/main/scala/com/github/mogproject/redismock/entity/Value.scala index dca543e..08d10d9 100644 --- a/src/main/scala/com/github/mogproject/redismock/entity/Value.scala +++ b/src/main/scala/com/github/mogproject/redismock/entity/Value.scala @@ -22,11 +22,11 @@ case class StringValue(value: STRING.DataType) extends Value { val valueType = S case class ListValue(value: LIST.DataType) extends Value { val valueType = LIST } +case class SetValue(value: SET.DataType) extends Value { val valueType = SET } -object StringValue { - def apply(value: Array[Byte]): StringValue = new StringValue(Bytes(value)) - def apply(value: Any)(implicit format: Format): StringValue = apply(format(value)) +object StringValue { + def apply(value: Any)(implicit format: Format): StringValue = apply(value) } object ListValue { @@ -34,3 +34,9 @@ object ListValue { def apply(value: Traversable[Any])(implicit format: Format): ListValue = apply(value.map(format.apply)) } + +object SetValue { + def apply(value: Set[Any])(implicit format: Format): ListValue = apply(value.map(format.apply)) + + def apply(value: Traversable[Any])(implicit format: Format): ListValue = apply(value.map(format.apply)) +} diff --git a/src/test/scala/com/github/mogproject/redismock/MockSetOperationsSpec.scala b/src/test/scala/com/github/mogproject/redismock/MockSetOperationsSpec.scala new file mode 100644 index 0000000..0e3a9b9 --- /dev/null +++ b/src/test/scala/com/github/mogproject/redismock/MockSetOperationsSpec.scala @@ -0,0 +1,327 @@ +package com.github.mogproject.redismock + +import org.scalatest.FunSpec +import org.scalatest.BeforeAndAfterEach +import org.scalatest.BeforeAndAfterAll +import org.scalatest.Matchers + +class MockSetOperationsSpec extends FunSpec +with Matchers +with BeforeAndAfterEach +with BeforeAndAfterAll { + + val r = TestUtil.getRedisClient + + override def beforeEach = { + } + + override def afterEach = { + r.flushdb + } + + override def afterAll = { + r.disconnect + } + + describe("sadd") { + it("should add a non-existent value to the set") { + r.sadd("set-1", "foo").get should equal(1) + r.sadd("set-1", "bar").get should equal(1) + } + it("should not add an existing value to the set") { + r.sadd("set-1", "foo").get should equal(1) + r.sadd("set-1", "foo").get should equal(0) + } + it("should fail if the key points to a non-set") { + r.lpush("list-1", "foo") should equal(Some(1)) + val thrown = the [Exception] thrownBy { r.sadd("list-1", "foo") } + thrown.getMessage should include ("Operation against a key holding the wrong kind of value") + } + } + + describe("sadd with variadic arguments") { + it("should add a non-existent value to the set") { + r.sadd("set-1", "foo", "bar", "baz").get should equal(3) + r.sadd("set-1", "foo", "bar", "faz").get should equal(1) + r.sadd("set-1", "bar").get should equal(0) + } + } + + describe("srem") { + it("should remove a value from the set") { + r.sadd("set-1", "foo").get should equal(1) + r.sadd("set-1", "bar").get should equal(1) + r.srem("set-1", "bar").get should equal(1) + r.srem("set-1", "foo").get should equal(1) + } + it("should not do anything if the value does not exist") { + r.sadd("set-1", "foo").get should equal(1) + r.srem("set-1", "bar").get should equal(0) + } + it("should fail if the key points to a non-set") { + r.lpush("list-1", "foo") should equal(Some(1)) + val thrown = the [Exception] thrownBy { r.srem("list-1", "foo") } + thrown.getMessage should include ("Operation against a key holding the wrong kind of value") + } + } + + describe("srem with variadic arguments") { + it("should remove a value from the set") { + r.sadd("set-1", "foo", "bar", "baz", "faz").get should equal(4) + r.srem("set-1", "foo", "bar").get should equal(2) + r.srem("set-1", "foo").get should equal(0) + r.srem("set-1", "baz", "bar").get should equal(1) + } + } + + describe("spop") { + it("should pop a random element") { + r.sadd("set-1", "foo").get should equal(1) + r.sadd("set-1", "bar").get should equal(1) + r.sadd("set-1", "baz").get should equal(1) + r.spop("set-1").get should (equal("foo") or equal("bar") or equal("baz")) + } + it("should return nil if the key does not exist") { + r.spop("set-1") should equal(None) + } + } + + describe("smove") { + it("should move from one set to another") { + r.sadd("set-1", "foo").get should equal(1) + r.sadd("set-1", "bar").get should equal(1) + r.sadd("set-1", "baz").get should equal(1) + + r.sadd("set-2", "1").get should equal(1) + r.sadd("set-2", "2").get should equal(1) + + r.smove("set-1", "set-2", "baz").get should equal(1) + r.sadd("set-2", "baz").get should equal(0) + r.sadd("set-1", "baz").get should equal(1) + } + it("should return 0 if the element does not exist in source set") { + r.sadd("set-1", "foo").get should equal(1) + r.sadd("set-1", "bar").get should equal(1) + r.sadd("set-1", "baz").get should equal(1) + r.smove("set-1", "set-2", "bat").get should equal(0) + r.smove("set-3", "set-2", "bat").get should equal(0) + } + it("should give error if the source or destination key is not a set") { + r.lpush("list-1", "foo") should equal(Some(1)) + r.lpush("list-1", "bar") should equal(Some(2)) + r.lpush("list-1", "baz") should equal(Some(3)) + r.sadd("set-1", "foo").get should equal(1) + val thrown = the [Exception] thrownBy { r.smove("list-1", "set-1", "bat") } + thrown.getMessage should include ("Operation against a key holding the wrong kind of value") + } + } + + describe("scard") { + it("should return cardinality") { + r.sadd("set-1", "foo").get should equal(1) + r.sadd("set-1", "bar").get should equal(1) + r.sadd("set-1", "baz").get should equal(1) + r.scard("set-1").get should equal(3) + } + it("should return 0 if key does not exist") { + r.scard("set-1").get should equal(0) + } + } + + describe("sismember") { + it("should return true for membership") { + r.sadd("set-1", "foo").get should equal(1) + r.sadd("set-1", "bar").get should equal(1) + r.sadd("set-1", "baz").get should equal(1) + r.sismember("set-1", "foo") should equal(true) + } + it("should return false for no membership") { + r.sadd("set-1", "foo").get should equal(1) + r.sadd("set-1", "bar").get should equal(1) + r.sadd("set-1", "baz").get should equal(1) + r.sismember("set-1", "fo") should equal(false) + } + it("should return false if key does not exist") { + r.sismember("set-1", "fo") should equal(false) + } + } + + describe("sinter") { + it("should return intersection") { + r.sadd("set-1", "foo").get should equal(1) + r.sadd("set-1", "bar").get should equal(1) + r.sadd("set-1", "baz").get should equal(1) + + r.sadd("set-2", "foo").get should equal(1) + r.sadd("set-2", "bat").get should equal(1) + r.sadd("set-2", "baz").get should equal(1) + + r.sadd("set-3", "for").get should equal(1) + r.sadd("set-3", "bat").get should equal(1) + r.sadd("set-3", "bay").get should equal(1) + + r.sinter("set-1", "set-2").get should equal(Set(Some("foo"), Some("baz"))) + r.sinter("set-1", "set-3").get should equal(Set.empty) + } + it("should return empty set for non-existing key") { + r.sadd("set-1", "foo").get should equal(1) + r.sadd("set-1", "bar").get should equal(1) + r.sadd("set-1", "baz").get should equal(1) + r.sinter("set-1", "set-4") should equal(Some(Set())) + } + } + + describe("sinterstore") { + it("should store intersection") { + r.sadd("set-1", "foo").get should equal(1) + r.sadd("set-1", "bar").get should equal(1) + r.sadd("set-1", "baz").get should equal(1) + + r.sadd("set-2", "foo").get should equal(1) + r.sadd("set-2", "bat").get should equal(1) + r.sadd("set-2", "baz").get should equal(1) + + r.sadd("set-3", "for").get should equal(1) + r.sadd("set-3", "bat").get should equal(1) + r.sadd("set-3", "bay").get should equal(1) + + r.sinterstore("set-r", "set-1", "set-2").get should equal(2) + r.scard("set-r").get should equal(2) + r.sinterstore("set-s", "set-1", "set-3").get should equal(0) + r.scard("set-s").get should equal(0) + } + it("should return empty set for non-existing key") { + r.sadd("set-1", "foo").get should equal(1) + r.sadd("set-1", "bar").get should equal(1) + r.sadd("set-1", "baz").get should equal(1) + r.sinterstore("set-r", "set-1", "set-4").get should equal(0) + r.scard("set-r").get should equal(0) + } + } + + describe("sunion") { + it("should return union") { + r.sadd("set-1", "foo").get should equal(1) + r.sadd("set-1", "bar").get should equal(1) + r.sadd("set-1", "baz").get should equal(1) + + r.sadd("set-2", "foo").get should equal(1) + r.sadd("set-2", "bat").get should equal(1) + r.sadd("set-2", "baz").get should equal(1) + + r.sadd("set-3", "for").get should equal(1) + r.sadd("set-3", "bat").get should equal(1) + r.sadd("set-3", "bay").get should equal(1) + + r.sunion("set-1", "set-2").get should equal(Set(Some("foo"), Some("bar"), Some("baz"), Some("bat"))) + r.sunion("set-1", "set-3").get should equal(Set(Some("foo"), Some("bar"), Some("baz"), Some("for"), Some("bat"), Some("bay"))) + } + it("should return empty set for non-existing key") { + r.sadd("set-1", "foo").get should equal(1) + r.sadd("set-1", "bar").get should equal(1) + r.sadd("set-1", "baz").get should equal(1) + r.sunion("set-1", "set-2").get should equal(Set(Some("foo"), Some("bar"), Some("baz"))) + } + } + + describe("sunionstore") { + it("should store union") { + r.sadd("set-1", "foo").get should equal(1) + r.sadd("set-1", "bar").get should equal(1) + r.sadd("set-1", "baz").get should equal(1) + + r.sadd("set-2", "foo").get should equal(1) + r.sadd("set-2", "bat").get should equal(1) + r.sadd("set-2", "baz").get should equal(1) + + r.sadd("set-3", "for").get should equal(1) + r.sadd("set-3", "bat").get should equal(1) + r.sadd("set-3", "bay").get should equal(1) + + r.sunionstore("set-r", "set-1", "set-2").get should equal(4) + r.scard("set-r").get should equal(4) + r.sunionstore("set-s", "set-1", "set-3").get should equal(6) + r.scard("set-s").get should equal(6) + } + it("should treat non-existing keys as empty sets") { + r.sadd("set-1", "foo").get should equal(1) + r.sadd("set-1", "bar").get should equal(1) + r.sadd("set-1", "baz").get should equal(1) + r.sunionstore("set-r", "set-1", "set-4").get should equal(3) + r.scard("set-r").get should equal(3) + } + } + + describe("sdiff") { + it("should return diff") { + r.sadd("set-1", "foo").get should equal(1) + r.sadd("set-1", "bar").get should equal(1) + r.sadd("set-1", "baz").get should equal(1) + + r.sadd("set-2", "foo").get should equal(1) + r.sadd("set-2", "bat").get should equal(1) + r.sadd("set-2", "baz").get should equal(1) + + r.sadd("set-3", "for").get should equal(1) + r.sadd("set-3", "bat").get should equal(1) + r.sadd("set-3", "bay").get should equal(1) + + r.sdiff("set-1", "set-2", "set-3").get should equal(Set(Some("bar"))) + } + it("should treat non-existing keys as empty sets") { + r.sadd("set-1", "foo").get should equal(1) + r.sadd("set-1", "bar").get should equal(1) + r.sadd("set-1", "baz").get should equal(1) + r.sdiff("set-1", "set-2").get should equal(Set(Some("foo"), Some("bar"), Some("baz"))) + } + } + + describe("smembers") { + it("should return members of a set") { + r.sadd("set-1", "foo").get should equal(1) + r.sadd("set-1", "bar").get should equal(1) + r.sadd("set-1", "baz").get should equal(1) + r.smembers("set-1").get should equal(Set(Some("foo"), Some("bar"), Some("baz"))) + } + it("should return None for an empty set") { + r.smembers("set-1") should equal(Some(Set())) + } + } + + describe("srandmember") { + it("should return a random member") { + r.sadd("set-1", "foo").get should equal(1) + r.sadd("set-1", "bar").get should equal(1) + r.sadd("set-1", "baz").get should equal(1) + r.srandmember("set-1").get should (equal("foo") or equal("bar") or equal("baz")) + } + it("should return None for a non-existing key") { + r.srandmember("set-1") should equal(None) + } + } + + describe("srandmember with count") { + it("should return a list of random members") { + r.sadd("set-1", "one").get should equal(1) + r.sadd("set-1", "two").get should equal(1) + r.sadd("set-1", "three").get should equal(1) + r.sadd("set-1", "four").get should equal(1) + r.sadd("set-1", "five").get should equal(1) + r.sadd("set-1", "six").get should equal(1) + r.sadd("set-1", "seven").get should equal(1) + r.sadd("set-1", "eight").get should equal(1) + + r.srandmember("set-1", 2).get.size should equal(2) + + // returned elements should be unique + val l = r.srandmember("set-1", 4).get + l.size should equal(l.toSet.size) + + // returned elements may have duplicates + r.srandmember("set-1", -4).get.toSet.size should (be <= (4)) + + // if supplied count > size, then whole set is returned + r.srandmember("set-1", 24).get.toSet.size should equal(8) + } + } +}