diff --git a/README.md b/README.md
index 2a1818d..f66f918 100644
--- a/README.md
+++ b/README.md
@@ -30,7 +30,7 @@ This library is in beta, and currently not all models are supported. The library
| [Ridge](sklearn_pmml_model/linear_model) | ✅2 | ✅ | ✅3 |
| [Lasso](sklearn_pmml_model/linear_model) | ✅2 | ✅ | ✅3 |
| [ElasticNet](sklearn_pmml_model/linear_model) | ✅2 | ✅ | ✅ |
-| [Gaussian Naive Bayes](sklearn_pmml_model/naive_bayes) | ✅ | | |
+| [Gaussian Naive Bayes](sklearn_pmml_model/naive_bayes) | ✅ | | ✅3 |
1 Categorical feature support using slightly modified internals, based on [scikit-learn#12866](https://github.com/scikit-learn/scikit-learn/pull/12866).
diff --git a/models/tree-iris.pmml b/models/tree-iris.pmml
index b8ace33..f2f2fbf 100644
--- a/models/tree-iris.pmml
+++ b/models/tree-iris.pmml
@@ -1,396 +1,111 @@
-
+
-
- 2018-06-18T14:47:30Z
+
+ 2021-07-06T10:18:03Z
- PMMLPipeline(steps=[('classifier', DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,
- max_features=None, max_leaf_nodes=None,
- min_impurity_decrease=0.0, min_impurity_split=None,
- min_samples_leaf=1, min_samples_split=2,
- min_weight_fraction_leaf=0.0, presort=False, random_state=1,
- splitter='best'))])
+ PMMLPipeline(steps=[('classifier', DecisionTreeClassifier(random_state=1))])
-
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
-
-
-
-
-
-
-
-
-
+
+
+
+
-
-
+
+
+
+
+
+
+
+
+
+
+
+
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
+
+
+
+
+
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
diff --git a/sklearn_pmml_model/base.py b/sklearn_pmml_model/base.py
index c50b7ac..c2dc373 100644
--- a/sklearn_pmml_model/base.py
+++ b/sklearn_pmml_model/base.py
@@ -1,11 +1,13 @@
from sklearn.base import BaseEstimator
-from sklearn.preprocessing import LabelBinarizer
+from sklearn.preprocessing import LabelBinarizer, OneHotEncoder
+from sklearn.compose import ColumnTransformer
from xml.etree import cElementTree as eTree
from cached_property import cached_property
from sklearn_pmml_model.datatypes import Category
from collections import OrderedDict
import datetime
import numpy as np
+import pandas as pd
class PMMLBaseEstimator(BaseEstimator):
@@ -137,13 +139,18 @@ def fit(self, x, y):
raise Exception('Not supported.')
def _prepare_data(self, X):
- X = np.asarray(X)
+ pmml_features = [f for f,e in self.fields.items() if e is not self.target_field and e.tag == 'DataField']
- for column, (index, field_type) in self.field_mapping.items():
- if type(field_type) is Category and index is not None and type(X[0,index]) is str:
- categories = [str(v) for v in field_type.categories]
- categories += [c for c in np.unique(X[:,index]) if c not in categories]
- X[:,index] = [categories.index(x) for x in X[:,index]]
+ if isinstance(X, pd.DataFrame):
+ X.columns = X.columns.map(str)
+
+ try:
+ X = X[pmml_features]
+ except KeyError:
+ raise Exception('The features in the input data do not match features expected by the PMML model.')
+ elif X.shape[1] != len(pmml_features):
+ raise Exception('The number of features in provided data does not match expected number of features in the PMML. '
+ 'Provide pandas.Dataframe, or provide data matching the DataFields in the PMML document.')
return X
@@ -258,3 +265,60 @@ def findall(element, path):
if element is None:
return []
return element.findall(path)
+
+
+class OneHotEncodingMixin:
+ """
+ Mixin class to automatically one-hot encode categorical variables.
+
+ """
+ def __init__(self):
+ # Setup a column transformer to encode categorical variables
+ target = self.target_field.get('name')
+ fields = [field for name, field in self.fields.items() if name != target]
+
+ def encoder_for(field):
+ if field.get('optype') != 'categorical':
+ return 'passthrough'
+
+ encoder = OneHotEncoder()
+ encoder.categories_ = np.array([self.field_mapping[field.get('name')][1].categories])
+ encoder.drop_idx_ = np.array([None for x in encoder.categories_])
+ encoder._legacy_mode = False
+ return encoder
+
+ transformer = ColumnTransformer(
+ transformers=[
+ (field.get('name'), encoder_for(field), [self.field_mapping[field.get('name')][0]])
+ for field in fields
+ if field.tag == 'DataField'
+ ]
+ )
+
+ X = np.array([[0 for field in fields if field.tag == "DataField"]])
+ transformer._validate_transformers()
+ transformer._validate_column_callables(X)
+ transformer._validate_remainder(X)
+ transformer.transformers_ = transformer.transformers
+ transformer.sparse_output_ = False
+ transformer._feature_names_in = None
+
+ self.transformer = transformer
+
+ def _prepare_data(self, X):
+ X = super()._prepare_data(X)
+ return self.transformer.transform(X)
+
+
+class IntegerEncodingMixin:
+ def _prepare_data(self, X):
+ X = super()._prepare_data(X)
+ X = np.asarray(X)
+
+ for column, (index, field_type) in self.field_mapping.items():
+ if type(field_type) is Category and index is not None and type(X[0, index]) is str:
+ categories = [str(v) for v in field_type.categories]
+ categories += [c for c in np.unique(X[:, index]) if c not in categories]
+ X[:, index] = [categories.index(x) for x in X[:, index]]
+
+ return X
\ No newline at end of file
diff --git a/sklearn_pmml_model/ensemble/forest.py b/sklearn_pmml_model/ensemble/forest.py
index 771e1fe..e29f37b 100644
--- a/sklearn_pmml_model/ensemble/forest.py
+++ b/sklearn_pmml_model/ensemble/forest.py
@@ -1,11 +1,11 @@
import numpy as np
import warnings
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
-from sklearn_pmml_model.base import PMMLBaseClassifier, PMMLBaseRegressor
+from sklearn_pmml_model.base import PMMLBaseClassifier, PMMLBaseRegressor, IntegerEncodingMixin
from sklearn_pmml_model.tree import get_tree
-class PMMLForestClassifier(PMMLBaseClassifier, RandomForestClassifier):
+class PMMLForestClassifier(IntegerEncodingMixin, PMMLBaseClassifier, RandomForestClassifier):
"""
A random forest classifier.
@@ -92,7 +92,7 @@ def _more_tags(self):
return RandomForestClassifier._more_tags(self)
-class PMMLForestRegressor(PMMLBaseRegressor, RandomForestRegressor):
+class PMMLForestRegressor(IntegerEncodingMixin, PMMLBaseRegressor, RandomForestRegressor):
"""
A random forest regressor.
diff --git a/sklearn_pmml_model/ensemble/gb.py b/sklearn_pmml_model/ensemble/gb.py
index a1ab5ee..8cdd4cd 100644
--- a/sklearn_pmml_model/ensemble/gb.py
+++ b/sklearn_pmml_model/ensemble/gb.py
@@ -3,13 +3,13 @@
import numpy as np
from sklearn.tree import DecisionTreeRegressor
from sklearn.ensemble import GradientBoostingClassifier, GradientBoostingRegressor, _gb_losses
-from sklearn_pmml_model.base import PMMLBaseClassifier, PMMLBaseRegressor
+from sklearn_pmml_model.base import PMMLBaseClassifier, PMMLBaseRegressor, IntegerEncodingMixin
from sklearn_pmml_model.tree import get_tree
from scipy.special import expit
from ._gradient_boosting import predict_stages
-class PMMLGradientBoostingClassifier(PMMLBaseClassifier, GradientBoostingClassifier, ABC):
+class PMMLGradientBoostingClassifier(IntegerEncodingMixin, PMMLBaseClassifier, GradientBoostingClassifier, ABC):
"""
Gradient Boosting for classification.
@@ -135,7 +135,7 @@ def _more_tags(self):
return GradientBoostingClassifier._more_tags(self)
-class PMMLGradientBoostingRegressor(PMMLBaseRegressor, GradientBoostingRegressor, ABC):
+class PMMLGradientBoostingRegressor(IntegerEncodingMixin, PMMLBaseRegressor, GradientBoostingRegressor, ABC):
"""
Gradient Boosting for regression.
diff --git a/sklearn_pmml_model/linear_model/base.py b/sklearn_pmml_model/linear_model/base.py
index 75870fd..6211a68 100644
--- a/sklearn_pmml_model/linear_model/base.py
+++ b/sklearn_pmml_model/linear_model/base.py
@@ -1,107 +1,9 @@
-from sklearn_pmml_model.base import PMMLBaseRegressor, PMMLBaseClassifier
-from sklearn.preprocessing import OneHotEncoder
-from sklearn.compose import ColumnTransformer
+from sklearn_pmml_model.base import PMMLBaseRegressor, PMMLBaseClassifier, OneHotEncodingMixin
import numpy as np
from itertools import chain
-class PMMLLinearModel(PMMLBaseRegressor):
- """
- Abstract class for linear models.
-
- """
- def __init__(self, pmml):
- PMMLBaseRegressor.__init__(self, pmml)
-
- # Setup a column transformer to deal with categorical variables
- target = self.target_field.get('name')
- fields = [field for name, field in self.fields.items() if name != target]
-
- def encoder_for(field):
- if field.get('optype') != 'categorical':
- return 'passthrough'
-
- encoder = OneHotEncoder()
- encoder.categories_ = np.array([self.field_mapping[field.get('name')][1].categories])
- encoder.drop_idx_ = np.array([None for x in encoder.categories_])
- encoder._legacy_mode = False
- return encoder
-
- transformer = ColumnTransformer(
- transformers=[
- (field.get('name'), encoder_for(field), [self.field_mapping[field.get('name')][0]])
- for field in fields
- if field.tag == 'DataField'
- ]
- )
-
- X = np.array([[0 for field in fields if field.tag == "DataField"]])
- transformer._validate_transformers()
- transformer._validate_column_callables(X)
- transformer._validate_remainder(X)
- transformer.transformers_ = transformer.transformers
- transformer.sparse_output_ = False
- transformer._feature_names_in = None
-
- self.transformer = transformer
-
- def _prepare_data(self, X):
- """
- Overrides the default data preparation operation by one-hot encoding
- categorical variables.
- """
- return self.transformer.transform(X)
-
-
-class PMMLLinearClassifier(PMMLBaseClassifier):
- """
- Abstract class for linear models.
-
- """
- def __init__(self, pmml):
- PMMLBaseClassifier.__init__(self, pmml)
-
- # Setup a column transformer to deal with categorical variables
- target = self.target_field.get('name')
- fields = [field for name, field in self.fields.items() if name != target]
-
- def encoder_for(field):
- if field.get('optype') != 'categorical':
- return 'passthrough'
-
- encoder = OneHotEncoder()
- encoder.categories_ = np.array([self.field_mapping[field.get('name')][1].categories])
- encoder.drop_idx_ = np.array([None for x in encoder.categories_])
- encoder._legacy_mode = False
- return encoder
-
- transformer = ColumnTransformer(
- transformers=[
- (field.get('name'), encoder_for(field), [self.field_mapping[field.get('name')][0]])
- for field in fields
- if field.tag == 'DataField'
- ]
- )
-
- X = np.array([[0 for field in fields if field.tag == "DataField"]])
- transformer._validate_transformers()
- transformer._validate_column_callables(X)
- transformer._validate_remainder(X)
- transformer.transformers_ = transformer.transformers
- transformer.sparse_output_ = False
- transformer._feature_names_in = None
-
- self.transformer = transformer
-
- def _prepare_data(self, X):
- """
- Overrides the default data preparation operation by one-hot encoding
- categorical variables.
- """
- return self.transformer.transform(X)
-
-
-class PMMLGeneralizedLinearRegressor(PMMLLinearModel):
+class PMMLGeneralizedLinearRegressor(OneHotEncodingMixin, PMMLBaseRegressor):
"""
Abstract class for Generalized Linear Models (GLMs).
@@ -122,7 +24,8 @@ class PMMLGeneralizedLinearRegressor(PMMLLinearModel):
"""
def __init__(self, pmml):
- PMMLLinearModel.__init__(self, pmml)
+ PMMLBaseRegressor.__init__(self, pmml)
+ OneHotEncodingMixin.__init__(self)
# Import coefficients and intercepts
model = self.root.find('GeneralRegressionModel')
@@ -134,7 +37,7 @@ def __init__(self, pmml):
self.intercept_ = _get_intercept(model)
-class PMMLGeneralizedLinearClassifier(PMMLLinearClassifier):
+class PMMLGeneralizedLinearClassifier(OneHotEncodingMixin, PMMLBaseClassifier):
"""
Abstract class for Generalized Linear Models (GLMs).
@@ -155,7 +58,8 @@ class PMMLGeneralizedLinearClassifier(PMMLLinearClassifier):
"""
def __init__(self, pmml):
- PMMLLinearClassifier.__init__(self, pmml)
+ PMMLBaseClassifier.__init__(self, pmml)
+ OneHotEncodingMixin.__init__(self)
# Import coefficients and intercepts
model = self.root.find('GeneralRegressionModel')
diff --git a/sklearn_pmml_model/linear_model/implementations.py b/sklearn_pmml_model/linear_model/implementations.py
index 9109899..457ee67 100644
--- a/sklearn_pmml_model/linear_model/implementations.py
+++ b/sklearn_pmml_model/linear_model/implementations.py
@@ -1,11 +1,11 @@
from sklearn.linear_model import LinearRegression, Ridge, RidgeClassifier, Lasso, ElasticNet, LogisticRegression
-from sklearn_pmml_model.linear_model.base import PMMLLinearModel, PMMLLinearClassifier, PMMLGeneralizedLinearRegressor,\
- PMMLGeneralizedLinearClassifier
+from sklearn_pmml_model.base import PMMLBaseRegressor, PMMLBaseClassifier, OneHotEncodingMixin
+from sklearn_pmml_model.linear_model.base import PMMLGeneralizedLinearRegressor, PMMLGeneralizedLinearClassifier
from itertools import chain
import numpy as np
-class PMMLLinearRegression(PMMLLinearModel, LinearRegression):
+class PMMLLinearRegression(OneHotEncodingMixin, PMMLBaseRegressor, LinearRegression):
"""
Ordinary least squares Linear Regression.
@@ -25,7 +25,8 @@ class PMMLLinearRegression(PMMLLinearModel, LinearRegression):
"""
def __init__(self, pmml):
- PMMLLinearModel.__init__(self, pmml)
+ PMMLBaseRegressor.__init__(self, pmml)
+ OneHotEncodingMixin.__init__(self)
# Import coefficients and intercepts
model = self.root.find('RegressionModel')
@@ -51,13 +52,13 @@ def __init__(self, pmml):
self.intercept_ = self.intercept_[0]
def fit(self, x, y):
- return PMMLLinearModel.fit(self, x, y)
+ return PMMLBaseRegressor.fit(self, x, y)
def _more_tags(self):
return LinearRegression._more_tags(self)
-class PMMLLogisticRegression(PMMLLinearClassifier, LogisticRegression):
+class PMMLLogisticRegression(OneHotEncodingMixin, PMMLBaseClassifier, LogisticRegression):
"""
Logistic Regression (aka logit, MaxEnt) classifier.
@@ -77,7 +78,8 @@ class PMMLLogisticRegression(PMMLLinearClassifier, LogisticRegression):
"""
def __init__(self, pmml):
- PMMLLinearClassifier.__init__(self, pmml)
+ PMMLBaseClassifier.__init__(self, pmml)
+ OneHotEncodingMixin.__init__(self)
# Import coefficients and intercepts
model = self.root.find('RegressionModel')
@@ -111,7 +113,7 @@ def __init__(self, pmml):
self.solver = 'lbfgs'
def fit(self, x, y):
- return PMMLLinearClassifier.fit(self, x, y)
+ return PMMLBaseClassifier.fit(self, x, y)
def _more_tags(self):
return LogisticRegression._more_tags(self)
diff --git a/sklearn_pmml_model/naive_bayes/implementations.py b/sklearn_pmml_model/naive_bayes/implementations.py
index a94d7c5..5cc91d0 100644
--- a/sklearn_pmml_model/naive_bayes/implementations.py
+++ b/sklearn_pmml_model/naive_bayes/implementations.py
@@ -1,10 +1,10 @@
-from sklearn_pmml_model.base import PMMLBaseClassifier
+from sklearn_pmml_model.base import PMMLBaseClassifier, OneHotEncodingMixin
from sklearn.naive_bayes import GaussianNB
import numpy as np
from itertools import chain
-class PMMLGaussianNB(PMMLBaseClassifier, GaussianNB):
+class PMMLGaussianNB(OneHotEncodingMixin, PMMLBaseClassifier, GaussianNB):
"""
Gaussian Naive Bayes (GaussianNB)
@@ -26,6 +26,7 @@ class PMMLGaussianNB(PMMLBaseClassifier, GaussianNB):
"""
def __init__(self, pmml):
PMMLBaseClassifier.__init__(self, pmml)
+ OneHotEncodingMixin.__init__(self)
model = self.root.find('NaiveBayesModel')
diff --git a/tests/naive_bayes/test_naive_bayes.py b/tests/naive_bayes/test_naive_bayes.py
index 7376ea4..e31e513 100644
--- a/tests/naive_bayes/test_naive_bayes.py
+++ b/tests/naive_bayes/test_naive_bayes.py
@@ -74,15 +74,16 @@ class TestGaussianNBIntegration(TestCase):
def setUp(self):
df = pd.read_csv(path.join(BASE_DIR, '../models/categorical-test.csv'))
Xte = df.iloc[:, 1:]
- Xte = pd.get_dummies(Xte, prefix_sep='')
+ Xenc = pd.get_dummies(Xte, prefix_sep='')
yte = df.iloc[:, 0]
self.test = (Xte, yte)
+ self.enc = (Xenc, yte)
pmml = path.join(BASE_DIR, '../models/nb-cat-pima.pmml')
self.clf = PMMLGaussianNB(pmml)
self.ref = GaussianNB()
- self.ref.fit(Xte, yte)
+ self.ref.fit(Xenc, yte)
def test_predict_proba(self):
Xte, _ = self.test
@@ -110,7 +111,7 @@ def test_sklearn2pmml(self):
pipeline = PMMLPipeline([
("classifier", self.ref)
])
- pipeline.fit(self.test[0], self.test[1])
+ pipeline.fit(self.enc[0], self.enc[1])
sklearn2pmml(pipeline, "gnb-sklearn2pmml.pmml", with_repr = True)
try:
@@ -118,10 +119,10 @@ def test_sklearn2pmml(self):
model = PMMLGaussianNB(pmml='gnb-sklearn2pmml.pmml')
# Verify classification
- Xte, _ = self.test
- assert np.array_equal(
- self.ref.predict_proba(Xte),
- model.predict_proba(Xte)
+ Xenc, _ = self.enc
+ assert np.allclose(
+ self.ref.predict_proba(Xenc),
+ model.predict_proba(Xenc)
)
finally:
@@ -152,7 +153,7 @@ def test_sklearn2pmml(self):
# Verify classification
Xte, _ = self.test
- assert np.array_equal(
+ assert np.allclose(
self.ref.predict_proba(Xte),
model.predict_proba(Xte)
)
diff --git a/tests/test_base.py b/tests/test_base.py
index 22f3436..92b044b 100644
--- a/tests/test_base.py
+++ b/tests/test_base.py
@@ -221,3 +221,102 @@ def test_fit_exception(self):
clf.fit(X, y)
assert str(cm.exception) == "Not supported."
+
+ def test_prepare_data_removes_unused_columns(self):
+ clf = PMMLBaseEstimator(pmml=StringIO("""
+
+
+
+
+
+
+
+
+
+
+
+
+
+ """))
+
+ X = pd.DataFrame(data=[[1, 2], [3, 4], [5, 6]], columns=["test1", "test2"])
+ result = clf._prepare_data(X)
+
+ assert list(X.columns) == ["test1", "test2"]
+ assert list(result.columns) == ["test1"]
+
+ def test_prepare_data_reorders_columns(self):
+ clf = PMMLBaseEstimator(pmml=StringIO("""
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ """))
+
+ X = pd.DataFrame(data=[[1, 2], [3, 4], [5, 6]], columns=["test1", "test2"])
+ result = clf._prepare_data(X)
+
+ assert list(X.columns) == ["test1", "test2"]
+ assert list(result.columns) == ["test2", "test1"]
+
+ def test_prepare_data_exception_mismatch_columns_numpy(self):
+ clf = PMMLBaseEstimator(pmml=StringIO("""
+
+
+
+
+
+
+
+
+
+
+
+
+
+ """))
+
+ X = pd.DataFrame(data=[[1, 2], [3, 4], [5, 6]], columns=["test1", "test2"])
+
+ with self.assertRaises(Exception) as cm:
+ clf._prepare_data(np.asanyarray(X))
+
+ assert str(cm.exception) == "The number of features in provided data does not match expected number of features " \
+ "in the PMML. Provide pandas.Dataframe, or provide data matching the DataFields in " \
+ "the PMML document."
+
+ def test_prepare_data_exception_mismatch_columns_pandas(self):
+ clf = PMMLBaseEstimator(pmml=StringIO("""
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ """))
+
+ X = pd.DataFrame(data=[[1, 2], [3, 4], [5, 6]], columns=["Test_1", "Test_2"])
+
+ with self.assertRaises(Exception) as cm:
+ clf._prepare_data(X)
+
+ assert str(cm.exception) == "The features in the input data do not match features expected by the PMML model."
+
diff --git a/tests/tree/test_tree.py b/tests/tree/test_tree.py
index dd744dd..d2ffde8 100644
--- a/tests/tree/test_tree.py
+++ b/tests/tree/test_tree.py
@@ -88,15 +88,12 @@ def test_more_tags(self):
class TestIrisTreeIntegration(TestCase):
def setUp(self):
pair = [0, 1]
- data = load_iris()
+ data = load_iris(as_frame=True)
- X = pd.DataFrame(data.data[:, pair])
- X.columns = np.array(data.feature_names)[pair]
- y = pd.Series(np.array(data.target_names)[data.target])
+ X = data.data
+ y = data.target
y.name = "Class"
- X, Xte, y, yte = train_test_split(X, y, test_size=0.33, random_state=123)
- self.test = (Xte, yte)
- self.train = (X, y)
+ self.test = (X, y)
pmml = path.join(BASE_DIR, '../models/tree-iris.pmml')
self.clf = PMMLTreeClassifier(pmml=pmml)
@@ -122,7 +119,7 @@ def test_sklearn2pmml(self):
pipeline = PMMLPipeline([
("classifier", self.ref)
])
- pipeline.fit(self.train[0], self.train[1])
+ pipeline.fit(self.test[0], self.test[1])
sklearn2pmml(pipeline, "tree-sklearn2pmml.pmml", with_repr = True)
try:
@@ -145,9 +142,6 @@ class TestDigitsTreeIntegration(TestCase):
def setUp(self):
data = load_digits()
- self.columns = [2, 3, 4, 5, 6, 7, 9, 10, 13, 14, 17, 18, 19, 20, 21, 25, 26,
- 27, 28, 29, 30, 33, 34, 35, 36, 37, 38, 41, 42, 43, 45, 46,
- 50, 51, 52, 53, 54, 55, 57, 58, 59, 60, 61, 62, 63]
X = pd.DataFrame(data.data)
y = pd.Series(np.array(data.target_names)[data.target])
y.name = "Class"
@@ -161,19 +155,19 @@ def test_predict(self):
Xte, _ = self.test
assert np.array_equal(
self.ref.predict(Xte),
- self.clf.predict(Xte[self.columns]).astype(np.int64)
+ self.clf.predict(Xte)
)
def test_predict_proba(self):
Xte, _ = self.test
assert np.array_equal(
self.ref.predict_proba(Xte),
- self.clf.predict_proba(Xte[self.columns])
+ self.clf.predict_proba(Xte)
)
def test_score(self):
Xte, yte = self.test
- assert self.ref.score(Xte, yte) == self.clf.score(Xte[self.columns], yte)
+ assert self.ref.score(Xte, yte) == self.clf.score(Xte, yte)
class TestCategoricalTreeIntegration(TestCase):