diff --git a/.gitignore b/.gitignore index 9e136965..477f7cec 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,4 @@ /docs/build /build /README -/dill/info.py \ No newline at end of file +/dill/__info__.py diff --git a/dill/_dill.py b/dill/_dill.py index e341357c..296a7328 100644 --- a/dill/_dill.py +++ b/dill/_dill.py @@ -322,16 +322,21 @@ def loads(str, ignore=None, **kwds): ### Pickle the Interpreter Session import pathlib import tempfile +from types import SimpleNamespace + +SESSION_IMPORTED_AS_TYPES = (BuiltinMethodType, FunctionType, MethodType, + ModuleType, TypeType) -SESSION_IMPORTED_AS_TYPES = (ModuleType, ClassType, TypeType, Exception, - FunctionType, MethodType, BuiltinMethodType) TEMPDIR = pathlib.PurePath(tempfile.gettempdir()) def _module_map(): """get map of imported modules""" - from collections import defaultdict, namedtuple - modmap = namedtuple('Modmap', ['by_name', 'by_id', 'top_level']) - modmap = modmap(defaultdict(list), defaultdict(list), {}) + from collections import defaultdict + modmap = SimpleNamespace( + by_name=defaultdict(list), + by_id=defaultdict(list), + top_level={}, + ) for modname, module in sys.modules.items(): if not isinstance(module, ModuleType): continue @@ -359,36 +364,38 @@ def _stash_modules(main_module): imported = [] imported_as = [] - imported_top_level = [] # keep separeted for backwards compatibility + imported_top_level = [] # keep separeted for backward compatibility original = {} for name, obj in main_module.__dict__.items(): if obj is main_module: original[name] = newmod # self-reference - continue - + elif obj is main_module.__dict__: + original[name] = newmod.__dict__ # Avoid incorrectly matching a singleton value in another package (ex.: __doc__). - if any(obj is singleton for singleton in (None, False, True)) or \ - isinstance(obj, ModuleType) and _is_builtin_module(obj): # always saved by ref + elif any(obj is singleton for singleton in (None, False, True)) \ + or isinstance(obj, ModuleType) and _is_builtin_module(obj): # always saved by ref original[name] = obj - continue - - source_module, objname = _lookup_module(modmap, name, obj, main_module) - if source_module: - if objname == name: - imported.append((source_module, name)) - else: - imported_as.append((source_module, objname, name)) else: - try: - imported_top_level.append((modmap.top_level[id(obj)], name)) - except KeyError: - original[name] = obj + source_module, objname = _lookup_module(modmap, name, obj, main_module) + if source_module: + if objname == name: + imported.append((source_module, name)) + else: + imported_as.append((source_module, objname, name)) + else: + try: + imported_top_level.append((modmap.top_level[id(obj)], name)) + except KeyError: + original[name] = obj if len(original) < len(main_module.__dict__): newmod.__dict__.update(original) newmod.__dill_imported = imported newmod.__dill_imported_as = imported_as newmod.__dill_imported_top_level = imported_top_level + if getattr(newmod, '__loader__', None) is None and _is_imported_module(main_module): + # Trick _is_imported_module() to force saving as an imported module. + newmod.__loader__ = True # will be discarded by save_module() return newmod else: return main_module @@ -407,7 +414,7 @@ def _restore_modules(unpickler, main_module): #NOTE: 06/03/15 renamed main_module to main def dump_module( filename = str(TEMPDIR/'session.pkl'), - main: Optional[Union[ModuleType, str]] = None, + module: Union[ModuleType, str] = None, refimported: bool = False, **kwds ) -> None: @@ -420,7 +427,8 @@ def dump_module( Parameters: filename: a path-like object or a writable stream. - main: a module object or the name of an importable module. + module: a module object or the name of an importable module. If `None` + (the default), :py:mod:`__main__` is saved. refimported: if `True`, all objects imported into the module's namespace are saved by reference. *Note:* this is similar but independent from ``dill.settings[`byref`]``, as ``refimported`` @@ -432,17 +440,18 @@ def dump_module( :py:exc:`PicklingError`: if pickling fails. Examples: + - Save current interpreter session state: >>> import dill - >>> squared = lambda x:x*x + >>> squared = lambda x: x*x >>> dill.dump_module() # save state of __main__ to /tmp/session.pkl - Save the state of an imported/importable module: >>> import dill >>> import pox - >>> pox.plus_one = lambda x:x+1 + >>> pox.plus_one = lambda x: x+1 >>> dill.dump_module('pox_session.pkl', main=pox) - Save the state of a non-importable, module-type object: @@ -468,24 +477,28 @@ def dump_module( >>> [foo.sin(x) for x in foo.values] [0.8414709848078965, 0.9092974268256817, 0.1411200080598672] - *Changed in version 0.3.6:* the function ``dump_session()`` was renamed to - ``dump_module()``. - - *Changed in version 0.3.6:* the parameter ``byref`` was renamed to - ``refimported``. + *Changed in version 0.3.6:* Function ``dump_session()`` was renamed to + ``dump_module()``. Parameters ``main`` and ``byref`` were renamed to + ``module`` and ``refimported``, respectively. """ - if 'byref' in kwds: - warnings.warn( - "The argument 'byref' has been renamed 'refimported'" - " to distinguish it from dill.settings['byref'].", - PendingDeprecationWarning - ) - if refimported: - raise TypeError("both 'refimported' and 'byref' were used") - refimported = kwds.pop('byref') + for old_par, par in [('main', 'module'), ('byref', 'refimported')]: + if old_par in kwds: + message = "The argument %r has been renamed %r" % (old_par, par) + if old_par == 'byref': + message += " to distinguish it from dill.settings['byref']" + warnings.warn(message + ".", PendingDeprecationWarning) + if locals()[par]: # the defaults are None and False + raise TypeError("both %r and %r arguments were used" % (par, old_par)) + refimported = kwds.pop('byref', refimported) + module = kwds.pop('main', module) + from .settings import settings protocol = settings['protocol'] - if main is None: main = _main_module + main = module + if main is None: + main = _main_module + elif isinstance(main, str): + main = _import_module(main) if hasattr(filename, 'write'): file = filename else: @@ -510,7 +523,7 @@ def dump_module( # Backward compatibility. def dump_session(filename=str(TEMPDIR/'session.pkl'), main=None, byref=False, **kwds): warnings.warn("dump_session() has been renamed dump_module()", PendingDeprecationWarning) - dump_module(filename, main, refimported=byref, **kwds) + dump_module(filename, module=main, refimported=byref, **kwds) dump_session.__doc__ = dump_module.__doc__ class _PeekableReader: @@ -574,7 +587,7 @@ def _identify_module(file, main=None): def load_module( filename = str(TEMPDIR/'session.pkl'), - main: Union[ModuleType, str] = None, + module: Union[ModuleType, str] = None, **kwds ) -> Optional[ModuleType]: """Update :py:mod:`__main__` or another module with the state from the @@ -592,7 +605,7 @@ def load_module( Parameters: filename: a path-like object or a readable stream. - main: a module object or the name of an importable module. + module: a module object or the name of an importable module. **kwds: extra keyword arguments passed to :py:class:`Unpickler()`. Raises: @@ -609,11 +622,11 @@ def load_module( - Save the state of some modules: >>> import dill - >>> squared = lambda x:x*x + >>> squared = lambda x: x*x >>> dill.dump_module() # save state of __main__ to /tmp/session.pkl >>> >>> import pox # an imported module - >>> pox.plus_one = lambda x:x+1 + >>> pox.plus_one = lambda x: x+1 >>> dill.dump_module('pox_session.pkl', main=pox) >>> >>> from types import ModuleType @@ -659,19 +672,27 @@ def load_module( >>> from types import ModuleType >>> foo = ModuleType('foo') >>> foo.values = ['a','b'] - >>> foo.sin = lambda x:x*x + >>> foo.sin = lambda x: x*x >>> dill.load_module('foo_session.pkl', main=foo) >>> [foo.sin(x) for x in foo.values] [0.8414709848078965, 0.9092974268256817, 0.1411200080598672] - *Changed in version 0.3.6:* the function ``load_session()`` was renamed to - ``load_module()``. + *Changed in version 0.3.6:* Function ``load_session()`` was renamed to + ``load_module()``. Parameter ``main`` was renamed to ``module``. See also: :py:func:`load_module_asdict` to load the contents of module saved with :py:func:`dump_module` into a dictionary. """ - main_arg = main + if 'main' in kwds: + warnings.warn( + "The argument 'main' has been renamed 'module'.", + PendingDeprecationWarning + ) + if module is not None: + raise TypeError("both 'module' and 'main' arguments were used") + module = kwds.pop('main') + main = module if hasattr(filename, 'read'): file = filename else: @@ -681,9 +702,9 @@ def load_module( #FIXME: dill.settings are disabled unpickler = Unpickler(file, **kwds) unpickler._session = True - pickle_main = _identify_module(file, main) # Resolve unpickler._main + pickle_main = _identify_module(file, main) if main is None and pickle_main is not None: main = pickle_main if isinstance(main, str): @@ -705,28 +726,26 @@ def load_module( is_runtime_mod = pickle_main.startswith('__runtime__.') if is_runtime_mod: pickle_main = pickle_main.partition('.')[-1] + error_msg = "can't update{} module{} %r with the saved state of{} module{} %r" if is_runtime_mod and is_main_imported: raise ValueError( - "can't restore non-imported module %r into an imported one" - % pickle_main + error_msg.format(" imported", "", "", "-type object") + % (main.__name__, pickle_main) ) if not is_runtime_mod and not is_main_imported: raise ValueError( - "can't restore imported module %r into a non-imported one" - % pickle_main - ) - if main.__name__ != pickle_main: - raise ValueError( - "can't restore module %r into module %r" + error_msg.format("", "-type object", " imported", "") % (pickle_main, main.__name__) ) + if main.__name__ != pickle_main: + raise ValueError(error_msg.format("", "", "", "") % (main.__name__, pickle_main)) # This is for find_class() to be able to locate it. if not is_main_imported: runtime_main = '__runtime__.%s' % main.__name__ sys.modules[runtime_main] = main - module = unpickler.load() + loaded = unpickler.load() finally: if not hasattr(filename, 'read'): # if newly opened file file.close() @@ -734,15 +753,17 @@ def load_module( del sys.modules[runtime_main] except (KeyError, NameError): pass - assert module is main - _restore_modules(unpickler, module) - if not (module is _main_module or module is main_arg): - return module + assert loaded is main + _restore_modules(unpickler, main) + if main is _main_module or main is module: + return None + else: + return main # Backward compatibility. def load_session(filename=str(TEMPDIR/'session.pkl'), main=None, **kwds): warnings.warn("load_session() has been renamed load_module().", PendingDeprecationWarning) - load_module(filename, main, **kwds) + load_module(filename, module=main, **kwds) load_session.__doc__ = load_module.__doc__ def load_module_asdict( @@ -774,6 +795,7 @@ def load_module_asdict( Note: If ``update`` is True, the saved module may be imported then updated. + If imported, the loaded module remains unchanged as in the general case. Example: >>> import dill @@ -796,8 +818,8 @@ def load_module_asdict( >>> new_var in main # would be True if the option 'update' was set False """ - if 'main' in kwds: - raise TypeError("'main' is an invalid keyword argument for load_module_asdict()") + if 'module' in kwds: + raise TypeError("'module' is an invalid keyword argument for load_module_asdict()") if hasattr(filename, 'read'): file = filename else: @@ -815,7 +837,6 @@ def load_module_asdict( main.__builtins__ = __builtin__ sys.modules[main_name] = main load_module(file, **kwds) - main.__session__ = str(filename) finally: if not hasattr(filename, 'read'): # if newly opened file file.close() @@ -826,6 +847,7 @@ def load_module_asdict( sys.modules[main_name] = old_main except NameError: # failed before setting old_main pass + main.__session__ = str(filename) return main.__dict__ ### End: Pickle the Interpreter diff --git a/dill/tests/test_session.py b/dill/tests/test_session.py index 8f687934..8a054eb8 100644 --- a/dill/tests/test_session.py +++ b/dill/tests/test_session.py @@ -9,6 +9,7 @@ import os import sys import __main__ +from contextlib import suppress from io import BytesIO import dill @@ -27,7 +28,7 @@ def _error_line(error, obj, refimported): if __name__ == '__main__' and len(sys.argv) >= 3 and sys.argv[1] == '--child': # Test session loading in a fresh interpreter session. refimported = (sys.argv[2] == 'True') - dill.load_module(session_file % refimported) + dill.load_module(session_file % refimported, module='__main__') def test_modules(refimported): # FIXME: In this test setting with CPython 3.7, 'calendar' is not included @@ -111,10 +112,8 @@ def _clean_up_cache(module): cached = module.__cached__ if hasattr(module, '__cached__') else cached pycache = os.path.join(os.path.dirname(module.__file__), '__pycache__') for remove, file in [(os.remove, cached), (os.removedirs, pycache)]: - try: + with suppress(OSError): remove(file) - except OSError: - pass atexit.register(_clean_up_cache, local_mod) @@ -163,16 +162,14 @@ def test_session_main(refimported): error = sp.call([python, __file__, '--child', str(refimported)], shell=shell) if error: sys.exit(error) finally: - try: + with suppress(OSError): os.remove(session_file % refimported) - except OSError: - pass # Test session loading in the same session. session_buffer = BytesIO() dill.dump_module(session_buffer, refimported=refimported) session_buffer.seek(0) - dill.load_module(session_buffer) + dill.load_module(session_buffer, module='__main__') ns.backup['_test_objects'](__main__, ns.backup, refimported) def test_session_other(): @@ -183,13 +180,13 @@ def test_session_other(): dict_objects = [obj for obj in module.__dict__.keys() if not obj.startswith('__')] session_buffer = BytesIO() - dill.dump_module(session_buffer, main=module) + dill.dump_module(session_buffer, module) for obj in dict_objects: del module.__dict__[obj] session_buffer.seek(0) - dill.load_module(session_buffer) #, main=module) + dill.load_module(session_buffer, module) assert all(obj in module.__dict__ for obj in dict_objects) assert module.selfref is module @@ -210,12 +207,12 @@ def test_runtime_module(): # without imported objects in the namespace. It's a contrived example because # even dill can't be in it. This should work after fixing #462. session_buffer = BytesIO() - dill.dump_module(session_buffer, main=runtime, refimported=True) + dill.dump_module(session_buffer, module=runtime, refimported=True) session_dump = session_buffer.getvalue() # Pass a new runtime created module with the same name. runtime = ModuleType(modname) # empty - return_val = dill.load_module(BytesIO(session_dump), main=runtime) + return_val = dill.load_module(BytesIO(session_dump), module=runtime) assert return_val is None assert runtime.__name__ == modname assert runtime.x == 42