diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModelEM.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModelEM.scala index 0907c647596da..a6e6ad9ef52b2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModelEM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModelEM.scala @@ -55,9 +55,26 @@ class GaussianMixtureModelEM private ( // number of samples per cluster to use when initializing Gaussians private val nSamples = 5 + // an initializing GMM can be provided rather than using the + // default random starting point + private var initialGmm: Option[GaussianMixtureModel] = None + /** A default instance, 2 Gaussians, 100 iterations, 0.01 log-likelihood threshold */ def this() = this(2, 0.01, 100) + /** Set the initial GMM starting point, bypassing the random initialization */ + def setInitialGmm(gmm: GaussianMixtureModel): this.type = { + if (gmm.k == k) { + initialGmm = Some(gmm) + } else { + throw new IllegalArgumentException("initialing GMM has mismatched cluster count (gmm.k != k)") + } + this + } + + /** Return the user supplied initial GMM, if supplied */ + def getInitialiGmm: Option[GaussianMixtureModel] = initialGmm + /** Set the number of Gaussians in the mixture model. Default: 2 */ def setK(k: Int): this.type = { this.k = k @@ -103,20 +120,35 @@ class GaussianMixtureModelEM private ( // Get length of the input vectors val d = breezeData.first.length - // For each Gaussian, we will initialize the mean as the average - // of some random samples from the data - val samples = breezeData.takeSample(true, k * nSamples, scala.util.Random.nextInt) - - // gaussians will be array of (weight, mean, covariance) tuples + // gaussians will be array of (weight, mean, covariance) tuples. + // If the user supplied an initial GMM, we use those values, otherwise // we start with uniform weights, a random mean from the data, and // diagonal covariance matrices using component variances // derived from the samples - var gaussians = (0 until k).map{ i => + var gaussians = initialGmm match { + case Some(gmm) => (0 until k).map{ i => + (gmm.weight(i), gmm.mu(i).toBreeze.toDenseVector, gmm.sigma(i).toBreeze.toDenseMatrix) + }.toArray + + case None => { + // For each Gaussian, we will initialize the mean as the average + // of some random samples from the data + val samples = breezeData.takeSample(true, k * nSamples, scala.util.Random.nextInt) + + (0 until k).map{ i => + (1.0 / k, + vectorMean(samples.slice(i * nSamples, (i + 1) * nSamples)), + initCovariance(samples.slice(i * nSamples, (i + 1) * nSamples))) + }.toArray + } + } + + /*var gaussians = (0 until k).map{ i => (1.0 / k, vectorMean(samples.slice(i * nSamples, (i + 1) * nSamples)), initCovariance(samples.slice(i * nSamples, (i + 1) * nSamples))) }.toArray - + */ val accW = new Array[Accumulator[Double]](k) val accMu = new Array[Accumulator[DenseDoubleVector]](k) val accSigma = new Array[Accumulator[DenseDoubleMatrix]](k) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala index d1f3fe34bfb09..e44db28ceb614 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala @@ -42,4 +42,37 @@ class GMMExpectationMaximizationSuite extends FunSuite with MLlibTestSparkContex assert(gmm.mu(0) ~== Emu absTol 1E-5) assert(gmm.sigma(0) ~== Esigma absTol 1E-5) } + + test("two clusters") { + val data = sc.parallelize(Array( + Vectors.dense(-5.1971), Vectors.dense(-2.5359), Vectors.dense(-3.8220), + Vectors.dense(-5.2211), Vectors.dense(-5.0602), Vectors.dense( 4.7118), + Vectors.dense( 6.8989), Vectors.dense( 3.4592), Vectors.dense( 4.6322), + Vectors.dense( 5.7048), Vectors.dense( 4.6567), Vectors.dense( 5.5026), + Vectors.dense( 4.5605), Vectors.dense( 5.2043), Vectors.dense( 6.2734) + )) + + // we set an initial gaussian to induce expected results + val initialGmm = new GaussianMixtureModel( + Array(0.5, 0.5), + Array(Vectors.dense(-1.0), Vectors.dense(1.0)), + Array(Matrices.dense(1, 1, Array(1.0)), Matrices.dense(1, 1, Array(1.0))) + ) + + val Ew = Array(1.0 / 3.0, 2.0 / 3.0) + val Emu = Array(Vectors.dense(-4.3673), Vectors.dense(5.1604)) + val Esigma = Array(Matrices.dense(1, 1, Array(1.1098)), Matrices.dense(1, 1, Array(0.86644))) + + val gmm = new GaussianMixtureModelEM() + .setK(2) + .setInitialGmm(initialGmm) + .run(data) + + assert(gmm.weight(0) ~== Ew(0) absTol 1E-3) + assert(gmm.weight(1) ~== Ew(1) absTol 1E-3) + assert(gmm.mu(0) ~== Emu(0) absTol 1E-3) + assert(gmm.mu(1) ~== Emu(1) absTol 1E-3) + assert(gmm.sigma(0) ~== Esigma(0) absTol 1E-3) + assert(gmm.sigma(1) ~== Esigma(1) absTol 1E-3) + } }