/
xgboost_extensions.py
73 lines (52 loc) · 1.92 KB
/
xgboost_extensions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import dataclasses
from os import PathLike
from typing import Any, Collection, Dict, Tuple, Type, Union
try:
import xgboost
except ImportError:
raise NotImplementedError("XGBoost is not installed.")
from hamilton import registry
from hamilton.io import utils
from hamilton.io.data_adapters import DataLoader, DataSaver
XGBOOST_MODEL_TYPES = [xgboost.XGBModel, xgboost.Booster]
XGBOOST_MODEL_TYPES_ANNOTATION = Union[xgboost.XGBModel, xgboost.Booster]
@dataclasses.dataclass
class XGBoostJsonWriter(DataSaver):
"""Write XGBoost models and boosters to json format
See differences with pickle format: https://xgboost.readthedocs.io/en/stable/tutorials/saving_model.html
"""
path: Union[str, PathLike]
@classmethod
def applicable_types(cls) -> Collection[Type]:
return XGBOOST_MODEL_TYPES
def save_data(self, data: XGBOOST_MODEL_TYPES_ANNOTATION) -> Dict[str, Any]:
data.save_model(self.path)
return utils.get_file_metadata(self.path)
@classmethod
def name(cls) -> str:
return "json"
@dataclasses.dataclass
class XGBoostJsonReader(DataLoader):
"""Load XGBoost models and boosters to json format
See differences with pickle format: https://xgboost.readthedocs.io/en/stable/tutorials/saving_model.html
"""
path: Union[str, bytearray, PathLike]
@classmethod
def applicable_types(cls) -> Collection[Type]:
return XGBOOST_MODEL_TYPES
def load_data(self, type_: Type) -> Tuple[XGBOOST_MODEL_TYPES_ANNOTATION, Dict[str, Any]]:
model = type_()
model.load_model(self.path)
metadata = utils.get_file_metadata(self.path)
return model, metadata
@classmethod
def name(cls) -> str:
return "json"
def register_data_loaders():
for loader in [
XGBoostJsonReader,
XGBoostJsonWriter,
]:
registry.register_adapter(loader)
register_data_loaders()
COLUMN_FRIENDLY_DF_TYPE = False