Skip to content

Commit

Permalink
Added functionality to allow setting of GMM starting point.
Browse files Browse the repository at this point in the history
Added two cluster test to testing suite.
  • Loading branch information
tgaloppo committed Dec 17, 2014
1 parent 8b633f3 commit 42b2142
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,26 @@ class GaussianMixtureModelEM private (
// number of samples per cluster to use when initializing Gaussians
private val nSamples = 5

// an initializing GMM can be provided rather than using the
// default random starting point
private var initialGmm: Option[GaussianMixtureModel] = None

/** A default instance, 2 Gaussians, 100 iterations, 0.01 log-likelihood threshold */
def this() = this(2, 0.01, 100)

/** Set the initial GMM starting point, bypassing the random initialization */
def setInitialGmm(gmm: GaussianMixtureModel): this.type = {
if (gmm.k == k) {
initialGmm = Some(gmm)
} else {
throw new IllegalArgumentException("initialing GMM has mismatched cluster count (gmm.k != k)")
}
this
}

/** Return the user supplied initial GMM, if supplied */
def getInitialiGmm: Option[GaussianMixtureModel] = initialGmm

/** Set the number of Gaussians in the mixture model. Default: 2 */
def setK(k: Int): this.type = {
this.k = k
Expand Down Expand Up @@ -103,20 +120,35 @@ class GaussianMixtureModelEM private (
// Get length of the input vectors
val d = breezeData.first.length

// For each Gaussian, we will initialize the mean as the average
// of some random samples from the data
val samples = breezeData.takeSample(true, k * nSamples, scala.util.Random.nextInt)

// gaussians will be array of (weight, mean, covariance) tuples
// gaussians will be array of (weight, mean, covariance) tuples.
// If the user supplied an initial GMM, we use those values, otherwise
// we start with uniform weights, a random mean from the data, and
// diagonal covariance matrices using component variances
// derived from the samples
var gaussians = (0 until k).map{ i =>
var gaussians = initialGmm match {
case Some(gmm) => (0 until k).map{ i =>
(gmm.weight(i), gmm.mu(i).toBreeze.toDenseVector, gmm.sigma(i).toBreeze.toDenseMatrix)
}.toArray

case None => {
// For each Gaussian, we will initialize the mean as the average
// of some random samples from the data
val samples = breezeData.takeSample(true, k * nSamples, scala.util.Random.nextInt)

(0 until k).map{ i =>
(1.0 / k,
vectorMean(samples.slice(i * nSamples, (i + 1) * nSamples)),
initCovariance(samples.slice(i * nSamples, (i + 1) * nSamples)))
}.toArray
}
}

/*var gaussians = (0 until k).map{ i =>
(1.0 / k,
vectorMean(samples.slice(i * nSamples, (i + 1) * nSamples)),
initCovariance(samples.slice(i * nSamples, (i + 1) * nSamples)))
}.toArray

*/
val accW = new Array[Accumulator[Double]](k)
val accMu = new Array[Accumulator[DenseDoubleVector]](k)
val accSigma = new Array[Accumulator[DenseDoubleMatrix]](k)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,37 @@ class GMMExpectationMaximizationSuite extends FunSuite with MLlibTestSparkContex
assert(gmm.mu(0) ~== Emu absTol 1E-5)
assert(gmm.sigma(0) ~== Esigma absTol 1E-5)
}

test("two clusters") {
val data = sc.parallelize(Array(
Vectors.dense(-5.1971), Vectors.dense(-2.5359), Vectors.dense(-3.8220),
Vectors.dense(-5.2211), Vectors.dense(-5.0602), Vectors.dense( 4.7118),
Vectors.dense( 6.8989), Vectors.dense( 3.4592), Vectors.dense( 4.6322),
Vectors.dense( 5.7048), Vectors.dense( 4.6567), Vectors.dense( 5.5026),
Vectors.dense( 4.5605), Vectors.dense( 5.2043), Vectors.dense( 6.2734)
))

// we set an initial gaussian to induce expected results
val initialGmm = new GaussianMixtureModel(
Array(0.5, 0.5),
Array(Vectors.dense(-1.0), Vectors.dense(1.0)),
Array(Matrices.dense(1, 1, Array(1.0)), Matrices.dense(1, 1, Array(1.0)))
)

val Ew = Array(1.0 / 3.0, 2.0 / 3.0)
val Emu = Array(Vectors.dense(-4.3673), Vectors.dense(5.1604))
val Esigma = Array(Matrices.dense(1, 1, Array(1.1098)), Matrices.dense(1, 1, Array(0.86644)))

val gmm = new GaussianMixtureModelEM()
.setK(2)
.setInitialGmm(initialGmm)
.run(data)

assert(gmm.weight(0) ~== Ew(0) absTol 1E-3)
assert(gmm.weight(1) ~== Ew(1) absTol 1E-3)
assert(gmm.mu(0) ~== Emu(0) absTol 1E-3)
assert(gmm.mu(1) ~== Emu(1) absTol 1E-3)
assert(gmm.sigma(0) ~== Esigma(0) absTol 1E-3)
assert(gmm.sigma(1) ~== Esigma(1) absTol 1E-3)
}
}

0 comments on commit 42b2142

Please sign in to comment.