From acf1fba6b0084511272cf4e19e0cab4587a11d16 Mon Sep 17 00:00:00 2001 From: Travis Galoppo Date: Mon, 22 Dec 2014 09:26:28 -0500 Subject: [PATCH] Fixed parameter comment in GaussianMixtureModel Made maximum iterations an optional parameter to DenseGmmEM --- .../org/apache/spark/examples/mllib/DenseGmmEM.scala | 8 +++++--- .../spark/mllib/clustering/GaussianMixtureModel.scala | 2 +- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala index b56c4b3bd7789..e0511eaec9cb5 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala @@ -30,14 +30,15 @@ import org.apache.spark.mllib.linalg.Vectors */ object DenseGmmEM { def main(args: Array[String]): Unit = { - if (args.length != 3) { + if (args.length < 3) { println("usage: DenseGmmEM ") } else { - run(args(0), args(1).toInt, args(2).toDouble) + val maxIterations = if (args.length > 3) args(3).toInt else 100 + run(args(0), args(1).toInt, args(2).toDouble, maxIterations) } } - private def run(inputFile: String, k: Int, convergenceTol: Double) { + private def run(inputFile: String, k: Int, convergenceTol: Double, maxIterations: Int) { val conf = new SparkConf().setAppName("Gaussian Mixture Model EM example") val ctx = new SparkContext(conf) @@ -48,6 +49,7 @@ object DenseGmmEM { val clusters = new GaussianMixtureModelEM() .setK(k) .setConvergenceTol(convergenceTol) + .setMaxIterations(maxIterations) .run(data) for (i <- 0 until clusters.k) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala index 734d67ea72a26..0285a847bd1b3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -28,7 +28,7 @@ import org.apache.spark.mllib.stat.impl.MultivariateGaussian * are drawn from each Gaussian i=1..k with probability w(i); mu(i) and sigma(i) are * the respective mean and covariance for each Gaussian distribution i=1..k. * - * @param weight Weights for each Gaussian distribution in the mixture, where mu(i) is + * @param weight Weights for each Gaussian distribution in the mixture, where weight(i) is * the weight for Gaussian i, and weight.sum == 1 * @param mu Means for each Gaussian in the mixture, where mu(i) is the mean for Gaussian i * @param sigma Covariance maxtrix for each Gaussian in the mixture, where sigma(i) is the