/
gbdt_model.py
190 lines (159 loc) · 6.5 KB
/
gbdt_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
"""
Gradient Boosting Decision Tree wrapper interface
"""
import os
import logging
import tempfile
import warnings
from typing import Callable, Optional, Union
import numpy as np
from sklearn.base import BaseEstimator
from sklearn.model_selection import train_test_split
from deepchem.data import Dataset
from deepchem.models.sklearn_models import SklearnModel
logger = logging.getLogger(__name__)
class GBDTModel(SklearnModel):
"""Wrapper class that wraps GBDT models as DeepChem models.
This class supports LightGBM/XGBoost models.
"""
def __init__(self,
model: BaseEstimator,
model_dir: Optional[str] = None,
early_stopping_rounds: int = 50,
eval_metric: Optional[Union[str, Callable]] = None,
**kwargs):
"""
Parameters
----------
model: BaseEstimator
The model instance of scikit-learn wrapper LightGBM/XGBoost models.
model_dir: str, optional (default None)
Path to directory where model will be stored.
early_stopping_rounds: int, optional (default 50)
Activates early stopping. Validation metric needs to improve at least once
in every early_stopping_rounds round(s) to continue training.
eval_metric: Union[str, Callable]
If string, it should be a built-in evaluation metric to use.
If callable, it should be a custom evaluation metric, see official note for more details.
"""
try:
import xgboost
import lightgbm
except:
raise ModuleNotFoundError(
"XGBoost or LightGBM modules not found. This function requires these modules to be installed."
)
if model_dir is not None:
if not os.path.exists(model_dir):
os.makedirs(model_dir)
else:
model_dir = tempfile.mkdtemp()
self.model_dir = model_dir
self.model = model
self.model_class = model.__class__
self.early_stopping_rounds = early_stopping_rounds
self.model_type = self._check_model_type()
if self.early_stopping_rounds <= 0:
raise ValueError("Early Stopping Rounds cannot be less than 1.")
if self.model.__class__.__name__.startswith('XGB'):
self.callbacks = [
xgboost.callback.EarlyStopping(
rounds=self.early_stopping_rounds)
]
elif self.model.__class__.__name__.startswith('LGBM'):
self.callbacks = [
lightgbm.early_stopping(
stopping_rounds=self.early_stopping_rounds),
]
if eval_metric is None:
if self.model_type == "classification":
self.eval_metric: Optional[Union[str, Callable]] = "auc"
elif self.model_type == "regression":
self.eval_metric = "mae"
else:
self.eval_metric = eval_metric
else:
self.eval_metric = eval_metric
def _check_model_type(self) -> str:
class_name = self.model.__class__.__name__
if class_name.endswith("Classifier"):
return "classification"
elif class_name.endswith("Regressor"):
return "regression"
elif class_name == "NoneType":
return "none"
else:
raise ValueError(
"{} is not a supported model instance.".format(class_name))
def fit(self, dataset: Dataset):
"""Fits GDBT model with all data.
First, this function splits all data into train and valid data (8:2),
and finds the best n_estimators. And then, we retrain all data using
best n_estimators * 1.25.
Parameters
----------
dataset: Dataset
The `Dataset` to train this model on.
"""
X = dataset.X
y = np.squeeze(dataset.y)
# GDBT doesn't support multi-output(task)
if len(y.shape) != 1:
raise ValueError("GDBT model doesn't support multi-output(task)")
seed = self.model.random_state
stratify = None
if self.model_type == "classification":
stratify = y
# Find optimal n_estimators based on original learning_rate and early_stopping_rounds
X_train, X_test, y_train, y_test = train_test_split(X,
y,
test_size=0.2,
random_state=seed,
stratify=stratify)
self.model.fit(
X_train,
y_train,
callbacks=self.callbacks,
eval_metric=self.eval_metric,
eval_set=[(X_test, y_test)],
)
# retrain model to whole data using best n_estimators * 1.25
if self.model.__class__.__name__.startswith('XGB'):
estimated_best_round = np.round(
(self.model.best_iteration + 1) * 1.25)
else:
estimated_best_round = np.round(self.model.best_iteration_ * 1.25)
self.model.n_estimators = np.int64(estimated_best_round)
self.model.fit(X, y, eval_metric=self.eval_metric)
def fit_with_eval(self, train_dataset: Dataset, valid_dataset: Dataset):
"""Fits GDBT model with valid data.
Parameters
----------
train_dataset: Dataset
The `Dataset` to train this model on.
valid_dataset: Dataset
The `Dataset` to validate this model on.
"""
X_train, X_valid = train_dataset.X, valid_dataset.X
y_train, y_valid = np.squeeze(train_dataset.y), np.squeeze(
valid_dataset.y)
# GDBT doesn't support multi-output(task)
if len(y_train.shape) != 1 or len(y_valid.shape) != 1:
raise ValueError("GDBT model doesn't support multi-output(task)")
self.model.fit(
X_train,
y_train,
callbacks=self.callbacks,
eval_metric=self.eval_metric,
eval_set=[(X_valid, y_valid)],
)
#########################################
# Deprecation warnings for XGBoostModel
#########################################
class XGBoostModel(GBDTModel):
def __init__(self, *args, **kwargs):
warnings.warn(
"XGBoostModel is deprecated and has been renamed to GBDTModel.",
FutureWarning,
)
super(XGBoostModel, self).__init__(*args, **kwargs)