From 5ce888ccb31a9c9eddcfca1c744d6e06a02ead93 Mon Sep 17 00:00:00 2001 From: Daniel Pope Date: Sat, 15 Jul 2017 17:50:14 +0200 Subject: [PATCH] Extend pencode to handle most basic Python types --- chopsticks/bubble.py | 31 ++++--- chopsticks/pencode.py | 206 +++++++++++++++++++++++++++--------------- tests/test_pencode.py | 125 +++++++++++++++++++++++++ 3 files changed, 277 insertions(+), 85 deletions(-) create mode 100644 tests/test_pencode.py diff --git a/chopsticks/bubble.py b/chopsticks/bubble.py index f018f60..614114c 100644 --- a/chopsticks/bubble.py +++ b/chopsticks/bubble.py @@ -346,10 +346,6 @@ def handle_start(req_id, host, path, depthlimit): MSG_PACK = 2 -# The source code from chopsticks.pencode will be substituted here -{{ PENCODE }} - - def send_msg(op, req_id, data): """Send a message to the orchestration host. @@ -408,6 +404,10 @@ def reader(): if msg is None: break req_id, op, params = msg + if PY2: + params = dict((str(k), v) for k, v in params.iteritems()) + else: + params = dict((force_str(k), v) for k, v in params.items()) HANDLERS[op](req_id, **params) finally: outqueue.put(done) @@ -422,11 +422,20 @@ def writer(): outpipe.write(msg) -for func in (reader, writer): - threading.Thread(target=func).start() -while True: - task = tasks.get() - if task is done: - break - do_call(*task) +def run(): + for func in (reader, writer): + threading.Thread(target=func).start() + + while True: + task = tasks.get() + if task is done: + break + do_call(*task) + + +# The source code from chopsticks.pencode will be substituted here +# We do this at the end to minimise changes to line numbers +{{ PENCODE }} + +run() diff --git a/chopsticks/pencode.py b/chopsticks/pencode.py index fbad169..dc63115 100644 --- a/chopsticks/pencode.py +++ b/chopsticks/pencode.py @@ -5,6 +5,7 @@ than encoding bytes-as-base64-in-JSON. """ +import sys import struct import codecs @@ -13,60 +14,87 @@ utf8_decode = codecs.getdecoder('utf8') -PY3 = bool(1 / 2) +PY3 = sys.version_info >= (3,) PY2 = not PY3 if PY3: unicode = str + range_ = range + long = int else: bytes = str + range_ = xrange + + +def pencode(obj): + """Encode the given Python primitive structure, returning a byte string.""" + p = Pencoder() + p._pencode(obj) + return p.getvalue() def bsz(seq): - """Encode the length of a sequence as a big-endian 4-byte unsigned int.""" + """Encode the length of a sequence as big-endian 4-byte uint.""" return SZ.pack(len(seq)) -def pencode(obj): - """Encode the given Python primitive structure, returning a byte string.""" - out = [] - _pencode(obj, out) - return b''.join(out) - - -def _pencode(obj, out): - """Inner function for encoding of structures.""" - if isinstance(obj, bytes): - out.extend([b'b', bsz(obj), obj]) - elif isinstance(obj, unicode): - bs = obj.encode('utf8') - out.extend([b's', bsz(bs), bs]) - elif isinstance(obj, bool): - out.extend([b'1', b't' if obj else b'f']) - elif isinstance(obj, int): - bs = str(int(obj)).encode('ascii') - out.extend([b'i', bsz(bs), bs]) - elif isinstance(obj, (tuple, list)): - code = b'l' if isinstance(obj, list) else b't' - out.extend([code, bsz(obj)]) - for item in obj: - _pencode(item, out) - elif isinstance(obj, dict): - out.extend([b'd', bsz(obj)]) - for k in obj: - if isinstance(k, str): - if PY2: - kbs = str(k) - else: - kbs = str(k).encode('utf8') - out.extend([b'k', bsz(kbs), kbs]) - else: - _pencode(k, out) - _pencode(obj[k], out) - elif obj is None: - out.append(b'n') - else: - raise ValueError('Unserialisable type %s' % type(obj)) +SEQTYPE_CODES = { + set: b'q', + frozenset: b'Q', + list: b'l', + tuple: b't', +} +CODE_SEQTYPES = dict((v, k) for k, v in SEQTYPE_CODES.items()) + + +class Pencoder(object): + def __init__(self): + self.out = [] + self.objs = 0 + self.backrefs = {} + + def getvalue(self): + return b''.join(self.out) + + def _pencode(self, obj): + """Inner function for encoding of structures.""" + out = self.out + objid = id(obj) + if objid in self.backrefs: + out.extend([b'R', SZ.pack(self.backrefs[objid])]) + return + else: + self.backrefs[objid] = len(self.backrefs) + + otype = type(obj) + + if isinstance(obj, bytes): + out.extend([b'b', bsz(obj), obj]) + elif isinstance(obj, unicode): + bs = obj.encode('utf8') + out.extend([b's', bsz(bs), bs]) + elif isinstance(obj, bool): + out.extend([b'1', b't' if obj else b'f']) + elif isinstance(obj, (int, long)): + bs = str(int(obj)).encode('ascii') + out.extend([b'i', bsz(bs), bs]) + elif isinstance(obj, float): + bs = str(float(obj)).encode('ascii') + out.extend([b'f', bsz(bs), bs]) + elif otype in SEQTYPE_CODES: + code = SEQTYPE_CODES[otype] + out.extend([code, bsz(obj)]) + for item in obj: + self._pencode(item) + elif isinstance(obj, dict): + out.extend([b'd', bsz(obj)]) + for k in obj: + self._pencode(k) + self._pencode(obj[k]) + elif obj is None: + out.append(b'n') + else: + raise ValueError('Unserialisable type %s' % type(obj)) class obuf(object): @@ -88,35 +116,65 @@ def read_bytes(self, n): def pdecode(buf): """Decode a pencoded byte string to a structure.""" - return _decode(obuf(buf)) - - -def _decode(obuf): - code = obuf.read_bytes(1) - if code == b'k': - code = b'b' if PY2 else b's' - - if code == b'n': - return None - elif code == b'b': - sz = obuf.read_size() - return obuf.read_bytes(sz) - elif code == b's': - sz = obuf.read_size() - return utf8_decode(obuf.read_bytes(sz))[0] - elif code == b'1': - return obuf.read_bytes(1) == b't' - elif code == b'i': - sz = obuf.read_size() - return int(obuf.read_bytes(sz)) - elif code == b'l': - sz = obuf.read_size() - return [_decode(obuf) for _ in range(sz)] - elif code == b't': - sz = obuf.read_size() - return tuple(_decode(obuf) for _ in range(sz)) - elif code == b'd': - sz = obuf.read_size() - return dict((_decode(obuf), _decode(obuf)) for _ in range(sz)) - else: - raise ValueError('Unknown pack opcode %r' % code) + return PDecoder().decode(buf) + + +class PDecoder(object): + def __init__(self): + self.br_count = 0 + self.backrefs = {} + + def decode(self, buf): + return self._decode(obuf(buf)) + + def _decode(self, obuf): + br_id = self.br_count + self.br_count += 1 + + code = obuf.read_bytes(1) + if code == b'n': + obj = None + elif code == b'b': + sz = obuf.read_size() + obj = obuf.read_bytes(sz) + elif code == b's': + sz = obuf.read_size() + obj = utf8_decode(obuf.read_bytes(sz))[0] + elif code == b'1': + obj = obuf.read_bytes(1) == b't' + elif code == b'i': + sz = obuf.read_size() + obj = int(obuf.read_bytes(sz)) + elif code == b'f': + sz = obuf.read_size() + obj = float(obuf.read_bytes(sz)) + elif code == b'l': + sz = obuf.read_size() + obj = [] + self.backrefs[br_id] = obj + obj.extend(self._decode(obuf) for _ in range_(sz)) + elif code == b'q': + sz = obuf.read_size() + obj = set() + self.backrefs[br_id] = obj + obj.update(self._decode(obuf) for _ in range_(sz)) + elif code in (b't', b'Q'): + cls = tuple if code == b't' else frozenset + sz = obuf.read_size() + obj = cls(self._decode(obuf) for _ in range_(sz)) + elif code == b'd': + sz = obuf.read_size() + obj = {} + self.backrefs[br_id] = obj + for _ in range_(sz): + key = self._decode(obuf) + value = self._decode(obuf) + obj[key] = value + elif code == b'R': + ref_id = obuf.read_size() + obj = self.backrefs[ref_id] + else: + raise ValueError('Unknown pack opcode %r' % code) + + self.backrefs[br_id] = obj + return obj diff --git a/tests/test_pencode.py b/tests/test_pencode.py new file mode 100644 index 0000000..72235de --- /dev/null +++ b/tests/test_pencode.py @@ -0,0 +1,125 @@ +# -*- coding: utf-8 -*- +"""Tests for Python-friendly binary encoding.""" +import pytest +from chopsticks.pencode import pencode, pdecode + +bytes = type(b'') + + +def assert_roundtrip(obj): + """Assert that we can successfully round-trip the given object.""" + buf = pencode(obj) + assert isinstance(buf, bytes) + obj2 = pdecode(buf) + + try: + assert obj == obj2 + except RuntimeError as e: + if 'maximum recursion depth exceeded' not in e.args[0]: + raise + # If we hit a RecursionError, we correctly decoded a recursive + # structure, so test passes :) + except RecursionError: + pass + return obj2 + + +def test_roundtrip_unicode(): + """We can round-trip a unicode string.""" + assert_roundtrip(u'I ❤️ emoji') + + +def test_roundtrip_list(): + """We can round-trip a list.""" + assert_roundtrip([1, 2, 3]) + + +def test_roundtrip_self_referential(): + """We can round-trip a self-referential structure.""" + a = [] + a.append(a) + assert_roundtrip(a) + + +def test_roundtrip_backref(): + """References to the same object are preserved.""" + foo = 'foo' + obj = [foo, foo] + buf = pencode(obj) + assert isinstance(buf, bytes) + a, b = pdecode(buf) + assert a is b + + +def test_roundtrip_set(): + """We can round-trip a set.""" + assert_roundtrip({1, 2, 3}) + + +def test_roundtrip_tuple(): + """We can round-trip a tuple of bytes.""" + assert_roundtrip((b'a', b'b', b'c')) + + +def test_roundtrip_frozenset(): + """We can round-trip a frozenset.""" + assert_roundtrip(frozenset([1, 2, 3])) + + +def test_roundtrip_float(): + """We can round-trip a float.""" + assert_roundtrip(1.1) + + +def test_roundtrip_float_inf(): + """We can round-trip inf.""" + assert_roundtrip(float('inf')) + + +def test_roundtrip_long(): + """We can round trip what, in Python 2, would be a long.""" + assert_roundtrip(10121071034790721094712093712037123) + +def test_roundtrip_float_nan(): + """We can round-trip nan.""" + import math + res = pdecode(pencode(float('nan'))) + assert math.isnan(res) + + +def test_roundtrip_dict(): + """We can round-trip a dict, keyed by frozenset.""" + assert_roundtrip({frozenset([1, 2, 3]): 'abc'}) + + +def test_roundtrip_none(): + """We can round-trip None.""" + assert_roundtrip(None) + + +def test_roundtrip_bool(): + """We can round-trip booleans and preserve their types.""" + res = assert_roundtrip((True, False)) + for r in res: + assert isinstance(r, bool) + + +def test_roundtrip_kdict(): + """We handle string-keyed dicts.""" + assert_roundtrip({'abcd': 'efgh'}) + + +def test_unserialisable(): + """An exception is raised if a type is not serialisable.""" + with pytest.raises(ValueError): + pencode(object()) + + +def test_roundtrip_start(): + """A typical start message is round-tripped.""" + host = u'unittest' + assert_roundtrip({ + 'host': host, + 'path': [host], + 'depthlimit': 2 + })