Skip to content

Commit

Permalink
Moved multivariate Gaussian utility class to mllib/stat/impl
Browse files Browse the repository at this point in the history
Improved comments
  • Loading branch information
tgaloppo committed Dec 12, 2014
1 parent 9770261 commit e7d413b
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ import org.apache.spark.mllib.linalg.Matrix
import org.apache.spark.mllib.linalg.Vector

/**
* Multivariate Gaussian mixture model consisting of k Gaussians, where points are drawn
* from each Gaussian i=1..k with probability w(i); mu(i) and sigma(i) are the respective
* mean and covariance for each Gaussian distribution i=1..k.
* Multivariate Gaussian Mixture Model (GMM) consisting of k Gaussians, where points
* are drawn from each Gaussian i=1..k with probability w(i); mu(i) and sigma(i) are
* the respective mean and covariance for each Gaussian distribution i=1..k.
*
* @param weight Weights for each Gaussian distribution in the mixture, where mu(i) is
* the weight for Gaussian i, and weight.sum == 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,30 @@
package org.apache.spark.mllib.clustering

import breeze.linalg.{DenseVector => BreezeVector, DenseMatrix => BreezeMatrix}
import breeze.linalg.{Transpose, det, inv}
import breeze.linalg.Transpose

import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.linalg.{Matrices, Vector, Vectors}
import org.apache.spark.mllib.stat.impl.MultivariateGaussian
import org.apache.spark.{Accumulator, AccumulatorParam, SparkContext}
import org.apache.spark.SparkContext.DoubleAccumulatorParam

/**
* This class performs multivariate Gaussian expectation maximization. It will
* maximize the log-likelihood for a mixture of k Gaussians, iterating until
* the log-likelihood changes by less than delta, or until it has reached
* the max number of iterations.
* This class performs expectation maximization for multivariate Gaussian
* Mixture Models (GMMs). A GMM represents a composite distribution of
* independent Gaussian distributions with associated "mixing" weights
* specifying each's contribution to the composite.
*
* Given a set of sample points, this class will maximize the log-likelihood
* for a mixture of k Gaussians, iterating until the log-likelihood changes by
* less than convergenceTol, or until it has reached the max number of iterations.
* While this process is generally guaranteed to converge, it is not guaranteed
* to find a global optimum.
*
* @param k The number of independent Gaussians in the mixture model
* @param convergenceTol The maximum change in log-likelihood at which convergence
* is considered to have occurred.
* @param maxIterations The maximum number of iterations to perform
*/
class GaussianMixtureModelEM private (
private var k: Int,
Expand All @@ -40,7 +52,7 @@ class GaussianMixtureModelEM private (
private type DenseDoubleVector = BreezeVector[Double]
private type DenseDoubleMatrix = BreezeMatrix[Double]

/** number of samples per cluster to use when initializing Gaussians */
// number of samples per cluster to use when initializing Gaussians
private val nSamples = 5

/** A default instance, 2 Gaussians, 100 iterations, 0.01 log-likelihood threshold */
Expand Down Expand Up @@ -219,21 +231,4 @@ class GaussianMixtureModelEM private (
a += b
}
}

/**
* Utility class to implement the density function for multivariate Gaussian distribution.
* Breeze provides this functionality, but it requires the Apache Commons Math library,
* so this class is here so-as to not introduce a new dependency in Spark.
*/
private class MultivariateGaussian(val mu: DenseDoubleVector, val sigma: DenseDoubleMatrix)
extends Serializable {
private val sigmaInv2 = inv(sigma) * -0.5
private val U = math.pow(2.0 * math.Pi, -mu.length / 2.0) * math.pow(det(sigma), -0.5)

def pdf(x: DenseDoubleVector): Double = {
val delta = x - mu
val deltaTranspose = new Transpose(delta)
U * math.exp(deltaTranspose * sigmaInv2 * delta)
}
}
}

0 comments on commit e7d413b

Please sign in to comment.