diff --git a/CHANGES.md b/CHANGES.md index 004a2914c..e21ae720d 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,11 @@ +1.1.0 +===== + +- Track the provenance of dynamic classes and enums so as to preseve the + usual `isinstance` relationship between pickled objects and their + original class defintions. + ([issue #246](https://github.com/cloudpipe/cloudpickle/pull/246)) + 1.0.0 ===== diff --git a/cloudpickle/cloudpickle.py b/cloudpickle/cloudpickle.py index 7df5f6c74..d84cce76d 100644 --- a/cloudpickle/cloudpickle.py +++ b/cloudpickle/cloudpickle.py @@ -44,7 +44,6 @@ import dis from functools import partial -import importlib import io import itertools import logging @@ -56,12 +55,26 @@ import traceback import types import weakref +import uuid +import threading + + +try: + from enum import Enum +except ImportError: + Enum = None # cloudpickle is meant for inter process communication: we expect all # communicating processes to run the same Python version hence we favor # communication speed over compatibility: DEFAULT_PROTOCOL = pickle.HIGHEST_PROTOCOL +# Track the provenance of reconstructed dynamic classes to make it possible to +# recontruct instances from the matching singleton class definition when +# appropriate and preserve the usual "isinstance" semantics of Python objects. +_DYNAMIC_CLASS_TRACKER_BY_CLASS = weakref.WeakKeyDictionary() +_DYNAMIC_CLASS_TRACKER_BY_ID = weakref.WeakValueDictionary() +_DYNAMIC_CLASS_TRACKER_LOCK = threading.Lock() if sys.version_info[0] < 3: # pragma: no branch from pickle import Pickler @@ -71,12 +84,37 @@ from StringIO import StringIO string_types = (basestring,) # noqa PY3 = False + PY2 = True + PY2_WRAPPER_DESCRIPTOR_TYPE = type(object.__init__) + PY2_METHOD_WRAPPER_TYPE = type(object.__eq__) + PY2_CLASS_DICT_BLACKLIST = (PY2_METHOD_WRAPPER_TYPE, + PY2_WRAPPER_DESCRIPTOR_TYPE) else: types.ClassType = type from pickle import _Pickler as Pickler from io import BytesIO as StringIO string_types = (str,) PY3 = True + PY2 = False + + +def _ensure_tracking(class_def): + with _DYNAMIC_CLASS_TRACKER_LOCK: + class_tracker_id = _DYNAMIC_CLASS_TRACKER_BY_CLASS.get(class_def) + if class_tracker_id is None: + class_tracker_id = uuid.uuid4().hex + _DYNAMIC_CLASS_TRACKER_BY_CLASS[class_def] = class_tracker_id + _DYNAMIC_CLASS_TRACKER_BY_ID[class_tracker_id] = class_def + return class_tracker_id + + +def _lookup_class_or_track(class_tracker_id, class_def): + if class_tracker_id is not None: + with _DYNAMIC_CLASS_TRACKER_LOCK: + class_def = _DYNAMIC_CLASS_TRACKER_BY_ID.setdefault( + class_tracker_id, class_def) + _DYNAMIC_CLASS_TRACKER_BY_CLASS[class_def] = class_tracker_id + return class_def def _make_cell_set_template_code(): @@ -112,7 +150,7 @@ def inner(value): # NOTE: we are marking the cell variable as a free variable intentionally # so that we simulate an inner function instead of the outer function. This # is what gives us the ``nonlocal`` behavior in a Python 2 compatible way. - if not PY3: # pragma: no branch + if PY2: # pragma: no branch return types.CodeType( co.co_argcount, co.co_nlocals, @@ -220,7 +258,7 @@ def _walk_global_ops(code): global-referencing instructions in *code*. """ code = getattr(code, 'co_code', b'') - if not PY3: # pragma: no branch + if PY2: # pragma: no branch code = map(ord, code) n = len(code) @@ -250,6 +288,39 @@ def _walk_global_ops(code): yield op, instr.arg +def _extract_class_dict(cls): + """Retrieve a copy of the dict of a class without the inherited methods""" + clsdict = dict(cls.__dict__) # copy dict proxy to a dict + if len(cls.__bases__) == 1: + inherited_dict = cls.__bases__[0].__dict__ + else: + inherited_dict = {} + for base in reversed(cls.__bases__): + inherited_dict.update(base.__dict__) + to_remove = [] + for name, value in clsdict.items(): + try: + base_value = inherited_dict[name] + if value is base_value: + to_remove.append(name) + elif PY2: + # backward compat for Python 2 + if hasattr(value, "im_func"): + if value.im_func is getattr(base_value, "im_func", None): + to_remove.append(name) + elif isinstance(value, PY2_CLASS_DICT_BLACKLIST): + # On Python 2 we have no way to pickle those specific + # methods types nor to check that they are actually + # inherited. So we assume that they are always inherited + # from builtin types. + to_remove.append(name) + except KeyError: + pass + for name in to_remove: + clsdict.pop(name) + return clsdict + + class CloudPickler(Pickler): dispatch = Pickler.dispatch.copy() @@ -277,7 +348,7 @@ def save_memoryview(self, obj): dispatch[memoryview] = save_memoryview - if not PY3: # pragma: no branch + if PY2: # pragma: no branch def save_buffer(self, obj): self.save(str(obj)) @@ -460,15 +531,40 @@ def func(): # then discards the reference to it self.write(pickle.POP) - def save_dynamic_class(self, obj): + def _save_dynamic_enum(self, obj, clsdict): + """Special handling for dynamic Enum subclasses + + Use a dedicated Enum constructor (inspired by EnumMeta.__call__) as the + EnumMeta metaclass has complex initialization that makes the Enum + subclasses hold references to their own instances. """ - Save a class that can't be stored as module global. + members = dict((e.name, e.value) for e in obj) + + # Python 2.7 with enum34 can have no qualname: + qualname = getattr(obj, "__qualname__", None) + + self.save_reduce(_make_skeleton_enum, + (obj.__bases__, obj.__name__, qualname, members, + obj.__module__, _ensure_tracking(obj), None), + obj=obj) + + # Cleanup the clsdict that will be passed to _rehydrate_skeleton_class: + # Those attributes are already handled by the metaclass. + for attrname in ["_generate_next_value_", "_member_names_", + "_member_map_", "_member_type_", + "_value2member_map_"]: + clsdict.pop(attrname, None) + for member in members: + clsdict.pop(member) + + def save_dynamic_class(self, obj): + """Save a class that can't be stored as module global. This method is used to serialize classes that are defined inside functions, or that otherwise can't be serialized as attribute lookups from global modules. """ - clsdict = dict(obj.__dict__) # copy dict proxy to a dict + clsdict = _extract_class_dict(obj) clsdict.pop('__weakref__', None) # For ABCMeta in python3.7+, remove _abc_impl as it is not picklable. @@ -496,8 +592,8 @@ def save_dynamic_class(self, obj): for k in obj.__slots__: clsdict.pop(k, None) - # If type overrides __dict__ as a property, include it in the type kwargs. - # In Python 2, we can't set this attribute after construction. + # If type overrides __dict__ as a property, include it in the type + # kwargs. In Python 2, we can't set this attribute after construction. __dict__ = clsdict.pop('__dict__', None) if isinstance(__dict__, property): type_kwargs['__dict__'] = __dict__ @@ -524,8 +620,16 @@ def save_dynamic_class(self, obj): write(pickle.MARK) # Create and memoize an skeleton class with obj's name and bases. - tp = type(obj) - self.save_reduce(tp, (obj.__name__, obj.__bases__, type_kwargs), obj=obj) + if Enum is not None and issubclass(obj, Enum): + # Special handling of Enum subclasses + self._save_dynamic_enum(obj, clsdict) + else: + # "Regular" class definition: + tp = type(obj) + self.save_reduce(_make_skeleton_class, + (tp, obj.__name__, obj.__bases__, type_kwargs, + _ensure_tracking(obj), None), + obj=obj) # Now save the rest of obj's __dict__. Any references to obj # encountered while saving will point to the skeleton class. @@ -778,7 +882,7 @@ def save_inst(self, obj): save(stuff) write(pickle.BUILD) - if not PY3: # pragma: no branch + if PY2: # pragma: no branch dispatch[types.InstanceType] = save_inst def save_property(self, obj): @@ -1119,6 +1223,22 @@ def _make_skel_func(code, cell_count, base_globals=None): return types.FunctionType(code, base_globals, None, None, closure) +def _make_skeleton_class(type_constructor, name, bases, type_kwargs, + class_tracker_id, extra): + """Build dynamic class with an empty __dict__ to be filled once memoized + + If class_tracker_id is not None, try to lookup an existing class definition + matching that id. If none is found, track a newly reconstructed class + definition under that id so that other instances stemming from the same + class id will also reuse this class definition. + + The "extra" variable is meant to be a dict (or None) that can be used for + forward compatibility shall the need arise. + """ + skeleton_class = type_constructor(name, bases, type_kwargs) + return _lookup_class_or_track(class_tracker_id, skeleton_class) + + def _rehydrate_skeleton_class(skeleton_class, class_dict): """Put attributes from `class_dict` back on `skeleton_class`. @@ -1137,6 +1257,39 @@ def _rehydrate_skeleton_class(skeleton_class, class_dict): return skeleton_class +def _make_skeleton_enum(bases, name, qualname, members, module, + class_tracker_id, extra): + """Build dynamic enum with an empty __dict__ to be filled once memoized + + The creation of the enum class is inspired by the code of + EnumMeta._create_. + + If class_tracker_id is not None, try to lookup an existing enum definition + matching that id. If none is found, track a newly reconstructed enum + definition under that id so that other instances stemming from the same + class id will also reuse this enum definition. + + The "extra" variable is meant to be a dict (or None) that can be used for + forward compatibility shall the need arise. + """ + # enums always inherit from their base Enum class at the last position in + # the list of base classes: + enum_base = bases[-1] + metacls = enum_base.__class__ + classdict = metacls.__prepare__(name, bases) + + for member_name, member_value in members.items(): + classdict[member_name] = member_value + enum_class = metacls.__new__(metacls, name, bases, classdict) + enum_class.__module__ = module + + # Python 2.7 compat + if qualname is not None: + enum_class.__qualname__ = qualname + + return _lookup_class_or_track(class_tracker_id, enum_class) + + def _is_dynamic(module): """ Return True if the module is special module that cannot be imported by its diff --git a/tests/cloudpickle_test.py b/tests/cloudpickle_test.py index e5afc3dd6..8f358ac64 100644 --- a/tests/cloudpickle_test.py +++ b/tests/cloudpickle_test.py @@ -42,6 +42,7 @@ import cloudpickle from cloudpickle.cloudpickle import _is_dynamic from cloudpickle.cloudpickle import _make_empty_cell, cell_set +from cloudpickle.cloudpickle import _extract_class_dict from .testutils import subprocess_pickle_echo from .testutils import assert_run_python_script @@ -71,6 +72,32 @@ def _escape(raw_filepath): return raw_filepath.replace("\\", r"\\\\") +def test_extract_class_dict(): + class A(int): + """A docstring""" + def method(self): + return "a" + + class B: + """B docstring""" + B_CONSTANT = 42 + + def method(self): + return "b" + + class C(A, B): + C_CONSTANT = 43 + + def method_c(self): + return "c" + + clsdict = _extract_class_dict(C) + assert sorted(clsdict.keys()) == ["C_CONSTANT", "__doc__", "method_c"] + assert clsdict["C_CONSTANT"] == 43 + assert clsdict["__doc__"] is None + assert clsdict["method_c"](C()) == C().method_c() + + class CloudPickleTest(unittest.TestCase): protocol = cloudpickle.DEFAULT_PROTOCOL @@ -924,21 +951,18 @@ def func(x): self.assertEqual(cloned.__qualname__, func.__qualname__) def test_namedtuple(self): - MyTuple = collections.namedtuple('MyTuple', ['a', 'b', 'c']) - t = MyTuple(1, 2, 3) + t1 = MyTuple(1, 2, 3) + t2 = MyTuple(3, 2, 1) - depickled_t, depickled_MyTuple = pickle_depickle( - [t, MyTuple], protocol=self.protocol) - self.assertTrue(isinstance(depickled_t, depickled_MyTuple)) + depickled_t1, depickled_MyTuple, depickled_t2 = pickle_depickle( + [t1, MyTuple, t2], protocol=self.protocol) - self.assertEqual((depickled_t.a, depickled_t.b, depickled_t.c), - (1, 2, 3)) - self.assertEqual((depickled_t[0], depickled_t[1], depickled_t[2]), - (1, 2, 3)) - - self.assertEqual(depickled_MyTuple.__name__, 'MyTuple') - self.assertTrue(issubclass(depickled_MyTuple, tuple)) + assert isinstance(depickled_t1, MyTuple) + assert depickled_t1 == t1 + assert depickled_MyTuple is MyTuple + assert isinstance(depickled_t2, MyTuple) + assert depickled_t2 == t2 def test_builtin_type__new__(self): # Functions occasionally take the __new__ of these types as default @@ -1197,6 +1221,123 @@ def is_in_main(name): """.format(protocol=self.protocol) assert_run_python_script(code) + def test_interactive_dynamic_type_and_remote_instances(self): + code = """if __name__ == "__main__": + from testutils import subprocess_worker + + with subprocess_worker(protocol={protocol}) as w: + + class CustomCounter: + def __init__(self): + self.count = 0 + def increment(self): + self.count += 1 + return self + + counter = CustomCounter().increment() + assert counter.count == 1 + + returned_counter = w.run(counter.increment) + assert returned_counter.count == 2, returned_counter.count + + # Check that the class definition of the returned instance was + # matched back to the original class definition living in __main__. + + assert isinstance(returned_counter, CustomCounter) + + # Check that memoization does not break provenance tracking: + + def echo(*args): + return args + + C1, C2, c1, c2 = w.run(echo, CustomCounter, CustomCounter, + CustomCounter(), returned_counter) + assert C1 is CustomCounter + assert C2 is CustomCounter + assert isinstance(c1, CustomCounter) + assert isinstance(c2, CustomCounter) + + """.format(protocol=self.protocol) + assert_run_python_script(code) + + def test_interactive_dynamic_type_and_stored_remote_instances(self): + """Simulate objects stored on workers to check isinstance semantics + + Such instances stored in the memory of running worker processes are + similar to dask-distributed futures for instance. + """ + code = """if __name__ == "__main__": + import cloudpickle, uuid + from testutils import subprocess_worker + + with subprocess_worker(protocol={protocol}) as w: + + class A: + '''Original class definition''' + pass + + def store(x): + storage = getattr(cloudpickle, "_test_storage", None) + if storage is None: + storage = cloudpickle._test_storage = dict() + obj_id = uuid.uuid4().hex + storage[obj_id] = x + return obj_id + + def lookup(obj_id): + return cloudpickle._test_storage[obj_id] + + id1 = w.run(store, A()) + + # The stored object on the worker is matched to a singleton class + # definition thanks to provenance tracking: + assert w.run(lambda obj_id: isinstance(lookup(obj_id), A), id1) + + # Retrieving the object from the worker yields a local copy that + # is matched back the local class definition this instance + # originally stems from. + assert isinstance(w.run(lookup, id1), A) + + # Changing the local class definition should be taken into account + # in all subsequent calls. In particular the old instances on the + # worker do not map back to the new class definition, neither on + # the worker itself, nor locally on the main program when the old + # instance is retrieved: + + class A: + '''Updated class definition''' + pass + + assert not w.run(lambda obj_id: isinstance(lookup(obj_id), A), id1) + retrieved1 = w.run(lookup, id1) + assert not isinstance(retrieved1, A) + assert retrieved1.__class__ is not A + assert retrieved1.__class__.__doc__ == "Original class definition" + + # New instances on the other hand are proper instances of the new + # class definition everywhere: + + a = A() + id2 = w.run(store, a) + assert w.run(lambda obj_id: isinstance(lookup(obj_id), A), id2) + assert isinstance(w.run(lookup, id2), A) + + # Monkeypatch the class defintion in the main process to a new + # class method: + A.echo = lambda cls, x: x + + # Calling this method on an instance will automatically update + # the remote class definition on the worker to propagate the monkey + # patch dynamically. + assert w.run(a.echo, 42) == 42 + + # The stored instance can therefore also access the new class + # method: + assert w.run(lambda obj_id: lookup(obj_id).echo(43), id2) == 43 + + """.format(protocol=self.protocol) + assert_run_python_script(code) + @pytest.mark.skipif(platform.python_implementation() == 'PyPy', reason="Skip PyPy because memory grows too much") def test_interactive_remote_function_calls_no_memory_leak(self): @@ -1226,9 +1367,10 @@ def process_data(): import gc w.run(gc.collect) - # By this time the worker process has processed worth of 100MB of - # data passed in the closures its memory size should now have - # grown by more than a few MB. + # By this time the worker process has processed 100MB worth of data + # passed in the closures. The worker memory size should not have + # grown by more than a few MB as closures are garbage collected at + # the end of each remote function call. growth = w.memsize() - reference_size assert growth < 1e7, growth @@ -1368,6 +1510,88 @@ def test_dataclass(self): pickle_depickle(DataClass, protocol=self.protocol) assert data.x == pickle_depickle(data, protocol=self.protocol).x == 42 + def test_locally_defined_enum(self): + enum = pytest.importorskip("enum") + + class StringEnum(str, enum.Enum): + """Enum when all members are also (and must be) strings""" + + class Color(StringEnum): + """3-element color space""" + RED = "1" + GREEN = "2" + BLUE = "3" + + def is_green(self): + return self is Color.GREEN + + green1, green2, ClonedColor = pickle_depickle( + [Color.GREEN, Color.GREEN, Color], protocol=self.protocol) + assert green1 is green2 + assert green1 is ClonedColor.GREEN + assert green1 is not ClonedColor.BLUE + assert isinstance(green1, str) + assert green1.is_green() + + # cloudpickle systematically tracks provenance of class definitions + # and ensure reconciliation in case of round trips: + assert green1 is Color.GREEN + assert ClonedColor is Color + + green3 = pickle_depickle(Color.GREEN, protocol=self.protocol) + assert green3 is Color.GREEN + + def test_locally_defined_intenum(self): + enum = pytest.importorskip("enum") + # Try again with a IntEnum defined with the functional API + DynamicColor = enum.IntEnum("Color", {"RED": 1, "GREEN": 2, "BLUE": 3}) + + green1, green2, ClonedDynamicColor = pickle_depickle( + [DynamicColor.GREEN, DynamicColor.GREEN, DynamicColor], + protocol=self.protocol) + + assert green1 is green2 + assert green1 is ClonedDynamicColor.GREEN + assert green1 is not ClonedDynamicColor.BLUE + assert ClonedDynamicColor is DynamicColor + + def test_interactively_defined_enum(self): + pytest.importorskip("enum") + code = """if __name__ == "__main__": + from enum import Enum + from testutils import subprocess_worker + + with subprocess_worker(protocol={protocol}) as w: + + class Color(Enum): + RED = 1 + GREEN = 2 + + def check_positive(x): + return Color.GREEN if x >= 0 else Color.RED + + result = w.run(check_positive, 1) + + # Check that the returned enum instance is reconciled with the + # locally defined Color enum type definition: + + assert result is Color.GREEN + + # Check that changing the definition of the Enum class is taken + # into account on the worker for subsequent calls: + + class Color(Enum): + RED = 1 + BLUE = 2 + + def check_positive(x): + return Color.BLUE if x >= 0 else Color.RED + + result = w.run(check_positive, 1) + assert result is Color.BLUE + """.format(protocol=self.protocol) + assert_run_python_script(code) + def test_relative_import_inside_function(self): # Make sure relative imports inside round-tripped functions is not # broken.This was a bug in cloudpickle versions <= 0.5.3 and was