diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index be4ce5e891f27..2fa01fbc9d9dd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -124,7 +124,7 @@ class KMeans private ( } // Compute squared norms and cache them. - val norms = data.map(_.norm(2.0)) + val norms = data.map(Vectors.norm(_, 2.0)) norms.persist() val breezeData = data.map(_.toBreeze).zip(norms).map { case (v, norm) => new BreezeVectorWithNorm(v, norm) @@ -424,7 +424,7 @@ object KMeans { private[clustering] class BreezeVectorWithNorm(val vector: BV[Double], val norm: Double) extends Serializable { - def this(vector: BV[Double]) = this(vector, Vectors.fromBreeze(vector).norm(2.0)) + def this(vector: BV[Double]) = this(vector, Vectors.norm(Vectors.fromBreeze(vector), 2.0)) def this(array: Array[Double]) = this(new BDV[Double](array)) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala index 5d36337f5887e..1ced26a9b70a2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala @@ -45,7 +45,7 @@ class Normalizer(p: Double) extends VectorTransformer { * @return normalized vector. If the norm of the input is zero, it will return the input vector. */ override def transform(vector: Vector): Vector = { - val norm = vector.norm(p) + val norm = Vectors.norm(vector, p) if (norm != 0.0) { // For dense vector, we've to allocate new memory for new output vector. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 81f11cf147a74..47d1a76fa361d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -85,13 +85,6 @@ sealed trait Vector extends Serializable { * with type `Double`. */ private[spark] def foreachActive(f: (Int, Double) => Unit) - - /** - * Returns the p-norm of this vector. - * @param p norm. - * @return norm in L^p^ space. - */ - private[spark] def norm(p: Double): Double } /** @@ -269,9 +262,21 @@ object Vectors { } } - private[linalg] def norm(p: Double, values: Array[Double]): Double = { + /** + * Returns the p-norm of this vector. + * @param vector input vector. + * @param p norm. + * @return norm in L^p^ space. + */ + private[spark] def norm(vector: Vector, p: Double): Double = { require(p >= 1.0) + val values = vector match { + case dv: DenseVector => dv.values + case sv: SparseVector => sv.values + case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass) + } val size = values.size + if (p == 1) { var sum = 0.0 var i = 0 @@ -339,8 +344,6 @@ class DenseVector(val values: Array[Double]) extends Vector { i += 1 } } - - private[spark] override def norm(p: Double): Double = Vectors.norm(p, values) } /** @@ -389,6 +392,4 @@ class SparseVector( i += 1 } } - - private[spark] override def norm(p: Double): Double = Vectors.norm(p, values) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index 1181e8ffc73a4..f04ce6b6d8e8f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -203,18 +203,22 @@ class VectorsSuite extends FunSuite { val dv = Vectors.dense(0.0, -1.2, 3.1, 0.0, -4.5, 1.9) val sv = Vectors.sparse(6, Seq((1, -1.2), (2, 3.1), (3, 0.0), (4, -4.5), (5, 1.9))) - assert(dv.norm(1.0) ~== dv.toArray.foldLeft(0.0)((a, v) => a + math.abs(v)) relTol 1E-8) - assert(sv.norm(1.0) ~== sv.toArray.foldLeft(0.0)((a, v) => a + math.abs(v)) relTol 1E-8) + assert(Vectors.norm(dv, 1.0) ~== dv.toArray.foldLeft(0.0)((a, v) => + a + math.abs(v)) relTol 1E-8) + assert(Vectors.norm(sv, 1.0) ~== sv.toArray.foldLeft(0.0)((a, v) => + a + math.abs(v)) relTol 1E-8) - assert(dv.norm(2.0) ~== math.sqrt(dv.toArray.foldLeft(0.0)((a, v) => a + v * v)) relTol 1E-8) - assert(sv.norm(2.0) ~== math.sqrt(sv.toArray.foldLeft(0.0)((a, v) => a + v * v)) relTol 1E-8) + assert(Vectors.norm(dv, 2.0) ~== math.sqrt(dv.toArray.foldLeft(0.0)((a, v) => + a + v * v)) relTol 1E-8) + assert(Vectors.norm(sv, 2.0) ~== math.sqrt(sv.toArray.foldLeft(0.0)((a, v) => + a + v * v)) relTol 1E-8) - assert(dv.norm(Double.PositiveInfinity) ~== dv.toArray.map(math.abs).max relTol 1E-8) - assert(sv.norm(Double.PositiveInfinity) ~== sv.toArray.map(math.abs).max relTol 1E-8) + assert(Vectors.norm(dv, Double.PositiveInfinity) ~== dv.toArray.map(math.abs).max relTol 1E-8) + assert(Vectors.norm(sv, Double.PositiveInfinity) ~== sv.toArray.map(math.abs).max relTol 1E-8) - assert(dv.norm(3.7) ~== math.pow(dv.toArray.foldLeft(0.0)((a, v) => + assert(Vectors.norm(dv, 3.7) ~== math.pow(dv.toArray.foldLeft(0.0)((a, v) => a + math.pow(math.abs(v), 3.7)), 1.0 / 3.7) relTol 1E-8) - assert(sv.norm(3.7) ~== math.pow(dv.toArray.foldLeft(0.0)((a, v) => + assert(Vectors.norm(sv, 3.7) ~== math.pow(dv.toArray.foldLeft(0.0)((a, v) => a + math.pow(math.abs(v), 3.7)), 1.0 / 3.7) relTol 1E-8) } }