Skip to content

Commit

Permalink
MLlib Python API consistency check
Browse files Browse the repository at this point in the history
  • Loading branch information
yanboliang committed Jun 23, 2015
1 parent 6ceb169 commit e763d32
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 44 deletions.
7 changes: 5 additions & 2 deletions python/pyspark/mllib/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ def k(self):

def predict(self, x):
"""Find the cluster to which x belongs in this model."""
if isinstance(x, RDD):
return x.map(lambda v: self.predict(v))

best = 0
best_distance = float("inf")
if isinstance(x, RDD):
Expand All @@ -114,12 +117,12 @@ def predict(self, x):
best_distance = distance
return best

def computeCost(self, rdd):
def computeCost(self, data):
"""
Return the K-means cost (sum of squared distances of points to
their nearest center) for this model on the given data.
"""
cost = callMLlibFunc("computeCostKmeansModel", rdd.map(_convert_to_vector),
cost = callMLlibFunc("computeCostKmeansModel", data.map(_convert_to_vector),
[_convert_to_vector(c) for c in self.centers])
return cost

Expand Down
51 changes: 17 additions & 34 deletions python/pyspark/mllib/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,15 @@ class JavaVectorTransformer(JavaModelWrapper, VectorTransformer):
"""

def transform(self, vector):
"""
Applies transformation on a vector or an RDD[Vector].
Note: In Python, transform cannot currently be used within
an RDD transformation or action.
Call transform directly on the RDD instead.
:param vector: Vector or RDD of Vector to be transformed.
"""
if isinstance(vector, RDD):
vector = vector.map(_convert_to_vector)
else:
Expand All @@ -124,20 +133,6 @@ class StandardScalerModel(JavaVectorTransformer):
Represents a StandardScaler model that can transform vectors.
"""
def transform(self, vector):
"""
Applies standardization transformation on a vector.
Note: In Python, transform cannot currently be used within
an RDD transformation or action.
Call transform directly on the RDD instead.
:param vector: Vector or RDD of Vector to be standardized.
:return: Standardized vector. If the variance of a column is
zero, it will return default `0.0` for the column with
zero variance.
"""
return JavaVectorTransformer.transform(self, vector)

def setWithMean(self, withMean):
"""
Expand Down Expand Up @@ -186,7 +181,7 @@ def __init__(self, withMean=False, withStd=True):
self.withMean = withMean
self.withStd = withStd

def fit(self, dataset):
def fit(self, data):
"""
Computes the mean and variance and stores as a model to be used
for later scaling.
Expand All @@ -195,8 +190,8 @@ def fit(self, dataset):
to build the transformation model.
:return: a StandardScalarModel
"""
dataset = dataset.map(_convert_to_vector)
jmodel = callMLlibFunc("fitStandardScaler", self.withMean, self.withStd, dataset)
data = data.map(_convert_to_vector)
jmodel = callMLlibFunc("fitStandardScaler", self.withMean, self.withStd, data)
return StandardScalerModel(jmodel)


Expand All @@ -206,14 +201,6 @@ class ChiSqSelectorModel(JavaVectorTransformer):
Represents a Chi Squared selector model.
"""
def transform(self, vector):
"""
Applies transformation on a vector.
:param vector: Vector or RDD of Vector to be transformed.
:return: transformed vector.
"""
return JavaVectorTransformer.transform(self, vector)


class ChiSqSelector(object):
Expand Down Expand Up @@ -330,7 +317,7 @@ class IDFModel(JavaVectorTransformer):
"""
Represents an IDF model that can transform term frequency vectors.
"""
def transform(self, x):
def transform(self, vector):
"""
Transforms term frequency (TF) vectors to TF-IDF vectors.
Expand All @@ -346,11 +333,7 @@ def transform(self, x):
vector
:return: an RDD of TF-IDF vectors or a TF-IDF vector
"""
if isinstance(x, RDD):
return JavaVectorTransformer.transform(self, x)

x = _convert_to_vector(x)
return JavaVectorTransformer.transform(self, x)
return JavaVectorTransformer.transform(self, vector)

def idf(self):
"""
Expand Down Expand Up @@ -540,16 +523,16 @@ def setMinCount(self, minCount):
self.minCount = minCount
return self

def fit(self, data):
def fit(self, dataset):
"""
Computes the vector representation of each word in vocabulary.
:param data: training data. RDD of list of string
:return: Word2VecModel instance
"""
if not isinstance(data, RDD):
if not isinstance(dataset, RDD):
raise TypeError("data should be an RDD of list of string")
jmodel = callMLlibFunc("trainWord2Vec", data, int(self.vectorSize),
jmodel = callMLlibFunc("trainWord2Vec", dataset, int(self.vectorSize),
float(self.learningRate), int(self.numPartitions),
int(self.numIterations), int(self.seed),
int(self.minCount))
Expand Down
20 changes: 12 additions & 8 deletions python/pyspark/mllib/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,14 @@ def predict(self, x):
else:
return self.call("predict", _convert_to_vector(x))

@property
def numTrees(self):
"""
Get number of trees in ensemble.
"""
return self.call("numTrees")

@property
def totalNumNodes(self):
"""
Get total number of nodes, summed over all trees in the
Expand Down Expand Up @@ -90,9 +92,11 @@ def predict(self, x):
else:
return self.call("predict", _convert_to_vector(x))

@property
def numNodes(self):
return self._java_model.numNodes()

@property
def depth(self):
return self._java_model.depth()

Expand Down Expand Up @@ -316,9 +320,9 @@ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, numTrees,
... LabeledPoint(1.0, [3.0])
... ]
>>> model = RandomForest.trainClassifier(sc.parallelize(data), 2, {}, 3, seed=42)
>>> model.numTrees()
>>> model.numTrees
3
>>> model.totalNumNodes()
>>> model.totalNumNodes
7
>>> print(model)
TreeEnsembleModel classifier with 3 trees
Expand Down Expand Up @@ -396,9 +400,9 @@ def trainRegressor(cls, data, categoricalFeaturesInfo, numTrees, featureSubsetSt
... ]
>>>
>>> model = RandomForest.trainRegressor(sc.parallelize(sparse_data), {}, 2, seed=42)
>>> model.numTrees()
>>> model.numTrees
2
>>> model.totalNumNodes()
>>> model.totalNumNodes
4
>>> model.predict(SparseVector(2, {1: 1.0}))
1.0
Expand Down Expand Up @@ -486,9 +490,9 @@ def trainClassifier(cls, data, categoricalFeaturesInfo,
... ]
>>>
>>> model = GradientBoostedTrees.trainClassifier(sc.parallelize(data), {}, numIterations=10)
>>> model.numTrees()
>>> model.numTrees
10
>>> model.totalNumNodes()
>>> model.totalNumNodes
30
>>> print(model) # it already has newline
TreeEnsembleModel classifier with 10 trees
Expand Down Expand Up @@ -549,9 +553,9 @@ def trainRegressor(cls, data, categoricalFeaturesInfo,
>>>
>>> data = sc.parallelize(sparse_data)
>>> model = GradientBoostedTrees.trainRegressor(data, {}, numIterations=10)
>>> model.numTrees()
>>> model.numTrees
10
>>> model.totalNumNodes()
>>> model.totalNumNodes
12
>>> model.predict(SparseVector(2, {1: 1.0}))
1.0
Expand Down

0 comments on commit e763d32

Please sign in to comment.