Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for collection serialization #2948

Merged
merged 14 commits into from Aug 16, 2019

Conversation

@pentschev
Copy link
Member

commented Aug 11, 2019

This PR adds support for serialization of collections using objects' native types, rather than pickling the entire collection. It is required for efficient serialization of CUDA objects in rapidsai/dask-cuda#110.

@pentschev

This comment has been minimized.

Copy link
Member Author

commented Aug 11, 2019

@mrocklin I think this is a good and simple solution, let me know what you think. I have for now only added support for tuple, but if you agree with the solution here, I will add support for other collections (list, set and dict) as well.

I didn't only do it on the CUDA side of things because of mixed CUDA/non-CUDA collections such as (cuda.DataFrame, None) which is another real case during cudf.sort(). I even tried only extending that on the CUDA side, but it becomes a much dirtier and complex solution, whereas this seems to work well.

@mrocklin

This comment has been minimized.

Copy link
Member

commented Aug 11, 2019

I didn't only do it on the CUDA side of things because of mixed CUDA/non-CUDA collections such as (cuda.DataFrame, None) which is another real case during cudf.sort(). I even tried only extending that on the CUDA side, but it becomes a much dirtier and complex solution, whereas this seems to work well.

Ah, I missed this comment. Yes, I can understand how mixed type situations might behave oddly. Does anyone have any suggestions?

@pentschev

This comment has been minimized.

Copy link
Member Author

commented Aug 11, 2019

Alternatively, we can allow tuples and such to be serialized entirely by pickle even on CUDA. That will definitely be a step back on device-to-host serialization performance, but it's probably better slow than non-functional. For collections of really small Python objects, I'm not sure if we can make them faster than they are today (before this PR) and while still allowing the collection's objects to be serialized individually.

@mrocklin

This comment has been minimized.

Copy link
Member

commented Aug 14, 2019

For collections of really small Python objects, I'm not sure if we can make them faster than they are today (before this PR) and while still allowing the collection's objects to be serialized individually.

My concern is that we may be making them 10-100x slower with this change.

@mrocklin

This comment has been minimized.

Copy link
Member

commented Aug 14, 2019

I think that for short-term work we might consider doing this only one-level deep, and only in cuda_serialize.

@cuda_serialize.register((tuple, list, set))
def f(obj):
    if not any(is_device_object(o) for o in obj):
        raise NotImplementedError

    ...

We might also include it here, but only if the length is small, like less than five or something. Do you have a sense for common cases where this is important?

@mrocklin

This comment has been minimized.

Copy link
Member

commented Aug 14, 2019

@pentschev and I met by video and decided to handle unpacking tuples, lists, sets, and dicts if the length of the collection is less than some small number, like 5. This should handle common cases where we return an array and something in a small pair, while not significantly affect the overhead in the Python collection of many objects case.

]
headers = [i[0] for i in t]
headers = {"sub-headers": headers, "is-collection": True}
frames = [i[1] for i in t]

This comment has been minimized.

Copy link
@mrocklin

mrocklin Aug 14, 2019

Member

Also, a note from our conversation, we'll probably want something like the following:

frames = []
locations = [0]
loc = 0
for _, _frames in t:
    frames.extend(_frames)
    loc += len(_frames)
    locations.append(loc)
 
header["frame-locations"] = locations

And then we'll have to unpack the frames when we deserialize using header["locations"]

@pentschev

This comment has been minimized.

Copy link
Member Author

commented Aug 15, 2019

@mrocklin I updated the code to reflect the discussion from yesterday, I think everything is in now. Let me know if I missed something.

@mrocklin
Copy link
Member

left a comment

Thanks @pentschev . This looks good generally. As usual, I have many tiny comments :)

h, f = serialize(
v, serializers=serializers, on_error=on_error, context=context
)
h["key"] = pickle.dumps(k)

This comment has been minimized.

Copy link
@mrocklin

mrocklin Aug 15, 2019

Member

We have to be a little bit careful with pickle. Some of our users that work in secure settings want to be able to turn pickle off. Otherwise people can construct types that, when unpickled, do bad things:

class MyTuple(tuple):
    def __setstate__(self, state):  # this is what is called when you call pickle.loads
        os.remove("/")

So, I recommend that we check that the keys are msgpack-serializable and, if so, place them into the header directly and, if not avoid this code path entirely and let the system handle this as if it was not a tuple.

