From f407b4c22257a03409f4ab4dd15ed92738f58140 Mon Sep 17 00:00:00 2001 From: FlytxtRnD Date: Tue, 16 Dec 2014 17:17:40 +0530 Subject: [PATCH] Added predict() to return the cluster labels and membership values --- .../spark/examples/mllib/DenseGmmEM.scala | 4 +++ .../clustering/GaussianMixtureModel.scala | 9 ++++++ .../clustering/GaussianMixtureModelEM.scala | 30 ++++++++++++++++++- 3 files changed, 42 insertions(+), 1 deletion(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala index f7301f533b1c2..01b8d92aabf07 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala @@ -47,5 +47,9 @@ object DenseGmmEM { println("weight=%f mu=%s sigma=\n%s\n" format (clusters.weight(i), clusters.mu(i), clusters.sigma(i))) } + val (responsibility_matrix, cluster_labels) = clusters.predict(data) + for(x <- cluster_labels.collect()){ + print(" " + x) + } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala index a64b3108de0bd..df11bbeb89ef0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -17,6 +17,7 @@ package org.apache.spark.mllib.clustering +import org.apache.spark.rdd.RDD import org.apache.spark.mllib.linalg.Matrix import org.apache.spark.mllib.linalg.Vector @@ -38,4 +39,12 @@ class GaussianMixtureModel( /** Number of gaussians in mixture */ def k: Int = weight.length; + + /** Maps given points to their cluster indices. */ + def predict(points: RDD[Vector]): (RDD[Array[Double]],RDD[Int]) = { + val responsibility_matrix = new GaussianMixtureModelEM() + .predictClusters(points,mu,sigma,weight,k) + val cluster_labels = responsibility_matrix.map(r => r.indexOf(r.max)) + (responsibility_matrix,cluster_labels) + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModelEM.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModelEM.scala index c21631b5715e1..ef610aca8a3e6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModelEM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModelEM.scala @@ -21,7 +21,7 @@ import breeze.linalg.{DenseVector => BreezeVector, DenseMatrix => BreezeMatrix} import breeze.linalg.Transpose import org.apache.spark.rdd.RDD -import org.apache.spark.mllib.linalg.{Matrices, Vector, Vectors} +import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors} import org.apache.spark.mllib.stat.impl.MultivariateGaussian import org.apache.spark.{Accumulator, AccumulatorParam, SparkContext} import org.apache.spark.SparkContext.DoubleAccumulatorParam @@ -208,6 +208,34 @@ class GaussianMixtureModelEM private ( cov } + /** + Given the input vectors, return the membership value of each vector + to all mixture components. + */ + def predictClusters(points:RDD[Vector],mu:Array[Vector],sigma:Array[Matrix], + weight:Array[Double],k:Int):RDD[Array[Double]]= { + val ctx = points.sparkContext + val dists = ctx.broadcast((0 until k).map(i => + new MultivariateGaussian(mu(i).toBreeze.toDenseVector,sigma(i).toBreeze.toDenseMatrix)) + .toArray) + val weights = ctx.broadcast((0 until k).map(i => weight(i)).toArray) + points.map(x=>compute_log_likelihood(x.toBreeze.toDenseVector,dists.value,weights.value,k)) + + } + /** + * Compute the log density of each vector + */ + def compute_log_likelihood(pt:DenseDoubleVector,dists:Array[MultivariateGaussian], + weights:Array[Double],k:Int):Array[Double]={ + val p = (0 until k).map(i => + eps + weights(i) * dists(i).pdf(pt)).toArray + val pSum = p.sum + for(i<- 0 until k){ + p(i) /= pSum + } + p + } + /** AccumulatorParam for Dense Breeze Vectors */ private object DenseDoubleVectorAccumulatorParam extends AccumulatorParam[DenseDoubleVector] { def zero(initialVector: DenseDoubleVector): DenseDoubleVector = {