Skip to content

Commit

Permalink
Add ElasticNet, Lasso and Ridge support (#625)
Browse files Browse the repository at this point in the history
Using the same function as for LinearRegression this adds support for
ElasticNet, Lasso and Ridge.

closes #624
  • Loading branch information
fd0r committed Aug 29, 2022
1 parent 77b1561 commit 0a1f21e
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 6 deletions.
5 changes: 4 additions & 1 deletion hummingbird/ml/operator_converters/sklearn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def convert_sklearn_linear_model(operator, device, extra_config):

def convert_sklearn_linear_regression_model(operator, device, extra_config):
"""
Converter for `sklearn.linear_model.LinearRegression`, `sklearn.svm.LinearSVR` and `sklearn.linear_model.RidgeCV`
Converter for `sklearn.linear_model.LinearRegression`, `sklearn.linear_model.Lasso`, `sklearn.linear_model.ElasticNet`, `sklearn.linear_model.Ridge`, `sklearn.svm.LinearSVR` and `sklearn.linear_model.RidgeCV`
Args:
operator: An operator wrapping a `sklearn.linear_model.LinearRegression`, `sklearn.svm.LinearSVR`
Expand All @@ -85,6 +85,9 @@ def convert_sklearn_linear_regression_model(operator, device, extra_config):


register_converter("SklearnLinearRegression", convert_sklearn_linear_regression_model)
register_converter("SklearnLasso", convert_sklearn_linear_regression_model)
register_converter("SklearnElasticNet", convert_sklearn_linear_regression_model)
register_converter("SklearnRidge", convert_sklearn_linear_regression_model)
register_converter("SklearnLogisticRegression", convert_sklearn_linear_model)
register_converter("SklearnLinearSVC", convert_sklearn_linear_model)
register_converter("SklearnLinearSVR", convert_sklearn_linear_regression_model)
Expand Down
14 changes: 13 additions & 1 deletion hummingbird/ml/supported.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,16 @@ def _build_sklearn_operator_list():
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor

# Linear-based models
from sklearn.linear_model import LinearRegression, LogisticRegression, LogisticRegressionCV, SGDClassifier, RidgeCV
from sklearn.linear_model import (
LinearRegression,
LogisticRegression,
LogisticRegressionCV,
SGDClassifier,
RidgeCV,
ElasticNet,
Ridge,
Lasso,
)

# SVM-based models
from sklearn.svm import LinearSVC, SVC, NuSVC, LinearSVR
Expand Down Expand Up @@ -211,6 +220,9 @@ def _build_sklearn_operator_list():
LogisticRegressionCV,
SGDClassifier,
RidgeCV,
Lasso,
ElasticNet,
Ridge,
# Clustering
KMeans,
MeanShift,
Expand Down
89 changes: 88 additions & 1 deletion tests/test_sklearn_linear_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,16 @@

import numpy as np
import torch
from sklearn.linear_model import LinearRegression, LogisticRegression, SGDClassifier, LogisticRegressionCV, RidgeCV
from sklearn.linear_model import (
LinearRegression,
LogisticRegression,
SGDClassifier,
LogisticRegressionCV,
RidgeCV,
Lasso,
ElasticNet,
Ridge,
)
from sklearn import datasets

import hummingbird.ml
Expand Down Expand Up @@ -111,6 +120,84 @@ def test_linear_regression_float(self):
np.random.seed(0)
self._test_linear_regression(np.random.rand(100))

# Lasso test function to be parameterized
def _test_lasso(self, y_input):
model = Lasso()

np.random.seed(0)
X = np.random.rand(100, 200)
X = np.array(X, dtype=np.float32)
y = y_input

model.fit(X, y)

torch_model = hummingbird.ml.convert(model, "torch")

self.assertTrue(torch_model is not None)
np.testing.assert_allclose(model.predict(X), torch_model.predict(X), rtol=1e-6, atol=1e-6)

# Lasso with ints
def test_lasso_int(self):
np.random.seed(0)
self._test_lasso(np.random.randint(2, size=100))

# Lasso with floats
def test_lasso_float(self):
np.random.seed(0)
self._test_lasso(np.random.rand(100))

# Ridge test function to be parameterized
def _test_ridge(self, y_input):
model = Ridge()

np.random.seed(0)
X = np.random.rand(100, 200)
X = np.array(X, dtype=np.float32)
y = y_input

model.fit(X, y)

torch_model = hummingbird.ml.convert(model, "torch")

self.assertTrue(torch_model is not None)
np.testing.assert_allclose(model.predict(X), torch_model.predict(X), rtol=1e-6, atol=1e-6)

# Ridge with ints
def test_ridge_int(self):
np.random.seed(0)
self._test_ridge(np.random.randint(2, size=100))

# Ridge with floats
def test_ridge_float(self):
np.random.seed(0)
self._test_ridge(np.random.rand(100))

# ElasticNet test function to be parameterized
def _test_elastic_net(self, y_input):
model = ElasticNet()

np.random.seed(0)
X = np.random.rand(100, 200)
X = np.array(X, dtype=np.float32)
y = y_input

model.fit(X, y)

torch_model = hummingbird.ml.convert(model, "torch")

self.assertTrue(torch_model is not None)
np.testing.assert_allclose(model.predict(X), torch_model.predict(X), rtol=1e-6, atol=1e-6)

# ElasticNet with ints
def test_elastic_net_int(self):
np.random.seed(0)
self._test_elastic_net(np.random.randint(2, size=100))

# ElasticNet with floats
def test_elastic_net_float(self):
np.random.seed(0)
self._test_elastic_net(np.random.rand(100))

# RidgeCV test function to be parameterized
def _test_ridge_cv(self, y_input):
model = RidgeCV()
Expand Down
6 changes: 3 additions & 3 deletions tests/test_sklearn_sv_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_linear_svc_multi(self):
def test_linear_svc_shifted(self):
self._test_linear_svc(3, labels_shift=2)

# RidgeCV test function to be parameterized
# LinearSVR test function to be parameterized
def _test_svr(self, y_input):
model = LinearSVR()

Expand All @@ -53,12 +53,12 @@ def _test_svr(self, y_input):
self.assertTrue(torch_model is not None)
np.testing.assert_allclose(model.predict(X), torch_model.predict(X), rtol=1e-6, atol=1e-6)

# RidgeCV with ints
# LinearSVR with ints
def test_svr_int(self):
np.random.seed(0)
self._test_svr(np.random.randint(2, size=100))

# RidgeCV with floats
# LinearSVR with floats
def test_svr_float(self):
np.random.seed(0)
self._test_svr(np.random.rand(100))
Expand Down

0 comments on commit 0a1f21e

Please sign in to comment.