Skip to content

Commit

Permalink
Model builders API update (#1320)
Browse files Browse the repository at this point in the history

---------

Co-authored-by: Dmitry Razdoburdin <>
Co-authored-by: Nikolay Petrov <nikolay.a.petrov@intel.com>
Co-authored-by: Alexander Andreev <alexander.andreev@intel.com>
  • Loading branch information
3 people committed Jul 12, 2023
1 parent 27d3f39 commit 12b963a
Show file tree
Hide file tree
Showing 14 changed files with 352 additions and 107 deletions.
5 changes: 5 additions & 0 deletions daal4py/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,8 @@
'[c]sh/psxevars.[c]sh may solve the issue.\n')

raise

from . import mb
from . import sklearn

__all__ = ['mb', 'sklearn']
20 changes: 20 additions & 0 deletions daal4py/mb/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#!/usr/bin/env python
#===============================================================================
# Copyright 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#===============================================================================

from .model_builders import GBTDAALBaseModel, convert_model

__all__ = ['GBTDAALBaseModel', 'convert_model']
222 changes: 222 additions & 0 deletions daal4py/mb/model_builders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
#===============================================================================
# Copyright 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#===============================================================================

# daal4py Model builders API

import numpy as np
import daal4py as d4p

try:
from pandas import DataFrame
from pandas.core.dtypes.cast import find_common_type
pandas_is_imported = True
except (ImportError, ModuleNotFoundError):
pandas_is_imported = False


def parse_dtype(dt):
if dt == np.double:
return "double"
if dt == np.single:
return "float"
raise ValueError(f"Input array has unexpected dtype = {dt}")


def getFPType(X):
if pandas_is_imported:
if isinstance(X, DataFrame):
dt = find_common_type(X.dtypes.tolist())
return parse_dtype(dt)

dt = getattr(X, 'dtype', None)
return parse_dtype(dt)


class GBTDAALBaseModel:
def _get_params_from_lightgbm(self, params):
self.n_classes_ = params["num_tree_per_iteration"]
objective_fun = params["objective"]
if self.n_classes_ <= 2:
if "binary" in objective_fun: # nClasses == 1
self.n_classes_ = 2

self.n_features_in_ = params["max_feature_idx"] + 1

def _get_params_from_xgboost(self, params):
self.n_classes_ = int(params["learner"]["learner_model_param"]["num_class"])
objective_fun = params["learner"]["learner_train_param"]["objective"]
if self.n_classes_ <= 2:
if objective_fun in ["binary:logistic", "binary:logitraw"]:
self.n_classes_ = 2

self.n_features_in_ = int(params["learner"]["learner_model_param"]["num_feature"])

def _get_params_from_catboost(self, params):
if 'class_params' in params['model_info']:
self.n_classes_ = len(params['model_info']['class_params']['class_to_label'])
self.n_features_in_ = len(params['features_info']['float_features'])

def _convert_model_from_lightgbm(self, booster):
lgbm_params = d4p.get_lightgbm_params(booster)
self.daal_model_ = d4p.get_gbt_model_from_lightgbm(booster, lgbm_params)
self._get_params_from_lightgbm(lgbm_params)

def _convert_model_from_xgboost(self, booster):
xgb_params = d4p.get_xgboost_params(booster)
self.daal_model_ = d4p.get_gbt_model_from_xgboost(booster, xgb_params)
self._get_params_from_xgboost(xgb_params)

def _convert_model_from_catboost(self, booster):
catboost_params = d4p.get_catboost_params(booster)
self.daal_model_ = d4p.get_gbt_model_from_catboost(booster)
self._get_params_from_catboost(catboost_params)

def _convert_model(self, model):
(submodule_name, class_name) = (model.__class__.__module__,
model.__class__.__name__)
self_class_name = self.__class__.__name__

