Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

Merge pull request #35 from jeethu/scripting

Redis 2.6 Lua scripting support
  • Loading branch information...
commit a63ddbabb9ae24b155f9331e478f6e4383c7d74a 2 parents 20b55ff + f13518e
Alexandre Fiori authored
Showing with 326 additions and 3 deletions.
  1. +204 −0 tests/test_scripting.py
  2. +122 −3 txredisapi.py
204 tests/test_scripting.py
View
@@ -0,0 +1,204 @@
+# coding: utf-8
+# Copyright 2009 Alexandre Fiori
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import sys
+import hashlib
+
+import txredisapi as redis
+
+from twisted.internet import defer
+from twisted.trial import unittest
+from twisted.internet import reactor
+from twisted.python import failure
+
+redis_host = "localhost"
+redis_port = 6379
+
+
+class TestScripting(unittest.TestCase):
+ _SCRIPT = "return {KEYS[1],KEYS[2],ARGV[1],ARGV[2]}" # From redis example
+
+ @defer.inlineCallbacks
+ def setUp(self):
+ self.db = yield redis.Connection(redis_host, redis_port,
+ reconnect=False)
+ self.db1 = None
+ self.redis_2_6 = yield self.is_redis_2_6()
+ d = yield self.db.info("server")
+ self.redis_version = d[u'redis_version']
+ yield self.db.script_flush()
+
+ def _skipCheck(self):
+ if not self.redis_2_6:
+ skipMsg = "Redis version < 2.6 (found version: %s)"
+ raise unittest.SkipTest(skipMsg % self.redis_version)
+
+ @defer.inlineCallbacks
+ def tearDown(self):
+ yield self.db.disconnect()
+ if self.db1 is not None:
+ yield self.db1.disconnect()
+
+ @defer.inlineCallbacks
+ def is_redis_2_6(self):
+ """
+ Returns true if the Redis version >= 2.6
+ """
+ d = yield self.db.info("server")
+ if u'redis_version' not in d:
+ defer.returnValue(False)
+ ver = d[u'redis_version']
+ ver_list = [int(x) for x in ver.split(u'.')]
+ if len(ver_list) < 2:
+ defer.returnValue(False)
+ if ver_list[0] > 2:
+ defer.returnValue(True)
+ elif ver_list[0] == 2 and ver_list[1] >= 6:
+ defer.returnValue(True)
+ defer.returnValue(False)
+
+ @defer.inlineCallbacks
+ def test_eval(self):
+ self._skipCheck()
+ d = dict(key1="first", key2="second")
+ r = yield self.db.eval(self._SCRIPT, **d)
+ self._check_eval_result(d, r)
+ r = yield self.db.eval("return 10")
+ self.assertEqual(r, 10)
+ r = yield self.db.eval("return {1,2,3.3333,'foo',nil,'bar'}")
+ self.assertEqual(r, [1, 2, 3, "foo"])
+ # Test the case where the hash is in script_hashes,
+ # but redis doesn't have it
+ h = self._hash_script(self._SCRIPT)
+ yield self.db.script_flush()
+ self.db.script_hashes.add(h)
+ r = yield self.db.eval(self._SCRIPT, **d)
+ self._check_eval_result(d, r)
+
+ @defer.inlineCallbacks
+ def test_eval_error(self):
+ self._skipCheck()
+ try:
+ result = yield self.db.eval('return {err="My Error"}')
+ except redis.ResponseError:
+ pass
+ except:
+ raise self.failureException('%s raised instead of %s:\n %s'
+ % (sys.exc_info()[0],
+ 'txredisapi.ResponseError',
+ failure.Failure().getTraceback()))
+ else:
+ raise self.failureException('%s not raised (%r returned)'
+ % ('txredisapi.ResponseError', result))
+
+ @defer.inlineCallbacks
+ def test_evalsha(self):
+ self._skipCheck()
+ r = yield self.db.eval(self._SCRIPT)
+ h = self._hash_script(self._SCRIPT)
+ r = yield self.db.evalsha(h)
+ self._check_eval_result({}, r)
+
+ @defer.inlineCallbacks
+ def test_evalsha_error(self):
+ self._skipCheck()
+ h = self._hash_script(self._SCRIPT)
+ try:
+ result = yield self.db.evalsha(h)
+ except redis.ScriptDoesNotExist:
+ pass
+ except:
+ raise self.failureException('%s raised instead of %s:\n %s'
+ % (sys.exc_info()[0],
+ 'txredisapi.ScriptDoesNotExist',
+ failure.Failure().getTraceback()))
+ else:
+ raise self.failureException('%s not raised (%r returned)'
+ % ('txredisapi.ResponseError', result))
+
+ @defer.inlineCallbacks
+ def test_script_load(self):
+ self._skipCheck()
+ h = self._hash_script(self._SCRIPT)
+ r = yield self.db.script_exists(h)
+ self.assertFalse(r)
+ r = yield self.db.script_load(self._SCRIPT)
+ self.assertEqual(r, h)
+ r = yield self.db.script_exists(h)
+ self.assertTrue(r)
+
+ @defer.inlineCallbacks
+ def test_script_exists(self):
+ self._skipCheck()
+ h = self._hash_script(self._SCRIPT)
+ script1 = "return 1"
+ h1 = self._hash_script(script1)
+ r = yield self.db.script_exists(h)
+ self.assertFalse(r)
+ r = yield self.db.script_exists(h, h1)
+ self.assertEqual(r, [False, False])
+ yield self.db.script_load(script1)
+ r = yield self.db.script_exists(h, h1)
+ self.assertEqual(r, [False, True])
+ yield self.db.script_load(self._SCRIPT)
+ r = yield self.db.script_exists(h, h1)
+ self.assertEqual(r, [True, True])
+
+ @defer.inlineCallbacks
+ def test_script_kill(self):
+ self._skipCheck()
+ try:
+ result = yield self.db.script_kill()
+ except redis.NoScriptRunning:
+ pass
+ except:
+ raise self.failureException('%s raised instead of %s:\n %s'
+ % (sys.exc_info()[0],
+ 'txredisapi.NoScriptRunning',
+ failure.Failure().getTraceback()))
+ else:
+ raise self.failureException('%s not raised (%r returned)'
+ % ('txredisapi.ResponseError', result))
+ # Run an infinite loop script from one connection
+ # and kill it from another.
+ inf_loop = "while 1 do end"
+ self.db1 = yield redis.Connection(redis_host, redis_port,
+ reconnect=False)
+ eval_deferred = self.db1.eval(inf_loop)
+ reactor.iterate()
+ r = yield self.db.script_kill()
+ self.assertEqual(r, 'OK')
+ try:
+ result = yield eval_deferred
+ except redis.ResponseError:
+ pass
+ except:
+ raise self.failureException('%s raised instead of %s:\n %s'
+ % (sys.exc_info()[0],
+ 'txredisapi.ResponseError',
+ failure.Failure().getTraceback()))
+ else:
+ raise self.failureException('%s not raised (%r returned)'
+ % ('txredisapi.ResponseError', result))
+
+ def _check_eval_result(self, d, r):
+ n = len(r)
+ self.assertEqual(n, len(d) * 2)
+ new_d = dict(zip(r[:n/2], r[n/2:]))
+ for k, v in d.items():
+ assert new_d[k] == v
+
+ def _hash_script(self, script):
+ return hashlib.sha1(script).hexdigest()
125 txredisapi.py
View
@@ -30,6 +30,7 @@
import warnings
import zlib
import string
+import hashlib
from twisted.internet import defer
from twisted.internet import protocol
@@ -38,6 +39,7 @@
from twisted.protocols import basic
from twisted.protocols import policies
from twisted.python import log
+from twisted.python.failure import Failure
class RedisError(Exception):
@@ -52,6 +54,14 @@ class ResponseError(RedisError):
pass
+class ScriptDoesNotExist(ResponseError):
+ pass
+
+
+class NoScriptRunning(ResponseError):
+ pass
+
+
class InvalidResponse(RedisError):
pass
@@ -199,6 +209,8 @@ def __init__(self, charset="utf-8", errors="strict"):
self.inTransaction = False
self.unwatch_cc = lambda: ()
+ self.script_hashes = set()
+
@defer.inlineCallbacks
def connectionMade(self):
if self.factory.dbid is not None:
@@ -222,6 +234,7 @@ def connectionMade(self):
def connectionLost(self, why):
self.connected = 0
+ self.script_hashes.clear()
self.factory.delConnection(self)
LineReceiver.connectionLost(self, why)
while self.replyQueue.waiting:
@@ -291,7 +304,7 @@ def lineReceived(self, line):
reply = int(data)
except ValueError:
reply = InvalidResponse(
- "Cannot convert data '%s' to integer" % data)
+ "Cannot convert data '%s' to integer" % data)
if self.multi_bulk.pending:
self.handleMultiBulkElement(reply)
@@ -1295,15 +1308,117 @@ def bgrewriteaof(self):
"""
return self.execute_command("BGREWRITEAOF")
+ def _process_info(self, r):
+ keypairs = [x for x in r.split('\r\n') if
+ u':' in x and not x.startswith(u'#')]
+ d = {}
+ for kv in keypairs:
+ k, v = kv.split(u':')
+ d[k] = v
+ return d
+
# Remote server control commands
- def info(self):
+ def info(self, type=None):
"""
Provide information and statistics about the server
"""
- return self.execute_command("INFO")
+ if type is None:
+ return self.execute_command("INFO")
+ else:
+ r = self.execute_command("INFO", type)
+ return r.addCallback(self._process_info)
# slaveof is missing
+ # Redis 2.6 scripting commands
+ def _eval(self, script, script_hash, **kwargs):
+ n = len(kwargs)
+ if n == 0:
+ args = ()
+ else:
+ keys, values = zip(*kwargs.items())
+ args = keys + values
+ r = self.execute_command("EVAL", script, n, *args)
+ if script_hash in self.script_hashes:
+ return r
+ return r.addCallback(self._eval_success, script_hash)
+
+ def _eval_success(self, r, script_hash):
+ self.script_hashes.add(script_hash)
+ return r
+
+ def _evalsha_failed(self, err, script, script_hash, **kwargs):
+ if err.check(ScriptDoesNotExist):
+ return self._eval(script, script_hash, **kwargs)
+ return err
+
+ def eval(self, script, **kwargs):
+ h = hashlib.sha1(script).hexdigest()
+ if h in self.script_hashes:
+ return self.evalsha(h, **kwargs).addErrback(
+ self._evalsha_failed, script, h, **kwargs)
+ return self._eval(script, h, **kwargs)
+
+ def _evalsha_errback(self, err, script_hash):
+ if err.check(ResponseError):
+ if err.value.args[0].startswith(u'NOSCRIPT'):
+ if script_hash in self.script_hashes:
+ self.script_hashes.remove(script_hash)
+ raise ScriptDoesNotExist("No script matching hash: %s found" %
+ script_hash)
+ return err
+
+ def evalsha(self, sha1_hash, **kwargs):
+ n = len(kwargs)
+ if n == 0:
+ args = ()
+ else:
+ keys, values = zip(*kwargs.items())
+ args = keys + values
+ r = self.execute_command("EVALSHA",
+ sha1_hash, n,
+ *args).addErrback(self._evalsha_errback,
+ sha1_hash)
+ if sha1_hash not in self.script_hashes:
+ r.addCallback(self._eval_success, sha1_hash)
+ return r
+
+ def _script_exists_success(self, r):
+ l = [bool(x) for x in r]
+ if len(l) == 1:
+ return l[0]
+ else:
+ return l
+
+ def script_exists(self, *hashes):
+ return self.execute_command("SCRIPT", "EXISTS",
+ post_proc=self._script_exists_success,
+ *hashes)
+
+ def _script_flush_success(self, r):
+ self.script_hashes.clear()
+ return r
+
+ def script_flush(self):
+ return self.execute_command("SCRIPT", "FLUSH").addCallback(
+ self._script_flush_success)
+
+ def _handle_script_kill(self, r):
+ if isinstance(r, Failure):
+ if r.check(ResponseError):
+ if r.value.args[0].startswith(u'NOTBUSY'):
+ raise NoScriptRunning("No script running")
+ else:
+ pass
+ return r
+
+ def script_kill(self):
+ return self.execute_command("SCRIPT",
+ "KILL").addBoth(self._handle_script_kill)
+
+ def script_load(self, script):
+ return self.execute_command("SCRIPT", "LOAD", script)
+
class MonitorProtocol(RedisProtocol):
"""
@@ -1360,6 +1475,7 @@ def punsubscribe(self, patterns):
patterns = [patterns]
return self.execute_command("PUNSUBSCRIBE", *patterns)
+
class ConnectionHandler(object):
def __init__(self, factory):
self._factory = factory
@@ -1680,6 +1796,7 @@ def getConnection(self):
raise RedisError("In transaction")
+
class SubscriberFactory(RedisFactory):
protocol = SubscriberProtocol
@@ -1687,6 +1804,7 @@ def __init__(self, isLazy=False, handler=ConnectionHandler):
RedisFactory.__init__(self, None, None, 1, isLazy=isLazy,
handler=handler)
+
class MonitorFactory(RedisFactory):
protocol = MonitorProtocol
@@ -1694,6 +1812,7 @@ def __init__(self, isLazy=False, handler=ConnectionHandler):
RedisFactory.__init__(self, None, None, 1, isLazy=isLazy,
handler=handler)
+
def makeConnection(host, port, dbid, poolsize, reconnect, isLazy):
uuid = "%s:%s" % (host, port)
factory = RedisFactory(uuid, dbid, poolsize, isLazy, ConnectionHandler)
Please sign in to comment.
Something went wrong with that request. Please try again.