Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

Add aggregator which sums up the different weights from the learner_b…

…olts
  • Loading branch information...
commit 2c58f6efb9128f2078add07faddaf5a5f6688003 1 parent 909cf16
@fbrubacher authored
View
2  storm-ml/src/main/java/com/twitter/MainOnlineTopology.java
@@ -5,6 +5,7 @@
import backtype.storm.topology.TopologyBuilder;
import backtype.storm.utils.Utils;
+import com.twitter.algorithms.Aggregator;
import com.twitter.storm.primitives.LocalLearner;
import com.twitter.storm.primitives.TrainingSpout;
@@ -15,6 +16,7 @@ public static void main(String[] args) throws Exception {
TopologyBuilder builder = new TopologyBuilder();
builder.setSpout("example_spitter", new TrainingSpout());
builder.setBolt("local_learner", new LocalLearner(2), 1).shuffleGrouping("example_spitter");
+ builder.setBolt("aggregator", new Aggregator()).globalGrouping("local_learner");
Config conf = new Config();
conf.setDebug(true);
LocalCluster cluster = new LocalCluster();
View
25 storm-ml/src/main/java/com/twitter/algorithms/Aggregator.java
@@ -5,31 +5,25 @@
import org.apache.log4j.Logger;
-import backtype.storm.coordination.BatchOutputCollector;
import backtype.storm.task.OutputCollector;
import backtype.storm.task.TopologyContext;
import backtype.storm.topology.OutputFieldsDeclarer;
import backtype.storm.topology.base.BaseRichBolt;
-import backtype.storm.transactional.ICommitter;
import backtype.storm.tuple.Tuple;
import com.twitter.util.MathUtil;
-public class Aggregator extends BaseRichBolt implements ICommitter {
+public class Aggregator extends BaseRichBolt {
public static Logger LOG = Logger.getLogger(Aggregator.class);
double[] aggregateWeights = null;
- double totalUpdateWeight = 0.0;
-
- public void prepare(Map conf, TopologyContext context, BatchOutputCollector collector, Object id) {
- // TODO Auto-generated method stub
-
- }
+ double totalUpdateWeight = 1.0;
public void execute(Tuple tuple) {
- double[] weight = (double[]) tuple.getValue(1);
- double parallelUpdateWeight = (Double) tuple.getValue(2);
+ double[] weight = (double[]) tuple.getValue(0);
+ Double parallelUpdateWeight = (Double) tuple.getValue(1);
+
if (parallelUpdateWeight != 1.0) {
weight = MathUtil.times(weight, parallelUpdateWeight);
}
@@ -39,22 +33,19 @@ public void execute(Tuple tuple) {
MathUtil.plus(aggregateWeights, weight);
}
totalUpdateWeight += parallelUpdateWeight;
- }
-
- public void finishBatch() {
+ LOG.info("totalUpdate");
+ LOG.info(totalUpdateWeight);
if (aggregateWeights != null) {
MathUtil.times(aggregateWeights, 1.0 / totalUpdateWeight);
- LOG.info("New weight vector: " + Arrays.toString(aggregateWeights));
+ LOG.info("New AGGREGATE vector: " + Arrays.toString(aggregateWeights));
}
}
public void declareOutputFields(OutputFieldsDeclarer declarer) {
- // TODO Auto-generated method stub
}
public void prepare(Map stormConf, TopologyContext context, OutputCollector collector) {
- // TODO Auto-generated method stub
}
View
9 storm-ml/src/main/java/com/twitter/algorithms/Learner.java
@@ -1,7 +1,9 @@
package com.twitter.algorithms;
import java.io.Serializable;
+import java.util.ArrayList;
import java.util.Arrays;
+import java.util.List;
import org.apache.log4j.Logger;
@@ -48,6 +50,13 @@ protected double getLearningRate(Example example, int timestamp) {
return weights;
}
+ public List<Object> getWeightsArray() {
+ List<Object> weight_array = new ArrayList<Object>();
+ for (double weight : weights)
+ weight_array.add(weight);
+ return weight_array;
+ }
+
public double getParallelUpdateWeight() {
return gradientSum;
}
View
3  storm-ml/src/main/java/com/twitter/storm/primitives/LocalLearner.java
@@ -54,6 +54,7 @@ public void execute(Tuple tuple) {
example.label = (Double) tuple.getValue(2);
example.isLabeled = true;
learner.update(example, 1);
+ _collector.emit(Arrays.asList((Object) learner.getWeights(), (Object) learner.getParallelUpdateWeight()));
_collector.ack(tuple);
LOG.debug("New weights" + Arrays.toString(learner.getWeights()));
// example.parseFrom((String) tuple.getValue(1), hashFunction);
@@ -61,7 +62,7 @@ public void execute(Tuple tuple) {
}
public void declareOutputFields(OutputFieldsDeclarer declarer) {
- declarer.declare(new Fields("id", "weight_vector", "parallel_update_weights"));
+ declarer.declare(new Fields("weight_vector", "parallel_weight"));
}
public void prepare(Map stormConf, TopologyContext context, OutputCollector collector) {
Please sign in to comment.
Something went wrong with that request. Please try again.