Skip to content

Commit

Permalink
Add more linear models (#216)
Browse files Browse the repository at this point in the history
* convert more linear classifier
* add RidgeCV
  • Loading branch information
xadupre committed Jul 11, 2019
1 parent 8009430 commit edf8eb4
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 25 deletions.
38 changes: 27 additions & 11 deletions skl2onnx/_supported_operators.py
Expand Up @@ -11,16 +11,18 @@
from sklearn.calibration import CalibratedClassifierCV

# Linear classifiers
from sklearn.linear_model import LogisticRegression
from sklearn.linear_model import LogisticRegression, LogisticRegressionCV
from sklearn.linear_model import SGDClassifier
from sklearn.svm import LinearSVC

# Linear regressors
from sklearn.linear_model import ElasticNet
from sklearn.linear_model import Lasso
from sklearn.linear_model import LassoLars
from sklearn.linear_model import ElasticNet, ElasticNetCV
from sklearn.linear_model import Lars, LarsCV
from sklearn.linear_model import Lasso, LassoCV
from sklearn.linear_model import LassoLars, LassoLarsCV
from sklearn.linear_model import LassoLarsIC
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import Ridge
from sklearn.linear_model import Ridge, RidgeCV
from sklearn.linear_model import SGDRegressor
from sklearn.svm import LinearSVR

Expand Down Expand Up @@ -105,9 +107,12 @@
# included in the following list and one output for everything not in
# the list.
sklearn_classifier_list = [
LogisticRegression, SGDClassifier, LinearSVC, SVC, NuSVC,
GradientBoostingClassifier, RandomForestClassifier, DecisionTreeClassifier,
ExtraTreesClassifier, BernoulliNB, MultinomialNB, KNeighborsClassifier,
LogisticRegression, LogisticRegressionCV, SGDClassifier,
LinearSVC, SVC, NuSVC,
GradientBoostingClassifier, RandomForestClassifier,
DecisionTreeClassifier,
ExtraTreesClassifier, BernoulliNB, MultinomialNB,
KNeighborsClassifier,
CalibratedClassifierCV, OneVsRestClassifier, VotingClassifier,
AdaBoostClassifier, MLPClassifier
]
Expand All @@ -129,7 +134,7 @@ def build_sklearn_operator_name_map():
GradientBoostingClassifier, GradientBoostingRegressor,
KNeighborsClassifier, KNeighborsRegressor, NearestNeighbors,
LinearSVC, LinearSVR, SVC, SVR,
LinearRegression, Lasso, LassoLars, Ridge,
LinearRegression,
MLPClassifier, MLPRegressor,
MultinomialNB, BernoulliNB,
OneVsRestClassifier,
Expand All @@ -143,14 +148,25 @@ def build_sklearn_operator_name_map():
RobustScaler, OneHotEncoder, DictVectorizer,
GenericUnivariateSelect, RFE, RFECV, SelectFdr, SelectFpr,
SelectFromModel, SelectFwe, SelectKBest, SelectPercentile,
VarianceThreshold,
VarianceThreshold
] if k is not None}
res.update({
ElasticNet: 'SklearnElasticNetRegressor',
ElasticNet: 'SklearnLinearRegressor',
ElasticNetCV: 'SklearnLinearRegressor',
LinearRegression: 'SklearnLinearRegressor',
Lars: 'SklearnLinearRegressor',
LarsCV: 'SklearnLinearRegressor',
Lasso: 'SklearnLinearRegressor',
LassoCV: 'SklearnLinearRegressor',
LassoLars: 'SklearnLinearRegressor',
LassoLarsCV: 'SklearnLinearRegressor',
LassoLarsIC: 'SklearnLinearRegressor',
LogisticRegression: 'SklearnLinearClassifier',
LogisticRegressionCV: 'SklearnLinearClassifier',
NuSVC: 'SklearnSVC',
NuSVR: 'SklearnSVR',
Ridge: 'SklearnLinearRegressor',
RidgeCV: 'SklearnLinearRegressor',
SGDRegressor: 'SklearnLinearRegressor',
StandardScaler: 'SklearnScaler',
})
Expand Down
3 changes: 2 additions & 1 deletion skl2onnx/operator_converters/linear_classifier.py
Expand Up @@ -7,6 +7,7 @@
import numbers
import numpy as np
import six
from sklearn.linear_model import LogisticRegression
from ..common._registration import register_converter
from ..proto import onnx_proto

