Skip to content
This repository

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse code

Merge pull request #16 from jakemannix/allow_prior_truncation

Allow prior truncation
  • Loading branch information...
commit 4d9380f0c559ead140c32e5fd34314968292168e 2 parents 5f43cc7 + 0517e48
Andy Schlaikjer sagemintblue authored
7 core/src/main/java/org/apache/mahout/clustering/lda/cvb/CVBConfig.java
@@ -365,6 +365,7 @@ public CVBConfig read(Configuration conf) {
365 365 public void check() {
366 366 checkPositive(NUM_TOPICS_PARAM, numTopics);
367 367 checkPositive(NUM_TERMS_PARAM, numTerms);
  368 + checkGreater(NUM_TERMS_PARAM, numTerms, numTopics);
368 369 checkPositive(DOC_TOPIC_SMOOTHING_PARAM, alpha);
369 370 checkPositive(TERM_TOPIC_SMOOTHING_PARAM, eta);
370 371 checkPositive(RANDOM_SEED_PARAM, randomSeed);
@@ -378,11 +379,13 @@ public void check() {
378 379 }
379 380
380 381 protected void checkGreater(String param, Number value, Number threshold) {
381   - Preconditions.checkArgument(value.doubleValue() > threshold.doubleValue(), "Expecting %s > %d but found %s", param, threshold, value);
  382 + Preconditions.checkArgument(value.doubleValue() > threshold.doubleValue(),
  383 + "Expecting %s > %d but found %s", param, threshold, value);
382 384 }
383 385
384 386 protected void checkGreaterOrEqual(String param, Number value, Number threshold) {
385   - Preconditions.checkArgument(value.doubleValue() >= threshold.doubleValue(), "Expecting %s >= %d but found %s", param, threshold, value);
  387 + Preconditions.checkArgument(value.doubleValue() >= threshold.doubleValue(),
  388 + "Expecting %s >= %d but found %s", param, threshold, value);
386 389 }
387 390
388 391 protected void checkPositive(String param, Number value) {
14 core/src/main/java/org/apache/mahout/clustering/lda/cvb/PriorTrainingReducer.java
@@ -54,6 +54,7 @@
54 54 private ModelTrainer modelTrainer;
55 55 private int maxIters;
56 56 private int numTopics;
  57 + private int numTerms;
57 58 private boolean onlyLabeledDocs;
58 59 private MultipleOutputs multipleOutputs;
59 60 private Reporter reporter;
@@ -79,7 +80,7 @@ public void configure(JobConf conf) {
79 80 double eta = c.getEta();
80 81 double alpha = c.getAlpha();
81 82 numTopics = c.getNumTopics();
82   - int numTerms = c.getNumTerms();
  83 + numTerms = c.getNumTerms();
83 84 int numUpdateThreads = c.getNumUpdateThreads();
84 85 int numTrainThreads = c.getNumTrainThreads();
85 86 maxIters = c.getMaxItersPerDoc();
@@ -127,10 +128,15 @@ public void reduce(IntWritable docId, Iterator<VectorWritable> vectors,
127 128 Vector document = null;
128 129 while(vectors.hasNext()) {
129 130 VectorWritable v = vectors.next();
130   - if(v.get().size() == numTopics) {
131   - topicVector = v.get();
132   - } else {
  131 + /*
  132 + * NOTE: we are susceptible to the pathological case of numTerms == numTopics (which should
  133 + * never happen, as that would generate a horrible topic model), because we identify which
  134 + * vector is the "prior" and which is the document by document.size() == numTerms
  135 + */
  136 + if(v.get().size() == numTerms) {
133 137 document = v.get();
  138 + } else {
  139 + topicVector = v.get();
134 140 }
135 141 }
136 142 if(document == null) {
4 core/src/main/java/org/apache/mahout/clustering/lda/cvb/TopicModel.java
@@ -374,6 +374,8 @@ private void pTopicGivenTerm(Vector document, Vector docTopics, Matrix termTopic
374 374 double topicSum = topicSums.get(x);
375 375 // get p(topic x | term a) distribution to update
376 376 Vector termTopicRow = termTopicDist.viewRow(x);
  377 + // cache factor which is the same for all terms, for this fixed topic.
  378 + double topicMult = (topicWeight + alpha) / (topicSum + eta * numTerms);
377 379
378 380 // for each term a in document i with non-zero weight
379 381 Iterator<Vector.Element> it = document.iterateNonZero();
@@ -382,7 +384,7 @@ private void pTopicGivenTerm(Vector document, Vector docTopics, Matrix termTopic
382 384 int termIndex = e.index();
383 385
384 386 // calc un-normalized p(topic x | term a, document i)
385   - double termTopicLikelihood = (topicTermRow.get(termIndex) + eta) * (topicWeight + alpha) / (topicSum + eta * numTerms);
  387 + double termTopicLikelihood = (topicTermRow.get(termIndex) + eta) * topicMult;
386 388 termTopicRow.set(termIndex, termTopicLikelihood);
387 389 }
388 390 }
6 core/src/test/java/org/apache/mahout/clustering/lda/cvb/TestCVBModelTrainer.java
@@ -198,7 +198,9 @@ public void testPriorDocTopics() throws Exception {
198 198 Path topicModelStateTempPath = getTestTempDirPath("topicTemp");
199 199 Path outputPath = new Path(getTestTempDirPath(), "finalOutput");
200 200 int numIterations = 10;
201   - CVBConfig cvbConfig = new CVBConfig().setAlpha(ALPHA).setEta(ETA).setNumTopics(numGeneratingTopics)
  201 + int numTopics = numGeneratingTopics - 2;
  202 + CVBConfig cvbConfig = new CVBConfig().setAlpha(ALPHA).setEta(ETA)
  203 + .setNumTopics(numTopics)
202 204 .setBackfillPerplexity(false).setConvergenceDelta(0).setDictionaryPath(null)
203 205 .setModelTempPath(topicModelStateTempPath).setTestFraction(0.2f).setNumTerms(numTerms)
204 206 .setMaxIterations(numIterations).setInputPath(sampleCorpusPath).setNumTrainThreads(1)
@@ -213,7 +215,7 @@ public void testPriorDocTopics() throws Exception {
213 215 modelParts.add(fileStatus.getPath());
214 216 }
215 217 Pair<Matrix, Vector> model = TopicModel.loadModel(conf, modelParts.toArray(new Path[0]));
216   - for(int topic = 0; topic < numGeneratingTopics; topic++) {
  218 + for(int topic = 0; topic < numTopics; topic++) {
217 219 Vector topicDist = model.getFirst().viewRow(topic);
218 220 int term = mostProminentFeature(topicDist);
219 221 int expectedTopicForTerm = expectedTopicForTerm(matrix, term);

0 comments on commit 4d9380f

Please sign in to comment.
Something went wrong with that request. Please try again.