Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

re-instate the aggregator within the topology builder

  • Loading branch information...
commit e5bc7c493daef4af9799e5c4c5e6cfc9cf302988 1 parent adb01bf
Federico Brubacher authored
9 storm-ml/src/main/java/com/twitter/MainOnlineTopology.java
View
@@ -1,5 +1,11 @@
package com.twitter;
+import java.util.ArrayList;
+import java.util.List;
+
+import net.spy.memcached.AddrUtil;
+import net.spy.memcached.MemcachedClient;
+import net.spy.memcached.internal.OperationFuture;
import backtype.storm.Config;
import backtype.storm.LocalCluster;
import backtype.storm.LocalDRPC;
@@ -16,6 +22,9 @@
static Double bias = 1.0;
public static void main(String[] args) throws Exception {
+ MemcachedClient memcache = new MemcachedClient(AddrUtil.getAddresses(MEMCACHED_SERVERS));
+ OperationFuture promise = memcache.set("model", 0, "[0.0, 0.0]");
+ promise.get();
Config topology_conf = new Config();
String topology_name;
4 storm-ml/src/main/java/com/twitter/algorithms/Aggregator.java
View
@@ -15,6 +15,7 @@
import backtype.storm.topology.base.BaseRichBolt;
import backtype.storm.tuple.Tuple;
+import com.twitter.util.Datautil;
import com.twitter.util.MathUtil;
public class Aggregator extends BaseRichBolt {
@@ -34,6 +35,7 @@ public void execute(Tuple tuple) {
List<Double> weight = (List<Double>) tuple.getValue(0);
Double parallelUpdateWeight = (Double) tuple.getValue(1);
+ LOG.error("AGGGG" + weight);
if (parallelUpdateWeight != 1.0) {
weight = MathUtil.times(weight, parallelUpdateWeight);
}
@@ -46,7 +48,7 @@ public void execute(Tuple tuple) {
MathUtil.times(aggregateWeights, 1.0 / totalUpdateWeight);
LOG.info("aggregate weights" + aggregateWeights);
if (aggregateWeights != null) {
- memcache.set("model", 3600 * 24, aggregateWeights);
+ memcache.set("model", 3600 * 24, Datautil.toStrVector(aggregateWeights));
}
}
7 storm-ml/src/main/java/com/twitter/algorithms/Learner.java
View
@@ -7,6 +7,7 @@
import net.spy.memcached.MemcachedClient;
+import org.apache.commons.lang.ArrayUtils;
import org.apache.log4j.Logger;
import com.twitter.data.Example;
@@ -31,9 +32,11 @@ public Learner(int dimension) {
}
public void update(Example example, int epoch, MemcachedClient memcache) {
- String cas_weights = (String) memcache.get("weights");
+ String cas_weights = (String) memcache.get("model");
List<Double> weights = Datautil.parse_str_vector(cas_weights);
- LOG.error("double weights" + weights);
+ Double[] weights_double = weights.toArray(new Double[weights.size()]);
+ this.setWeights(ArrayUtils.toPrimitive(weights_double));
+ LOG.error("double weights" + weights_double[0]);
int predicted = predict(example);
updateStats(example, predicted);
LOG.debug("EXAMPLE " + example.label + " PREDICTED: " + predicted);
3  storm-ml/src/main/java/com/twitter/storm/primitives/LocalLearner.java
View
@@ -62,7 +62,8 @@ public void execute(Tuple tuple) {
example.x[1] = (Double) tuple.getValue(1);
example.label = (Double) tuple.getValue(2);
learner.update(example, 1, memcache);
- LOG.debug("getwe" + learner.getWeights());
+ LOG.debug("local weights" + learner.getWeightsArray() + " parallel weights "
+ + learner.getParallelUpdateWeight());
_collector.emit(new Values(learner.getWeightsArray(), learner.getParallelUpdateWeight()));
_collector.ack(tuple);
}
4 storm-ml/src/main/java/com/twitter/storm/primitives/MLTopologyBuilder.java
View
@@ -8,6 +8,8 @@
import backtype.storm.topology.IRichBolt;
import backtype.storm.topology.TopologyBuilder;
+import com.twitter.algorithms.Aggregator;
+
public class MLTopologyBuilder {
public static final String MEMCACHED_SERVERS = "127.0.0.1:11211";
@@ -97,6 +99,8 @@ public TopologyBuilder prepareTopology(String drpcFunctionName, ILocalDRPC drpc,
topology_builder.setBolt(this.topology_prefix + "-training-bolt", this.rich_training_bolt,
this.training_bolt_parallelism).shuffleGrouping(this.topology_prefix + "-training-spout");
}
+ topology_builder.setBolt("aggregator", new Aggregator(MEMCACHED_SERVERS)).globalGrouping(
+ this.topology_prefix + "-training-bolt");
// evaluation
DRPCSpout drpc_spout;
9 storm-ml/src/main/java/com/twitter/util/Datautil.java
View
@@ -54,4 +54,13 @@ public static Double dot_product(List<Double> vector_a, List<Double> vector_b) {
}
return lines;
}
+
+ public static String toStrVector(List<Double> aggregateWeights) {
+ String acc = "[";
+ for (Double weight : aggregateWeights) {
+ acc += weight.toString() + ", ";
+ }
+ acc += "]";
+ return acc;
+ }
}
Please sign in to comment.
Something went wrong with that request. Please try again.