Expand Down Expand Up @@ -46,7 +47,7 @@ def convert_sklearn_linear_classifier(scope, operator, container):
classifier_attrs['multi_class'] = 1 if multi_class == 2 else 0
if op.__class__.__name__ == 'LinearSVC':
classifier_attrs['post_transform'] = 'NONE'
elif op.__class__.__name__ == 'LogisticRegression':
elif isinstance(op, LogisticRegression):
ovr = (op.multi_class in ["ovr", "warn"] or
(op.multi_class == 'auto' and (op.classes_.size <= 2 or
op.solver == 'liblinear')))
Expand Down
5 changes: 0 additions & 5 deletions skl2onnx/operator_converters/linear_regressor.py
Expand Up @@ -34,10 +34,5 @@ def convert_sklearn_linear_regressor(scope, operator, container):
**attrs)


register_converter('SklearnElasticNetRegressor',
convert_sklearn_linear_regressor)
register_converter('SklearnLasso', convert_sklearn_linear_regressor)
register_converter('SklearnLassoLars', convert_sklearn_linear_regressor)
register_converter('SklearnLinearRegressor', convert_sklearn_linear_regressor)
register_converter('SklearnLinearSVR', convert_sklearn_linear_regressor)
register_converter('SklearnRidge', convert_sklearn_linear_regressor)
8 changes: 0 additions & 8 deletions skl2onnx/shape_calculators/linear_regressor.py
Expand Up @@ -10,8 +10,6 @@

register_shape_calculator('SklearnAdaBoostRegressor',
calculate_linear_regressor_output_shapes)
register_shape_calculator('SklearnElasticNetRegressor',
calculate_linear_regressor_output_shapes)
register_shape_calculator('SklearnLinearRegressor',
calculate_linear_regressor_output_shapes)
register_shape_calculator('SklearnLinearSVR',
Expand All @@ -26,11 +24,5 @@
calculate_linear_regressor_output_shapes)
register_shape_calculator('SklearnKNeighborsRegressor',
calculate_linear_regressor_output_shapes)
register_shape_calculator('SklearnLasso',
calculate_linear_regressor_output_shapes)
register_shape_calculator('SklearnLassoLars',
calculate_linear_regressor_output_shapes)
register_shape_calculator('SklearnMLPRegressor',
calculate_linear_regressor_output_shapes)
register_shape_calculator('SklearnRidge',
calculate_linear_regressor_output_shapes)
20 changes: 20 additions & 0 deletions tests/test_sklearn_glm_classifier_converter.py
Expand Up @@ -53,6 +53,26 @@ def test_model_logistic_regression_binary_class(self):
" <= StrictVersion('0.2.1')",
)

@unittest.skipIf(not onnx_built_with_ml(),
reason="Requires ONNX-ML extension.")
def test_model_logistic_regression_cv_binary_class(self):
model, X = self._fit_model_binary_classification(
linear_model.LogisticRegressionCV())
model_onnx = convert_sklearn(model, "logistic regression cv",
[("input", FloatTensorType([1, 3]))])
self.assertIsNotNone(model_onnx)
dump_data_and_model(
X.astype(numpy.float32),
model,
model_onnx,
basename="SklearnLogitisticCVRegressionBinary",
# Operator cast-1 is not implemented in onnxruntime
allow_failure="StrictVersion(onnx.__version__)"
" < StrictVersion('1.3') or "
"StrictVersion(onnxruntime.__version__)"
" <= StrictVersion('0.2.1')",
)

@unittest.skipIf(not onnx_built_with_ml(),
reason="Requires ONNX-ML extension.")
def test_model_logistic_regression_binary_class_nointercept(self):
Expand Down
93 changes: 93 additions & 0 deletions tests/test_sklearn_glm_regressor_converter.py
Expand Up @@ -188,6 +188,24 @@ def test_model_elastic_net_regressor(self):
"<= StrictVersion('0.2.1')",
)