This comment has been minimized.

Copy link
@mrocklin

mrocklin Aug 15, 2019

Member
try:
    msgpack.dumps(k)
except Exception:
    continue on to other serializers somehow
else:
    _header["key"] = k

This comment has been minimized.

Copy link
@pentschev

pentschev Aug 16, 2019

Author Member

I understand. Pushed some changes but I'm not totally sure they comply totally with your suggestion (in particular the exception handling part). Could you please check and let me know if that's what you meant?

distributed/protocol/serialize.py Outdated Show resolved Hide resolved
distributed/protocol/serialize.py Outdated Show resolved Hide resolved
distributed/protocol/serialize.py Outdated Show resolved Hide resolved
distributed/protocol/serialize.py Outdated Show resolved Hide resolved
distributed/protocol/serialize.py Outdated Show resolved Hide resolved
from dask.dataframe.utils import assert_eq


@pytest.mark.parametrize("collection", [tuple, dict])

This comment has been minimized.

Copy link
@mrocklin

mrocklin Aug 15, 2019

Member

These four tests seem fairly similar. Do you have thoughts on how we might combine them further with parameterization?

This comment has been minimized.

Copy link
@pentschev

pentschev Aug 16, 2019

Author Member

I'm not sure I understand your question. Are you asking if we could combine NumPy and Pandas tests together, or what are the four tests you refer to?

This comment has been minimized.

Copy link
@mrocklin

mrocklin Aug 16, 2019

Member

Yes, the numpy and the pandas tests, also maybe the cuda equivalent tests if it's easy to isolate the cuda imports with try-except blocks.

values = [
    (np.arange(50), "dask"),
    (pd.Series([1, 2, 3]), "dask"),
    (None, "pickle"),
]

try:
    import cudf
except ImportError:
    pass
else:
    values.append((cudf.Series([1, 2, 3]), "cuda"))

try:
...

This comment has been minimized.

Copy link
@pentschev

pentschev Aug 16, 2019

Author Member

Doing that for NumPy/Pandas can be done quite, I've pushed a commit for that. However, for CUDA types, this would incur a lot more conditionals (maybe someone has CuPy but no cuDF installed) which will clutter the code even more, and to be fair, I think it's already very cluttered as it is. That said, if you don't mind, I would prefer keeping them as they are now (different tests for CuPy, cuDF and CPU equivalents).

This comment has been minimized.

Copy link
@mrocklin

mrocklin Aug 16, 2019

Member

I'm fine keeping them separate if you prefer.

if df2 is None:
assert (t["df2"] if isinstance(t, dict) else t[1]) is None
else:
assert_eq(t["df2"] if isinstance(t, dict) else t[1], df2)

This comment has been minimized.

Copy link
@mrocklin

mrocklin Aug 15, 2019

Member

I would also like to see a couple of tests like the following:

def test_large_collections_serialize_simply():
    header, frames = serialize(tuple(range(1000)))
    assert len(frames) == 1

def test_nested_types():
    x = np.ones(5)
    header, frames = [[[x]]]
    assert "dask" in str(header)  
    assert len(frames) == 1
    assert x.data in frames

This comment has been minimized.

Copy link
@mrocklin

mrocklin Aug 15, 2019

Member

Here is a bad type test

def test_dont_pickle_types():

    class EvilTuple(tuple):
        def __getstate__(self):
            return None

        def __setstate__(self, state):
            raise Exception("Evil Exception")

    with pytest.raises(TypeError) as info:
        serialize(EvilTuple(1, 2, 3), serializers=["dask", "msgpack"], on_error="raise")  # note, no pickle serializer

    assert "EvilTuple" in str(info.value)

This comment has been minimized.

Copy link
@pentschev

pentschev Aug 16, 2019

Author Member

I've added the first two tests, but I'm not sure I understand the last one. msgpack doesn't seem to rely on __{get,set}state__, so the exception is never raised. When do we expect that exception to be raised exactly, during msgpack.dumps(EvilTuple((1, 2, 3)))?

This comment has been minimized.

Copy link
@mrocklin

mrocklin Aug 16, 2019

Member

Ah, you're right, I should have made the type evil, not the object.

I tried again with the following, but it didn't fail either