# Build GBTDAALClassifier from LightGBM
if (submodule_name, class_name) == ("lightgbm.sklearn", "LGBMClassifier"):
if self_class_name == "GBTDAALClassifier":
self._convert_model_from_lightgbm(model.booster_)
else:
raise TypeError(f"Only GBTDAALClassifier can be created from\
{submodule_name}.{class_name} (got {self_class_name})")
# Build GBTDAALClassifier from XGBoost
elif (submodule_name, class_name) == ("xgboost.sklearn", "XGBClassifier"):
if self_class_name == "GBTDAALClassifier":
self._convert_model_from_xgboost(model.get_booster())
else:
raise TypeError(f"Only GBTDAALClassifier can be created from\
{submodule_name}.{class_name} (got {self_class_name})")
# Build GBTDAALClassifier from CatBoost
elif (submodule_name, class_name) == ("catboost.core", "CatBoostClassifier"):
if self_class_name == "GBTDAALClassifier":
self._convert_model_from_catboost(model)
else:
raise TypeError(f"Only GBTDAALClassifier can be created from\
{submodule_name}.{class_name} (got {self_class_name})")
# Build GBTDAALRegressor from LightGBM
elif (submodule_name, class_name) == ("lightgbm.sklearn", "LGBMRegressor"):
if self_class_name == "GBTDAALRegressor":
self._convert_model_from_lightgbm(model.booster_)
else:
raise TypeError(f"Only GBTDAALRegressor can be created from\
{submodule_name}.{class_name} (got {self_class_name})")
# Build GBTDAALRegressor from XGBoost
elif (submodule_name, class_name) == ("xgboost.sklearn", "XGBRegressor"):
if self_class_name == "GBTDAALRegressor":
self._convert_model_from_xgboost(model.get_booster())
else:
raise TypeError(f"Only GBTDAALRegressor can be created from\
{submodule_name}.{class_name} (got {self_class_name})")
# Build GBTDAALRegressor from CatBoost
elif (submodule_name, class_name) == ("catboost.core", "CatBoostRegressor"):
if self_class_name == "GBTDAALRegressor":
self._convert_model_from_catboost(model)
else:
raise TypeError(f"Only GBTDAALRegressor can be created from\
{submodule_name}.{class_name} (got {self_class_name})")
# Build GBTDAALModel from LightGBM
elif (submodule_name, class_name) == ("lightgbm.basic", "Booster"):
if self_class_name == "GBTDAALModel":
self._convert_model_from_lightgbm(model)
else:
raise TypeError(f"Only GBTDAALModel can be created from\
{submodule_name}.{class_name} (got {self_class_name})")
# Build GBTDAALModel from XGBoost
elif (submodule_name, class_name) == ("xgboost.core", "Booster"):
if self_class_name == "GBTDAALModel":
self._convert_model_from_xgboost(model)
else:
raise TypeError(f"Only GBTDAALModel can be created from\
{submodule_name}.{class_name} (got {self_class_name})")
# Build GBTDAALModel from CatBoost
elif (submodule_name, class_name) == ("catboost.core", "CatBoost"):
if self_class_name == "GBTDAALModel":
self._convert_model_from_catboost(model)
else:
raise TypeError(f"Only GBTDAALModel can be created from\
{submodule_name}.{class_name} (got {self_class_name})")
else:
raise TypeError(f"Unknown model format {submodule_name}.{class_name}")

def _predict_classification(self, X, fptype, resultsToEvaluate):
if X.shape[1] != self.n_features_in_:
raise ValueError('Shape of input is different from what was seen in `fit`')

if not hasattr(self, 'daal_model_'):
raise ValueError((
"The class {} instance does not have 'daal_model_' attribute set. "
"Call 'fit' with appropriate arguments before using this method.")
.format(type(self).__name__))

# Prediction
predict_algo = d4p.gbt_classification_prediction(
fptype=fptype,
nClasses=self.n_classes_,
resultsToEvaluate=resultsToEvaluate)
predict_result = predict_algo.compute(X, self.daal_model_)

if resultsToEvaluate == "computeClassLabels":
return predict_result.prediction.ravel().astype(np.int64, copy=False)
else:
return predict_result.probabilities

def _predict_regression(self, X, fptype):
if X.shape[1] != self.n_features_in_:
raise ValueError('Shape of input is different from what was seen in `fit`')

if not hasattr(self, 'daal_model_'):
raise ValueError((
"The class {} instance does not have 'daal_model_' attribute set. "
"Call 'fit' with appropriate arguments before using this method.").format(
type(self).__name__))

# Prediction
predict_algo = d4p.gbt_regression_prediction(fptype=fptype)
predict_result = predict_algo.compute(X, self.daal_model_)

return predict_result.prediction.ravel()


class GBTDAALModel(GBTDAALBaseModel):
def __init__(self):
pass

def predict(self, X):
fptype = getFPType(X)
if self._is_regression:
return self._predict_regression(X, fptype)
else:
return self._predict_classification(X, fptype, "computeClassLabels")

def predict_proba(self, X):
fptype = getFPType(X)
if self._is_regression:
raise NotImplementedError("Can't predict probabilities for regression task")
else:
return self._predict_classification(X, fptype, "computeClassProbabilities")


def convert_model(model):
gbm = GBTDAALModel()
gbm._convert_model(model)

gbm._is_regression = isinstance(gbm.daal_model_, d4p.gbt_regression_model)

return gbm
74 changes: 35 additions & 39 deletions daal4py/sklearn/ensemble/GBTDAAL.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from .._utils import getFPType


