Skip to content

Commit

Permalink
Use tuples in msgpack (#2000)
Browse files Browse the repository at this point in the history
  • Loading branch information
mrocklin committed Jun 12, 2018
1 parent 668b7fd commit 52749ff
Show file tree
Hide file tree
Showing 12 changed files with 71 additions and 42 deletions.
19 changes: 10 additions & 9 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -960,8 +960,8 @@ def _handle_report(self):
continue
else:
break
if not isinstance(msgs, list):
msgs = [msgs]
if not isinstance(msgs, (list, tuple)):
msgs = (msgs,)

breakout = False
for msg in msgs:
Expand Down Expand Up @@ -2665,7 +2665,7 @@ def ncores(self, workers=None, **kwargs):
if (isinstance(workers, tuple)
and all(isinstance(i, (str, tuple)) for i in workers)):
workers = list(workers)
if workers is not None and not isinstance(workers, (list, set)):
if workers is not None and not isinstance(workers, (tuple, list, set)):
workers = [workers]
return self.sync(self.scheduler.ncores, workers=workers, **kwargs)

Expand Down Expand Up @@ -2731,7 +2731,7 @@ def has_what(self, workers=None, **kwargs):
if (isinstance(workers, tuple)
and all(isinstance(i, (str, tuple)) for i in workers)):
workers = list(workers)
if workers is not None and not isinstance(workers, (list, set)):
if workers is not None and not isinstance(workers, (tuple, list, set)):
workers = [workers]
return self.sync(self.scheduler.has_what, workers=workers, **kwargs)

Expand Down Expand Up @@ -2760,7 +2760,7 @@ def processing(self, workers=None):
if (isinstance(workers, tuple)
and all(isinstance(i, (str, tuple)) for i in workers)):
workers = list(workers)
if workers is not None and not isinstance(workers, (list, set)):
if workers is not None and not isinstance(workers, (tuple, list, set)):
workers = [workers]
return self.sync(self.scheduler.processing, workers=workers)

Expand Down Expand Up @@ -2910,8 +2910,8 @@ def get_metadata(self, keys, default=no_default):
--------
Client.set_metadata
"""
if not isinstance(keys, list):
keys = [keys]
if not isinstance(keys, (list, tuple)):
keys = (keys,)
return self.sync(self.scheduler.get_metadata, keys=keys,
default=default)

Expand Down Expand Up @@ -3012,7 +3012,7 @@ def set_metadata(self, key, value):
get_metadata
"""
if not isinstance(key, list):
key = [key]
key = (key,)
return self.sync(self.scheduler.set_metadata, keys=key, value=value)

def get_versions(self, check=False):
Expand Down Expand Up @@ -3040,7 +3040,8 @@ def get_versions(self, check=False):
if check:
# we care about the required & optional packages matching
def to_packages(d):
return dict(sum(d['packages'].values(), []))
L = list(d['packages'].values())
return dict(sum(L, type(L[0])()))
client_versions = to_packages(result['client'])
versions = [('scheduler', to_packages(result['scheduler']))]
versions.extend((w, to_packages(d))
Expand Down
4 changes: 2 additions & 2 deletions distributed/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,8 +359,8 @@ def handle_stream(self, comm, extra=None, every_cycle=[]):
try:
while not closed:
msgs = yield comm.read()
if not isinstance(msgs, list):
msgs = [msgs]
if not isinstance(msgs, (tuple, list)):
msgs = (msgs,)

if not comm.closed():
for msg in msgs:
Expand Down
2 changes: 1 addition & 1 deletion distributed/diagnostics/tests/test_eventstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_eventstream_remote(c, s, a, b):
total = []
while len(total) < 10:
msgs = yield comm.read()
assert isinstance(msgs, list)
assert isinstance(msgs, tuple)
total.extend(msgs)
assert time() < start + 5

Expand Down
23 changes: 18 additions & 5 deletions distributed/protocol/core.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from __future__ import print_function, division, absolute_import

import logging
import operator

import msgpack

try:
from cytoolz import get_in
from cytoolz import reduce
except ImportError:
from toolz import get_in
from toolz import reduce

from .compression import compressions, maybe_compress, decompress
from .serialize import (serialize, deserialize, Serialize, Serialized,
Expand Down Expand Up @@ -122,7 +123,19 @@ def loads(frames, deserialize=True, deserializers=None):
else:
value = Serialized(head, fs)

get_in(key[:-1], msg)[key[-1]] = value
def put_in(keys, coll, val):
"""Inverse of get_in, but does type promotion in the case of lists"""
if keys:
holder = reduce(operator.getitem, keys[:-1], coll)
if isinstance(holder, tuple):
holder = list(holder)
coll = put_in(keys[:-1], coll, holder)
holder[keys[-1]] = val
else:
coll = val
return coll

msg = put_in(key, msg, value)

return msg
except Exception:
Expand Down Expand Up @@ -160,7 +173,7 @@ def loads_msgpack(header, payload):
dumps_msgpack
"""
if header:
header = msgpack.loads(header, encoding='utf8')
header = msgpack.loads(header, encoding='utf8', use_list=False)
else:
header = {}

Expand All @@ -172,4 +185,4 @@ def loads_msgpack(header, payload):
raise ValueError("Data is compressed as %s but we don't have this"
" installed" % str(header['compression']))

return msgpack.loads(payload, encoding='utf8')
return msgpack.loads(payload, encoding='utf8', use_list=False)
4 changes: 2 additions & 2 deletions distributed/protocol/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def msgpack_dumps(x):


def msgpack_loads(header, frames):
return msgpack.loads(b''.join(frames), encoding='utf8')
return msgpack.loads(b''.join(frames), encoding='utf8', use_list=False)


def serialization_error_loads(header, frames):
Expand Down Expand Up @@ -441,7 +441,7 @@ def deserialize_bytes(b):
frames = unpack_frames(b)
header, frames = frames[0], frames[1:]
if header:
header = msgpack.loads(header, encoding='utf8')
header = msgpack.loads(header, encoding='utf8', use_list=False)
else:
header = {}
frames = decompress(header, frames)
Expand Down
2 changes: 1 addition & 1 deletion distributed/protocol/tests/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_small():


def test_small_and_big():
d = {'x': [1, 2, 3], 'y': b'0' * 10000000}
d = {'x': (1, 2, 3), 'y': b'0' * 10000000}
L = dumps(d)
assert loads(L) == d
# assert loads([small_header, small]) == {'x': [1, 2, 3]}
Expand Down
7 changes: 7 additions & 0 deletions distributed/protocol/tests/test_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,13 @@ def test_empty_loads():
assert isinstance(e2[0], Empty)


def test_empty_loads_deep():
from distributed.protocol import loads, dumps
e = Empty()
e2 = loads(dumps([[[to_serialize(e)]]]))
assert isinstance(e2[0][0][0], Empty)


def test_serialize_bytes():
for x in [1, 'abc', np.arange(5)]:
b = serialize_bytes(x)
Expand Down
10 changes: 5 additions & 5 deletions distributed/tests/test_batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ def test_BatchedSend():
b.send('HELLO')

result = yield comm.read()
assert result == ['hello', 'hello', 'world']
assert result == ('hello', 'hello', 'world')
result = yield comm.read()
assert result == ['HELLO', 'HELLO']
assert result == ('HELLO', 'HELLO')

assert b.byte_count > 1

Expand All @@ -88,7 +88,7 @@ def test_send_before_start():

b.start(comm)
result = yield comm.read()
assert result == ['hello', 'world']
assert result == ('hello', 'world')


@gen_test()
Expand All @@ -104,7 +104,7 @@ def test_send_after_stream_start():
result = yield comm.read()
if len(result) < 2:
result += yield comm.read()
assert result == ['hello', 'world']
assert result == ('hello', 'world')


@gen_test()
Expand Down Expand Up @@ -295,7 +295,7 @@ def test_serializers():
assert 'function' in value

msg = yield comm.read()
assert msg == [{'x': 123}, {'x': 'hello'}]
assert list(msg) == [{'x': 123}, {'x': 'hello'}]

with pytest.raises(gen.TimeoutError):
msg = yield gen.with_timeout(timedelta(milliseconds=100), comm.read())
28 changes: 18 additions & 10 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3340,7 +3340,7 @@ def test_default_get():
@gen_cluster(client=True)
def test_get_processing(c, s, a, b):
processing = yield c.processing()
assert processing == valmap(list, s.processing)
assert processing == valmap(tuple, s.processing)

futures = c.map(slowinc, range(10), delay=0.1, workers=[a.address],
allow_other_workers=True)
Expand All @@ -3351,7 +3351,7 @@ def test_get_processing(c, s, a, b):
assert set(x) == {a.address, b.address}

x = yield c.processing(workers=[a.address])
assert isinstance(x[a.address], list)
assert isinstance(x[a.address], (list, tuple))


@gen_cluster(client=True)
Expand Down Expand Up @@ -3384,6 +3384,14 @@ def test_get_foo(c, s, a, b):
assert valmap(sorted, x) == {futures[0].key: sorted(s.who_has[futures[0].key])}


def assert_dict_key_equal(expected, actual):
assert set(expected.keys()) == set(actual.keys())
for k in actual.keys():
ev = expected[k]
av = actual[k]
assert list(ev) == list(av)


@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 3)
def test_get_foo_lost_keys(c, s, u, v, w):
x = c.submit(inc, 1, workers=[u.address])
Expand All @@ -3393,27 +3401,27 @@ def test_get_foo_lost_keys(c, s, u, v, w):
ua, va, wa = u.address, v.address, w.address

d = yield c.scheduler.has_what()
assert d == {ua: [x.key], va: [y.key], wa: []}
assert_dict_key_equal(d, {ua: [x.key], va: [y.key], wa: []})
d = yield c.scheduler.has_what(workers=[ua, va])
assert d == {ua: [x.key], va: [y.key]}
assert_dict_key_equal(d, {ua: [x.key], va: [y.key]})
d = yield c.scheduler.who_has()
assert d == {x.key: [ua], y.key: [va]}
assert_dict_key_equal(d, {x.key: [ua], y.key: [va]})
d = yield c.scheduler.who_has(keys=[x.key, y.key])
assert d == {x.key: [ua], y.key: [va]}
assert_dict_key_equal(d, {x.key: [ua], y.key: [va]})

yield u._close()
yield v._close()

d = yield c.scheduler.has_what()
assert d == {wa: []}
assert_dict_key_equal(d, {wa: []})
d = yield c.scheduler.has_what(workers=[ua, va])
assert d == {ua: [], va: []}
assert_dict_key_equal(d, {ua: [], va: []})
# The scattered key cannot be recomputed so it is forgotten
d = yield c.scheduler.who_has()
assert d == {x.key: []}
assert_dict_key_equal(d, {x.key: []})
# ... but when passed explicitly, it is included in the result
d = yield c.scheduler.who_has(keys=[x.key, y.key])
assert d == {x.key: [], y.key: []}
assert_dict_key_equal(d, {x.key: [], y.key: []})


@slow
Expand Down
6 changes: 3 additions & 3 deletions distributed/tests/test_publish.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ def test_publish_simple(s, a, b):
assert "data" in str(exc_info.value)

result = yield c.scheduler.publish_list()
assert result == ['data']
assert result == ('data',)

result = yield f.scheduler.publish_list()
assert result == ['data']
assert result == ('data',)

yield c.close()
yield f.close()
Expand Down Expand Up @@ -221,7 +221,7 @@ def test_pickle_safe(c, s, a, b):
try:
yield c2.publish_dataset(x=[1, 2, 3])
result = yield c2.get_dataset('x')
assert result == [1, 2, 3]
assert result == (1, 2, 3)

with pytest.raises(TypeError):
yield c2.publish_dataset(y=lambda x: x)
Expand Down
4 changes: 2 additions & 2 deletions distributed/tests/test_queues.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ def test_queue_with_data(c, s, a, b):
xx = yield Queue('x')
assert x.client is c

yield x.put([1, 'hello'])
yield x.put((1, 'hello'))
data = yield xx.get()

assert data == [1, 'hello']
assert data == (1, 'hello')

with pytest.raises(gen.TimeoutError):
yield x.get(timeout=0.1)
Expand Down
4 changes: 2 additions & 2 deletions distributed/tests/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ def test_queue_with_data(c, s, a, b):
xx = Variable('x')
assert x.client is c

yield x.set([1, 'hello'])
yield x.set((1, 'hello'))
data = yield xx.get()

assert data == [1, 'hello']
assert data == (1, 'hello')


def test_sync(loop):
Expand Down

0 comments on commit 52749ff

Please sign in to comment.