Skip to content

Commit

Permalink
Python: Added a pickle transport to support object serialization
Browse files Browse the repository at this point in the history
  • Loading branch information
Nathan Farrington committed Jun 15, 2014
1 parent 49fe999 commit 7f50f34
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 13 deletions.
6 changes: 6 additions & 0 deletions README.markdown
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,12 @@ calcaultor.div(2)
assert calculator.val() == 4
```

The Python client also takes the additional keyword parameters `timeout`
and `transport`. If `transport='pickle'` then it is possible to send Python
objects as function arguments, but only if the server is also implemented in
Python, and only if the server has access to the same libraries in order to
unpickle the objects.

### server.py

```python
Expand Down
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.3.6rc3
0.3.6rc4
9 changes: 7 additions & 2 deletions python/examples/client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#!/usr/bin/env python

import argparse
import logging
import traceback
import sys
Expand All @@ -16,7 +17,11 @@
# Direct all RedisPRC logging messages to stderr.
logging.basicConfig(stream=sys.stderr, level=logging.DEBUG)


parser = argparse.ArgumentParser(description='Example calculator server')
parser.add_argument('--transport', choices=('json', 'pickle'), default='json',
help='data encoding used for transport')
args = parser.parse_args()

def do_calculations(calculator):
calculator.clr()
calculator.add(5)
Expand All @@ -38,6 +43,6 @@ def do_calculations(calculator):
# 2. Remote object, should act like local object
redis_server = redis.Redis()
message_queue = 'calc'
calculator = redisrpc.Client(redis_server, message_queue, timeout=1)
calculator = redisrpc.Client(redis_server, message_queue, timeout=1, transport=args.transport)
do_calculations(calculator)
print('success!')
2 changes: 1 addition & 1 deletion python/examples/server.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#!/usr/bin/env python

import argparse
import logging
import sys

Expand All @@ -15,7 +16,6 @@
# Direct all RedisPRC logging messages to stderr.
logging.basicConfig(stream=sys.stderr, level=logging.DEBUG)


redis_server = redis.Redis()
message_queue = 'calc'
local_object = calc.Calculator()
Expand Down
61 changes: 52 additions & 9 deletions python/redisrpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import json
import logging
import pickle
import random
import string
import sys
Expand Down Expand Up @@ -90,29 +91,72 @@ def as_python_code(self):
return '%s(%s)' % (self['name'], params)


def decode_message(message):
"""Returns a (transport, decoded_message) pair."""
# Try JSON, then try Python pickle, then fail.
try:
return JSONTransport.create(), json.loads(message)
except:
pass
return PickleTransport.create(), pickle.loads(message)


class JSONTransport(object):
"""Cross platform transport."""
_singleton = None
@classmethod
def create(cls):
if cls._singleton is None:
cls._singleton = JSONTransport()
return cls._singleton
def dumps(self, obj):
return json.dumps(obj)
def loads(self, obj):
return json.loads(obj)


class PickleTransport(object):
"""Only works with Python clients and servers."""
_singleton = None
@classmethod
def create(cls):
if cls._singleton is None:
cls._singleton = PickleTransport()
return cls._singleton
def dumps(self, obj):
return pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
def loads(self, obj):
return pickle.loads(obj)

class Client(object):
"""Calls remote functions using Redis as a message queue."""

def __init__(self, redis_server, message_queue, timeout=0):
def __init__(self, redis_server, message_queue, timeout=0, transport='json'):
self.redis_server = redis_server
self.message_queue = message_queue
self.timeout = timeout
if transport == 'json':
self.transport = JSONTransport()
elif transport == 'pickle':
self.transport = PickleTransport()
else:
raise Exception('invalid transport {0}'.format(transport))

def call(self, method_name, *args, **kwargs):
function_call = FunctionCall(method_name, args, kwargs)
response_queue = self.message_queue + ':rpc:' + random_string()
rpc_request = dict(function_call=function_call, response_queue=response_queue)
message = json.dumps(rpc_request)
message = self.transport.dumps(rpc_request)
logging.debug('RPC Request: %s' % message)
self.redis_server.rpush(self.message_queue, message)
result = self.redis_server.blpop(response_queue, self.timeout)
if result is None: raise TimeoutException()
if result is None:
raise TimeoutException()
message_queue, message = result
message_queue = message_queue.decode()
message = message.decode()
assert message_queue == response_queue
logging.debug('RPC Response: %s' % message)
rpc_response = json.loads(message)
rpc_response = self.transport.loads(message)
exception = rpc_response.get('exception')
if exception is not None:
raise RemoteException(exception)
Expand All @@ -132,17 +176,16 @@ def __init__(self, redis_server, message_queue, local_object):
self.redis_server = redis_server
self.message_queue = message_queue
self.local_object = local_object