class GBTDAALBase(BaseEstimator):
class GBTDAALBase(BaseEstimator, d4p.mb.GBTDAALBaseModel):
def __init__(self,
split_method='inexact',
max_iterations=50,
Expand Down Expand Up @@ -101,6 +101,11 @@ def _check_params(self):
raise ValueError('Parameter "min_bin_size" must be '
'non-zero positive integer value.')

allow_nan_ = False

def _more_tags(self):
return {"allow_nan": self.allow_nan_}


class GBTDAALClassifier(GBTDAALBase, ClassifierMixin):
def fit(self, X, y):
Expand Down Expand Up @@ -165,41 +170,28 @@ def fit(self, X, y):
return self

def _predict(self, X, resultsToEvaluate):
# Input validation
if not self.allow_nan_:
X = check_array(X, dtype=[np.single, np.double])
else:
X = check_array(X, dtype=[np.single, np.double], force_all_finite='allow-nan')

# Check is fit had been called
check_is_fitted(self, ['n_features_in_', 'n_classes_'])

# Input validation
X = check_array(X, dtype=[np.single, np.double])
if X.shape[1] != self.n_features_in_:
raise ValueError('Shape of input is different from what was seen in `fit`')

# Trivial case
if self.n_classes_ == 1:
return np.full(X.shape[0], self.classes_[0])

if not hasattr(self, 'daal_model_'):
raise ValueError((
"The class {} instance does not have 'daal_model_' attribute set. "
"Call 'fit' with appropriate arguments before using this method.").format(
type(self).__name__))

# Define type of data
fptype = getFPType(X)

# Prediction
predict_algo = d4p.gbt_classification_prediction(
fptype=fptype,
nClasses=self.n_classes_,
resultsToEvaluate=resultsToEvaluate)
predict_result = predict_algo.compute(X, self.daal_model_)
predict_result = self._predict_classification(X, fptype, resultsToEvaluate)

if resultsToEvaluate == "computeClassLabels":
# Decode labels
le = preprocessing.LabelEncoder()
le.classes_ = self.classes_
return le.inverse_transform(
predict_result.prediction.ravel().astype(np.int64, copy=False))
return predict_result.probabilities
return le.inverse_transform(predict_result)
return predict_result

def predict(self, X):
return self._predict(X, "computeClassLabels")
Expand All @@ -218,6 +210,14 @@ def predict_log_proba(self, X):

return proba

def convert_model(model):
gbm = GBTDAALClassifier()
gbm._convert_model(model)

gbm.classes_ = model.classes_
gbm.allow_nan_ = True
return gbm


class GBTDAALRegressor(GBTDAALBase, RegressorMixin):
def fit(self, X, y):
Expand Down Expand Up @@ -264,25 +264,21 @@ def fit(self, X, y):
return self

def predict(self, X):
# Check is fit had been called
check_is_fitted(self, ['n_features_in_'])

# Input validation
X = check_array(X, dtype=[np.single, np.double])
if X.shape[1] != self.n_features_in_:
raise ValueError('Shape of input is different from what was seen in `fit`')
if not self.allow_nan_:
X = check_array(X, dtype=[np.single, np.double])
else:
X = check_array(X, dtype=[np.single, np.double], force_all_finite='allow-nan')

if not hasattr(self, 'daal_model_'):
raise ValueError((
"The class {} instance does not have 'daal_model_' attribute set. "
"Call 'fit' with appropriate arguments before using this method.").format(
type(self).__name__))
# Check is fit had been called
check_is_fitted(self, ['n_features_in_'])

# Define type of data
fptype = getFPType(X)
return self._predict_regression(X, fptype)

# Prediction
predict_algo = d4p.gbt_regression_prediction(fptype=fptype)
predict_result = predict_algo.compute(X, self.daal_model_)
def convert_model(model):
gbm = GBTDAALRegressor()
gbm._convert_model(model)

return predict_result.prediction.ravel()
gbm.allow_nan_ = True
return gbm
4 changes: 2 additions & 2 deletions daal4py/sklearn/ensemble/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,5 @@
from .GBTDAAL import (GBTDAALClassifier, GBTDAALRegressor)
from .AdaBoostClassifier import AdaBoostClassifier

__all__ = ['RandomForestClassifier', 'RandomForestRegressor', 'GBTDAALClassifier',
'GBTDAALRegressor', 'AdaBoostClassifier']
__all__ = ['RandomForestClassifier', 'RandomForestRegressor',
'GBTDAALClassifier', 'GBTDAALRegressor', 'AdaBoostClassifier']
6 changes: 3 additions & 3 deletions doc/daal4py/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ General usage

Building models from Gradient Boosting frameworks

- `XGBoost* model conversion <https://github.com/intel/scikit-learn-intelex/blob/master/examples/daal4py/gbt_cls_model_create_from_xgboost_batch.py>`_
- `LightGBM* model conversion <https://github.com/intel/scikit-learn-intelex/blob/master/examples/daal4py/gbt_cls_model_create_from_lightgbm_batch.py>`_
- `CatBoost* model conversion <https://github.com/intel/scikit-learn-intelex/blob/master/examples/daal4py/gbt_cls_model_create_from_catboost_batch.py>`_
- `XGBoost* model conversion <https://github.com/intel/scikit-learn-intelex/blob/master/examples/daal4py/model_builders_xgboost.py>`_
- `LightGBM* model conversion <https://github.com/intel/scikit-learn-intelex/blob/master/examples/daal4py/model_builders_lightgbm.py>`_
- `CatBoost* model conversion <https://github.com/intel/scikit-learn-intelex/blob/master/examples/daal4py/model_builders_catboost.py>`_


Principal Component Analysis (PCA) Transform
Expand Down

0 comments on commit 12b963a

Please sign in to comment.