Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP

We’re showing branches in this repository, but you can also compare across forks.

base fork: nathanmarz/storm-contrib
base: dda29767cc
...
head fork: nathanmarz/storm-contrib
compare: 8a15a397e8
  • 3 commits
  • 6 files changed
  • 0 commit comments
  • 1 contributor
58 storm-ml/src/main/java/com/twitter/MainOnlineTopology.java
View
@@ -1,15 +1,19 @@
package com.twitter;
+import java.util.ArrayList;
+import java.util.List;
+
+import net.spy.memcached.AddrUtil;
+import net.spy.memcached.MemcachedClient;
+import net.spy.memcached.internal.OperationFuture;
import backtype.storm.Config;
import backtype.storm.LocalCluster;
import backtype.storm.LocalDRPC;
-import backtype.storm.drpc.LinearDRPCTopologyBuilder;
-import backtype.storm.topology.TopologyBuilder;
import backtype.storm.utils.Utils;
-import com.twitter.algorithms.Aggregator;
import com.twitter.storm.primitives.EvaluationBolt;
import com.twitter.storm.primitives.LocalLearner;
+import com.twitter.storm.primitives.MLTopologyBuilder;
import com.twitter.storm.primitives.TrainingSpout;
public class MainOnlineTopology {
@@ -18,25 +22,33 @@
static Double bias = 1.0;
public static void main(String[] args) throws Exception {
- TopologyBuilder builder = new TopologyBuilder();
- LocalDRPC drpc = new LocalDRPC();
-
- builder.setSpout("example_spitter", new TrainingSpout());
- builder.setBolt("local_learner", new LocalLearner(2, MEMCACHED_SERVERS), 1).shuffleGrouping("example_spitter");
- builder.setBolt("aggregator", new Aggregator(MEMCACHED_SERVERS)).globalGrouping("local_learner");
-
- LinearDRPCTopologyBuilder drpc_builder = new LinearDRPCTopologyBuilder("evaluate");
- drpc_builder.addBolt(new EvaluationBolt(bias, threshold, MEMCACHED_SERVERS), 3);
-
- Config conf = new Config();
- conf.setDebug(true);
- LocalCluster cluster = new LocalCluster();
- cluster.submitTopology("learning", conf, builder.createTopology());
- // cluster.submitTopology("evaluation", conf, drpc_builder.createLocalTopology(drpc));
-
- Utils.sleep(10000);
- cluster.killTopology("learning");
- cluster.shutdown();
-
+ MemcachedClient memcache = new MemcachedClient(AddrUtil.getAddresses(MEMCACHED_SERVERS));
+ OperationFuture promise = memcache.set("model", 0, "[0.0, 0.0]");
+ promise.get();
+
+ Config topology_conf = new Config();
+ String topology_name;
+ if (args == null || args.length == 0)
+ topology_name = "perceptron";
+ else
+ topology_name = args[0];
+
+ MLTopologyBuilder ml_topology_builder = new MLTopologyBuilder(topology_name);
+
+ ml_topology_builder.setTrainingSpout(new TrainingSpout());
+ ml_topology_builder.setTrainingBolt(new LocalLearner(2, MEMCACHED_SERVERS));
+ ml_topology_builder.setEvaluationBolt(new EvaluationBolt(1.0, 2.0, MEMCACHED_SERVERS));
+
+ if (args == null || args.length == 0) {
+ LocalDRPC drpc = new LocalDRPC();
+ LocalCluster cluster = new LocalCluster();
+
+ cluster.submitTopology(topology_name, topology_conf,
+ ml_topology_builder.createLocalTopology("evaluate", drpc));
+
+ Utils.sleep(10000);
+ cluster.killTopology("perceptron");
+ cluster.shutdown();
+ }
}
}
6 storm-ml/src/main/java/com/twitter/algorithms/Aggregator.java
View
@@ -15,6 +15,7 @@
import backtype.storm.topology.base.BaseRichBolt;
import backtype.storm.tuple.Tuple;
+import com.twitter.util.Datautil;
import com.twitter.util.MathUtil;
public class Aggregator extends BaseRichBolt {
@@ -34,6 +35,7 @@ public void execute(Tuple tuple) {
List<Double> weight = (List<Double>) tuple.getValue(0);
Double parallelUpdateWeight = (Double) tuple.getValue(1);
+ LOG.error("AGGGG" + weight);
if (parallelUpdateWeight != 1.0) {
weight = MathUtil.times(weight, parallelUpdateWeight);
}
@@ -43,10 +45,10 @@ public void execute(Tuple tuple) {
MathUtil.plus(aggregateWeights, weight);
}
totalUpdateWeight += parallelUpdateWeight;
- LOG.info("aggregate weights" + aggregateWeights);
MathUtil.times(aggregateWeights, 1.0 / totalUpdateWeight);
+ LOG.info("aggregate weights" + aggregateWeights);
if (aggregateWeights != null) {
- memcache.set("model", 3600 * 24, aggregateWeights);
+ memcache.set("model", 3600 * 24, Datautil.toStrVector(aggregateWeights));
}
}
7 storm-ml/src/main/java/com/twitter/algorithms/Learner.java
View
@@ -7,6 +7,7 @@
import net.spy.memcached.MemcachedClient;
+import org.apache.commons.lang.ArrayUtils;
import org.apache.log4j.Logger;
import com.twitter.data.Example;
@@ -31,9 +32,11 @@ public Learner(int dimension) {
}
public void update(Example example, int epoch, MemcachedClient memcache) {
- String cas_weights = (String) memcache.get("weights");
+ String cas_weights = (String) memcache.get("model");
List<Double> weights = Datautil.parse_str_vector(cas_weights);
- LOG.error("double weights" + weights);
+ Double[] weights_double = weights.toArray(new Double[weights.size()]);
+ this.setWeights(ArrayUtils.toPrimitive(weights_double));
+ LOG.error("double weights" + weights_double[0]);
int predicted = predict(example);
updateStats(example, predicted);
LOG.debug("EXAMPLE " + example.label + " PREDICTED: " + predicted);
3  storm-ml/src/main/java/com/twitter/storm/primitives/LocalLearner.java
View
@@ -62,7 +62,8 @@ public void execute(Tuple tuple) {
example.x[1] = (Double) tuple.getValue(1);
example.label = (Double) tuple.getValue(2);
learner.update(example, 1, memcache);
- LOG.debug("getwe" + learner.getWeights());
+ LOG.debug("local weights" + learner.getWeightsArray() + " parallel weights "
+ + learner.getParallelUpdateWeight());
_collector.emit(new Values(learner.getWeightsArray(), learner.getParallelUpdateWeight()));
_collector.ack(tuple);
}
128 storm-ml/src/main/java/com/twitter/storm/primitives/MLTopologyBuilder.java
View
@@ -4,64 +4,134 @@
import backtype.storm.drpc.DRPCSpout;
import backtype.storm.drpc.ReturnResults;
import backtype.storm.generated.StormTopology;
+import backtype.storm.topology.IBasicBolt;
+import backtype.storm.topology.IRichBolt;
import backtype.storm.topology.TopologyBuilder;
-import backtype.storm.topology.base.BaseRichBolt;
-import backtype.storm.topology.base.BaseRichSpout;
+
+import com.twitter.algorithms.Aggregator;
public class MLTopologyBuilder {
public static final String MEMCACHED_SERVERS = "127.0.0.1:11211";
- private BaseRichBolt trainingBolt;
- private BaseRichSpout trainingSpout;
- public TopologyBuilder prepareTopology(ILocalDRPC drpc) {
- return prepareTopology(drpc, 3.0, 0.0, 3.0, MEMCACHED_SERVERS);
+ String topology_prefix;
+
+ TrainingSpout training_spout;
+ Number training_spout_parallelism;
+
+ IBasicBolt basic_training_bolt;
+ IRichBolt rich_training_bolt;
+ Number training_bolt_parallelism;
+
+ IBasicBolt basic_evaluation_bolt;
+ IRichBolt rich_evaluation_bolt;
+ Number evaluation_bolt_parallelism;
+
+ public MLTopologyBuilder(String topologyPrefix) {
+ this.topology_prefix = topologyPrefix;
+ }
+
+ public TopologyBuilder prepareTopology(String drpcFunctionName, ILocalDRPC drpc) {
+ return prepareTopology(drpcFunctionName, drpc, 1.0, 0.0, 0.5, MEMCACHED_SERVERS);
+ }
+
+ public void setTrainingSpout(TrainingSpout trainingSpout, Number parallelism) {
+ this.training_spout = trainingSpout;
+ this.training_spout_parallelism = training_spout_parallelism;
+ }
+
+ public void setTrainingSpout(TrainingSpout trainingSpout) {
+ setTrainingSpout(trainingSpout, 1);
+ }
+
+ public void setTrainingBolt(IBasicBolt training_bolt, Number parallelism) {
+ this.basic_training_bolt = training_bolt;
+ this.rich_training_bolt = null;
+ this.training_bolt_parallelism = training_bolt_parallelism;
}
- public void setTrainingBolt(BaseRichBolt trainingBolt) {
- this.trainingBolt = trainingBolt;
+ public void setTrainingBolt(IBasicBolt training_bolt) {
+ setTrainingBolt(training_bolt, 1);
}
- public void setTrainingSpout(BaseRichSpout trainingSpout) {
- this.trainingSpout = trainingSpout;
+ public void setTrainingBolt(IRichBolt training_bolt, Number parallelism) {
+ this.rich_training_bolt = training_bolt;
+ this.basic_training_bolt = null;
+ this.training_bolt_parallelism = training_bolt_parallelism;
}
- public TopologyBuilder prepareTopology(ILocalDRPC drpc, double bias, double threshold, double learning_rate,
- String memcached_servers) {
+ public void setTrainingBolt(IRichBolt training_bolt) {
+ setTrainingBolt(training_bolt, 1);
+ }
+
+ public void setEvaluationBolt(IBasicBolt evaluation_bolt, Number parallelism) {
+ this.basic_evaluation_bolt = evaluation_bolt;
+ this.rich_evaluation_bolt = null;
+ this.evaluation_bolt_parallelism = evaluation_bolt_parallelism;
+ }
+
+ public void setEvaluationBolt(IBasicBolt evaluation_bolt) {
+ setEvaluationBolt(evaluation_bolt, 1);
+ }
+
+ public void setEvaluationBolt(IRichBolt evaluation_bolt, Number parallelism) {
+ this.rich_evaluation_bolt = evaluation_bolt;
+ this.basic_evaluation_bolt = null;
+ this.evaluation_bolt_parallelism = evaluation_bolt_parallelism;
+ }
+
+ public void setEvaluationBolt(IRichBolt evaluation_bolt) {
+ setEvaluationBolt(evaluation_bolt, 1);
+ }
+
+ public TopologyBuilder prepareTopology(String drpcFunctionName, ILocalDRPC drpc, double bias, double threshold,
+ double learning_rate, String memcached_servers) {
TopologyBuilder topology_builder = new TopologyBuilder();
// training
- topology_builder.setSpout("training-spout", new ExampleTrainingSpout());
-
- topology_builder.setBolt("training-bolt", new LocalLearner(bias, threshold, learning_rate, MEMCACHED_SERVERS))
- .shuffleGrouping("training-spout");
+ topology_builder.setSpout(this.topology_prefix + "-training-spout", this.training_spout,
+ this.training_spout_parallelism);
+
+ if (this.rich_training_bolt == null) {
+ topology_builder.setBolt(this.topology_prefix + "-training-bolt", this.basic_training_bolt,
+ this.training_bolt_parallelism).shuffleGrouping(this.topology_prefix + "-training-spout");
+ } else {
+ topology_builder.setBolt(this.topology_prefix + "-training-bolt", this.rich_training_bolt,
+ this.training_bolt_parallelism).shuffleGrouping(this.topology_prefix + "-training-spout");
+ }
+ topology_builder.setBolt("aggregator", new Aggregator(MEMCACHED_SERVERS)).globalGrouping(
+ this.topology_prefix + "-training-bolt");
// evaluation
DRPCSpout drpc_spout;
+
if (drpc != null)
- drpc_spout = new DRPCSpout("evaluate", drpc);
+ drpc_spout = new DRPCSpout(drpcFunctionName, drpc);
else
- drpc_spout = new DRPCSpout("evaluate");
-
- topology_builder.setSpout("drpc-spout", drpc_spout);
+ drpc_spout = new DRPCSpout(drpcFunctionName);
- topology_builder.setBolt(
- "drpc-evaluation",
- new EvaluationBolt(PerceptronDRPCTopology.bias, PerceptronDRPCTopology.threshold,
- PerceptronDRPCTopology.MEMCACHED_SERVERS)).shuffleGrouping("drpc-spout");
+ topology_builder.setSpout(this.topology_prefix + "-drpc-spout", drpc_spout);
- topology_builder.setBolt("drpc-return", new ReturnResults()).shuffleGrouping("drpc-evaluation");
+ if (this.rich_evaluation_bolt == null) {
+ topology_builder.setBolt(this.topology_prefix + "-drpc-evaluation", this.basic_evaluation_bolt,
+ this.evaluation_bolt_parallelism).shuffleGrouping(this.topology_prefix + "-drpc-spout");
+ } else {
+ topology_builder.setBolt(this.topology_prefix + "-drpc-evaluation", this.rich_evaluation_bolt,
+ this.evaluation_bolt_parallelism).shuffleGrouping(this.topology_prefix + "-drpc-spout");
+ }
+ topology_builder.setBolt(this.topology_prefix + "-drpc-return", new ReturnResults()).shuffleGrouping(
+ this.topology_prefix + "-drpc-evaluation");
// return
return topology_builder;
}
- public StormTopology createLocalTopology(ILocalDRPC drpc) {
- return prepareTopology(drpc).createTopology();
+ public StormTopology createLocalTopology(String drpcFunctionName, ILocalDRPC drpc) {
+ return prepareTopology(drpcFunctionName, drpc).createTopology();
}
- public StormTopology createRemoteTopology() {
- return prepareTopology(null).createTopology();
+ public StormTopology createRemoteTopology(String drpcFunctionName) {
+ return prepareTopology(drpcFunctionName, null).createTopology();
}
}
9 storm-ml/src/main/java/com/twitter/util/Datautil.java
View
@@ -54,4 +54,13 @@ public static Double dot_product(List<Double> vector_a, List<Double> vector_b) {
}
return lines;
}
+
+ public static String toStrVector(List<Double> aggregateWeights) {
+ String acc = "[";
+ for (Double weight : aggregateWeights) {
+ acc += weight.toString() + ", ";
+ }
+ acc += "]";
+ return acc;
+ }
}

No commit comments for this range

Something went wrong with that request. Please try again.