Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP

Comparing changes

Choose two branches to see what's changed or to start a new pull request. If you need to, you can also compare across forks.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also compare across forks.
base fork: fbrubacher/storm-contrib
base: ccfb1b226b
...
head fork: fbrubacher/storm-contrib
compare: 9a87175652
Checking mergeability… Don't worry, you can still create the pull request.
  • 2 commits
  • 6 files changed
  • 0 commit comments
  • 1 contributor
View
7 storm-ml/src/main/java/com/twitter/MainOnlineTopology.java
@@ -18,13 +18,12 @@
static Double bias = 1.0;
public static void main(String[] args) throws Exception {
-
TopologyBuilder builder = new TopologyBuilder();
LocalDRPC drpc = new LocalDRPC();
builder.setSpout("example_spitter", new TrainingSpout());
builder.setBolt("local_learner", new LocalLearner(2, MEMCACHED_SERVERS), 1).shuffleGrouping("example_spitter");
- builder.setBolt("aggregator", new Aggregator()).globalGrouping("local_learner");
+ builder.setBolt("aggregator", new Aggregator(MEMCACHED_SERVERS)).globalGrouping("local_learner");
LinearDRPCTopologyBuilder drpc_builder = new LinearDRPCTopologyBuilder("evaluate");
drpc_builder.addBolt(new EvaluationBolt(bias, threshold, MEMCACHED_SERVERS), 3);
@@ -33,10 +32,10 @@ public static void main(String[] args) throws Exception {
conf.setDebug(true);
LocalCluster cluster = new LocalCluster();
cluster.submitTopology("learning", conf, builder.createTopology());
- cluster.submitTopology("evaluation", conf, drpc_builder.createLocalTopology(drpc));
+ // cluster.submitTopology("evaluation", conf, drpc_builder.createLocalTopology(drpc));
Utils.sleep(10000);
- cluster.killTopology("test");
+ cluster.killTopology("learning");
cluster.shutdown();
}
View
31 storm-ml/src/main/java/com/twitter/algorithms/Aggregator.java
@@ -1,8 +1,12 @@
package com.twitter.algorithms;
+import java.io.IOException;
import java.util.List;
import java.util.Map;
+import net.spy.memcached.AddrUtil;
+import net.spy.memcached.MemcachedClient;
+
import org.apache.log4j.Logger;
import backtype.storm.task.OutputCollector;
@@ -18,6 +22,12 @@
public static Logger LOG = Logger.getLogger(Aggregator.class);
List<Double> aggregateWeights = null;
double totalUpdateWeight = 1.0;
+ MemcachedClient memcache;
+ String memcached_servers;
+
+ public Aggregator(String memcached_servers) {
+ this.memcached_servers = memcached_servers;
+ }
public void execute(Tuple tuple) {
@@ -33,20 +43,25 @@ public void execute(Tuple tuple) {
MathUtil.plus(aggregateWeights, weight);
}
totalUpdateWeight += parallelUpdateWeight;
- LOG.info("totalUpdate");
- LOG.info(totalUpdateWeight);
+ LOG.info("aggregate weights" + aggregateWeights);
+ MathUtil.times(aggregateWeights, 1.0 / totalUpdateWeight);
if (aggregateWeights != null) {
- MathUtil.times(aggregateWeights, 1.0 / totalUpdateWeight);
- // LOG.info("New AGGREGATE vector: " + Arrays.toString(aggregateWeights));
+ memcache.set("model", 3600 * 24, aggregateWeights);
}
- }
-
- public void declareOutputFields(OutputFieldsDeclarer declarer) {
}
public void prepare(Map stormConf, TopologyContext context, OutputCollector collector) {
-
+ try {
+ memcache = new MemcachedClient(AddrUtil.getAddresses(memcached_servers));
+ } catch (IOException e) {
+ // TODO Auto-generated catch block
+ e.printStackTrace();
+ }
}
+ public void declareOutputFields(OutputFieldsDeclarer declarer) {
+ // TODO Auto-generated method stub
+
+ }
}
View
12 storm-ml/src/main/java/com/twitter/algorithms/Learner.java
@@ -5,7 +5,6 @@
import java.util.Arrays;
import java.util.List;
-import net.spy.memcached.CASValue;
import net.spy.memcached.MemcachedClient;
import org.apache.log4j.Logger;
@@ -25,17 +24,16 @@
double totalLoss = 0.0;
double gradientSum = 0.0;
protected double learningRate = 0.0;
- MemcachedClient memcache;
- public Learner(int dimension, MemcachedClient memcache) {
+ public Learner(int dimension) {
weights = new double[dimension];
lossFunction = new LossFunction(2);
- this.memcache = memcache;
}
- public void update(Example example, int epoch) {
- CASValue<Object> cas_weights = (CASValue<Object>) this.memcache.get("weights");
- List<Double> weights = Datautil.parse_str_vector((String) cas_weights.getValue());
+ public void update(Example example, int epoch, MemcachedClient memcache) {
+ String cas_weights = (String) memcache.get("weights");
+ List<Double> weights = Datautil.parse_str_vector(cas_weights);
+ LOG.error("double weights" + weights);
int predicted = predict(example);
updateStats(example, predicted);
LOG.debug("EXAMPLE " + example.label + " PREDICTED: " + predicted);
View
21 storm-ml/src/main/java/com/twitter/storm/primitives/BaseTrainingSpout.java
@@ -0,0 +1,21 @@
+package com.twitter.storm.primitives;
+
+import java.util.Map;
+
+import backtype.storm.spout.SpoutOutputCollector;
+import backtype.storm.task.TopologyContext;
+import backtype.storm.topology.OutputFieldsDeclarer;
+import backtype.storm.topology.base.BaseRichSpout;
+import backtype.storm.tuple.Fields;
+
+public abstract class BaseTrainingSpout extends BaseRichSpout {
+ SpoutOutputCollector _collector;
+
+ public void open(Map conf, TopologyContext context, SpoutOutputCollector collector) {
+ this._collector = collector;
+ }
+
+ public void declareOutputFields(OutputFieldsDeclarer declarer) {
+ declarer.declare(new Fields("example", "label"));
+ }
+}
View
33 storm-ml/src/main/java/com/twitter/storm/primitives/ExampleTrainingSpout.java
@@ -0,0 +1,33 @@
+package com.twitter.storm.primitives;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import backtype.storm.tuple.Values;
+
+public class ExampleTrainingSpout extends BaseTrainingSpout {
+ int samples_count = 0;
+ int max_samples = 100;
+
+ public static int get_label(Double x, Double y) {
+ // arbitrary expected output (for testing purposes)
+ return (2 * x + 1 < y) ? 1 : 0;
+ }
+
+ public void nextTuple() {
+ if (this.samples_count < this.max_samples) {
+ Double x = 100 * Math.random();
+ Double y = 100 * Math.random();
+
+ List<Double> example = new ArrayList<Double>();
+ example.add(x);
+ example.add(y);
+
+ int label = ExampleTrainingSpout.get_label(x, y);
+
+ _collector.emit(new Values(example.toString(), label));
+
+ this.samples_count++;
+ }
+ }
+}
View
23 storm-ml/src/main/java/com/twitter/storm/primitives/LocalLearner.java
@@ -2,7 +2,6 @@
import java.io.IOException;
import java.util.ArrayList;
-import java.util.Arrays;
import java.util.List;
import java.util.Map;
@@ -18,6 +17,7 @@
import backtype.storm.transactional.ICommitter;
import backtype.storm.tuple.Fields;
import backtype.storm.tuple.Tuple;
+import backtype.storm.tuple.Values;
import com.twitter.algorithms.Learner;
import com.twitter.data.Example;
@@ -34,19 +34,19 @@
HashAll hashFunction;
Learner learner;
double[] weightVector;
+ String memcached_servers;
MemcachedClient memcache;
public LocalLearner(int dimension, String memcached_servers) throws IOException {
- this(dimension, new Learner(dimension, new MemcachedClient(AddrUtil.getAddresses(memcached_servers))));
+ this(dimension, new Learner(dimension), memcached_servers);
}
- public LocalLearner(int dimension, Learner onlinePerceptron) {// , HashAll hashAll) {
+ public LocalLearner(int dimension, Learner onlinePerceptron, String memcached_servers) {// , HashAll hashAll) {
try {
this.dimension = dimension;
this.learner = onlinePerceptron;
// this.hashFunction = hashAll;
-
- weightVector = new double[dimension];
+ this.memcached_servers = memcached_servers;
weightVector = new double[dimension];
weightVector[0] = -6.8;
weightVector[1] = -0.8;
@@ -61,8 +61,9 @@ public void execute(Tuple tuple) {
example.x[0] = (Double) tuple.getValue(0);
example.x[1] = (Double) tuple.getValue(1);
example.label = (Double) tuple.getValue(2);
- learner.update(example, 1);
- _collector.emit(Arrays.asList((Object) learner.getWeights(), (Object) learner.getParallelUpdateWeight()));
+ learner.update(example, 1, memcache);
+ LOG.debug("getwe" + learner.getWeights());
+ _collector.emit(new Values(learner.getWeightsArray(), learner.getParallelUpdateWeight()));
_collector.ack(tuple);
}
@@ -74,7 +75,13 @@ public void prepare(Map stormConf, TopologyContext context, OutputCollector coll
this.collector = collector;
learner.initWeights(weightVector);
_collector = collector;
+ memcache = (MemcachedClient) context.getTaskData();
weightVector = (double[]) context.getTaskData();
- context.setTaskData(weightVector);
+ try {
+ memcache = new MemcachedClient(AddrUtil.getAddresses(memcached_servers));
+ } catch (IOException e) {
+ // TODO Auto-generated catch block
+ e.printStackTrace();
+ }
}
}

No commit comments for this range

Something went wrong with that request. Please try again.