diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala index 32469ac7e2995..a64b3108de0bd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -21,9 +21,9 @@ import org.apache.spark.mllib.linalg.Matrix import org.apache.spark.mllib.linalg.Vector /** - * Multivariate Gaussian mixture model consisting of k Gaussians, where points are drawn - * from each Gaussian i=1..k with probability w(i); mu(i) and sigma(i) are the respective - * mean and covariance for each Gaussian distribution i=1..k. + * Multivariate Gaussian Mixture Model (GMM) consisting of k Gaussians, where points + * are drawn from each Gaussian i=1..k with probability w(i); mu(i) and sigma(i) are + * the respective mean and covariance for each Gaussian distribution i=1..k. * * @param weight Weights for each Gaussian distribution in the mixture, where mu(i) is * the weight for Gaussian i, and weight.sum == 1 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModelEM.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModelEM.scala index e5568c252ed5c..ccea01277c683 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModelEM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModelEM.scala @@ -18,18 +18,30 @@ package org.apache.spark.mllib.clustering import breeze.linalg.{DenseVector => BreezeVector, DenseMatrix => BreezeMatrix} -import breeze.linalg.{Transpose, det, inv} +import breeze.linalg.Transpose import org.apache.spark.rdd.RDD import org.apache.spark.mllib.linalg.{Matrices, Vector, Vectors} +import org.apache.spark.mllib.stat.impl.MultivariateGaussian import org.apache.spark.{Accumulator, AccumulatorParam, SparkContext} import org.apache.spark.SparkContext.DoubleAccumulatorParam /** - * This class performs multivariate Gaussian expectation maximization. It will - * maximize the log-likelihood for a mixture of k Gaussians, iterating until - * the log-likelihood changes by less than delta, or until it has reached - * the max number of iterations. + * This class performs expectation maximization for multivariate Gaussian + * Mixture Models (GMMs). A GMM represents a composite distribution of + * independent Gaussian distributions with associated "mixing" weights + * specifying each's contribution to the composite. + * + * Given a set of sample points, this class will maximize the log-likelihood + * for a mixture of k Gaussians, iterating until the log-likelihood changes by + * less than convergenceTol, or until it has reached the max number of iterations. + * While this process is generally guaranteed to converge, it is not guaranteed + * to find a global optimum. + * + * @param k The number of independent Gaussians in the mixture model + * @param convergenceTol The maximum change in log-likelihood at which convergence + * is considered to have occurred. + * @param maxIterations The maximum number of iterations to perform */ class GaussianMixtureModelEM private ( private var k: Int, @@ -40,7 +52,7 @@ class GaussianMixtureModelEM private ( private type DenseDoubleVector = BreezeVector[Double] private type DenseDoubleMatrix = BreezeMatrix[Double] - /** number of samples per cluster to use when initializing Gaussians */ + // number of samples per cluster to use when initializing Gaussians private val nSamples = 5 /** A default instance, 2 Gaussians, 100 iterations, 0.01 log-likelihood threshold */ @@ -219,21 +231,4 @@ class GaussianMixtureModelEM private ( a += b } } - - /** - * Utility class to implement the density function for multivariate Gaussian distribution. - * Breeze provides this functionality, but it requires the Apache Commons Math library, - * so this class is here so-as to not introduce a new dependency in Spark. - */ - private class MultivariateGaussian(val mu: DenseDoubleVector, val sigma: DenseDoubleMatrix) - extends Serializable { - private val sigmaInv2 = inv(sigma) * -0.5 - private val U = math.pow(2.0 * math.Pi, -mu.length / 2.0) * math.pow(det(sigma), -0.5) - - def pdf(x: DenseDoubleVector): Double = { - val delta = x - mu - val deltaTranspose = new Transpose(delta) - U * math.exp(deltaTranspose * sigmaInv2 * delta) - } - } }