Skip to content

Commit

Permalink
use sample to pick up batch
Browse files Browse the repository at this point in the history
  • Loading branch information
hhbyyh committed Mar 11, 2015
1 parent 4a3f27e commit a570c9a
Showing 1 changed file with 36 additions and 13 deletions.
49 changes: 36 additions & 13 deletions mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
Original file line number Diff line number Diff line change
Expand Up @@ -247,9 +247,34 @@ class LDA private (
new DistributedLDAModel(state, iterationTimes)
}

def runOnlineLDA(documents: RDD[(Long, Vector)]): LDAModel = {
val onlineLDA = new LDA.OnlineLDAOptimizer(documents, k)
(0 until onlineLDA.batchNumber).map(_ => onlineLDA.next())

/**
* Learn an LDA model using the given dataset, using online variational Bayes (VB) algorithm.
* Hoffman, Blei and Bach, “Online Learning for Latent Dirichlet Allocation.” NIPS, 2010.
*
* @param documents RDD of documents, which are term (word) count vectors paired with IDs.
* The term count vectors are "bags of words" with a fixed-size vocabulary
* (where the vocabulary size is the length of the vector).
* Document IDs must be unique and >= 0.
* @param batchNumber Number of batches. For each batch, recommendation size is [4, 16384].
* -1 for automatic batchNumber.
* @return Inferred LDA model
*/
def runOnlineLDA(documents: RDD[(Long, Vector)], batchNumber: Int = -1): LDAModel = {
val D = documents.count().toInt
val batchSize =
if (batchNumber == -1) { // auto mode
if (D / 100 > 16384) 16384
else if (D / 100 < 4) 4
else D / 100
}
else {
require(batchNumber > 0, "batchNumber should be positive or -1")
D / batchNumber
}

val onlineLDA = new LDA.OnlineLDAOptimizer(documents, k, batchSize)
(0 until onlineLDA.actualBatchNumber).map(_ => onlineLDA.next())
new LocalLDAModel(Matrices.fromBreeze(onlineLDA.lambda).transpose)
}

Expand Down Expand Up @@ -411,28 +436,26 @@ private[clustering] object LDA {
* Hoffman, Blei and Bach, “Online Learning for Latent Dirichlet Allocation.” NIPS, 2010.
*/
private[clustering] class OnlineLDAOptimizer(
private val documents: RDD[(Long, Vector)],
private val k: Int) extends Serializable{
private val documents: RDD[(Long, Vector)],
private val k: Int,
private val batchSize: Int) extends Serializable{

private val vocabSize = documents.first._2.size
private val D = documents.count().toInt
private val batchSize = if (D / 1000 > 4096) 4096
else if (D / 1000 < 4) 4
else D / 1000
val batchNumber = D/batchSize
val actualBatchNumber = Math.ceil(D.toDouble / batchSize).toInt

// Initialize the variational distribution q(beta|lambda)
//Initialize the variational distribution q(beta|lambda)
var lambda = getGammaMatrix(k, vocabSize) // K * V
private var Elogbeta = dirichlet_expectation(lambda) // K * V
private var expElogbeta = exp(Elogbeta) // K * V

private var batchId = 0
def next(): Unit = {
require(batchId < batchNumber)
require(batchId < actualBatchNumber)
// weight of the mini-batch. 1024 down weights early iterations
val weight = math.pow(1024 + batchId, -0.5)
val batch = documents.filter(doc => doc._1 % batchNumber == batchId)

val batch = documents.sample(true, batchSize.toDouble / D)
batch.cache()
// Given a mini-batch of documents, estimates the parameters gamma controlling the
// variational distribution over the topic weights for each document in the mini-batch.
var stat = BDM.zeros[Double](k, vocabSize)
Expand Down

0 comments on commit a570c9a

Please sign in to comment.