class Evil(type):
    def __getstate__(self):
        return None
    def __setstate__(self, state):
        raise Exception("Evil Exception")

EvilTuple = Evil("EvilTuple", (tuple,), {})

obj = EvilTuple((1, 2, 3))
pickle.loads(pickle.dumps(obj))

@jcrist tends to know these things better than I do. Maybe he can suggest a failing case here?

This comment has been minimized.

Copy link
@pentschev

pentschev Aug 16, 2019

Author Member

I think serialization with pickle is the one that should fail, no? Maybe what the test we should add is the following:

def test_dont_pickle_types():

    class EvilTuple(tuple):
        def __getstate__(self):
            raise Exception("Evil Exception")

    evil = EvilTuple((1, 2, 3))

    with pytest.raises(TypeError) as info:
        serialize(evil, serializers=["pickle"], on_error="raise")

    assert "EvilTuple" in str(info.value)

    header, frames = serialize(evil, serializers=["msgpack"], on_error="raise")  # note: no pickle serializer
    assert header["serializer"] == "msgpack"
    assert deserialize(header, frames) == evil

The test above passes, does that cover what you wanted to check @mrocklin?

Note also that pickle.dumps calls __getstate__, not __setstate__.

This comment has been minimized.

Copy link
@jcrist

jcrist Aug 16, 2019

Member

Pickling of types is hardcoded (the same way it is for functions). It looks at the qualname, and tries to import it. To make an unpickleable type just wrap the type creation in a closure:

In [56]: def make_class():
    ...:     class EvilTuple(tuple):
    ...:         pass
    ...:     return EvilTuple
    ...:

In [57]: EvilTuple = make_class()

In [58]: EvilTuple
Out[58]: __main__.make_class.<locals>.EvilTuple

In [59]: import pickle

In [60]: pickle.dumps(EvilTuple)
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-60-f2472b1154ee> in <module>
----> 1 pickle.dumps(EvilTuple)

AttributeError: Can't pickle local object 'make_class.<locals>.EvilTuple'

This comment has been minimized.

Copy link
@mrocklin

mrocklin Aug 16, 2019

Member

Thanks @jcrist . We're looking for something that fails on loads, such as would be used for nefarious behavior. Any suggestions?

This comment has been minimized.

Copy link
@jcrist

jcrist Aug 16, 2019

Member

If pickle is used for object serialization (not the type), then defining __loadstate__ or a reducer that errors should work fine:

In [65]: class ErrorOnLoad(tuple):
    ...:     def __reduce__(self):
    ...:         return ErrorOnLoad._raise, ()
    ...:     @staticmethod
    ...:     def _raise():
    ...:         raise ValueError("Oh no!")
    ...:

In [66]: e = ErrorOnLoad((1, 2, 3))

In [67]: s = pickle.dumps(e)

In [68]: pickle.loads(s)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-68-764e4625bc41> in <module>
----> 1 pickle.loads(s)

<ipython-input-65-c51f58a083ef> in _raise()
      4     @staticmethod
      5     def _raise():
----> 6         raise ValueError("Oh no!")
      7

ValueError: Oh no!

This comment has been minimized.

Copy link
@mrocklin

mrocklin Aug 16, 2019

Member

We were actually concerned about deserializing the type (sorry to be unspecific above)

This isn't that big of a deal though, we've stopped calling pickle.loads in this change, so testing this here is just to be particularly safe. It's not a hard requirement though. That being said, thinking adversarially like this is probably a good thing for us to do.

This comment has been minimized.

Copy link
@jcrist

jcrist Aug 16, 2019

Member

We were actually concerned about deserializing the type (sorry to be unspecific above)

Since you can't override pickle for a type, you'd have to define the type in a way that works to serialize (so at module level), serialize it, then break it by making it no longer importable (e.g. by deleting it).

Something like:

In [10]: class Foo(object):
    ...:     pass
    ...:

In [11]: def test():
    ...:     global Foo
    ...:     s = pickle.dumps(Foo)
    ...:     del Foo
    ...:     pickle.loads(s)
    ...:

In [12]: test()
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-12-fbd55f77ab7c> in <module>
----> 1 test()

<ipython-input-11-dca8467a1b57> in test()
      3     s = pickle.dumps(Foo)
      4     del Foo
----> 5     pickle.loads(s)
      6

