forked from microsoft/qlib
-
Notifications
You must be signed in to change notification settings - Fork 0
/
serial.py
126 lines (103 loc) · 3.91 KB
/
serial.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from pathlib import Path
import pickle
import typing
import dill
from typing import Union
class Serializable:
"""
Serializable will change the behaviors of pickle.
- It only saves the state whose name **does not** start with `_`
It provides a syntactic sugar for distinguish the attributes which user doesn't want.
- For examples, a learnable Datahandler just wants to save the parameters without data when dumping to disk
"""
pickle_backend = "pickle" # another optional value is "dill" which can pickle more things of python.
default_dump_all = False # if dump all things
def __init__(self):
self._dump_all = self.default_dump_all
self._exclude = []
def __getstate__(self) -> dict:
return {
k: v for k, v in self.__dict__.items() if k not in self.exclude and (self.dump_all or not k.startswith("_"))
}
def __setstate__(self, state: dict):
self.__dict__.update(state)
@property
def dump_all(self):
"""
will the object dump all object
"""
return getattr(self, "_dump_all", False)
@property
def exclude(self):
"""
What attribute will not be dumped
"""
return getattr(self, "_exclude", [])
FLAG_KEY = "_qlib_serial_flag"
def config(self, dump_all: bool = None, exclude: list = None, recursive=False):
"""
configure the serializable object
Parameters
----------
dump_all : bool
will the object dump all object
exclude : list
What attribute will not be dumped
recursive : bool
will the configuration be recursive
"""
params = {"dump_all": dump_all, "exclude": exclude}
for k, v in params.items():
if v is not None:
attr_name = f"_{k}"
setattr(self, attr_name, v)
if recursive:
for obj in self.__dict__.values():
# set flag to prevent endless loop
self.__dict__[self.FLAG_KEY] = True
if isinstance(obj, Serializable) and self.FLAG_KEY not in obj.__dict__:
obj.config(**params, recursive=True)
del self.__dict__[self.FLAG_KEY]
def to_pickle(self, path: Union[Path, str], dump_all: bool = None, exclude: list = None):
"""
Dump self to a pickle file.
Args:
path (Union[Path, str]): the path to dump
dump_all (bool, optional): if need to dump all things. Defaults to None.
exclude (list, optional): will exclude the attributes in this list when dumping. Defaults to None.
"""
self.config(dump_all=dump_all, exclude=exclude)
with Path(path).open("wb") as f:
self.get_backend().dump(self, f)
@classmethod
def load(cls, filepath):
"""
Load the serializable class from a filepath.
Args:
filepath (str): the path of file
Raises:
TypeError: the pickled file must be `type(cls)`
Returns:
`type(cls)`: the instance of `type(cls)`
"""
with open(filepath, "rb") as f:
object = cls.get_backend().load(f)
if isinstance(object, cls):
return object
else:
raise TypeError(f"The instance of {type(object)} is not a valid `{type(cls)}`!")
@classmethod
def get_backend(cls):
"""
Return the real backend of a Serializable class. The pickle_backend value can be "pickle" or "dill".
Returns:
module: pickle or dill module based on pickle_backend
"""
if cls.pickle_backend == "pickle":
return pickle
elif cls.pickle_backend == "dill":
return dill
else:
raise ValueError("Unknown pickle backend, please use 'pickle' or 'dill'.")