-
Notifications
You must be signed in to change notification settings - Fork 1.2k
/
Copy pathserialize.py
260 lines (187 loc) · 6.14 KB
/
serialize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
"""
Serialization support for compiled functions.
"""
import sys
import abc
import io
import copyreg
import pickle
from numba import cloudpickle
from llvmlite import ir
#
# Pickle support
#
def _rebuild_reduction(cls, *args):
"""
Global hook to rebuild a given class from its __reduce__ arguments.
"""
return cls._rebuild(*args)
# Keep unpickled object via `numba_unpickle` alive.
_unpickled_memo = {}
def _numba_unpickle(address, bytedata, hashed):
"""Used by `numba_unpickle` from _helperlib.c
Parameters
----------
address : int
bytedata : bytes
hashed : bytes
Returns
-------
obj : object
unpickled object
"""
key = (address, hashed)
try:
obj = _unpickled_memo[key]
except KeyError:
_unpickled_memo[key] = obj = cloudpickle.loads(bytedata)
return obj
def dumps(obj):
"""Similar to `pickle.dumps()`. Returns the serialized object in bytes.
"""
pickler = NumbaPickler
with io.BytesIO() as buf:
p = pickler(buf, protocol=4)
p.dump(obj)
pickled = buf.getvalue()
return pickled
def runtime_build_excinfo_struct(static_exc, exc_args):
exc, static_args, locinfo = cloudpickle.loads(static_exc)
real_args = []
exc_args_iter = iter(exc_args)
for arg in static_args:
if isinstance(arg, ir.Value):
real_args.append(next(exc_args_iter))
else:
real_args.append(arg)
return (exc, tuple(real_args), locinfo)
# Alias to pickle.loads to allow `serialize.loads()`
loads = cloudpickle.loads
class _CustomPickled:
"""A wrapper for objects that must be pickled with `NumbaPickler`.
Standard `pickle` will pick up the implementation registered via `copyreg`.
This will spawn a `NumbaPickler` instance to serialize the data.
`NumbaPickler` overrides the handling of this type so as not to spawn a
new pickler for the object when it is already being pickled by a
`NumbaPickler`.
"""
__slots__ = 'ctor', 'states'
def __init__(self, ctor, states):
self.ctor = ctor
self.states = states
def _reduce(self):
return _CustomPickled._rebuild, (self.ctor, self.states)
@classmethod
def _rebuild(cls, ctor, states):
return cls(ctor, states)
def _unpickle__CustomPickled(serialized):
"""standard unpickling for `_CustomPickled`.
Uses `NumbaPickler` to load.
"""
ctor, states = loads(serialized)
return _CustomPickled(ctor, states)
def _pickle__CustomPickled(cp):
"""standard pickling for `_CustomPickled`.
Uses `NumbaPickler` to dump.
"""
serialized = dumps((cp.ctor, cp.states))
return _unpickle__CustomPickled, (serialized,)
# Register custom pickling for the standard pickler.
copyreg.pickle(_CustomPickled, _pickle__CustomPickled)
def custom_reduce(cls, states):
"""For customizing object serialization in `__reduce__`.
Object states provided here are used as keyword arguments to the
`._rebuild()` class method.
Parameters
----------
states : dict
Dictionary of object states to be serialized.
Returns
-------
result : tuple
This tuple conforms to the return type requirement for `__reduce__`.
"""
return custom_rebuild, (_CustomPickled(cls, states),)
def custom_rebuild(custom_pickled):
"""Customized object deserialization.
This function is referenced internally by `custom_reduce()`.
"""
cls, states = custom_pickled.ctor, custom_pickled.states
return cls._rebuild(**states)
def is_serialiable(obj):
"""Check if *obj* can be serialized.
Parameters
----------
obj : object
Returns
--------
can_serialize : bool
"""
with io.BytesIO() as fout:
pickler = NumbaPickler(fout)
try:
pickler.dump(obj)
except pickle.PicklingError:
return False
else:
return True
def _no_pickle(obj):
raise pickle.PicklingError(f"Pickling of {type(obj)} is unsupported")
def disable_pickling(typ):
"""This is called on a type to disable pickling
"""
NumbaPickler.disabled_types.add(typ)
# Return `typ` to allow use as a decorator
return typ
class NumbaPickler(cloudpickle.CloudPickler):
disabled_types = set()
"""A set of types that pickling cannot is disabled.
"""
def reducer_override(self, obj):
# Overridden to disable pickling of certain types
if type(obj) in self.disabled_types:
_no_pickle(obj) # noreturn
return super().reducer_override(obj)
def _custom_reduce__custompickled(cp):
return cp._reduce()
NumbaPickler.dispatch_table[_CustomPickled] = _custom_reduce__custompickled
class ReduceMixin(abc.ABC):
"""A mixin class for objects that should be reduced by the NumbaPickler
instead of the standard pickler.
"""
# Subclass MUST override the below methods
@abc.abstractmethod
def _reduce_states(self):
raise NotImplementedError
@classmethod
@abc.abstractmethod
def _rebuild(cls, **kwargs):
raise NotImplementedError
# Subclass can override the below methods
def _reduce_class(self):
return self.__class__
# Private methods
def __reduce__(self):
return custom_reduce(self._reduce_class(), self._reduce_states())
class PickleCallableByPath:
"""Wrap a callable object to be pickled by path to workaround limitation
in pickling due to non-pickleable objects in function non-locals.
Note:
- Do not use this as a decorator.
- Wrapped object must be a global that exist in its parent module and it
can be imported by `from the_module import the_object`.
Usage:
>>> def my_fn(x):
>>> ...
>>> wrapped_fn = PickleCallableByPath(my_fn)
>>> # refer to `wrapped_fn` instead of `my_fn`
"""
def __init__(self, fn):
self._fn = fn
def __call__(self, *args, **kwargs):
return self._fn(*args, **kwargs)
def __reduce__(self):
return type(self)._rebuild, (self._fn.__module__, self._fn.__name__,)
@classmethod
def _rebuild(cls, modname, fn_path):
return cls(getattr(sys.modules[modname], fn_path))