Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
yanboliang committed Jun 23, 2015
1 parent e763d32 commit 9e7ec3c
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 23 deletions.
4 changes: 2 additions & 2 deletions python/pyspark/mllib/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,12 @@ def predict(self, x):
best_distance = distance
return best

def computeCost(self, data):
def computeCost(self, rdd):
"""
Return the K-means cost (sum of squared distances of points to
their nearest center) for this model on the given data.
"""
cost = callMLlibFunc("computeCostKmeansModel", data.map(_convert_to_vector),
cost = callMLlibFunc("computeCostKmeansModel", rdd.map(_convert_to_vector),
[_convert_to_vector(c) for c in self.centers])
return cost

Expand Down
34 changes: 25 additions & 9 deletions python/pyspark/mllib/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from pyspark.mllib.linalg import (
Vector, Vectors, DenseVector, SparseVector, _convert_to_vector)
from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.util import inherit_doc

__all__ = ['Normalizer', 'StandardScalerModel', 'StandardScaler',
'HashingTF', 'IDFModel', 'IDF', 'Word2Vec', 'Word2VecModel',
Expand Down Expand Up @@ -133,6 +134,20 @@ class StandardScalerModel(JavaVectorTransformer):
Represents a StandardScaler model that can transform vectors.
"""
def transform(self, vector):
"""
Applies standardization transformation on a vector.
Note: In Python, transform cannot currently be used within
an RDD transformation or action.
Call transform directly on the RDD instead.
:param vector: Vector or RDD of Vector to be standardized.
:return: Standardized vector. If the variance of a column is
zero, it will return default `0.0` for the column with
zero variance.
"""
return JavaVectorTransformer.transform(self, vector)

def setWithMean(self, withMean):
"""
Expand Down Expand Up @@ -181,20 +196,21 @@ def __init__(self, withMean=False, withStd=True):
self.withMean = withMean
self.withStd = withStd

def fit(self, data):
def fit(self, dataset):
"""
Computes the mean and variance and stores as a model to be used
for later scaling.
:param data: The data used to compute the mean and variance
:param dataset: The data used to compute the mean and variance
to build the transformation model.
:return: a StandardScalarModel
"""
data = data.map(_convert_to_vector)
jmodel = callMLlibFunc("fitStandardScaler", self.withMean, self.withStd, data)
dataset = dataset.map(_convert_to_vector)
jmodel = callMLlibFunc("fitStandardScaler", self.withMean, self.withStd, dataset)
return StandardScalerModel(jmodel)


@inherit_doc
class ChiSqSelectorModel(JavaVectorTransformer):
"""
.. note:: Experimental
Expand Down Expand Up @@ -317,7 +333,7 @@ class IDFModel(JavaVectorTransformer):
"""
Represents an IDF model that can transform term frequency vectors.
"""
def transform(self, vector):
def transform(self, x):
"""
Transforms term frequency (TF) vectors to TF-IDF vectors.
Expand All @@ -333,7 +349,7 @@ def transform(self, vector):
vector
:return: an RDD of TF-IDF vectors or a TF-IDF vector
"""
return JavaVectorTransformer.transform(self, vector)
return JavaVectorTransformer.transform(self, x)

def idf(self):
"""
Expand Down Expand Up @@ -523,16 +539,16 @@ def setMinCount(self, minCount):
self.minCount = minCount
return self

def fit(self, dataset):
def fit(self, data):
"""
Computes the vector representation of each word in vocabulary.
:param data: training data. RDD of list of string
:return: Word2VecModel instance
"""
if not isinstance(dataset, RDD):
if not isinstance(data, RDD):
raise TypeError("data should be an RDD of list of string")
jmodel = callMLlibFunc("trainWord2Vec", dataset, int(self.vectorSize),
jmodel = callMLlibFunc("trainWord2Vec", data, int(self.vectorSize),
float(self.learningRate), int(self.numPartitions),
int(self.numIterations), int(self.seed),
int(self.minCount))
Expand Down
20 changes: 8 additions & 12 deletions python/pyspark/mllib/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,12 @@ def predict(self, x):
else:
return self.call("predict", _convert_to_vector(x))

@property
def numTrees(self):
"""
Get number of trees in ensemble.
"""
return self.call("numTrees")

@property
def totalNumNodes(self):
"""
Get total number of nodes, summed over all trees in the
Expand Down Expand Up @@ -92,11 +90,9 @@ def predict(self, x):
else:
return self.call("predict", _convert_to_vector(x))

@property
def numNodes(self):
return self._java_model.numNodes()

@property
def depth(self):
return self._java_model.depth()

Expand Down Expand Up @@ -320,9 +316,9 @@ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, numTrees,
... LabeledPoint(1.0, [3.0])
... ]
>>> model = RandomForest.trainClassifier(sc.parallelize(data), 2, {}, 3, seed=42)
>>> model.numTrees
>>> model.numTrees()
3
>>> model.totalNumNodes
>>> model.totalNumNodes()
7
>>> print(model)
TreeEnsembleModel classifier with 3 trees
Expand Down Expand Up @@ -400,9 +396,9 @@ def trainRegressor(cls, data, categoricalFeaturesInfo, numTrees, featureSubsetSt
... ]
>>>
>>> model = RandomForest.trainRegressor(sc.parallelize(sparse_data), {}, 2, seed=42)
>>> model.numTrees
>>> model.numTrees()
2
>>> model.totalNumNodes
>>> model.totalNumNodes()
4
>>> model.predict(SparseVector(2, {1: 1.0}))
1.0
Expand Down Expand Up @@ -490,9 +486,9 @@ def trainClassifier(cls, data, categoricalFeaturesInfo,
... ]
>>>
>>> model = GradientBoostedTrees.trainClassifier(sc.parallelize(data), {}, numIterations=10)
>>> model.numTrees
>>> model.numTrees()
10
>>> model.totalNumNodes
>>> model.totalNumNodes()
30
>>> print(model) # it already has newline
TreeEnsembleModel classifier with 10 trees
Expand Down Expand Up @@ -553,9 +549,9 @@ def trainRegressor(cls, data, categoricalFeaturesInfo,
>>>
>>> data = sc.parallelize(sparse_data)
>>> model = GradientBoostedTrees.trainRegressor(data, {}, numIterations=10)
>>> model.numTrees
>>> model.numTrees()
10
>>> model.totalNumNodes
>>> model.totalNumNodes()
12
>>> model.predict(SparseVector(2, {1: 1.0}))
1.0
Expand Down

0 comments on commit 9e7ec3c

Please sign in to comment.