def run(self):
# Flush the message queue.
self.redis_server.delete(self.message_queue)
while True:
message_queue, message = self.redis_server.blpop(self.message_queue)
message_queue = message_queue.decode()
message = message.decode()
assert message_queue == self.message_queue
logging.debug('RPC Request: %s' % message)
rpc_request = json.loads(message)
transport, rpc_request = decode_message(message)
response_queue = rpc_request['response_queue']
function_call = FunctionCall.from_dict(rpc_request['function_call'])
code = 'self.return_value = self.local_object.' + function_call.as_python_code()
Expand All @@ -152,7 +195,7 @@ def run(self):
except:
(type, value, traceback) = sys.exc_info()
rpc_response = dict(exception=repr(value))
message = json.dumps(rpc_response)
message = transport.dumps(rpc_response)
logging.debug('RPC Response: %s' % message)
self.redis_server.rpush(response_queue, message)

Expand Down
41 changes: 41 additions & 0 deletions python/tests/test_redisrpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,39 @@

TIMEOUT=5

class Bar(object):
"""A class used only for serialization."""
def __init__(self):
self.a = 1
self.b = 2
self.c = 3
def __eq__(self, rhs):
try:
return self.a == rhs.a and self.b == rhs.b and self.c == rhs.c
except:
return False

class Foo(object):
"""A class used for redisrpc testing."""
MESSAGE_QUEUE = 'foo'
def return_none(self):
return None
def return_true(self):
return True
def return_false(self):
return False
def return_int(self):
return 1
def return_float(self):
return 3.14159
def return_string(self):
return "STRING"
def return_list(self):
return [1, 2, 3]
def return_dict(self):
return dict(a=1, b=2)
def return_obj(self):
return Bar()

def start_thread(redisdb):
"""Run the redisrpc server in a thread."""
Expand All @@ -42,11 +62,26 @@ def test_redisrpc_none(redisdb):
client = redisrpc.Client(redisdb, Foo.MESSAGE_QUEUE, timeout=TIMEOUT)
assert client.return_none() is None

def test_redisrpc_true(redisdb):
server = start_thread(redisdb)
client = redisrpc.Client(redisdb, Foo.MESSAGE_QUEUE, timeout=TIMEOUT)
assert client.return_true() is True

def test_redisrpc_false(redisdb):
server = start_thread(redisdb)
client = redisrpc.Client(redisdb, Foo.MESSAGE_QUEUE, timeout=TIMEOUT)
assert client.return_false() is False

def test_redisrpc_int(redisdb):
server = start_thread(redisdb)
client = redisrpc.Client(redisdb, Foo.MESSAGE_QUEUE, timeout=TIMEOUT)
assert client.return_int() == 1

def test_redisrpc_float(redisdb):
server = start_thread(redisdb)
client = redisrpc.Client(redisdb, Foo.MESSAGE_QUEUE, timeout=TIMEOUT)
assert abs(client.return_float() - 3.14159) < 1e-15

def test_redisrpc_string(redisdb):
server = start_thread(redisdb)
client = redisrpc.Client(redisdb, Foo.MESSAGE_QUEUE, timeout=TIMEOUT)
Expand All @@ -61,3 +96,9 @@ def test_redisrpc_dict(redisdb):
server = start_thread(redisdb)
client = redisrpc.Client(redisdb, Foo.MESSAGE_QUEUE, timeout=TIMEOUT)
assert client.return_dict() == dict(a=1, b=2)

def test_redisrpc_obj(redisdb):
"""Objects are not JSON serializable by default."""
server = start_thread(redisdb)
client = redisrpc.Client(redisdb, Foo.MESSAGE_QUEUE, timeout=TIMEOUT, transport='pickle')
assert client.return_obj() == Bar()

0 comments on commit 7f50f34

Please sign in to comment.