Skip to content
Permalink
 
 
Cannot retrieve contributors at this time
"""
Serialization support for compiled functions.
"""
import sys
import abc
import io
import copyreg
import pickle
from numba import cloudpickle
#
# Pickle support
#
def _rebuild_reduction(cls, *args):
"""
Global hook to rebuild a given class from its __reduce__ arguments.
"""
return cls._rebuild(*args)
# Keep unpickled object via `numba_unpickle` alive.
_unpickled_memo = {}
def _numba_unpickle(address, bytedata, hashed):
"""Used by `numba_unpickle` from _helperlib.c
Parameters
----------
address : int
bytedata : bytes
hashed : bytes
Returns
-------
obj : object
unpickled object
"""
key = (address, hashed)
try:
obj = _unpickled_memo[key]
except KeyError:
_unpickled_memo[key] = obj = cloudpickle.loads(bytedata)
return obj
def dumps(obj):
"""Similar to `pickle.dumps()`. Returns the serialized object in bytes.
"""
pickler = NumbaPickler
with io.BytesIO() as buf:
p = pickler(buf, protocol=4)
p.dump(obj)
pickled = buf.getvalue()
return pickled
# Alias to pickle.loads to allow `serialize.loads()`
loads = cloudpickle.loads
class _CustomPickled:
"""A wrapper for objects that must be pickled with `NumbaPickler`.
Standard `pickle` will pick up the implementation registered via `copyreg`.
This will spawn a `NumbaPickler` instance to serialize the data.
`NumbaPickler` overrides the handling of this type so as not to spawn a
new pickler for the object when it is already being pickled by a
`NumbaPickler`.
"""
__slots__ = 'ctor', 'states'
def __init__(self, ctor, states):
self.ctor = ctor
self.states = states
def _reduce(self):
return _CustomPickled._rebuild, (self.ctor, self.states)
@classmethod
def _rebuild(cls, ctor, states):
return cls(ctor, states)
def _unpickle__CustomPickled(serialized):
"""standard unpickling for `_CustomPickled`.
Uses `NumbaPickler` to load.
"""
ctor, states = loads(serialized)
return _CustomPickled(ctor, states)
def _pickle__CustomPickled(cp):
"""standard pickling for `_CustomPickled`.
Uses `NumbaPickler` to dump.
"""
serialized = dumps((cp.ctor, cp.states))
return _unpickle__CustomPickled, (serialized,)
# Register custom pickling for the standard pickler.
copyreg.pickle(_CustomPickled, _pickle__CustomPickled)
def custom_reduce(cls, states):
"""For customizing object serialization in `__reduce__`.
Object states provided here are used as keyword arguments to the
`._rebuild()` class method.
Parameters
----------
states : dict
Dictionary of object states to be serialized.
Returns
-------
result : tuple
This tuple conforms to the return type requirement for `__reduce__`.
"""
return custom_rebuild, (_CustomPickled(cls, states),)
def custom_rebuild(custom_pickled):
"""Customized object deserialization.
This function is referenced internally by `custom_reduce()`.
"""
cls, states = custom_pickled.ctor, custom_pickled.states
return cls._rebuild(**states)
def is_serialiable(obj):
"""Check if *obj* can be serialized.
Parameters
----------
obj : object
Returns
--------
can_serialize : bool
"""
with io.BytesIO() as fout:
pickler = NumbaPickler(fout)
try:
pickler.dump(obj)
except pickle.PicklingError:
return False
else:
return True
def _no_pickle(obj):
raise pickle.PicklingError(f"Pickling of {type(obj)} is unsupported")
def disable_pickling(typ):
"""This is called on a type to disable pickling
"""
NumbaPickler.disabled_types.add(typ)
# The following is needed for Py3.7
NumbaPickler.dispatch_table[typ] = _no_pickle
# Return `typ` to allow use as a decorator
return typ
class NumbaPickler(cloudpickle.CloudPickler):
disabled_types = set()
"""A set of types that pickling cannot is disabled.
"""
def reducer_override(self, obj):
# Overridden to disable pickling of certain types
if type(obj) in self.disabled_types:
_no_pickle(obj) # noreturn
return super().reducer_override(obj)
def _custom_reduce__custompickled(cp):
return cp._reduce()
NumbaPickler.dispatch_table[_CustomPickled] = _custom_reduce__custompickled
class ReduceMixin(abc.ABC):
"""A mixin class for objects that should be reduced by the NumbaPickler instead
of the standard pickler.
"""
# Subclass MUST override the below methods
@abc.abstractmethod
def _reduce_states(self):
raise NotImplementedError
@abc.abstractclassmethod
def _rebuild(cls, **kwargs):
raise NotImplementedError
# Subclass can override the below methods
def _reduce_class(self):
return self.__class__
# Private methods
def __reduce__(self):
return custom_reduce(self._reduce_class(), self._reduce_states())
class PickleCallableByPath:
"""Wrap a callable object to be pickled by path to workaround limitation
in pickling due to non-pickleable objects in function non-locals.
Note:
- Do not use this as a decorator.
- Wrapped object must be a global that exist in its parent module and it
can be imported by `from the_module import the_object`.
Usage:
>>> def my_fn(x):
>>> ...
>>> wrapped_fn = PickleCallableByPath(my_fn)
>>> # refer to `wrapped_fn` instead of `my_fn`
"""
def __init__(self, fn):
self._fn = fn
def __call__(self, *args, **kwargs):
return self._fn(*args, **kwargs)
def __reduce__(self):
return type(self)._rebuild, (self._fn.__module__, self._fn.__name__,)
@classmethod
def _rebuild(cls, modname, fn_path):
return cls(getattr(sys.modules[modname], fn_path))