def test_model_elastic_net_cv_regressor(self):
model, X = _fit_model(linear_model.ElasticNetCV())
model_onnx = convert_sklearn(
model,
"scikit-learn elastic-net regression",
[("input", FloatTensorType(X.shape))],
)
self.assertIsNotNone(model_onnx)
dump_data_and_model(
X.astype(numpy.float32),
model,
model_onnx,
basename="SklearnElasticNetCV-Dec4",
allow_failure="StrictVersion("
"onnxruntime.__version__)"
"<= StrictVersion('0.2.1')",
)

def test_model_elastic_net_regressor_int(self):
model, X = _fit_model(linear_model.ElasticNet(), is_int=True)
model_onnx = convert_sklearn(model, "elastic net regression",
Expand All @@ -203,6 +221,36 @@ def test_model_elastic_net_regressor_int(self):
"<= StrictVersion('0.2.1')",
)

def test_model_lars(self):
model, X = _fit_model(linear_model.Lars())
model_onnx = convert_sklearn(model, "lars",
[("input", FloatTensorType(X.shape))])
self.assertIsNotNone(model_onnx)
dump_data_and_model(
X.astype(numpy.float32),
model,
model_onnx,
basename="SklearnLars-Dec4",
allow_failure="StrictVersion("
"onnxruntime.__version__)"
"<= StrictVersion('0.2.1')",
)

def test_model_lars_cv(self):
model, X = _fit_model(linear_model.LarsCV())
model_onnx = convert_sklearn(model, "lars",
[("input", FloatTensorType(X.shape))])
self.assertIsNotNone(model_onnx)
dump_data_and_model(
X.astype(numpy.float32),
model,
model_onnx,
basename="SklearnLarsCV-Dec4",
allow_failure="StrictVersion("
"onnxruntime.__version__)"
"<= StrictVersion('0.2.1')",
)

def test_model_lasso_lars(self):
model, X = _fit_model(linear_model.LassoLars(alpha=0.01))
model_onnx = convert_sklearn(model, "lasso lars",
Expand All @@ -218,6 +266,51 @@ def test_model_lasso_lars(self):
"<= StrictVersion('0.2.1')",
)

def test_model_lasso_lars_cv(self):
model, X = _fit_model(linear_model.LassoLarsCV())
model_onnx = convert_sklearn(model, "lasso lars cv",
[("input", FloatTensorType(X.shape))])
self.assertIsNotNone(model_onnx)
dump_data_and_model(
X.astype(numpy.float32),
model,
model_onnx,
basename="SklearnLassoLarsCV-Dec4",
allow_failure="StrictVersion("
"onnxruntime.__version__)"
"<= StrictVersion('0.2.1')",
)

def test_model_lasso_lars_ic(self):
model, X = _fit_model(linear_model.LassoLarsIC())
model_onnx = convert_sklearn(model, "lasso lars cv",
[("input", FloatTensorType(X.shape))])
self.assertIsNotNone(model_onnx)
dump_data_and_model(
X.astype(numpy.float32),
model,
model_onnx,
basename="SklearnLassoLarsIC-Dec4",
allow_failure="StrictVersion("
"onnxruntime.__version__)"
"<= StrictVersion('0.2.1')",
)

def test_model_lasso_cv(self):
model, X = _fit_model(linear_model.LassoCV())
model_onnx = convert_sklearn(model, "lasso cv",
[("input", FloatTensorType(X.shape))])
self.assertIsNotNone(model_onnx)
dump_data_and_model(
X.astype(numpy.float32),
model,
model_onnx,
basename="SklearnLassoCV-Dec4",
allow_failure="StrictVersion("
"onnxruntime.__version__)"
"<= StrictVersion('0.2.1')",
)

def test_model_lasso_lars_int(self):
model, X = _fit_model(linear_model.LassoLars(), is_int=True)
model_onnx = convert_sklearn(model, "lasso lars",
Expand Down

0 comments on commit edf8eb4

Please sign in to comment.