Skip to content

Commit

Permalink
[SW-1477] Fix bug with setting init on KMeans (#1397)
Browse files Browse the repository at this point in the history
(cherry picked from commit dbf549a)
  • Loading branch information
jakubhava committed Jul 31, 2019
1 parent 0d3603a commit 456673c
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 5 deletions.
10 changes: 5 additions & 5 deletions py/ai/h2o/sparkling/ml/algos/H2OKmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from ai.h2o.sparkling.ml.utils import getDoubleArrayArrayFromIntArrayArray
from py_sparkling.ml.models import H2OMOJOModel
from py_sparkling.ml.util import set_double_values, validateEnumValue
from py_sparkling.ml.util import set_double_values, validateEnumValue, getValidatedEnumValue
from pysparkling.ml.params import H2OAlgoUnsupervisedParams
from pysparkling.spark_specifics import get_input_kwargs

Expand Down Expand Up @@ -59,7 +59,7 @@ def setStandardize(self, value):
return self._set(standardize=value)

def setInit(self, value):
validated = validateEnumValue(self.__getInitEnum(), value)
validated = getValidateEnumValue(self.__getInitEnum(), value)
return self._set(init=validated)

def setUserPoints(self, value):
Expand All @@ -75,7 +75,7 @@ def setK(self, value):
return self._set(k=value)

def __getInitEnum(self):
return "hex.kmeans$Initialization"
return "hex.kmeans.KMeans$Initialization"


class H2OKMeans(H2OKMeansParams, JavaEstimator, JavaMLReadable, JavaMLWritable):
Expand Down Expand Up @@ -174,8 +174,8 @@ def setParams(self,
double_types = ["splitRatio"]
set_double_values(kwargs, double_types)

if "init" in kwargs:
kwargs["init"] = getDoubleArrayArrayFromIntArrayArray(kwargs["init"])
if "userPoints" in kwargs:
kwargs["userPoints"] = getDoubleArrayArrayFromIntArrayArray(kwargs["userPoints"])

return self._set(**kwargs)

Expand Down
52 changes: 52 additions & 0 deletions py/tests/tests_unit_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,58 @@ def test_grid_params(self):
self.assertEquals(grid.getWithDetailedPredictionCol(), False)
self.assertEquals(grid.getConvertInvalidNumbersToNa(), False)

def test_kmeans_params(self):
kmeans = H2OKMeans(
predictionCol="prediction",
detailedPredictionCol="detailed_prediction",
withDetailedPredictionCol=False,
featuresCols=[],
foldCol=None,
weightCol=None,
splitRatio=1.0,
seed=-1,
nfolds=0,
allStringColumnsToCategorical=True,
columnsToCategorical=[],
convertUnknownCategoricalLevelsToNa=False,
convertInvalidNumbersToNa=False,
modelId=None,
keepCrossValidationPredictions=False,
keepCrossValidationFoldAssignment=False,
parallelizeCrossValidation=True,
distribution="AUTO",
maxIterations=10,
standardize=True,
init="Furthest",
userPoints=None,
estimateK=False,
k=2)

self.assertEquals(kmeans.getPredictionCol(), "prediction")
self.assertEquals(kmeans.getDetailedPredictionCol(), "detailed_prediction")
self.assertEquals(kmeans.getWithDetailedPredictionCol(), False)
self.assertEquals(kmeans.getFeaturesCols(), [])
self.assertEquals(kmeans.getFoldCol(), None)
self.assertEquals(kmeans.getWeightCol(), None)
self.assertEquals(kmeans.getSplitRatio(), 1.0)
self.assertEquals(kmeans.getSeed(), -1)
self.assertEquals(kmeans.getNfolds(), 0)
self.assertEquals(kmeans.getAllStringColumnsToCategorical(), True)
self.assertEquals(kmeans.getColumnsToCategorical(), [])
self.assertEquals(kmeans.getConvertUnknownCategoricalLevelsToNa(), False)
self.assertEquals(kmeans.getConvertInvalidNumbersToNa(), False)
self.assertEquals(kmeans.getModelId(), None)
self.assertEquals(kmeans.getKeepCrossValidationPredictions(), False)
self.assertEquals(kmeans.getKeepCrossValidationFoldAssignment(), False)
self.assertEquals(kmeans.getParallelizeCrossValidation(), True)
self.assertEquals(kmeans.getDistribution(), "AUTO")
self.assertEquals(kmeans.getMaxIterations(), 10)
self.assertEquals(kmeans.getStandardize(), True)
self.assertEquals(kmeans.getInit(), "Furthest")
self.assertEquals(kmeans.getUserPoints(), None)
self.assertEquals(kmeans.getEstimateK(), False)
self.assertEquals(kmeans.getK(), 2)


if __name__ == '__main__':
generic_test_utils.run_tests([H2OConfTest], file_name="py_unit_tests_conf_report")

0 comments on commit 456673c

Please sign in to comment.