Permalink
Browse files

Add aggregator which sums up the different weights from the learner_b…

…olts
  • Loading branch information...
1 parent 909cf16 commit 2c58f6efb9128f2078add07faddaf5a5f6688003 @fbrubacher committed Jul 4, 2012
@@ -5,6 +5,7 @@
import backtype.storm.topology.TopologyBuilder;
import backtype.storm.utils.Utils;
+import com.twitter.algorithms.Aggregator;
import com.twitter.storm.primitives.LocalLearner;
import com.twitter.storm.primitives.TrainingSpout;
@@ -15,6 +16,7 @@ public static void main(String[] args) throws Exception {
TopologyBuilder builder = new TopologyBuilder();
builder.setSpout("example_spitter", new TrainingSpout());
builder.setBolt("local_learner", new LocalLearner(2), 1).shuffleGrouping("example_spitter");
+ builder.setBolt("aggregator", new Aggregator()).globalGrouping("local_learner");
Config conf = new Config();
conf.setDebug(true);
LocalCluster cluster = new LocalCluster();
@@ -5,31 +5,25 @@
import org.apache.log4j.Logger;
-import backtype.storm.coordination.BatchOutputCollector;
import backtype.storm.task.OutputCollector;
import backtype.storm.task.TopologyContext;
import backtype.storm.topology.OutputFieldsDeclarer;
import backtype.storm.topology.base.BaseRichBolt;
-import backtype.storm.transactional.ICommitter;
import backtype.storm.tuple.Tuple;
import com.twitter.util.MathUtil;
-public class Aggregator extends BaseRichBolt implements ICommitter {
+public class Aggregator extends BaseRichBolt {
public static Logger LOG = Logger.getLogger(Aggregator.class);
double[] aggregateWeights = null;
- double totalUpdateWeight = 0.0;
-
- public void prepare(Map conf, TopologyContext context, BatchOutputCollector collector, Object id) {
- // TODO Auto-generated method stub
-
- }
+ double totalUpdateWeight = 1.0;
public void execute(Tuple tuple) {
- double[] weight = (double[]) tuple.getValue(1);
- double parallelUpdateWeight = (Double) tuple.getValue(2);
+ double[] weight = (double[]) tuple.getValue(0);
+ Double parallelUpdateWeight = (Double) tuple.getValue(1);
+
if (parallelUpdateWeight != 1.0) {
weight = MathUtil.times(weight, parallelUpdateWeight);
}
@@ -39,22 +33,19 @@ public void execute(Tuple tuple) {
MathUtil.plus(aggregateWeights, weight);
}
totalUpdateWeight += parallelUpdateWeight;
- }
-
- public void finishBatch() {
+ LOG.info("totalUpdate");
+ LOG.info(totalUpdateWeight);
if (aggregateWeights != null) {
MathUtil.times(aggregateWeights, 1.0 / totalUpdateWeight);
- LOG.info("New weight vector: " + Arrays.toString(aggregateWeights));
+ LOG.info("New AGGREGATE vector: " + Arrays.toString(aggregateWeights));
}
}
public void declareOutputFields(OutputFieldsDeclarer declarer) {
- // TODO Auto-generated method stub
}
public void prepare(Map stormConf, TopologyContext context, OutputCollector collector) {
- // TODO Auto-generated method stub
}
@@ -1,7 +1,9 @@
package com.twitter.algorithms;
import java.io.Serializable;
+import java.util.ArrayList;
import java.util.Arrays;
+import java.util.List;
import org.apache.log4j.Logger;
@@ -48,6 +50,13 @@ protected double getLearningRate(Example example, int timestamp) {
return weights;
}
+ public List<Object> getWeightsArray() {
+ List<Object> weight_array = new ArrayList<Object>();
+ for (double weight : weights)
+ weight_array.add(weight);
+ return weight_array;
+ }
+
public double getParallelUpdateWeight() {
return gradientSum;
}
@@ -54,14 +54,15 @@ public void execute(Tuple tuple) {
example.label = (Double) tuple.getValue(2);
example.isLabeled = true;
learner.update(example, 1);
+ _collector.emit(Arrays.asList((Object) learner.getWeights(), (Object) learner.getParallelUpdateWeight()));
_collector.ack(tuple);
LOG.debug("New weights" + Arrays.toString(learner.getWeights()));
// example.parseFrom((String) tuple.getValue(1), hashFunction);
// buffer.add(example);
}
public void declareOutputFields(OutputFieldsDeclarer declarer) {
- declarer.declare(new Fields("id", "weight_vector", "parallel_update_weights"));
+ declarer.declare(new Fields("weight_vector", "parallel_weight"));
}
public void prepare(Map stormConf, TopologyContext context, OutputCollector collector) {

0 comments on commit 2c58f6e

Please sign in to comment.