From a570c9a5cbdbf0ac7b7a4eae1e3b571e0060e5f0 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Wed, 11 Mar 2015 19:23:13 +0800 Subject: [PATCH] use sample to pick up batch --- .../apache/spark/mllib/clustering/LDA.scala | 49 ++++++++++++++----- 1 file changed, 36 insertions(+), 13 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index 76ecdf92f26ed..a3681e34a147d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -247,9 +247,34 @@ class LDA private ( new DistributedLDAModel(state, iterationTimes) } - def runOnlineLDA(documents: RDD[(Long, Vector)]): LDAModel = { - val onlineLDA = new LDA.OnlineLDAOptimizer(documents, k) - (0 until onlineLDA.batchNumber).map(_ => onlineLDA.next()) + + /** + * Learn an LDA model using the given dataset, using online variational Bayes (VB) algorithm. + * Hoffman, Blei and Bach, “Online Learning for Latent Dirichlet Allocation.” NIPS, 2010. + * + * @param documents RDD of documents, which are term (word) count vectors paired with IDs. + * The term count vectors are "bags of words" with a fixed-size vocabulary + * (where the vocabulary size is the length of the vector). + * Document IDs must be unique and >= 0. + * @param batchNumber Number of batches. For each batch, recommendation size is [4, 16384]. + * -1 for automatic batchNumber. + * @return Inferred LDA model + */ + def runOnlineLDA(documents: RDD[(Long, Vector)], batchNumber: Int = -1): LDAModel = { + val D = documents.count().toInt + val batchSize = + if (batchNumber == -1) { // auto mode + if (D / 100 > 16384) 16384 + else if (D / 100 < 4) 4 + else D / 100 + } + else { + require(batchNumber > 0, "batchNumber should be positive or -1") + D / batchNumber + } + + val onlineLDA = new LDA.OnlineLDAOptimizer(documents, k, batchSize) + (0 until onlineLDA.actualBatchNumber).map(_ => onlineLDA.next()) new LocalLDAModel(Matrices.fromBreeze(onlineLDA.lambda).transpose) } @@ -411,28 +436,26 @@ private[clustering] object LDA { * Hoffman, Blei and Bach, “Online Learning for Latent Dirichlet Allocation.” NIPS, 2010. */ private[clustering] class OnlineLDAOptimizer( - private val documents: RDD[(Long, Vector)], - private val k: Int) extends Serializable{ + private val documents: RDD[(Long, Vector)], + private val k: Int, + private val batchSize: Int) extends Serializable{ private val vocabSize = documents.first._2.size private val D = documents.count().toInt - private val batchSize = if (D / 1000 > 4096) 4096 - else if (D / 1000 < 4) 4 - else D / 1000 - val batchNumber = D/batchSize + val actualBatchNumber = Math.ceil(D.toDouble / batchSize).toInt - // Initialize the variational distribution q(beta|lambda) + //Initialize the variational distribution q(beta|lambda) var lambda = getGammaMatrix(k, vocabSize) // K * V private var Elogbeta = dirichlet_expectation(lambda) // K * V private var expElogbeta = exp(Elogbeta) // K * V private var batchId = 0 def next(): Unit = { - require(batchId < batchNumber) + require(batchId < actualBatchNumber) // weight of the mini-batch. 1024 down weights early iterations val weight = math.pow(1024 + batchId, -0.5) - val batch = documents.filter(doc => doc._1 % batchNumber == batchId) - + val batch = documents.sample(true, batchSize.toDouble / D) + batch.cache() // Given a mini-batch of documents, estimates the parameters gamma controlling the // variational distribution over the topic weights for each document in the mini-batch. var stat = BDM.zeros[Double](k, vocabSize)