Permalink
Browse files

Merge branch 'master' into doc_topic_priors_merge

  • Loading branch information...
2 parents ce4afed + 1c8624d commit 53c22753b712baaa7e36d99d395e5e43a50da582 @jakemannix committed Mar 16, 2012
View
@@ -20,6 +20,7 @@ core/testdata/
core/temp
temp
distribution/.settings/
+distribution/target/
examples/.settings/
foo
math-tests/
View
@@ -23,7 +23,7 @@
<parent>
<groupId>org.apache.mahout</groupId>
<artifactId>mahout</artifactId>
- <version>0.7-SNAPSHOT</version>
+ <version>0.7-T1-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>
View
@@ -23,7 +23,7 @@
<parent>
<groupId>org.apache.mahout</groupId>
<artifactId>mahout</artifactId>
- <version>0.7-SNAPSHOT</version>
+ <version>0.7-T1-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>
@@ -22,19 +22,19 @@
import org.apache.mahout.math.SparseRowMatrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
import java.io.IOException;
public class CVB0DocInferenceMapper extends CachingCVB0Mapper {
+ private static final Logger log = LoggerFactory.getLogger(CVB0DocInferenceMapper.class);
@Override
public void map(IntWritable docId, VectorWritable doc, Context context)
throws IOException, InterruptedException {
- int numTopics = getNumTopics();
- Vector docTopics = new DenseVector(new double[numTopics]).assign(1.0 /numTopics);
+ Vector docTopics = new DenseVector(new double[numTopics]).assign(1.0 / numTopics);
Matrix docModel = new SparseRowMatrix(numTopics, doc.get().size());
- int maxIters = getMaxIters();
- ModelTrainer modelTrainer = getModelTrainer();
for(int i = 0; i < maxIters; i++) {
modelTrainer.getReadModel().trainDocTopicModel(doc.get(), docTopics, docModel);
}
@@ -43,6 +43,7 @@ public void map(IntWritable docId, VectorWritable doc, Context context)
@Override
protected void cleanup(Context context) {
- getModelTrainer().stop();
+ log.info("Stopping model trainer");
+ modelTrainer.stop();
}
}

Large diffs are not rendered by default.

Oops, something went wrong.
@@ -59,17 +59,17 @@
public void configure(org.apache.hadoop.mapred.JobConf conf) {
try {
multipleOutputs = new MultipleOutputs(conf);
-
- double eta = conf.getFloat(CVB0Driver.TERM_TOPIC_SMOOTHING, Float.NaN);
- double alpha = conf.getFloat(CVB0Driver.DOC_TOPIC_SMOOTHING, Float.NaN);
- long seed = conf.getLong(CVB0Driver.RANDOM_SEED, 1234L);
+ CVBConfig c = new CVBConfig().read(conf);
+ double eta = c.getEta();
+ double alpha = c.getAlpha();
+ long seed = c.getRandomSeed();
random = RandomUtils.getRandom(seed);
- numTopics = conf.getInt(CVB0Driver.NUM_TOPICS, -1);
- int numTerms = conf.getInt(CVB0Driver.NUM_TERMS, -1);
- int numUpdateThreads = conf.getInt(CVB0Driver.NUM_UPDATE_THREADS, 1);
- int numTrainThreads = conf.getInt(CVB0Driver.NUM_TRAIN_THREADS, 4);
- double modelWeight = conf.getFloat(CVB0Driver.MODEL_WEIGHT, 1f);
- testFraction = conf.getFloat(CVB0Driver.TEST_SET_FRACTION, 0.1f);
+ numTopics = c.getNumTopics();
+ int numTerms = c.getNumTerms();
+ int numUpdateThreads = c.getNumUpdateThreads();
+ int numTrainThreads = c.getNumTrainThreads();
+ double modelWeight = c.getModelWeight();
+ testFraction = c.getTestFraction();
log.info("Initializing read model");
TopicModel readModel;
Path[] modelPaths = CVB0Driver.getModelPaths(conf);
@@ -16,32 +16,77 @@
*/
package org.apache.mahout.clustering.lda.cvb;
+import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.Maps;
+import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+
+import java.util.Map;
public class CVBConfig {
+ // parameter names
+ public static final String NUM_TOPICS_PARAM = "num_topics";
+ public static final String NUM_TERMS_PARAM = "num_terms";
+ public static final String DOC_TOPIC_SMOOTHING_PARAM = "doc_topic_smoothing";
+ public static final String TERM_TOPIC_SMOOTHING_PARAM = "term_topic_smoothing";
+ public static final String MAX_ITERATIONS_PARAM = DefaultOptionCreator.MAX_ITERATIONS_OPTION;
+ public static final String CONVERGENCE_DELTA_PARAM = DefaultOptionCreator.CONVERGENCE_DELTA_OPTION;
+ public static final String DICTIONARY_PARAM = "dictionary";
+ public static final String DOC_TOPIC_OUTPUT_PARAM = "doc_topic_output";
+ public static final String MODEL_TEMP_DIR_PARAM = "topic_model_temp_dir";
+ public static final String ITERATION_BLOCK_SIZE_PARAM = "iteration_block_size";
+ public static final String RANDOM_SEED_PARAM = "random_seed";
+ public static final String TEST_SET_FRACTION_PARAM = "test_set_fraction";
+ public static final String NUM_TRAIN_THREADS_PARAM = "num_train_threads";
+ public static final String NUM_UPDATE_THREADS_PARAM = "num_update_threads";
+ public static final String MAX_ITERATIONS_PER_DOC_PARAM = "max_doc_topic_iters";
+ public static final String MODEL_WEIGHT_PARAM = "prev_iter_mult";
+ public static final String NUM_REDUCE_TASKS_PARAM = "num_reduce_tasks";
+ public static final String PERSIST_INTERMEDIATE_DOCTOPICS_PARAM = "persist_intermediate_doctopics";
+ public static final String DOC_TOPIC_PRIOR_PARAM = "doc_topic_prior_path";
+ public static final String BACKFILL_PERPLEXITY_PARAM = "backfill_perplexity";
+ public static final String ONLY_LABELED_DOCS_PARAM = "labeled_only";
+
+ // default values
+ public static final float DOC_TOPIC_SMOOTHING_DEFAULT = 1e-5f;
+ public static final float TERM_TOPIC_SMOOTHING_DEFAULT = 1e-5f;
+ public static final int MAX_ITERATIONS_DEFAULT = 30;
+ public static final int ITERATION_BLOCK_SIZE_DEFAULT = 1;
+ public static final float CONVERGENCE_DELTA_DEFAULT = 1e-5f;
+ public static final long RANDOM_SEED_DEFAULT = 1234l;
+ public static final float TEST_SET_FRACTION_DEFAULT = 1e-2f;
+ public static final int NUM_TRAIN_THREADS_DEFAULT = 4;
+ public static final int NUM_UPDATE_THREADS_DEFAULT = 1;
+ public static final int MAX_ITERATIONS_PER_DOC_DEFAULT = 10;
+ public static final int NUM_REDUCE_TASKS_DEFAULT = 1;
+ public static final float MODEL_WEIGHT_DEFAULT = 1f;
+
// TODO: sensible defaults and/or checks for validity
private Path inputPath;
private Path outputPath;
private int numTopics;
private int numTerms;
- private double alpha;
- private double eta;
- private int maxIterations = 1;
- private int iterationBlockSize = 1;
- private double convergenceDelta;
+ private float alpha = DOC_TOPIC_SMOOTHING_DEFAULT;
+ private float eta = TERM_TOPIC_SMOOTHING_DEFAULT;
+ private int maxIterations = MAX_ITERATIONS_DEFAULT;
+ private int iterationBlockSize = ITERATION_BLOCK_SIZE_DEFAULT;
+ private float convergenceDelta = CONVERGENCE_DELTA_DEFAULT;
private Path dictionaryPath;
private Path docTopicOutputPath;
private Path modelTempPath;
private Path docTopicPriorPath;
private boolean persistDocTopics;
- private long seed = 1234;
- private double testFraction = 0;
- private int numTrainThreads = 1;
- private int numUpdateThreads = 1;
- private int maxItersPerDoc = 1;
- private int numReduceTasks = 1;
+ private long randomSeed = RANDOM_SEED_DEFAULT;
+ private float testFraction = TEST_SET_FRACTION_DEFAULT;
+ private int numTrainThreads = NUM_TRAIN_THREADS_DEFAULT;
+ private int numUpdateThreads = NUM_UPDATE_THREADS_DEFAULT;
+ private int maxItersPerDoc = MAX_ITERATIONS_PER_DOC_DEFAULT;
+ private int numReduceTasks = NUM_REDUCE_TASKS_DEFAULT;
private boolean backfillPerplexity;
private boolean useOnlyLabeledDocs;
+ private float modelWeight = MODEL_WEIGHT_DEFAULT;
public boolean isUseOnlyLabeledDocs() {
return useOnlyLabeledDocs;
@@ -88,20 +133,20 @@ public CVBConfig setNumTerms(int numTerms) {
return this;
}
- public double getAlpha() {
+ public float getAlpha() {
return alpha;
}
- public CVBConfig setAlpha(double alpha) {
+ public CVBConfig setAlpha(float alpha) {
this.alpha = alpha;
return this;
}
- public double getEta() {
+ public float getEta() {
return eta;
}
- public CVBConfig setEta(double eta) {
+ public CVBConfig setEta(float eta) {
this.eta = eta;
return this;
}
@@ -124,11 +169,11 @@ public CVBConfig setIterationBlockSize(int iterationBlockSize) {
return this;
}
- public double getConvergenceDelta() {
+ public float getConvergenceDelta() {
return convergenceDelta;
}
- public CVBConfig setConvergenceDelta(double convergenceDelta) {
+ public CVBConfig setConvergenceDelta(float convergenceDelta) {
this.convergenceDelta = convergenceDelta;
return this;
}
@@ -178,20 +223,20 @@ public CVBConfig setPersistDocTopics(boolean persistDocTopics) {
return this;
}
- public long getSeed() {
- return seed;
+ public long getRandomSeed() {
+ return randomSeed;
}
- public CVBConfig setSeed(long seed) {
- this.seed = seed;
+ public CVBConfig setRandomSeed(long randomSeed) {
+ this.randomSeed = randomSeed;
return this;
}
- public double getTestFraction() {
+ public float getTestFraction() {
return testFraction;
}
- public CVBConfig setTestFraction(double testFraction) {
+ public CVBConfig setTestFraction(float testFraction) {
this.testFraction = testFraction;
return this;
}
@@ -240,4 +285,68 @@ public CVBConfig setBackfillPerplexity(boolean backfillPerplexity) {
this.backfillPerplexity = backfillPerplexity;
return this;
}
+
+ public float getModelWeight() {
+ return modelWeight;
+ }
+
+ public CVBConfig setModelWeight(float modelWeight) {
+ this.modelWeight = modelWeight;
+ return this;
+ }
+
+ public void write(Configuration conf) {
+ conf.setInt(NUM_TOPICS_PARAM, numTopics);
+ conf.setInt(NUM_TERMS_PARAM, numTerms);
+ conf.setFloat(DOC_TOPIC_SMOOTHING_PARAM, alpha);
+ conf.setFloat(TERM_TOPIC_SMOOTHING_PARAM, eta);
+ conf.setLong(RANDOM_SEED_PARAM, randomSeed);
+ conf.setFloat(TEST_SET_FRACTION_PARAM, testFraction);
+ conf.setInt(NUM_TRAIN_THREADS_PARAM, numTrainThreads);
+ conf.setInt(NUM_UPDATE_THREADS_PARAM, numUpdateThreads);
+ conf.setInt(MAX_ITERATIONS_PER_DOC_PARAM, maxItersPerDoc);
+ conf.setFloat(MODEL_WEIGHT_PARAM, modelWeight);
+ conf.setBoolean(ONLY_LABELED_DOCS_PARAM, useOnlyLabeledDocs);
+ }
+
+ public CVBConfig read(Configuration conf) {
+ setNumTopics(conf.getInt(NUM_TOPICS_PARAM, 0));
+ setNumTerms(conf.getInt(NUM_TERMS_PARAM, 0));
+ setAlpha(conf.getFloat(DOC_TOPIC_SMOOTHING_PARAM, 0));
+ setEta(conf.getFloat(TERM_TOPIC_SMOOTHING_PARAM, 0));
+ setRandomSeed(conf.getLong(RANDOM_SEED_PARAM, 0));
+ setTestFraction(conf.getFloat(TEST_SET_FRACTION_PARAM, 0));
+ setNumTrainThreads(conf.getInt(NUM_TRAIN_THREADS_PARAM, 0));
+ setNumUpdateThreads(conf.getInt(NUM_UPDATE_THREADS_PARAM, 0));
+ setMaxItersPerDoc(conf.getInt(MAX_ITERATIONS_PER_DOC_PARAM, 0));
+ setModelWeight(conf.getFloat(MODEL_WEIGHT_PARAM, 0));
+ setUseOnlyLabeledDocs(conf.getBoolean(ONLY_LABELED_DOCS_PARAM, false));
+ check();
+ return this;
+ }
+
+ public void check() {
+ checkPositive(NUM_TOPICS_PARAM, numTopics);
+ checkPositive(NUM_TERMS_PARAM, numTerms);
+ checkPositive(DOC_TOPIC_SMOOTHING_PARAM, alpha);
+ checkPositive(TERM_TOPIC_SMOOTHING_PARAM, eta);
+ checkPositive(RANDOM_SEED_PARAM, randomSeed);
+ checkPositive(TEST_SET_FRACTION_PARAM, testFraction);
+ checkPositive(NUM_TRAIN_THREADS_PARAM, numTrainThreads);
+ checkPositive(NUM_UPDATE_THREADS_PARAM, numUpdateThreads);
+ checkPositive(MAX_ITERATIONS_PER_DOC_PARAM, maxItersPerDoc);
+ checkGreaterOrEqual(MODEL_WEIGHT_PARAM, modelWeight, 1);
+ }
+
+ protected void checkGreater(String param, Number value, Number threshold) {
+ Preconditions.checkArgument(value.doubleValue() > threshold.doubleValue(), "Expecting %s > %d but found %s", param, threshold, value);
+ }
+
+ protected void checkGreaterOrEqual(String param, Number value, Number threshold) {
+ Preconditions.checkArgument(value.doubleValue() >= threshold.doubleValue(), "Expecting %s >= %d but found %s", param, threshold, value);
+ }
+
+ protected void checkPositive(String param, Number value) {
+ checkGreater(param, value, 0);
+ }
}
@@ -29,6 +29,7 @@
import org.slf4j.LoggerFactory;
import java.io.IOException;
+import java.util.Random;
/**
* Run ensemble learning via loading the {@link ModelTrainer} with two {@link TopicModel} instances:
@@ -54,38 +55,29 @@
*/
public class CachingCVB0Mapper
extends Mapper<IntWritable, VectorWritable, IntWritable, VectorWritable> {
-
private static final Logger log = LoggerFactory.getLogger(CachingCVB0Mapper.class);
-
- private ModelTrainer modelTrainer;
- private int maxIters;
- private int numTopics;
-
- protected ModelTrainer getModelTrainer() {
- return modelTrainer;
- }
-
- protected int getMaxIters() {
- return maxIters;
- }
-
- protected int getNumTopics() {
- return numTopics;
- }
+ protected ModelTrainer modelTrainer;
+ protected int maxIters;
+ protected int numTopics;
+ protected float testFraction;
+ protected Random random;
@Override
protected void setup(Context context) throws IOException, InterruptedException {
log.info("Retrieving configuration");
Configuration conf = context.getConfiguration();
- float eta = conf.getFloat(CVB0Driver.TERM_TOPIC_SMOOTHING, Float.NaN);
- float alpha = conf.getFloat(CVB0Driver.DOC_TOPIC_SMOOTHING, Float.NaN);
- long seed = conf.getLong(CVB0Driver.RANDOM_SEED, 1234L);
- numTopics = conf.getInt(CVB0Driver.NUM_TOPICS, -1);
- int numTerms = conf.getInt(CVB0Driver.NUM_TERMS, -1);
- int numUpdateThreads = conf.getInt(CVB0Driver.NUM_UPDATE_THREADS, 1);
- int numTrainThreads = conf.getInt(CVB0Driver.NUM_TRAIN_THREADS, 4);
- maxIters = conf.getInt(CVB0Driver.MAX_ITERATIONS_PER_DOC, 10);
- float modelWeight = conf.getFloat(CVB0Driver.MODEL_WEIGHT, 1.0f);
+ CVBConfig config = new CVBConfig().read(conf);
+ float eta = config.getEta();
+ float alpha = config.getAlpha();
+ long seed = config.getRandomSeed();
+ numTopics = config.getNumTopics();
+ int numTerms = config.getNumTerms();
+ int numUpdateThreads = config.getNumUpdateThreads();
+ int numTrainThreads = config.getNumTrainThreads();
+ maxIters = config.getMaxItersPerDoc();
+ float modelWeight = config.getModelWeight();
+ testFraction = config.getTestFraction();
+ random = RandomUtils.getRandom(seed);
log.info("Initializing read model");
TopicModel readModel;
@@ -112,7 +104,12 @@ protected void setup(Context context) throws IOException, InterruptedException {
public void map(IntWritable docId, VectorWritable document, Context context)
throws IOException, InterruptedException{
/* where to get docTopics? */
- Vector topicVector = new DenseVector(new double[numTopics]).assign(1.0/numTopics);
+ if (0 < testFraction && random.nextFloat() < testFraction) {
+ // don't train on test doc
+ return;
+ }
+ context.getCounter(CVB0Driver.Counters.SAMPLED_DOCUMENTS).increment(1);
+ Vector topicVector = new DenseVector(new double[numTopics]).assign(1.0 / numTopics);
modelTrainer.train(document.get(), topicVector, true, maxIters);
}
Oops, something went wrong.

0 comments on commit 53c2275

Please sign in to comment.