AttributeError: Can't get attribute 'Foo' on <module '__main__'>
if "is-collection" in header:
headers = header["sub-headers"]
lengths = header["frame-lengths"]
cls = pickle.loads(header["type-serialized"])

This comment has been minimized.

Copy link
@mrocklin

mrocklin Aug 16, 2019

Member

If we wanted to become more paranoid here we might just include the type name and then map into a dict of known types:

cls = {"tuple": tuple, "list": list, "set": set, "dict": dict}[header["type-serialized"]]

This will likely come up the next time someone does a security pass over the code (which happens sometimes) and it would avoid a conversation in the future.

This comment has been minimized.

Copy link
@mrocklin

mrocklin Aug 16, 2019

Member

(sorry to not say this originally)

This comment has been minimized.

Copy link
@mrocklin

mrocklin Aug 16, 2019

Member
Suggested change
cls = pickle.loads(header["type-serialized"])
cls = {"tuple": tuple, "list": list, "set": set, "dict": dict}[header["type-serialized"]]

This comment has been minimized.

Copy link
@pentschev

pentschev Aug 16, 2019

Author Member

Aren't we already safe? We only enter this block if the type is one of those collections, we don't check if they're instances.

This comment has been minimized.

Copy link
@pentschev

pentschev Aug 16, 2019

Author Member

I mean, we only enter the serialization block for those types.

This comment has been minimized.

Copy link
@pentschev

pentschev Aug 16, 2019

Author Member

And I also mean, we don't check if they're subclasses (instead of instances, as I said before).

This comment has been minimized.

Copy link
@mrocklin

mrocklin Aug 16, 2019

Member

Yeah, this is where I'm apologizing for suggesting one thing, having you do that thing, and then suggesting another.

We're safe only if we can assume that the message that we've received was sent by our code as well. This isn't necessarily true though. Someone could craft a message on their own and send that.

In general we've made some promises to people that we'll only call pickle.loads in very safe situations, and always with an ability to turn it off. So we try to avoid calling it if we can.

This comment has been minimized.

Copy link
@pentschev

pentschev Aug 16, 2019

Author Member

I understand now. But in that case, would msgpack be strictly safer? The user could also craft an unsafe msgpack message, no?

This comment has been minimized.

Copy link
@mrocklin

mrocklin Aug 16, 2019

Member

pickle.loads can call arbitrary Python code. msgpack.loads just produces base Python objects like int/float/str. Certainly someone could craft malformed msgpack messages, but the worst that would happen is that it would cause an exception. People can use pickle.loads to run arbitrary code, suck up data, delete files, and so on.

@mrocklin mrocklin referenced this pull request Aug 16, 2019
@mrocklin

This comment has been minimized.

Copy link
Member

commented Aug 16, 2019

@pentschev if you need to check out for family/weekend reasons I'm happy to take this over.

@pentschev

This comment has been minimized.

Copy link
Member Author

commented Aug 16, 2019

There is this one test that failed for the past 3 builds, but only on Python 3.5 and Windows: https://travis-ci.org/dask/distributed/jobs/572843992#L2564, https://ci.appveyor.com/project/daskdev/distributed/builds/26753114#L2091, it could be related to this PR, but I'm not sure.

This one has only happened in this last build and I think it's not related as well: https://travis-ci.org/dask/distributed/jobs/572843992#L2696.

There's another two errors that are related to this PR but it only failed on Windows, which is a bit strange: https://ci.appveyor.com/project/daskdev/distributed/builds/26753114#L2113

Any ideas on this @mrocklin ?

@mrocklin

This comment has been minimized.

Copy link
Member

commented Aug 16, 2019

I'm taking a look now

@mrocklin

This comment has been minimized.

Copy link
Member

commented Aug 16, 2019

I'm taking a look at this here: #2958

@pentschev

This comment has been minimized.

Copy link
Member Author

commented Aug 16, 2019

Woohoo, this is awesome, everything green! Thanks for the help here @mrocklin and @jcrist.

@mrocklin mrocklin merged commit 41a4d41 into dask:master Aug 16, 2019

2 checks passed

continuous-integration/appveyor/pr AppVeyor build succeeded
Details
continuous-integration/travis-ci/pr The Travis CI build passed
Details
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
3 participants
You can’t perform that action at this time.