Skip to content

Commit

Permalink
address feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
DB Tsai committed Nov 26, 2014
1 parent 9b7cb56 commit 6fa616c
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ class KMeans private (
}

// Compute squared norms and cache them.
val norms = data.map(_.norm(2.0))
val norms = data.map(Vectors.norm(_, 2.0))
norms.persist()
val breezeData = data.map(_.toBreeze).zip(norms).map { case (v, norm) =>
new BreezeVectorWithNorm(v, norm)
Expand Down Expand Up @@ -424,7 +424,7 @@ object KMeans {
private[clustering]
class BreezeVectorWithNorm(val vector: BV[Double], val norm: Double) extends Serializable {

def this(vector: BV[Double]) = this(vector, Vectors.fromBreeze(vector).norm(2.0))
def this(vector: BV[Double]) = this(vector, Vectors.norm(Vectors.fromBreeze(vector), 2.0))

def this(array: Array[Double]) = this(new BDV[Double](array))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class Normalizer(p: Double) extends VectorTransformer {
* @return normalized vector. If the norm of the input is zero, it will return the input vector.
*/
override def transform(vector: Vector): Vector = {
val norm = vector.norm(p)
val norm = Vectors.norm(vector, p)

if (norm != 0.0) {
// For dense vector, we've to allocate new memory for new output vector.
Expand Down
25 changes: 13 additions & 12 deletions mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,6 @@ sealed trait Vector extends Serializable {
* with type `Double`.
*/
private[spark] def foreachActive(f: (Int, Double) => Unit)

/**
* Returns the p-norm of this vector.
* @param p norm.
* @return norm in L^p^ space.
*/
private[spark] def norm(p: Double): Double
}

/**
Expand Down Expand Up @@ -269,9 +262,21 @@ object Vectors {
}
}

private[linalg] def norm(p: Double, values: Array[Double]): Double = {
/**
* Returns the p-norm of this vector.
* @param vector input vector.
* @param p norm.
* @return norm in L^p^ space.
*/
private[spark] def norm(vector: Vector, p: Double): Double = {
require(p >= 1.0)
val values = vector match {
case dv: DenseVector => dv.values
case sv: SparseVector => sv.values
case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
}
val size = values.size

if (p == 1) {
var sum = 0.0
var i = 0
Expand Down Expand Up @@ -339,8 +344,6 @@ class DenseVector(val values: Array[Double]) extends Vector {
i += 1
}
}

private[spark] override def norm(p: Double): Double = Vectors.norm(p, values)
}

/**
Expand Down Expand Up @@ -389,6 +392,4 @@ class SparseVector(
i += 1
}
}

private[spark] override def norm(p: Double): Double = Vectors.norm(p, values)
}
Original file line number Diff line number Diff line change
Expand Up @@ -203,18 +203,22 @@ class VectorsSuite extends FunSuite {
val dv = Vectors.dense(0.0, -1.2, 3.1, 0.0, -4.5, 1.9)
val sv = Vectors.sparse(6, Seq((1, -1.2), (2, 3.1), (3, 0.0), (4, -4.5), (5, 1.9)))

assert(dv.norm(1.0) ~== dv.toArray.foldLeft(0.0)((a, v) => a + math.abs(v)) relTol 1E-8)
assert(sv.norm(1.0) ~== sv.toArray.foldLeft(0.0)((a, v) => a + math.abs(v)) relTol 1E-8)
assert(Vectors.norm(dv, 1.0) ~== dv.toArray.foldLeft(0.0)((a, v) =>
a + math.abs(v)) relTol 1E-8)
assert(Vectors.norm(sv, 1.0) ~== sv.toArray.foldLeft(0.0)((a, v) =>
a + math.abs(v)) relTol 1E-8)

assert(dv.norm(2.0) ~== math.sqrt(dv.toArray.foldLeft(0.0)((a, v) => a + v * v)) relTol 1E-8)
assert(sv.norm(2.0) ~== math.sqrt(sv.toArray.foldLeft(0.0)((a, v) => a + v * v)) relTol 1E-8)
assert(Vectors.norm(dv, 2.0) ~== math.sqrt(dv.toArray.foldLeft(0.0)((a, v) =>
a + v * v)) relTol 1E-8)
assert(Vectors.norm(sv, 2.0) ~== math.sqrt(sv.toArray.foldLeft(0.0)((a, v) =>
a + v * v)) relTol 1E-8)

assert(dv.norm(Double.PositiveInfinity) ~== dv.toArray.map(math.abs).max relTol 1E-8)
assert(sv.norm(Double.PositiveInfinity) ~== sv.toArray.map(math.abs).max relTol 1E-8)
assert(Vectors.norm(dv, Double.PositiveInfinity) ~== dv.toArray.map(math.abs).max relTol 1E-8)
assert(Vectors.norm(sv, Double.PositiveInfinity) ~== sv.toArray.map(math.abs).max relTol 1E-8)

assert(dv.norm(3.7) ~== math.pow(dv.toArray.foldLeft(0.0)((a, v) =>
assert(Vectors.norm(dv, 3.7) ~== math.pow(dv.toArray.foldLeft(0.0)((a, v) =>
a + math.pow(math.abs(v), 3.7)), 1.0 / 3.7) relTol 1E-8)
assert(sv.norm(3.7) ~== math.pow(dv.toArray.foldLeft(0.0)((a, v) =>
assert(Vectors.norm(sv, 3.7) ~== math.pow(dv.toArray.foldLeft(0.0)((a, v) =>
a + math.pow(math.abs(v), 3.7)), 1.0 / 3.7) relTol 1E-8)
}
}

0 comments on commit 6fa616c

Please sign in to comment.