Skip to content

Commit

Permalink
Fix appendBias return type
Browse files Browse the repository at this point in the history
  • Loading branch information
Lewuathe committed May 10, 2015
1 parent 454c73d commit 62a9c7e
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,7 @@ private[python] class PythonMLLibAPI extends Serializable {
* @param path file or directory path in any Hadoop-supported file system URI
* @return serialized vectors in a RDD
*/
def loadVectors(jsc: JavaSparkContext,
path: String): RDD[Vector] =
def loadVectors(jsc: JavaSparkContext, path: String): RDD[Vector] =
MLUtils.loadVectors(jsc.sc, path)

private def trainRegressionModel(
Expand Down
7 changes: 7 additions & 0 deletions python/pyspark/mllib/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,6 +824,13 @@ def test_append_bias(self):
data = [2.0, 2.0, 2.0]
ret = MLUtils.appendBias(data)
self.assertEqual(ret[3], 1.0)
self.assertEqual(type(ret), list)

def test_append_bias_with_vector(self):
data = Vectors.dense([2.0, 2.0, 2.0])
ret = MLUtils.appendBias(data)
self.assertEqual(ret[3], 1.0)
self.assertEqual(type(ret), list)

def test_load_vectors(self):
import shutil
Expand Down
9 changes: 6 additions & 3 deletions python/pyspark/mllib/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
xrange = range

from pyspark.mllib.common import callMLlibFunc, inherit_doc
from pyspark.mllib.linalg import Vectors, SparseVector, _convert_to_vector
from pyspark.mllib.linalg import Vector, Vectors, SparseVector, _convert_to_vector


class MLUtils(object):
Expand Down Expand Up @@ -172,10 +172,13 @@ def loadLabeledPoints(sc, path, minPartitions=None):
@staticmethod
def appendBias(data):
"""
Returns a new vector with `1.0` (bias) appended to the input vector.
Returns a new vector with `1.0` (bias) appended to
the end of the input vector.
"""
vec = _convert_to_vector(data)
return np.append(vec, 1.0)
if isinstance(vec, Vector):
vec = vec.toArray()
return np.append(vec, 1.0).tolist()

@staticmethod
def loadVectors(sc, path):
Expand Down

0 comments on commit 62a9c7e

Please sign in to comment.